xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass converts a TFLite uint8 graph to the int8 domain, with adaptors at
17 // input and output tensors. This is needed because TOSA precision is
18 // implemented in the int8 domain. This pass does:
19 // 1. match TFL::QConst with uint8, generate TFL::QConst with int8 with value
20 // remapped.
21 // 2. insert tosa.RESCALE uint8 -> int8 if block argument (placeholder of graph)
22 // is uint8 typed.
23 // 3. insert tosa.RESCALE int8 -> uint8 if original returned tensor is uint8
24 // typed.
25 
26 #include <climits>
27 #include <cstddef>
28 #include <cstdint>
29 #include <iterator>
30 #include <numeric>
31 
32 #include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
33 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"  // from @llvm-project
34 #include "mlir/IR/Builders.h"  // from @llvm-project
35 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
36 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
37 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
38 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
39 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
40 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
42 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
43 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
44 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
45 
46 #define PASS_NAME "tosa-convert-tfl-uint8"
47 #define DEBUG_TYPE PASS_NAME
48 
49 namespace mlir {
50 namespace tosa {
51 namespace {
52 
53 class StripQuantTypes : public TosaStripQuantTypesPassBase<StripQuantTypes> {
54  public:
StripQuantTypes()55   explicit StripQuantTypes() {}
56   void runOnOperation() override;
57 };
58 
59 class QuantTypeConverter : public TypeConverter {
60  public:
convertType(Type type)61   static Type convertType(Type type) {
62     if (auto qType = type.dyn_cast<quant::QuantizedType>()) {
63       if (qType.isSigned() || qType.getStorageTypeIntegralWidth() != 8) {
64         return IntegerType::get(type.getContext(),
65                                 qType.getStorageTypeIntegralWidth());
66       }
67 
68       return IntegerType::get(type.getContext(),
69                               qType.getStorageTypeIntegralWidth(),
70                               IntegerType::SignednessSemantics::Unsigned);
71     }
72     return type;
73   }
convertTensor(RankedTensorType type)74   static Type convertTensor(RankedTensorType type) {
75     auto newType = RankedTensorType::get(type.getShape(),
76                                          convertType(type.getElementType()));
77     return newType;
78   }
QuantTypeConverter()79   explicit QuantTypeConverter() {
80     addConversion([](Type type) { return convertType(type); });
81     addConversion(convertTensor);
82   }
83 };
84 
85 // Handles the type conversion component of the TypeConversion. This updates
86 // conversion patterns that used the original Quant types to be updated to
87 // the non-quant variants.
88 class GenericTypeConvert : public ConversionPattern {
89  public:
GenericTypeConvert(MLIRContext * context,TypeConverter & converter)90   GenericTypeConvert(MLIRContext* context, TypeConverter& converter)
91       : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {}
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const92   LogicalResult matchAndRewrite(
93       Operation* op, ArrayRef<Value> operands,
94       ConversionPatternRewriter& rewriter) const override {
95     llvm::SmallVector<Type, 4> newResults;
96     if (isa<func::FuncOp>(op)) {
97       return failure();
98     }
99 
100     (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
101     OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
102                          newResults, op->getAttrs(), op->getSuccessors());
103     for (Region& r : op->getRegions()) {
104       Region* newRegion = state.addRegion();
105       rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
106       TypeConverter::SignatureConversion result(newRegion->getNumArguments());
107       (void)getTypeConverter()->convertSignatureArgs(
108           newRegion->getArgumentTypes(), result);
109       rewriter.applySignatureConversion(newRegion, result);
110     }
111     Operation* newOp = rewriter.create(state);
112     rewriter.replaceOp(op, newOp->getResults());
113     return success();
114   }
115 };
116 
isIllegalType(Type type)117 static bool isIllegalType(Type type) {
118   if (type.isa<quant::QuantizedType>()) return true;
119   if (auto shapedType = type.dyn_cast<ShapedType>()) {
120     return isIllegalType(shapedType.getElementType());
121   }
122   return false;
123 }
124 
runOnOperation()125 void StripQuantTypes::runOnOperation() {
126   QuantTypeConverter converter;
127   ConversionTarget target(getContext());
128 
129   target.addIllegalDialect<quantfork::QuantizationForkDialect>();
130   // Operations are legal if they don't contain any illegal type.
131   target.markUnknownOpDynamicallyLegal([](Operation* op) {
132     if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
133       for (Type type : funcOp.getFunctionType().getInputs()) {
134         if (isIllegalType(type)) return false;
135       }
136       for (Type type : funcOp.getFunctionType().getResults()) {
137         if (isIllegalType(type)) return false;
138       }
139     }
140     for (Type type : op->getResultTypes()) {
141       if (type && isIllegalType(type)) return false;
142     }
143     for (Type type : op->getOperandTypes()) {
144       if (type && isIllegalType(type)) return false;
145     }
146     return true;
147   });
148 
149   auto* ctx = &getContext();
150   auto func = getOperation();
151 
152   RewritePatternSet patterns(&getContext());
153   patterns.add<GenericTypeConvert>(ctx, converter);
154   populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
155                                                                  converter);
156 
157   if (failed(applyFullConversion(func, target, std::move(patterns)))) {
158     signalPassFailure();
159   }
160 }
161 
162 }  // anonymous namespace
163 
createStripQuantTypesPass()164 std::unique_ptr<OperationPass<func::FuncOp>> createStripQuantTypesPass() {
165   return std::make_unique<StripQuantTypes>();
166 }
167 }  // namespace tosa
168 }  // namespace mlir
169