1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <iterator>
17 #include <memory>
18 #include <utility>
19 
20 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
21 #include "mlir-hlo/Dialect/gml_st/transforms/pass_detail.h"
22 #include "mlir-hlo/Dialect/gml_st/transforms/passes.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 namespace mlir {
28 namespace gml_st {
29 namespace {
30 
31 // Uncollapse materialize operations with nested tile chains t1, t2, ..., tn. A
32 // materialize op of the form ...
33 //   `materialize(t1(t2(...(tn(sn)))), arg)`
34 // ... is expanded into ...
35 //   `materialize(t1(s1), materialize(t2(...(tn(sn))), arg))`.
36 struct UncollapseMaterializePattern : public OpRewritePattern<MaterializeOp> {
37   using OpRewritePattern<MaterializeOp>::OpRewritePattern;
matchAndRewritemlir::gml_st::__anonc7586d3a0111::UncollapseMaterializePattern38   LogicalResult matchAndRewrite(MaterializeOp op,
39                                 PatternRewriter &rewriter) const override {
40     // Find head of the tile chain.
41     auto tileDef = op.set().getDefiningOp<TileOp>();
42     if (!tileDef) return failure();
43 
44     // Find tail of the tile chain.
45     auto superTile = tileDef.superset();
46     auto superTileDef = superTile.getDefiningOp<TileOp>();
47     if (!superTileDef) return failure();
48 
49     // Create independent head tile and tail tile chain.
50     auto loc = op.getLoc();
51     auto newTileSpace = rewriter.create<SpaceOp>(loc, superTileDef.getType(),
52                                                  superTileDef.sizes(),
53                                                  superTileDef.static_sizes());
54     auto newTile = rewriter.create<TileOp>(
55         loc, newTileSpace, tileDef.offsets(), tileDef.sizes(),
56         tileDef.strides(), tileDef.static_offsets(), tileDef.static_sizes(),
57         tileDef.static_strides());
58     auto newInnerMaterialize =
59         rewriter.create<MaterializeOp>(loc, op.source(), superTile);
60 
61     // Create expanded materialize op.
62     rewriter.replaceOpWithNewOp<MaterializeOp>(op, newInnerMaterialize,
63                                                newTile);
64     return success();
65   }
66 };
67 
68 // Collapse materialize operations with nested tile chains t1, t2, ..., tn, and
69 // u1, u2, ..., un. A materialize op of the form ...
70 //   `materialize(t1(t2(...(tn(sn)))), materialize(u1(u2(...(un(sn')))), arg))`
71 // ... is collapsed as ...
72 //   `materialize(t1(t2(...(tn(u1(u2(...(un(sn'))))))), arg)`.
73 struct CollapseMaterializePattern : public OpRewritePattern<MaterializeOp> {
74   using OpRewritePattern<MaterializeOp>::OpRewritePattern;
75 
matchAndRewritemlir::gml_st::__anonc7586d3a0111::CollapseMaterializePattern76   LogicalResult matchAndRewrite(MaterializeOp op,
77                                 PatternRewriter &rewriter) const override {
78     // Find inner materialize op.
79     auto innerMaterialize = op.source().getDefiningOp<MaterializeOp>();
80     if (!innerMaterialize) return failure();
81 
82     // Find outer tile chain to replace its root space op.
83     llvm::SmallVector<TileOp> tileChain;
84     Operation *tileDef = op.set().getDefiningOp();
85     while (tileDef && !llvm::isa<SpaceOp>(tileDef)) {
86       auto tileOp = llvm::dyn_cast<TileOp>(tileDef);
87       if (!tileOp) return failure();
88       tileChain.push_back(tileOp);
89       tileDef = tileOp.superset().getDefiningOp();
90     }
91 
92     // Create new tile chain, starting with its tail.
93     auto loc = op.getLoc();
94     Value newTileChain = innerMaterialize.set();
95     while (!tileChain.empty()) {
96       TileOp tileOp = tileChain.pop_back_val();
97       newTileChain = rewriter.create<TileOp>(
98           loc, newTileChain, tileOp.offsets(), tileOp.sizes(), tileOp.strides(),
99           tileOp.static_offsets(), tileOp.static_sizes(),
100           tileOp.static_strides());
101     }
102 
103     // Create collapsed materialize op.
104     rewriter.replaceOpWithNewOp<MaterializeOp>(op, innerMaterialize.source(),
105                                                newTileChain);
106     return success();
107   }
108 };
109 
110 struct CollapseMaterializeOpsPass
111     : public CollapseMaterializeOpsPassBase<CollapseMaterializeOpsPass> {
CollapseMaterializeOpsPassmlir::gml_st::__anonc7586d3a0111::CollapseMaterializeOpsPass112   explicit CollapseMaterializeOpsPass(bool reverse)
113       : CollapseMaterializeOpsPassBase() {
114     reverse_ = reverse;
115   }
116 
getDependentDialectsmlir::gml_st::__anonc7586d3a0111::CollapseMaterializeOpsPass117   void getDependentDialects(DialectRegistry &registry) const final {}
118 
runOnOperationmlir::gml_st::__anonc7586d3a0111::CollapseMaterializeOpsPass119   void runOnOperation() final {
120     MLIRContext *ctx = &getContext();
121 
122     // Populate collapse or uncollapse pattern.
123     RewritePatternSet patterns(ctx);
124     if (reverse_) {
125       patterns.add<UncollapseMaterializePattern>(ctx);
126     } else {
127       patterns.add<CollapseMaterializePattern>(ctx);
128     }
129 
130     if (failed(applyPatternsAndFoldGreedily(getOperation(),
131                                             std::move(patterns)))) {
132       return signalPassFailure();
133     }
134   }
135 };
136 
137 }  // namespace
138 
createCollapseMaterializeOpsPass(bool reverse)139 std::unique_ptr<OperationPass<func::FuncOp>> createCollapseMaterializeOpsPass(
140     bool reverse) {
141   return std::make_unique<CollapseMaterializeOpsPass>(reverse);
142 }
143 
144 }  // namespace gml_st
145 }  // namespace mlir
146