1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for fusing linalg ops obtained after LHLO
17 // lowering.
18
19 #include <utility>
20
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "mlir-hlo/Dialect/lhlo/transforms/PassDetail.h"
24 #include "mlir-hlo/Dialect/lhlo/transforms/passes.h"
25 #include "mlir/Dialect/Affine/IR/AffineOps.h"
26 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"
28 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
29 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
30 #include "mlir/Dialect/MemRef/IR/MemRef.h"
31 #include "mlir/Dialect/SCF/IR/SCF.h"
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"
33 #include "mlir/Interfaces/ViewLikeInterface.h"
34 #include "mlir/Pass/Pass.h"
35 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
36
37 namespace mlir {
38 namespace lmhlo {
39 namespace {
40
41 using linalg::LinalgOp;
42
43 class LhloFuseLinalgPass : public LhloFuseLinalgPassBase<LhloFuseLinalgPass> {
getDependentDialects(DialectRegistry & registry) const44 void getDependentDialects(DialectRegistry& registry) const override {
45 registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
46 }
47
48 public:
49 LhloFuseLinalgPass() = default;
50 LhloFuseLinalgPass(const LhloFuseLinalgPass&) = default;
LhloFuseLinalgPass(bool useParallelLoops,llvm::ArrayRef<unsigned> tileSizes)51 LhloFuseLinalgPass(bool useParallelLoops,
52 llvm::ArrayRef<unsigned> tileSizes) {
53 tile_sizes_ = tileSizes;
54 use_parallel_loops_.setValue(useParallelLoops);
55 }
56
runOnOperation()57 void runOnOperation() override {
58 auto func = getOperation();
59
60 // TODO(pifon): Remove assumption that the function has a single block.
61 if (!llvm::hasSingleElement(func)) {
62 emitError(func.getLoc(), "The function needs to have a single block.");
63 signalPassFailure();
64 return;
65 }
66
67 // The fusion in Linalg is currently possible only when the consumer op is
68 // tiled. In order to greedily fuse the ops, we have to start from the tiled
69 // root linalg ops, i.e. linalg ops that write to output buffers of the
70 // function or are returned in case of escaping allocations.
71 llvm::SmallDenseSet<Value> resultBuffers;
72 for (auto funcArg : func.getArguments()) {
73 resultBuffers.insert(funcArg);
74 }
75 for (auto& block : func) {
76 auto returnOp =
77 mlir::dyn_cast<mlir::func::ReturnOp>(block.getTerminator());
78 if (!returnOp) continue;
79 for (auto operand : returnOp.getOperands()) {
80 resultBuffers.insert(operand);
81 }
82 }
83 // Resolve aliasing operations (like casts) on the result to identify
84 // results. This only handles escaping results.
85 // TODO(herhut): Use BufferizeAliasAnalysis for this.
86 llvm::SmallVector<Value, 4> worklist(resultBuffers.begin(),
87 resultBuffers.end());
88 while (!worklist.empty()) {
89 Value result = worklist.pop_back_val();
90 auto* definingOp = result.getDefiningOp();
91 if (!definingOp) {
92 continue;
93 }
94
95 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
96 auto alias = viewLike.getViewSource();
97 if (resultBuffers.insert(alias).second) {
98 worklist.push_back(alias);
99 }
100 continue;
101 }
102
103 if (auto toTensor = dyn_cast<bufferization::ToTensorOp>(definingOp)) {
104 auto alias = toTensor.getMemref();
105 if (resultBuffers.insert(alias).second) {
106 worklist.push_back(alias);
107 }
108 continue;
109 }
110
111 if (auto toMemref = dyn_cast<bufferization::ToMemrefOp>(definingOp)) {
112 auto alias = toMemref.getTensor();
113 if (resultBuffers.insert(alias).second) {
114 worklist.push_back(alias);
115 }
116 continue;
117 }
118
119 if (auto tensorCast = dyn_cast<tensor::CastOp>(definingOp)) {
120 auto alias = tensorCast.getSource();
121 if (resultBuffers.insert(alias).second) {
122 worklist.push_back(alias);
123 }
124 continue;
125 }
126
127 if (auto regionInterface =
128 dyn_cast<RegionBranchOpInterface>(definingOp)) {
129 for (Region& region : regionInterface.getOperation()->getRegions()) {
130 // Only consider regions that can return to the parent region.
131 SmallVector<RegionSuccessor, 2> successorRegions;
132 regionInterface.getSuccessorRegions(region.getRegionNumber(),
133 successorRegions);
134 if (llvm::none_of(successorRegions, [&](auto successorRegion) {
135 return successorRegion.isParent();
136 }))
137 continue;
138
139 // Iterate over all immediate terminators and record the values
140 // corresponding to result_buffers of interest.
141 for (Block& block : region) {
142 if (block.empty()) continue;
143 Operation& operation = block.back();
144 if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
145 auto idx = result.dyn_cast<OpResult>().getResultNumber();
146 if (resultBuffers.insert(operation.getOperand(idx)).second) {
147 worklist.push_back(operation.getOperand(idx));
148 }
149 }
150 }
151 }
152 }
153
154 MLIRContext* ctx = func.getContext();
155 OpBuilder b(func);
156 func.walk([&](linalg::GenericOp genericOp) {
157 SmallVector<int64_t, 2> tileSizes(tile_sizes_.begin(), tile_sizes_.end());
158 if (tileSizes.empty()) {
159 tileSizes = SmallVector<int64_t, 2>(genericOp.getNumLoops(), 1);
160 }
161 auto op = cast<LinalgOp>(genericOp.getOperation());
162 for (OpOperand* opOperand : op.getOutputBufferOperands()) {
163 if (!resultBuffers.count(opOperand->get())) continue;
164 if (tileGenericOp(op, tileSizes, &b)) {
165 genericOp.erase();
166 return;
167 }
168 }
169 });
170 auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
171 if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
172 return signalPassFailure();
173
174 // Fuse producers of tiled linalg ops.
175 llvm::SmallDenseSet<Operation*> eraseSet;
176 SmallVector<LinalgOp, 8> linalgOps;
177 func.walk([&](LinalgOp op) { linalgOps.push_back(op); });
178 for (LinalgOp op : llvm::reverse(linalgOps)) {
179 for (OpOperand* inputOperand : op.getInputOperands()) {
180 linalg::Aliases aliases;
181 linalg::LinalgDependenceGraph graph(aliases, linalgOps);
182 auto info = fuseProducerOfBuffer(b, *inputOperand, graph);
183 if (failed(info)) continue;
184 auto* originalOp = info->originalProducer.getOperation();
185 eraseSet.insert(originalOp);
186 auto* originalOpInLinalgOpsVector =
187 std::find_if(linalgOps.begin(), linalgOps.end(),
188 [&](const Operation* op) { return op == originalOp; });
189 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
190 }
191
192 auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
193 if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
194 return signalPassFailure();
195 }
196 for (auto* e : eraseSet) e->erase();
197 }
198
199 private:
tileGenericOp(LinalgOp op,ArrayRef<int64_t> tileSizes,OpBuilder * b)200 bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tileSizes, OpBuilder* b) {
201 auto loopType = use_parallel_loops_
202 ? linalg::LinalgTilingLoopType::ParallelLoops
203 : linalg::LinalgTilingLoopType::Loops;
204 IRRewriter rewriter(*b);
205 return succeeded(linalg::tileLinalgOp(
206 rewriter, op,
207 linalg::LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(
208 loopType)));
209 }
210 };
211
212 } // namespace
213
createLhloFuseLinalgPass(bool useParallelLoops,ArrayRef<unsigned> tileSizes)214 std::unique_ptr<OperationPass<func::FuncOp>> createLhloFuseLinalgPass(
215 bool useParallelLoops, ArrayRef<unsigned> tileSizes) {
216 return std::make_unique<LhloFuseLinalgPass>(useParallelLoops, tileSizes);
217 }
218
219 } // namespace lmhlo
220 } // namespace mlir
221