1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file contains the patterns to simplify shape ops that were deemed not
17 // suitable for shape op canonicalization in MLIR Core.
18 
19 #include "llvm/ADT/Optional.h"
20 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
24 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
27 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
28 
29 namespace mlir {
30 namespace kernel_gen {
31 namespace transforms {
32 
33 namespace {
34 
35 using shape::BroadcastOp;
36 using shape::ConstShapeOp;
37 using shape::ShapeOfOp;
38 
39 // Try to remove operands from broadcasts that don't contribute to the final
40 // result.
41 struct BroadcastRemoveSubsumedOperandsPattern
42     : public OpRewritePattern<BroadcastOp> {
43   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
44 
matchAndRewritemlir::kernel_gen::transforms::__anon1432853e0111::BroadcastRemoveSubsumedOperandsPattern45   LogicalResult matchAndRewrite(BroadcastOp op,
46                                 PatternRewriter &rewriter) const override {
47     // First collect the static components when joining all shapes. The
48     // resulting vector contains a static dimension if any operand has a static
49     // non-1 dimension in that position. The remaining dimensions are set to
50     // dynamic size.
51     SmallVector<int64_t> known_extents;
52     SmallVector<SmallVector<int64_t, 4>, 4> operand_extents;
53     for (Value shape : op.shapes()) {
54       auto &extents = operand_extents.emplace_back();
55       if (failed(shape::getShapeVec(shape, extents))) return failure();
56 
57       // Prepend dynamic dims if sizes don't match.
58       if (extents.size() > known_extents.size()) {
59         known_extents.insert(known_extents.begin(),
60                              extents.size() - known_extents.size(),
61                              ShapedType::kDynamicSize);
62       }
63 
64       for (size_t i = 0, e = extents.size(); i != e; ++i) {
65         int64_t extent = extents[e - i - 1];
66         if (extent != ShapedType::kDynamicSize && extent != 1) {
67           int64_t &known_extent = known_extents[known_extents.size() - i - 1];
68           // A dynamic dimension is subsumed by a static one, but bail out for
69           // known conflicting shapes.
70           if (known_extent != extent &&
71               known_extent != ShapedType::kDynamicSize)
72             return failure();
73           known_extent = extent;
74         }
75       }
76     }
77 
78     // If we've figured out all shapes to be constants we're done.
79     if (!llvm::is_contained(known_extents, ShapedType::kDynamicSize)) {
80       rewriter.replaceOpWithNewOp<ConstShapeOp>(
81           op, op->getResultTypes(), rewriter.getIndexTensorAttr(known_extents));
82       return success();
83     }
84 
85     // If only some dimensions are known see if any of the operands can be
86     // removed without affecting the result.
87     SmallVector<Value, 4> filtered_operands;
88     for (auto tuple : llvm::zip(op.shapes(), operand_extents)) {
89       Value shape = std::get<0>(tuple);
90       auto &extents = std::get<1>(tuple);
91 
92       // An operand can't be dead if it's the only operand of the maximum rank.
93       // Removing it would reduce the rank of the output.
94       if (llvm::count_if(operand_extents, [&](ArrayRef<int64_t> op) {
95             return op.size() >= extents.size();
96           }) <= 1) {
97         filtered_operands.push_back(shape);
98         continue;
99       }
100 
101       for (size_t i = 0, e = extents.size(); i != e; ++i) {
102         int64_t extent = extents[e - i - 1];
103         // A dimension of an operand can be subsumed if it's
104         //   - a 1 dimension. All other operands will have 1 dims or better.
105         if (extent == 1) continue;
106 
107         //   - a dynamic dim but the result is known to be constant.
108         int64_t known_extent = known_extents[known_extents.size() - i - 1];
109         assert(known_extent != 1);
110         if (known_extent != ShapedType::kDynamicSize &&
111             extent == ShapedType::kDynamicSize)
112           continue;
113 
114         //   - a constant non-1 dimension equal to the "known" dim.
115         // In this case we also have to check whether this operand is the only
116         // contributor of that constant.
117         if (known_extent != ShapedType::kDynamicSize &&
118             extent == known_extent &&
119             llvm::count_if(
120                 operand_extents, [&](ArrayRef<int64_t> operand_shape) {
121                   return i < operand_shape.size() &&
122                          operand_shape[operand_shape.size() - i - 1] ==
123                              known_extent;
124                 }) > 1)
125           continue;
126 
127         filtered_operands.push_back(shape);
128         break;
129       }
130     }
131     if (filtered_operands.size() != op.shapes().size()) {
132       rewriter.replaceOpWithNewOp<BroadcastOp>(op, op->getResultTypes(),
133                                                filtered_operands);
134       return success();
135     }
136     return failure();
137   }
138 };
139 
140 struct ExtractFromExtentTensorCanonicalizationPattern
141     : public OpRewritePattern<tensor::ExtractOp> {
142   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
143 
matchAndRewritemlir::kernel_gen::transforms::__anon1432853e0111::ExtractFromExtentTensorCanonicalizationPattern144   LogicalResult matchAndRewrite(tensor::ExtractOp op,
145                                 PatternRewriter &rewriter) const override {
146     auto shape_of_op = op.tensor().getDefiningOp<ShapeOfOp>();
147     if (!shape_of_op) return failure();
148     Value index = op.indices().front();
149     rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shape_of_op.arg(), index);
150     return success();
151   }
152 };
153 
154 // Convert cases like:
155 // ```
156 //  %1 = shape.shape_of %arg0 : tensor<?x?x?xf64> -> tensor<3xindex>
157 //  %2 = shape.shape_of %arg1 : tensor<?x?x1xf64> -> tensor<3xindex>
158 //  %3 = shape.broadcast %1, %2 : tensor<3xindex>, tensor<3xindex>
159 //                                -> tensor<3xindex>
160 //  %result = tensor.extract %3[%c2] : tensor<3xindex>
161 // ```
162 // to
163 //
164 // ```
165 //  %result = tensor.dim %arg0[%c2] : tensor<?x?x2048xf64>
166 // ```
167 struct ExtractFromBroadcastedTensorCanonicalizationPattern
168     : public OpRewritePattern<tensor::ExtractOp> {
169   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
170 
matchAndRewritemlir::kernel_gen::transforms::__anon1432853e0111::ExtractFromBroadcastedTensorCanonicalizationPattern171   LogicalResult matchAndRewrite(tensor::ExtractOp op,
172                                 PatternRewriter &rewriter) const override {
173     // Confirm that there is a constant index. This is required, so we can
174     // confirm the DimOp's input will define the resulting broadcasted shape in
175     // that dimension.
176     auto index = op.indices().front().getDefiningOp<ConstantIndexOp>();
177     if (!index) return failure();
178     auto idx = index.getValue();
179     auto broadcast_op = op.tensor().getDefiningOp<BroadcastOp>();
180     if (!broadcast_op) return failure();
181 
182     // Iterate through the operands with 3 considerations in this order:
183     // 1. If a static, non-1 dimension is seen, we know this to be the
184     // broadcasted result
185     // 2. If a single dynamic dimension is seen, we know this to be the
186     // broadcasted result (with a possibly 1 or non-1 result)
187     // 3. If no dynamic dimensions and no non-1 static dimensions are seen, we
188     // know the result to be 1
189     //
190     // Iterate through all operands, keeping track of dynamic dimensions and
191     // returning immediately if a non-1 static dimension is seen.
192     ShapeOfOp dynamic_shape;
193     int64_t num_dynamic = 0;
194     for (auto shape : broadcast_op.shapes()) {
195       auto shape_of_op = shape.getDefiningOp<ShapeOfOp>();
196       if (!shape_of_op) return failure();
197       auto shaped_type =
198           shape_of_op->getOperandTypes().front().cast<ShapedType>();
199 
200       // Abort on the existence of unranked shapes as they require more logic.
201       if (!shaped_type.hasRank()) return failure();
202       if (shaped_type.getRank() <= idx) continue;
203 
204       // Only consider dynamic dimensions after the loop because any non-1
205       // static dimension takes precedence.
206       if (shaped_type.isDynamicDim(idx)) {
207         dynamic_shape = shape_of_op;
208         num_dynamic++;
209         continue;
210       }
211 
212       if (shaped_type.getDimSize(idx) == 1) continue;
213 
214       // Return as soon as we see a non-1 static dim.
215       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op,
216                                                    shaped_type.getDimSize(idx));
217       return success();
218     }
219     if (num_dynamic > 1) return failure();
220 
221     // Replace with the single dynamic dimension or 1.
222     if (dynamic_shape) {
223       rewriter.replaceOpWithNewOp<tensor::DimOp>(op, dynamic_shape.arg(),
224                                                  index);
225     } else {
226       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, 1);
227     }
228     return success();
229   }
230 };
231 
232 #define GEN_PASS_CLASSES
233 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
234 
235 struct ShapeSimplification
236     : public ShapeSimplificationBase<ShapeSimplification> {
getDependentDialectsmlir::kernel_gen::transforms::__anon1432853e0111::ShapeSimplification237   void getDependentDialects(DialectRegistry &registry) const override {
238     registry.insert<mhlo::MhloDialect>();
239     registry.insert<mlir::StandardOpsDialect>();
240     registry.insert<shape::ShapeDialect>();
241     registry.insert<tensor::TensorDialect>();
242   }
243 
runOnFunctionmlir::kernel_gen::transforms::__anon1432853e0111::ShapeSimplification244   void runOnFunction() override {
245     MLIRContext *context = &getContext();
246     RewritePatternSet patterns(&getContext());
247 
248     Dialect *shape_dialect = context->getLoadedDialect<shape::ShapeDialect>();
249     Dialect *mhlo_dialect = context->getLoadedDialect<mhlo::MhloDialect>();
250     for (auto *op : context->getRegisteredOperations()) {
251       if (op->dialect.getTypeID() == shape_dialect->getTypeID() ||
252           op->dialect.getTypeID() == mhlo_dialect->getTypeID())
253         op->getCanonicalizationPatterns(patterns, context);
254     }
255 
256     patterns.insert<BroadcastRemoveSubsumedOperandsPattern,
257                     ExtractFromBroadcastedTensorCanonicalizationPattern,
258                     ExtractFromExtentTensorCanonicalizationPattern>(context);
259 
260     auto func = getFunction();
261     if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
262       return signalPassFailure();
263   }
264 };
265 
266 }  // namespace
267 
CreateShapeSimplification()268 std::unique_ptr<FunctionPass> CreateShapeSimplification() {
269   return std::make_unique<ShapeSimplification>();
270 }
271 
272 }  // namespace transforms
273 }  // namespace kernel_gen
274 }  // namespace mlir
275