1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Analysis/shape_component_analysis.h"
17 
18 #include <algorithm>
19 #include <vector>
20 
21 #include "llvm/ADT/STLExtras.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
23 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/Dialect/Shape/IR/Shape.h"
26 #include "mlir/Dialect/Tensor/IR/Tensor.h"
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Matchers.h"
30 
31 using namespace mlir;
32 
33 using SymbolicShapeConstraintsMap =
34     ShapeComponentAnalysis::SymbolicShapeConstraintsMap;
35 using ShapeOrValueInfo = ShapeComponentAnalysis::ShapeOrValueInfo;
36 using Symbol = ShapeComponentAnalysis::Symbol;
37 using SymbolicExpr = ShapeComponentAnalysis::SymbolicExpr;
38 using SymbolicExprsMap = ShapeComponentAnalysis::SymbolicExprsMap;
39 
40 namespace {
41 // Shape visitor. This implements a symbolic interpreter for MHLO with some
42 // shape and tensor dialect ops mixed in. We are interested in shapes (e.g., the
43 // dimensions of a tensor) and values (e.g, the elements of a shape tensor). The
44 // goal is to assign every component of a shape or value either a symbol, a
45 // constant, or a symbolic expression. We propagate these symbolic expressions
46 // through the various operations. Later optimization passes can use this
47 // information for optimizations, e.g., exploiting the equality of dimensions.
48 //
49 // The visitation happens in two phases:
50 //   1. Find the sources of a value's shape or value. This climbs up the
51 //      operations from a given value until an unknown op or a function argument
52 //      is found. These sources are assigned the initial symbols for each of
53 //      their components.
54 //   2. Propagate the initial symbols downwards. This builds symbolic
55 //      expressions so users of the analysis can pattern match things like
56 //      "two dimensions are multiplied".
57 //
58 // Conceptually, this is defined recursively. For each op, we compute the
59 // required shape or value information for the operands and then derive the
60 // resulting symbolic expression.
61 struct ShapeVisitor {
ShapeVisitor__anon297788610111::ShapeVisitor62   ShapeVisitor(SymbolicExprsMap *symbolicExprsMap,
63                SymbolicShapeConstraintsMap *symbolicShapeConstraintsMap)
64       : symbolicExprsMap(symbolicExprsMap),
65         symbolicShapeConstraintsMap(symbolicShapeConstraintsMap) {}
66 
visit__anon297788610111::ShapeVisitor67   void visit(ShapeOrValueInfo requestedInfo) {
68     backwardsWorklist.push_back(requestedInfo);
69 
70     // First, we climb up the operations so we get the set of all ops taking
71     // part in this shape or value computation. An alternative would be
72     // analyzing everything eagerly. This backwards pass allows us to be lazy.
73     while (!backwardsWorklist.empty()) {
74       // Skip if already processed.
75       ShapeOrValueInfo transitivelyRequestedInfo =
76           backwardsWorklist.pop_back_val();
77       if (symbolicExprsMap->count(transitivelyRequestedInfo)) continue;
78 
79       // Skip irrelevant cases early.
80       Value value = transitivelyRequestedInfo.value();
81       Type ty = value.getType();
82       if (!ty.isIntOrIndexOrFloat() && !ty.isa<RankedTensorType>()) continue;
83 
84       // Handle shapes.
85       if (transitivelyRequestedInfo.isShapeInfo()) {
86         if (value.getDefiningOp<shape::AssumingOp>()) {
87           backwardAssumingShape(value);
88         } else if (auto bcast =
89                        value.getDefiningOp<mhlo::DynamicBroadcastInDimOp>()) {
90           backwardDynamicBroadcastInDimShape(bcast);
91         } else if (auto reshape =
92                        value.getDefiningOp<mhlo::DynamicReshapeOp>()) {
93           backwardDynamicReshapeShape(reshape);
94         } else if (value.getDefiningOp<mhlo::ReduceOp>()) {
95           backwardReduceShape(value);
96         } else if (auto transpose = value.getDefiningOp<mhlo::TransposeOp>()) {
97           backwardTransposeShape(transpose);
98         } else if (auto select = value.getDefiningOp<mhlo::SelectOp>()) {
99           backwardSelectShape(select);
100         } else if (auto arg = value.dyn_cast<BlockArgument>()) {
101           backwardBlockArgumentShape(arg);
102         } else if (value.getDefiningOp() &&
103                    value.getDefiningOp()
104                        ->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
105           backwardSameOperandsAndResultShape(value);
106         } else {
107           backwardUnknownShape(value);
108         }
109         continue;
110       }
111 
112       // Skip irrelevant cases early.
113       auto rankedTy = ty.dyn_cast<RankedTensorType>();
114       bool isPossiblyInterestingScalar = ty.isIntOrIndex();
115       bool isPossiblyInterestingTensor =
116           rankedTy && rankedTy.getRank() <= 1 && rankedTy.hasStaticShape();
117       if (!isPossiblyInterestingScalar && !isPossiblyInterestingTensor) {
118         continue;
119       }
120 
121       // Handle values.
122       assert(transitivelyRequestedInfo.isValueInfo() &&
123              "Expect value info at this point.");
124       if (auto shapeof = value.getDefiningOp<shape::ShapeOfOp>()) {
125         backwardShapeOf(shapeof);
126       } else if (auto bcast = value.getDefiningOp<shape::BroadcastOp>()) {
127         backwardBroadcast(bcast);
128       } else if (auto numElements =
129                      value.getDefiningOp<shape::NumElementsOp>()) {
130         backwardNumElements(numElements);
131       } else if (auto dim = value.getDefiningOp<tensor::DimOp>()) {
132         backwardDim(dim);
133       } else if (auto cast = value.getDefiningOp<arith::IndexCastOp>()) {
134         backwardIndexCast(cast);
135       } else if (auto fromElements =
136                      value.getDefiningOp<tensor::FromElementsOp>()) {
137         backwardTensorFromElements(fromElements);
138       } else if (auto extract = value.getDefiningOp<tensor::ExtractOp>()) {
139         backwardTensorExtract(extract);
140       } else if (auto add = value.getDefiningOp<mhlo::AddOp>()) {
141         backwardBinOp(add);
142       } else if (auto mul = value.getDefiningOp<mhlo::MulOp>()) {
143         backwardBinOp(mul);
144       } else if (auto add = value.getDefiningOp<arith::AddIOp>()) {
145         backwardBinOp(add);
146       } else if (auto mul = value.getDefiningOp<arith::MulIOp>()) {
147         backwardBinOp(mul);
148       } else if (auto concat = value.getDefiningOp<mhlo::ConcatenateOp>()) {
149         backwardConcatenate(concat);
150       } else if (auto reshape = value.getDefiningOp<mhlo::ReshapeOp>()) {
151         backwardReshape(reshape);
152       } else if (auto slice = value.getDefiningOp<mhlo::SliceOp>()) {
153         backwardSlice(slice);
154       } else if (matchPattern(value, m_Constant())) {
155         backwardConstant(value);
156       } else {
157         backwardUnknown(value);
158       }
159     }
160 
161     // Second, we walk down from the defs to the uses, building symbolic
162     // expressions for shape and value components.
163     while (!forwardsWorklist.empty()) {
164       auto transitivelyRequestedInfo = forwardsWorklist.pop_back_val();
165 
166       // Skip if already processed.
167       if (symbolicExprsMap->count(transitivelyRequestedInfo)) continue;
168 
169       // Handle shapes.
170       Value value = transitivelyRequestedInfo.value();
171       if (!transitivelyRequestedInfo.isValueInfo()) {
172         if (value.getDefiningOp<shape::AssumingOp>()) {
173           forwardAssumingShape(value);
174         } else if (auto broadcast =
175                        value.getDefiningOp<mhlo::DynamicBroadcastInDimOp>()) {
176           forwardDynamicBroadcastInDimShape(broadcast);
177         } else if (auto reshape =
178                        value.getDefiningOp<mhlo::DynamicReshapeOp>()) {
179           forwardDynamicReshapeShape(reshape);
180         } else if (value.getDefiningOp<mhlo::ReduceOp>()) {
181           forwardReduceShape(value);
182         } else if (auto transpose = value.getDefiningOp<mhlo::TransposeOp>()) {
183           forwardTransposeShape(transpose);
184         } else if (auto select = value.getDefiningOp<mhlo::SelectOp>()) {
185           forwardSelectShape(select);
186         } else if (value.getDefiningOp() &&
187                    value.getDefiningOp()
188                        ->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
189           forwardSameOperandsShape(value);
190         } else {
191           forwardUnknownShape(value);
192         }
193         continue;
194       }
195 
196       // Handle values.
197       assert(transitivelyRequestedInfo.isValueInfo() &&
198              "Expect value info at this point.");
199       if (auto shapeof = value.getDefiningOp<shape::ShapeOfOp>()) {
200         forwardShapeOf(shapeof);
201       } else if (auto bcast = value.getDefiningOp<shape::BroadcastOp>()) {
202         forwardBroadcast(bcast);
203       } else if (auto numElements =
204                      value.getDefiningOp<shape::NumElementsOp>()) {
205         forwardNumElements(numElements);
206       } else if (auto dim = value.getDefiningOp<tensor::DimOp>()) {
207         forwardDim(dim);
208       } else if (auto cast = value.getDefiningOp<arith::IndexCastOp>()) {
209         forwardIndexCast(cast);
210       } else if (auto fromElements =
211                      value.getDefiningOp<tensor::FromElementsOp>()) {
212         forwardTensorFromElements(fromElements);
213       } else if (auto extract = value.getDefiningOp<tensor::ExtractOp>()) {
214         forwardTensorExtract(extract);
215       } else if (auto add = value.getDefiningOp<mhlo::AddOp>()) {
216         forwardBinOp(add, [](AffineExpr a, AffineExpr b) { return a + b; });
217       } else if (auto mul = value.getDefiningOp<mhlo::MulOp>()) {
218         forwardBinOp(mul, [](AffineExpr a, AffineExpr b) { return a * b; });
219       } else if (auto add = value.getDefiningOp<arith::AddIOp>()) {
220         forwardBinOp(add, [](AffineExpr a, AffineExpr b) { return a + b; });
221       } else if (auto mul = value.getDefiningOp<arith::MulIOp>()) {
222         forwardBinOp(mul, [](AffineExpr a, AffineExpr b) { return a * b; });
223       } else if (auto concat = value.getDefiningOp<mhlo::ConcatenateOp>()) {
224         forwardConcatenate(concat);
225       } else if (auto reshape = value.getDefiningOp<mhlo::ReshapeOp>()) {
226         forwardReshape(reshape);
227       } else if (auto slice = value.getDefiningOp<mhlo::SliceOp>()) {
228         forwardSlice(slice);
229       } else if (matchPattern(value, m_Constant())) {
230         forwardConstant(value);
231       } else {
232         forwardUnknown(value);
233       }
234     }
235   }
236 
237  private:
238   // ===
239   // Functions that traverse the shapes of operations.
240   // ===
241 
backwardAssumingShape__anon297788610111::ShapeVisitor242   void backwardAssumingShape(Value op) {
243     auto assumingOp = op.getDefiningOp<shape::AssumingOp>();
244     auto number = op.cast<OpResult>().getResultNumber();
245     forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op));
246     backwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(
247         cast<shape::AssumingYieldOp>(
248             assumingOp.getDoRegion().back().getTerminator())
249             .getOperand(number)));
250   }
forwardAssumingShape__anon297788610111::ShapeVisitor251   void forwardAssumingShape(Value op) {
252     auto assumingOp = op.getDefiningOp<shape::AssumingOp>();
253     auto number = op.cast<OpResult>().getResultNumber();
254     auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op));
255     dims = lookup(ShapeOrValueInfo::getShapeInfoOf(
256         cast<shape::AssumingYieldOp>(
257             assumingOp.getDoRegion().back().getTerminator())
258             .getOperand(number)));
259   }
backwardBroadcast__anon297788610111::ShapeVisitor260   void backwardBroadcast(shape::BroadcastOp op) {
261     forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
262     for (Value s : op.getShapes())
263       backwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(s));
264   }
forwardBroadcast__anon297788610111::ShapeVisitor265   void forwardBroadcast(shape::BroadcastOp op) {
266     auto *ctx = op.getContext();
267 
268     // Get operands' info.
269     SmallVector<ArrayRef<SymbolicExpr>> argsInfo =
270         llvm::to_vector(llvm::map_range(op.getShapes(), [&](Value s) {
271           return lookup(ShapeOrValueInfo::getValueInfoOf(s));
272         }));
273 
274     // Determine broadcasted rank.
275     size_t rank = 0;
276     for (auto &info : argsInfo) rank = std::max(rank, info.size());
277 
278     // Evaluate broadcast per result dimension.
279     auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
280     for (size_t i = 0; i < rank; ++i) {
281       // Init with neural element.
282       SymbolicExpr bcastedExpr;
283       bcastedExpr.expr = getAffineConstantExpr(1, ctx);
284 
285       // Consider all the operands.
286       for (auto &info : argsInfo) {
287         // Find corresponding symbolic expression for the ith result dimension,
288         // if the operand contributes.
289         size_t argRank = info.size();
290         if (i + argRank < rank) continue;
291         size_t j = i + argRank - rank;
292         SymbolicExpr expr = info[j];
293 
294         // One dimensions are neutral.
295         if (expr.isConstant(1)) continue;
296 
297         // If a dimension is known not to be 1, we can use this expression.
298         if (expr.isKnownNotOne()) {
299           bcastedExpr = expr;
300           break;
301         }
302 
303         // If all other dimensions were neutral, try using this expression.
304         if (bcastedExpr.isConstant(1)) {
305           bcastedExpr = expr;
306           continue;
307         }
308 
309         // If we have contradicting expressions, give up and create a new
310         // symbol.
311         if (bcastedExpr != expr) {
312           bcastedExpr.expr = getAffineSymbolExpr(0, ctx);
313           bcastedExpr.symbols = {{ShapeOrValueInfo::getValueInfoOf(op), i}};
314           break;
315         }
316       }
317 
318       dims.push_back(bcastedExpr);
319     }
320     assert(dims.size() == rank && "expect one expression per dimension");
321   }
backwardDynamicBroadcastInDimShape__anon297788610111::ShapeVisitor322   void backwardDynamicBroadcastInDimShape(mhlo::DynamicBroadcastInDimOp op) {
323     forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op));
324     backwardsWorklist.push_back(
325         ShapeOrValueInfo::getValueInfoOf(op.output_dimensions()));
326   }
forwardDynamicBroadcastInDimShape__anon297788610111::ShapeVisitor327   void forwardDynamicBroadcastInDimShape(mhlo::DynamicBroadcastInDimOp op) {
328     auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op));
329     dims = lookup(ShapeOrValueInfo::getValueInfoOf(op.output_dimensions()));
330   }
backwardDynamicReshapeShape__anon297788610111::ShapeVisitor331   void backwardDynamicReshapeShape(mhlo::DynamicReshapeOp op) {
332     forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op));
333     backwardsWorklist.push_back(
334         ShapeOrValueInfo::getValueInfoOf(op.output_shape()));
335   }
forwardDynamicReshapeShape__anon297788610111::ShapeVisitor336   void forwardDynamicReshapeShape(mhlo::DynamicReshapeOp op) {
337     auto rankedTy = op.getResult().getType().cast<RankedTensorType>();
338     auto shapeDims =
339         lookup(ShapeOrValueInfo::getValueInfoOf(op.output_shape()));
340     auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op));
341     dimsFromStaticShape(rankedTy, shapeDims, &dims);
342   }
backwardReduceShape__anon297788610111::ShapeVisitor343   void backwardReduceShape(Value op) {
344     forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op));
345     auto reduceOp = op.getDefiningOp<mhlo::ReduceOp>();
346     if (reduceOp.operands().size() == 1) {
347       backwardsWorklist.push_back(
348           ShapeOrValueInfo::getShapeInfoOf(reduceOp.operands().back()));
349     }
350   }
forwardReduceShape__anon297788610111::ShapeVisitor351   void forwardReduceShape(Value op) {
352     auto reduceOp = op.getDefiningOp<mhlo::ReduceOp>();
353     if (reduceOp.operands().size() != 1) return forwardUnknownShape(op);
354     auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op));
355     for (const auto &dim : llvm::enumerate(lookup(
356              ShapeOrValueInfo::getShapeInfoOf(reduceOp.operands().back())))) {
357       if (!llvm::is_contained(reduceOp.dimensions(), dim.index()))
358         dims.push_back(dim.value());
359     }
360   }
backwardTransposeShape__anon297788610111::ShapeVisitor361   void backwardTransposeShape(mhlo::TransposeOp op) {
362     forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op));
363     backwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op.operand()));
364   }
forwardTransposeShape__anon297788610111::ShapeVisitor365   void forwardTransposeShape(mhlo::TransposeOp op) {
366     auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op));
367     auto in = lookup(ShapeOrValueInfo::getShapeInfoOf(op.operand()));
368     auto elem = op.permutation().cast<DenseIntElementsAttr>();
369     for (const auto &val : elem) dims.push_back(in[val.getZExtValue()]);
370   }
backwardSelectShape__anon297788610111::ShapeVisitor371   void backwardSelectShape(mhlo::SelectOp op) {
372     forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op));
373     backwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op.on_true()));
374   }
forwardSelectShape__anon297788610111::ShapeVisitor375   void forwardSelectShape(mhlo::SelectOp op) {
376     auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op));
377     // Forward the `on_true` operand, it has the same shape as the output.
378     dims = lookup(ShapeOrValueInfo::getShapeInfoOf(op.on_true()));
379   }
backwardSameOperandsAndResultShape__anon297788610111::ShapeVisitor380   void backwardSameOperandsAndResultShape(Value v) {
381     forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(v));
382     backwardsWorklist.push_back(
383         ShapeOrValueInfo::getShapeInfoOf(v.getDefiningOp()->getOperand(0)));
384   }
forwardSameOperandsShape__anon297788610111::ShapeVisitor385   void forwardSameOperandsShape(Value v) {
386     auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(v));
387     dims = lookup(
388         ShapeOrValueInfo::getShapeInfoOf(v.getDefiningOp()->getOperand(0)));
389   }
backwardBlockArgumentShape__anon297788610111::ShapeVisitor390   void backwardBlockArgumentShape(BlockArgument argument) {
391     // JitRT uses rt.symbolic_shape to describe identical dimensions. Make
392     // use of that when it exists.
393     //
394     // Example:
395     //   func @compute(
396     //     %arg0: tensor<?xf32> {rt.symbolic_shape = dense<-2> :
397     //     tensor<1xi64>}, %arg1: tensor<?xf32> {rt.symbolic_shape =
398     //     dense<-2> : tensor<1xi64>})
399     //   } { ... }
400     //
401     // Symbolic shape is a negative value smaller than `-1`. The concrete value
402     // is not known at compile time, and in this particular example it is only
403     // known that both arguments have the same shape.
404     //
405     // TODO(ezhulenev): Add symbolic shape attribute verifier to the jitrt
406     // dialect.
407     if (auto func = dyn_cast_or_null<func::FuncOp>(
408             argument.getOwner()->getParentOp())) {
409       if (auto shape = func.getArgAttrOfType<DenseIntElementsAttr>(
410               argument.getArgNumber(), "rt.symbolic_shape")) {
411         auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(argument));
412         auto id = getAffineSymbolExpr(0, argument.getContext());
413         for (const auto &symbol : llvm::enumerate(shape.getValues<ssize_t>())) {
414           dims.emplace_back();
415           auto &dim = dims.back();
416           if (symbol.value() >= 0) {
417             dim.expr =
418                 getAffineConstantExpr(symbol.value(), argument.getContext());
419           } else {
420             auto it = symbolicShapeConstraintsMap->try_emplace(
421                 symbol.value(),
422                 Symbol{ShapeOrValueInfo::getShapeInfoOf(argument),
423                        symbol.index()});
424             dim.symbols.push_back(it.first->second);
425             dim.expr = id;
426           }
427         }
428         return;
429       }
430     }
431     forwardUnknownShape(argument);
432   }
backwardUnknownShape__anon297788610111::ShapeVisitor433   void backwardUnknownShape(Value v) {
434     forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(v));
435   }
forwardUnknownShape__anon297788610111::ShapeVisitor436   void forwardUnknownShape(Value v) {
437     auto rankedTy = v.getType().dyn_cast<RankedTensorType>();
438     if (!rankedTy) return;
439     auto id = getAffineSymbolExpr(0, v.getContext());
440     auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(v));
441     return dimsFromStaticShape(
442         rankedTy,
443         [&](size_t i) {
444           SymbolicExpr d;
445           d.symbols.push_back({ShapeOrValueInfo::getShapeInfoOf(v), i});
446           d.expr = id;
447           return d;
448         },
449         &dims);
450   }
451 
452   // ===
453   // Functions that traverse values. These can be shape tensors (e.g., of type
454   // tensor<3xindex>) or interesting scalars (e.g., of type index).
455   // ===
456 
backwardShapeOf__anon297788610111::ShapeVisitor457   void backwardShapeOf(shape::ShapeOfOp op) {
458     forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
459     backwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op.getArg()));
460   }
forwardShapeOf__anon297788610111::ShapeVisitor461   void forwardShapeOf(shape::ShapeOfOp op) {
462     auto rankedTy = op.getArg().getType().cast<RankedTensorType>();
463     auto arg = lookup(ShapeOrValueInfo::getShapeInfoOf(op.getArg()));
464     auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
465     return dimsFromStaticShape(rankedTy, arg, &dims);
466   }
backwardNumElements__anon297788610111::ShapeVisitor467   void backwardNumElements(shape::NumElementsOp op) {
468     forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
469     backwardsWorklist.push_back(
470         ShapeOrValueInfo::getValueInfoOf(op.getShape()));
471   }
forwardNumElements__anon297788610111::ShapeVisitor472   void forwardNumElements(shape::NumElementsOp op) {
473     auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.getShape()));
474 
475     // Accumulate product symbolically and concrete where possible.
476     int64_t concreteProduct = 1;
477     SymbolicExpr dim;
478     for (auto &it : in) {
479       // For constant expressions, we can accumulate a concrete product.
480       if (auto cexpr = it.expr.dyn_cast<AffineConstantExpr>()) {
481         assert(cexpr.getValue() > 0 && "shape value must be positive");
482         concreteProduct *= cexpr.getValue();
483         continue;
484       }
485 
486       // Simply copy the first sybolic factor.
487       if (!dim.expr) {
488         dim = it;
489         continue;
490       }
491 
492       // Multiply remaining symbolic factors.
493       dim.expr = dim.expr *
494                  it.expr.shiftSymbols(dim.symbols.size(), it.symbols.size());
495       dim.symbols.append(it.symbols);
496     }
497 
498     // Combine concrete and symbolic product.
499     if (concreteProduct != 1 || !dim.expr) {
500       auto cexpr = getAffineConstantExpr(concreteProduct, op.getContext());
501       if (dim.expr)
502         dim.expr = cexpr * dim.expr;
503       else
504         dim.expr = cexpr;
505     }
506 
507     auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
508     dims.push_back(dim);
509   }
backwardDim__anon297788610111::ShapeVisitor510   void backwardDim(tensor::DimOp op) {
511     forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
512     backwardsWorklist.push_back(
513         ShapeOrValueInfo::getShapeInfoOf(op.getSource()));
514   }
forwardDim__anon297788610111::ShapeVisitor515   void forwardDim(tensor::DimOp op) {
516     auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
517     if (auto index = op.getIndex().getDefiningOp<arith::ConstantOp>()) {
518       int64_t i = index.getValue().cast<IntegerAttr>().getInt();
519       auto in = lookup(ShapeOrValueInfo::getShapeInfoOf(op.getSource()));
520       dims.push_back({in[i].symbols, in[i].expr});
521     } else {
522       forwardUnknown(op);
523     }
524   }
525   template <typename Op>
backwardBinOp__anon297788610111::ShapeVisitor526   void backwardBinOp(Op op) {
527     forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
528     // TODO(jpienaar): Switch to named accessors when MHLO uses prefixed form.
529     backwardsWorklist.append(
530         {ShapeOrValueInfo::getValueInfoOf(op.getOperand(0)),
531          ShapeOrValueInfo::getValueInfoOf(op.getOperand(1))});
532   }
533   template <typename Op, typename Combiner>
forwardBinOp__anon297788610111::ShapeVisitor534   void forwardBinOp(Op op, Combiner &&combiner) {
535     auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
536     // TODO(jpienaar): Switch to named accessors when MHLO uses prefixed form.
537     auto lhs = lookup(ShapeOrValueInfo::getValueInfoOf(op.getOperand(0)));
538     auto rhs = lookup(ShapeOrValueInfo::getValueInfoOf(op.getOperand(1)));
539     for (int64_t i = 0, e = dim0size(op.getType()); i != e; ++i) {
540       dims.emplace_back();
541       auto &dim = dims.back();
542       dim.symbols.append(lhs[i].symbols);
543       dim.symbols.append(rhs[i].symbols);
544       dim.expr = combiner(lhs[i].expr,
545                           rhs[i].expr.shiftSymbols(rhs[i].symbols.size(),
546                                                    lhs[i].symbols.size()));
547     }
548   }
backwardIndexCast__anon297788610111::ShapeVisitor549   void backwardIndexCast(arith::IndexCastOp op) {
550     forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
551     backwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op.getIn()));
552   }
forwardIndexCast__anon297788610111::ShapeVisitor553   void forwardIndexCast(arith::IndexCastOp op) {
554     auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
555     auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.getIn()));
556     for (int64_t i = 0, e = dim0size(op.getType()); i != e; ++i) {
557       // This is intentionally not modelling the truncation/zero extension of
558       // index_cast. While it's incorrect it doesn't really matter for shape
559       // computations.
560       dims.push_back({in[i].symbols, in[i].expr});
561     }
562   }
backwardTensorFromElements__anon297788610111::ShapeVisitor563   void backwardTensorFromElements(tensor::FromElementsOp op) {
564     forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
565     for (auto operand : op.getOperands())
566       backwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(operand));
567   }
forwardTensorFromElements__anon297788610111::ShapeVisitor568   void forwardTensorFromElements(tensor::FromElementsOp op) {
569     auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
570     for (auto operand : op.getOperands()) {
571       auto in = lookup(ShapeOrValueInfo::getValueInfoOf(operand));
572       assert(in.size() == 1);
573       dims.push_back({in[0].symbols, in[0].expr});
574     }
575   }
backwardTensorExtract__anon297788610111::ShapeVisitor576   void backwardTensorExtract(tensor::ExtractOp op) {
577     forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
578     backwardsWorklist.push_back(
579         ShapeOrValueInfo::getValueInfoOf(op.getTensor()));
580   }
forwardTensorExtract__anon297788610111::ShapeVisitor581   void forwardTensorExtract(tensor::ExtractOp op) {
582     auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
583     assert(op.getIndices().size() == 1);
584     if (auto index =
585             op.getIndices().front().getDefiningOp<arith::ConstantOp>()) {
586       int64_t i = index.getValue().cast<IntegerAttr>().getInt();
587       // We asssume this is in bounds.
588       auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.getTensor()));
589       dims.push_back({in[i].symbols, in[i].expr});
590     } else {
591       forwardUnknown(op);
592     }
593   }
backwardConstant__anon297788610111::ShapeVisitor594   void backwardConstant(Value v) {
595     forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(v));
596   }
forwardConstant__anon297788610111::ShapeVisitor597   void forwardConstant(Value v) {
598     IntegerAttr intAttr;
599     DenseIntElementsAttr denseAttr;
600     if (matchPattern(v, m_Constant(&denseAttr))) {
601       auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(v));
602       for (uint64_t i = 0, e = dim0size(v.getType()); i != e; ++i) {
603         dims.emplace_back();
604         auto &dim = dims.back();
605         dim.expr = getAffineConstantExpr(
606             denseAttr.getValues<APInt>()[i].getSExtValue(), v.getContext());
607       }
608     } else if (matchPattern(v, m_Constant(&intAttr))) {
609       auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(v));
610       dims.emplace_back();
611       auto &dim = dims.back();
612       dim.expr = getAffineConstantExpr(intAttr.getInt(), v.getContext());
613     } else {
614       forwardUnknown(v);
615     }
616   }
backwardConcatenate__anon297788610111::ShapeVisitor617   void backwardConcatenate(mhlo::ConcatenateOp op) {
618     forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
619     for (auto operand : op.getOperands())
620       backwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(operand));
621   }
forwardConcatenate__anon297788610111::ShapeVisitor622   void forwardConcatenate(mhlo::ConcatenateOp op) {
623     for (auto operand : op.getOperands()) {
624       auto in = lookup(ShapeOrValueInfo::getValueInfoOf(operand));
625       if (in.size() != 1) return forwardUnknown(op);
626     }
627     auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
628     for (auto operand : op.getOperands()) {
629       auto in = lookup(ShapeOrValueInfo::getValueInfoOf(operand));
630       dims.push_back({in[0].symbols, in[0].expr});
631     }
632   }
backwardReshape__anon297788610111::ShapeVisitor633   void backwardReshape(mhlo::ReshapeOp op) {
634     forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
635     backwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op.operand()));
636   }
forwardReshape__anon297788610111::ShapeVisitor637   void forwardReshape(mhlo::ReshapeOp op) {
638     auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.operand()));
639     if (in.size() != 1) return forwardUnknown(op);
640     auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
641     dims.push_back({in[0].symbols, in[0].expr});
642   }
backwardSlice__anon297788610111::ShapeVisitor643   void backwardSlice(mhlo::SliceOp op) {
644     forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
645     backwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op.operand()));
646   }
forwardSlice__anon297788610111::ShapeVisitor647   void forwardSlice(mhlo::SliceOp op) {
648     // Only handle slices equivalent to an extract.
649     if (!op.getType().hasStaticShape({1})) {
650       return forwardUnknown(op);
651     }
652     auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
653     auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.operand()));
654     auto elem = op.start_indices().cast<DenseIntElementsAttr>();
655     auto i = (*elem.begin()).getZExtValue();
656     if (i >= in.size()) {  // Bounds check.
657       return forwardUnknown(op);
658     }
659     dims.push_back({in[i].symbols, in[i].expr});
660   }
backwardUnknown__anon297788610111::ShapeVisitor661   void backwardUnknown(Value v) {
662     forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(v));
663   }
forwardUnknown__anon297788610111::ShapeVisitor664   void forwardUnknown(Value v) {
665     auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(v));
666     auto id = getAffineSymbolExpr(0, v.getContext());
667     for (size_t i = 0, e = dim0size(v.getType()); i != e; ++i) {
668       dims.emplace_back();
669       auto &dim = dims.back();
670       dim.symbols.push_back({ShapeOrValueInfo::getValueInfoOf(v), i});
671       dim.expr = id;
672     }
673   }
674 
675   // ===
676   // Helpers
677   // ===
678 
dimsFromStaticShape__anon297788610111::ShapeVisitor679   static void dimsFromStaticShape(
680       RankedTensorType rankedTy,
681       llvm::function_ref<SymbolicExpr(int64_t)> fallback,
682       std::vector<SymbolicExpr> *mergedDims) {
683     auto *ctx = rankedTy.getContext();
684     for (int64_t i = 0, e = rankedTy.getRank(); i != e; ++i) {
685       if (rankedTy.isDynamicDim(i)) {
686         mergedDims->push_back(fallback(i));
687       } else {
688         mergedDims->emplace_back();
689         auto &d = mergedDims->back();
690         d.expr = getAffineConstantExpr(rankedTy.getDimSize(i), ctx);
691       }
692     }
693   }
694 
dimsFromStaticShape__anon297788610111::ShapeVisitor695   static void dimsFromStaticShape(RankedTensorType rankedTy,
696                                   ArrayRef<SymbolicExpr> fallback,
697                                   std::vector<SymbolicExpr> *mergedDims) {
698     return dimsFromStaticShape(
699         rankedTy, [&](int64_t i) { return fallback[i]; }, mergedDims);
700   }
701 
702   // Return the size of the first dimension. Returns 1 for scalars.
dim0size__anon297788610111::ShapeVisitor703   static int64_t dim0size(Type type) {
704     if (auto rankedType = type.dyn_cast<RankedTensorType>())
705       return rankedType.getRank() == 0 ? 1 : rankedType.getDimSize(0);
706     return 1;
707   }
708 
709   // Retrieves the existing information from the cache.
lookup__anon297788610111::ShapeVisitor710   ArrayRef<SymbolicExpr> lookup(ShapeOrValueInfo requestedInfo) {
711     auto i = symbolicExprsMap->find(requestedInfo);
712     assert(i != symbolicExprsMap->end() && "op not processed yet?");
713     return llvm::makeArrayRef(i->second);
714   }
715 
716   // Inserts a new entry into the cache and returns a reference to its result
717   // components.
insert__anon297788610111::ShapeVisitor718   std::vector<SymbolicExpr> &insert(ShapeOrValueInfo requestedInfo) {
719     auto i = symbolicExprsMap->try_emplace(requestedInfo);
720     assert(i.second && "op already processed?");
721     return i.first->second;
722   }
723 
724   SymbolicExprsMap *symbolicExprsMap;
725   SymbolicShapeConstraintsMap *symbolicShapeConstraintsMap;
726 
727   // Worklists for the forward and backward passes.
728   SmallVector<ShapeOrValueInfo> backwardsWorklist;
729   SmallVector<ShapeOrValueInfo> forwardsWorklist;
730 };
731 }  // namespace
732 
compute(ShapeOrValueInfo requestedInfo)733 void ShapeComponentAnalysis::compute(ShapeOrValueInfo requestedInfo) {
734   ShapeVisitor(&symbolicExprsMap, &symbolicShapeConstraintsMap)
735       .visit(requestedInfo);
736 }
737 
738 Optional<ArrayRef<SymbolicExpr>>
GetShapeInfo(Value value)739 ShapeComponentAnalysis::ShapeComponentAnalysis::GetShapeInfo(Value value) {
740   auto request = ShapeOrValueInfo::getShapeInfoOf(value);
741   compute(request);
742   auto found = symbolicExprsMap.find(request);
743   if (found == symbolicExprsMap.end()) return {};
744   return llvm::makeArrayRef(found->second);
745 }
746 
747 Optional<ArrayRef<SymbolicExpr>>
GetValueInfo(Value shape)748 ShapeComponentAnalysis::ShapeComponentAnalysis::GetValueInfo(Value shape) {
749   auto request = ShapeOrValueInfo::getValueInfoOf(shape);
750   compute(request);
751   auto found = symbolicExprsMap.find(request);
752   if (found == symbolicExprsMap.end()) return {};
753   return llvm::makeArrayRef(found->second);
754 }
755 
reset()756 void ShapeComponentAnalysis::reset() {
757   symbolicExprsMap.clear();
758   symbolicShapeConstraintsMap.clear();
759 }
760 
isConstant(int64_t value) const761 bool SymbolicExpr::isConstant(int64_t value) const {
762   return expr.isa<AffineConstantExpr>() &&
763          expr.cast<AffineConstantExpr>().getValue() == value;
764 }
765 
isKnownNotNegativeOne() const766 bool SymbolicExpr::isKnownNotNegativeOne() const {
767   // If the symbol is coming from a shape it can't be a -1. Also allow results
768   // of shape_of, compute_reshape_shape, and num_elements. This is correct, not
769   // complete.
770   auto isGoodSymbol = [](const Symbol &symbol) {
771     if (symbol.source.isShapeInfo()) return true;
772     Operation *op = symbol.source.value().getDefiningOp();
773     if (op == nullptr) return false;
774     return llvm::isa<shape::ShapeOfOp, mhlo::ComputeReshapeShapeOp,
775                      shape::NumElementsOp>(op);
776   };
777 
778   // For constants we know if it's -1 or not. Checking the sign is sufficient
779   // here and allows for reuse below. This is correct, not complete.
780   auto isGoodSymbolOrGoodConstantExpr = [&](AffineExpr expr) {
781     if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
782       return isGoodSymbol(symbols[symExpr.getPosition()]);
783     if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
784       return constExpr.getValue() >= 0;
785     return false;
786   };
787 
788   if (isGoodSymbolOrGoodConstantExpr(expr)) return true;
789 
790   // Multiplying non-negative symbols and non-negative constants will always
791   // give a positive result. This is correct, not complete.
792   // TODO(kramerb): Could the analysis provide a generic interface for this?
793   if (auto bexpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
794     return bexpr.getKind() == AffineExprKind::Mul &&
795            isGoodSymbolOrGoodConstantExpr(bexpr.getLHS()) &&
796            isGoodSymbolOrGoodConstantExpr(bexpr.getRHS());
797   }
798 
799   return false;
800 }
801 
isKnownNotOne() const802 bool SymbolicExpr::isKnownNotOne() const {
803   if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
804     return constExpr.getValue() != 1;
805   }
806   return false;
807 }
808 
singleton() const809 llvm::Optional<Symbol> SymbolicExpr::singleton() const {
810   if (expr.isa<AffineSymbolExpr>() &&
811       expr.cast<AffineSymbolExpr>().getPosition() == 0) {
812     assert(symbols.size() == 1);
813     return symbols[0];
814   }
815   return llvm::None;
816 }
817 
dump(llvm::raw_ostream & os) const818 void SymbolicExpr::dump(llvm::raw_ostream &os) const {
819   expr.print(os);
820   if (!symbols.empty()) os << " with";
821   os << "\n";
822   if (symbols.empty()) return;
823   for (const auto &sym : llvm::enumerate(symbols)) {
824     os.indent(4);
825     os << 's' << sym.index() << " = ";
826     if (!sym.value().source.isValueInfo()) os << "shapeof(";
827     sym.value().source.value().print(os);
828     if (!sym.value().source.isValueInfo()) os << ")";
829     os << '[' << sym.value().index << "]\n";
830   }
831 }
832