1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass decomposes dense operations that assume
17 // support for hybrid quantization. These cases cover when a dense operation
18 // (e.g. matmul) has both quantized and unquantized inputs by dequantizing
19 // the quantized inputs, performing the operation in the expressed type, then
20 // requantizing if a quantized output is required.
21 //
22 // The motivation behind these changes is for Dialects that assume only float
23 // or quantized computation, and do not support a mixture of these types on
24 // dense operations. Decomposition allows TFLite to be compiled to these
25 // dialects, such as TOSA.
26
27 #include <utility>
28
29 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
30 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/Pass/Pass.h" // from @llvm-project
34 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
36 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
37
38 namespace mlir {
39 namespace TFL {
40
41 namespace {
42
43 #define GEN_PASS_CLASSES
44 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
45
46 class DecomposeHybridQuantizationPass
47 : public DecomposeHybridQuantizationPassBase<
48 DecomposeHybridQuantizationPass> {
49 public:
DecomposeHybridQuantizationPass()50 explicit DecomposeHybridQuantizationPass() {}
51 void runOnOperation() override;
52 };
53
54 template <typename SrcOp>
55 class DequantizeConverter : public OpRewritePattern<SrcOp> {
56 public:
57 using OpRewritePattern<SrcOp>::OpRewritePattern;
58
matchAndRewrite(SrcOp srcop,PatternRewriter & rewriter) const59 LogicalResult matchAndRewrite(SrcOp srcop,
60 PatternRewriter &rewriter) const final {
61 Operation *op = srcop.getOperation();
62 bool allTypesFp = true;
63 bool allTypesQuantizedOrInt = true;
64 for (auto operand : op->getOperands()) {
65 ShapedType type = operand.getType().template dyn_cast<ShapedType>();
66 if (!type) continue;
67 allTypesFp &= !type.getElementType().isa<quant::QuantizedType>();
68 allTypesQuantizedOrInt &=
69 (type.getElementType().isa<quant::QuantizedType>() ||
70 type.getElementType().isa<IntegerType>());
71 }
72
73 for (auto result : op->getResults()) {
74 ShapedType type = result.getType().template cast<ShapedType>();
75 allTypesFp &= !type.getElementType().isa<quant::QuantizedType>();
76 allTypesQuantizedOrInt &=
77 (type.getElementType().isa<quant::QuantizedType>() ||
78 type.getElementType().isa<IntegerType>());
79 }
80
81 // If all quantized or floating point then types are consistent.
82 // Int is valid in combination with both quantized and floating point.
83 // This occurs when doing qi16 convolution, as bias is passed as a
84 // non-quantized int64
85 if (allTypesFp || allTypesQuantizedOrInt) return failure();
86
87 Location loc = op->getLoc();
88 SmallVector<Value> newOperands;
89 newOperands.reserve(op->getNumOperands());
90 for (auto operand : op->getOperands()) {
91 if (QuantizedType::getQuantizedElementType(operand.getType())) {
92 auto newTy = QuantizedType::castToExpressedType(operand.getType());
93 newOperands.push_back(
94 rewriter.create<TFL::DequantizeOp>(loc, newTy, operand));
95 continue;
96 }
97
98 newOperands.push_back(operand);
99 }
100
101 SmallVector<Type> newResultTys;
102 for (auto result : op->getResults()) {
103 Type resultTy = result.getType();
104 if (QuantizedType::getQuantizedElementType(resultTy)) {
105 resultTy = QuantizedType::castToExpressedType(resultTy);
106 }
107 newResultTys.push_back(resultTy);
108 }
109
110 auto newResults = rewriter
111 .create<SrcOp>(loc, newResultTys, newOperands,
112 op->getAttrDictionary().getValue())
113 .getOperation()
114 ->getResults();
115
116 SmallVector<Value> replaceResults;
117 for (int i = 0; i < newResults.size(); i++) {
118 Value result = newResults[i];
119 Type resultTy = op->getOpResult(i).getType();
120 if (QuantizedType::getQuantizedElementType(resultTy)) {
121 replaceResults.push_back(rewriter.create<TFL::QuantizeOp>(
122 loc, resultTy, result, TypeAttr::get(resultTy)));
123 continue;
124 }
125
126 replaceResults.push_back(result);
127 }
128
129 rewriter.replaceOp(op, replaceResults);
130
131 return success();
132 }
133 };
134
runOnOperation()135 void DecomposeHybridQuantizationPass::runOnOperation() {
136 RewritePatternSet patterns(&getContext());
137 auto *ctx = &getContext();
138 auto func = getOperation();
139 patterns.add<DequantizeConverter<TFL::Conv2DOp>,
140 DequantizeConverter<TFL::Conv3DOp>,
141 DequantizeConverter<TFL::DepthwiseConv2DOp>,
142 DequantizeConverter<TFL::FullyConnectedOp>,
143 DequantizeConverter<TFL::TransposeConvOp>>(ctx);
144 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
145 }
146
147 } // namespace
148
149 std::unique_ptr<OperationPass<func::FuncOp>>
CreateDecomposeHybridQuantizationPass()150 CreateDecomposeHybridQuantizationPass() {
151 return std::make_unique<DecomposeHybridQuantizationPass>();
152 }
153
154 } // namespace TFL
155 } // namespace mlir
156