1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <utility>
17
18 #include "mlir/IR/Builders.h" // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
20 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
21 #include "mlir/Pass/Pass.h" // from @llvm-project
22 #include "mlir/Support/LogicalResult.h" // from @llvm-project
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
25 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
26 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
27
28 namespace mlir {
29 namespace TFL {
30 namespace {
31 #define GEN_PASS_CLASSES
32 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
33
34 // Dequantize ops will produce 3x larger tensors, so we want to move it after
35 // some passthrough ops to reduce the memory consumption.
36 struct PushDownDequantize : public OpRewritePattern<DequantizeOp> {
PushDownDequantizemlir::TFL::__anon8a4b0e310111::PushDownDequantize37 explicit PushDownDequantize(MLIRContext* context)
38 : OpRewritePattern<DequantizeOp>(context) {}
39
matchAndRewritemlir::TFL::__anon8a4b0e310111::PushDownDequantize40 LogicalResult matchAndRewrite(DequantizeOp dequantize_op,
41 PatternRewriter& rewriter) const override {
42 if (!dequantize_op->hasOneUse()) return failure();
43
44 auto use = dequantize_op->use_begin();
45 Operation* passthrough_op = use->getOwner();
46 unsigned operand_index = use->getOperandNumber();
47 if (passthrough_op->hasTrait<OpTrait::IsTerminator>()) return failure();
48
49 auto get_num_elements = [](RankedTensorType tensor) {
50 int num_elements = 1;
51 for (int i = 0; i < tensor.getRank(); ++i) {
52 // Assume dynamic dim size as the dim size one.
53 if (!tensor.isDynamicDim(i)) {
54 num_elements *= tensor.getDimSize(i);
55 }
56 }
57 return num_elements;
58 };
59
60 // If the op is the pass-through op with (3x) smaller output, the dequantize
61 // op can be pushed down to the single result of this op.
62 if (!llvm::dyn_cast<mlir::SameScalesOpInterface>(passthrough_op) ||
63 passthrough_op->getNumResults() != 1) {
64 return failure();
65 }
66 // Only push down the dequantize op when the output is smaller, so that it
67 // can have smaller memory usage.
68 auto input_type =
69 dequantize_op.output().getType().dyn_cast<RankedTensorType>();
70 auto output_type =
71 passthrough_op->getResult(0).getType().dyn_cast<RankedTensorType>();
72 if (!input_type || !output_type ||
73 get_num_elements(input_type) <= get_num_elements(output_type)) {
74 return failure();
75 }
76 Type input_element_type = getElementTypeOrSelf(dequantize_op.input());
77 // Most passthrough ops do not support F16.
78 if (input_element_type.isF16()) {
79 return failure();
80 }
81
82 // Set the output type of the dequantize op and push it down.
83 dequantize_op.output().setType(output_type);
84 passthrough_op->replaceAllUsesWith(dequantize_op);
85
86 // Set the input type of the passthrough op and pull it up.
87 Type new_output_type;
88 if (input_element_type.isa<quant::QuantizedType>()) {
89 new_output_type = QuantizedType::getQuantizedElementType(
90 dequantize_op.input().getType())
91 .castFromExpressedType(output_type);
92 } else {
93 llvm_unreachable("unhandled element type");
94 }
95
96 passthrough_op->getResult(0).setType(new_output_type);
97 passthrough_op->setOperand(operand_index, dequantize_op.input());
98
99 // Set the input of the dequantize to the result of the passthrough op.
100 // And switch the order of the ops.
101 dequantize_op->setOperand(0, passthrough_op->getResult(0));
102 dequantize_op->moveAfter(passthrough_op);
103 return success();
104 }
105 };
106
107 struct OptimizeOpOrderPass
108 : public OptimizeOpOrderPassBase<OptimizeOpOrderPass> {
109 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizeOpOrderPass)
110
111 void runOnOperation() override;
112 };
113
runOnOperation()114 void OptimizeOpOrderPass::runOnOperation() {
115 RewritePatternSet patterns(&getContext());
116 auto func = getOperation();
117 auto* ctx = func.getContext();
118 patterns.add<PushDownDequantize>(ctx);
119 if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
120 signalPassFailure();
121 }
122 }
123 } // namespace
124
125 // Creates an instance of the TensorFlow Lite optimize op order pass.
CreateOptimizeOpOrderPass()126 std::unique_ptr<OperationPass<func::FuncOp>> CreateOptimizeOpOrderPass() {
127 return std::make_unique<OptimizeOpOrderPass>();
128 }
129
130 static PassRegistration<OptimizeOpOrderPass> pass;
131
132 } // namespace TFL
133 } // namespace mlir
134