xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_COMMON_TFL_PASS_CONFIG_H_
17 #define TENSORFLOW_COMPILER_MLIR_LITE_COMMON_TFL_PASS_CONFIG_H_
18 
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/strings/str_join.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
27 
28 namespace mlir {
29 namespace TFL {
30 
31 // A config that controls which passes get run as part TFLite converter.
32 struct PassConfig {
PassConfigPassConfig33   explicit PassConfig(quant::QuantizationSpecs specs)
34       : emit_builtin_tflite_ops(true),
35         lower_tensor_list_ops(false),
36         trim_functions_allowlist({}),
37         quant_specs(std::move(specs)),
38         form_clusters(false),
39         unfold_batch_matmul(true),
40         shape_inference(true),
41         runtime_verification(true),
42         enable_tflite_variables(false),
43         disable_variable_freezing(false),
44         unfold_large_splat_constant(false),
45         guarantee_all_funcs_one_use(false),
46         enable_hlo_to_tf_conversion(false),
47         enable_dynamic_update_slice(false),
48         preserve_assert_op(false) {}
49 
50   // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
51   // added, which produces TF Lite ops.
52   bool emit_builtin_tflite_ops;
53   // If `lower_tensor_list_ops` is true, tensorlist ops will be lowered to basic
54   // TF ops before legalization to TF Lite dialect.
55   bool lower_tensor_list_ops;
56   // The allowlist of functions that would be preserved after trimming.
57   llvm::ArrayRef<std::string> trim_functions_allowlist;
58   // All information about quantization.
59   quant::QuantizationSpecs quant_specs;
60   // If `form_clusters` is true , clusters are formed by grouping consecutive
61   // ops of the same device, under a `tf_device.launch` op.
62   bool form_clusters;
63   // if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
64   // of tfl.fully_connected ops.
65   bool unfold_batch_matmul;
66   // Whether to outline WhileOp at the end of the pipeline.
67   bool outline_tf_while = false;
68   // Whether to do shape inference.
69   bool shape_inference;
70   // Whether to do TFLite runtime verification.
71   bool runtime_verification;
72   // Whether to enable TFLite variables or not, this will allow
73   // mutable variables and produce ReadVariable/AssignVariable ops in TFLite.
74   bool enable_tflite_variables;
75   // Whether to disable the variable freezing pass or not.
76   // By default we freeze all variables and disallow mutable variables. When
77   // 'enable_tflite_variables' is true then we allow mutable variable only.
78   bool disable_variable_freezing;
79   // Whether to unfold large splat constant tensors and replace them with
80   // fill operation.
81   bool unfold_large_splat_constant;
82   // Whether to run the `GuaranteeAllFuncsOneUsePass` to ensure each function
83   // has a single use.
84   bool guarantee_all_funcs_one_use;
85   // Whether to enable the hlo to tf conversion.
86   bool enable_hlo_to_tf_conversion;
87   // Whether to enable to use DynamicUpdateSlice op.
88   bool enable_dynamic_update_slice;
89   // Whether to preserve AssertOp during legalization.
90   bool preserve_assert_op;
91 };
92 
93 inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os,
94                                      const PassConfig& pass_config) {
95   return os << "emit_builtin_tflite_ops: "
96             << pass_config.emit_builtin_tflite_ops
97             << "\nlower_tensor_list_ops: " << pass_config.lower_tensor_list_ops
98             << "\ntrim_functions_allowlist: "
99             << absl::StrJoin(pass_config.trim_functions_allowlist.vec(), ",")
100             << "\nform_clusters: " << pass_config.form_clusters
101             << "\nunfold_batch_matmul: " << pass_config.unfold_batch_matmul
102             << "\noutline_tf_while: " << pass_config.outline_tf_while
103             << "\nshape_inference: " << pass_config.shape_inference
104             << "\nruntime_verification: " << pass_config.runtime_verification
105             << "\nenable_tflite_variables: "
106             << pass_config.enable_tflite_variables
107             << "\ndisable_variable_freezing: "
108             << pass_config.disable_variable_freezing
109             << "\nunfold_large_splat_constant: "
110             << pass_config.unfold_large_splat_constant
111             << "\nguarantee_all_funcs_one_use: "
112             << pass_config.guarantee_all_funcs_one_use
113             << "\nenable_hlo_to_tf_conversion: "
114             << pass_config.enable_hlo_to_tf_conversion << "\n";
115 }
116 
117 }  // namespace TFL
118 }  // namespace mlir
119 
120 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_COMMON_TFL_PASS_CONFIG_H_
121