1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file contains the patterns to simplify shape ops that were deemed not
17 // suitable for shape op canonicalization in MLIR Core.
18
19 #include "llvm/ADT/Optional.h"
20 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
21 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
22 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
24 #include "mlir/IR/PatternMatch.h" // from @llvm-project
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
27 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
28
29 namespace mlir {
30 namespace kernel_gen {
31 namespace transforms {
32
33 namespace {
34
35 using shape::BroadcastOp;
36 using shape::ConstShapeOp;
37 using shape::ShapeOfOp;
38
39 // Try to remove operands from broadcasts that don't contribute to the final
40 // result.
41 struct BroadcastRemoveSubsumedOperandsPattern
42 : public OpRewritePattern<BroadcastOp> {
43 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
44
matchAndRewritemlir::kernel_gen::transforms::__anon1432853e0111::BroadcastRemoveSubsumedOperandsPattern45 LogicalResult matchAndRewrite(BroadcastOp op,
46 PatternRewriter &rewriter) const override {
47 // First collect the static components when joining all shapes. The
48 // resulting vector contains a static dimension if any operand has a static
49 // non-1 dimension in that position. The remaining dimensions are set to
50 // dynamic size.
51 SmallVector<int64_t> known_extents;
52 SmallVector<SmallVector<int64_t, 4>, 4> operand_extents;
53 for (Value shape : op.shapes()) {
54 auto &extents = operand_extents.emplace_back();
55 if (failed(shape::getShapeVec(shape, extents))) return failure();
56
57 // Prepend dynamic dims if sizes don't match.
58 if (extents.size() > known_extents.size()) {
59 known_extents.insert(known_extents.begin(),
60 extents.size() - known_extents.size(),
61 ShapedType::kDynamicSize);
62 }
63
64 for (size_t i = 0, e = extents.size(); i != e; ++i) {
65 int64_t extent = extents[e - i - 1];
66 if (extent != ShapedType::kDynamicSize && extent != 1) {
67 int64_t &known_extent = known_extents[known_extents.size() - i - 1];
68 // A dynamic dimension is subsumed by a static one, but bail out for
69 // known conflicting shapes.
70 if (known_extent != extent &&
71 known_extent != ShapedType::kDynamicSize)
72 return failure();
73 known_extent = extent;
74 }
75 }
76 }
77
78 // If we've figured out all shapes to be constants we're done.
79 if (!llvm::is_contained(known_extents, ShapedType::kDynamicSize)) {
80 rewriter.replaceOpWithNewOp<ConstShapeOp>(
81 op, op->getResultTypes(), rewriter.getIndexTensorAttr(known_extents));
82 return success();
83 }
84
85 // If only some dimensions are known see if any of the operands can be
86 // removed without affecting the result.
87 SmallVector<Value, 4> filtered_operands;
88 for (auto tuple : llvm::zip(op.shapes(), operand_extents)) {
89 Value shape = std::get<0>(tuple);
90 auto &extents = std::get<1>(tuple);
91
92 // An operand can't be dead if it's the only operand of the maximum rank.
93 // Removing it would reduce the rank of the output.
94 if (llvm::count_if(operand_extents, [&](ArrayRef<int64_t> op) {
95 return op.size() >= extents.size();
96 }) <= 1) {
97 filtered_operands.push_back(shape);
98 continue;
99 }
100
101 for (size_t i = 0, e = extents.size(); i != e; ++i) {
102 int64_t extent = extents[e - i - 1];
103 // A dimension of an operand can be subsumed if it's
104 // - a 1 dimension. All other operands will have 1 dims or better.
105 if (extent == 1) continue;
106
107 // - a dynamic dim but the result is known to be constant.
108 int64_t known_extent = known_extents[known_extents.size() - i - 1];
109 assert(known_extent != 1);
110 if (known_extent != ShapedType::kDynamicSize &&
111 extent == ShapedType::kDynamicSize)
112 continue;
113
114 // - a constant non-1 dimension equal to the "known" dim.
115 // In this case we also have to check whether this operand is the only
116 // contributor of that constant.
117 if (known_extent != ShapedType::kDynamicSize &&
118 extent == known_extent &&
119 llvm::count_if(
120 operand_extents, [&](ArrayRef<int64_t> operand_shape) {
121 return i < operand_shape.size() &&
122 operand_shape[operand_shape.size() - i - 1] ==
123 known_extent;
124 }) > 1)
125 continue;
126
127 filtered_operands.push_back(shape);
128 break;
129 }
130 }
131 if (filtered_operands.size() != op.shapes().size()) {
132 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op->getResultTypes(),
133 filtered_operands);
134 return success();
135 }
136 return failure();
137 }
138 };
139
140 struct ExtractFromExtentTensorCanonicalizationPattern
141 : public OpRewritePattern<tensor::ExtractOp> {
142 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
143
matchAndRewritemlir::kernel_gen::transforms::__anon1432853e0111::ExtractFromExtentTensorCanonicalizationPattern144 LogicalResult matchAndRewrite(tensor::ExtractOp op,
145 PatternRewriter &rewriter) const override {
146 auto shape_of_op = op.tensor().getDefiningOp<ShapeOfOp>();
147 if (!shape_of_op) return failure();
148 Value index = op.indices().front();
149 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shape_of_op.arg(), index);
150 return success();
151 }
152 };
153
154 // Convert cases like:
155 // ```
156 // %1 = shape.shape_of %arg0 : tensor<?x?x?xf64> -> tensor<3xindex>
157 // %2 = shape.shape_of %arg1 : tensor<?x?x1xf64> -> tensor<3xindex>
158 // %3 = shape.broadcast %1, %2 : tensor<3xindex>, tensor<3xindex>
159 // -> tensor<3xindex>
160 // %result = tensor.extract %3[%c2] : tensor<3xindex>
161 // ```
162 // to
163 //
164 // ```
165 // %result = tensor.dim %arg0[%c2] : tensor<?x?x2048xf64>
166 // ```
167 struct ExtractFromBroadcastedTensorCanonicalizationPattern
168 : public OpRewritePattern<tensor::ExtractOp> {
169 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
170
matchAndRewritemlir::kernel_gen::transforms::__anon1432853e0111::ExtractFromBroadcastedTensorCanonicalizationPattern171 LogicalResult matchAndRewrite(tensor::ExtractOp op,
172 PatternRewriter &rewriter) const override {
173 // Confirm that there is a constant index. This is required, so we can
174 // confirm the DimOp's input will define the resulting broadcasted shape in
175 // that dimension.
176 auto index = op.indices().front().getDefiningOp<ConstantIndexOp>();
177 if (!index) return failure();
178 auto idx = index.getValue();
179 auto broadcast_op = op.tensor().getDefiningOp<BroadcastOp>();
180 if (!broadcast_op) return failure();
181
182 // Iterate through the operands with 3 considerations in this order:
183 // 1. If a static, non-1 dimension is seen, we know this to be the
184 // broadcasted result
185 // 2. If a single dynamic dimension is seen, we know this to be the
186 // broadcasted result (with a possibly 1 or non-1 result)
187 // 3. If no dynamic dimensions and no non-1 static dimensions are seen, we
188 // know the result to be 1
189 //
190 // Iterate through all operands, keeping track of dynamic dimensions and
191 // returning immediately if a non-1 static dimension is seen.
192 ShapeOfOp dynamic_shape;
193 int64_t num_dynamic = 0;
194 for (auto shape : broadcast_op.shapes()) {
195 auto shape_of_op = shape.getDefiningOp<ShapeOfOp>();
196 if (!shape_of_op) return failure();
197 auto shaped_type =
198 shape_of_op->getOperandTypes().front().cast<ShapedType>();
199
200 // Abort on the existence of unranked shapes as they require more logic.
201 if (!shaped_type.hasRank()) return failure();
202 if (shaped_type.getRank() <= idx) continue;
203
204 // Only consider dynamic dimensions after the loop because any non-1
205 // static dimension takes precedence.
206 if (shaped_type.isDynamicDim(idx)) {
207 dynamic_shape = shape_of_op;
208 num_dynamic++;
209 continue;
210 }
211
212 if (shaped_type.getDimSize(idx) == 1) continue;
213
214 // Return as soon as we see a non-1 static dim.
215 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op,
216 shaped_type.getDimSize(idx));
217 return success();
218 }
219 if (num_dynamic > 1) return failure();
220
221 // Replace with the single dynamic dimension or 1.
222 if (dynamic_shape) {
223 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, dynamic_shape.arg(),
224 index);
225 } else {
226 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, 1);
227 }
228 return success();
229 }
230 };
231
232 #define GEN_PASS_CLASSES
233 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
234
235 struct ShapeSimplification
236 : public ShapeSimplificationBase<ShapeSimplification> {
getDependentDialectsmlir::kernel_gen::transforms::__anon1432853e0111::ShapeSimplification237 void getDependentDialects(DialectRegistry ®istry) const override {
238 registry.insert<mhlo::MhloDialect>();
239 registry.insert<mlir::StandardOpsDialect>();
240 registry.insert<shape::ShapeDialect>();
241 registry.insert<tensor::TensorDialect>();
242 }
243
runOnFunctionmlir::kernel_gen::transforms::__anon1432853e0111::ShapeSimplification244 void runOnFunction() override {
245 MLIRContext *context = &getContext();
246 RewritePatternSet patterns(&getContext());
247
248 Dialect *shape_dialect = context->getLoadedDialect<shape::ShapeDialect>();
249 Dialect *mhlo_dialect = context->getLoadedDialect<mhlo::MhloDialect>();
250 for (auto *op : context->getRegisteredOperations()) {
251 if (op->dialect.getTypeID() == shape_dialect->getTypeID() ||
252 op->dialect.getTypeID() == mhlo_dialect->getTypeID())
253 op->getCanonicalizationPatterns(patterns, context);
254 }
255
256 patterns.insert<BroadcastRemoveSubsumedOperandsPattern,
257 ExtractFromBroadcastedTensorCanonicalizationPattern,
258 ExtractFromExtentTensorCanonicalizationPattern>(context);
259
260 auto func = getFunction();
261 if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
262 return signalPassFailure();
263 }
264 };
265
266 } // namespace
267
CreateShapeSimplification()268 std::unique_ptr<FunctionPass> CreateShapeSimplification() {
269 return std::make_unique<ShapeSimplification>();
270 }
271
272 } // namespace transforms
273 } // namespace kernel_gen
274 } // namespace mlir
275