xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
16 
17 #include <ostream>
18 #include <string>
19 #include <unordered_set>
20 #include <utility>
21 
22 #include "llvm/Support/ToolOutputFile.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
28 #include "mlir/Transforms/ViewOpGraph.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
30 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
31 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
32 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
34 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
35 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
36 #include "tensorflow/core/framework/graph.pb.h"
37 #include "tensorflow/core/framework/types.pb.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/platform/status.h"
40 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
41 #include "tensorflow/lite/toco/model_flags.pb.h"
42 #include "tensorflow/lite/toco/toco_flags.pb.h"
43 #include "tensorflow/lite/toco/types.pb.h"
44 #include "tensorflow/lite/tools/optimize/reduced_precision_support.h"
45 #include "tensorflow/stream_executor/lib/statusor.h"
46 
47 using stream_executor::port::StatusOr;
48 
49 namespace tensorflow {
50 namespace internal {
51 namespace {
52 
53 using ::mlir::quant::ReducedPrecisionSupport;
54 
55 // Op def string for TFLite_Detection_PostProcess Op.
56 const char kDetectionPostProcessOp[] =
57     "name: 'TFLite_Detection_PostProcess' input_arg: { name: "
58     "'raw_outputs/box_encodings' type: DT_FLOAT } input_arg: { name: "
59     "'raw_outputs/class_predictions' type: DT_FLOAT } input_arg: { name: "
60     "'anchors' type: DT_FLOAT } output_arg: { name: "
61     "'TFLite_Detection_PostProcess' type: DT_FLOAT } output_arg: { name: "
62     "'TFLite_Detection_PostProcess:1' type: DT_FLOAT } output_arg: { name: "
63     "'TFLite_Detection_PostProcess:2' type: DT_FLOAT } output_arg: { name: "
64     "'TFLite_Detection_PostProcess:3' type: DT_FLOAT } attr : { name: "
65     "'h_scale' type: 'float'} attr : { name: 'max_classes_per_detection' "
66     "type: 'int'} attr : { name: 'max_detections' type: 'int'} attr : { "
67     "name: 'nms_iou_threshold' type: 'float'} attr : { name: "
68     "'nms_score_threshold' type: 'float'} attr : { name: 'num_classes' type: "
69     "'int'} attr : { name: 'w_scale' type: 'float'} attr : { name: 'x_scale' "
70     "type: 'float'} attr : { name: 'y_scale' type: 'float'} attr { name: "
71     "'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
72     "name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
73 
74 const char kUnidirectionalSequenceLstmOp[] =
75     "name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
76     "DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
77     "input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
78     "name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
79     "'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
80     "'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
81     "'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
82     "'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
83     "'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
84     "'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
85     "'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
86     "'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
87     "type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
88     "input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
89     "'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
90     "type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
91     "input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
92     "name: 'InputCellStateTensor' type: DT_FLOAT } "
93     "output_arg: { name: 'Concat' type: DT_FLOAT} "
94     "output_arg: { name: "
95     "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
96     "attr : { name: '_tflite_input_indices' type: 'list(int)'}";
97 
98 const char kUnidirectionalSequenceRnnOp[] =
99     "name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
100     "DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
101     "input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
102     "name: 'Bias' type: DT_FLOAT} "
103     "input_arg: { name: 'HiddenState' type: DT_FLOAT} "
104     "output_arg: { name: "
105     "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
106     "DT_FLOAT} "
107     "attr : { name: '_tflite_input_indices' type: 'list(int)'}";
108 
109 // Converts the toco::IODataType to tensorflow::DataType. Only contains the
110 // conversion mapping for constants defined in TFLite Python API.
ConvertIODataTypeToDataType(toco::IODataType dtype)111 DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
112   switch (dtype) {
113     case toco::IODataType::FLOAT:
114       return DT_FLOAT;
115     case toco::IODataType::FLOAT16:
116       return DT_HALF;
117     case toco::IODataType::FLOAT64:
118       return DT_DOUBLE;
119     case toco::IODataType::QUANTIZED_UINT8:
120       return DT_QUINT8;
121     case toco::IODataType::QUANTIZED_INT8:
122       return DT_QINT8;
123     case toco::IODataType::QUANTIZED_INT16:
124       return DT_QINT16;
125     case toco::IODataType::INT8:
126       return DT_INT8;
127     case toco::IODataType::INT16:
128       return DT_INT16;
129     case toco::IODataType::UINT16:
130       return DT_UINT16;
131     case toco::IODataType::INT32:
132       return DT_INT32;
133     case toco::IODataType::UINT32:
134       return DT_UINT32;
135     case toco::IODataType::INT64:
136       return DT_INT64;
137     case toco::IODataType::UINT8:
138       return DT_UINT8;
139     case toco::IODataType::UINT64:
140       return DT_UINT64;
141     case toco::IODataType::STRING:
142       return DT_STRING;
143     case toco::IODataType::BOOL:
144       return DT_BOOL;
145     case toco::IODataType::COMPLEX64:
146       return DT_COMPLEX64;
147     case toco::IODataType::COMPLEX128:
148       return DT_COMPLEX128;
149     case toco::IODataType::RESOURCE:
150       return DT_RESOURCE;
151     case toco::IODataType::VARIANT:
152       return DT_VARIANT;
153     default:
154       return DT_INVALID;
155   }
156 }
157 
InputStatsToMinMax(double mean,double std,DataType type)158 StatusOr<std::pair<double, double>> InputStatsToMinMax(double mean, double std,
159                                                        DataType type) {
160   // Only qint8 and quint8 are considered here.
161   double qmin, qmax;
162   if (type == DT_QUINT8) {
163     qmin = 0.0;
164     qmax = 255.0;
165   } else if (type == DT_QINT8) {
166     qmin = -128.0;
167     qmax = 127.0;
168   } else {
169     return errors::InvalidArgument("Only int8 and uint8 are considered.");
170   }
171   return std::make_pair((qmin - mean) / std, (qmax - mean) / std);
172 }
173 
RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs)174 Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
175   for (const auto& tf_opdefs_string : extra_tf_opdefs) {
176     tensorflow::OpDef opdef;
177     if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
178                                                            &opdef)) {
179       return errors::InvalidArgument("fail to parse extra OpDef");
180     }
181     // Make sure the op is not already registered. If registered continue.
182     const OpRegistrationData* op_reg =
183         tensorflow::OpRegistry::Global()->LookUp(opdef.name());
184     if (op_reg) continue;
185 
186     tensorflow::OpRegistry::Global()->Register(
187         [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
188           *op_reg_data = tensorflow::OpRegistrationData(opdef);
189           return OkStatus();
190         });
191   }
192   return OkStatus();
193 }
194 
195 }  // namespace
196 
RegisterAllCustomOps(const toco::TocoFlags & toco_flags)197 Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
198   // Register any custom OpDefs.
199   std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
200                                       toco_flags.custom_opdefs().end());
201   extra_tf_opdefs.push_back(kDetectionPostProcessOp);
202   extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
203   extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
204   return RegisterCustomBuiltinOps(extra_tf_opdefs);
205 }
206 
PopulateQuantizationSpecs(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,mlir::quant::QuantizationSpecs * quant_specs,std::vector<string> * node_names,std::vector<string> * node_dtypes,std::vector<llvm::Optional<std::vector<int>>> * node_shapes,std::vector<llvm::Optional<double>> * node_mins,std::vector<llvm::Optional<double>> * node_maxs)207 Status PopulateQuantizationSpecs(
208     const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
209     mlir::quant::QuantizationSpecs* quant_specs,
210     std::vector<string>* node_names, std::vector<string>* node_dtypes,
211     std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
212     std::vector<llvm::Optional<double>>* node_mins,
213     std::vector<llvm::Optional<double>>* node_maxs) {
214   quant_specs->inference_input_type =
215       ConvertIODataTypeToDataType(toco_flags.inference_input_type());
216   tensorflow::DataType inference_type =
217       ConvertIODataTypeToDataType(toco_flags.inference_type());
218   // Use non-float flag `inference_input_type` to override the `inference_type`
219   // because we have to apply quantization to satisfy that.
220   if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) {
221     inference_type = quant_specs->inference_input_type;
222   }
223 
224   for (auto& flag : model_flags.input_arrays()) {
225     node_names->push_back(flag.name());
226     // TOCO doesn't required `data_type` to be filled for every input.
227     // If it's not filled, make it an empty string so the importer will use
228     // the data type in the NodeDef.
229     auto toco_data_type = flag.data_type();
230     if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN) {
231       node_dtypes->push_back("");
232     } else {
233       node_dtypes->push_back(
234           DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
235     }
236     if (flag.shape().unknown_rank()) {
237       node_shapes->push_back(llvm::None);
238     } else {
239       node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
240                                               flag.shape().dims().end()));
241     }
242     // Currently, only UINT8 and INT8 require inputs stats
243     if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
244       if (flag.has_mean_value() && flag.has_std_value()) {
245         TF_ASSIGN_OR_RETURN(
246             auto min_max, InputStatsToMinMax(flag.mean_value(),
247                                              flag.std_value(), inference_type));
248         node_mins->push_back(min_max.first);
249         node_maxs->push_back(min_max.second);
250       } else {
251         node_mins->push_back(llvm::None);
252         node_maxs->push_back(llvm::None);
253       }
254     }
255   }
256 
257   if (mlir::quant::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs,
258                                           inference_type, quant_specs)) {
259     return errors::InvalidArgument("Failed to get input quant spec.");
260   }
261 
262   // Some extra flag related to post training quantization. If post-training
263   // quantization is enabled, `inference_type` and `inference_input_type` are
264   // not used by MLIR passes.
265   if (toco_flags.post_training_quantize()) {
266     quant_specs->weight_quantization = true;
267     quant_specs->disable_per_channel =
268         toco_flags.disable_per_channel_quantization();
269     if (toco_flags.quantize_to_float16()) {
270       quant_specs->inference_type = tensorflow::DT_HALF;
271       quant_specs->inference_input_type = tensorflow::DT_HALF;
272     } else {
273       quant_specs->inference_type = tensorflow::DT_QINT8;
274       quant_specs->inference_input_type = tensorflow::DT_QINT8;
275     }
276   } else {
277     // These flags are incompatible with post_training_quantize() as only
278     // QAT models can provide required ranges.
279     quant_specs->disable_infer_tensor_range =
280         toco_flags.disable_infer_tensor_range();
281     quant_specs->use_fake_quant_num_bits = toco_flags.use_fake_quant_num_bits();
282   }
283 
284   // Add information about half-precision support if fp16 quantization applies.
285   // TODO(b/195945955): Add e2e test for this.
286   if (toco_flags.quantize_to_float16() || toco_flags.allow_bfloat16()) {
287     ReducedPrecisionSupport mask = ReducedPrecisionSupport::None;
288     if (toco_flags.quantize_to_float16()) {
289       mask |= ReducedPrecisionSupport::Float16Inference;
290     }
291     if (toco_flags.allow_bfloat16()) {
292       mask |= ReducedPrecisionSupport::Bfloat16Inference;
293     }
294     if (toco_flags.accumulation_type() == toco::IODataType::FLOAT16) {
295       mask |= ReducedPrecisionSupport::Float16Accumulation;
296     } else {
297       mask |= ReducedPrecisionSupport::Float32Accumulation;
298     }
299     quant_specs->support_mask = mask;
300   }
301 
302   // Other flags.
303   if (toco_flags.has_default_ranges_min()) {
304     quant_specs->default_ranges.first = toco_flags.default_ranges_min();
305   }
306   if (toco_flags.has_default_ranges_max()) {
307     quant_specs->default_ranges.second = toco_flags.default_ranges_max();
308   }
309   if (toco_flags.enable_mlir_dynamic_range_quantizer()) {
310     quant_specs->enable_mlir_dynamic_range_quantizer = true;
311   }
312   return OkStatus();
313 }
314 
315 // Dumps the op graph of the `module` to `filename` in DOT format.
DumpOpGraphToFile(mlir::ModuleOp module,const std::string & filename)316 Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
317   std::string error_message;
318   auto output = mlir::openOutputFile(filename, &error_message);
319   if (!error_message.empty()) {
320     return errors::InvalidArgument("Failed to open file in ", filename);
321   }
322   mlir::PassManager pm(module.getContext());
323   pm.addPass(mlir::createPrintOpGraphPass(output->os()));
324   if (failed(pm.run(module))) {
325     return errors::Unknown("Failed to dump Op Graph from MLIR module.");
326   }
327   output->keep();
328   return OkStatus();
329 }
330 
ConvertMLIRToTFLiteFlatBuffer(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,mlir::OwningOpRef<mlir::ModuleOp> module,const mlir::TFL::PassConfig & pass_config,const std::unordered_set<std::string> & saved_model_tags,string * result,llvm::Optional<tensorflow::Session * > session)331 Status ConvertMLIRToTFLiteFlatBuffer(
332     const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
333     mlir::OwningOpRef<mlir::ModuleOp> module,
334     const mlir::TFL::PassConfig& pass_config,
335     const std::unordered_set<std::string>& saved_model_tags, string* result,
336     llvm::Optional<tensorflow::Session*> session) {
337   if (toco_flags.has_dump_graphviz_dir()) {
338     TF_RETURN_IF_ERROR(DumpOpGraphToFile(
339         module.get(),
340         // rename once we enable the new converter feature flag.
341         absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
342   }
343 
344   mlir::TFL::PassConfig pass_config_copy = pass_config;
345   pass_config_copy.outline_tf_while = true;
346   auto status = ConvertTFExecutorToTFLOrFlatbuffer(
347       module.get(), /*export_to_mlir=*/false, toco_flags, pass_config_copy,
348       saved_model_tags, model_flags.saved_model_dir(), session, result);
349   if (toco_flags.has_dump_graphviz_dir()) {
350     TF_RETURN_IF_ERROR(DumpOpGraphToFile(
351         // rename once we enable the new converter feature flag.
352         module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(),
353                                    "/toco_AFTER_TRANSFORMATIONS.dot")));
354   }
355 
356   return status;
357 }
358 
WarningUnusedFlags(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags)359 void WarningUnusedFlags(const toco::ModelFlags& model_flags,
360                         const toco::TocoFlags& toco_flags) {
361   if (toco_flags.output_format()) {
362     LOG(WARNING) << "Ignored output_format.";
363   }
364   if (toco_flags.drop_control_dependency()) {
365     LOG(WARNING) << "Ignored drop_control_dependency.";
366   }
367   if (toco_flags.reorder_across_fake_quant()) {
368     LOG(WARNING) << "Ignored reorder_across_fake_quant.";
369   }
370   if (model_flags.change_concat_input_ranges()) {
371     LOG(WARNING) << "Ignored change_concat_input_ranges.";
372   }
373   if (toco_flags.dump_graphviz_include_video()) {
374     LOG(WARNING) << "Ignored dump_graphviz_video.";
375   }
376   if (model_flags.allow_nonexistent_arrays()) {
377     LOG(WARNING) << "Allow allow_nonexistent_arrays.";
378   }
379 }
380 
381 }  // namespace internal
382 }  // namespace tensorflow
383