1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass applies some clean up steps after quantization.
17
18 #include <string>
19 #include <utility>
20
21 #include "llvm/Support/Casting.h"
22 #include "mlir/IR/MLIRContext.h" // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "mlir/Support/LogicalResult.h" // from @llvm-project
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
28 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
29 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
30 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
31
32 //===----------------------------------------------------------------------===//
33 // The post-quantize Passes.
34 //
35 namespace mlir {
36 namespace TFL {
37 namespace {
38 #define GEN_PASS_CLASSES
39 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
40
41 // Applies all the clean up steps after quantization.
42 class PostQuantizePass : public PostQuantizePassBase<PostQuantizePass> {
43 public:
44 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PostQuantizePass)
45
46 // Constructor used by the PassRegistration. This will remove the adaptor ops.
PostQuantizePass()47 explicit PostQuantizePass() { this->emit_quant_adaptor_ops_ = false; }
48
49 // Constructor used by manually creating the pass.
PostQuantizePass(bool emit_quant_adaptor_ops,const quant::CustomOpMap & custom_op_map)50 explicit PostQuantizePass(bool emit_quant_adaptor_ops,
51 const quant::CustomOpMap& custom_op_map)
52 : custom_op_map_(custom_op_map) {
53 // Set this flag to true if the inputs and outputs are in floating point.
54 // The quant adaptor ops convert them to fixed point values (i.e. quantize)
55 // before feeding them to the model and convert them back to floating point
56 // (i.e. dequantize) as the output.
57 this->emit_quant_adaptor_ops_ = emit_quant_adaptor_ops;
58 }
59
60 void runOnOperation() override;
61
62 private:
63 quant::CustomOpMap custom_op_map_;
64 };
65
66 // Cleans up unnecessary QDQ pattern for input/output ops.
67 class PostQuantizeRemoveQDQPass
68 : public PostQuantizeRemoveQDQPassBase<PostQuantizeRemoveQDQPass> {
69 public:
70 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PostQuantizeRemoveQDQPass)
71
72 void runOnOperation() override;
73 };
74
75 // TODO(fengliuai): migrate to use modify_io_nodes pass.
RemoveQuantizationAdaptorOps(func::FuncOp func)76 void RemoveQuantizationAdaptorOps(func::FuncOp func) {
77 mlir::OpBuilder builder(func.getBody());
78 auto& bb = func.front();
79 auto loc = func.getLoc();
80
81 int num_args = bb.getNumArguments();
82 llvm::SmallVector<Type, 4> input_types;
83 input_types.reserve(num_args);
84 // Edit the block arguments and create the new input ops in place to replace
85 // the old input ops and quantize ops.
86 for (int i = 0; i != num_args; ++i) {
87 // Previous loop iteration may invalidate the insertion point so we have to
88 // reset insertion point each iteration.
89 builder.setInsertionPointToStart(&bb);
90
91 // In each iteration, a new argument is appended to the end of the list
92 // and the current argument is erased, so here we always process the first
93 // argument in the list.
94 auto arg = bb.getArgument(0);
95
96 auto remove_quantize_op = [&](QuantizeOp quantize_op) {
97 auto quantize_output = quantize_op.output();
98 auto quantize_type = quantize_output.getType();
99 input_types.push_back(quantize_type);
100 auto new_arg = bb.addArgument(quantize_type, loc);
101 quantize_output.replaceAllUsesWith(new_arg);
102 quantize_op.erase();
103 arg.dropAllUses();
104 bb.eraseArgument(0);
105 };
106
107 // This is looking for a pattern: arg -> tfl.quantize
108 if (arg.hasOneUse() && llvm::isa<QuantizeOp>(*arg.user_begin())) {
109 auto quantize_op = llvm::cast<QuantizeOp>(*arg.user_begin());
110 remove_quantize_op(quantize_op);
111 continue;
112 }
113
114 // Make a copy of current argument and append it to the end of the list if
115 // the pattern isn't found.
116 Type arg_type = arg.getType();
117 input_types.push_back(arg_type);
118 auto new_arg = bb.addArgument(arg_type, loc);
119 arg.replaceAllUsesWith(new_arg);
120 arg.dropAllUses();
121 bb.eraseArgument(0);
122 }
123
124 // Edit the return ops and remove the dequantize ops in place.
125 auto* terminator = bb.getTerminator();
126 int num_return_operands = terminator->getNumOperands();
127 llvm::SmallVector<Type, 4> output_types;
128 output_types.reserve(num_return_operands);
129 for (int i = 0; i != num_return_operands; ++i) {
130 auto returned_value = terminator->getOperand(i);
131 Operation* returned_op = returned_value.getDefiningOp();
132 if (returned_op && returned_op->hasOneUse() &&
133 llvm::isa<DequantizeOp>(returned_op)) {
134 auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
135 Value dequantized_result = dequantize_op.input();
136 output_types.push_back(dequantized_result.getType());
137 terminator->setOperand(i, dequantized_result);
138 returned_op->erase();
139 } else {
140 output_types.push_back(returned_value.getType());
141 }
142 }
143 auto new_func_type = builder.getFunctionType(input_types, output_types);
144 func.setType(new_func_type);
145 }
146
147 enum RemoveVolatileOpsType {
148 // Remove all volatile quant-dequant ops.
149 kPreserveNone,
150 // Preserve volatile quant-dequants for input and output ops.
151 kPreserveInputsAndOutputs,
152 };
153
154 // Remove the back-to-back quantize and dequantize ops with volatile attribute.
155 template <RemoveVolatileOpsType remove_volatile_ops_type>
156 struct RemoveVolatileOps : public OpRewritePattern<DequantizeOp> {
RemoveVolatileOpsmlir::TFL::__anonfa26bdbd0111::RemoveVolatileOps157 explicit RemoveVolatileOps(MLIRContext* context)
158 : OpRewritePattern<DequantizeOp>(context, 1) {}
159
matchAndRewritemlir::TFL::__anonfa26bdbd0111::RemoveVolatileOps160 LogicalResult matchAndRewrite(DequantizeOp op,
161 PatternRewriter& rewriter) const override {
162 auto input_op = op.input().getDefiningOp();
163 if (auto q = llvm::dyn_cast_or_null<QuantizeOp>(input_op)) {
164 if (!q->getAttr(mlir::quant::kVolatileOpAttrName)) return failure();
165
166 if (remove_volatile_ops_type == kPreserveInputsAndOutputs) {
167 // Don't remove leading and trailing QDQ for PTQ workflow, so the io
168 // modifying lib can work correctly.
169 if (!q.input().getDefiningOp()) return failure();
170 if (op->hasOneUse() &&
171 op->user_begin()->hasTrait<OpTrait::IsTerminator>())
172 return failure();
173 }
174 // If the quantize op is a requantize op, it is being used in other scale
175 // adjustments and should be kept. Instead, moving dequantize op before
176 // the requantize op to remove the unnecessary requantize op.
177 if (auto qtype = quant::QuantizedType::getQuantizedElementType(
178 q.input().getType())) {
179 rewriter.setInsertionPoint(op);
180 rewriter.replaceOpWithNewOp<DequantizeOp>(op, op.output().getType(),
181 q.input());
182 return success();
183 }
184
185 op.replaceAllUsesWith(q.input());
186 return success();
187 }
188 return failure();
189 }
190 };
191
192 // Fold the constant quantized Transpose ops.
193 struct FoldTransposeOp : public OpRewritePattern<TransposeOp> {
FoldTransposeOpmlir::TFL::__anonfa26bdbd0111::FoldTransposeOp194 explicit FoldTransposeOp(MLIRContext* context)
195 : OpRewritePattern<TransposeOp>(context, 1) {}
196
197 // Computes the permutation of a constant `input_tensor` according to `perm`.
198 // The function recursively traverses the dimensions of the output tensor in
199 // a row-major order and writes the value in the output tensor into
200 // `new_values`.
ComputePermutationmlir::TFL::__anonfa26bdbd0111::FoldTransposeOp201 void ComputePermutation(ElementsAttr input_tensor, ArrayRef<int32_t> perm,
202 ArrayRef<int64_t> output_shape, int num_dimensions,
203 int output_axis, std::vector<uint64_t>* input_indices,
204 std::vector<Attribute>* new_values) const {
205 // Refer to the implementation of `Transpose` function in
206 // tensorflow/lite/kernels/internal/reference/reference_ops.h
207 assert(output_axis < num_dimensions);
208 const int input_axis = perm[output_axis];
209 for (int i = 0; i < output_shape[output_axis]; ++i) {
210 // Update the input indices on `input_axis`.
211 assert(input_axis < input_indices->size());
212 input_indices->operator[](input_axis) = static_cast<uint64_t>(i);
213 // Write the value from `input_tensor` if it is the last axis or
214 // recurse into the next axis.
215 const bool is_last_axis = output_axis == num_dimensions - 1;
216 if (is_last_axis) {
217 new_values->push_back(
218 input_tensor.getValues<Attribute>()[*input_indices]);
219 } else {
220 ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
221 output_axis + 1, input_indices, new_values);
222 }
223 }
224 }
225
matchAndRewritemlir::TFL::__anonfa26bdbd0111::FoldTransposeOp226 LogicalResult matchAndRewrite(TransposeOp op,
227 PatternRewriter& rewriter) const override {
228 Operation* def_op = op.input().getDefiningOp();
229 auto qconst_op = llvm::dyn_cast_or_null<QConstOp>(def_op);
230 if (qconst_op == nullptr) return failure();
231
232 DenseIntElementsAttr perm_tensor;
233 if (!matchPattern(op.perm(), m_Constant(&perm_tensor))) return failure();
234
235 if (!(getElementTypeOrSelf(op.output().getType()))
236 .isa<quant::UniformQuantizedType>())
237 return failure();
238
239 ElementsAttr input_tensor = qconst_op.value();
240
241 assert(perm_tensor.getType().getRank() == 1);
242 const int num_dimensions = input_tensor.getType().getRank();
243 assert(perm_tensor.getType().getNumElements() == num_dimensions);
244
245 ArrayRef<int64_t> input_shape = input_tensor.getType().getShape();
246 auto output_type = op.output().getType().cast<ShapedType>();
247
248 SmallVector<int32_t, 4> perm;
249 SmallVector<int64_t, 4> output_shape;
250 for (int i = 0; i < num_dimensions; ++i) {
251 perm.push_back(perm_tensor.getValues<IntegerAttr>()[i].getInt());
252 output_shape.push_back(input_shape[perm[i]]);
253
254 // Check that the derived output shape matches the static shape.
255 assert(!output_type.hasStaticShape() ||
256 output_type.getShape()[i] == output_shape[i]);
257 }
258
259 std::vector<Attribute> new_values;
260 new_values.reserve(input_tensor.getType().getNumElements());
261 std::vector<uint64_t> input_indices(num_dimensions);
262 ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
263 /*output_axis=*/0, &input_indices, &new_values);
264 auto result_type =
265 RankedTensorType::get(output_shape, output_type.getElementType());
266 auto values_type = RankedTensorType::get(
267 output_shape, output_type.getElementType()
268 .cast<quant::UniformQuantizedType>()
269 .getStorageType());
270 rewriter.replaceOpWithNewOp<QConstOp>(
271 op, TypeAttr::get(result_type),
272 DenseIntElementsAttr::get(values_type, new_values));
273 return success();
274 }
275 };
276
277 // Fold constant quantized Reshape ops.
278 struct FoldReshapeOp : public OpRewritePattern<ReshapeOp> {
279 // Does not take ownership of context, which must refer to a valid value that
280 // outlives this object.
FoldReshapeOpmlir::TFL::__anonfa26bdbd0111::FoldReshapeOp281 explicit FoldReshapeOp(MLIRContext* context)
282 : OpRewritePattern<ReshapeOp>(context, /*benefit=*/1) {}
283
matchAndRewritemlir::TFL::__anonfa26bdbd0111::FoldReshapeOp284 LogicalResult matchAndRewrite(ReshapeOp op,
285 PatternRewriter& rewriter) const override {
286 Operation* def_op = op.input().getDefiningOp();
287 auto qconst_op = llvm::dyn_cast_or_null<QConstOp>(def_op);
288 if (qconst_op == nullptr) return failure();
289
290 auto dense_elements =
291 qconst_op.value().dyn_cast_or_null<DenseElementsAttr>();
292 if (dense_elements == nullptr) return failure();
293
294 // Handle per tensor cases only.
295 if (!(getElementTypeOrSelf(op.getType()))
296 .isa<quant::UniformQuantizedType>()) {
297 return failure();
298 }
299
300 // Remove identity reshape with both static result and input shape.
301 auto result_type = op.getType().cast<ShapedType>();
302 auto input_type = op.input().getType().cast<ShapedType>();
303
304 // Constant folding
305 // If the result type isn't static, tries to derive the result type from
306 // the #2 operand.
307 if (!result_type.hasStaticShape()) {
308 DenseIntElementsAttr shape_elements;
309 if (!matchPattern(op.shape(), m_Constant(&shape_elements)))
310 return failure();
311
312 SmallVector<int64_t, 4> shape_data;
313 for (const APInt& it : shape_elements.getValues<APInt>()) {
314 shape_data.push_back(it.getSExtValue());
315 }
316 result_type =
317 RankedTensorType::get(shape_data, input_type.getElementType());
318 }
319 auto values_type = RankedTensorType::get(
320 result_type.getShape(), result_type.getElementType()
321 .cast<quant::UniformQuantizedType>()
322 .getStorageType());
323
324 DenseElementsAttr reshaped_elements = dense_elements.reshape(values_type);
325 rewriter.replaceOpWithNewOp<QConstOp>(op, TypeAttr::get(result_type),
326 reshaped_elements);
327 return success();
328 }
329 };
330
331 // Removes operations with side effect (i.e. LSTM, SVDF) that have dangling
332 // output.
333 template <typename OpTy>
334 struct PruneUnusedOpsWithSideEffect : public OpRewritePattern<OpTy> {
335 public:
PruneUnusedOpsWithSideEffectmlir::TFL::__anonfa26bdbd0111::PruneUnusedOpsWithSideEffect336 explicit PruneUnusedOpsWithSideEffect(
337 MLIRContext* context, const quant::CustomOpMap& custom_op_map = {})
338 : OpRewritePattern<OpTy>(context), custom_op_map(custom_op_map) {}
339
matchAndRewritemlir::TFL::__anonfa26bdbd0111::PruneUnusedOpsWithSideEffect340 LogicalResult matchAndRewrite(OpTy op,
341 PatternRewriter& rewriter) const override {
342 if (op.getOperation()->template hasTrait<OpTrait::IsTerminator>()) {
343 return failure();
344 }
345 for (auto result : op.getOperation()->getOpResults()) {
346 if (!result.use_empty()) {
347 return failure();
348 }
349 }
350 // Remove if the custom op is in the provided map and is NoSideEffect.
351 auto custom_op = llvm::isa<CustomOp>(op);
352 if (custom_op) {
353 auto q = llvm::cast<CustomOp>(op);
354 std::string op_name = q.custom_code().str();
355 if ((custom_op_map.find(op_name) == custom_op_map.end()) ||
356 !custom_op_map.find(op_name)->second.no_side_effect)
357 return failure();
358 }
359 rewriter.eraseOp(op);
360 return success();
361 }
362 quant::CustomOpMap custom_op_map;
363 };
364
365 #include "tensorflow/compiler/mlir/lite/transforms/generated_post_quantize.inc"
366
runOnOperation()367 void PostQuantizePass::runOnOperation() {
368 if (!enable_custom_op_no_side_effect_.empty()) {
369 ParseCustomOpSpecs(enable_custom_op_no_side_effect_,
370 quant::CustomOpUpdateOptions::kNoSideEffect,
371 custom_op_map_);
372 }
373
374 RewritePatternSet patterns(&getContext());
375 auto func = getOperation();
376 auto* ctx = func.getContext();
377 TFL::populateWithGenerated(patterns);
378 patterns.add<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
379 patterns.add<PruneUnusedOpsWithSideEffect<TFL::LSTMOp>>(ctx);
380 patterns.add<PruneUnusedOpsWithSideEffect<TFL::UnidirectionalSequenceLSTMOp>>(
381 ctx);
382 patterns.add<PruneUnusedOpsWithSideEffect<TFL::SVDFOp>>(ctx);
383 patterns.add<PruneUnusedOpsWithSideEffect<TFL::CustomOp>>(ctx,
384 custom_op_map_);
385 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
386
387 if (!emit_quant_adaptor_ops_) {
388 RemoveQuantizationAdaptorOps(getOperation());
389 }
390
391 RewritePatternSet phase_2_patterns(&getContext());
392 TFL::populateWithGenerated(phase_2_patterns);
393 phase_2_patterns.add<quant::FoldTrivalRequantizeOp<QuantizeOp>,
394 RemoveVolatileOps<kPreserveInputsAndOutputs>,
395 FoldTransposeOp, FoldReshapeOp>(ctx);
396 (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
397 }
398
runOnOperation()399 void PostQuantizeRemoveQDQPass::runOnOperation() {
400 RewritePatternSet patterns(&getContext());
401 auto func = getOperation();
402 auto* ctx = func.getContext();
403 TFL::populateWithGenerated(patterns);
404 patterns.add<RemoveVolatileOps<kPreserveNone>>(ctx);
405 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
406 }
407
408 } // namespace
409
410 // Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
CreatePostQuantizePass(bool emit_quant_adaptor_ops,const quant::CustomOpMap & custom_op_map)411 std::unique_ptr<OperationPass<func::FuncOp>> CreatePostQuantizePass(
412 bool emit_quant_adaptor_ops, const quant::CustomOpMap& custom_op_map) {
413 return std::make_unique<PostQuantizePass>(emit_quant_adaptor_ops,
414 custom_op_map);
415 }
416
CreatePostQuantizePass()417 std::unique_ptr<OperationPass<func::FuncOp>> CreatePostQuantizePass() {
418 return std::make_unique<PostQuantizePass>();
419 }
420
421 // Creates an instance of the TensorFlow Lite dialect PostQuantizeRemoveQDQ
422 // pass.
CreatePostQuantizeRemoveQDQPass()423 std::unique_ptr<OperationPass<func::FuncOp>> CreatePostQuantizeRemoveQDQPass() {
424 return std::make_unique<PostQuantizeRemoveQDQPass>();
425 }
426
427 } // namespace TFL
428 } // namespace mlir
429