xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/shape_simplification.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file contains the patterns to simplify shape ops that were deemed not
17 // suitable for shape op canonicalization in MLIR Core.
18 
19 #include <memory>
20 #include <utility>
21 
22 #include "llvm/ADT/Optional.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Transforms/PassDetail.h"
25 #include "mlir-hlo/Transforms/passes.h"
26 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"
28 #include "mlir/Dialect/Shape/IR/Shape.h"
29 #include "mlir/Dialect/Tensor/IR/Tensor.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33 
34 namespace mlir {
35 
36 namespace {
37 
38 using shape::BroadcastOp;
39 using shape::ConstShapeOp;
40 using shape::ShapeOfOp;
41 
42 // Try to remove operands from broadcasts that don't contribute to the final
43 // result.
44 struct BroadcastRemoveSubsumedOperandsPattern
45     : public OpRewritePattern<BroadcastOp> {
46   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
47 
matchAndRewritemlir::__anon944f02b10111::BroadcastRemoveSubsumedOperandsPattern48   LogicalResult matchAndRewrite(BroadcastOp op,
49                                 PatternRewriter &rewriter) const override {
50     // First collect the static components when joining all shapes. The
51     // resulting vector contains a static dimension if any operand has a static
52     // non-1 dimension in that position. The remaining dimensions are set to
53     // dynamic size.
54     SmallVector<int64_t> knownExtents;
55     SmallVector<SmallVector<int64_t, 4>, 4> operandExtents;
56     for (Value shape : op.getShapes()) {
57       auto &extents = operandExtents.emplace_back();
58       if (failed(shape::getShapeVec(shape, extents))) return failure();
59 
60       // Prepend dynamic dims if sizes don't match.
61       if (extents.size() > knownExtents.size()) {
62         knownExtents.insert(knownExtents.begin(),
63                             extents.size() - knownExtents.size(),
64                             ShapedType::kDynamicSize);
65       }
66 
67       for (size_t i = 0, e = extents.size(); i != e; ++i) {
68         int64_t extent = extents[e - i - 1];
69         if (extent != ShapedType::kDynamicSize && extent != 1) {
70           int64_t &knownExtent = knownExtents[knownExtents.size() - i - 1];
71           // A dynamic dimension is subsumed by a static one, but bail out for
72           // known conflicting shapes.
73           if (knownExtent != extent && knownExtent != ShapedType::kDynamicSize)
74             return failure();
75           knownExtent = extent;
76         }
77       }
78     }
79 
80     // If we've figured out all shapes to be constants we're done.
81     if (!llvm::is_contained(knownExtents, ShapedType::kDynamicSize)) {
82       rewriter.replaceOpWithNewOp<ConstShapeOp>(
83           op, op->getResultTypes(), rewriter.getIndexTensorAttr(knownExtents));
84       return success();
85     }
86 
87     // If only some dimensions are known see if any of the operands can be
88     // removed without affecting the result.
89     SmallVector<Value, 4> filteredOperands;
90     for (auto tuple : llvm::zip(op.getShapes(), operandExtents)) {
91       Value shape = std::get<0>(tuple);
92       auto &extents = std::get<1>(tuple);
93 
94       // An operand can't be dead if it's the only operand of the maximum rank.
95       // Removing it would reduce the rank of the output.
96       if (llvm::count_if(operandExtents, [&](ArrayRef<int64_t> op) {
97             return op.size() >= extents.size();
98           }) <= 1) {
99         filteredOperands.push_back(shape);
100         continue;
101       }
102 
103       for (size_t i = 0, e = extents.size(); i != e; ++i) {
104         int64_t extent = extents[e - i - 1];
105         // A dimension of an operand can be subsumed if it's
106         //   - a 1 dimension. All other operands will have 1 dims or better.
107         if (extent == 1) continue;
108 
109         //   - a dynamic dim but the result is known to be constant.
110         int64_t knownExtent = knownExtents[knownExtents.size() - i - 1];
111         assert(knownExtent != 1);
112         if (knownExtent != ShapedType::kDynamicSize &&
113             extent == ShapedType::kDynamicSize)
114           continue;
115 
116         //   - a constant non-1 dimension equal to the "known" dim.
117         // In this case we also have to check whether this operand is the only
118         // contributor of that constant.
119         if (knownExtent != ShapedType::kDynamicSize && extent == knownExtent &&
120             llvm::count_if(operandExtents, [&](ArrayRef<int64_t> operandShape) {
121               return i < operandShape.size() &&
122                      operandShape[operandShape.size() - i - 1] == knownExtent;
123             }) > 1)
124           continue;
125 
126         filteredOperands.push_back(shape);
127         break;
128       }
129     }
130     if (filteredOperands.size() != op.getShapes().size()) {
131       rewriter.replaceOpWithNewOp<BroadcastOp>(op, op->getResultTypes(),
132                                                filteredOperands);
133       return success();
134     }
135     return failure();
136   }
137 };
138 
139 // Convert cases like:
140 // ```
141 //  %1 = shape.shape_of %arg0 : tensor<?x?x?xf64> -> tensor<3xindex>
142 //  %2 = shape.shape_of %arg1 : tensor<?x?x1xf64> -> tensor<3xindex>
143 //  %3 = shape.broadcast %1, %2 : tensor<3xindex>, tensor<3xindex>
144 //                                -> tensor<3xindex>
145 //  %result = tensor.extract %3[%c2] : tensor<3xindex>
146 // ```
147 // to
148 //
149 // ```
150 //  %result = tensor.dim %arg0[%c2] : tensor<?x?x2048xf64>
151 // ```
152 struct ExtractFromBroadcastedTensorCanonicalizationPattern
153     : public OpRewritePattern<tensor::ExtractOp> {
154   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
155 
matchAndRewritemlir::__anon944f02b10111::ExtractFromBroadcastedTensorCanonicalizationPattern156   LogicalResult matchAndRewrite(tensor::ExtractOp op,
157                                 PatternRewriter &rewriter) const override {
158     auto broadcastOp = op.getTensor().getDefiningOp<BroadcastOp>();
159     if (!broadcastOp) return failure();
160 
161     // Confirm that there is a constant index. This is required, so we can
162     // confirm the DimOp's input will define the resulting broadcasted shape in
163     // that dimension.
164     auto index =
165         op.getIndices().front().getDefiningOp<arith::ConstantIndexOp>();
166     if (!index) return failure();
167     auto idx = index.value();
168 
169     // Iterate through the operands with 3 considerations in this order:
170     // 1. If a static, non-1 dimension is seen, we know this to be the
171     // broadcasted result
172     // 2. If a single dynamic dimension is seen, we know this to be the
173     // broadcasted result (with a possibly 1 or non-1 result)
174     // 3. If no dynamic dimensions and no non-1 static dimensions are seen, we
175     // know the result to be 1
176     //
177     // Iterate through all operands, keeping track of dynamic dimensions and
178     // returning immediately if a non-1 static dimension is seen.
179     ShapeOfOp dynamicShape;
180     int64_t numDynamic = 0;
181     for (auto shape : broadcastOp.getShapes()) {
182       auto shapeOfOp = shape.getDefiningOp<ShapeOfOp>();
183       if (!shapeOfOp) return failure();
184       auto shapedType = shapeOfOp->getOperandTypes().front().cast<ShapedType>();
185 
186       // Abort on the existence of unranked shapes as they require more logic.
187       if (!shapedType.hasRank()) return failure();
188       if (shapedType.getRank() <= idx) continue;
189 
190       // Only consider dynamic dimensions after the loop because any non-1
191       // static dimension takes precedence.
192       if (shapedType.isDynamicDim(idx)) {
193         dynamicShape = shapeOfOp;
194         numDynamic++;
195         continue;
196       }
197 
198       if (shapedType.getDimSize(idx) == 1) continue;
199 
200       // Return as soon as we see a non-1 static dim.
201       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
202           op, shapedType.getDimSize(idx));
203       return success();
204     }
205     if (numDynamic > 1) return failure();
206 
207     // Replace with the single dynamic dimension or 1.
208     if (dynamicShape) {
209       rewriter.replaceOpWithNewOp<tensor::DimOp>(op, dynamicShape.getArg(),
210                                                  index);
211     } else {
212       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 1);
213     }
214     return success();
215   }
216 };
217 
218 struct ShapeSimplification
219     : public ShapeSimplificationBase<ShapeSimplification> {
getDependentDialectsmlir::__anon944f02b10111::ShapeSimplification220   void getDependentDialects(DialectRegistry &registry) const override {
221     registry.insert<mlir::arith::ArithmeticDialect>();
222     registry.insert<mhlo::MhloDialect>();
223     registry.insert<mlir::func::FuncDialect>();
224     registry.insert<shape::ShapeDialect>();
225     registry.insert<tensor::TensorDialect>();
226   }
227 
runOnOperationmlir::__anon944f02b10111::ShapeSimplification228   void runOnOperation() override {
229     MLIRContext *context = &getContext();
230     RewritePatternSet patterns(&getContext());
231 
232     for (auto op : context->getRegisteredOperations()) {
233       if (isa<shape::ShapeDialect, mhlo::MhloDialect>(op.getDialect()))
234         op.getCanonicalizationPatterns(patterns, context);
235     }
236 
237     patterns.add<BroadcastRemoveSubsumedOperandsPattern,
238                  ExtractFromBroadcastedTensorCanonicalizationPattern>(context);
239 
240     auto func = getOperation();
241     if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
242       return signalPassFailure();
243   }
244 };
245 
246 }  // namespace
247 
createShapeSimplification()248 std::unique_ptr<OperationPass<func::FuncOp>> createShapeSimplification() {
249   return std::make_unique<ShapeSimplification>();
250 }
251 
252 }  // namespace mlir
253