xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This header file defines common utils used by TFLite transformation
17 // passes to work with tf.FakeQuant* ops.
18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_
19 #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_
20 
21 #include <string>
22 
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
26 #include "mlir/Support/LLVM.h"  // from @llvm-project
27 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
30 
31 namespace mlir {
32 namespace TFL {
33 
34 template <class TFFakeQuantOp>
35 struct FetchMinMaxAttrs {
36   using AttrType = FloatAttr;
operatorFetchMinMaxAttrs37   bool operator()(TFFakeQuantOp tf_op, AttrType &min_value,
38                   AttrType &max_value) const {
39     min_value = tf_op.minAttr();
40     max_value = tf_op.maxAttr();
41     return true;  // Successfully matched and fetched.
42   }
43 };
44 
45 template <class TFFakeQuantOp>
46 struct FetchConstantMinMaxInputs {
47   using AttrType = DenseFPElementsAttr;
operatorFetchConstantMinMaxInputs48   bool operator()(TFFakeQuantOp tf_op, AttrType &min_value,
49                   AttrType &max_value) const {
50     Value min = tf_op.min(), max = tf_op.max();
51     if (!matchPattern(min, m_Constant(&min_value))) {
52       return false;
53     }
54     if (!matchPattern(max, m_Constant(&max_value))) {
55       return false;
56     }
57     return true;  // Successfully matched and fetched.
58   }
59 };
60 
61 // Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the
62 // tf.FakeQyantWithMinMax{Vars|VarsPerChannel|Args}Op
63 // before the op being constant folded. Since the constant
64 // folding logic will use a "arith.constant" op to replace the
65 // "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
66 // the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
67 // convert the output type to the next op. Here are the transformations:
68 //
69 // input   min cst       max cst          input   min cst       max cst
70 //  \       |             |                \       |             |
71 //   \  (tf.Identity) (tf.Identity)   =>    \  (tf.Identity) (tf.Identity)
72 //    \     |             |                  \     |             |
73 //       tf.FakeQuantWithMinMaxVars       tf.FakeQuantWithMinMaxVars
74 //                   |                                 |
75 //                                                tfl.quantize
76 //                                                     |
77 //                                                tfl.dequantize
78 //                                                     |
79 // If the input is a constant, the result pattern will eventually converted to
80 //
81 //            quant-emulated input
82 //                   |
83 //               tfl.quantize
84 //                   |
85 //              tfl.dequantize
86 //                   |
87 //
88 //
89 // Warns if the (most likely unwanted, currently not quite correctly handled)
90 // case of back-to-back tf.FakeQuant occurs
91 //
92 //             tf.FakeQuant*
93 //                   |
94 //             tf.FakeQuant*
95 //
96 template <typename TFFakeQuantOp, bool PerAxis, class FetchMinMax>
97 class InsertTFLQuantOpsAfterTFFakeQuantOp {
98  public:
InsertTFLQuantOpsAfterTFFakeQuantOp(bool use_fake_quant_num_bits)99   explicit InsertTFLQuantOpsAfterTFFakeQuantOp(bool use_fake_quant_num_bits)
100       : use_fake_quant_num_bits_(use_fake_quant_num_bits) {}
101 
102   FetchMinMax fetch_min_max_;
103 
104   using FetchAttrType = typename FetchMinMax::AttrType;
matchAndRewrite(TFFakeQuantOp tf_op,OpBuilder & rewriter)105   LogicalResult matchAndRewrite(TFFakeQuantOp tf_op,
106                                 OpBuilder &rewriter) const {
107     // We don't want to insert quantize/dequantize if the quantize op exists.
108     auto res = tf_op.outputs();
109     if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin())) {
110       return failure();
111     }
112 
113     // Extract the min/max constant values from the operands. We also consider
114     // a special case that there are tf.Identity ops between the min/max
115     // constants and the tf.FakeQuantWithMinMaxVarsOp.
116 
117     FetchAttrType min_value, max_value;
118     if (!fetch_min_max_(tf_op, min_value, max_value)) {
119       return failure();
120     }
121 
122     int quant_dim = -1;
123     if (PerAxis) {
124       // This is a special case that the quant_dim is the last dimensions.
125       quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
126     }
127     // Use the min/max from the operands and the num_bits and narrow_range
128     // attribute to create the quantization parameter for the new quantize op.
129     rewriter.setInsertionPointAfter(tf_op.getOperation());
130     IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.num_bits());
131     BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
132     Type res_type = tf_op.getType();
133     TypeAttr qtype = quant::GetQuantizedTypeAttr(
134         rewriter, res_type, min_value, max_value, quant_dim, num_bits,
135         narrow_range, /*is_signed=*/false, /*legacy_float_scale=*/false,
136         use_fake_quant_num_bits_);
137     if (!qtype) {
138       return failure();
139     }
140 
141     // Finally, use the quantization parameter to create the quantize and
142     // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
143     // and its users.
144     Value value = tf_op.outputs();
145     auto quantize = rewriter.create<TFL::QuantizeOp>(
146         tf_op.getLoc(), qtype.getValue(), value, qtype);
147     auto dequantize = rewriter.create<TFL::DequantizeOp>(
148         tf_op.getLoc(), res_type, quantize.output());
149     value.replaceAllUsesWith(dequantize);
150     quantize.getOperation()->replaceUsesOfWith(dequantize, value);
151 
152     return success();
153   }
154 
155   bool use_fake_quant_num_bits_;
156 };
157 
158 // Removes the wrapper of the tf.FakeQuant* ops and creates the tfl.quantize
159 // and tfl.dequantize pairs before tf.FakeQuant* being foled.
160 LogicalResult ConvertFakeQuantOps(func::FuncOp func, MLIRContext *ctx,
161                                   bool use_fake_quant_num_bits = false);
162 
163 // Returns the names of all the considered tf.FakeQuant* ops.
164 std::vector<std::string> AllTfFakeQuantOps();
165 
166 }  // namespace TFL
167 }  // namespace mlir
168 
169 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_
170