1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This header file defines common utils used by TFLite transformation 17 // passes to work with tf.FakeQuant* ops. 18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_ 19 #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_ 20 21 #include <string> 22 23 #include "mlir/IR/Attributes.h" // from @llvm-project 24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 25 #include "mlir/IR/MLIRContext.h" // from @llvm-project 26 #include "mlir/Support/LLVM.h" // from @llvm-project 27 #include "mlir/Support/LogicalResult.h" // from @llvm-project 28 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" 29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" 30 31 namespace mlir { 32 namespace TFL { 33 34 template <class TFFakeQuantOp> 35 struct FetchMinMaxAttrs { 36 using AttrType = FloatAttr; operatorFetchMinMaxAttrs37 bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, 38 AttrType &max_value) const { 39 min_value = tf_op.minAttr(); 40 max_value = tf_op.maxAttr(); 41 return true; // Successfully matched and fetched. 42 } 43 }; 44 45 template <class TFFakeQuantOp> 46 struct FetchConstantMinMaxInputs { 47 using AttrType = DenseFPElementsAttr; operatorFetchConstantMinMaxInputs48 bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, 49 AttrType &max_value) const { 50 Value min = tf_op.min(), max = tf_op.max(); 51 if (!matchPattern(min, m_Constant(&min_value))) { 52 return false; 53 } 54 if (!matchPattern(max, m_Constant(&max_value))) { 55 return false; 56 } 57 return true; // Successfully matched and fetched. 58 } 59 }; 60 61 // Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the 62 // tf.FakeQyantWithMinMax{Vars|VarsPerChannel|Args}Op 63 // before the op being constant folded. Since the constant 64 // folding logic will use a "arith.constant" op to replace the 65 // "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve 66 // the quantization parameters as a TypeAttr and "tfl.dequantize" op used to 67 // convert the output type to the next op. Here are the transformations: 68 // 69 // input min cst max cst input min cst max cst 70 // \ | | \ | | 71 // \ (tf.Identity) (tf.Identity) => \ (tf.Identity) (tf.Identity) 72 // \ | | \ | | 73 // tf.FakeQuantWithMinMaxVars tf.FakeQuantWithMinMaxVars 74 // | | 75 // tfl.quantize 76 // | 77 // tfl.dequantize 78 // | 79 // If the input is a constant, the result pattern will eventually converted to 80 // 81 // quant-emulated input 82 // | 83 // tfl.quantize 84 // | 85 // tfl.dequantize 86 // | 87 // 88 // 89 // Warns if the (most likely unwanted, currently not quite correctly handled) 90 // case of back-to-back tf.FakeQuant occurs 91 // 92 // tf.FakeQuant* 93 // | 94 // tf.FakeQuant* 95 // 96 template <typename TFFakeQuantOp, bool PerAxis, class FetchMinMax> 97 class InsertTFLQuantOpsAfterTFFakeQuantOp { 98 public: InsertTFLQuantOpsAfterTFFakeQuantOp(bool use_fake_quant_num_bits)99 explicit InsertTFLQuantOpsAfterTFFakeQuantOp(bool use_fake_quant_num_bits) 100 : use_fake_quant_num_bits_(use_fake_quant_num_bits) {} 101 102 FetchMinMax fetch_min_max_; 103 104 using FetchAttrType = typename FetchMinMax::AttrType; matchAndRewrite(TFFakeQuantOp tf_op,OpBuilder & rewriter)105 LogicalResult matchAndRewrite(TFFakeQuantOp tf_op, 106 OpBuilder &rewriter) const { 107 // We don't want to insert quantize/dequantize if the quantize op exists. 108 auto res = tf_op.outputs(); 109 if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin())) { 110 return failure(); 111 } 112 113 // Extract the min/max constant values from the operands. We also consider 114 // a special case that there are tf.Identity ops between the min/max 115 // constants and the tf.FakeQuantWithMinMaxVarsOp. 116 117 FetchAttrType min_value, max_value; 118 if (!fetch_min_max_(tf_op, min_value, max_value)) { 119 return failure(); 120 } 121 122 int quant_dim = -1; 123 if (PerAxis) { 124 // This is a special case that the quant_dim is the last dimensions. 125 quant_dim = res.getType().template cast<ShapedType>().getRank() - 1; 126 } 127 // Use the min/max from the operands and the num_bits and narrow_range 128 // attribute to create the quantization parameter for the new quantize op. 129 rewriter.setInsertionPointAfter(tf_op.getOperation()); 130 IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.num_bits()); 131 BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range()); 132 Type res_type = tf_op.getType(); 133 TypeAttr qtype = quant::GetQuantizedTypeAttr( 134 rewriter, res_type, min_value, max_value, quant_dim, num_bits, 135 narrow_range, /*is_signed=*/false, /*legacy_float_scale=*/false, 136 use_fake_quant_num_bits_); 137 if (!qtype) { 138 return failure(); 139 } 140 141 // Finally, use the quantization parameter to create the quantize and 142 // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp 143 // and its users. 144 Value value = tf_op.outputs(); 145 auto quantize = rewriter.create<TFL::QuantizeOp>( 146 tf_op.getLoc(), qtype.getValue(), value, qtype); 147 auto dequantize = rewriter.create<TFL::DequantizeOp>( 148 tf_op.getLoc(), res_type, quantize.output()); 149 value.replaceAllUsesWith(dequantize); 150 quantize.getOperation()->replaceUsesOfWith(dequantize, value); 151 152 return success(); 153 } 154 155 bool use_fake_quant_num_bits_; 156 }; 157 158 // Removes the wrapper of the tf.FakeQuant* ops and creates the tfl.quantize 159 // and tfl.dequantize pairs before tf.FakeQuant* being foled. 160 LogicalResult ConvertFakeQuantOps(func::FuncOp func, MLIRContext *ctx, 161 bool use_fake_quant_num_bits = false); 162 163 // Returns the names of all the considered tf.FakeQuant* ops. 164 std::vector<std::string> AllTfFakeQuantOps(); 165 166 } // namespace TFL 167 } // namespace mlir 168 169 #endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_ 170