1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass decomposes dense operations that assume
17 // support for hybrid quantization. These cases cover when a dense operation
18 // (e.g. matmul) has both quantized and unquantized inputs by dequantizing
19 // the quantized inputs, performing the operation in the expressed type, then
20 // requantizing if a quantized output is required.
21 //
22 // The motivation behind these changes is for Dialects that assume only float
23 // or quantized computation, and do not support a mixture of these types on
24 // dense operations. Decomposition allows TFLite to be compiled to these
25 // dialects, such as TOSA.
26 
27 #include <utility>
28 
29 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/Pass/Pass.h"  // from @llvm-project
34 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
36 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
37 
38 namespace mlir {
39 namespace TFL {
40 
41 namespace {
42 
43 #define GEN_PASS_CLASSES
44 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
45 
46 class DecomposeHybridQuantizationPass
47     : public DecomposeHybridQuantizationPassBase<
48           DecomposeHybridQuantizationPass> {
49  public:
DecomposeHybridQuantizationPass()50   explicit DecomposeHybridQuantizationPass() {}
51   void runOnOperation() override;
52 };
53 
54 template <typename SrcOp>
55 class DequantizeConverter : public OpRewritePattern<SrcOp> {
56  public:
57   using OpRewritePattern<SrcOp>::OpRewritePattern;
58 
matchAndRewrite(SrcOp srcop,PatternRewriter & rewriter) const59   LogicalResult matchAndRewrite(SrcOp srcop,
60                                 PatternRewriter &rewriter) const final {
61     Operation *op = srcop.getOperation();
62     bool allTypesFp = true;
63     bool allTypesQuantizedOrInt = true;
64     for (auto operand : op->getOperands()) {
65       ShapedType type = operand.getType().template dyn_cast<ShapedType>();
66       if (!type) continue;
67       allTypesFp &= !type.getElementType().isa<quant::QuantizedType>();
68       allTypesQuantizedOrInt &=
69           (type.getElementType().isa<quant::QuantizedType>() ||
70            type.getElementType().isa<IntegerType>());
71     }
72 
73     for (auto result : op->getResults()) {
74       ShapedType type = result.getType().template cast<ShapedType>();
75       allTypesFp &= !type.getElementType().isa<quant::QuantizedType>();
76       allTypesQuantizedOrInt &=
77           (type.getElementType().isa<quant::QuantizedType>() ||
78            type.getElementType().isa<IntegerType>());
79     }
80 
81     // If all quantized or floating point then types are consistent.
82     // Int is valid in combination with both quantized and floating point.
83     // This occurs when doing qi16 convolution, as bias is passed as a
84     // non-quantized int64
85     if (allTypesFp || allTypesQuantizedOrInt) return failure();
86 
87     Location loc = op->getLoc();
88     SmallVector<Value> newOperands;
89     newOperands.reserve(op->getNumOperands());
90     for (auto operand : op->getOperands()) {
91       if (QuantizedType::getQuantizedElementType(operand.getType())) {
92         auto newTy = QuantizedType::castToExpressedType(operand.getType());
93         newOperands.push_back(
94             rewriter.create<TFL::DequantizeOp>(loc, newTy, operand));
95         continue;
96       }
97 
98       newOperands.push_back(operand);
99     }
100 
101     SmallVector<Type> newResultTys;
102     for (auto result : op->getResults()) {
103       Type resultTy = result.getType();
104       if (QuantizedType::getQuantizedElementType(resultTy)) {
105         resultTy = QuantizedType::castToExpressedType(resultTy);
106       }
107       newResultTys.push_back(resultTy);
108     }
109 
110     auto newResults = rewriter
111                           .create<SrcOp>(loc, newResultTys, newOperands,
112                                          op->getAttrDictionary().getValue())
113                           .getOperation()
114                           ->getResults();
115 
116     SmallVector<Value> replaceResults;
117     for (int i = 0; i < newResults.size(); i++) {
118       Value result = newResults[i];
119       Type resultTy = op->getOpResult(i).getType();
120       if (QuantizedType::getQuantizedElementType(resultTy)) {
121         replaceResults.push_back(rewriter.create<TFL::QuantizeOp>(
122             loc, resultTy, result, TypeAttr::get(resultTy)));
123         continue;
124       }
125 
126       replaceResults.push_back(result);
127     }
128 
129     rewriter.replaceOp(op, replaceResults);
130 
131     return success();
132   }
133 };
134 
runOnOperation()135 void DecomposeHybridQuantizationPass::runOnOperation() {
136   RewritePatternSet patterns(&getContext());
137   auto *ctx = &getContext();
138   auto func = getOperation();
139   patterns.add<DequantizeConverter<TFL::Conv2DOp>,
140                DequantizeConverter<TFL::Conv3DOp>,
141                DequantizeConverter<TFL::DepthwiseConv2DOp>,
142                DequantizeConverter<TFL::FullyConnectedOp>,
143                DequantizeConverter<TFL::TransposeConvOp>>(ctx);
144   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
145 }
146 
147 }  // namespace
148 
149 std::unique_ptr<OperationPass<func::FuncOp>>
CreateDecomposeHybridQuantizationPass()150 CreateDecomposeHybridQuantizationPass() {
151   return std::make_unique<DecomposeHybridQuantizationPass>();
152 }
153 
154 }  // namespace TFL
155 }  // namespace mlir
156