xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass applies some clean up steps after quantization.
17 
18 #include <string>
19 #include <utility>
20 
21 #include "llvm/Support/Casting.h"
22 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
28 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
29 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
30 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
31 
32 //===----------------------------------------------------------------------===//
33 // The post-quantize Passes.
34 //
35 namespace mlir {
36 namespace TFL {
37 namespace {
38 #define GEN_PASS_CLASSES
39 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
40 
41 // Applies all the clean up steps after quantization.
42 class PostQuantizePass : public PostQuantizePassBase<PostQuantizePass> {
43  public:
44   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PostQuantizePass)
45 
46   // Constructor used by the PassRegistration. This will remove the adaptor ops.
PostQuantizePass()47   explicit PostQuantizePass() { this->emit_quant_adaptor_ops_ = false; }
48 
49   // Constructor used by manually creating the pass.
PostQuantizePass(bool emit_quant_adaptor_ops,const quant::CustomOpMap & custom_op_map)50   explicit PostQuantizePass(bool emit_quant_adaptor_ops,
51                             const quant::CustomOpMap& custom_op_map)
52       : custom_op_map_(custom_op_map) {
53     // Set this flag to true if the inputs and outputs are in floating point.
54     // The quant adaptor ops convert them to fixed point values (i.e. quantize)
55     // before feeding them to the model and convert them back to floating point
56     // (i.e. dequantize) as the output.
57     this->emit_quant_adaptor_ops_ = emit_quant_adaptor_ops;
58   }
59 
60   void runOnOperation() override;
61 
62  private:
63   quant::CustomOpMap custom_op_map_;
64 };
65 
66 // Cleans up unnecessary QDQ pattern for input/output ops.
67 class PostQuantizeRemoveQDQPass
68     : public PostQuantizeRemoveQDQPassBase<PostQuantizeRemoveQDQPass> {
69  public:
70   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PostQuantizeRemoveQDQPass)
71 
72   void runOnOperation() override;
73 };
74 
75 // TODO(fengliuai): migrate to use modify_io_nodes pass.
RemoveQuantizationAdaptorOps(func::FuncOp func)76 void RemoveQuantizationAdaptorOps(func::FuncOp func) {
77   mlir::OpBuilder builder(func.getBody());
78   auto& bb = func.front();
79   auto loc = func.getLoc();
80 
81   int num_args = bb.getNumArguments();
82   llvm::SmallVector<Type, 4> input_types;
83   input_types.reserve(num_args);
84   // Edit the block arguments and create the new input ops in place to replace
85   // the old input ops and quantize ops.
86   for (int i = 0; i != num_args; ++i) {
87     // Previous loop iteration may invalidate the insertion point so we have to
88     // reset insertion point each iteration.
89     builder.setInsertionPointToStart(&bb);
90 
91     // In each iteration, a new argument is appended to the end of the list
92     // and the current argument is erased, so here we always process the first
93     // argument in the list.
94     auto arg = bb.getArgument(0);
95 
96     auto remove_quantize_op = [&](QuantizeOp quantize_op) {
97       auto quantize_output = quantize_op.output();
98       auto quantize_type = quantize_output.getType();
99       input_types.push_back(quantize_type);
100       auto new_arg = bb.addArgument(quantize_type, loc);
101       quantize_output.replaceAllUsesWith(new_arg);
102       quantize_op.erase();
103       arg.dropAllUses();
104       bb.eraseArgument(0);
105     };
106 
107     // This is looking for a pattern: arg -> tfl.quantize
108     if (arg.hasOneUse() && llvm::isa<QuantizeOp>(*arg.user_begin())) {
109       auto quantize_op = llvm::cast<QuantizeOp>(*arg.user_begin());
110       remove_quantize_op(quantize_op);
111       continue;
112     }
113 
114     // Make a copy of current argument and append it to the end of the list if
115     // the pattern isn't found.
116     Type arg_type = arg.getType();
117     input_types.push_back(arg_type);
118     auto new_arg = bb.addArgument(arg_type, loc);
119     arg.replaceAllUsesWith(new_arg);
120     arg.dropAllUses();
121     bb.eraseArgument(0);
122   }
123 
124   // Edit the return ops and remove the dequantize ops in place.
125   auto* terminator = bb.getTerminator();
126   int num_return_operands = terminator->getNumOperands();
127   llvm::SmallVector<Type, 4> output_types;
128   output_types.reserve(num_return_operands);
129   for (int i = 0; i != num_return_operands; ++i) {
130     auto returned_value = terminator->getOperand(i);
131     Operation* returned_op = returned_value.getDefiningOp();
132     if (returned_op && returned_op->hasOneUse() &&
133         llvm::isa<DequantizeOp>(returned_op)) {
134       auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
135       Value dequantized_result = dequantize_op.input();
136       output_types.push_back(dequantized_result.getType());
137       terminator->setOperand(i, dequantized_result);
138       returned_op->erase();
139     } else {
140       output_types.push_back(returned_value.getType());
141     }
142   }
143   auto new_func_type = builder.getFunctionType(input_types, output_types);
144   func.setType(new_func_type);
145 }
146 
147 enum RemoveVolatileOpsType {
148   // Remove all volatile quant-dequant ops.
149   kPreserveNone,
150   // Preserve volatile quant-dequants for input and output ops.
151   kPreserveInputsAndOutputs,
152 };
153 
154 // Remove the back-to-back quantize and dequantize ops with volatile attribute.
155 template <RemoveVolatileOpsType remove_volatile_ops_type>
156 struct RemoveVolatileOps : public OpRewritePattern<DequantizeOp> {
RemoveVolatileOpsmlir::TFL::__anonfa26bdbd0111::RemoveVolatileOps157   explicit RemoveVolatileOps(MLIRContext* context)
158       : OpRewritePattern<DequantizeOp>(context, 1) {}
159 
matchAndRewritemlir::TFL::__anonfa26bdbd0111::RemoveVolatileOps160   LogicalResult matchAndRewrite(DequantizeOp op,
161                                 PatternRewriter& rewriter) const override {
162     auto input_op = op.input().getDefiningOp();
163     if (auto q = llvm::dyn_cast_or_null<QuantizeOp>(input_op)) {
164       if (!q->getAttr(mlir::quant::kVolatileOpAttrName)) return failure();
165 
166       if (remove_volatile_ops_type == kPreserveInputsAndOutputs) {
167         // Don't remove leading and trailing QDQ for PTQ workflow, so the io
168         // modifying lib can work correctly.
169         if (!q.input().getDefiningOp()) return failure();
170         if (op->hasOneUse() &&
171             op->user_begin()->hasTrait<OpTrait::IsTerminator>())
172           return failure();
173       }
174       // If the quantize op is a requantize op, it is being used in other scale
175       // adjustments and should be kept. Instead, moving dequantize op before
176       // the requantize op to remove the unnecessary requantize op.
177       if (auto qtype = quant::QuantizedType::getQuantizedElementType(
178               q.input().getType())) {
179         rewriter.setInsertionPoint(op);
180         rewriter.replaceOpWithNewOp<DequantizeOp>(op, op.output().getType(),
181                                                   q.input());
182         return success();
183       }
184 
185       op.replaceAllUsesWith(q.input());
186       return success();
187     }
188     return failure();
189   }
190 };
191 
192 // Fold the constant quantized Transpose ops.
193 struct FoldTransposeOp : public OpRewritePattern<TransposeOp> {
FoldTransposeOpmlir::TFL::__anonfa26bdbd0111::FoldTransposeOp194   explicit FoldTransposeOp(MLIRContext* context)
195       : OpRewritePattern<TransposeOp>(context, 1) {}
196 
197   // Computes the permutation of a constant `input_tensor` according to `perm`.
198   // The function recursively traverses the dimensions of the output tensor in
199   // a row-major order and writes the value in the output tensor into
200   // `new_values`.
ComputePermutationmlir::TFL::__anonfa26bdbd0111::FoldTransposeOp201   void ComputePermutation(ElementsAttr input_tensor, ArrayRef<int32_t> perm,
202                           ArrayRef<int64_t> output_shape, int num_dimensions,
203                           int output_axis, std::vector<uint64_t>* input_indices,
204                           std::vector<Attribute>* new_values) const {
205     // Refer to the implementation of `Transpose` function in
206     // tensorflow/lite/kernels/internal/reference/reference_ops.h
207     assert(output_axis < num_dimensions);
208     const int input_axis = perm[output_axis];
209     for (int i = 0; i < output_shape[output_axis]; ++i) {
210       // Update the input indices on `input_axis`.
211       assert(input_axis < input_indices->size());
212       input_indices->operator[](input_axis) = static_cast<uint64_t>(i);
213       // Write the value from `input_tensor` if it is the last axis or
214       // recurse into the next axis.
215       const bool is_last_axis = output_axis == num_dimensions - 1;
216       if (is_last_axis) {
217         new_values->push_back(
218             input_tensor.getValues<Attribute>()[*input_indices]);
219       } else {
220         ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
221                            output_axis + 1, input_indices, new_values);
222       }
223     }
224   }
225 
matchAndRewritemlir::TFL::__anonfa26bdbd0111::FoldTransposeOp226   LogicalResult matchAndRewrite(TransposeOp op,
227                                 PatternRewriter& rewriter) const override {
228     Operation* def_op = op.input().getDefiningOp();
229     auto qconst_op = llvm::dyn_cast_or_null<QConstOp>(def_op);
230     if (qconst_op == nullptr) return failure();
231 
232     DenseIntElementsAttr perm_tensor;
233     if (!matchPattern(op.perm(), m_Constant(&perm_tensor))) return failure();
234 
235     if (!(getElementTypeOrSelf(op.output().getType()))
236              .isa<quant::UniformQuantizedType>())
237       return failure();
238 
239     ElementsAttr input_tensor = qconst_op.value();
240 
241     assert(perm_tensor.getType().getRank() == 1);
242     const int num_dimensions = input_tensor.getType().getRank();
243     assert(perm_tensor.getType().getNumElements() == num_dimensions);
244 
245     ArrayRef<int64_t> input_shape = input_tensor.getType().getShape();
246     auto output_type = op.output().getType().cast<ShapedType>();
247 
248     SmallVector<int32_t, 4> perm;
249     SmallVector<int64_t, 4> output_shape;
250     for (int i = 0; i < num_dimensions; ++i) {
251       perm.push_back(perm_tensor.getValues<IntegerAttr>()[i].getInt());
252       output_shape.push_back(input_shape[perm[i]]);
253 
254       // Check that the derived output shape matches the static shape.
255       assert(!output_type.hasStaticShape() ||
256              output_type.getShape()[i] == output_shape[i]);
257     }
258 
259     std::vector<Attribute> new_values;
260     new_values.reserve(input_tensor.getType().getNumElements());
261     std::vector<uint64_t> input_indices(num_dimensions);
262     ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
263                        /*output_axis=*/0, &input_indices, &new_values);
264     auto result_type =
265         RankedTensorType::get(output_shape, output_type.getElementType());
266     auto values_type = RankedTensorType::get(
267         output_shape, output_type.getElementType()
268                           .cast<quant::UniformQuantizedType>()
269                           .getStorageType());
270     rewriter.replaceOpWithNewOp<QConstOp>(
271         op, TypeAttr::get(result_type),
272         DenseIntElementsAttr::get(values_type, new_values));
273     return success();
274   }
275 };
276 
277 // Fold constant quantized Reshape ops.
278 struct FoldReshapeOp : public OpRewritePattern<ReshapeOp> {
279   // Does not take ownership of context, which must refer to a valid value that
280   // outlives this object.
FoldReshapeOpmlir::TFL::__anonfa26bdbd0111::FoldReshapeOp281   explicit FoldReshapeOp(MLIRContext* context)
282       : OpRewritePattern<ReshapeOp>(context, /*benefit=*/1) {}
283 
matchAndRewritemlir::TFL::__anonfa26bdbd0111::FoldReshapeOp284   LogicalResult matchAndRewrite(ReshapeOp op,
285                                 PatternRewriter& rewriter) const override {
286     Operation* def_op = op.input().getDefiningOp();
287     auto qconst_op = llvm::dyn_cast_or_null<QConstOp>(def_op);
288     if (qconst_op == nullptr) return failure();
289 
290     auto dense_elements =
291         qconst_op.value().dyn_cast_or_null<DenseElementsAttr>();
292     if (dense_elements == nullptr) return failure();
293 
294     // Handle per tensor cases only.
295     if (!(getElementTypeOrSelf(op.getType()))
296              .isa<quant::UniformQuantizedType>()) {
297       return failure();
298     }
299 
300     // Remove identity reshape with both static result and input shape.
301     auto result_type = op.getType().cast<ShapedType>();
302     auto input_type = op.input().getType().cast<ShapedType>();
303 
304     // Constant folding
305     // If the result type isn't static, tries to derive the result type from
306     // the #2 operand.
307     if (!result_type.hasStaticShape()) {
308       DenseIntElementsAttr shape_elements;
309       if (!matchPattern(op.shape(), m_Constant(&shape_elements)))
310         return failure();
311 
312       SmallVector<int64_t, 4> shape_data;
313       for (const APInt& it : shape_elements.getValues<APInt>()) {
314         shape_data.push_back(it.getSExtValue());
315       }
316       result_type =
317           RankedTensorType::get(shape_data, input_type.getElementType());
318     }
319     auto values_type = RankedTensorType::get(
320         result_type.getShape(), result_type.getElementType()
321                                     .cast<quant::UniformQuantizedType>()
322                                     .getStorageType());
323 
324     DenseElementsAttr reshaped_elements = dense_elements.reshape(values_type);
325     rewriter.replaceOpWithNewOp<QConstOp>(op, TypeAttr::get(result_type),
326                                           reshaped_elements);
327     return success();
328   }
329 };
330 
331 // Removes operations with side effect (i.e. LSTM, SVDF) that have dangling
332 // output.
333 template <typename OpTy>
334 struct PruneUnusedOpsWithSideEffect : public OpRewritePattern<OpTy> {
335  public:
PruneUnusedOpsWithSideEffectmlir::TFL::__anonfa26bdbd0111::PruneUnusedOpsWithSideEffect336   explicit PruneUnusedOpsWithSideEffect(
337       MLIRContext* context, const quant::CustomOpMap& custom_op_map = {})
338       : OpRewritePattern<OpTy>(context), custom_op_map(custom_op_map) {}
339 
matchAndRewritemlir::TFL::__anonfa26bdbd0111::PruneUnusedOpsWithSideEffect340   LogicalResult matchAndRewrite(OpTy op,
341                                 PatternRewriter& rewriter) const override {
342     if (op.getOperation()->template hasTrait<OpTrait::IsTerminator>()) {
343       return failure();
344     }
345     for (auto result : op.getOperation()->getOpResults()) {
346       if (!result.use_empty()) {
347         return failure();
348       }
349     }
350     // Remove if the custom op is in the provided map and is NoSideEffect.
351     auto custom_op = llvm::isa<CustomOp>(op);
352     if (custom_op) {
353       auto q = llvm::cast<CustomOp>(op);
354       std::string op_name = q.custom_code().str();
355       if ((custom_op_map.find(op_name) == custom_op_map.end()) ||
356           !custom_op_map.find(op_name)->second.no_side_effect)
357         return failure();
358     }
359     rewriter.eraseOp(op);
360     return success();
361   }
362   quant::CustomOpMap custom_op_map;
363 };
364 
365 #include "tensorflow/compiler/mlir/lite/transforms/generated_post_quantize.inc"
366 
runOnOperation()367 void PostQuantizePass::runOnOperation() {
368   if (!enable_custom_op_no_side_effect_.empty()) {
369     ParseCustomOpSpecs(enable_custom_op_no_side_effect_,
370                        quant::CustomOpUpdateOptions::kNoSideEffect,
371                        custom_op_map_);
372   }
373 
374   RewritePatternSet patterns(&getContext());
375   auto func = getOperation();
376   auto* ctx = func.getContext();
377   TFL::populateWithGenerated(patterns);
378   patterns.add<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
379   patterns.add<PruneUnusedOpsWithSideEffect<TFL::LSTMOp>>(ctx);
380   patterns.add<PruneUnusedOpsWithSideEffect<TFL::UnidirectionalSequenceLSTMOp>>(
381       ctx);
382   patterns.add<PruneUnusedOpsWithSideEffect<TFL::SVDFOp>>(ctx);
383   patterns.add<PruneUnusedOpsWithSideEffect<TFL::CustomOp>>(ctx,
384                                                             custom_op_map_);
385   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
386 
387   if (!emit_quant_adaptor_ops_) {
388     RemoveQuantizationAdaptorOps(getOperation());
389   }
390 
391   RewritePatternSet phase_2_patterns(&getContext());
392   TFL::populateWithGenerated(phase_2_patterns);
393   phase_2_patterns.add<quant::FoldTrivalRequantizeOp<QuantizeOp>,
394                        RemoveVolatileOps<kPreserveInputsAndOutputs>,
395                        FoldTransposeOp, FoldReshapeOp>(ctx);
396   (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
397 }
398 
runOnOperation()399 void PostQuantizeRemoveQDQPass::runOnOperation() {
400   RewritePatternSet patterns(&getContext());
401   auto func = getOperation();
402   auto* ctx = func.getContext();
403   TFL::populateWithGenerated(patterns);
404   patterns.add<RemoveVolatileOps<kPreserveNone>>(ctx);
405   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
406 }
407 
408 }  // namespace
409 
410 // Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
CreatePostQuantizePass(bool emit_quant_adaptor_ops,const quant::CustomOpMap & custom_op_map)411 std::unique_ptr<OperationPass<func::FuncOp>> CreatePostQuantizePass(
412     bool emit_quant_adaptor_ops, const quant::CustomOpMap& custom_op_map) {
413   return std::make_unique<PostQuantizePass>(emit_quant_adaptor_ops,
414                                             custom_op_map);
415 }
416 
CreatePostQuantizePass()417 std::unique_ptr<OperationPass<func::FuncOp>> CreatePostQuantizePass() {
418   return std::make_unique<PostQuantizePass>();
419 }
420 
421 // Creates an instance of the TensorFlow Lite dialect PostQuantizeRemoveQDQ
422 // pass.
CreatePostQuantizeRemoveQDQPass()423 std::unique_ptr<OperationPass<func::FuncOp>> CreatePostQuantizeRemoveQDQPass() {
424   return std::make_unique<PostQuantizeRemoveQDQPass>();
425 }
426 
427 }  // namespace TFL
428 }  // namespace mlir
429