1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 
15 ==============================================================================*/
16 
17 #include <memory>
18 #include <utility>
19 
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"
28 #include "mlir/Dialect/Shape/IR/Shape.h"
29 #include "mlir/Dialect/Tensor/IR/Tensor.h"
30 #include "mlir/IR/BlockAndValueMapping.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/BuiltinTypes.h"
33 #include "mlir/IR/MLIRContext.h"
34 #include "mlir/IR/Operation.h"
35 #include "mlir/IR/OperationSupport.h"
36 #include "mlir/IR/PatternMatch.h"
37 #include "mlir/Interfaces/InferTypeOpInterface.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
40 
41 namespace mlir {
42 namespace mhlo {
43 namespace {
44 
45 struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
ShapeReificationPatternmlir::mhlo::__anon84680ab10111::ShapeReificationPattern46   explicit ShapeReificationPattern(MLIRContext *context)
47       : OpRewritePattern<shape::ShapeOfOp>(context) {
48     // Recursively reify until we hit an op that doesn't support it.
49     setHasBoundedRewriteRecursion();
50   }
51 
matchAndRewritemlir::mhlo::__anon84680ab10111::ShapeReificationPattern52   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
53                                 PatternRewriter &rewriter) const override {
54     // Only reify shape computation if operand allows for it.
55     auto shapeOrigin = op.getArg().getDefiningOp<InferShapedTypeOpInterface>();
56     if (!shapeOrigin) return failure();
57 
58     llvm::SmallVector<Value, 1> reifications;
59     if (failed(shapeOrigin.reifyReturnTypeShapes(
60             rewriter, shapeOrigin->getOperands(), reifications)))
61       return failure();
62     assert(reifications.size() == 1);
63     Value reifiedShape = reifications.front();
64 
65     // Insert cast if needed.
66     if (reifiedShape.getType() != op.getType()) {
67       reifiedShape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
68                                                      reifiedShape);
69     }
70 
71     rewriter.replaceOp(op, reifiedShape);
72     return success();
73   }
74 };
75 
76 template <typename OpTy>
77 struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
78   using OpRewritePattern<OpTy>::OpRewritePattern;
79 
matchAndRewritemlir::mhlo::__anon84680ab10111::InlineBroadcastedShapeOperandsPattern80   LogicalResult matchAndRewrite(OpTy op,
81                                 PatternRewriter &rewriter) const override {
82     // Find all the shape operands, direct and indirect.
83     SmallVector<Value, 8> inlinedOperands;
84     for (Value direct : op->getOperands()) {
85       if (auto bcastOp = direct.getDefiningOp<shape::BroadcastOp>()) {
86         for (Value indirect : bcastOp->getOperands())
87           inlinedOperands.push_back(indirect);
88       } else {
89         inlinedOperands.push_back(direct);
90       }
91     }
92 
93     // Only rewrite if it makes a difference.
94     if (inlinedOperands.size() == op.getNumOperands()) return failure();
95 
96     // Inline shape operands.
97     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), inlinedOperands,
98                                       op->getAttrs());
99     return success();
100   }
101 };
102 
moveUpIntoAssumingOpMatchAndRewrite(Operation * op,PatternRewriter & rewriter)103 LogicalResult moveUpIntoAssumingOpMatchAndRewrite(Operation *op,
104                                                   PatternRewriter &rewriter) {
105   // Only implemented for single-result ops.
106   if (op->getNumResults() != 1) return failure();
107 
108   // Find a preceding `assuming` op.
109   auto *theBlock = op->getBlock();
110   Operation *prev = op->getPrevNode();
111   while (prev != nullptr && !llvm::isa<shape::AssumingOp>(prev))
112     prev = prev->getPrevNode();
113   auto assumingOp = llvm::dyn_cast_or_null<shape::AssumingOp>(prev);
114   if (!assumingOp) return failure();
115   assert(assumingOp->getBlock() == theBlock && op->getBlock() == theBlock &&
116          "expect assuming op and root op to be in the same block");
117 
118   // Make sure that all operands will be available after moving.
119   auto isAvailable = [&](Value v) {
120     Operation *def = v.getDefiningOp();
121     return def == nullptr || def->getBlock() != theBlock ||
122            !assumingOp->isBeforeInBlock(def);
123   };
124   if (!llvm::all_of(op->getOperands(), isAvailable)) return failure();
125 
126   Block *body = assumingOp.getBody();
127   auto yieldOp = llvm::cast<shape::AssumingYieldOp>(body->getTerminator());
128 
129   // Find the operands to use if the op was within the assuming region. We
130   // will later use their copies, as we copy the assuming op and its body.
131   SmallVector<Value, 8> newOperandsUnmapped =
132       llvm::to_vector<8>(llvm::map_range(op->getOperands(), [&](Value v) {
133         for (const auto &result : llvm::enumerate(assumingOp->getResults())) {
134           if (result.value() == v) return yieldOp->getOperand(result.index());
135         }
136         return v;
137       }));
138 
139   // Insert the rewritten assuming op right before the old one.
140   OpBuilder::InsertionGuard guard(rewriter);
141   rewriter.setInsertionPoint(assumingOp);
142   auto newAssumingOp = rewriter.create<shape::AssumingOp>(
143       assumingOp.getLoc(), assumingOp.getWitness(),
144       [&](OpBuilder &b, Location) {
145         // Copy body.
146         BlockAndValueMapping mapping;
147         for (auto &nested : body->without_terminator())
148           b.clone(nested, mapping);
149 
150         // Copy op into the new body and use the mapped operands.
151         for (auto it : llvm::zip(op->getOperands(), newOperandsUnmapped)) {
152           Value oldOperand, newOperandUnmapped;
153           std::tie(oldOperand, newOperandUnmapped) = it;
154           mapping.map(oldOperand, mapping.lookupOrDefault(newOperandUnmapped));
155         }
156         Operation *newOp = b.clone(*op, mapping);
157 
158         // Yield the previous results and also the new ones.
159         auto mappedResults = llvm::to_vector<8>(llvm::map_range(
160             yieldOp.getOperands(),
161             [&](Value v) { return mapping.lookupOrDefault(v); }));
162         mappedResults.append(newOp->getResults().begin(),
163                              newOp->getResults().end());
164         return mappedResults;
165       });
166 
167   // Replace the assuming op and the root op with the corresponding result
168   // values.
169   ValueRange newAssumingOpResults = newAssumingOp->getResults();
170   rewriter.replaceOp(assumingOp, newAssumingOpResults.drop_back());
171   rewriter.replaceOp(op, newAssumingOpResults.back());
172   return success();
173 }
174 
175 /// Move operation into a preceding assuming op. This allows to process
176 /// operations that depend on the assuming op's results. It will eventually
177 /// allow to make assuming regions' constraints independent from each other.
178 template <typename OpTy>
179 struct MoveUpIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
180   using OpRewritePattern<OpTy>::OpRewritePattern;
181 
matchAndRewritemlir::mhlo::__anon84680ab10111::MoveUpIntoAssumingOpPattern182   LogicalResult matchAndRewrite(OpTy op,
183                                 PatternRewriter &rewriter) const override {
184     return moveUpIntoAssumingOpMatchAndRewrite(op.getOperation(), rewriter);
185   }
186 };
187 
188 // Move elementwise operations into a preceding assuming op. This will
189 // eventually allow for more fusion opportunities.
190 struct MoveElementwiseOpsUpIntoAssumingOpPattern : public RewritePattern {
MoveElementwiseOpsUpIntoAssumingOpPatternmlir::mhlo::__anon84680ab10111::MoveElementwiseOpsUpIntoAssumingOpPattern191   explicit MoveElementwiseOpsUpIntoAssumingOpPattern(MLIRContext *ctx)
192       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
193 
matchAndRewritemlir::mhlo::__anon84680ab10111::MoveElementwiseOpsUpIntoAssumingOpPattern194   LogicalResult matchAndRewrite(Operation *op,
195                                 PatternRewriter &rewriter) const override {
196     // Apply to all elementwise and broadcasting elementwise operations with no
197     // side effects.
198     if (!op->hasTrait<mlir::OpTrait::Elementwise>() &&
199         !op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>()) {
200       return failure();
201     }
202     if (!MemoryEffectOpInterface::hasNoEffect(op)) return failure();
203 
204     return moveUpIntoAssumingOpMatchAndRewrite(op, rewriter);
205   }
206 };
207 
208 // Move operation into an assuming region if all uses are within its body.
moveDownIntoAssumingOpMatchAndRewrite(Operation * op,PatternRewriter & rewriter)209 LogicalResult moveDownIntoAssumingOpMatchAndRewrite(Operation *op,
210                                                     PatternRewriter &rewriter) {
211   auto users = op->getUsers();
212   auto it = users.begin();
213   auto end = users.end();
214   if (it == end) return failure();
215 
216   // Find candidate assuming op.
217   auto assumingOp = (it++)->getParentOfType<shape::AssumingOp>();
218   if (!assumingOp || assumingOp->isProperAncestor(op)) return failure();
219 
220   // Make sure all uses are within the unique assuming op's body.
221   while (it != end) {
222     auto hopefullySameAssumingOp = (it++)->getParentOfType<shape::AssumingOp>();
223     if (!hopefullySameAssumingOp || hopefullySameAssumingOp != assumingOp) {
224       return failure();
225     }
226   }
227 
228   // Move op into the assuming region.
229   OpBuilder::InsertionGuard guard(rewriter);
230   rewriter.setInsertionPointToStart(assumingOp.getBody());
231   Operation *newOp = rewriter.clone(*op);
232   rewriter.replaceOp(op, newOp->getResults());
233   return success();
234 }
235 
236 // Move elementwise operations into succeeding assuming regions. This will
237 // eventually allow for more fusion opportunities.
238 struct MoveElementwiseOpsDownIntoAssumingOpPattern : public RewritePattern {
MoveElementwiseOpsDownIntoAssumingOpPatternmlir::mhlo::__anon84680ab10111::MoveElementwiseOpsDownIntoAssumingOpPattern239   explicit MoveElementwiseOpsDownIntoAssumingOpPattern(MLIRContext *ctx)
240       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
241 
matchAndRewritemlir::mhlo::__anon84680ab10111::MoveElementwiseOpsDownIntoAssumingOpPattern242   LogicalResult matchAndRewrite(Operation *op,
243                                 PatternRewriter &rewriter) const override {
244     // Apply to all elementwise and broadcasting elementwise operations with no
245     // side effects.
246     if (!op->hasTrait<mlir::OpTrait::Elementwise>() &&
247         !op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>()) {
248       return failure();
249     }
250     if (!MemoryEffectOpInterface::hasNoEffect(op)) return failure();
251 
252     return moveDownIntoAssumingOpMatchAndRewrite(op, rewriter);
253   }
254 };
255 
256 /// Move operation out of assuming op. This is only valid for
257 /// constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It
258 /// will eventually allow to make assuming regions' constraints independent from
259 /// each other.
260 template <typename OpTy>
261 struct MoveUpOutOfAssumingOpPattern : public OpRewritePattern<OpTy> {
262   using OpRewritePattern<OpTy>::OpRewritePattern;
263 
matchAndRewritemlir::mhlo::__anon84680ab10111::MoveUpOutOfAssumingOpPattern264   LogicalResult matchAndRewrite(OpTy op,
265                                 PatternRewriter &rewriter) const override {
266     // Must be inside of an assuming op.
267     auto assumingOp = op->template getParentOfType<shape::AssumingOp>();
268     if (!assumingOp) return failure();
269 
270     // Operands must not be defined within the assuming op.
271     Block *body = assumingOp.getBody();
272     auto isAvailable = [&](Value v) {
273       Operation *def = v.getDefiningOp();
274       return def == nullptr || def->getBlock() != body;
275     };
276     if (!llvm::all_of(op->getOperands(), isAvailable)) return failure();
277 
278     // Move op before the assuming region.
279     OpBuilder::InsertionGuard guard(rewriter);
280     rewriter.setInsertionPoint(assumingOp);
281     Operation *newOp = rewriter.clone(*op);
282     rewriter.replaceOp(op, newOp->getResults());
283 
284     // If the assuming region yields none of the new op's results, these values
285     // are exclusively used in the assuming op's body. In these cases there is
286     // no need for further rewrites.
287     auto isNewOpResult = [newOp](Value v) {
288       return llvm::is_contained(newOp->getResults(), v);
289     };
290     auto yieldOp = cast<shape::AssumingYieldOp>(body->getTerminator());
291     if (llvm::none_of(yieldOp.getOperands(), isNewOpResult)) return success();
292 
293     // If the assuming region yields any of the new op's results, these values
294     // can instead bypass the assuming region. There is no need to yield them
295     // explicitly as they are assumed to be independent. The assuming op is
296     // rewritten accordingly.
297     SmallVector<Value, 2> replacementValues;
298     auto newAssumingOp = rewriter.create<shape::AssumingOp>(
299         assumingOp.getLoc(), assumingOp.getWitness(),
300         [&](OpBuilder &b, Location) {
301           // Copy body.
302           BlockAndValueMapping mapping;
303           for (Operation &nested : body->without_terminator()) {
304             b.clone(nested, mapping);
305           }
306 
307           // Collect new yield operands.
308           SmallVector<Value, 2> newYieldOperands;
309           for (Value result : yieldOp.getOperands()) {
310             if (isNewOpResult(result)) {
311               replacementValues.push_back(result);
312             } else {
313               newYieldOperands.push_back(mapping.lookupOrDefault(result));
314               replacementValues.push_back(nullptr);
315             }
316           }
317           return newYieldOperands;
318         });
319 
320     // Use the assuming op's results for the missing replacement values.
321     auto src = newAssumingOp.getResults().begin();
322     for (auto &dst : replacementValues) {
323       if (dst) continue;
324       dst = *src++;
325     }
326 
327     rewriter.replaceOp(assumingOp, replacementValues);
328     return success();
329   }
330 };
331 
332 /// Merge assuming regions if their constraints are independent from each other.
333 struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
334   using OpRewritePattern<shape::AssumingOp>::OpRewritePattern;
335 
matchAndRewritemlir::mhlo::__anon84680ab10111::MergeAssumingOpsPattern336   LogicalResult matchAndRewrite(shape::AssumingOp op,
337                                 PatternRewriter &rewriter) const override {
338     // Merge assuming op with directly preceding one if both witnesses are
339     // availiable.
340     auto precedingOp =
341         llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
342     if (!precedingOp) return failure();
343     if (op.getWitness().getDefiningOp() == precedingOp) return failure();
344 
345     // Merge witnesses.
346     OpBuilder::InsertionGuard guard(rewriter);
347     rewriter.setInsertionPoint(precedingOp);
348     Value newWitness = rewriter.create<shape::AssumingAllOp>(
349         op.getWitness().getDefiningOp()->getLoc(),
350         ValueRange{precedingOp.getWitness(), op.getWitness()});
351 
352     // Merge assuming ops.
353     Block *body_a = precedingOp.getBody();
354     Block *body_b = op.getBody();
355     auto newAssumingOp = rewriter.create<shape::AssumingOp>(
356         precedingOp.getLoc(), newWitness, [&](OpBuilder &b, Location) {
357           // Copy preceding op's body.
358           BlockAndValueMapping mapping;
359           for (auto &nested : body_a->without_terminator()) {
360             b.clone(nested, mapping);
361           }
362 
363           // Map result values of preceding assuming op.
364           auto yieldOpA =
365               llvm::dyn_cast<shape::AssumingYieldOp>(body_a->getTerminator());
366           for (auto pair :
367                llvm::zip(precedingOp->getResults(), yieldOpA.getOperands())) {
368             mapping.map(std::get<0>(pair),
369                         mapping.lookupOrDefault(std::get<1>(pair)));
370           }
371 
372           // Copy op's body.
373           for (auto &nested : body_b->without_terminator()) {
374             b.clone(nested, mapping);
375           }
376 
377           // Collect merged assuming op's results.
378           SmallVector<Value, 4> mappedResults;
379           auto yieldOpB =
380               llvm::dyn_cast<shape::AssumingYieldOp>(body_b->getTerminator());
381           for (Value v : yieldOpA.getOperands()) {
382             mappedResults.push_back(mapping.lookupOrDefault(v));
383           }
384           for (Value v : yieldOpB.getOperands()) {
385             mappedResults.push_back(mapping.lookupOrDefault(v));
386           }
387           return mappedResults;
388         });
389 
390     // Replace the two assuming ops with the new corresponding results.
391     ValueRange newResults = newAssumingOp->getResults();
392     size_t splitAt = precedingOp->getNumResults();
393     rewriter.replaceOp(precedingOp, newResults.take_front(splitAt));
394     rewriter.replaceOp(op, newResults.drop_front(splitAt));
395     return success();
396   }
397 };
398 
399 struct EliminateDuplicateCstrBroadcastableOps
400     : public OpRewritePattern<shape::CstrBroadcastableOp> {
401   using OpRewritePattern<shape::CstrBroadcastableOp>::OpRewritePattern;
402 
matchAndRewritemlir::mhlo::__anon84680ab10111::EliminateDuplicateCstrBroadcastableOps403   LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
404                                 PatternRewriter &rewriter) const override {
405     // Search for previous occurence of the same constraint.
406     Operation *it = op->getPrevNode();
407     while (it != nullptr) {
408       if (auto candidate = llvm::dyn_cast<shape::CstrBroadcastableOp>(it)) {
409         if (candidate.getShapes() == op.getShapes()) {
410           rewriter.replaceOp(op, candidate.getResult());
411           return success();
412         }
413       }
414       it = it->getPrevNode();
415     }
416 
417     return failure();
418   }
419 };
420 
421 struct MergeAssumingOpsPass
422     : public MergeAssumingOpsPassBase<MergeAssumingOpsPass> {
getDependentDialectsmlir::mhlo::__anon84680ab10111::MergeAssumingOpsPass423   void getDependentDialects(DialectRegistry &registry) const override {
424     registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
425   }
426 
runOnOperationmlir::mhlo::__anon84680ab10111::MergeAssumingOpsPass427   void runOnOperation() override {
428     MLIRContext *ctx = &getContext();
429     RewritePatternSet patterns(ctx);
430     mhlo::populateMergeAssumingOpsPatterns(ctx, &patterns);
431     GreedyRewriteConfig config;
432     config.maxIterations = GreedyRewriteConfig::kNoIterationLimit;
433     if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
434                                             config))) {
435       return signalPassFailure();
436     }
437   }
438 };
439 
440 }  // namespace
441 
populateMergeAssumingOpsPatterns(MLIRContext * context,RewritePatternSet * patterns)442 void populateMergeAssumingOpsPatterns(MLIRContext *context,
443                                       RewritePatternSet *patterns) {
444   // clang-format off
445   patterns->add<
446       EliminateDuplicateCstrBroadcastableOps,
447       InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
448       MergeAssumingOpsPattern,
449       MoveElementwiseOpsDownIntoAssumingOpPattern,
450       MoveElementwiseOpsUpIntoAssumingOpPattern,
451       MoveUpIntoAssumingOpPattern<shape::AssumingAllOp>,
452       MoveUpIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
453       MoveUpIntoAssumingOpPattern<shape::ShapeOfOp>,
454       MoveUpOutOfAssumingOpPattern<shape::AssumingAllOp>,
455       MoveUpOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
456       MoveUpOutOfAssumingOpPattern<shape::ShapeOfOp>,
457       ShapeReificationPattern>(context);
458   // clang-format on
459   mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns,
460                                                              context);
461   mhlo::DynamicReshapeOp::getCanonicalizationPatterns(*patterns, context);
462   shape::AssumingAllOp::getCanonicalizationPatterns(*patterns, context);
463   shape::AssumingOp::getCanonicalizationPatterns(*patterns, context);
464   shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
465   shape::CstrBroadcastableOp::getCanonicalizationPatterns(*patterns, context);
466   tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
467 }
468 
createMergeAssumingOpsPass()469 std::unique_ptr<OperationPass<func::FuncOp>> createMergeAssumingOpsPass() {
470   return std::make_unique<MergeAssumingOpsPass>();
471 }
472 
473 }  // namespace mhlo
474 }  // namespace mlir
475