1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
24 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
26 #include "mlir/IR/PatternMatch.h" // from @llvm-project
27 #include "mlir/Support/LLVM.h" // from @llvm-project
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
30 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
31 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
32 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_gpu.h"
33 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35
36 namespace mlir {
37 namespace TFL {
38 namespace tac {
39 namespace {
40 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/generated_transform_patterns.inc"
41 } // namespace
42
GetHardwareRewritePatterns(MLIRContext * context,const std::string & hardware)43 RewritePatternSet GetHardwareRewritePatterns(MLIRContext* context,
44 const std::string& hardware) {
45 auto* devce_hardware = GetTargetHardware(hardware);
46 if (devce_hardware == nullptr) return {context};
47 return devce_hardware->GetTransformations(context);
48 }
49
IsSupported(Operation * op,const std::string & hardware)50 bool IsSupported(Operation* op, const std::string& hardware) {
51 auto* devce_hardware = GetTargetHardware(hardware);
52 if (devce_hardware == nullptr) return {};
53 return devce_hardware->IsOpSupported(op);
54 }
55
56 // ================== Convert Quantized Op ============================
57
58 // Walk through the func and convert the quantize ops to their float version.
ConvertQuantizedOpToFloat(mlir::func::FuncOp func,OpBuilder * builder)59 void ConvertQuantizedOpToFloat(mlir::func::FuncOp func, OpBuilder* builder) {
60 func.walk([&](Operation* op) {
61 // TODO(renjieliu): Find a generic way to deal with const ops.
62 if (op->hasTrait<OpTrait::IsTerminator>() ||
63 llvm::isa<TFL::QConstOp, TFL::ConstOp>(op) ||
64 llvm::isa<TFL::QConstOp, TFL::ConstOp, TF::ConstOp, ConstOp>(op))
65 return;
66
67 bool int8_type_observed = false;
68 bool uint8_type_observed = false;
69 for (auto& input : op->getOpOperands()) {
70 auto input_type = input.get().getType();
71 if (IsQI8Type(input_type)) {
72 int8_type_observed = true;
73 } else if (IsQUI8Type(input_type)) {
74 uint8_type_observed = true;
75 }
76 }
77
78 // TODO(renjieliu): We probably should check whether the op supports float
79 // execution to be safe. Although normally they should support float
80 // execution. Not Quantized ops.
81 if (!int8_type_observed && !uint8_type_observed) return;
82
83 // Insert dequantize ops for every quantized input.
84 SmallVector<Value, 4> dequantized_inputs;
85 for (auto& input : op->getOpOperands()) {
86 auto input_type = input.get().getType();
87 if (IsQI8Type(input_type) || IsQUI8Type(input_type) ||
88 IsQI32Type(input_type)) {
89 auto dequantized_input_type =
90 mlir::quant::QuantizedType::castToExpressedType(input_type);
91 builder->setInsertionPoint(op);
92 auto dequantize_op = builder->create<TFL::DequantizeOp>(
93 op->getLoc(), dequantized_input_type, input.get());
94 dequantized_inputs.push_back(dequantize_op);
95 } else {
96 dequantized_inputs.push_back(input.get());
97 }
98 }
99
100 // Result types.
101 SmallVector<Type, 4> result_types;
102 for (auto result_type : op->getResultTypes()) {
103 if (IsQI8Type(result_type) || IsQUI8Type(result_type)) {
104 auto dequantized_result_type =
105 mlir::quant::QuantizedType::castToExpressedType(result_type);
106 result_types.push_back(dequantized_result_type);
107 } else {
108 result_types.push_back(result_type);
109 }
110 }
111
112 // Build the new float-versioned op.
113 OperationState state(op->getLoc(), op->getName());
114 state.operands = dequantized_inputs;
115 state.types = result_types;
116 state.attributes = op->getAttrs();
117 state.successors = op->getSuccessors();
118 builder->setInsertionPoint(op);
119 Operation* new_op = builder->create(state);
120
121 // Insert quantize ops for every outputs and rewrite.
122 for (int i = 0; i < op->getNumResults(); ++i) {
123 auto result = op->getResult(i);
124 auto result_type = result.getType();
125
126 Value new_result = new_op->getResult(i);
127 if (IsQI8Type(result_type) || IsQUI8Type(result_type)) {
128 builder->setInsertionPoint(op);
129 TFL::QuantizeOp quant_op = builder->create<TFL::QuantizeOp>(
130 op->getLoc(), result_type, new_result, TypeAttr::get(result_type));
131 new_result = quant_op.getResult();
132 }
133
134 // Rewire the outputs.
135 result.replaceAllUsesWith(new_result);
136 }
137
138 // Remove the old op.
139 op->erase();
140 });
141 }
142
143 // Fold quantized i32 (normally bias) into their float values.
144 struct FoldQuantizedI32ToFloat : public OpRewritePattern<TFL::DequantizeOp> {
145 using OpRewritePattern<TFL::DequantizeOp>::OpRewritePattern;
146
matchAndRewritemlir::TFL::tac::FoldQuantizedI32ToFloat147 LogicalResult matchAndRewrite(TFL::DequantizeOp dequant_op,
148 PatternRewriter& rewriter) const override {
149 // We only fold i32 -> float pattern.
150 auto input = dequant_op.input().getDefiningOp();
151 if (!input) return failure();
152
153 auto input_dequant = llvm::dyn_cast_or_null<TFL::QConstOp>(input);
154 if (!input_dequant) return failure();
155
156 if (!IsQI32Type(input_dequant.getType())) return failure();
157
158 auto output_type =
159 dequant_op.output().getType().dyn_cast_or_null<ShapedType>();
160 if (!output_type || !output_type.getElementType().isF32()) return failure();
161
162 auto input_type = input_dequant.getType().dyn_cast<ShapedType>();
163 // TODO(renjieliu): support UniformQuantizedPerAxisType.
164 auto q_type = input_type.getElementType()
165 .dyn_cast_or_null<quant::UniformQuantizedType>();
166 if (!q_type) return failure();
167
168 const float scale = q_type.getScale();
169 const float zp = q_type.getZeroPoint();
170
171 auto input_values = input_dequant.value();
172
173 // mapValues always takes a function returning APInt, even when the output
174 // is actually float.
175 using DequantizeFuncType = llvm::APInt(const llvm::APInt&);
176 auto dequantize_func = [&](const APInt& ap_int_value) -> APInt {
177 const int64_t int_value = ap_int_value.getSExtValue();
178
179 const float real = (int_value - zp) * scale;
180
181 auto real_int = absl::bit_cast<int32_t>(real);
182 return APInt(/*numBits=*/32, real_int);
183 };
184
185 auto dequant_values =
186 input_values.cast<DenseIntOrFPElementsAttr>().mapValues(
187 FloatType::getF32(rewriter.getContext()),
188 llvm::function_ref<DequantizeFuncType>(dequantize_func));
189 rewriter.replaceOpWithNewOp<TFL::ConstOp>(dequant_op, dequant_op.getType(),
190 dequant_values);
191
192 return success();
193 }
194 };
195
196 // If the quant op has no consumer, we will remove them.
197 struct RemoveUnusedQuant : public OpRewritePattern<TFL::QuantizeOp> {
198 using OpRewritePattern<TFL::QuantizeOp>::OpRewritePattern;
199
matchAndRewritemlir::TFL::tac::RemoveUnusedQuant200 LogicalResult matchAndRewrite(TFL::QuantizeOp quant_op,
201 PatternRewriter& rewriter) const override {
202 if (!quant_op.getResult().use_empty()) return failure();
203
204 rewriter.eraseOp(quant_op);
205 return success();
206 }
207 };
208
OptimizeQuantizedOpToFloat(func::FuncOp func,MLIRContext * context)209 void OptimizeQuantizedOpToFloat(func::FuncOp func, MLIRContext* context) {
210 RewritePatternSet patterns(func.getContext());
211 patterns
212 .add<FoldQuantizedI32ToFloat, FoldQuantizeDequantize, RemoveUnusedQuant>(
213 context);
214 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
215 }
216
217 } // namespace tac
218 } // namespace TFL
219 } // namespace mlir
220