1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Copied and modified from
16 // //third_party/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc
17 // This transformation pass applies quantization propagation on TF dialect.
18
19 #include <algorithm>
20 #include <memory>
21 #include <utility>
22
23 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
24 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
26 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
27 #include "mlir/IR/Operation.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
31 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
32 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
33 #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36
37 //===----------------------------------------------------------------------===//
38 // The prepare-quantize-drq Pass.
39 //
40 namespace mlir {
41 namespace quant {
42
43 namespace {
44
45 using QuantizationUnits = llvm::SetVector<std::pair<Operation*, int>>;
46
47 // Applies prepare quantization on the model in TF dialect for dynamic range
48 // quantization case.
49 class PrepareQuantizeDRQPass
50 : public PassWrapper<PrepareQuantizeDRQPass, OperationPass<func::FuncOp>> {
getDependentDialects(DialectRegistry & registry) const51 void getDependentDialects(DialectRegistry& registry) const override {
52 registry.insert<TF::TensorFlowDialect, ::mlir::quant::QuantizationDialect,
53 ::mlir::quantfork::QuantizationForkDialect>();
54 }
55
56 public:
57 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareQuantizeDRQPass)
58
59 // Constructor used by the PassRegistration and enforce int8 quantization.
60 // This is only used by test.
PrepareQuantizeDRQPass()61 explicit PrepareQuantizeDRQPass() {
62 quant_specs_.inference_type = tensorflow::DT_QINT8;
63 }
64
65 // Constructor used by manually creating the pass.
PrepareQuantizeDRQPass(const QuantizationSpecs & quant_specs)66 explicit PrepareQuantizeDRQPass(const QuantizationSpecs& quant_specs)
67 : quant_specs_(quant_specs) {}
68
getArgument() const69 StringRef getArgument() const final {
70 // This is the argument used to refer to the pass in
71 // the textual format (on the commandline for example).
72 return "quant-prepare-quantize-drq";
73 }
getDescription() const74 StringRef getDescription() const final {
75 // This is a brief description of the pass.
76 return "Prepare TF dialect for dynamic range quantization";
77 }
78
79 // The function might contain stats ops which are redundant for processing
80 // dynamic range quantization. And stats ops may cause conflict while
81 // processing the function for dynamic range quantization. Therefore, this
82 // method preprocess the function to remove all stats ops.
83 void removeAllStatsOp(func::FuncOp func);
84
85 void runOnOperation() override;
86
87 private:
88 QuantizationSpecs quant_specs_;
89 };
90
91 // If the weight is applicable to dynamic range quantization, insert Quantize
92 // and Dequantize ops with per-tensor scale.
93 class PrepareDRQQuantizableOp : public OpRewritePattern<arith::ConstantOp> {
94 public:
PrepareDRQQuantizableOp(MLIRContext * context,const quant::QuantizationSpecs & quant_specs)95 explicit PrepareDRQQuantizableOp(MLIRContext* context,
96 const quant::QuantizationSpecs& quant_specs)
97 : OpRewritePattern<arith::ConstantOp>(context),
98 quant_specs_(quant_specs) {}
99
matchAndRewrite(arith::ConstantOp op,PatternRewriter & rewriter) const100 LogicalResult matchAndRewrite(arith::ConstantOp op,
101 PatternRewriter& rewriter) const override {
102 QuantizationUnits quantizable_ops;
103
104 // 1. Collect quantizable ops.
105 if (!(getQuantizableOps(op, quantizable_ops))) {
106 return failure();
107 }
108
109 // 2. Quantize collected ops. It is immediatly quantized by inserting Q-DQ
110 // pair for int8.
111 if (!(quantizeOps(rewriter, op, quantizable_ops))) {
112 return failure();
113 }
114
115 return success();
116 }
117
118 private:
119 // Mark users that are applicable for dynamic range quantization where the
120 // criteria for determining quantizable ops differs by the inference type.
getQuantizableOps(arith::ConstantOp op,QuantizationUnits & quantizable_ops) const121 bool getQuantizableOps(arith::ConstantOp op,
122 QuantizationUnits& quantizable_ops) const {
123 // Non-float tensors do not need quantization.
124 auto type = op.getType().dyn_cast<ShapedType>();
125 if (!type || !type.getElementType().isF32()) return false;
126
127 Value value = op.getResult();
128
129 // Check whether dynamic range quantization can be applied.
130 for (auto& use : value.getUses()) {
131 Operation* user = use.getOwner();
132 int operand_num = use.getOperandNumber();
133 std::unique_ptr<OpQuantSpec> spec = GetTFOpQuantSpec(user);
134
135 if (quant_specs_.inference_type == tensorflow::DT_QINT8 &&
136 spec->quantizable_operands.contains(operand_num)) {
137 quantizable_ops.insert({user, operand_num});
138 }
139 }
140
141 return !quantizable_ops.empty();
142 }
143
144 // Apply per-tensor quantization for int8 dynamic range quantization.
quantizeOpAsInt8(PatternRewriter & rewriter,arith::ConstantOp op,std::pair<Operation *,int> quant_op) const145 bool quantizeOpAsInt8(PatternRewriter& rewriter, arith::ConstantOp op,
146 std::pair<Operation*, int> quant_op) const {
147 bool is_narrow_range = true;
148 bool is_legacy_float = quant_specs_.legacy_float_scale;
149 bool is_signed = quant_specs_.IsSignedInferenceType();
150 int bit_width = quant_specs_.GetQuantizationTypeWidth();
151
152 QuantizedType quant_type = nullptr;
153 DenseFPElementsAttr attr;
154 if (!matchPattern(op->getResult(0), m_Constant(&attr))) return false;
155
156 quant_type = quant::GetUniformQuantizedTypeForWeight(
157 attr, is_narrow_range && is_signed, bit_width, is_signed,
158 is_narrow_range, is_legacy_float)
159 .template dyn_cast<quant::QuantizedType>();
160
161 return insertQDQ(rewriter, op, quant_type, quant_op);
162 }
163
164 // Insert Quantize and Dequantize ops.
insertQDQ(PatternRewriter & rewriter,arith::ConstantOp op,QuantizedType quant_type,std::pair<Operation *,int> quant_op) const165 bool insertQDQ(PatternRewriter& rewriter, arith::ConstantOp op,
166 QuantizedType quant_type,
167 std::pair<Operation*, int> quant_op) const {
168 if (!quant_type) return false;
169
170 Operation* quantize_op = quant_op.first;
171 int quantize_operand_num = quant_op.second;
172
173 Type expressed_type = op.getResult().getType();
174 Type cast_type = quant_type.castFromExpressedType(expressed_type);
175
176 // Insert DQ-op if it does not exist yet. Otherwise, just rewire without
177 // creating a new DQ-op.
178 for (auto connected_op : op->getUsers()) {
179 auto q_op =
180 llvm::dyn_cast_or_null<quantfork::QuantizeCastOp>(connected_op);
181 if (q_op && q_op.getType() == cast_type) {
182 auto dq_op = llvm::cast<quantfork::DequantizeCastOp>(
183 q_op.getResult().use_begin()->getOwner());
184 quantize_op->setOperand(quantize_operand_num, dq_op);
185 return false;
186 }
187 }
188 rewriter.setInsertionPointAfter(op);
189 auto q = rewriter.create<quantfork::QuantizeCastOp>(op->getLoc(), cast_type,
190 op.getResult());
191 auto dq = rewriter.create<quantfork::DequantizeCastOp>(op->getLoc(),
192 expressed_type, q);
193 quantize_op->setOperand(quantize_operand_num, dq.getResult());
194 return true;
195 }
196
197 // For each filtered user, apply quantization.
quantizeOps(PatternRewriter & rewriter,arith::ConstantOp op,QuantizationUnits & quantizable_ops) const198 bool quantizeOps(PatternRewriter& rewriter, arith::ConstantOp op,
199 QuantizationUnits& quantizable_ops) const {
200 bool quantized = false;
201
202 for (auto& quant_op : quantizable_ops) {
203 if (quant_specs_.inference_type == tensorflow::DT_QINT8) {
204 quantized |= quantizeOpAsInt8(rewriter, op, quant_op);
205 }
206 }
207 return quantized;
208 }
209
210 protected:
211 quant::QuantizationSpecs quant_specs_;
212 };
213
214 // Remove all the stats ops which are redundant for dynamic range quantizaiton.
removeAllStatsOp(func::FuncOp func)215 void PrepareQuantizeDRQPass::removeAllStatsOp(func::FuncOp func) {
216 func.walk([&](quantfork::StatisticsOp stats_op) {
217 stats_op.replaceAllUsesWith(stats_op.getArg());
218 stats_op.erase();
219 });
220 }
221
222 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.inc"
223
runOnOperation()224 void PrepareQuantizeDRQPass::runOnOperation() {
225 func::FuncOp func = getOperation();
226 MLIRContext* ctx = func.getContext();
227
228 removeAllStatsOp(func);
229
230 RewritePatternSet patterns(&getContext());
231 populateWithGenerated(patterns);
232 patterns.add<PrepareDRQQuantizableOp>(ctx, quant_specs_);
233 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
234 }
235
236 } // namespace
237
238 // Creates an instance of the TensorFlow dialect PrepareQuantizeDRQ
239 // pass.
CreatePrepareQuantizeDRQPass()240 std::unique_ptr<OperationPass<func::FuncOp>> CreatePrepareQuantizeDRQPass() {
241 return std::make_unique<PrepareQuantizeDRQPass>();
242 }
243
244 static PassRegistration<PrepareQuantizeDRQPass> pass;
245
246 } // namespace quant
247 } // namespace mlir
248