1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file contains the patterns to simplify shape ops that were deemed not
17 // suitable for shape op canonicalization in MLIR Core.
18
19 #include <memory>
20 #include <utility>
21
22 #include "llvm/ADT/Optional.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Transforms/PassDetail.h"
25 #include "mlir-hlo/Transforms/passes.h"
26 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"
28 #include "mlir/Dialect/Shape/IR/Shape.h"
29 #include "mlir/Dialect/Tensor/IR/Tensor.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33
34 namespace mlir {
35
36 namespace {
37
38 using shape::BroadcastOp;
39 using shape::ConstShapeOp;
40 using shape::ShapeOfOp;
41
42 // Try to remove operands from broadcasts that don't contribute to the final
43 // result.
44 struct BroadcastRemoveSubsumedOperandsPattern
45 : public OpRewritePattern<BroadcastOp> {
46 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
47
matchAndRewritemlir::__anon944f02b10111::BroadcastRemoveSubsumedOperandsPattern48 LogicalResult matchAndRewrite(BroadcastOp op,
49 PatternRewriter &rewriter) const override {
50 // First collect the static components when joining all shapes. The
51 // resulting vector contains a static dimension if any operand has a static
52 // non-1 dimension in that position. The remaining dimensions are set to
53 // dynamic size.
54 SmallVector<int64_t> knownExtents;
55 SmallVector<SmallVector<int64_t, 4>, 4> operandExtents;
56 for (Value shape : op.getShapes()) {
57 auto &extents = operandExtents.emplace_back();
58 if (failed(shape::getShapeVec(shape, extents))) return failure();
59
60 // Prepend dynamic dims if sizes don't match.
61 if (extents.size() > knownExtents.size()) {
62 knownExtents.insert(knownExtents.begin(),
63 extents.size() - knownExtents.size(),
64 ShapedType::kDynamicSize);
65 }
66
67 for (size_t i = 0, e = extents.size(); i != e; ++i) {
68 int64_t extent = extents[e - i - 1];
69 if (extent != ShapedType::kDynamicSize && extent != 1) {
70 int64_t &knownExtent = knownExtents[knownExtents.size() - i - 1];
71 // A dynamic dimension is subsumed by a static one, but bail out for
72 // known conflicting shapes.
73 if (knownExtent != extent && knownExtent != ShapedType::kDynamicSize)
74 return failure();
75 knownExtent = extent;
76 }
77 }
78 }
79
80 // If we've figured out all shapes to be constants we're done.
81 if (!llvm::is_contained(knownExtents, ShapedType::kDynamicSize)) {
82 rewriter.replaceOpWithNewOp<ConstShapeOp>(
83 op, op->getResultTypes(), rewriter.getIndexTensorAttr(knownExtents));
84 return success();
85 }
86
87 // If only some dimensions are known see if any of the operands can be
88 // removed without affecting the result.
89 SmallVector<Value, 4> filteredOperands;
90 for (auto tuple : llvm::zip(op.getShapes(), operandExtents)) {
91 Value shape = std::get<0>(tuple);
92 auto &extents = std::get<1>(tuple);
93
94 // An operand can't be dead if it's the only operand of the maximum rank.
95 // Removing it would reduce the rank of the output.
96 if (llvm::count_if(operandExtents, [&](ArrayRef<int64_t> op) {
97 return op.size() >= extents.size();
98 }) <= 1) {
99 filteredOperands.push_back(shape);
100 continue;
101 }
102
103 for (size_t i = 0, e = extents.size(); i != e; ++i) {
104 int64_t extent = extents[e - i - 1];
105 // A dimension of an operand can be subsumed if it's
106 // - a 1 dimension. All other operands will have 1 dims or better.
107 if (extent == 1) continue;
108
109 // - a dynamic dim but the result is known to be constant.
110 int64_t knownExtent = knownExtents[knownExtents.size() - i - 1];
111 assert(knownExtent != 1);
112 if (knownExtent != ShapedType::kDynamicSize &&
113 extent == ShapedType::kDynamicSize)
114 continue;
115
116 // - a constant non-1 dimension equal to the "known" dim.
117 // In this case we also have to check whether this operand is the only
118 // contributor of that constant.
119 if (knownExtent != ShapedType::kDynamicSize && extent == knownExtent &&
120 llvm::count_if(operandExtents, [&](ArrayRef<int64_t> operandShape) {
121 return i < operandShape.size() &&
122 operandShape[operandShape.size() - i - 1] == knownExtent;
123 }) > 1)
124 continue;
125
126 filteredOperands.push_back(shape);
127 break;
128 }
129 }
130 if (filteredOperands.size() != op.getShapes().size()) {
131 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op->getResultTypes(),
132 filteredOperands);
133 return success();
134 }
135 return failure();
136 }
137 };
138
139 // Convert cases like:
140 // ```
141 // %1 = shape.shape_of %arg0 : tensor<?x?x?xf64> -> tensor<3xindex>
142 // %2 = shape.shape_of %arg1 : tensor<?x?x1xf64> -> tensor<3xindex>
143 // %3 = shape.broadcast %1, %2 : tensor<3xindex>, tensor<3xindex>
144 // -> tensor<3xindex>
145 // %result = tensor.extract %3[%c2] : tensor<3xindex>
146 // ```
147 // to
148 //
149 // ```
150 // %result = tensor.dim %arg0[%c2] : tensor<?x?x2048xf64>
151 // ```
152 struct ExtractFromBroadcastedTensorCanonicalizationPattern
153 : public OpRewritePattern<tensor::ExtractOp> {
154 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
155
matchAndRewritemlir::__anon944f02b10111::ExtractFromBroadcastedTensorCanonicalizationPattern156 LogicalResult matchAndRewrite(tensor::ExtractOp op,
157 PatternRewriter &rewriter) const override {
158 auto broadcastOp = op.getTensor().getDefiningOp<BroadcastOp>();
159 if (!broadcastOp) return failure();
160
161 // Confirm that there is a constant index. This is required, so we can
162 // confirm the DimOp's input will define the resulting broadcasted shape in
163 // that dimension.
164 auto index =
165 op.getIndices().front().getDefiningOp<arith::ConstantIndexOp>();
166 if (!index) return failure();
167 auto idx = index.value();
168
169 // Iterate through the operands with 3 considerations in this order:
170 // 1. If a static, non-1 dimension is seen, we know this to be the
171 // broadcasted result
172 // 2. If a single dynamic dimension is seen, we know this to be the
173 // broadcasted result (with a possibly 1 or non-1 result)
174 // 3. If no dynamic dimensions and no non-1 static dimensions are seen, we
175 // know the result to be 1
176 //
177 // Iterate through all operands, keeping track of dynamic dimensions and
178 // returning immediately if a non-1 static dimension is seen.
179 ShapeOfOp dynamicShape;
180 int64_t numDynamic = 0;
181 for (auto shape : broadcastOp.getShapes()) {
182 auto shapeOfOp = shape.getDefiningOp<ShapeOfOp>();
183 if (!shapeOfOp) return failure();
184 auto shapedType = shapeOfOp->getOperandTypes().front().cast<ShapedType>();
185
186 // Abort on the existence of unranked shapes as they require more logic.
187 if (!shapedType.hasRank()) return failure();
188 if (shapedType.getRank() <= idx) continue;
189
190 // Only consider dynamic dimensions after the loop because any non-1
191 // static dimension takes precedence.
192 if (shapedType.isDynamicDim(idx)) {
193 dynamicShape = shapeOfOp;
194 numDynamic++;
195 continue;
196 }
197
198 if (shapedType.getDimSize(idx) == 1) continue;
199
200 // Return as soon as we see a non-1 static dim.
201 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
202 op, shapedType.getDimSize(idx));
203 return success();
204 }
205 if (numDynamic > 1) return failure();
206
207 // Replace with the single dynamic dimension or 1.
208 if (dynamicShape) {
209 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, dynamicShape.getArg(),
210 index);
211 } else {
212 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 1);
213 }
214 return success();
215 }
216 };
217
218 struct ShapeSimplification
219 : public ShapeSimplificationBase<ShapeSimplification> {
getDependentDialectsmlir::__anon944f02b10111::ShapeSimplification220 void getDependentDialects(DialectRegistry ®istry) const override {
221 registry.insert<mlir::arith::ArithmeticDialect>();
222 registry.insert<mhlo::MhloDialect>();
223 registry.insert<mlir::func::FuncDialect>();
224 registry.insert<shape::ShapeDialect>();
225 registry.insert<tensor::TensorDialect>();
226 }
227
runOnOperationmlir::__anon944f02b10111::ShapeSimplification228 void runOnOperation() override {
229 MLIRContext *context = &getContext();
230 RewritePatternSet patterns(&getContext());
231
232 for (auto op : context->getRegisteredOperations()) {
233 if (isa<shape::ShapeDialect, mhlo::MhloDialect>(op.getDialect()))
234 op.getCanonicalizationPatterns(patterns, context);
235 }
236
237 patterns.add<BroadcastRemoveSubsumedOperandsPattern,
238 ExtractFromBroadcastedTensorCanonicalizationPattern>(context);
239
240 auto func = getOperation();
241 if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
242 return signalPassFailure();
243 }
244 };
245
246 } // namespace
247
createShapeSimplification()248 std::unique_ptr<OperationPass<func::FuncOp>> createShapeSimplification() {
249 return std::make_unique<ShapeSimplification>();
250 }
251
252 } // namespace mlir
253