1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <cstdint>
18 #include <iterator>
19 #include <memory>
20 #include <numeric>
21 #include <utility>
22 
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "mlir-hlo/Analysis/shape_component_analysis.h"
27 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
28 #include "mlir-hlo/Transforms/PassDetail.h"
29 #include "mlir-hlo/Transforms/passes.h"
30 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
31 #include "mlir/Dialect/Func/IR/FuncOps.h"
32 #include "mlir/Dialect/Linalg/IR/Linalg.h"
33 #include "mlir/Dialect/Shape/IR/Shape.h"
34 #include "mlir/Dialect/Tensor/IR/Tensor.h"
35 #include "mlir/IR/AffineExpr.h"
36 #include "mlir/IR/BuiltinTypes.h"
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
40 
41 namespace mlir {
42 
43 using ShapeOrValueInfo = ShapeComponentAnalysis::ShapeOrValueInfo;
44 using Symbol = ShapeComponentAnalysis::Symbol;
45 using SymbolicExpr = ShapeComponentAnalysis::SymbolicExpr;
46 
47 namespace {
48 
49 // Temporary data structure to hold a single dimension of the symbolic result of
50 // `shape.broadcast`.
51 struct SymbolicBroadcastDimension {
52   size_t operandIndex;
53   size_t operandDim;
54   SymbolicExpr expr;
55 };
56 
57 // Replace shape.broadcast with a shape if it's statically known.
58 struct SimplifyBroadcasts : public mlir::OpRewritePattern<shape::BroadcastOp> {
59   using OpRewritePattern::OpRewritePattern;
matchAndRewritemlir::__anona8b7c6fe0111::SimplifyBroadcasts60   LogicalResult matchAndRewrite(
61       shape::BroadcastOp op, mlir::PatternRewriter &rewriter) const override {
62     // Require successful shape analysis.
63     ShapeComponentAnalysis shapeAnalysis;
64     llvm::SmallVector<ArrayRef<SymbolicExpr>> shapesInfo;
65     auto shapes = op.getShapes();
66     shapesInfo.reserve(shapes.size());
67     for (Value s : shapes) {
68       auto sInfo = shapeAnalysis.GetValueInfo(s);
69       if (!sInfo) return failure();
70       shapesInfo.push_back(*sInfo);
71     }
72 
73     // Find the result rank.
74     size_t rank = 0;
75     for (const auto &sInfo : shapesInfo) rank = std::max(rank, sInfo.size());
76 
77     // Compute broadcast symbolically.
78     SmallVector<Optional<SymbolicBroadcastDimension>> symResult(rank,
79                                                                 llvm::None);
80     for (const auto &sInfo : llvm::enumerate(shapesInfo)) {
81       size_t dimOffset = rank - sInfo.value().size();
82       for (const auto &symExpr : llvm::enumerate(sInfo.value())) {
83         // Unit dimensions are neutral to the final result.
84         if (symExpr.value().isConstant(1)) continue;
85 
86         // Use unique expression.
87         size_t i = dimOffset + symExpr.index();
88         if (!symResult[i]) {
89           symResult[i] = {sInfo.index(), symExpr.index(), symExpr.value()};
90           continue;
91         }
92 
93         // Bail if the dimensions are neither equal nor 1.
94         if (symResult[i]->expr != symExpr.value()) return failure();
95       }
96     }
97 
98     // Materialize broadcast result.
99     auto loc = op.getLoc();
100     DenseMap<int64_t, Value> constants;
101     auto findOrCreateConstant = [&](int64_t c) {
102       auto it = constants.find(c);
103       if (it != constants.end()) return it->second;
104       Value newlyCreated = rewriter.create<arith::ConstantIndexOp>(loc, c);
105       constants[c] = newlyCreated;
106       return newlyCreated;
107     };
108     auto elements = llvm::to_vector<8>(
109         llvm::map_range(symResult, [&](const auto &symResultDim) {
110           // If we know the dimension statically, use a constant.
111           if (!symResultDim) return findOrCreateConstant(1);
112           if (auto cexpr = symResultDim->expr.expr
113                                .template dyn_cast<AffineConstantExpr>()) {
114             return findOrCreateConstant(cexpr.getValue());
115           }
116 
117           // Othwerise, extract the dimension from the unique operand.
118           Value operand = shapes[symResultDim->operandIndex];
119           Value operandDim = findOrCreateConstant(symResultDim->operandDim);
120           return rewriter.create<tensor::ExtractOp>(loc, operand, operandDim)
121               .getResult();
122         }));
123     Type indexTy = rewriter.getIndexType();
124     Type concreteResultTy =
125         RankedTensorType::get({static_cast<int64_t>(elements.size())}, indexTy);
126     Value result = rewriter.create<tensor::FromElementsOp>(
127         loc, concreteResultTy, elements);
128 
129     // Insert cast, if needed.
130     Type expectedTy = op.getResult().getType();
131     if (result.getType() != expectedTy) {
132       result = rewriter.create<tensor::CastOp>(loc, expectedTy, result);
133     }
134 
135     rewriter.replaceOp(op, result);
136     return success();
137   }
138 };
139 
analyzeDynamicBroadcastInDimExpandingBehavior(ShapeComponentAnalysis & analysis,Value value,Value shape,llvm::SmallSetVector<int64_t,4> * knownExpandingDims,llvm::SmallSetVector<int64_t,4> * knownNonexpandingDims)140 LogicalResult analyzeDynamicBroadcastInDimExpandingBehavior(
141     ShapeComponentAnalysis &analysis, Value value, Value shape,
142     llvm::SmallSetVector<int64_t, 4> *knownExpandingDims,
143     llvm::SmallSetVector<int64_t, 4> *knownNonexpandingDims) {
144   // Require successful analysis of shapes.
145   auto shapeIn = analysis.GetShapeInfo(value);
146   auto shapeOut = analysis.GetValueInfo(shape);
147   if (!shapeIn || !shapeOut) return failure();
148 
149   // Analyze per argument dimension.
150   size_t rankIn = shapeIn->size();
151   size_t rankOut = shapeOut->size();
152   assert(rankIn <= rankOut);
153   size_t dimOutOffset = rankOut - rankIn;
154   for (size_t i = 0; i < rankIn; ++i) {
155     SymbolicExpr dimIn = (*shapeIn)[i];
156     SymbolicExpr dimOut = (*shapeOut)[dimOutOffset + i];
157     if (dimIn.isConstant(1) && dimOut.isKnownNotOne())
158       knownExpandingDims->insert(i);
159     if (dimIn == dimOut || dimOut.isConstant(1))
160       knownNonexpandingDims->insert(i);
161   }
162   return success();
163 }
164 
165 // Analyze `mhlo.dynamic_broadcast_in_dim` op and populate attributes for
166 // statically known expanding and non-expanding dimensions.
167 struct AnnotateExpandingDimensionsInDynamicBroadcastInDim
168     : public mlir::OpRewritePattern<mhlo::DynamicBroadcastInDimOp> {
169   using OpRewritePattern::OpRewritePattern;
matchAndRewritemlir::__anona8b7c6fe0111::AnnotateExpandingDimensionsInDynamicBroadcastInDim170   LogicalResult matchAndRewrite(
171       mhlo::DynamicBroadcastInDimOp op,
172       mlir::PatternRewriter &rewriter) const override {
173     // Analyze shapes and identify expanding and non-expanding dims.
174     ShapeComponentAnalysis analysis;
175     llvm::SmallSetVector<int64_t, 4> knownExpandingDims, knownNonexpandingDims;
176     if (failed(analyzeDynamicBroadcastInDimExpandingBehavior(
177             analysis, op.operand(), op.output_dimensions(), &knownExpandingDims,
178             &knownNonexpandingDims))) {
179       return failure();
180     }
181 
182     // Collect possibly already annotated info.
183     auto insertAll = [](llvm::SmallSetVector<int64_t, 4> &dst,
184                         Optional<DenseIntElementsAttr> src) {
185       if (!src) return;
186       for (auto it : *src) dst.insert(it.getLimitedValue());
187     };
188     insertAll(knownExpandingDims, op.known_expanding_dimensions());
189     insertAll(knownNonexpandingDims, op.known_nonexpanding_dimensions());
190 
191     // Fail pattern application if there is nothing new to annotate.
192     auto isEqual = [](llvm::SmallSetVector<int64_t, 4> &set,
193                       DenseIntElementsAttr attr) {
194       return static_cast<int64_t>(set.size()) == attr.size() &&
195              llvm::all_of(attr, [&](auto it) {
196                return set.count(it.getLimitedValue());
197              });
198     };
199     if (op.known_expanding_dimensions() && op.known_nonexpanding_dimensions() &&
200         isEqual(knownExpandingDims, *op.known_expanding_dimensions()) &&
201         isEqual(knownNonexpandingDims, *op.known_nonexpanding_dimensions())) {
202       return failure();
203     }
204 
205     // Annotate op in place.
206     rewriter.startRootUpdate(op);
207     op.known_expanding_dimensionsAttr(
208         rewriter.getI64TensorAttr(knownExpandingDims.takeVector()));
209     op.known_nonexpanding_dimensionsAttr(
210         rewriter.getI64TensorAttr(knownNonexpandingDims.takeVector()));
211     rewriter.finalizeRootUpdate(op);
212     return success();
213   }
214 };
215 
216 // Remove compute_reshape_shape if we can prove that the dynamic shape does not
217 // contain a `-1` dimension.
218 struct RemoveComputeReshapeShape final
219     : public OpRewritePattern<mhlo::ComputeReshapeShapeOp> {
220   using OpRewritePattern::OpRewritePattern;
matchAndRewritemlir::__anona8b7c6fe0111::RemoveComputeReshapeShape221   LogicalResult matchAndRewrite(mhlo::ComputeReshapeShapeOp op,
222                                 PatternRewriter &rewriter) const override {
223     ShapeComponentAnalysis shapeComponentAnalysis;
224     auto dynamicShape = shapeComponentAnalysis.GetValueInfo(op.dynamic_shape());
225     if (!dynamicShape) return failure();
226 
227     if (llvm::any_of(*dynamicShape, [](const auto &dim) {
228           return !dim.isKnownNotNegativeOne();
229         })) {
230       return failure();
231     }
232     rewriter.replaceOp(op, op.dynamic_shape());
233     return success();
234   }
235 };
236 
isProduct(AffineExpr expr,llvm::function_ref<void (AffineConstantExpr)> cbkConstantFactor,llvm::function_ref<void (AffineSymbolExpr)> cbkSymbolicFactor)237 bool isProduct(AffineExpr expr,
238                llvm::function_ref<void(AffineConstantExpr)> cbkConstantFactor,
239                llvm::function_ref<void(AffineSymbolExpr)> cbkSymbolicFactor) {
240   auto binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
241   if (binExpr && binExpr.getKind() == AffineExprKind::Mul) {
242     return isProduct(binExpr.getLHS(), cbkConstantFactor, cbkSymbolicFactor) &&
243            isProduct(binExpr.getRHS(), cbkConstantFactor, cbkSymbolicFactor);
244   }
245   if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
246     cbkSymbolicFactor(symExpr);
247     return true;
248   }
249   if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
250     cbkConstantFactor(constExpr);
251     return true;
252   }
253   return false;
254 }
255 
isSymbolicProduct(const SymbolicExpr & symbolicExpr,llvm::function_ref<void (int64_t)> cbkConstantFactor,llvm::function_ref<void (Symbol)> cbkSymbolicFactor)256 bool isSymbolicProduct(const SymbolicExpr &symbolicExpr,
257                        llvm::function_ref<void(int64_t)> cbkConstantFactor,
258                        llvm::function_ref<void(Symbol)> cbkSymbolicFactor) {
259   return isProduct(
260       symbolicExpr.expr,
261       [&](AffineConstantExpr cexpr) { cbkConstantFactor(cexpr.getValue()); },
262       [&](AffineSymbolExpr sexpr) {
263         cbkSymbolicFactor(symbolicExpr.symbols[sexpr.getPosition()]);
264       });
265 }
266 
267 // Represents a product of symbolic and concrete factors. This will allow us to
268 // prove product equalities symbolically.
269 struct SymbolicProduct {
270   // Product of all concrete factors.
271   int64_t concrete = 1;
272   // List all symbolic factors as they can not be aggregated.
273   llvm::SmallVector<Symbol> symbolic;
emptymlir::__anona8b7c6fe0111::SymbolicProduct274   bool empty() { return concrete == 1 && symbolic.empty(); }
275 };
276 
isSymbolicProduct(const SymbolicExpr & symbolicExpr,SymbolicProduct * product)277 bool isSymbolicProduct(const SymbolicExpr &symbolicExpr,
278                        SymbolicProduct *product) {
279   return isSymbolicProduct(
280       symbolicExpr, [&](int64_t c) { product->concrete *= c; },
281       [&](Symbol s) { product->symbolic.push_back(s); });
282 }
283 
284 struct RemoveRedundantCstrReshapable final
285     : public OpRewritePattern<mhlo::CstrReshapableOp> {
286   using OpRewritePattern::OpRewritePattern;
matchAndRewritemlir::__anona8b7c6fe0111::RemoveRedundantCstrReshapable287   LogicalResult matchAndRewrite(mhlo::CstrReshapableOp op,
288                                 PatternRewriter &rewriter) const override {
289     // Get shape analysis info for the number of elements.
290     ShapeComponentAnalysis shapeComponentAnalysis;
291     auto numElementsInfo =
292         shapeComponentAnalysis.GetValueInfo(op.num_elements());
293     if (!numElementsInfo) return failure();
294     assert(numElementsInfo->size() == 1 && "expect one value for a scalar");
295     auto numElements = numElementsInfo->front();
296 
297     // Get shape analysis info for the dynamic shape.
298     auto dynShapeDims = shapeComponentAnalysis.GetValueInfo(op.dynamic_shape());
299     if (!dynShapeDims) return failure();
300 
301     // We can handle two cases:
302     //   - there is exactly one -1 in the dynamic shape, i.e. a unique wildcard
303     //     dimension, or
304     //   - there is no -1 in the dynamic shape, i.e. no wildcard dimension.
305     bool uniqueWildcardDimension = false;
306     for (const auto &d : *dynShapeDims) {
307       if (d.isConstant(-1)) {
308         if (uniqueWildcardDimension) return failure();
309         uniqueWildcardDimension = true;
310       } else if (!d.isKnownNotNegativeOne()) {
311         return failure();
312       }
313     }
314 
315     // We can only handle simple products with constants and symbols. Find all
316     // the factors based on the number of elements.
317     SymbolicProduct numElementsRemainingFactors;
318     if (!isSymbolicProduct(numElements, &numElementsRemainingFactors)) {
319       return failure();
320     }
321     assert(numElementsRemainingFactors.concrete >= 1 &&
322            "number of elements cannot entail negative or zero factors");
323 
324     // Find all factors based on the dynamic shape.
325     //   - Accumulate the conrete product to later compare it against its
326     //     equivalent based on the number of elements.
327     //   - Remove symbolic factors from the list and fail if we find an unknown
328     //     factor, i.e. if the symbolic factors based on the dynamic shape are
329     //     not a subset of the factors based on the number of elements.
330     int64_t concreteProductDynShape = 1;
331     for (const auto &dim : *dynShapeDims) {
332       SmallVector<Symbol> partialSymbolicFactorsDynShape;
333       if (!isSymbolicProduct(
334               dim,
335               [&](int64_t c) {
336                 if (c != ShapedType::kDynamicSize) concreteProductDynShape *= c;
337               },
338               [&](Symbol s) { partialSymbolicFactorsDynShape.push_back(s); })) {
339         return failure();
340       }
341       for (const Symbol &symDynShape : partialSymbolicFactorsDynShape) {
342         auto *it =
343             llvm::find(numElementsRemainingFactors.symbolic, symDynShape);
344         if (it == numElementsRemainingFactors.symbolic.end()) return failure();
345         numElementsRemainingFactors.symbolic.erase(it);
346       }
347     }
348     assert(concreteProductDynShape >= 1 &&
349            "concrete product must not aggregate negative or zero factors");
350 
351     // A wildcard dimension can subsume the remaining symbolic factors and
352     // potentially also a concrete factor.
353     if (uniqueWildcardDimension) {
354       if (numElementsRemainingFactors.concrete % concreteProductDynShape != 0)
355         return failure();
356       rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
357       return success();
358     }
359 
360     // W/o a wildcard, the symbolic and concrete products must be equal.
361     bool isReshapable =
362         numElementsRemainingFactors.symbolic.empty() &&
363         numElementsRemainingFactors.concrete == concreteProductDynShape;
364     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, isReshapable);
365     return success();
366   }
367 };
368 
materializeReshapeAsScalarExpand(RankedTensorType operandTy,RankedTensorType resultTy,mhlo::DynamicReshapeOp op,PatternRewriter & rewriter)369 LogicalResult materializeReshapeAsScalarExpand(RankedTensorType operandTy,
370                                                RankedTensorType resultTy,
371                                                mhlo::DynamicReshapeOp op,
372                                                PatternRewriter &rewriter) {
373   assert(operandTy.getRank() == 0 && "expect scalar operand");
374   auto loc = op.getLoc();
375   SmallVector<int64_t> unitDims(resultTy.getRank(), 1);
376   auto expandedTy = RankedTensorType::get(unitDims, resultTy.getElementType());
377   Value expandedScalar = rewriter.create<tensor::ExpandShapeOp>(
378       loc, expandedTy, op.operand(), ArrayRef<ReassociationIndices>{});
379   if (expandedScalar.getType() != resultTy) {
380     expandedScalar =
381         rewriter.create<tensor::CastOp>(loc, resultTy, expandedScalar);
382   }
383   rewriter.replaceOp(op, expandedScalar);
384   return success();
385 }
386 
materializeReshapeAsScalarCollapse(RankedTensorType operandTy,RankedTensorType resultTy,mhlo::DynamicReshapeOp op,PatternRewriter & rewriter)387 LogicalResult materializeReshapeAsScalarCollapse(RankedTensorType operandTy,
388                                                  RankedTensorType resultTy,
389                                                  mhlo::DynamicReshapeOp op,
390                                                  PatternRewriter &rewriter) {
391   assert(resultTy.getRank() == 0 && "expect scalar result");
392   auto loc = op.getLoc();
393   Value operand = op.operand();
394   SmallVector<int64_t> unitDims(operandTy.getRank(), 1);
395   auto castedOperandTy =
396       RankedTensorType::get(unitDims, operandTy.getElementType());
397   if (operand.getType() != castedOperandTy) {
398     operand = rewriter.create<tensor::CastOp>(loc, castedOperandTy, operand);
399   }
400   Value collapsedScalar = rewriter.create<tensor::CollapseShapeOp>(
401       loc, operand, ArrayRef<ReassociationIndices>{});
402   rewriter.replaceOp(op, collapsedScalar);
403   return success();
404 }
405 
406 enum class DimensionGroupKind {
407   kNone,
408   kExpanding,
409   kCollapsing,
410 };
411 
412 struct DimensionGroup {
413   int64_t size = 0;
414   DimensionGroupKind kind = DimensionGroupKind::kNone;
415 };
416 
eliminateCommonFactors(SymbolicProduct & a,SymbolicProduct & b)417 SymbolicProduct eliminateCommonFactors(SymbolicProduct &a, SymbolicProduct &b) {
418   SymbolicProduct gcd;
419 
420   // Eliminate common concrete factors.
421   gcd.concrete = llvm::GreatestCommonDivisor64(a.concrete, b.concrete);
422   a.concrete /= gcd.concrete;
423   b.concrete /= gcd.concrete;
424 
425   // Eliminate common symbolic factors.
426   int64_t i = 0;
427   while (i < static_cast<int64_t>(a.symbolic.size())) {
428     auto *it = llvm::find(b.symbolic, a.symbolic[i]);
429     if (it != b.symbolic.end()) {
430       gcd.symbolic.push_back(*it);
431       std::swap(a.symbolic[i], a.symbolic.back());
432       a.symbolic.pop_back();
433       b.symbolic.erase(it);
434     } else {
435       i++;
436     }
437   }
438 
439   return gcd;
440 }
441 
isUnpairedUnitDimension(ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator it,ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator end,ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator otherIt,ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator otherEnd)442 bool isUnpairedUnitDimension(
443     ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator it,
444     ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator end,
445     ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator otherIt,
446     ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator otherEnd) {
447   return it != end && it->isConstant(1) &&
448          (otherIt == otherEnd || !otherIt->isConstant(1));
449 }
450 
getShapedTypyDimSize(const SymbolicProduct & symProduct)451 int64_t getShapedTypyDimSize(const SymbolicProduct &symProduct) {
452   return symProduct.symbolic.empty() ? symProduct.concrete
453                                      : ShapedType::kDynamicSize;
454 }
455 
456 // Iterate over the operand's and the result's shape dimensions and find
457 // dimension groups that are collapsing, expanding, or untouched:
458 //   - Collapsing: Multiple dimensions of the operand shape can be collapsed
459 //     into a single dimension of the result shape. We must prove that the
460 //     product of the operand shape's dimensions is equal to the corresponding
461 //     result dimension.
462 //   - Expanding: A single dimension of the operand shape can be expanded into
463 //     multiple dimensions of the result shape. We must prove that the product
464 //     of the result shape's dimensions is equal to the corresponding operand
465 //     dimension. This case is limited to at most one dynamic dimension per
466 //     expansion group as otherwise not supported by the `expand_shape` op.
467 //   - Untouched: There is a 1:1 correspondance between an operand and a result
468 //     shape dimension.
469 //
470 // We can determine the optimal dimension groups greedily by consuming operand
471 // and result dimensions from left to right. If the leading operand dimension is
472 // a strict divisor of the leading result dimension, collapsing is required. In
473 // this case, we keep consuming the operand dimensions until the products are
474 // equal. If the leading result dimension is a strict divisor of the leading
475 // operand dimension, expanding is required. In this case, we keep consuming the
476 // result dimensions until the products are equal. Trailing unit dimensions may
477 // be inlcuded in the dimension group. This is useful iff they are "unpaired",
478 // in which case they would only limit us in the subsequent iteration.
479 //
findExpandingAndCollapsingDimensionGroups(ArrayRef<SymbolicExpr> operandShapeInfo,ArrayRef<SymbolicExpr> resultShapeInfo,SmallVector<DimensionGroup> * dimensionGroups,SmallVector<int64_t> * expandedIntermShape)480 LogicalResult findExpandingAndCollapsingDimensionGroups(
481     ArrayRef<SymbolicExpr> operandShapeInfo,
482     ArrayRef<SymbolicExpr> resultShapeInfo,
483     SmallVector<DimensionGroup> *dimensionGroups,
484     SmallVector<int64_t> *expandedIntermShape) {
485   const auto *operandShapeIt = operandShapeInfo.begin();
486   const auto *operandShapeEnd = operandShapeInfo.end();
487   const auto *resultShapeIt = resultShapeInfo.begin();
488   const auto *resultShapeEnd = resultShapeInfo.end();
489 
490   // Crucial iteration state.
491   SymbolicProduct remainingOperandShapeFactors;
492   SymbolicProduct remainingResultShapeFactors;
493   auto anyRemainingFactors = [&]() {
494     return !remainingOperandShapeFactors.empty() ||
495            !remainingResultShapeFactors.empty();
496   };
497 
498   while (operandShapeIt != operandShapeEnd && resultShapeIt != resultShapeEnd) {
499     assert(!anyRemainingFactors() &&
500            "expect no remaining factors from previous iteration");
501     DimensionGroup &dimGroup = dimensionGroups->emplace_back();
502 
503     // Consume at least one operand and result dimension.
504     {
505       if (!isSymbolicProduct(*operandShapeIt++,
506                              &remainingOperandShapeFactors) ||
507           !isSymbolicProduct(*resultShapeIt++, &remainingResultShapeFactors)) {
508         return failure();
509       }
510       dimGroup.size++;
511       SymbolicProduct gcd = eliminateCommonFactors(remainingOperandShapeFactors,
512                                                    remainingResultShapeFactors);
513       expandedIntermShape->push_back(getShapedTypyDimSize(gcd));
514     }
515 
516     // Fail if there are unresolvable, contradicting factors remaining.
517     if (!remainingOperandShapeFactors.empty() &&
518         !remainingResultShapeFactors.empty()) {
519       return failure();
520     }
521 
522     // Collapsing: Create a collapsing dimension group.
523     bool requiresCollapsing =
524         remainingOperandShapeFactors.empty() &&
525         (!remainingResultShapeFactors.empty() ||
526          isUnpairedUnitDimension(operandShapeIt, operandShapeEnd, resultShapeIt,
527                                  resultShapeEnd));
528     if (requiresCollapsing) {
529       dimGroup.kind = DimensionGroupKind::kCollapsing;
530 
531       // Consume operand shape dimensions until their product matches the
532       // corresponding result dimension (or fail if unresolvable/contradicting
533       // factors are found).
534       while (operandShapeIt != operandShapeEnd &&
535              remainingOperandShapeFactors.empty() &&
536              !remainingResultShapeFactors.empty()) {
537         if (!isSymbolicProduct(*operandShapeIt++,
538                                &remainingOperandShapeFactors)) {
539           return failure();
540         }
541         dimGroup.size++;
542         SymbolicProduct gcd = eliminateCommonFactors(
543             remainingOperandShapeFactors, remainingResultShapeFactors);
544         expandedIntermShape->push_back(getShapedTypyDimSize(gcd));
545       }
546       if (anyRemainingFactors()) return failure();
547 
548       // Consume trailing, unpaired unit dimensions.
549       while (isUnpairedUnitDimension(operandShapeIt, operandShapeEnd,
550                                      resultShapeIt, resultShapeEnd)) {
551         operandShapeIt++;
552         dimGroup.size++;
553         expandedIntermShape->push_back(1);
554       }
555 
556       continue;
557     }
558 
559     // Expanding: Create an expanding dimension group.
560     bool requiresExpanding =
561         remainingResultShapeFactors.empty() &&
562         (!remainingOperandShapeFactors.empty() ||
563          isUnpairedUnitDimension(resultShapeIt, resultShapeEnd, operandShapeIt,
564                                  operandShapeEnd));
565     if (requiresExpanding) {
566       dimGroup.kind = DimensionGroupKind::kExpanding;
567       int64_t numDynamicDims = 0;
568 
569       // Consume result shape dimensions until their product matches the
570       // corresponding operand dimension (or fail if unresolvable/contradicting
571       // factors are found).
572       while (resultShapeIt != resultShapeEnd &&
573              remainingResultShapeFactors.empty() &&
574              !remainingOperandShapeFactors.empty()) {
575         if (!isSymbolicProduct(*resultShapeIt++,
576                                &remainingResultShapeFactors)) {
577           return failure();
578         }
579         dimGroup.size++;
580         SymbolicProduct gcd = eliminateCommonFactors(
581             remainingOperandShapeFactors, remainingResultShapeFactors);
582         int64_t tyDimSize = getShapedTypyDimSize(gcd);
583 
584         // Allow no more than one dynamic dimension per expansion group.
585         if (tyDimSize == ShapedType::kDynamicSize) {
586           numDynamicDims++;
587           if (numDynamicDims > 1) return failure();
588         }
589         expandedIntermShape->push_back(tyDimSize);
590       }
591       if (anyRemainingFactors()) return failure();
592 
593       // Consume trailing, unpaired unit dimensions.
594       while (isUnpairedUnitDimension(resultShapeIt, resultShapeEnd,
595                                      operandShapeIt, operandShapeEnd)) {
596         resultShapeIt++;
597         dimGroup.size++;
598         expandedIntermShape->push_back(1);
599       }
600 
601       continue;
602     }
603 
604     // Untouched: 1:1 mapping between operand and result shape dimension. This
605     // is neither expanding nor collapsing.
606     assert(!requiresCollapsing && !requiresExpanding && "expect id case");
607     assert(dimGroup.size == 1 && dimGroup.kind == DimensionGroupKind::kNone &&
608            "expect simple dimension group");
609   }
610 
611   // Fail if there are remaining dimensions that could not be consumed.
612   assert(!anyRemainingFactors() && "expect no remaining factors");
613   if (operandShapeIt != operandShapeEnd || resultShapeIt != resultShapeEnd) {
614     return failure();
615   }
616 
617   return success();
618 }
619 
concretizeOperandShape(ArrayRef<int64_t> operandShape,ArrayRef<SymbolicExpr> operandShapeInfo)620 SmallVector<int64_t> concretizeOperandShape(
621     ArrayRef<int64_t> operandShape, ArrayRef<SymbolicExpr> operandShapeInfo) {
622   SmallVector<int64_t> result;
623   for (auto it : llvm::zip(operandShape, operandShapeInfo)) {
624     auto dimSize = std::get<0>(it);
625     auto sExpr = std::get<1>(it);
626     if (auto cexpr = sExpr.expr.dyn_cast<AffineConstantExpr>()) {
627       int64_t alsoDimSize = cexpr.getValue();
628       assert((ShapedType::isDynamic(dimSize) || dimSize == alsoDimSize) &&
629              "expect shape analysis result to be compatible with type");
630       result.push_back(alsoDimSize);
631       continue;
632     }
633     result.push_back(dimSize);
634   }
635   return result;
636 }
637 
requiresReassociationOfKind(DimensionGroupKind kind,const SmallVector<DimensionGroup> & dimGroups)638 llvm::Optional<SmallVector<ReassociationIndices>> requiresReassociationOfKind(
639     DimensionGroupKind kind, const SmallVector<DimensionGroup> &dimGroups) {
640   SmallVector<ReassociationIndices> reassociation;
641   reassociation.reserve(dimGroups.size());
642   bool isStrictlyReassociating = false;
643   int64_t i = 0;
644   for (const DimensionGroup &g : dimGroups) {
645     if (g.kind == kind) {
646       isStrictlyReassociating = true;
647       reassociation.push_back(
648           llvm::to_vector(llvm::seq<int64_t>(i, i + g.size)));
649       i += g.size;
650       continue;
651     }
652     for (int64_t j = 0; j < g.size; j++) reassociation.push_back({i++});
653   }
654 
655   // Return the reassociation if expansion is required.
656   if (isStrictlyReassociating) return reassociation;
657   return llvm::None;
658 }
659 
materializeReshapeAsExpandAndCollapse(ShapeComponentAnalysis & shapeAnalysis,RankedTensorType operandTy,RankedTensorType resultTy,mhlo::DynamicReshapeOp op,PatternRewriter & rewriter)660 LogicalResult materializeReshapeAsExpandAndCollapse(
661     ShapeComponentAnalysis &shapeAnalysis, RankedTensorType operandTy,
662     RankedTensorType resultTy, mhlo::DynamicReshapeOp op,
663     PatternRewriter &rewriter) {
664   // Require sucessful shape analysis for operand and result shape.
665   auto operandShapeInfo = shapeAnalysis.GetShapeInfo(op.operand());
666   if (!operandShapeInfo) return failure();
667   auto resultShapeInfo = shapeAnalysis.GetValueInfo(op.output_shape());
668   if (!resultShapeInfo) return failure();
669 
670   // Identify dimension groups and the intermediate expanded type.
671   SmallVector<DimensionGroup> dimensionGroups;
672   SmallVector<int64_t> expandedIntermShape;
673   if (failed(findExpandingAndCollapsingDimensionGroups(
674           *operandShapeInfo, *resultShapeInfo, &dimensionGroups,
675           &expandedIntermShape))) {
676     return failure();
677   }
678 
679   // Materialize cast, expand, collapse, and cast, as needed.
680   auto loc = op.getLoc();
681   Value interm = op.operand();
682   auto castedOperandTy = RankedTensorType::get(
683       concretizeOperandShape(operandTy.getShape(), *operandShapeInfo),
684       operandTy.getElementType());
685   if (operandTy != castedOperandTy) {
686     interm = rewriter.create<tensor::CastOp>(loc, castedOperandTy, interm);
687   }
688   if (auto reassociation = requiresReassociationOfKind(
689           DimensionGroupKind::kExpanding, dimensionGroups)) {
690     interm = rewriter.create<tensor::ExpandShapeOp>(
691         loc,
692         RankedTensorType::get(expandedIntermShape, operandTy.getElementType()),
693         interm, *reassociation);
694   }
695   if (auto reassociation = requiresReassociationOfKind(
696           DimensionGroupKind::kCollapsing, dimensionGroups)) {
697     interm =
698         rewriter.create<tensor::CollapseShapeOp>(loc, interm, *reassociation);
699   }
700   if (interm.getType() != resultTy) {
701     interm = rewriter.create<tensor::CastOp>(loc, resultTy, interm);
702   }
703   rewriter.replaceOp(op, interm);
704   return success();
705 }
706 
707 // Tries to express `dynamic_reshape` ops through `expand_shape` and
708 // `collapse_shape` ops.
709 struct DynamicReshapeToExpandAndCollapseShape final
710     : public OpRewritePattern<mhlo::DynamicReshapeOp> {
711   using OpRewritePattern::OpRewritePattern;
matchAndRewritemlir::__anona8b7c6fe0111::DynamicReshapeToExpandAndCollapseShape712   LogicalResult matchAndRewrite(mhlo::DynamicReshapeOp op,
713                                 PatternRewriter &rewriter) const override {
714     auto operandTy = op.operand().getType().dyn_cast<RankedTensorType>();
715     if (!operandTy) return failure();
716     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
717     if (!resultTy) return failure();
718 
719     // Handle degenerate scalar expand case.
720     if (operandTy.getRank() == 0) {
721       return materializeReshapeAsScalarExpand(operandTy, resultTy, op,
722                                               rewriter);
723     }
724 
725     // Handle degenerate scalar collapse case.
726     if (resultTy.getRank() == 0) {
727       return materializeReshapeAsScalarCollapse(operandTy, resultTy, op,
728                                                 rewriter);
729     }
730 
731     ShapeComponentAnalysis shapeAnalysis;
732     return materializeReshapeAsExpandAndCollapse(shapeAnalysis, operandTy,
733                                                  resultTy, op, rewriter);
734   }
735 };
736 
737 // Returns true if all of bcasted_shapes can be broadcasted with output_shape.
isKnownBroadcastable(ShapeComponentAnalysis & analysis,ValueRange bcastedShapes,Value outputShape)738 bool isKnownBroadcastable(ShapeComponentAnalysis &analysis,
739                           ValueRange bcastedShapes, Value outputShape) {
740   auto outputShapeDims = analysis.GetValueInfo(outputShape);
741   if (!outputShapeDims) return false;
742   for (Value shape : bcastedShapes) {
743     auto shapeDims = analysis.GetValueInfo(shape);
744     if (!shapeDims) return false;
745     // Iterate backwards over the smallest input shape.
746     for (auto zip : llvm::zip(llvm::reverse(*outputShapeDims),
747                               llvm::reverse(*shapeDims))) {
748       const auto &first = std::get<0>(zip);
749       const auto &second = std::get<1>(zip);
750       // TODO(ezhulenev): What to do with dimensions statically known to be
751       // zero?
752       // Numpy can only broadcast [0] with [1], however Tensorflow can broadcast
753       // [0] with any dimension size, and produces dimension of size [0].
754       // Currently we'll conservatively return failure and will not proceed with
755       // a rewrite.
756       if (first.isConstant(0) || second.isConstant(0)) return false;
757       // If either shape has a static one dimension the broadcast will always
758       // succeed.
759       if (first.isConstant(1) || second.isConstant(1)) continue;
760       // Otherwise dims have to be equal.
761       if (first != second) return false;
762     }
763   }
764   return true;
765 }
766 
767 // Rewrite `shape.cstr_broadcastable` with constant witness if can prove that
768 // shapes are broadcastable from a symbolic analysis.
769 struct CstrBroadcastableOpLowering
770     : public OpRewritePattern<shape::CstrBroadcastableOp> {
771   using OpRewritePattern::OpRewritePattern;
matchAndRewritemlir::__anona8b7c6fe0111::CstrBroadcastableOpLowering772   LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
773                                 PatternRewriter &rewriter) const override {
774     ShapeComponentAnalysis shapeComponentAnalysis;
775     if (!isKnownBroadcastable(shapeComponentAnalysis, op.getShapes(),
776                               op.getShapes().front())) {
777       return failure();
778     }
779     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
780     return success();
781   }
782 };
783 
784 class SymbolicShapeOptimizationPass final
785     : public SymbolicShapeOptimizationBase<SymbolicShapeOptimizationPass> {
getDependentDialects(DialectRegistry & registry) const786   void getDependentDialects(DialectRegistry &registry) const override {
787     registry.insert<linalg::LinalgDialect>();
788   }
789 
runOnOperation()790   void runOnOperation() override {
791     MLIRContext *ctx = &getContext();
792     mlir::RewritePatternSet patterns(ctx);
793 
794     // clang-format off
795     patterns.insert<
796         AnnotateExpandingDimensionsInDynamicBroadcastInDim,
797         CstrBroadcastableOpLowering,
798         DynamicReshapeToExpandAndCollapseShape,
799         RemoveComputeReshapeShape,
800         RemoveRedundantCstrReshapable,
801         SimplifyBroadcasts>(ctx);
802     // clang-format on
803     shape::AssumingOp::getCanonicalizationPatterns(patterns, ctx);
804 
805     if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
806                                                   std::move(patterns)))) {
807       signalPassFailure();
808     }
809   }
810 };
811 
812 }  // end namespace
813 
814 std::unique_ptr<OperationPass<func::FuncOp>>
createSymbolicShapeOptimizationPass()815 createSymbolicShapeOptimizationPass() {
816   return std::make_unique<SymbolicShapeOptimizationPass>();
817 }
818 
819 }  // end namespace mlir
820