1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14
15 ==============================================================================*/
16
17 #include <memory>
18 #include <utility>
19
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"
28 #include "mlir/Dialect/Shape/IR/Shape.h"
29 #include "mlir/Dialect/Tensor/IR/Tensor.h"
30 #include "mlir/IR/BlockAndValueMapping.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/BuiltinTypes.h"
33 #include "mlir/IR/MLIRContext.h"
34 #include "mlir/IR/Operation.h"
35 #include "mlir/IR/OperationSupport.h"
36 #include "mlir/IR/PatternMatch.h"
37 #include "mlir/Interfaces/InferTypeOpInterface.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
40
41 namespace mlir {
42 namespace mhlo {
43 namespace {
44
45 struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
ShapeReificationPatternmlir::mhlo::__anon84680ab10111::ShapeReificationPattern46 explicit ShapeReificationPattern(MLIRContext *context)
47 : OpRewritePattern<shape::ShapeOfOp>(context) {
48 // Recursively reify until we hit an op that doesn't support it.
49 setHasBoundedRewriteRecursion();
50 }
51
matchAndRewritemlir::mhlo::__anon84680ab10111::ShapeReificationPattern52 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
53 PatternRewriter &rewriter) const override {
54 // Only reify shape computation if operand allows for it.
55 auto shapeOrigin = op.getArg().getDefiningOp<InferShapedTypeOpInterface>();
56 if (!shapeOrigin) return failure();
57
58 llvm::SmallVector<Value, 1> reifications;
59 if (failed(shapeOrigin.reifyReturnTypeShapes(
60 rewriter, shapeOrigin->getOperands(), reifications)))
61 return failure();
62 assert(reifications.size() == 1);
63 Value reifiedShape = reifications.front();
64
65 // Insert cast if needed.
66 if (reifiedShape.getType() != op.getType()) {
67 reifiedShape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
68 reifiedShape);
69 }
70
71 rewriter.replaceOp(op, reifiedShape);
72 return success();
73 }
74 };
75
76 template <typename OpTy>
77 struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
78 using OpRewritePattern<OpTy>::OpRewritePattern;
79
matchAndRewritemlir::mhlo::__anon84680ab10111::InlineBroadcastedShapeOperandsPattern80 LogicalResult matchAndRewrite(OpTy op,
81 PatternRewriter &rewriter) const override {
82 // Find all the shape operands, direct and indirect.
83 SmallVector<Value, 8> inlinedOperands;
84 for (Value direct : op->getOperands()) {
85 if (auto bcastOp = direct.getDefiningOp<shape::BroadcastOp>()) {
86 for (Value indirect : bcastOp->getOperands())
87 inlinedOperands.push_back(indirect);
88 } else {
89 inlinedOperands.push_back(direct);
90 }
91 }
92
93 // Only rewrite if it makes a difference.
94 if (inlinedOperands.size() == op.getNumOperands()) return failure();
95
96 // Inline shape operands.
97 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), inlinedOperands,
98 op->getAttrs());
99 return success();
100 }
101 };
102
moveUpIntoAssumingOpMatchAndRewrite(Operation * op,PatternRewriter & rewriter)103 LogicalResult moveUpIntoAssumingOpMatchAndRewrite(Operation *op,
104 PatternRewriter &rewriter) {
105 // Only implemented for single-result ops.
106 if (op->getNumResults() != 1) return failure();
107
108 // Find a preceding `assuming` op.
109 auto *theBlock = op->getBlock();
110 Operation *prev = op->getPrevNode();
111 while (prev != nullptr && !llvm::isa<shape::AssumingOp>(prev))
112 prev = prev->getPrevNode();
113 auto assumingOp = llvm::dyn_cast_or_null<shape::AssumingOp>(prev);
114 if (!assumingOp) return failure();
115 assert(assumingOp->getBlock() == theBlock && op->getBlock() == theBlock &&
116 "expect assuming op and root op to be in the same block");
117
118 // Make sure that all operands will be available after moving.
119 auto isAvailable = [&](Value v) {
120 Operation *def = v.getDefiningOp();
121 return def == nullptr || def->getBlock() != theBlock ||
122 !assumingOp->isBeforeInBlock(def);
123 };
124 if (!llvm::all_of(op->getOperands(), isAvailable)) return failure();
125
126 Block *body = assumingOp.getBody();
127 auto yieldOp = llvm::cast<shape::AssumingYieldOp>(body->getTerminator());
128
129 // Find the operands to use if the op was within the assuming region. We
130 // will later use their copies, as we copy the assuming op and its body.
131 SmallVector<Value, 8> newOperandsUnmapped =
132 llvm::to_vector<8>(llvm::map_range(op->getOperands(), [&](Value v) {
133 for (const auto &result : llvm::enumerate(assumingOp->getResults())) {
134 if (result.value() == v) return yieldOp->getOperand(result.index());
135 }
136 return v;
137 }));
138
139 // Insert the rewritten assuming op right before the old one.
140 OpBuilder::InsertionGuard guard(rewriter);
141 rewriter.setInsertionPoint(assumingOp);
142 auto newAssumingOp = rewriter.create<shape::AssumingOp>(
143 assumingOp.getLoc(), assumingOp.getWitness(),
144 [&](OpBuilder &b, Location) {
145 // Copy body.
146 BlockAndValueMapping mapping;
147 for (auto &nested : body->without_terminator())
148 b.clone(nested, mapping);
149
150 // Copy op into the new body and use the mapped operands.
151 for (auto it : llvm::zip(op->getOperands(), newOperandsUnmapped)) {
152 Value oldOperand, newOperandUnmapped;
153 std::tie(oldOperand, newOperandUnmapped) = it;
154 mapping.map(oldOperand, mapping.lookupOrDefault(newOperandUnmapped));
155 }
156 Operation *newOp = b.clone(*op, mapping);
157
158 // Yield the previous results and also the new ones.
159 auto mappedResults = llvm::to_vector<8>(llvm::map_range(
160 yieldOp.getOperands(),
161 [&](Value v) { return mapping.lookupOrDefault(v); }));
162 mappedResults.append(newOp->getResults().begin(),
163 newOp->getResults().end());
164 return mappedResults;
165 });
166
167 // Replace the assuming op and the root op with the corresponding result
168 // values.
169 ValueRange newAssumingOpResults = newAssumingOp->getResults();
170 rewriter.replaceOp(assumingOp, newAssumingOpResults.drop_back());
171 rewriter.replaceOp(op, newAssumingOpResults.back());
172 return success();
173 }
174
175 /// Move operation into a preceding assuming op. This allows to process
176 /// operations that depend on the assuming op's results. It will eventually
177 /// allow to make assuming regions' constraints independent from each other.
178 template <typename OpTy>
179 struct MoveUpIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
180 using OpRewritePattern<OpTy>::OpRewritePattern;
181
matchAndRewritemlir::mhlo::__anon84680ab10111::MoveUpIntoAssumingOpPattern182 LogicalResult matchAndRewrite(OpTy op,
183 PatternRewriter &rewriter) const override {
184 return moveUpIntoAssumingOpMatchAndRewrite(op.getOperation(), rewriter);
185 }
186 };
187
188 // Move elementwise operations into a preceding assuming op. This will
189 // eventually allow for more fusion opportunities.
190 struct MoveElementwiseOpsUpIntoAssumingOpPattern : public RewritePattern {
MoveElementwiseOpsUpIntoAssumingOpPatternmlir::mhlo::__anon84680ab10111::MoveElementwiseOpsUpIntoAssumingOpPattern191 explicit MoveElementwiseOpsUpIntoAssumingOpPattern(MLIRContext *ctx)
192 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
193
matchAndRewritemlir::mhlo::__anon84680ab10111::MoveElementwiseOpsUpIntoAssumingOpPattern194 LogicalResult matchAndRewrite(Operation *op,
195 PatternRewriter &rewriter) const override {
196 // Apply to all elementwise and broadcasting elementwise operations with no
197 // side effects.
198 if (!op->hasTrait<mlir::OpTrait::Elementwise>() &&
199 !op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>()) {
200 return failure();
201 }
202 if (!MemoryEffectOpInterface::hasNoEffect(op)) return failure();
203
204 return moveUpIntoAssumingOpMatchAndRewrite(op, rewriter);
205 }
206 };
207
208 // Move operation into an assuming region if all uses are within its body.
moveDownIntoAssumingOpMatchAndRewrite(Operation * op,PatternRewriter & rewriter)209 LogicalResult moveDownIntoAssumingOpMatchAndRewrite(Operation *op,
210 PatternRewriter &rewriter) {
211 auto users = op->getUsers();
212 auto it = users.begin();
213 auto end = users.end();
214 if (it == end) return failure();
215
216 // Find candidate assuming op.
217 auto assumingOp = (it++)->getParentOfType<shape::AssumingOp>();
218 if (!assumingOp || assumingOp->isProperAncestor(op)) return failure();
219
220 // Make sure all uses are within the unique assuming op's body.
221 while (it != end) {
222 auto hopefullySameAssumingOp = (it++)->getParentOfType<shape::AssumingOp>();
223 if (!hopefullySameAssumingOp || hopefullySameAssumingOp != assumingOp) {
224 return failure();
225 }
226 }
227
228 // Move op into the assuming region.
229 OpBuilder::InsertionGuard guard(rewriter);
230 rewriter.setInsertionPointToStart(assumingOp.getBody());
231 Operation *newOp = rewriter.clone(*op);
232 rewriter.replaceOp(op, newOp->getResults());
233 return success();
234 }
235
236 // Move elementwise operations into succeeding assuming regions. This will
237 // eventually allow for more fusion opportunities.
238 struct MoveElementwiseOpsDownIntoAssumingOpPattern : public RewritePattern {
MoveElementwiseOpsDownIntoAssumingOpPatternmlir::mhlo::__anon84680ab10111::MoveElementwiseOpsDownIntoAssumingOpPattern239 explicit MoveElementwiseOpsDownIntoAssumingOpPattern(MLIRContext *ctx)
240 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
241
matchAndRewritemlir::mhlo::__anon84680ab10111::MoveElementwiseOpsDownIntoAssumingOpPattern242 LogicalResult matchAndRewrite(Operation *op,
243 PatternRewriter &rewriter) const override {
244 // Apply to all elementwise and broadcasting elementwise operations with no
245 // side effects.
246 if (!op->hasTrait<mlir::OpTrait::Elementwise>() &&
247 !op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>()) {
248 return failure();
249 }
250 if (!MemoryEffectOpInterface::hasNoEffect(op)) return failure();
251
252 return moveDownIntoAssumingOpMatchAndRewrite(op, rewriter);
253 }
254 };
255
256 /// Move operation out of assuming op. This is only valid for
257 /// constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It
258 /// will eventually allow to make assuming regions' constraints independent from
259 /// each other.
260 template <typename OpTy>
261 struct MoveUpOutOfAssumingOpPattern : public OpRewritePattern<OpTy> {
262 using OpRewritePattern<OpTy>::OpRewritePattern;
263
matchAndRewritemlir::mhlo::__anon84680ab10111::MoveUpOutOfAssumingOpPattern264 LogicalResult matchAndRewrite(OpTy op,
265 PatternRewriter &rewriter) const override {
266 // Must be inside of an assuming op.
267 auto assumingOp = op->template getParentOfType<shape::AssumingOp>();
268 if (!assumingOp) return failure();
269
270 // Operands must not be defined within the assuming op.
271 Block *body = assumingOp.getBody();
272 auto isAvailable = [&](Value v) {
273 Operation *def = v.getDefiningOp();
274 return def == nullptr || def->getBlock() != body;
275 };
276 if (!llvm::all_of(op->getOperands(), isAvailable)) return failure();
277
278 // Move op before the assuming region.
279 OpBuilder::InsertionGuard guard(rewriter);
280 rewriter.setInsertionPoint(assumingOp);
281 Operation *newOp = rewriter.clone(*op);
282 rewriter.replaceOp(op, newOp->getResults());
283
284 // If the assuming region yields none of the new op's results, these values
285 // are exclusively used in the assuming op's body. In these cases there is
286 // no need for further rewrites.
287 auto isNewOpResult = [newOp](Value v) {
288 return llvm::is_contained(newOp->getResults(), v);
289 };
290 auto yieldOp = cast<shape::AssumingYieldOp>(body->getTerminator());
291 if (llvm::none_of(yieldOp.getOperands(), isNewOpResult)) return success();
292
293 // If the assuming region yields any of the new op's results, these values
294 // can instead bypass the assuming region. There is no need to yield them
295 // explicitly as they are assumed to be independent. The assuming op is
296 // rewritten accordingly.
297 SmallVector<Value, 2> replacementValues;
298 auto newAssumingOp = rewriter.create<shape::AssumingOp>(
299 assumingOp.getLoc(), assumingOp.getWitness(),
300 [&](OpBuilder &b, Location) {
301 // Copy body.
302 BlockAndValueMapping mapping;
303 for (Operation &nested : body->without_terminator()) {
304 b.clone(nested, mapping);
305 }
306
307 // Collect new yield operands.
308 SmallVector<Value, 2> newYieldOperands;
309 for (Value result : yieldOp.getOperands()) {
310 if (isNewOpResult(result)) {
311 replacementValues.push_back(result);
312 } else {
313 newYieldOperands.push_back(mapping.lookupOrDefault(result));
314 replacementValues.push_back(nullptr);
315 }
316 }
317 return newYieldOperands;
318 });
319
320 // Use the assuming op's results for the missing replacement values.
321 auto src = newAssumingOp.getResults().begin();
322 for (auto &dst : replacementValues) {
323 if (dst) continue;
324 dst = *src++;
325 }
326
327 rewriter.replaceOp(assumingOp, replacementValues);
328 return success();
329 }
330 };
331
332 /// Merge assuming regions if their constraints are independent from each other.
333 struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
334 using OpRewritePattern<shape::AssumingOp>::OpRewritePattern;
335
matchAndRewritemlir::mhlo::__anon84680ab10111::MergeAssumingOpsPattern336 LogicalResult matchAndRewrite(shape::AssumingOp op,
337 PatternRewriter &rewriter) const override {
338 // Merge assuming op with directly preceding one if both witnesses are
339 // availiable.
340 auto precedingOp =
341 llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
342 if (!precedingOp) return failure();
343 if (op.getWitness().getDefiningOp() == precedingOp) return failure();
344
345 // Merge witnesses.
346 OpBuilder::InsertionGuard guard(rewriter);
347 rewriter.setInsertionPoint(precedingOp);
348 Value newWitness = rewriter.create<shape::AssumingAllOp>(
349 op.getWitness().getDefiningOp()->getLoc(),
350 ValueRange{precedingOp.getWitness(), op.getWitness()});
351
352 // Merge assuming ops.
353 Block *body_a = precedingOp.getBody();
354 Block *body_b = op.getBody();
355 auto newAssumingOp = rewriter.create<shape::AssumingOp>(
356 precedingOp.getLoc(), newWitness, [&](OpBuilder &b, Location) {
357 // Copy preceding op's body.
358 BlockAndValueMapping mapping;
359 for (auto &nested : body_a->without_terminator()) {
360 b.clone(nested, mapping);
361 }
362
363 // Map result values of preceding assuming op.
364 auto yieldOpA =
365 llvm::dyn_cast<shape::AssumingYieldOp>(body_a->getTerminator());
366 for (auto pair :
367 llvm::zip(precedingOp->getResults(), yieldOpA.getOperands())) {
368 mapping.map(std::get<0>(pair),
369 mapping.lookupOrDefault(std::get<1>(pair)));
370 }
371
372 // Copy op's body.
373 for (auto &nested : body_b->without_terminator()) {
374 b.clone(nested, mapping);
375 }
376
377 // Collect merged assuming op's results.
378 SmallVector<Value, 4> mappedResults;
379 auto yieldOpB =
380 llvm::dyn_cast<shape::AssumingYieldOp>(body_b->getTerminator());
381 for (Value v : yieldOpA.getOperands()) {
382 mappedResults.push_back(mapping.lookupOrDefault(v));
383 }
384 for (Value v : yieldOpB.getOperands()) {
385 mappedResults.push_back(mapping.lookupOrDefault(v));
386 }
387 return mappedResults;
388 });
389
390 // Replace the two assuming ops with the new corresponding results.
391 ValueRange newResults = newAssumingOp->getResults();
392 size_t splitAt = precedingOp->getNumResults();
393 rewriter.replaceOp(precedingOp, newResults.take_front(splitAt));
394 rewriter.replaceOp(op, newResults.drop_front(splitAt));
395 return success();
396 }
397 };
398
399 struct EliminateDuplicateCstrBroadcastableOps
400 : public OpRewritePattern<shape::CstrBroadcastableOp> {
401 using OpRewritePattern<shape::CstrBroadcastableOp>::OpRewritePattern;
402
matchAndRewritemlir::mhlo::__anon84680ab10111::EliminateDuplicateCstrBroadcastableOps403 LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
404 PatternRewriter &rewriter) const override {
405 // Search for previous occurence of the same constraint.
406 Operation *it = op->getPrevNode();
407 while (it != nullptr) {
408 if (auto candidate = llvm::dyn_cast<shape::CstrBroadcastableOp>(it)) {
409 if (candidate.getShapes() == op.getShapes()) {
410 rewriter.replaceOp(op, candidate.getResult());
411 return success();
412 }
413 }
414 it = it->getPrevNode();
415 }
416
417 return failure();
418 }
419 };
420
421 struct MergeAssumingOpsPass
422 : public MergeAssumingOpsPassBase<MergeAssumingOpsPass> {
getDependentDialectsmlir::mhlo::__anon84680ab10111::MergeAssumingOpsPass423 void getDependentDialects(DialectRegistry ®istry) const override {
424 registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
425 }
426
runOnOperationmlir::mhlo::__anon84680ab10111::MergeAssumingOpsPass427 void runOnOperation() override {
428 MLIRContext *ctx = &getContext();
429 RewritePatternSet patterns(ctx);
430 mhlo::populateMergeAssumingOpsPatterns(ctx, &patterns);
431 GreedyRewriteConfig config;
432 config.maxIterations = GreedyRewriteConfig::kNoIterationLimit;
433 if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
434 config))) {
435 return signalPassFailure();
436 }
437 }
438 };
439
440 } // namespace
441
populateMergeAssumingOpsPatterns(MLIRContext * context,RewritePatternSet * patterns)442 void populateMergeAssumingOpsPatterns(MLIRContext *context,
443 RewritePatternSet *patterns) {
444 // clang-format off
445 patterns->add<
446 EliminateDuplicateCstrBroadcastableOps,
447 InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
448 MergeAssumingOpsPattern,
449 MoveElementwiseOpsDownIntoAssumingOpPattern,
450 MoveElementwiseOpsUpIntoAssumingOpPattern,
451 MoveUpIntoAssumingOpPattern<shape::AssumingAllOp>,
452 MoveUpIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
453 MoveUpIntoAssumingOpPattern<shape::ShapeOfOp>,
454 MoveUpOutOfAssumingOpPattern<shape::AssumingAllOp>,
455 MoveUpOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
456 MoveUpOutOfAssumingOpPattern<shape::ShapeOfOp>,
457 ShapeReificationPattern>(context);
458 // clang-format on
459 mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns,
460 context);
461 mhlo::DynamicReshapeOp::getCanonicalizationPatterns(*patterns, context);
462 shape::AssumingAllOp::getCanonicalizationPatterns(*patterns, context);
463 shape::AssumingOp::getCanonicalizationPatterns(*patterns, context);
464 shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
465 shape::CstrBroadcastableOp::getCanonicalizationPatterns(*patterns, context);
466 tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
467 }
468
createMergeAssumingOpsPass()469 std::unique_ptr<OperationPass<func::FuncOp>> createMergeAssumingOpsPass() {
470 return std::make_unique<MergeAssumingOpsPass>();
471 }
472
473 } // namespace mhlo
474 } // namespace mlir
475