1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
27 #include "mlir/Support/LLVM.h"  // from @llvm-project
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
30 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
31 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
32 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_gpu.h"
33 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 
36 namespace mlir {
37 namespace TFL {
38 namespace tac {
39 namespace {
40 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/generated_transform_patterns.inc"
41 }  // namespace
42 
GetHardwareRewritePatterns(MLIRContext * context,const std::string & hardware)43 RewritePatternSet GetHardwareRewritePatterns(MLIRContext* context,
44                                              const std::string& hardware) {
45   auto* devce_hardware = GetTargetHardware(hardware);
46   if (devce_hardware == nullptr) return {context};
47   return devce_hardware->GetTransformations(context);
48 }
49 
IsSupported(Operation * op,const std::string & hardware)50 bool IsSupported(Operation* op, const std::string& hardware) {
51   auto* devce_hardware = GetTargetHardware(hardware);
52   if (devce_hardware == nullptr) return {};
53   return devce_hardware->IsOpSupported(op);
54 }
55 
56 // ================== Convert Quantized Op ============================
57 
58 // Walk through the func and convert the quantize ops to their float version.
ConvertQuantizedOpToFloat(mlir::func::FuncOp func,OpBuilder * builder)59 void ConvertQuantizedOpToFloat(mlir::func::FuncOp func, OpBuilder* builder) {
60   func.walk([&](Operation* op) {
61     // TODO(renjieliu): Find a generic way to deal with const ops.
62     if (op->hasTrait<OpTrait::IsTerminator>() ||
63         llvm::isa<TFL::QConstOp, TFL::ConstOp>(op) ||
64         llvm::isa<TFL::QConstOp, TFL::ConstOp, TF::ConstOp, ConstOp>(op))
65       return;
66 
67     bool int8_type_observed = false;
68     bool uint8_type_observed = false;
69     for (auto& input : op->getOpOperands()) {
70       auto input_type = input.get().getType();
71       if (IsQI8Type(input_type)) {
72         int8_type_observed = true;
73       } else if (IsQUI8Type(input_type)) {
74         uint8_type_observed = true;
75       }
76     }
77 
78     // TODO(renjieliu): We probably should check whether the op supports float
79     // execution to be safe. Although normally they should support float
80     // execution. Not Quantized ops.
81     if (!int8_type_observed && !uint8_type_observed) return;
82 
83     // Insert dequantize ops for every quantized input.
84     SmallVector<Value, 4> dequantized_inputs;
85     for (auto& input : op->getOpOperands()) {
86       auto input_type = input.get().getType();
87       if (IsQI8Type(input_type) || IsQUI8Type(input_type) ||
88           IsQI32Type(input_type)) {
89         auto dequantized_input_type =
90             mlir::quant::QuantizedType::castToExpressedType(input_type);
91         builder->setInsertionPoint(op);
92         auto dequantize_op = builder->create<TFL::DequantizeOp>(
93             op->getLoc(), dequantized_input_type, input.get());
94         dequantized_inputs.push_back(dequantize_op);
95       } else {
96         dequantized_inputs.push_back(input.get());
97       }
98     }
99 
100     // Result types.
101     SmallVector<Type, 4> result_types;
102     for (auto result_type : op->getResultTypes()) {
103       if (IsQI8Type(result_type) || IsQUI8Type(result_type)) {
104         auto dequantized_result_type =
105             mlir::quant::QuantizedType::castToExpressedType(result_type);
106         result_types.push_back(dequantized_result_type);
107       } else {
108         result_types.push_back(result_type);
109       }
110     }
111 
112     // Build the new float-versioned op.
113     OperationState state(op->getLoc(), op->getName());
114     state.operands = dequantized_inputs;
115     state.types = result_types;
116     state.attributes = op->getAttrs();
117     state.successors = op->getSuccessors();
118     builder->setInsertionPoint(op);
119     Operation* new_op = builder->create(state);
120 
121     // Insert quantize ops for every outputs and rewrite.
122     for (int i = 0; i < op->getNumResults(); ++i) {
123       auto result = op->getResult(i);
124       auto result_type = result.getType();
125 
126       Value new_result = new_op->getResult(i);
127       if (IsQI8Type(result_type) || IsQUI8Type(result_type)) {
128         builder->setInsertionPoint(op);
129         TFL::QuantizeOp quant_op = builder->create<TFL::QuantizeOp>(
130             op->getLoc(), result_type, new_result, TypeAttr::get(result_type));
131         new_result = quant_op.getResult();
132       }
133 
134       // Rewire the outputs.
135       result.replaceAllUsesWith(new_result);
136     }
137 
138     // Remove the old op.
139     op->erase();
140   });
141 }
142 
143 // Fold quantized i32 (normally bias) into their float values.
144 struct FoldQuantizedI32ToFloat : public OpRewritePattern<TFL::DequantizeOp> {
145   using OpRewritePattern<TFL::DequantizeOp>::OpRewritePattern;
146 
matchAndRewritemlir::TFL::tac::FoldQuantizedI32ToFloat147   LogicalResult matchAndRewrite(TFL::DequantizeOp dequant_op,
148                                 PatternRewriter& rewriter) const override {
149     // We only fold i32 -> float pattern.
150     auto input = dequant_op.input().getDefiningOp();
151     if (!input) return failure();
152 
153     auto input_dequant = llvm::dyn_cast_or_null<TFL::QConstOp>(input);
154     if (!input_dequant) return failure();
155 
156     if (!IsQI32Type(input_dequant.getType())) return failure();
157 
158     auto output_type =
159         dequant_op.output().getType().dyn_cast_or_null<ShapedType>();
160     if (!output_type || !output_type.getElementType().isF32()) return failure();
161 
162     auto input_type = input_dequant.getType().dyn_cast<ShapedType>();
163     // TODO(renjieliu): support UniformQuantizedPerAxisType.
164     auto q_type = input_type.getElementType()
165                       .dyn_cast_or_null<quant::UniformQuantizedType>();
166     if (!q_type) return failure();
167 
168     const float scale = q_type.getScale();
169     const float zp = q_type.getZeroPoint();
170 
171     auto input_values = input_dequant.value();
172 
173     // mapValues always takes a function returning APInt, even when the output
174     // is actually float.
175     using DequantizeFuncType = llvm::APInt(const llvm::APInt&);
176     auto dequantize_func = [&](const APInt& ap_int_value) -> APInt {
177       const int64_t int_value = ap_int_value.getSExtValue();
178 
179       const float real = (int_value - zp) * scale;
180 
181       auto real_int = absl::bit_cast<int32_t>(real);
182       return APInt(/*numBits=*/32, real_int);
183     };
184 
185     auto dequant_values =
186         input_values.cast<DenseIntOrFPElementsAttr>().mapValues(
187             FloatType::getF32(rewriter.getContext()),
188             llvm::function_ref<DequantizeFuncType>(dequantize_func));
189     rewriter.replaceOpWithNewOp<TFL::ConstOp>(dequant_op, dequant_op.getType(),
190                                               dequant_values);
191 
192     return success();
193   }
194 };
195 
196 // If the quant op has no consumer, we will remove them.
197 struct RemoveUnusedQuant : public OpRewritePattern<TFL::QuantizeOp> {
198   using OpRewritePattern<TFL::QuantizeOp>::OpRewritePattern;
199 
matchAndRewritemlir::TFL::tac::RemoveUnusedQuant200   LogicalResult matchAndRewrite(TFL::QuantizeOp quant_op,
201                                 PatternRewriter& rewriter) const override {
202     if (!quant_op.getResult().use_empty()) return failure();
203 
204     rewriter.eraseOp(quant_op);
205     return success();
206   }
207 };
208 
OptimizeQuantizedOpToFloat(func::FuncOp func,MLIRContext * context)209 void OptimizeQuantizedOpToFloat(func::FuncOp func, MLIRContext* context) {
210   RewritePatternSet patterns(func.getContext());
211   patterns
212       .add<FoldQuantizedI32ToFloat, FoldQuantizeDequantize, RemoveUnusedQuant>(
213           context);
214   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
215 }
216 
217 }  // namespace tac
218 }  // namespace TFL
219 }  // namespace mlir
220