xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <iterator>
18 #include <memory>
19 
20 #include "llvm/Support/raw_ostream.h"
21 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"  // from @llvm-project
22 #include "mlir/Dialect/Affine/LoopUtils.h"  // from @llvm-project
23 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
24 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
25 #include "mlir/Dialect/SCF/IR/SCF.h"  // from @llvm-project
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
28 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
29 #include "mlir/IR/Matchers.h"  // from @llvm-project
30 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
31 #include "mlir/IR/Region.h"  // from @llvm-project
32 #include "mlir/Support/LLVM.h"  // from @llvm-project
33 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
34 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
37 #include "tensorflow/compiler/mlir/tfr/passes/passes.h"
38 
39 //===----------------------------------------------------------------------===//
40 // Canonicalization patterns for the scf.for and scf.if ops. They are used to
41 // optimize the control flow in the tfr function. Technically, both patterns
42 // should be upstreamed to be part of the op definition.
43 // TODO(fengliuai): sync with the llvm upstream for both patterns.
44 //
45 namespace mlir {
46 namespace TFR {
47 
48 namespace {
49 
50 class UnrollSCFForOp : public OpRewritePattern<scf::ForOp> {
51   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
52 
53  public:
matchAndRewrite(scf::ForOp for_op,PatternRewriter & rewriter) const54   LogicalResult matchAndRewrite(scf::ForOp for_op,
55                                 PatternRewriter &rewriter) const override {
56     Location loc = for_op.getLoc();
57     APInt lower_bound, upper_bound, step;
58     if (!matchPattern(for_op.getLowerBound(), m_ConstantInt(&lower_bound)) ||
59         !matchPattern(for_op.getUpperBound(), m_ConstantInt(&upper_bound)) ||
60         !matchPattern(for_op.getStep(), m_ConstantInt(&step))) {
61       return failure();
62     }
63     uint64_t trip_count = (upper_bound - lower_bound).sdiv(step).getZExtValue();
64     if (trip_count <= 0) return failure();
65 
66     // TODO(fengliuai): use loopUnrollByFactor once the iter_arg is supported
67 
68     Block *single_block = for_op.getBody();
69     BlockAndValueMapping mapping;
70     Value iv = for_op.getInductionVar();
71     for (auto iter_op :
72          llvm::zip(for_op.getRegionIterArgs(), for_op.getInitArgs())) {
73       mapping.map(std::get<0>(iter_op), std::get<1>(iter_op));
74     }
75     mapping.map(iv, for_op.getLowerBound());
76     for (auto i = 0; i < trip_count; ++i) {
77       if (!iv.use_empty()) {
78         // iv' = iv + step * i;
79         Value iter = rewriter.create<arith::ConstantIndexOp>(loc, i);
80         Value step_cst =
81             rewriter.create<arith::ConstantIndexOp>(loc, step.getSExtValue());
82         Value stride = rewriter.create<arith::MulIOp>(loc, step_cst, iter);
83         Value iv_unroll =
84             rewriter.create<arith::AddIOp>(loc, mapping.lookup(iv), stride);
85         mapping.map(iv, iv_unroll);
86       }
87 
88       Operation *terminator_op;
89       for (auto it = single_block->begin(); it != single_block->end(); ++it) {
90         terminator_op = rewriter.clone(*it, mapping);
91       }
92       // Map the block arguments to the yield results.
93       for (auto iter_op : llvm::zip(for_op.getRegionIterArgs(),
94                                     terminator_op->getOperands())) {
95         mapping.map(std::get<0>(iter_op), std::get<1>(iter_op));
96       }
97       rewriter.eraseOp(terminator_op);
98     }
99     SmallVector<Value, 4> returned;
100     for (Value arg : for_op.getRegionIterArgs()) {
101       returned.push_back(mapping.lookup(arg));
102     }
103     rewriter.replaceOp(for_op, returned);
104     return success();
105   }
106 };
107 
108 // TODO(fengliuai): up stream this pattern.
109 class SimplifySCFIfOp : public OpRewritePattern<scf::IfOp> {
110   using OpRewritePattern<scf::IfOp>::OpRewritePattern;
111 
112  public:
matchAndRewrite(scf::IfOp if_op,PatternRewriter & rewriter) const113   LogicalResult matchAndRewrite(scf::IfOp if_op,
114                                 PatternRewriter &rewriter) const override {
115     // Then branch
116     if (matchPattern(if_op.getCondition(), m_NonZero())) {
117       return InlineRegion(if_op.getLoc(), rewriter, if_op,
118                           &if_op.getThenRegion());
119     }
120 
121     // Else branch
122     if (matchPattern(if_op.getCondition(), m_Zero())) {
123       if (if_op.getElseRegion().empty()) {
124         // Remove the op
125         rewriter.eraseOp(if_op);
126         return success();
127       } else {
128         return InlineRegion(if_op.getLoc(), rewriter, if_op,
129                             &if_op.getElseRegion());
130       }
131     }
132 
133     // Not a constant condition
134     return failure();
135   }
136 
137  private:
138   LogicalResult InlineRegion(Location loc, PatternRewriter &rewriter,
139                              Operation *inline_point, Region *region) const;
140 };
141 
InlineRegion(Location loc,PatternRewriter & rewriter,Operation * inline_point,Region * region) const142 LogicalResult SimplifySCFIfOp::InlineRegion(Location loc,
143                                             PatternRewriter &rewriter,
144                                             Operation *inline_point,
145                                             Region *region) const {
146   InlinerInterface interface(loc.getContext());
147   if (failed(inlineRegion(interface, region, inline_point, {},
148                           inline_point->getResults(), loc,
149                           /*shouldCloneInlinedRegion=*/true))) {
150     return failure();
151   }
152 
153   // If the inlining was successful then erase the scf.if op.
154   rewriter.eraseOp(inline_point);
155   return success();
156 }
157 
158 }  // namespace
159 
populateCanonicalizationPatterns(func::FuncOp func,RewritePatternSet & patterns)160 void populateCanonicalizationPatterns(func::FuncOp func,
161                                       RewritePatternSet &patterns) {
162   MLIRContext *context = func.getContext();
163   mlir::Dialect *tf = context->getLoadedDialect<mlir::TF::TensorFlowDialect>();
164   // Load all official canonicalization patterns. Here we skip the
165   // canonicalization of the ops in the tf dialect, because they couldn't
166   // propagate the attributes correctly. These optimization will be played by
167   // bridge.
168   func->walk([&](Operation *op) {
169     if (op->getDialect() != tf) {
170       op->getRegisteredInfo()->getCanonicalizationPatterns(patterns, context);
171     }
172   });
173   patterns.add<UnrollSCFForOp, SimplifySCFIfOp>(context);
174 }
175 
176 }  // namespace TFR
177 }  // namespace mlir
178