1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass converts a TFLite uint8 graph to the int8 domain, with adaptors at
17 // input and output tensors. This is needed because TOSA precision is
18 // implemented in the int8 domain. This pass does:
19 // 1. match TFL::QConst with uint8, generate TFL::QConst with int8 with value
20 // remapped.
21 // 2. insert tosa.RESCALE uint8 -> int8 if block argument (placeholder of graph)
22 // is uint8 typed.
23 // 3. insert tosa.RESCALE int8 -> uint8 if original returned tensor is uint8
24 // typed.
25
26 #include <climits>
27 #include <cstddef>
28 #include <cstdint>
29 #include <iterator>
30 #include <numeric>
31
32 #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
33 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project
34 #include "mlir/IR/Builders.h" // from @llvm-project
35 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
36 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
37 #include "mlir/IR/PatternMatch.h" // from @llvm-project
38 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
39 #include "mlir/Support/LogicalResult.h" // from @llvm-project
40 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
42 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
43 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
44 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
45
46 #define PASS_NAME "tosa-convert-tfl-uint8"
47 #define DEBUG_TYPE PASS_NAME
48
49 namespace mlir {
50 namespace tosa {
51 namespace {
52
53 class StripQuantTypes : public TosaStripQuantTypesPassBase<StripQuantTypes> {
54 public:
StripQuantTypes()55 explicit StripQuantTypes() {}
56 void runOnOperation() override;
57 };
58
59 class QuantTypeConverter : public TypeConverter {
60 public:
convertType(Type type)61 static Type convertType(Type type) {
62 if (auto qType = type.dyn_cast<quant::QuantizedType>()) {
63 if (qType.isSigned() || qType.getStorageTypeIntegralWidth() != 8) {
64 return IntegerType::get(type.getContext(),
65 qType.getStorageTypeIntegralWidth());
66 }
67
68 return IntegerType::get(type.getContext(),
69 qType.getStorageTypeIntegralWidth(),
70 IntegerType::SignednessSemantics::Unsigned);
71 }
72 return type;
73 }
convertTensor(RankedTensorType type)74 static Type convertTensor(RankedTensorType type) {
75 auto newType = RankedTensorType::get(type.getShape(),
76 convertType(type.getElementType()));
77 return newType;
78 }
QuantTypeConverter()79 explicit QuantTypeConverter() {
80 addConversion([](Type type) { return convertType(type); });
81 addConversion(convertTensor);
82 }
83 };
84
85 // Handles the type conversion component of the TypeConversion. This updates
86 // conversion patterns that used the original Quant types to be updated to
87 // the non-quant variants.
88 class GenericTypeConvert : public ConversionPattern {
89 public:
GenericTypeConvert(MLIRContext * context,TypeConverter & converter)90 GenericTypeConvert(MLIRContext* context, TypeConverter& converter)
91 : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {}
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const92 LogicalResult matchAndRewrite(
93 Operation* op, ArrayRef<Value> operands,
94 ConversionPatternRewriter& rewriter) const override {
95 llvm::SmallVector<Type, 4> newResults;
96 if (isa<func::FuncOp>(op)) {
97 return failure();
98 }
99
100 (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
101 OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
102 newResults, op->getAttrs(), op->getSuccessors());
103 for (Region& r : op->getRegions()) {
104 Region* newRegion = state.addRegion();
105 rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
106 TypeConverter::SignatureConversion result(newRegion->getNumArguments());
107 (void)getTypeConverter()->convertSignatureArgs(
108 newRegion->getArgumentTypes(), result);
109 rewriter.applySignatureConversion(newRegion, result);
110 }
111 Operation* newOp = rewriter.create(state);
112 rewriter.replaceOp(op, newOp->getResults());
113 return success();
114 }
115 };
116
isIllegalType(Type type)117 static bool isIllegalType(Type type) {
118 if (type.isa<quant::QuantizedType>()) return true;
119 if (auto shapedType = type.dyn_cast<ShapedType>()) {
120 return isIllegalType(shapedType.getElementType());
121 }
122 return false;
123 }
124
runOnOperation()125 void StripQuantTypes::runOnOperation() {
126 QuantTypeConverter converter;
127 ConversionTarget target(getContext());
128
129 target.addIllegalDialect<quantfork::QuantizationForkDialect>();
130 // Operations are legal if they don't contain any illegal type.
131 target.markUnknownOpDynamicallyLegal([](Operation* op) {
132 if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
133 for (Type type : funcOp.getFunctionType().getInputs()) {
134 if (isIllegalType(type)) return false;
135 }
136 for (Type type : funcOp.getFunctionType().getResults()) {
137 if (isIllegalType(type)) return false;
138 }
139 }
140 for (Type type : op->getResultTypes()) {
141 if (type && isIllegalType(type)) return false;
142 }
143 for (Type type : op->getOperandTypes()) {
144 if (type && isIllegalType(type)) return false;
145 }
146 return true;
147 });
148
149 auto* ctx = &getContext();
150 auto func = getOperation();
151
152 RewritePatternSet patterns(&getContext());
153 patterns.add<GenericTypeConvert>(ctx, converter);
154 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
155 converter);
156
157 if (failed(applyFullConversion(func, target, std::move(patterns)))) {
158 signalPassFailure();
159 }
160 }
161
162 } // anonymous namespace
163
createStripQuantTypesPass()164 std::unique_ptr<OperationPass<func::FuncOp>> createStripQuantTypesPass() {
165 return std::make_unique<StripQuantTypes>();
166 }
167 } // namespace tosa
168 } // namespace mlir
169