1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <iterator>
18 #include <memory>
19
20 #include "llvm/Support/raw_ostream.h"
21 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project
22 #include "mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project
23 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
24 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25 #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
26 #include "mlir/IR/Attributes.h" // from @llvm-project
27 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
28 #include "mlir/IR/MLIRContext.h" // from @llvm-project
29 #include "mlir/IR/Matchers.h" // from @llvm-project
30 #include "mlir/IR/PatternMatch.h" // from @llvm-project
31 #include "mlir/IR/Region.h" // from @llvm-project
32 #include "mlir/Support/LLVM.h" // from @llvm-project
33 #include "mlir/Support/LogicalResult.h" // from @llvm-project
34 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
37 #include "tensorflow/compiler/mlir/tfr/passes/passes.h"
38
39 //===----------------------------------------------------------------------===//
40 // Canonicalization patterns for the scf.for and scf.if ops. They are used to
41 // optimize the control flow in the tfr function. Technically, both patterns
42 // should be upstreamed to be part of the op definition.
43 // TODO(fengliuai): sync with the llvm upstream for both patterns.
44 //
45 namespace mlir {
46 namespace TFR {
47
48 namespace {
49
50 class UnrollSCFForOp : public OpRewritePattern<scf::ForOp> {
51 using OpRewritePattern<scf::ForOp>::OpRewritePattern;
52
53 public:
matchAndRewrite(scf::ForOp for_op,PatternRewriter & rewriter) const54 LogicalResult matchAndRewrite(scf::ForOp for_op,
55 PatternRewriter &rewriter) const override {
56 Location loc = for_op.getLoc();
57 APInt lower_bound, upper_bound, step;
58 if (!matchPattern(for_op.getLowerBound(), m_ConstantInt(&lower_bound)) ||
59 !matchPattern(for_op.getUpperBound(), m_ConstantInt(&upper_bound)) ||
60 !matchPattern(for_op.getStep(), m_ConstantInt(&step))) {
61 return failure();
62 }
63 uint64_t trip_count = (upper_bound - lower_bound).sdiv(step).getZExtValue();
64 if (trip_count <= 0) return failure();
65
66 // TODO(fengliuai): use loopUnrollByFactor once the iter_arg is supported
67
68 Block *single_block = for_op.getBody();
69 BlockAndValueMapping mapping;
70 Value iv = for_op.getInductionVar();
71 for (auto iter_op :
72 llvm::zip(for_op.getRegionIterArgs(), for_op.getInitArgs())) {
73 mapping.map(std::get<0>(iter_op), std::get<1>(iter_op));
74 }
75 mapping.map(iv, for_op.getLowerBound());
76 for (auto i = 0; i < trip_count; ++i) {
77 if (!iv.use_empty()) {
78 // iv' = iv + step * i;
79 Value iter = rewriter.create<arith::ConstantIndexOp>(loc, i);
80 Value step_cst =
81 rewriter.create<arith::ConstantIndexOp>(loc, step.getSExtValue());
82 Value stride = rewriter.create<arith::MulIOp>(loc, step_cst, iter);
83 Value iv_unroll =
84 rewriter.create<arith::AddIOp>(loc, mapping.lookup(iv), stride);
85 mapping.map(iv, iv_unroll);
86 }
87
88 Operation *terminator_op;
89 for (auto it = single_block->begin(); it != single_block->end(); ++it) {
90 terminator_op = rewriter.clone(*it, mapping);
91 }
92 // Map the block arguments to the yield results.
93 for (auto iter_op : llvm::zip(for_op.getRegionIterArgs(),
94 terminator_op->getOperands())) {
95 mapping.map(std::get<0>(iter_op), std::get<1>(iter_op));
96 }
97 rewriter.eraseOp(terminator_op);
98 }
99 SmallVector<Value, 4> returned;
100 for (Value arg : for_op.getRegionIterArgs()) {
101 returned.push_back(mapping.lookup(arg));
102 }
103 rewriter.replaceOp(for_op, returned);
104 return success();
105 }
106 };
107
108 // TODO(fengliuai): up stream this pattern.
109 class SimplifySCFIfOp : public OpRewritePattern<scf::IfOp> {
110 using OpRewritePattern<scf::IfOp>::OpRewritePattern;
111
112 public:
matchAndRewrite(scf::IfOp if_op,PatternRewriter & rewriter) const113 LogicalResult matchAndRewrite(scf::IfOp if_op,
114 PatternRewriter &rewriter) const override {
115 // Then branch
116 if (matchPattern(if_op.getCondition(), m_NonZero())) {
117 return InlineRegion(if_op.getLoc(), rewriter, if_op,
118 &if_op.getThenRegion());
119 }
120
121 // Else branch
122 if (matchPattern(if_op.getCondition(), m_Zero())) {
123 if (if_op.getElseRegion().empty()) {
124 // Remove the op
125 rewriter.eraseOp(if_op);
126 return success();
127 } else {
128 return InlineRegion(if_op.getLoc(), rewriter, if_op,
129 &if_op.getElseRegion());
130 }
131 }
132
133 // Not a constant condition
134 return failure();
135 }
136
137 private:
138 LogicalResult InlineRegion(Location loc, PatternRewriter &rewriter,
139 Operation *inline_point, Region *region) const;
140 };
141
InlineRegion(Location loc,PatternRewriter & rewriter,Operation * inline_point,Region * region) const142 LogicalResult SimplifySCFIfOp::InlineRegion(Location loc,
143 PatternRewriter &rewriter,
144 Operation *inline_point,
145 Region *region) const {
146 InlinerInterface interface(loc.getContext());
147 if (failed(inlineRegion(interface, region, inline_point, {},
148 inline_point->getResults(), loc,
149 /*shouldCloneInlinedRegion=*/true))) {
150 return failure();
151 }
152
153 // If the inlining was successful then erase the scf.if op.
154 rewriter.eraseOp(inline_point);
155 return success();
156 }
157
158 } // namespace
159
populateCanonicalizationPatterns(func::FuncOp func,RewritePatternSet & patterns)160 void populateCanonicalizationPatterns(func::FuncOp func,
161 RewritePatternSet &patterns) {
162 MLIRContext *context = func.getContext();
163 mlir::Dialect *tf = context->getLoadedDialect<mlir::TF::TensorFlowDialect>();
164 // Load all official canonicalization patterns. Here we skip the
165 // canonicalization of the ops in the tf dialect, because they couldn't
166 // propagate the attributes correctly. These optimization will be played by
167 // bridge.
168 func->walk([&](Operation *op) {
169 if (op->getDialect() != tf) {
170 op->getRegisteredInfo()->getCanonicalizationPatterns(patterns, context);
171 }
172 });
173 patterns.add<UnrollSCFForOp, SimplifySCFIfOp>(context);
174 }
175
176 } // namespace TFR
177 } // namespace mlir
178