1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <iterator>
17 #include <memory>
18 #include <utility>
19
20 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
21 #include "mlir-hlo/Dialect/gml_st/transforms/pass_detail.h"
22 #include "mlir-hlo/Dialect/gml_st/transforms/passes.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26
27 namespace mlir {
28 namespace gml_st {
29 namespace {
30
31 // Uncollapse materialize operations with nested tile chains t1, t2, ..., tn. A
32 // materialize op of the form ...
33 // `materialize(t1(t2(...(tn(sn)))), arg)`
34 // ... is expanded into ...
35 // `materialize(t1(s1), materialize(t2(...(tn(sn))), arg))`.
36 struct UncollapseMaterializePattern : public OpRewritePattern<MaterializeOp> {
37 using OpRewritePattern<MaterializeOp>::OpRewritePattern;
matchAndRewritemlir::gml_st::__anonc7586d3a0111::UncollapseMaterializePattern38 LogicalResult matchAndRewrite(MaterializeOp op,
39 PatternRewriter &rewriter) const override {
40 // Find head of the tile chain.
41 auto tileDef = op.set().getDefiningOp<TileOp>();
42 if (!tileDef) return failure();
43
44 // Find tail of the tile chain.
45 auto superTile = tileDef.superset();
46 auto superTileDef = superTile.getDefiningOp<TileOp>();
47 if (!superTileDef) return failure();
48
49 // Create independent head tile and tail tile chain.
50 auto loc = op.getLoc();
51 auto newTileSpace = rewriter.create<SpaceOp>(loc, superTileDef.getType(),
52 superTileDef.sizes(),
53 superTileDef.static_sizes());
54 auto newTile = rewriter.create<TileOp>(
55 loc, newTileSpace, tileDef.offsets(), tileDef.sizes(),
56 tileDef.strides(), tileDef.static_offsets(), tileDef.static_sizes(),
57 tileDef.static_strides());
58 auto newInnerMaterialize =
59 rewriter.create<MaterializeOp>(loc, op.source(), superTile);
60
61 // Create expanded materialize op.
62 rewriter.replaceOpWithNewOp<MaterializeOp>(op, newInnerMaterialize,
63 newTile);
64 return success();
65 }
66 };
67
68 // Collapse materialize operations with nested tile chains t1, t2, ..., tn, and
69 // u1, u2, ..., un. A materialize op of the form ...
70 // `materialize(t1(t2(...(tn(sn)))), materialize(u1(u2(...(un(sn')))), arg))`
71 // ... is collapsed as ...
72 // `materialize(t1(t2(...(tn(u1(u2(...(un(sn'))))))), arg)`.
73 struct CollapseMaterializePattern : public OpRewritePattern<MaterializeOp> {
74 using OpRewritePattern<MaterializeOp>::OpRewritePattern;
75
matchAndRewritemlir::gml_st::__anonc7586d3a0111::CollapseMaterializePattern76 LogicalResult matchAndRewrite(MaterializeOp op,
77 PatternRewriter &rewriter) const override {
78 // Find inner materialize op.
79 auto innerMaterialize = op.source().getDefiningOp<MaterializeOp>();
80 if (!innerMaterialize) return failure();
81
82 // Find outer tile chain to replace its root space op.
83 llvm::SmallVector<TileOp> tileChain;
84 Operation *tileDef = op.set().getDefiningOp();
85 while (tileDef && !llvm::isa<SpaceOp>(tileDef)) {
86 auto tileOp = llvm::dyn_cast<TileOp>(tileDef);
87 if (!tileOp) return failure();
88 tileChain.push_back(tileOp);
89 tileDef = tileOp.superset().getDefiningOp();
90 }
91
92 // Create new tile chain, starting with its tail.
93 auto loc = op.getLoc();
94 Value newTileChain = innerMaterialize.set();
95 while (!tileChain.empty()) {
96 TileOp tileOp = tileChain.pop_back_val();
97 newTileChain = rewriter.create<TileOp>(
98 loc, newTileChain, tileOp.offsets(), tileOp.sizes(), tileOp.strides(),
99 tileOp.static_offsets(), tileOp.static_sizes(),
100 tileOp.static_strides());
101 }
102
103 // Create collapsed materialize op.
104 rewriter.replaceOpWithNewOp<MaterializeOp>(op, innerMaterialize.source(),
105 newTileChain);
106 return success();
107 }
108 };
109
110 struct CollapseMaterializeOpsPass
111 : public CollapseMaterializeOpsPassBase<CollapseMaterializeOpsPass> {
CollapseMaterializeOpsPassmlir::gml_st::__anonc7586d3a0111::CollapseMaterializeOpsPass112 explicit CollapseMaterializeOpsPass(bool reverse)
113 : CollapseMaterializeOpsPassBase() {
114 reverse_ = reverse;
115 }
116
getDependentDialectsmlir::gml_st::__anonc7586d3a0111::CollapseMaterializeOpsPass117 void getDependentDialects(DialectRegistry ®istry) const final {}
118
runOnOperationmlir::gml_st::__anonc7586d3a0111::CollapseMaterializeOpsPass119 void runOnOperation() final {
120 MLIRContext *ctx = &getContext();
121
122 // Populate collapse or uncollapse pattern.
123 RewritePatternSet patterns(ctx);
124 if (reverse_) {
125 patterns.add<UncollapseMaterializePattern>(ctx);
126 } else {
127 patterns.add<CollapseMaterializePattern>(ctx);
128 }
129
130 if (failed(applyPatternsAndFoldGreedily(getOperation(),
131 std::move(patterns)))) {
132 return signalPassFailure();
133 }
134 }
135 };
136
137 } // namespace
138
createCollapseMaterializeOpsPass(bool reverse)139 std::unique_ptr<OperationPass<func::FuncOp>> createCollapseMaterializeOpsPass(
140 bool reverse) {
141 return std::make_unique<CollapseMaterializeOpsPass>(reverse);
142 }
143
144 } // namespace gml_st
145 } // namespace mlir
146