1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for fusing linalg ops obtained after LHLO
17 // lowering.
18 
19 #include <utility>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "mlir-hlo/Dialect/lhlo/transforms/PassDetail.h"
24 #include "mlir-hlo/Dialect/lhlo/transforms/passes.h"
25 #include "mlir/Dialect/Affine/IR/AffineOps.h"
26 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"
28 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
29 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
30 #include "mlir/Dialect/MemRef/IR/MemRef.h"
31 #include "mlir/Dialect/SCF/IR/SCF.h"
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"
33 #include "mlir/Interfaces/ViewLikeInterface.h"
34 #include "mlir/Pass/Pass.h"
35 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
36 
37 namespace mlir {
38 namespace lmhlo {
39 namespace {
40 
41 using linalg::LinalgOp;
42 
43 class LhloFuseLinalgPass : public LhloFuseLinalgPassBase<LhloFuseLinalgPass> {
getDependentDialects(DialectRegistry & registry) const44   void getDependentDialects(DialectRegistry& registry) const override {
45     registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
46   }
47 
48  public:
49   LhloFuseLinalgPass() = default;
50   LhloFuseLinalgPass(const LhloFuseLinalgPass&) = default;
LhloFuseLinalgPass(bool useParallelLoops,llvm::ArrayRef<unsigned> tileSizes)51   LhloFuseLinalgPass(bool useParallelLoops,
52                      llvm::ArrayRef<unsigned> tileSizes) {
53     tile_sizes_ = tileSizes;
54     use_parallel_loops_.setValue(useParallelLoops);
55   }
56 
runOnOperation()57   void runOnOperation() override {
58     auto func = getOperation();
59 
60     // TODO(pifon): Remove assumption that the function has a single block.
61     if (!llvm::hasSingleElement(func)) {
62       emitError(func.getLoc(), "The function needs to have a single block.");
63       signalPassFailure();
64       return;
65     }
66 
67     // The fusion in Linalg is currently possible only when the consumer op is
68     // tiled. In order to greedily fuse the ops, we have to start from the tiled
69     // root linalg ops, i.e. linalg ops that write to output buffers of the
70     // function or are returned in case of escaping allocations.
71     llvm::SmallDenseSet<Value> resultBuffers;
72     for (auto funcArg : func.getArguments()) {
73       resultBuffers.insert(funcArg);
74     }
75     for (auto& block : func) {
76       auto returnOp =
77           mlir::dyn_cast<mlir::func::ReturnOp>(block.getTerminator());
78       if (!returnOp) continue;
79       for (auto operand : returnOp.getOperands()) {
80         resultBuffers.insert(operand);
81       }
82     }
83     // Resolve aliasing operations (like casts) on the result to identify
84     // results. This only handles escaping results.
85     // TODO(herhut): Use BufferizeAliasAnalysis for this.
86     llvm::SmallVector<Value, 4> worklist(resultBuffers.begin(),
87                                          resultBuffers.end());
88     while (!worklist.empty()) {
89       Value result = worklist.pop_back_val();
90       auto* definingOp = result.getDefiningOp();
91       if (!definingOp) {
92         continue;
93       }
94 
95       if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
96         auto alias = viewLike.getViewSource();
97         if (resultBuffers.insert(alias).second) {
98           worklist.push_back(alias);
99         }
100         continue;
101       }
102 
103       if (auto toTensor = dyn_cast<bufferization::ToTensorOp>(definingOp)) {
104         auto alias = toTensor.getMemref();
105         if (resultBuffers.insert(alias).second) {
106           worklist.push_back(alias);
107         }
108         continue;
109       }
110 
111       if (auto toMemref = dyn_cast<bufferization::ToMemrefOp>(definingOp)) {
112         auto alias = toMemref.getTensor();
113         if (resultBuffers.insert(alias).second) {
114           worklist.push_back(alias);
115         }
116         continue;
117       }
118 
119       if (auto tensorCast = dyn_cast<tensor::CastOp>(definingOp)) {
120         auto alias = tensorCast.getSource();
121         if (resultBuffers.insert(alias).second) {
122           worklist.push_back(alias);
123         }
124         continue;
125       }
126 
127       if (auto regionInterface =
128               dyn_cast<RegionBranchOpInterface>(definingOp)) {
129         for (Region& region : regionInterface.getOperation()->getRegions()) {
130           // Only consider regions that can return to the parent region.
131           SmallVector<RegionSuccessor, 2> successorRegions;
132           regionInterface.getSuccessorRegions(region.getRegionNumber(),
133                                               successorRegions);
134           if (llvm::none_of(successorRegions, [&](auto successorRegion) {
135                 return successorRegion.isParent();
136               }))
137             continue;
138 
139           // Iterate over all immediate terminators and record the values
140           // corresponding to result_buffers of interest.
141           for (Block& block : region) {
142             if (block.empty()) continue;
143             Operation& operation = block.back();
144             if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
145             auto idx = result.dyn_cast<OpResult>().getResultNumber();
146             if (resultBuffers.insert(operation.getOperand(idx)).second) {
147               worklist.push_back(operation.getOperand(idx));
148             }
149           }
150         }
151       }
152     }
153 
154     MLIRContext* ctx = func.getContext();
155     OpBuilder b(func);
156     func.walk([&](linalg::GenericOp genericOp) {
157       SmallVector<int64_t, 2> tileSizes(tile_sizes_.begin(), tile_sizes_.end());
158       if (tileSizes.empty()) {
159         tileSizes = SmallVector<int64_t, 2>(genericOp.getNumLoops(), 1);
160       }
161       auto op = cast<LinalgOp>(genericOp.getOperation());
162       for (OpOperand* opOperand : op.getOutputBufferOperands()) {
163         if (!resultBuffers.count(opOperand->get())) continue;
164         if (tileGenericOp(op, tileSizes, &b)) {
165           genericOp.erase();
166           return;
167         }
168       }
169     });
170     auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
171     if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
172       return signalPassFailure();
173 
174     // Fuse producers of tiled linalg ops.
175     llvm::SmallDenseSet<Operation*> eraseSet;
176     SmallVector<LinalgOp, 8> linalgOps;
177     func.walk([&](LinalgOp op) { linalgOps.push_back(op); });
178     for (LinalgOp op : llvm::reverse(linalgOps)) {
179       for (OpOperand* inputOperand : op.getInputOperands()) {
180         linalg::Aliases aliases;
181         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
182         auto info = fuseProducerOfBuffer(b, *inputOperand, graph);
183         if (failed(info)) continue;
184         auto* originalOp = info->originalProducer.getOperation();
185         eraseSet.insert(originalOp);
186         auto* originalOpInLinalgOpsVector =
187             std::find_if(linalgOps.begin(), linalgOps.end(),
188                          [&](const Operation* op) { return op == originalOp; });
189         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
190       }
191 
192       auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
193       if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
194         return signalPassFailure();
195     }
196     for (auto* e : eraseSet) e->erase();
197   }
198 
199  private:
tileGenericOp(LinalgOp op,ArrayRef<int64_t> tileSizes,OpBuilder * b)200   bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tileSizes, OpBuilder* b) {
201     auto loopType = use_parallel_loops_
202                         ? linalg::LinalgTilingLoopType::ParallelLoops
203                         : linalg::LinalgTilingLoopType::Loops;
204     IRRewriter rewriter(*b);
205     return succeeded(linalg::tileLinalgOp(
206         rewriter, op,
207         linalg::LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(
208             loopType)));
209   }
210 };
211 
212 }  // namespace
213 
createLhloFuseLinalgPass(bool useParallelLoops,ArrayRef<unsigned> tileSizes)214 std::unique_ptr<OperationPass<func::FuncOp>> createLhloFuseLinalgPass(
215     bool useParallelLoops, ArrayRef<unsigned> tileSizes) {
216   return std::make_unique<LhloFuseLinalgPass>(useParallelLoops, tileSizes);
217 }
218 
219 }  // namespace lmhlo
220 }  // namespace mlir
221