1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <cstdint>
18 #include <iterator>
19 #include <memory>
20 #include <numeric>
21 #include <utility>
22
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "mlir-hlo/Analysis/shape_component_analysis.h"
27 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
28 #include "mlir-hlo/Transforms/PassDetail.h"
29 #include "mlir-hlo/Transforms/passes.h"
30 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
31 #include "mlir/Dialect/Func/IR/FuncOps.h"
32 #include "mlir/Dialect/Linalg/IR/Linalg.h"
33 #include "mlir/Dialect/Shape/IR/Shape.h"
34 #include "mlir/Dialect/Tensor/IR/Tensor.h"
35 #include "mlir/IR/AffineExpr.h"
36 #include "mlir/IR/BuiltinTypes.h"
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
40
41 namespace mlir {
42
43 using ShapeOrValueInfo = ShapeComponentAnalysis::ShapeOrValueInfo;
44 using Symbol = ShapeComponentAnalysis::Symbol;
45 using SymbolicExpr = ShapeComponentAnalysis::SymbolicExpr;
46
47 namespace {
48
49 // Temporary data structure to hold a single dimension of the symbolic result of
50 // `shape.broadcast`.
51 struct SymbolicBroadcastDimension {
52 size_t operandIndex;
53 size_t operandDim;
54 SymbolicExpr expr;
55 };
56
57 // Replace shape.broadcast with a shape if it's statically known.
58 struct SimplifyBroadcasts : public mlir::OpRewritePattern<shape::BroadcastOp> {
59 using OpRewritePattern::OpRewritePattern;
matchAndRewritemlir::__anona8b7c6fe0111::SimplifyBroadcasts60 LogicalResult matchAndRewrite(
61 shape::BroadcastOp op, mlir::PatternRewriter &rewriter) const override {
62 // Require successful shape analysis.
63 ShapeComponentAnalysis shapeAnalysis;
64 llvm::SmallVector<ArrayRef<SymbolicExpr>> shapesInfo;
65 auto shapes = op.getShapes();
66 shapesInfo.reserve(shapes.size());
67 for (Value s : shapes) {
68 auto sInfo = shapeAnalysis.GetValueInfo(s);
69 if (!sInfo) return failure();
70 shapesInfo.push_back(*sInfo);
71 }
72
73 // Find the result rank.
74 size_t rank = 0;
75 for (const auto &sInfo : shapesInfo) rank = std::max(rank, sInfo.size());
76
77 // Compute broadcast symbolically.
78 SmallVector<Optional<SymbolicBroadcastDimension>> symResult(rank,
79 llvm::None);
80 for (const auto &sInfo : llvm::enumerate(shapesInfo)) {
81 size_t dimOffset = rank - sInfo.value().size();
82 for (const auto &symExpr : llvm::enumerate(sInfo.value())) {
83 // Unit dimensions are neutral to the final result.
84 if (symExpr.value().isConstant(1)) continue;
85
86 // Use unique expression.
87 size_t i = dimOffset + symExpr.index();
88 if (!symResult[i]) {
89 symResult[i] = {sInfo.index(), symExpr.index(), symExpr.value()};
90 continue;
91 }
92
93 // Bail if the dimensions are neither equal nor 1.
94 if (symResult[i]->expr != symExpr.value()) return failure();
95 }
96 }
97
98 // Materialize broadcast result.
99 auto loc = op.getLoc();
100 DenseMap<int64_t, Value> constants;
101 auto findOrCreateConstant = [&](int64_t c) {
102 auto it = constants.find(c);
103 if (it != constants.end()) return it->second;
104 Value newlyCreated = rewriter.create<arith::ConstantIndexOp>(loc, c);
105 constants[c] = newlyCreated;
106 return newlyCreated;
107 };
108 auto elements = llvm::to_vector<8>(
109 llvm::map_range(symResult, [&](const auto &symResultDim) {
110 // If we know the dimension statically, use a constant.
111 if (!symResultDim) return findOrCreateConstant(1);
112 if (auto cexpr = symResultDim->expr.expr
113 .template dyn_cast<AffineConstantExpr>()) {
114 return findOrCreateConstant(cexpr.getValue());
115 }
116
117 // Othwerise, extract the dimension from the unique operand.
118 Value operand = shapes[symResultDim->operandIndex];
119 Value operandDim = findOrCreateConstant(symResultDim->operandDim);
120 return rewriter.create<tensor::ExtractOp>(loc, operand, operandDim)
121 .getResult();
122 }));
123 Type indexTy = rewriter.getIndexType();
124 Type concreteResultTy =
125 RankedTensorType::get({static_cast<int64_t>(elements.size())}, indexTy);
126 Value result = rewriter.create<tensor::FromElementsOp>(
127 loc, concreteResultTy, elements);
128
129 // Insert cast, if needed.
130 Type expectedTy = op.getResult().getType();
131 if (result.getType() != expectedTy) {
132 result = rewriter.create<tensor::CastOp>(loc, expectedTy, result);
133 }
134
135 rewriter.replaceOp(op, result);
136 return success();
137 }
138 };
139
analyzeDynamicBroadcastInDimExpandingBehavior(ShapeComponentAnalysis & analysis,Value value,Value shape,llvm::SmallSetVector<int64_t,4> * knownExpandingDims,llvm::SmallSetVector<int64_t,4> * knownNonexpandingDims)140 LogicalResult analyzeDynamicBroadcastInDimExpandingBehavior(
141 ShapeComponentAnalysis &analysis, Value value, Value shape,
142 llvm::SmallSetVector<int64_t, 4> *knownExpandingDims,
143 llvm::SmallSetVector<int64_t, 4> *knownNonexpandingDims) {
144 // Require successful analysis of shapes.
145 auto shapeIn = analysis.GetShapeInfo(value);
146 auto shapeOut = analysis.GetValueInfo(shape);
147 if (!shapeIn || !shapeOut) return failure();
148
149 // Analyze per argument dimension.
150 size_t rankIn = shapeIn->size();
151 size_t rankOut = shapeOut->size();
152 assert(rankIn <= rankOut);
153 size_t dimOutOffset = rankOut - rankIn;
154 for (size_t i = 0; i < rankIn; ++i) {
155 SymbolicExpr dimIn = (*shapeIn)[i];
156 SymbolicExpr dimOut = (*shapeOut)[dimOutOffset + i];
157 if (dimIn.isConstant(1) && dimOut.isKnownNotOne())
158 knownExpandingDims->insert(i);
159 if (dimIn == dimOut || dimOut.isConstant(1))
160 knownNonexpandingDims->insert(i);
161 }
162 return success();
163 }
164
165 // Analyze `mhlo.dynamic_broadcast_in_dim` op and populate attributes for
166 // statically known expanding and non-expanding dimensions.
167 struct AnnotateExpandingDimensionsInDynamicBroadcastInDim
168 : public mlir::OpRewritePattern<mhlo::DynamicBroadcastInDimOp> {
169 using OpRewritePattern::OpRewritePattern;
matchAndRewritemlir::__anona8b7c6fe0111::AnnotateExpandingDimensionsInDynamicBroadcastInDim170 LogicalResult matchAndRewrite(
171 mhlo::DynamicBroadcastInDimOp op,
172 mlir::PatternRewriter &rewriter) const override {
173 // Analyze shapes and identify expanding and non-expanding dims.
174 ShapeComponentAnalysis analysis;
175 llvm::SmallSetVector<int64_t, 4> knownExpandingDims, knownNonexpandingDims;
176 if (failed(analyzeDynamicBroadcastInDimExpandingBehavior(
177 analysis, op.operand(), op.output_dimensions(), &knownExpandingDims,
178 &knownNonexpandingDims))) {
179 return failure();
180 }
181
182 // Collect possibly already annotated info.
183 auto insertAll = [](llvm::SmallSetVector<int64_t, 4> &dst,
184 Optional<DenseIntElementsAttr> src) {
185 if (!src) return;
186 for (auto it : *src) dst.insert(it.getLimitedValue());
187 };
188 insertAll(knownExpandingDims, op.known_expanding_dimensions());
189 insertAll(knownNonexpandingDims, op.known_nonexpanding_dimensions());
190
191 // Fail pattern application if there is nothing new to annotate.
192 auto isEqual = [](llvm::SmallSetVector<int64_t, 4> &set,
193 DenseIntElementsAttr attr) {
194 return static_cast<int64_t>(set.size()) == attr.size() &&
195 llvm::all_of(attr, [&](auto it) {
196 return set.count(it.getLimitedValue());
197 });
198 };
199 if (op.known_expanding_dimensions() && op.known_nonexpanding_dimensions() &&
200 isEqual(knownExpandingDims, *op.known_expanding_dimensions()) &&
201 isEqual(knownNonexpandingDims, *op.known_nonexpanding_dimensions())) {
202 return failure();
203 }
204
205 // Annotate op in place.
206 rewriter.startRootUpdate(op);
207 op.known_expanding_dimensionsAttr(
208 rewriter.getI64TensorAttr(knownExpandingDims.takeVector()));
209 op.known_nonexpanding_dimensionsAttr(
210 rewriter.getI64TensorAttr(knownNonexpandingDims.takeVector()));
211 rewriter.finalizeRootUpdate(op);
212 return success();
213 }
214 };
215
216 // Remove compute_reshape_shape if we can prove that the dynamic shape does not
217 // contain a `-1` dimension.
218 struct RemoveComputeReshapeShape final
219 : public OpRewritePattern<mhlo::ComputeReshapeShapeOp> {
220 using OpRewritePattern::OpRewritePattern;
matchAndRewritemlir::__anona8b7c6fe0111::RemoveComputeReshapeShape221 LogicalResult matchAndRewrite(mhlo::ComputeReshapeShapeOp op,
222 PatternRewriter &rewriter) const override {
223 ShapeComponentAnalysis shapeComponentAnalysis;
224 auto dynamicShape = shapeComponentAnalysis.GetValueInfo(op.dynamic_shape());
225 if (!dynamicShape) return failure();
226
227 if (llvm::any_of(*dynamicShape, [](const auto &dim) {
228 return !dim.isKnownNotNegativeOne();
229 })) {
230 return failure();
231 }
232 rewriter.replaceOp(op, op.dynamic_shape());
233 return success();
234 }
235 };
236
isProduct(AffineExpr expr,llvm::function_ref<void (AffineConstantExpr)> cbkConstantFactor,llvm::function_ref<void (AffineSymbolExpr)> cbkSymbolicFactor)237 bool isProduct(AffineExpr expr,
238 llvm::function_ref<void(AffineConstantExpr)> cbkConstantFactor,
239 llvm::function_ref<void(AffineSymbolExpr)> cbkSymbolicFactor) {
240 auto binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
241 if (binExpr && binExpr.getKind() == AffineExprKind::Mul) {
242 return isProduct(binExpr.getLHS(), cbkConstantFactor, cbkSymbolicFactor) &&
243 isProduct(binExpr.getRHS(), cbkConstantFactor, cbkSymbolicFactor);
244 }
245 if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
246 cbkSymbolicFactor(symExpr);
247 return true;
248 }
249 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
250 cbkConstantFactor(constExpr);
251 return true;
252 }
253 return false;
254 }
255
isSymbolicProduct(const SymbolicExpr & symbolicExpr,llvm::function_ref<void (int64_t)> cbkConstantFactor,llvm::function_ref<void (Symbol)> cbkSymbolicFactor)256 bool isSymbolicProduct(const SymbolicExpr &symbolicExpr,
257 llvm::function_ref<void(int64_t)> cbkConstantFactor,
258 llvm::function_ref<void(Symbol)> cbkSymbolicFactor) {
259 return isProduct(
260 symbolicExpr.expr,
261 [&](AffineConstantExpr cexpr) { cbkConstantFactor(cexpr.getValue()); },
262 [&](AffineSymbolExpr sexpr) {
263 cbkSymbolicFactor(symbolicExpr.symbols[sexpr.getPosition()]);
264 });
265 }
266
267 // Represents a product of symbolic and concrete factors. This will allow us to
268 // prove product equalities symbolically.
269 struct SymbolicProduct {
270 // Product of all concrete factors.
271 int64_t concrete = 1;
272 // List all symbolic factors as they can not be aggregated.
273 llvm::SmallVector<Symbol> symbolic;
emptymlir::__anona8b7c6fe0111::SymbolicProduct274 bool empty() { return concrete == 1 && symbolic.empty(); }
275 };
276
isSymbolicProduct(const SymbolicExpr & symbolicExpr,SymbolicProduct * product)277 bool isSymbolicProduct(const SymbolicExpr &symbolicExpr,
278 SymbolicProduct *product) {
279 return isSymbolicProduct(
280 symbolicExpr, [&](int64_t c) { product->concrete *= c; },
281 [&](Symbol s) { product->symbolic.push_back(s); });
282 }
283
284 struct RemoveRedundantCstrReshapable final
285 : public OpRewritePattern<mhlo::CstrReshapableOp> {
286 using OpRewritePattern::OpRewritePattern;
matchAndRewritemlir::__anona8b7c6fe0111::RemoveRedundantCstrReshapable287 LogicalResult matchAndRewrite(mhlo::CstrReshapableOp op,
288 PatternRewriter &rewriter) const override {
289 // Get shape analysis info for the number of elements.
290 ShapeComponentAnalysis shapeComponentAnalysis;
291 auto numElementsInfo =
292 shapeComponentAnalysis.GetValueInfo(op.num_elements());
293 if (!numElementsInfo) return failure();
294 assert(numElementsInfo->size() == 1 && "expect one value for a scalar");
295 auto numElements = numElementsInfo->front();
296
297 // Get shape analysis info for the dynamic shape.
298 auto dynShapeDims = shapeComponentAnalysis.GetValueInfo(op.dynamic_shape());
299 if (!dynShapeDims) return failure();
300
301 // We can handle two cases:
302 // - there is exactly one -1 in the dynamic shape, i.e. a unique wildcard
303 // dimension, or
304 // - there is no -1 in the dynamic shape, i.e. no wildcard dimension.
305 bool uniqueWildcardDimension = false;
306 for (const auto &d : *dynShapeDims) {
307 if (d.isConstant(-1)) {
308 if (uniqueWildcardDimension) return failure();
309 uniqueWildcardDimension = true;
310 } else if (!d.isKnownNotNegativeOne()) {
311 return failure();
312 }
313 }
314
315 // We can only handle simple products with constants and symbols. Find all
316 // the factors based on the number of elements.
317 SymbolicProduct numElementsRemainingFactors;
318 if (!isSymbolicProduct(numElements, &numElementsRemainingFactors)) {
319 return failure();
320 }
321 assert(numElementsRemainingFactors.concrete >= 1 &&
322 "number of elements cannot entail negative or zero factors");
323
324 // Find all factors based on the dynamic shape.
325 // - Accumulate the conrete product to later compare it against its
326 // equivalent based on the number of elements.
327 // - Remove symbolic factors from the list and fail if we find an unknown
328 // factor, i.e. if the symbolic factors based on the dynamic shape are
329 // not a subset of the factors based on the number of elements.
330 int64_t concreteProductDynShape = 1;
331 for (const auto &dim : *dynShapeDims) {
332 SmallVector<Symbol> partialSymbolicFactorsDynShape;
333 if (!isSymbolicProduct(
334 dim,
335 [&](int64_t c) {
336 if (c != ShapedType::kDynamicSize) concreteProductDynShape *= c;
337 },
338 [&](Symbol s) { partialSymbolicFactorsDynShape.push_back(s); })) {
339 return failure();
340 }
341 for (const Symbol &symDynShape : partialSymbolicFactorsDynShape) {
342 auto *it =
343 llvm::find(numElementsRemainingFactors.symbolic, symDynShape);
344 if (it == numElementsRemainingFactors.symbolic.end()) return failure();
345 numElementsRemainingFactors.symbolic.erase(it);
346 }
347 }
348 assert(concreteProductDynShape >= 1 &&
349 "concrete product must not aggregate negative or zero factors");
350
351 // A wildcard dimension can subsume the remaining symbolic factors and
352 // potentially also a concrete factor.
353 if (uniqueWildcardDimension) {
354 if (numElementsRemainingFactors.concrete % concreteProductDynShape != 0)
355 return failure();
356 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
357 return success();
358 }
359
360 // W/o a wildcard, the symbolic and concrete products must be equal.
361 bool isReshapable =
362 numElementsRemainingFactors.symbolic.empty() &&
363 numElementsRemainingFactors.concrete == concreteProductDynShape;
364 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, isReshapable);
365 return success();
366 }
367 };
368
materializeReshapeAsScalarExpand(RankedTensorType operandTy,RankedTensorType resultTy,mhlo::DynamicReshapeOp op,PatternRewriter & rewriter)369 LogicalResult materializeReshapeAsScalarExpand(RankedTensorType operandTy,
370 RankedTensorType resultTy,
371 mhlo::DynamicReshapeOp op,
372 PatternRewriter &rewriter) {
373 assert(operandTy.getRank() == 0 && "expect scalar operand");
374 auto loc = op.getLoc();
375 SmallVector<int64_t> unitDims(resultTy.getRank(), 1);
376 auto expandedTy = RankedTensorType::get(unitDims, resultTy.getElementType());
377 Value expandedScalar = rewriter.create<tensor::ExpandShapeOp>(
378 loc, expandedTy, op.operand(), ArrayRef<ReassociationIndices>{});
379 if (expandedScalar.getType() != resultTy) {
380 expandedScalar =
381 rewriter.create<tensor::CastOp>(loc, resultTy, expandedScalar);
382 }
383 rewriter.replaceOp(op, expandedScalar);
384 return success();
385 }
386
materializeReshapeAsScalarCollapse(RankedTensorType operandTy,RankedTensorType resultTy,mhlo::DynamicReshapeOp op,PatternRewriter & rewriter)387 LogicalResult materializeReshapeAsScalarCollapse(RankedTensorType operandTy,
388 RankedTensorType resultTy,
389 mhlo::DynamicReshapeOp op,
390 PatternRewriter &rewriter) {
391 assert(resultTy.getRank() == 0 && "expect scalar result");
392 auto loc = op.getLoc();
393 Value operand = op.operand();
394 SmallVector<int64_t> unitDims(operandTy.getRank(), 1);
395 auto castedOperandTy =
396 RankedTensorType::get(unitDims, operandTy.getElementType());
397 if (operand.getType() != castedOperandTy) {
398 operand = rewriter.create<tensor::CastOp>(loc, castedOperandTy, operand);
399 }
400 Value collapsedScalar = rewriter.create<tensor::CollapseShapeOp>(
401 loc, operand, ArrayRef<ReassociationIndices>{});
402 rewriter.replaceOp(op, collapsedScalar);
403 return success();
404 }
405
406 enum class DimensionGroupKind {
407 kNone,
408 kExpanding,
409 kCollapsing,
410 };
411
412 struct DimensionGroup {
413 int64_t size = 0;
414 DimensionGroupKind kind = DimensionGroupKind::kNone;
415 };
416
eliminateCommonFactors(SymbolicProduct & a,SymbolicProduct & b)417 SymbolicProduct eliminateCommonFactors(SymbolicProduct &a, SymbolicProduct &b) {
418 SymbolicProduct gcd;
419
420 // Eliminate common concrete factors.
421 gcd.concrete = llvm::GreatestCommonDivisor64(a.concrete, b.concrete);
422 a.concrete /= gcd.concrete;
423 b.concrete /= gcd.concrete;
424
425 // Eliminate common symbolic factors.
426 int64_t i = 0;
427 while (i < static_cast<int64_t>(a.symbolic.size())) {
428 auto *it = llvm::find(b.symbolic, a.symbolic[i]);
429 if (it != b.symbolic.end()) {
430 gcd.symbolic.push_back(*it);
431 std::swap(a.symbolic[i], a.symbolic.back());
432 a.symbolic.pop_back();
433 b.symbolic.erase(it);
434 } else {
435 i++;
436 }
437 }
438
439 return gcd;
440 }
441
isUnpairedUnitDimension(ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator it,ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator end,ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator otherIt,ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator otherEnd)442 bool isUnpairedUnitDimension(
443 ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator it,
444 ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator end,
445 ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator otherIt,
446 ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator otherEnd) {
447 return it != end && it->isConstant(1) &&
448 (otherIt == otherEnd || !otherIt->isConstant(1));
449 }
450
getShapedTypyDimSize(const SymbolicProduct & symProduct)451 int64_t getShapedTypyDimSize(const SymbolicProduct &symProduct) {
452 return symProduct.symbolic.empty() ? symProduct.concrete
453 : ShapedType::kDynamicSize;
454 }
455
456 // Iterate over the operand's and the result's shape dimensions and find
457 // dimension groups that are collapsing, expanding, or untouched:
458 // - Collapsing: Multiple dimensions of the operand shape can be collapsed
459 // into a single dimension of the result shape. We must prove that the
460 // product of the operand shape's dimensions is equal to the corresponding
461 // result dimension.
462 // - Expanding: A single dimension of the operand shape can be expanded into
463 // multiple dimensions of the result shape. We must prove that the product
464 // of the result shape's dimensions is equal to the corresponding operand
465 // dimension. This case is limited to at most one dynamic dimension per
466 // expansion group as otherwise not supported by the `expand_shape` op.
467 // - Untouched: There is a 1:1 correspondance between an operand and a result
468 // shape dimension.
469 //
470 // We can determine the optimal dimension groups greedily by consuming operand
471 // and result dimensions from left to right. If the leading operand dimension is
472 // a strict divisor of the leading result dimension, collapsing is required. In
473 // this case, we keep consuming the operand dimensions until the products are
474 // equal. If the leading result dimension is a strict divisor of the leading
475 // operand dimension, expanding is required. In this case, we keep consuming the
476 // result dimensions until the products are equal. Trailing unit dimensions may
477 // be inlcuded in the dimension group. This is useful iff they are "unpaired",
478 // in which case they would only limit us in the subsequent iteration.
479 //
findExpandingAndCollapsingDimensionGroups(ArrayRef<SymbolicExpr> operandShapeInfo,ArrayRef<SymbolicExpr> resultShapeInfo,SmallVector<DimensionGroup> * dimensionGroups,SmallVector<int64_t> * expandedIntermShape)480 LogicalResult findExpandingAndCollapsingDimensionGroups(
481 ArrayRef<SymbolicExpr> operandShapeInfo,
482 ArrayRef<SymbolicExpr> resultShapeInfo,
483 SmallVector<DimensionGroup> *dimensionGroups,
484 SmallVector<int64_t> *expandedIntermShape) {
485 const auto *operandShapeIt = operandShapeInfo.begin();
486 const auto *operandShapeEnd = operandShapeInfo.end();
487 const auto *resultShapeIt = resultShapeInfo.begin();
488 const auto *resultShapeEnd = resultShapeInfo.end();
489
490 // Crucial iteration state.
491 SymbolicProduct remainingOperandShapeFactors;
492 SymbolicProduct remainingResultShapeFactors;
493 auto anyRemainingFactors = [&]() {
494 return !remainingOperandShapeFactors.empty() ||
495 !remainingResultShapeFactors.empty();
496 };
497
498 while (operandShapeIt != operandShapeEnd && resultShapeIt != resultShapeEnd) {
499 assert(!anyRemainingFactors() &&
500 "expect no remaining factors from previous iteration");
501 DimensionGroup &dimGroup = dimensionGroups->emplace_back();
502
503 // Consume at least one operand and result dimension.
504 {
505 if (!isSymbolicProduct(*operandShapeIt++,
506 &remainingOperandShapeFactors) ||
507 !isSymbolicProduct(*resultShapeIt++, &remainingResultShapeFactors)) {
508 return failure();
509 }
510 dimGroup.size++;
511 SymbolicProduct gcd = eliminateCommonFactors(remainingOperandShapeFactors,
512 remainingResultShapeFactors);
513 expandedIntermShape->push_back(getShapedTypyDimSize(gcd));
514 }
515
516 // Fail if there are unresolvable, contradicting factors remaining.
517 if (!remainingOperandShapeFactors.empty() &&
518 !remainingResultShapeFactors.empty()) {
519 return failure();
520 }
521
522 // Collapsing: Create a collapsing dimension group.
523 bool requiresCollapsing =
524 remainingOperandShapeFactors.empty() &&
525 (!remainingResultShapeFactors.empty() ||
526 isUnpairedUnitDimension(operandShapeIt, operandShapeEnd, resultShapeIt,
527 resultShapeEnd));
528 if (requiresCollapsing) {
529 dimGroup.kind = DimensionGroupKind::kCollapsing;
530
531 // Consume operand shape dimensions until their product matches the
532 // corresponding result dimension (or fail if unresolvable/contradicting
533 // factors are found).
534 while (operandShapeIt != operandShapeEnd &&
535 remainingOperandShapeFactors.empty() &&
536 !remainingResultShapeFactors.empty()) {
537 if (!isSymbolicProduct(*operandShapeIt++,
538 &remainingOperandShapeFactors)) {
539 return failure();
540 }
541 dimGroup.size++;
542 SymbolicProduct gcd = eliminateCommonFactors(
543 remainingOperandShapeFactors, remainingResultShapeFactors);
544 expandedIntermShape->push_back(getShapedTypyDimSize(gcd));
545 }
546 if (anyRemainingFactors()) return failure();
547
548 // Consume trailing, unpaired unit dimensions.
549 while (isUnpairedUnitDimension(operandShapeIt, operandShapeEnd,
550 resultShapeIt, resultShapeEnd)) {
551 operandShapeIt++;
552 dimGroup.size++;
553 expandedIntermShape->push_back(1);
554 }
555
556 continue;
557 }
558
559 // Expanding: Create an expanding dimension group.
560 bool requiresExpanding =
561 remainingResultShapeFactors.empty() &&
562 (!remainingOperandShapeFactors.empty() ||
563 isUnpairedUnitDimension(resultShapeIt, resultShapeEnd, operandShapeIt,
564 operandShapeEnd));
565 if (requiresExpanding) {
566 dimGroup.kind = DimensionGroupKind::kExpanding;
567 int64_t numDynamicDims = 0;
568
569 // Consume result shape dimensions until their product matches the
570 // corresponding operand dimension (or fail if unresolvable/contradicting
571 // factors are found).
572 while (resultShapeIt != resultShapeEnd &&
573 remainingResultShapeFactors.empty() &&
574 !remainingOperandShapeFactors.empty()) {
575 if (!isSymbolicProduct(*resultShapeIt++,
576 &remainingResultShapeFactors)) {
577 return failure();
578 }
579 dimGroup.size++;
580 SymbolicProduct gcd = eliminateCommonFactors(
581 remainingOperandShapeFactors, remainingResultShapeFactors);
582 int64_t tyDimSize = getShapedTypyDimSize(gcd);
583
584 // Allow no more than one dynamic dimension per expansion group.
585 if (tyDimSize == ShapedType::kDynamicSize) {
586 numDynamicDims++;
587 if (numDynamicDims > 1) return failure();
588 }
589 expandedIntermShape->push_back(tyDimSize);
590 }
591 if (anyRemainingFactors()) return failure();
592
593 // Consume trailing, unpaired unit dimensions.
594 while (isUnpairedUnitDimension(resultShapeIt, resultShapeEnd,
595 operandShapeIt, operandShapeEnd)) {
596 resultShapeIt++;
597 dimGroup.size++;
598 expandedIntermShape->push_back(1);
599 }
600
601 continue;
602 }
603
604 // Untouched: 1:1 mapping between operand and result shape dimension. This
605 // is neither expanding nor collapsing.
606 assert(!requiresCollapsing && !requiresExpanding && "expect id case");
607 assert(dimGroup.size == 1 && dimGroup.kind == DimensionGroupKind::kNone &&
608 "expect simple dimension group");
609 }
610
611 // Fail if there are remaining dimensions that could not be consumed.
612 assert(!anyRemainingFactors() && "expect no remaining factors");
613 if (operandShapeIt != operandShapeEnd || resultShapeIt != resultShapeEnd) {
614 return failure();
615 }
616
617 return success();
618 }
619
concretizeOperandShape(ArrayRef<int64_t> operandShape,ArrayRef<SymbolicExpr> operandShapeInfo)620 SmallVector<int64_t> concretizeOperandShape(
621 ArrayRef<int64_t> operandShape, ArrayRef<SymbolicExpr> operandShapeInfo) {
622 SmallVector<int64_t> result;
623 for (auto it : llvm::zip(operandShape, operandShapeInfo)) {
624 auto dimSize = std::get<0>(it);
625 auto sExpr = std::get<1>(it);
626 if (auto cexpr = sExpr.expr.dyn_cast<AffineConstantExpr>()) {
627 int64_t alsoDimSize = cexpr.getValue();
628 assert((ShapedType::isDynamic(dimSize) || dimSize == alsoDimSize) &&
629 "expect shape analysis result to be compatible with type");
630 result.push_back(alsoDimSize);
631 continue;
632 }
633 result.push_back(dimSize);
634 }
635 return result;
636 }
637
requiresReassociationOfKind(DimensionGroupKind kind,const SmallVector<DimensionGroup> & dimGroups)638 llvm::Optional<SmallVector<ReassociationIndices>> requiresReassociationOfKind(
639 DimensionGroupKind kind, const SmallVector<DimensionGroup> &dimGroups) {
640 SmallVector<ReassociationIndices> reassociation;
641 reassociation.reserve(dimGroups.size());
642 bool isStrictlyReassociating = false;
643 int64_t i = 0;
644 for (const DimensionGroup &g : dimGroups) {
645 if (g.kind == kind) {
646 isStrictlyReassociating = true;
647 reassociation.push_back(
648 llvm::to_vector(llvm::seq<int64_t>(i, i + g.size)));
649 i += g.size;
650 continue;
651 }
652 for (int64_t j = 0; j < g.size; j++) reassociation.push_back({i++});
653 }
654
655 // Return the reassociation if expansion is required.
656 if (isStrictlyReassociating) return reassociation;
657 return llvm::None;
658 }
659
materializeReshapeAsExpandAndCollapse(ShapeComponentAnalysis & shapeAnalysis,RankedTensorType operandTy,RankedTensorType resultTy,mhlo::DynamicReshapeOp op,PatternRewriter & rewriter)660 LogicalResult materializeReshapeAsExpandAndCollapse(
661 ShapeComponentAnalysis &shapeAnalysis, RankedTensorType operandTy,
662 RankedTensorType resultTy, mhlo::DynamicReshapeOp op,
663 PatternRewriter &rewriter) {
664 // Require sucessful shape analysis for operand and result shape.
665 auto operandShapeInfo = shapeAnalysis.GetShapeInfo(op.operand());
666 if (!operandShapeInfo) return failure();
667 auto resultShapeInfo = shapeAnalysis.GetValueInfo(op.output_shape());
668 if (!resultShapeInfo) return failure();
669
670 // Identify dimension groups and the intermediate expanded type.
671 SmallVector<DimensionGroup> dimensionGroups;
672 SmallVector<int64_t> expandedIntermShape;
673 if (failed(findExpandingAndCollapsingDimensionGroups(
674 *operandShapeInfo, *resultShapeInfo, &dimensionGroups,
675 &expandedIntermShape))) {
676 return failure();
677 }
678
679 // Materialize cast, expand, collapse, and cast, as needed.
680 auto loc = op.getLoc();
681 Value interm = op.operand();
682 auto castedOperandTy = RankedTensorType::get(
683 concretizeOperandShape(operandTy.getShape(), *operandShapeInfo),
684 operandTy.getElementType());
685 if (operandTy != castedOperandTy) {
686 interm = rewriter.create<tensor::CastOp>(loc, castedOperandTy, interm);
687 }
688 if (auto reassociation = requiresReassociationOfKind(
689 DimensionGroupKind::kExpanding, dimensionGroups)) {
690 interm = rewriter.create<tensor::ExpandShapeOp>(
691 loc,
692 RankedTensorType::get(expandedIntermShape, operandTy.getElementType()),
693 interm, *reassociation);
694 }
695 if (auto reassociation = requiresReassociationOfKind(
696 DimensionGroupKind::kCollapsing, dimensionGroups)) {
697 interm =
698 rewriter.create<tensor::CollapseShapeOp>(loc, interm, *reassociation);
699 }
700 if (interm.getType() != resultTy) {
701 interm = rewriter.create<tensor::CastOp>(loc, resultTy, interm);
702 }
703 rewriter.replaceOp(op, interm);
704 return success();
705 }
706
707 // Tries to express `dynamic_reshape` ops through `expand_shape` and
708 // `collapse_shape` ops.
709 struct DynamicReshapeToExpandAndCollapseShape final
710 : public OpRewritePattern<mhlo::DynamicReshapeOp> {
711 using OpRewritePattern::OpRewritePattern;
matchAndRewritemlir::__anona8b7c6fe0111::DynamicReshapeToExpandAndCollapseShape712 LogicalResult matchAndRewrite(mhlo::DynamicReshapeOp op,
713 PatternRewriter &rewriter) const override {
714 auto operandTy = op.operand().getType().dyn_cast<RankedTensorType>();
715 if (!operandTy) return failure();
716 auto resultTy = op.getType().dyn_cast<RankedTensorType>();
717 if (!resultTy) return failure();
718
719 // Handle degenerate scalar expand case.
720 if (operandTy.getRank() == 0) {
721 return materializeReshapeAsScalarExpand(operandTy, resultTy, op,
722 rewriter);
723 }
724
725 // Handle degenerate scalar collapse case.
726 if (resultTy.getRank() == 0) {
727 return materializeReshapeAsScalarCollapse(operandTy, resultTy, op,
728 rewriter);
729 }
730
731 ShapeComponentAnalysis shapeAnalysis;
732 return materializeReshapeAsExpandAndCollapse(shapeAnalysis, operandTy,
733 resultTy, op, rewriter);
734 }
735 };
736
737 // Returns true if all of bcasted_shapes can be broadcasted with output_shape.
isKnownBroadcastable(ShapeComponentAnalysis & analysis,ValueRange bcastedShapes,Value outputShape)738 bool isKnownBroadcastable(ShapeComponentAnalysis &analysis,
739 ValueRange bcastedShapes, Value outputShape) {
740 auto outputShapeDims = analysis.GetValueInfo(outputShape);
741 if (!outputShapeDims) return false;
742 for (Value shape : bcastedShapes) {
743 auto shapeDims = analysis.GetValueInfo(shape);
744 if (!shapeDims) return false;
745 // Iterate backwards over the smallest input shape.
746 for (auto zip : llvm::zip(llvm::reverse(*outputShapeDims),
747 llvm::reverse(*shapeDims))) {
748 const auto &first = std::get<0>(zip);
749 const auto &second = std::get<1>(zip);
750 // TODO(ezhulenev): What to do with dimensions statically known to be
751 // zero?
752 // Numpy can only broadcast [0] with [1], however Tensorflow can broadcast
753 // [0] with any dimension size, and produces dimension of size [0].
754 // Currently we'll conservatively return failure and will not proceed with
755 // a rewrite.
756 if (first.isConstant(0) || second.isConstant(0)) return false;
757 // If either shape has a static one dimension the broadcast will always
758 // succeed.
759 if (first.isConstant(1) || second.isConstant(1)) continue;
760 // Otherwise dims have to be equal.
761 if (first != second) return false;
762 }
763 }
764 return true;
765 }
766
767 // Rewrite `shape.cstr_broadcastable` with constant witness if can prove that
768 // shapes are broadcastable from a symbolic analysis.
769 struct CstrBroadcastableOpLowering
770 : public OpRewritePattern<shape::CstrBroadcastableOp> {
771 using OpRewritePattern::OpRewritePattern;
matchAndRewritemlir::__anona8b7c6fe0111::CstrBroadcastableOpLowering772 LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
773 PatternRewriter &rewriter) const override {
774 ShapeComponentAnalysis shapeComponentAnalysis;
775 if (!isKnownBroadcastable(shapeComponentAnalysis, op.getShapes(),
776 op.getShapes().front())) {
777 return failure();
778 }
779 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
780 return success();
781 }
782 };
783
784 class SymbolicShapeOptimizationPass final
785 : public SymbolicShapeOptimizationBase<SymbolicShapeOptimizationPass> {
getDependentDialects(DialectRegistry & registry) const786 void getDependentDialects(DialectRegistry ®istry) const override {
787 registry.insert<linalg::LinalgDialect>();
788 }
789
runOnOperation()790 void runOnOperation() override {
791 MLIRContext *ctx = &getContext();
792 mlir::RewritePatternSet patterns(ctx);
793
794 // clang-format off
795 patterns.insert<
796 AnnotateExpandingDimensionsInDynamicBroadcastInDim,
797 CstrBroadcastableOpLowering,
798 DynamicReshapeToExpandAndCollapseShape,
799 RemoveComputeReshapeShape,
800 RemoveRedundantCstrReshapable,
801 SimplifyBroadcasts>(ctx);
802 // clang-format on
803 shape::AssumingOp::getCanonicalizationPatterns(patterns, ctx);
804
805 if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
806 std::move(patterns)))) {
807 signalPassFailure();
808 }
809 }
810 };
811
812 } // end namespace
813
814 std::unique_ptr<OperationPass<func::FuncOp>>
createSymbolicShapeOptimizationPass()815 createSymbolicShapeOptimizationPass() {
816 return std::make_unique<SymbolicShapeOptimizationPass>();
817 }
818
819 } // end namespace mlir
820