1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
16
17 #include <ostream>
18 #include <string>
19 #include <unordered_set>
20 #include <utility>
21
22 #include "llvm/Support/ToolOutputFile.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/IR/MLIRContext.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Support/FileUtilities.h" // from @llvm-project
28 #include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
30 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
31 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
32 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
34 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
35 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
36 #include "tensorflow/core/framework/graph.pb.h"
37 #include "tensorflow/core/framework/types.pb.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/platform/status.h"
40 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
41 #include "tensorflow/lite/toco/model_flags.pb.h"
42 #include "tensorflow/lite/toco/toco_flags.pb.h"
43 #include "tensorflow/lite/toco/types.pb.h"
44 #include "tensorflow/lite/tools/optimize/reduced_precision_support.h"
45 #include "tensorflow/stream_executor/lib/statusor.h"
46
47 using stream_executor::port::StatusOr;
48
49 namespace tensorflow {
50 namespace internal {
51 namespace {
52
53 using ::mlir::quant::ReducedPrecisionSupport;
54
55 // Op def string for TFLite_Detection_PostProcess Op.
56 const char kDetectionPostProcessOp[] =
57 "name: 'TFLite_Detection_PostProcess' input_arg: { name: "
58 "'raw_outputs/box_encodings' type: DT_FLOAT } input_arg: { name: "
59 "'raw_outputs/class_predictions' type: DT_FLOAT } input_arg: { name: "
60 "'anchors' type: DT_FLOAT } output_arg: { name: "
61 "'TFLite_Detection_PostProcess' type: DT_FLOAT } output_arg: { name: "
62 "'TFLite_Detection_PostProcess:1' type: DT_FLOAT } output_arg: { name: "
63 "'TFLite_Detection_PostProcess:2' type: DT_FLOAT } output_arg: { name: "
64 "'TFLite_Detection_PostProcess:3' type: DT_FLOAT } attr : { name: "
65 "'h_scale' type: 'float'} attr : { name: 'max_classes_per_detection' "
66 "type: 'int'} attr : { name: 'max_detections' type: 'int'} attr : { "
67 "name: 'nms_iou_threshold' type: 'float'} attr : { name: "
68 "'nms_score_threshold' type: 'float'} attr : { name: 'num_classes' type: "
69 "'int'} attr : { name: 'w_scale' type: 'float'} attr : { name: 'x_scale' "
70 "type: 'float'} attr : { name: 'y_scale' type: 'float'} attr { name: "
71 "'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
72 "name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
73
74 const char kUnidirectionalSequenceLstmOp[] =
75 "name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
76 "DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
77 "input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
78 "name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
79 "'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
80 "'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
81 "'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
82 "'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
83 "'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
84 "'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
85 "'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
86 "'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
87 "type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
88 "input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
89 "'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
90 "type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
91 "input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
92 "name: 'InputCellStateTensor' type: DT_FLOAT } "
93 "output_arg: { name: 'Concat' type: DT_FLOAT} "
94 "output_arg: { name: "
95 "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
96 "attr : { name: '_tflite_input_indices' type: 'list(int)'}";
97
98 const char kUnidirectionalSequenceRnnOp[] =
99 "name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
100 "DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
101 "input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
102 "name: 'Bias' type: DT_FLOAT} "
103 "input_arg: { name: 'HiddenState' type: DT_FLOAT} "
104 "output_arg: { name: "
105 "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
106 "DT_FLOAT} "
107 "attr : { name: '_tflite_input_indices' type: 'list(int)'}";
108
109 // Converts the toco::IODataType to tensorflow::DataType. Only contains the
110 // conversion mapping for constants defined in TFLite Python API.
ConvertIODataTypeToDataType(toco::IODataType dtype)111 DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
112 switch (dtype) {
113 case toco::IODataType::FLOAT:
114 return DT_FLOAT;
115 case toco::IODataType::FLOAT16:
116 return DT_HALF;
117 case toco::IODataType::FLOAT64:
118 return DT_DOUBLE;
119 case toco::IODataType::QUANTIZED_UINT8:
120 return DT_QUINT8;
121 case toco::IODataType::QUANTIZED_INT8:
122 return DT_QINT8;
123 case toco::IODataType::QUANTIZED_INT16:
124 return DT_QINT16;
125 case toco::IODataType::INT8:
126 return DT_INT8;
127 case toco::IODataType::INT16:
128 return DT_INT16;
129 case toco::IODataType::UINT16:
130 return DT_UINT16;
131 case toco::IODataType::INT32:
132 return DT_INT32;
133 case toco::IODataType::UINT32:
134 return DT_UINT32;
135 case toco::IODataType::INT64:
136 return DT_INT64;
137 case toco::IODataType::UINT8:
138 return DT_UINT8;
139 case toco::IODataType::UINT64:
140 return DT_UINT64;
141 case toco::IODataType::STRING:
142 return DT_STRING;
143 case toco::IODataType::BOOL:
144 return DT_BOOL;
145 case toco::IODataType::COMPLEX64:
146 return DT_COMPLEX64;
147 case toco::IODataType::COMPLEX128:
148 return DT_COMPLEX128;
149 case toco::IODataType::RESOURCE:
150 return DT_RESOURCE;
151 case toco::IODataType::VARIANT:
152 return DT_VARIANT;
153 default:
154 return DT_INVALID;
155 }
156 }
157
InputStatsToMinMax(double mean,double std,DataType type)158 StatusOr<std::pair<double, double>> InputStatsToMinMax(double mean, double std,
159 DataType type) {
160 // Only qint8 and quint8 are considered here.
161 double qmin, qmax;
162 if (type == DT_QUINT8) {
163 qmin = 0.0;
164 qmax = 255.0;
165 } else if (type == DT_QINT8) {
166 qmin = -128.0;
167 qmax = 127.0;
168 } else {
169 return errors::InvalidArgument("Only int8 and uint8 are considered.");
170 }
171 return std::make_pair((qmin - mean) / std, (qmax - mean) / std);
172 }
173
RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs)174 Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
175 for (const auto& tf_opdefs_string : extra_tf_opdefs) {
176 tensorflow::OpDef opdef;
177 if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
178 &opdef)) {
179 return errors::InvalidArgument("fail to parse extra OpDef");
180 }
181 // Make sure the op is not already registered. If registered continue.
182 const OpRegistrationData* op_reg =
183 tensorflow::OpRegistry::Global()->LookUp(opdef.name());
184 if (op_reg) continue;
185
186 tensorflow::OpRegistry::Global()->Register(
187 [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
188 *op_reg_data = tensorflow::OpRegistrationData(opdef);
189 return OkStatus();
190 });
191 }
192 return OkStatus();
193 }
194
195 } // namespace
196
RegisterAllCustomOps(const toco::TocoFlags & toco_flags)197 Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
198 // Register any custom OpDefs.
199 std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
200 toco_flags.custom_opdefs().end());
201 extra_tf_opdefs.push_back(kDetectionPostProcessOp);
202 extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
203 extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
204 return RegisterCustomBuiltinOps(extra_tf_opdefs);
205 }
206
PopulateQuantizationSpecs(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,mlir::quant::QuantizationSpecs * quant_specs,std::vector<string> * node_names,std::vector<string> * node_dtypes,std::vector<llvm::Optional<std::vector<int>>> * node_shapes,std::vector<llvm::Optional<double>> * node_mins,std::vector<llvm::Optional<double>> * node_maxs)207 Status PopulateQuantizationSpecs(
208 const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
209 mlir::quant::QuantizationSpecs* quant_specs,
210 std::vector<string>* node_names, std::vector<string>* node_dtypes,
211 std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
212 std::vector<llvm::Optional<double>>* node_mins,
213 std::vector<llvm::Optional<double>>* node_maxs) {
214 quant_specs->inference_input_type =
215 ConvertIODataTypeToDataType(toco_flags.inference_input_type());
216 tensorflow::DataType inference_type =
217 ConvertIODataTypeToDataType(toco_flags.inference_type());
218 // Use non-float flag `inference_input_type` to override the `inference_type`
219 // because we have to apply quantization to satisfy that.
220 if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) {
221 inference_type = quant_specs->inference_input_type;
222 }
223
224 for (auto& flag : model_flags.input_arrays()) {
225 node_names->push_back(flag.name());
226 // TOCO doesn't required `data_type` to be filled for every input.
227 // If it's not filled, make it an empty string so the importer will use
228 // the data type in the NodeDef.
229 auto toco_data_type = flag.data_type();
230 if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN) {
231 node_dtypes->push_back("");
232 } else {
233 node_dtypes->push_back(
234 DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
235 }
236 if (flag.shape().unknown_rank()) {
237 node_shapes->push_back(llvm::None);
238 } else {
239 node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
240 flag.shape().dims().end()));
241 }
242 // Currently, only UINT8 and INT8 require inputs stats
243 if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
244 if (flag.has_mean_value() && flag.has_std_value()) {
245 TF_ASSIGN_OR_RETURN(
246 auto min_max, InputStatsToMinMax(flag.mean_value(),
247 flag.std_value(), inference_type));
248 node_mins->push_back(min_max.first);
249 node_maxs->push_back(min_max.second);
250 } else {
251 node_mins->push_back(llvm::None);
252 node_maxs->push_back(llvm::None);
253 }
254 }
255 }
256
257 if (mlir::quant::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs,
258 inference_type, quant_specs)) {
259 return errors::InvalidArgument("Failed to get input quant spec.");
260 }
261
262 // Some extra flag related to post training quantization. If post-training
263 // quantization is enabled, `inference_type` and `inference_input_type` are
264 // not used by MLIR passes.
265 if (toco_flags.post_training_quantize()) {
266 quant_specs->weight_quantization = true;
267 quant_specs->disable_per_channel =
268 toco_flags.disable_per_channel_quantization();
269 if (toco_flags.quantize_to_float16()) {
270 quant_specs->inference_type = tensorflow::DT_HALF;
271 quant_specs->inference_input_type = tensorflow::DT_HALF;
272 } else {
273 quant_specs->inference_type = tensorflow::DT_QINT8;
274 quant_specs->inference_input_type = tensorflow::DT_QINT8;
275 }
276 } else {
277 // These flags are incompatible with post_training_quantize() as only
278 // QAT models can provide required ranges.
279 quant_specs->disable_infer_tensor_range =
280 toco_flags.disable_infer_tensor_range();
281 quant_specs->use_fake_quant_num_bits = toco_flags.use_fake_quant_num_bits();
282 }
283
284 // Add information about half-precision support if fp16 quantization applies.
285 // TODO(b/195945955): Add e2e test for this.
286 if (toco_flags.quantize_to_float16() || toco_flags.allow_bfloat16()) {
287 ReducedPrecisionSupport mask = ReducedPrecisionSupport::None;
288 if (toco_flags.quantize_to_float16()) {
289 mask |= ReducedPrecisionSupport::Float16Inference;
290 }
291 if (toco_flags.allow_bfloat16()) {
292 mask |= ReducedPrecisionSupport::Bfloat16Inference;
293 }
294 if (toco_flags.accumulation_type() == toco::IODataType::FLOAT16) {
295 mask |= ReducedPrecisionSupport::Float16Accumulation;
296 } else {
297 mask |= ReducedPrecisionSupport::Float32Accumulation;
298 }
299 quant_specs->support_mask = mask;
300 }
301
302 // Other flags.
303 if (toco_flags.has_default_ranges_min()) {
304 quant_specs->default_ranges.first = toco_flags.default_ranges_min();
305 }
306 if (toco_flags.has_default_ranges_max()) {
307 quant_specs->default_ranges.second = toco_flags.default_ranges_max();
308 }
309 if (toco_flags.enable_mlir_dynamic_range_quantizer()) {
310 quant_specs->enable_mlir_dynamic_range_quantizer = true;
311 }
312 return OkStatus();
313 }
314
315 // Dumps the op graph of the `module` to `filename` in DOT format.
DumpOpGraphToFile(mlir::ModuleOp module,const std::string & filename)316 Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
317 std::string error_message;
318 auto output = mlir::openOutputFile(filename, &error_message);
319 if (!error_message.empty()) {
320 return errors::InvalidArgument("Failed to open file in ", filename);
321 }
322 mlir::PassManager pm(module.getContext());
323 pm.addPass(mlir::createPrintOpGraphPass(output->os()));
324 if (failed(pm.run(module))) {
325 return errors::Unknown("Failed to dump Op Graph from MLIR module.");
326 }
327 output->keep();
328 return OkStatus();
329 }
330
ConvertMLIRToTFLiteFlatBuffer(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,mlir::OwningOpRef<mlir::ModuleOp> module,const mlir::TFL::PassConfig & pass_config,const std::unordered_set<std::string> & saved_model_tags,string * result,llvm::Optional<tensorflow::Session * > session)331 Status ConvertMLIRToTFLiteFlatBuffer(
332 const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
333 mlir::OwningOpRef<mlir::ModuleOp> module,
334 const mlir::TFL::PassConfig& pass_config,
335 const std::unordered_set<std::string>& saved_model_tags, string* result,
336 llvm::Optional<tensorflow::Session*> session) {
337 if (toco_flags.has_dump_graphviz_dir()) {
338 TF_RETURN_IF_ERROR(DumpOpGraphToFile(
339 module.get(),
340 // rename once we enable the new converter feature flag.
341 absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
342 }
343
344 mlir::TFL::PassConfig pass_config_copy = pass_config;
345 pass_config_copy.outline_tf_while = true;
346 auto status = ConvertTFExecutorToTFLOrFlatbuffer(
347 module.get(), /*export_to_mlir=*/false, toco_flags, pass_config_copy,
348 saved_model_tags, model_flags.saved_model_dir(), session, result);
349 if (toco_flags.has_dump_graphviz_dir()) {
350 TF_RETURN_IF_ERROR(DumpOpGraphToFile(
351 // rename once we enable the new converter feature flag.
352 module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(),
353 "/toco_AFTER_TRANSFORMATIONS.dot")));
354 }
355
356 return status;
357 }
358
WarningUnusedFlags(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags)359 void WarningUnusedFlags(const toco::ModelFlags& model_flags,
360 const toco::TocoFlags& toco_flags) {
361 if (toco_flags.output_format()) {
362 LOG(WARNING) << "Ignored output_format.";
363 }
364 if (toco_flags.drop_control_dependency()) {
365 LOG(WARNING) << "Ignored drop_control_dependency.";
366 }
367 if (toco_flags.reorder_across_fake_quant()) {
368 LOG(WARNING) << "Ignored reorder_across_fake_quant.";
369 }
370 if (model_flags.change_concat_input_ranges()) {
371 LOG(WARNING) << "Ignored change_concat_input_ranges.";
372 }
373 if (toco_flags.dump_graphviz_include_video()) {
374 LOG(WARNING) << "Ignored dump_graphviz_video.";
375 }
376 if (model_flags.allow_nonexistent_arrays()) {
377 LOG(WARNING) << "Allow allow_nonexistent_arrays.";
378 }
379 }
380
381 } // namespace internal
382 } // namespace tensorflow
383