1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_COMMON_TFL_PASS_CONFIG_H_ 17 #define TENSORFLOW_COMPILER_MLIR_LITE_COMMON_TFL_PASS_CONFIG_H_ 18 19 #include <string> 20 #include <utility> 21 #include <vector> 22 23 #include "absl/strings/str_join.h" 24 #include "llvm/ADT/ArrayRef.h" 25 #include "llvm/Support/raw_ostream.h" 26 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" 27 28 namespace mlir { 29 namespace TFL { 30 31 // A config that controls which passes get run as part TFLite converter. 32 struct PassConfig { PassConfigPassConfig33 explicit PassConfig(quant::QuantizationSpecs specs) 34 : emit_builtin_tflite_ops(true), 35 lower_tensor_list_ops(false), 36 trim_functions_allowlist({}), 37 quant_specs(std::move(specs)), 38 form_clusters(false), 39 unfold_batch_matmul(true), 40 shape_inference(true), 41 runtime_verification(true), 42 enable_tflite_variables(false), 43 disable_variable_freezing(false), 44 unfold_large_splat_constant(false), 45 guarantee_all_funcs_one_use(false), 46 enable_hlo_to_tf_conversion(false), 47 enable_dynamic_update_slice(false), 48 preserve_assert_op(false) {} 49 50 // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be 51 // added, which produces TF Lite ops. 52 bool emit_builtin_tflite_ops; 53 // If `lower_tensor_list_ops` is true, tensorlist ops will be lowered to basic 54 // TF ops before legalization to TF Lite dialect. 55 bool lower_tensor_list_ops; 56 // The allowlist of functions that would be preserved after trimming. 57 llvm::ArrayRef<std::string> trim_functions_allowlist; 58 // All information about quantization. 59 quant::QuantizationSpecs quant_specs; 60 // If `form_clusters` is true , clusters are formed by grouping consecutive 61 // ops of the same device, under a `tf_device.launch` op. 62 bool form_clusters; 63 // if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set 64 // of tfl.fully_connected ops. 65 bool unfold_batch_matmul; 66 // Whether to outline WhileOp at the end of the pipeline. 67 bool outline_tf_while = false; 68 // Whether to do shape inference. 69 bool shape_inference; 70 // Whether to do TFLite runtime verification. 71 bool runtime_verification; 72 // Whether to enable TFLite variables or not, this will allow 73 // mutable variables and produce ReadVariable/AssignVariable ops in TFLite. 74 bool enable_tflite_variables; 75 // Whether to disable the variable freezing pass or not. 76 // By default we freeze all variables and disallow mutable variables. When 77 // 'enable_tflite_variables' is true then we allow mutable variable only. 78 bool disable_variable_freezing; 79 // Whether to unfold large splat constant tensors and replace them with 80 // fill operation. 81 bool unfold_large_splat_constant; 82 // Whether to run the `GuaranteeAllFuncsOneUsePass` to ensure each function 83 // has a single use. 84 bool guarantee_all_funcs_one_use; 85 // Whether to enable the hlo to tf conversion. 86 bool enable_hlo_to_tf_conversion; 87 // Whether to enable to use DynamicUpdateSlice op. 88 bool enable_dynamic_update_slice; 89 // Whether to preserve AssertOp during legalization. 90 bool preserve_assert_op; 91 }; 92 93 inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, 94 const PassConfig& pass_config) { 95 return os << "emit_builtin_tflite_ops: " 96 << pass_config.emit_builtin_tflite_ops 97 << "\nlower_tensor_list_ops: " << pass_config.lower_tensor_list_ops 98 << "\ntrim_functions_allowlist: " 99 << absl::StrJoin(pass_config.trim_functions_allowlist.vec(), ",") 100 << "\nform_clusters: " << pass_config.form_clusters 101 << "\nunfold_batch_matmul: " << pass_config.unfold_batch_matmul 102 << "\noutline_tf_while: " << pass_config.outline_tf_while 103 << "\nshape_inference: " << pass_config.shape_inference 104 << "\nruntime_verification: " << pass_config.runtime_verification 105 << "\nenable_tflite_variables: " 106 << pass_config.enable_tflite_variables 107 << "\ndisable_variable_freezing: " 108 << pass_config.disable_variable_freezing 109 << "\nunfold_large_splat_constant: " 110 << pass_config.unfold_large_splat_constant 111 << "\nguarantee_all_funcs_one_use: " 112 << pass_config.guarantee_all_funcs_one_use 113 << "\nenable_hlo_to_tf_conversion: " 114 << pass_config.enable_hlo_to_tf_conversion << "\n"; 115 } 116 117 } // namespace TFL 118 } // namespace mlir 119 120 #endif // TENSORFLOW_COMPILER_MLIR_LITE_COMMON_TFL_PASS_CONFIG_H_ 121