xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/Support/CommandLine.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/MemoryBuffer.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Support/ToolOutputFile.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
25 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/Builders.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/Location.h"  // from @llvm-project
31 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
32 #include "mlir/IR/Operation.h"  // from @llvm-project
33 #include "mlir/IR/Types.h"  // from @llvm-project
34 #include "mlir/IR/Value.h"  // from @llvm-project
35 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
36 #include "mlir/Tools/mlir-translate/Translation.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
38 #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
39 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
41 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
44 
45 using llvm::cl::opt;
46 
47 // Commandline flag to enable the control of flatbuffer import.
48 bool use_external_constant;
49 
50 // Commandline flag to enable graph pruning.
51 bool experimental_prune_unreachable_nodes_unconditionally;
52 
53 // NOLINTNEXTLINE
54 static opt<bool, true> use_external_constant_flag(
55     "use-external-constant",
56     llvm::cl::desc("Use external constant during flatbuffer import"),
57     llvm::cl::location(use_external_constant), llvm::cl::init(false));
58 
59 // TODO(b/147111261): After the importer supports generic custom ops, we should
60 // change the flag to a more lightwise flag, e.g.
61 // "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune
62 // the operations.
63 // NOLINTNEXTLINE
64 static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
65     "experimental-prune-unreachable-nodes-unconditionally",
66     llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."),
67     llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
68     llvm::cl::init(false));
69 
70 // NOLINTNEXTLINE
71 static opt<std::string> input_arrays_flag(
72     "input-arrays",
73     llvm::cl::desc(
74         "List of input tensors, if different from the default inputs"),
75     llvm::cl::init(""));
76 
77 // NOLINTNEXTLINE
78 static opt<std::string> output_arrays_flag(
79     "output-arrays",
80     llvm::cl::desc(
81         "List of output tensors, if different from the default outputs"),
82     llvm::cl::init(""));
83 using llvm::cl::opt;
84 
85 // These command line flags enable control of the translation implementation.
86 bool emit_builtin_tflite_ops;
87 bool emit_custom_ops;
88 bool emit_select_tf_ops;
89 bool lower_tensor_list_ops;
90 bool strip_debug_info;
91 
92 // NOLINTNEXTLINE
93 static opt<bool, true> emit_builtin_tflite_ops_flag(
94     "emit-builtin-tflite-ops",
95     llvm::cl::desc(
96         "Emit TFLite built in operations in the generated TFLite model"),
97     llvm::cl::location(emit_builtin_tflite_ops), llvm::cl::init(true));
98 
99 // NOLINTNEXTLINE
100 static opt<bool, true> emit_select_tf_ops_flag(
101     "emit-select-tf-ops",
102     llvm::cl::desc(
103         "Emit Select TF operations (Flex ops) in the generated TFLite model"),
104     llvm::cl::location(emit_select_tf_ops), llvm::cl::init(false));
105 
106 // NOLINTNEXTLINE
107 static opt<bool, true> emit_custom_ops_flag(
108     "emit-custom-ops",
109     llvm::cl::desc("Emit Custom operations in the generated TFLite model"),
110     llvm::cl::location(emit_custom_ops), llvm::cl::init(false));
111 
112 // NOLINTNEXTLINE
113 static opt<bool, true> lower_tensor_list_ops_flag(
114     "lower-tensor-list-ops",
115     llvm::cl::desc("Lower the TensorList ops within the TFLite dialect"),
116     llvm::cl::location(lower_tensor_list_ops), llvm::cl::init(false));
117 
118 // NOLINTNEXTLINE
119 static opt<bool, true> strip_debug_info_flag(
120     "strip-debug-info", llvm::cl::desc("Strip debug info during export"),
121     llvm::cl::location(strip_debug_info), llvm::cl::init(false));
122 
123 namespace mlir {
124 namespace {
FlatBufferFileToMlirTrans(llvm::SourceMgr * source_mgr,MLIRContext * context,bool use_external_constant,bool experimental_prune_unreachable_nodes_unconditionally)125 static OwningOpRef<mlir::ModuleOp> FlatBufferFileToMlirTrans(
126     llvm::SourceMgr* source_mgr, MLIRContext* context,
127     bool use_external_constant,
128     bool experimental_prune_unreachable_nodes_unconditionally) {
129   const llvm::MemoryBuffer* input =
130       source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
131   std::string error;
132   auto loc =
133       mlir::FileLineColLoc::get(context, input->getBufferIdentifier(), 0, 0);
134 
135   // Parses input/output names from command line options.
136   std::vector<std::string> inputs;
137   std::vector<std::string> outputs;
138   // Use output parser since we only have tensor names.
139   if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) {
140     return emitError(loc, "parsing input array info failed ")
141                << input_arrays_flag,
142            nullptr;
143   }
144   if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) {
145     return emitError(loc, "parsing output array info failed ")
146                << output_arrays_flag,
147            nullptr;
148   }
149   return tflite::FlatBufferToMlir(
150       absl::string_view(input->getBufferStart(), input->getBufferSize()),
151       context, loc, use_external_constant, inputs, outputs,
152       experimental_prune_unreachable_nodes_unconditionally);
153 }
154 
MlirToFlatBufferFileTranslateFunction(ModuleOp module,llvm::raw_ostream & output)155 static LogicalResult MlirToFlatBufferFileTranslateFunction(
156     ModuleOp module, llvm::raw_ostream& output) {
157   std::string serialized_flatbuffer;
158   std::unique_ptr<tensorflow::OpOrArgNameMapper> op_or_arg_name_mapper;
159   if (strip_debug_info) {
160     op_or_arg_name_mapper =
161         std::make_unique<tensorflow::OpOrArgStripNameMapper>();
162   } else {
163     op_or_arg_name_mapper =
164         std::make_unique<tensorflow::OpOrArgLocNameMapper>();
165   }
166   tflite::FlatbufferExportOptions options;
167   options.toco_flags.set_force_select_tf_ops(!emit_builtin_tflite_ops);
168   options.toco_flags.set_enable_select_tf_ops(emit_select_tf_ops);
169   options.toco_flags.set_allow_custom_ops(emit_custom_ops);
170   options.op_or_arg_name_mapper = op_or_arg_name_mapper.get();
171   if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
172                                                  &serialized_flatbuffer))
173     return mlir::failure();
174 
175   output << serialized_flatbuffer;
176   return success();
177 }
178 }  // namespace
179 
180 static TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
181     "tflite-flatbuffer-to-mlir",
__anona6bed5d70202(llvm::SourceMgr& source_mgr, MLIRContext* context) 182     [](llvm::SourceMgr& source_mgr, MLIRContext* context) {
183       return FlatBufferFileToMlirTrans(
184           &source_mgr, context, use_external_constant,
185           experimental_prune_unreachable_nodes_unconditionally);
186     });
187 
188 static TranslateFromMLIRRegistration MLIRToFlatBufferTranslate(
189     "mlir-to-tflite-flatbuffer", MlirToFlatBufferFileTranslateFunction,
__anona6bed5d70302(DialectRegistry& registry) 190     [](DialectRegistry& registry) {
191       registry.insert<quant::QuantizationDialect,
192                       quantfork::QuantizationForkDialect>();
193       mlir::RegisterAllTensorFlowDialects(registry);
194       registry.insert<TFL::TensorFlowLiteDialect>();
195       registry.insert<arith::ArithmeticDialect>();
196       registry.insert<func::FuncDialect>();
197     });
198 }  // namespace mlir
199