1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This files implements the logic for converting `scf.parallel` loops into
17 // tiled loops.
18
19 #include <cstdint>
20 #include <tuple>
21 #include <utility>
22
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/iterator_range.h"
26 #include "mlir-hlo/Transforms/PassDetail.h"
27 #include "mlir-hlo/Transforms/passes.h"
28 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"
30 #include "mlir/Dialect/MemRef/IR/MemRef.h"
31 #include "mlir/Dialect/SCF/IR/SCF.h"
32 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
33 #include "mlir/Dialect/SCF/Utils/Utils.h"
34 #include "mlir/IR/OperationSupport.h"
35 #include "mlir/IR/Value.h"
36 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
37
38 namespace mlir {
39
40 using ::mlir::scf::ParallelOp;
41
42 namespace {
43
44 // This is the implementation of the TileLoops pass declared in
45 // include/mlir-hlo/Transforms/passes.td
46 class TileLoopsPass : public TileLoopsPassBase<TileLoopsPass> {
47 public:
48 // Creates a TileLoopsPass with tiles sizes provided through `tile_sizes`
49 // and unroll factors provided through `unroll_factors`.
TileLoopsPass(ArrayRef<int64_t> tileSizes,ArrayRef<int64_t> unrollFactors)50 explicit TileLoopsPass(ArrayRef<int64_t> tileSizes,
51 ArrayRef<int64_t> unrollFactors) {
52 tile_sizes_ = tileSizes;
53 unroll_factors_ = unrollFactors;
54 }
55
56 void runOnOperation() override;
57 };
58
59 } // namespace
60
61 // Returns whether the access pattern in `ploop` is "complex". That is, whether
62 // any memref.load op in its region uses indices that don't correspond to the
63 // loop induction variables.
isComplexAccessPattern(ParallelOp ploop)64 static bool isComplexAccessPattern(ParallelOp ploop) {
65 auto isComplex = [&](memref::LoadOp loadOp) {
66 if (!loadOp.getMemRefType().getLayout().isIdentity()) return true;
67 if (loadOp.getIndices().empty()) return false;
68 return loadOp.getIndices() != ploop.getInductionVars();
69 };
70 return llvm::any_of(ploop.getBody()->getOps<memref::LoadOp>(), isComplex);
71 }
72
runOnOperation()73 void TileLoopsPass::runOnOperation() {
74 SmallVector<int64_t> unrolledTile;
75 if (tile_sizes_.size() == unroll_factors_.size()) {
76 unrolledTile.reserve(tile_sizes_.size());
77 for (int64_t i = 0; i < static_cast<int64_t>(tile_sizes_.size()); ++i)
78 unrolledTile.push_back(tile_sizes_[i] * unroll_factors_[i]);
79 }
80
81 SmallVector<ParallelOp, 2> ploops;
82 getInnermostParallelLoops(this->getOperation().getOperation(), ploops);
83 for (ParallelOp ploop : ploops) {
84 // Do not unroll if the tiling and unrolling have different rank, or if
85 // the access pattern is complex.
86 if (unrolledTile.empty() || isComplexAccessPattern(ploop)) {
87 tileParallelLoop(ploop, tile_sizes_, /*noMinMaxBounds=*/false);
88 continue;
89 }
90
91 // Collect lower/upper bounds and step size, if they are constants.
92 auto getConstDefOps = [](OperandRange operands) {
93 return llvm::to_vector(llvm::map_range(operands, [&](Value value) {
94 return value.getDefiningOp<arith::ConstantIndexOp>();
95 }));
96 };
97 auto lower = getConstDefOps(ploop.getLowerBound());
98 auto upper = getConstDefOps(ploop.getUpperBound());
99 auto step = getConstDefOps(ploop.getStep());
100
101 bool noMinMaxBounds = false;
102 ploop = tileParallelLoop(ploop, unrolledTile, noMinMaxBounds).second;
103 ploop = tileParallelLoop(ploop, unroll_factors_, noMinMaxBounds).second;
104
105 // Use static upper bound on unrolled loop if possible. That is, if the
106 // unroll factor evenly divides the iteration size of the outer ploop.
107 OpBuilder builder(ploop);
108 Location loc = ploop.getLoc();
109 for (int64_t i = 0; i < static_cast<int64_t>(unrolledTile.size()); ++i) {
110 if (!lower[i] || !upper[i] || !step[i]) continue;
111 int64_t unrollFactor = unroll_factors_[i];
112 int64_t difference = upper[i].value() - lower[i].value();
113 if (difference % (step[i].value() * unrollFactor) != 0) continue;
114 ploop.getUpperBoundMutable().slice(i, 1).assign(
115 builder.create<arith::ConstantIndexOp>(loc, unrollFactor));
116 }
117 }
118
119 // Apply arithmetic dialect canonicalizations so that
120 // ParallelToGpuLaunchLowering can derive loop-invariant upper bound for
121 // number of iterations.
122 RewritePatternSet patterns(&getContext());
123 getContext()
124 .getOrLoadDialect<arith::ArithmeticDialect>()
125 ->getCanonicalizationPatterns(patterns);
126 if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
127 return signalPassFailure();
128 }
129
createTileLoopsPass(ArrayRef<int64_t> tileSizes,ArrayRef<int64_t> unrollFactors)130 std::unique_ptr<OperationPass<func::FuncOp>> createTileLoopsPass(
131 ArrayRef<int64_t> tileSizes, ArrayRef<int64_t> unrollFactors) {
132 return std::make_unique<TileLoopsPass>(tileSizes, unrollFactors);
133 }
134
135 } // namespace mlir
136