1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Analysis/shape_component_analysis.h"
17
18 #include <algorithm>
19 #include <vector>
20
21 #include "llvm/ADT/STLExtras.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
23 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/Dialect/Shape/IR/Shape.h"
26 #include "mlir/Dialect/Tensor/IR/Tensor.h"
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Matchers.h"
30
31 using namespace mlir;
32
33 using SymbolicShapeConstraintsMap =
34 ShapeComponentAnalysis::SymbolicShapeConstraintsMap;
35 using ShapeOrValueInfo = ShapeComponentAnalysis::ShapeOrValueInfo;
36 using Symbol = ShapeComponentAnalysis::Symbol;
37 using SymbolicExpr = ShapeComponentAnalysis::SymbolicExpr;
38 using SymbolicExprsMap = ShapeComponentAnalysis::SymbolicExprsMap;
39
40 namespace {
41 // Shape visitor. This implements a symbolic interpreter for MHLO with some
42 // shape and tensor dialect ops mixed in. We are interested in shapes (e.g., the
43 // dimensions of a tensor) and values (e.g, the elements of a shape tensor). The
44 // goal is to assign every component of a shape or value either a symbol, a
45 // constant, or a symbolic expression. We propagate these symbolic expressions
46 // through the various operations. Later optimization passes can use this
47 // information for optimizations, e.g., exploiting the equality of dimensions.
48 //
49 // The visitation happens in two phases:
50 // 1. Find the sources of a value's shape or value. This climbs up the
51 // operations from a given value until an unknown op or a function argument
52 // is found. These sources are assigned the initial symbols for each of
53 // their components.
54 // 2. Propagate the initial symbols downwards. This builds symbolic
55 // expressions so users of the analysis can pattern match things like
56 // "two dimensions are multiplied".
57 //
58 // Conceptually, this is defined recursively. For each op, we compute the
59 // required shape or value information for the operands and then derive the
60 // resulting symbolic expression.
61 struct ShapeVisitor {
ShapeVisitor__anon297788610111::ShapeVisitor62 ShapeVisitor(SymbolicExprsMap *symbolicExprsMap,
63 SymbolicShapeConstraintsMap *symbolicShapeConstraintsMap)
64 : symbolicExprsMap(symbolicExprsMap),
65 symbolicShapeConstraintsMap(symbolicShapeConstraintsMap) {}
66
visit__anon297788610111::ShapeVisitor67 void visit(ShapeOrValueInfo requestedInfo) {
68 backwardsWorklist.push_back(requestedInfo);
69
70 // First, we climb up the operations so we get the set of all ops taking
71 // part in this shape or value computation. An alternative would be
72 // analyzing everything eagerly. This backwards pass allows us to be lazy.
73 while (!backwardsWorklist.empty()) {
74 // Skip if already processed.
75 ShapeOrValueInfo transitivelyRequestedInfo =
76 backwardsWorklist.pop_back_val();
77 if (symbolicExprsMap->count(transitivelyRequestedInfo)) continue;
78
79 // Skip irrelevant cases early.
80 Value value = transitivelyRequestedInfo.value();
81 Type ty = value.getType();
82 if (!ty.isIntOrIndexOrFloat() && !ty.isa<RankedTensorType>()) continue;
83
84 // Handle shapes.
85 if (transitivelyRequestedInfo.isShapeInfo()) {
86 if (value.getDefiningOp<shape::AssumingOp>()) {
87 backwardAssumingShape(value);
88 } else if (auto bcast =
89 value.getDefiningOp<mhlo::DynamicBroadcastInDimOp>()) {
90 backwardDynamicBroadcastInDimShape(bcast);
91 } else if (auto reshape =
92 value.getDefiningOp<mhlo::DynamicReshapeOp>()) {
93 backwardDynamicReshapeShape(reshape);
94 } else if (value.getDefiningOp<mhlo::ReduceOp>()) {
95 backwardReduceShape(value);
96 } else if (auto transpose = value.getDefiningOp<mhlo::TransposeOp>()) {
97 backwardTransposeShape(transpose);
98 } else if (auto select = value.getDefiningOp<mhlo::SelectOp>()) {
99 backwardSelectShape(select);
100 } else if (auto arg = value.dyn_cast<BlockArgument>()) {
101 backwardBlockArgumentShape(arg);
102 } else if (value.getDefiningOp() &&
103 value.getDefiningOp()
104 ->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
105 backwardSameOperandsAndResultShape(value);
106 } else {
107 backwardUnknownShape(value);
108 }
109 continue;
110 }
111
112 // Skip irrelevant cases early.
113 auto rankedTy = ty.dyn_cast<RankedTensorType>();
114 bool isPossiblyInterestingScalar = ty.isIntOrIndex();
115 bool isPossiblyInterestingTensor =
116 rankedTy && rankedTy.getRank() <= 1 && rankedTy.hasStaticShape();
117 if (!isPossiblyInterestingScalar && !isPossiblyInterestingTensor) {
118 continue;
119 }
120
121 // Handle values.
122 assert(transitivelyRequestedInfo.isValueInfo() &&
123 "Expect value info at this point.");
124 if (auto shapeof = value.getDefiningOp<shape::ShapeOfOp>()) {
125 backwardShapeOf(shapeof);
126 } else if (auto bcast = value.getDefiningOp<shape::BroadcastOp>()) {
127 backwardBroadcast(bcast);
128 } else if (auto numElements =
129 value.getDefiningOp<shape::NumElementsOp>()) {
130 backwardNumElements(numElements);
131 } else if (auto dim = value.getDefiningOp<tensor::DimOp>()) {
132 backwardDim(dim);
133 } else if (auto cast = value.getDefiningOp<arith::IndexCastOp>()) {
134 backwardIndexCast(cast);
135 } else if (auto fromElements =
136 value.getDefiningOp<tensor::FromElementsOp>()) {
137 backwardTensorFromElements(fromElements);
138 } else if (auto extract = value.getDefiningOp<tensor::ExtractOp>()) {
139 backwardTensorExtract(extract);
140 } else if (auto add = value.getDefiningOp<mhlo::AddOp>()) {
141 backwardBinOp(add);
142 } else if (auto mul = value.getDefiningOp<mhlo::MulOp>()) {
143 backwardBinOp(mul);
144 } else if (auto add = value.getDefiningOp<arith::AddIOp>()) {
145 backwardBinOp(add);
146 } else if (auto mul = value.getDefiningOp<arith::MulIOp>()) {
147 backwardBinOp(mul);
148 } else if (auto concat = value.getDefiningOp<mhlo::ConcatenateOp>()) {
149 backwardConcatenate(concat);
150 } else if (auto reshape = value.getDefiningOp<mhlo::ReshapeOp>()) {
151 backwardReshape(reshape);
152 } else if (auto slice = value.getDefiningOp<mhlo::SliceOp>()) {
153 backwardSlice(slice);
154 } else if (matchPattern(value, m_Constant())) {
155 backwardConstant(value);
156 } else {
157 backwardUnknown(value);
158 }
159 }
160
161 // Second, we walk down from the defs to the uses, building symbolic
162 // expressions for shape and value components.
163 while (!forwardsWorklist.empty()) {
164 auto transitivelyRequestedInfo = forwardsWorklist.pop_back_val();
165
166 // Skip if already processed.
167 if (symbolicExprsMap->count(transitivelyRequestedInfo)) continue;
168
169 // Handle shapes.
170 Value value = transitivelyRequestedInfo.value();
171 if (!transitivelyRequestedInfo.isValueInfo()) {
172 if (value.getDefiningOp<shape::AssumingOp>()) {
173 forwardAssumingShape(value);
174 } else if (auto broadcast =
175 value.getDefiningOp<mhlo::DynamicBroadcastInDimOp>()) {
176 forwardDynamicBroadcastInDimShape(broadcast);
177 } else if (auto reshape =
178 value.getDefiningOp<mhlo::DynamicReshapeOp>()) {
179 forwardDynamicReshapeShape(reshape);
180 } else if (value.getDefiningOp<mhlo::ReduceOp>()) {
181 forwardReduceShape(value);
182 } else if (auto transpose = value.getDefiningOp<mhlo::TransposeOp>()) {
183 forwardTransposeShape(transpose);
184 } else if (auto select = value.getDefiningOp<mhlo::SelectOp>()) {
185 forwardSelectShape(select);
186 } else if (value.getDefiningOp() &&
187 value.getDefiningOp()
188 ->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
189 forwardSameOperandsShape(value);
190 } else {
191 forwardUnknownShape(value);
192 }
193 continue;
194 }
195
196 // Handle values.
197 assert(transitivelyRequestedInfo.isValueInfo() &&
198 "Expect value info at this point.");
199 if (auto shapeof = value.getDefiningOp<shape::ShapeOfOp>()) {
200 forwardShapeOf(shapeof);
201 } else if (auto bcast = value.getDefiningOp<shape::BroadcastOp>()) {
202 forwardBroadcast(bcast);
203 } else if (auto numElements =
204 value.getDefiningOp<shape::NumElementsOp>()) {
205 forwardNumElements(numElements);
206 } else if (auto dim = value.getDefiningOp<tensor::DimOp>()) {
207 forwardDim(dim);
208 } else if (auto cast = value.getDefiningOp<arith::IndexCastOp>()) {
209 forwardIndexCast(cast);
210 } else if (auto fromElements =
211 value.getDefiningOp<tensor::FromElementsOp>()) {
212 forwardTensorFromElements(fromElements);
213 } else if (auto extract = value.getDefiningOp<tensor::ExtractOp>()) {
214 forwardTensorExtract(extract);
215 } else if (auto add = value.getDefiningOp<mhlo::AddOp>()) {
216 forwardBinOp(add, [](AffineExpr a, AffineExpr b) { return a + b; });
217 } else if (auto mul = value.getDefiningOp<mhlo::MulOp>()) {
218 forwardBinOp(mul, [](AffineExpr a, AffineExpr b) { return a * b; });
219 } else if (auto add = value.getDefiningOp<arith::AddIOp>()) {
220 forwardBinOp(add, [](AffineExpr a, AffineExpr b) { return a + b; });
221 } else if (auto mul = value.getDefiningOp<arith::MulIOp>()) {
222 forwardBinOp(mul, [](AffineExpr a, AffineExpr b) { return a * b; });
223 } else if (auto concat = value.getDefiningOp<mhlo::ConcatenateOp>()) {
224 forwardConcatenate(concat);
225 } else if (auto reshape = value.getDefiningOp<mhlo::ReshapeOp>()) {
226 forwardReshape(reshape);
227 } else if (auto slice = value.getDefiningOp<mhlo::SliceOp>()) {
228 forwardSlice(slice);
229 } else if (matchPattern(value, m_Constant())) {
230 forwardConstant(value);
231 } else {
232 forwardUnknown(value);
233 }
234 }
235 }
236
237 private:
238 // ===
239 // Functions that traverse the shapes of operations.
240 // ===
241
backwardAssumingShape__anon297788610111::ShapeVisitor242 void backwardAssumingShape(Value op) {
243 auto assumingOp = op.getDefiningOp<shape::AssumingOp>();
244 auto number = op.cast<OpResult>().getResultNumber();
245 forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op));
246 backwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(
247 cast<shape::AssumingYieldOp>(
248 assumingOp.getDoRegion().back().getTerminator())
249 .getOperand(number)));
250 }
forwardAssumingShape__anon297788610111::ShapeVisitor251 void forwardAssumingShape(Value op) {
252 auto assumingOp = op.getDefiningOp<shape::AssumingOp>();
253 auto number = op.cast<OpResult>().getResultNumber();
254 auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op));
255 dims = lookup(ShapeOrValueInfo::getShapeInfoOf(
256 cast<shape::AssumingYieldOp>(
257 assumingOp.getDoRegion().back().getTerminator())
258 .getOperand(number)));
259 }
backwardBroadcast__anon297788610111::ShapeVisitor260 void backwardBroadcast(shape::BroadcastOp op) {
261 forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
262 for (Value s : op.getShapes())
263 backwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(s));
264 }
forwardBroadcast__anon297788610111::ShapeVisitor265 void forwardBroadcast(shape::BroadcastOp op) {
266 auto *ctx = op.getContext();
267
268 // Get operands' info.
269 SmallVector<ArrayRef<SymbolicExpr>> argsInfo =
270 llvm::to_vector(llvm::map_range(op.getShapes(), [&](Value s) {
271 return lookup(ShapeOrValueInfo::getValueInfoOf(s));
272 }));
273
274 // Determine broadcasted rank.
275 size_t rank = 0;
276 for (auto &info : argsInfo) rank = std::max(rank, info.size());
277
278 // Evaluate broadcast per result dimension.
279 auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
280 for (size_t i = 0; i < rank; ++i) {
281 // Init with neural element.
282 SymbolicExpr bcastedExpr;
283 bcastedExpr.expr = getAffineConstantExpr(1, ctx);
284
285 // Consider all the operands.
286 for (auto &info : argsInfo) {
287 // Find corresponding symbolic expression for the ith result dimension,
288 // if the operand contributes.
289 size_t argRank = info.size();
290 if (i + argRank < rank) continue;
291 size_t j = i + argRank - rank;
292 SymbolicExpr expr = info[j];
293
294 // One dimensions are neutral.
295 if (expr.isConstant(1)) continue;
296
297 // If a dimension is known not to be 1, we can use this expression.
298 if (expr.isKnownNotOne()) {
299 bcastedExpr = expr;
300 break;
301 }
302
303 // If all other dimensions were neutral, try using this expression.
304 if (bcastedExpr.isConstant(1)) {
305 bcastedExpr = expr;
306 continue;
307 }
308
309 // If we have contradicting expressions, give up and create a new
310 // symbol.
311 if (bcastedExpr != expr) {
312 bcastedExpr.expr = getAffineSymbolExpr(0, ctx);
313 bcastedExpr.symbols = {{ShapeOrValueInfo::getValueInfoOf(op), i}};
314 break;
315 }
316 }
317
318 dims.push_back(bcastedExpr);
319 }
320 assert(dims.size() == rank && "expect one expression per dimension");
321 }
backwardDynamicBroadcastInDimShape__anon297788610111::ShapeVisitor322 void backwardDynamicBroadcastInDimShape(mhlo::DynamicBroadcastInDimOp op) {
323 forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op));
324 backwardsWorklist.push_back(
325 ShapeOrValueInfo::getValueInfoOf(op.output_dimensions()));
326 }
forwardDynamicBroadcastInDimShape__anon297788610111::ShapeVisitor327 void forwardDynamicBroadcastInDimShape(mhlo::DynamicBroadcastInDimOp op) {
328 auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op));
329 dims = lookup(ShapeOrValueInfo::getValueInfoOf(op.output_dimensions()));
330 }
backwardDynamicReshapeShape__anon297788610111::ShapeVisitor331 void backwardDynamicReshapeShape(mhlo::DynamicReshapeOp op) {
332 forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op));
333 backwardsWorklist.push_back(
334 ShapeOrValueInfo::getValueInfoOf(op.output_shape()));
335 }
forwardDynamicReshapeShape__anon297788610111::ShapeVisitor336 void forwardDynamicReshapeShape(mhlo::DynamicReshapeOp op) {
337 auto rankedTy = op.getResult().getType().cast<RankedTensorType>();
338 auto shapeDims =
339 lookup(ShapeOrValueInfo::getValueInfoOf(op.output_shape()));
340 auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op));
341 dimsFromStaticShape(rankedTy, shapeDims, &dims);
342 }
backwardReduceShape__anon297788610111::ShapeVisitor343 void backwardReduceShape(Value op) {
344 forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op));
345 auto reduceOp = op.getDefiningOp<mhlo::ReduceOp>();
346 if (reduceOp.operands().size() == 1) {
347 backwardsWorklist.push_back(
348 ShapeOrValueInfo::getShapeInfoOf(reduceOp.operands().back()));
349 }
350 }
forwardReduceShape__anon297788610111::ShapeVisitor351 void forwardReduceShape(Value op) {
352 auto reduceOp = op.getDefiningOp<mhlo::ReduceOp>();
353 if (reduceOp.operands().size() != 1) return forwardUnknownShape(op);
354 auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op));
355 for (const auto &dim : llvm::enumerate(lookup(
356 ShapeOrValueInfo::getShapeInfoOf(reduceOp.operands().back())))) {
357 if (!llvm::is_contained(reduceOp.dimensions(), dim.index()))
358 dims.push_back(dim.value());
359 }
360 }
backwardTransposeShape__anon297788610111::ShapeVisitor361 void backwardTransposeShape(mhlo::TransposeOp op) {
362 forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op));
363 backwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op.operand()));
364 }
forwardTransposeShape__anon297788610111::ShapeVisitor365 void forwardTransposeShape(mhlo::TransposeOp op) {
366 auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op));
367 auto in = lookup(ShapeOrValueInfo::getShapeInfoOf(op.operand()));
368 auto elem = op.permutation().cast<DenseIntElementsAttr>();
369 for (const auto &val : elem) dims.push_back(in[val.getZExtValue()]);
370 }
backwardSelectShape__anon297788610111::ShapeVisitor371 void backwardSelectShape(mhlo::SelectOp op) {
372 forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op));
373 backwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op.on_true()));
374 }
forwardSelectShape__anon297788610111::ShapeVisitor375 void forwardSelectShape(mhlo::SelectOp op) {
376 auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op));
377 // Forward the `on_true` operand, it has the same shape as the output.
378 dims = lookup(ShapeOrValueInfo::getShapeInfoOf(op.on_true()));
379 }
backwardSameOperandsAndResultShape__anon297788610111::ShapeVisitor380 void backwardSameOperandsAndResultShape(Value v) {
381 forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(v));
382 backwardsWorklist.push_back(
383 ShapeOrValueInfo::getShapeInfoOf(v.getDefiningOp()->getOperand(0)));
384 }
forwardSameOperandsShape__anon297788610111::ShapeVisitor385 void forwardSameOperandsShape(Value v) {
386 auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(v));
387 dims = lookup(
388 ShapeOrValueInfo::getShapeInfoOf(v.getDefiningOp()->getOperand(0)));
389 }
backwardBlockArgumentShape__anon297788610111::ShapeVisitor390 void backwardBlockArgumentShape(BlockArgument argument) {
391 // JitRT uses rt.symbolic_shape to describe identical dimensions. Make
392 // use of that when it exists.
393 //
394 // Example:
395 // func @compute(
396 // %arg0: tensor<?xf32> {rt.symbolic_shape = dense<-2> :
397 // tensor<1xi64>}, %arg1: tensor<?xf32> {rt.symbolic_shape =
398 // dense<-2> : tensor<1xi64>})
399 // } { ... }
400 //
401 // Symbolic shape is a negative value smaller than `-1`. The concrete value
402 // is not known at compile time, and in this particular example it is only
403 // known that both arguments have the same shape.
404 //
405 // TODO(ezhulenev): Add symbolic shape attribute verifier to the jitrt
406 // dialect.
407 if (auto func = dyn_cast_or_null<func::FuncOp>(
408 argument.getOwner()->getParentOp())) {
409 if (auto shape = func.getArgAttrOfType<DenseIntElementsAttr>(
410 argument.getArgNumber(), "rt.symbolic_shape")) {
411 auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(argument));
412 auto id = getAffineSymbolExpr(0, argument.getContext());
413 for (const auto &symbol : llvm::enumerate(shape.getValues<ssize_t>())) {
414 dims.emplace_back();
415 auto &dim = dims.back();
416 if (symbol.value() >= 0) {
417 dim.expr =
418 getAffineConstantExpr(symbol.value(), argument.getContext());
419 } else {
420 auto it = symbolicShapeConstraintsMap->try_emplace(
421 symbol.value(),
422 Symbol{ShapeOrValueInfo::getShapeInfoOf(argument),
423 symbol.index()});
424 dim.symbols.push_back(it.first->second);
425 dim.expr = id;
426 }
427 }
428 return;
429 }
430 }
431 forwardUnknownShape(argument);
432 }
backwardUnknownShape__anon297788610111::ShapeVisitor433 void backwardUnknownShape(Value v) {
434 forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(v));
435 }
forwardUnknownShape__anon297788610111::ShapeVisitor436 void forwardUnknownShape(Value v) {
437 auto rankedTy = v.getType().dyn_cast<RankedTensorType>();
438 if (!rankedTy) return;
439 auto id = getAffineSymbolExpr(0, v.getContext());
440 auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(v));
441 return dimsFromStaticShape(
442 rankedTy,
443 [&](size_t i) {
444 SymbolicExpr d;
445 d.symbols.push_back({ShapeOrValueInfo::getShapeInfoOf(v), i});
446 d.expr = id;
447 return d;
448 },
449 &dims);
450 }
451
452 // ===
453 // Functions that traverse values. These can be shape tensors (e.g., of type
454 // tensor<3xindex>) or interesting scalars (e.g., of type index).
455 // ===
456
backwardShapeOf__anon297788610111::ShapeVisitor457 void backwardShapeOf(shape::ShapeOfOp op) {
458 forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
459 backwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op.getArg()));
460 }
forwardShapeOf__anon297788610111::ShapeVisitor461 void forwardShapeOf(shape::ShapeOfOp op) {
462 auto rankedTy = op.getArg().getType().cast<RankedTensorType>();
463 auto arg = lookup(ShapeOrValueInfo::getShapeInfoOf(op.getArg()));
464 auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
465 return dimsFromStaticShape(rankedTy, arg, &dims);
466 }
backwardNumElements__anon297788610111::ShapeVisitor467 void backwardNumElements(shape::NumElementsOp op) {
468 forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
469 backwardsWorklist.push_back(
470 ShapeOrValueInfo::getValueInfoOf(op.getShape()));
471 }
forwardNumElements__anon297788610111::ShapeVisitor472 void forwardNumElements(shape::NumElementsOp op) {
473 auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.getShape()));
474
475 // Accumulate product symbolically and concrete where possible.
476 int64_t concreteProduct = 1;
477 SymbolicExpr dim;
478 for (auto &it : in) {
479 // For constant expressions, we can accumulate a concrete product.
480 if (auto cexpr = it.expr.dyn_cast<AffineConstantExpr>()) {
481 assert(cexpr.getValue() > 0 && "shape value must be positive");
482 concreteProduct *= cexpr.getValue();
483 continue;
484 }
485
486 // Simply copy the first sybolic factor.
487 if (!dim.expr) {
488 dim = it;
489 continue;
490 }
491
492 // Multiply remaining symbolic factors.
493 dim.expr = dim.expr *
494 it.expr.shiftSymbols(dim.symbols.size(), it.symbols.size());
495 dim.symbols.append(it.symbols);
496 }
497
498 // Combine concrete and symbolic product.
499 if (concreteProduct != 1 || !dim.expr) {
500 auto cexpr = getAffineConstantExpr(concreteProduct, op.getContext());
501 if (dim.expr)
502 dim.expr = cexpr * dim.expr;
503 else
504 dim.expr = cexpr;
505 }
506
507 auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
508 dims.push_back(dim);
509 }
backwardDim__anon297788610111::ShapeVisitor510 void backwardDim(tensor::DimOp op) {
511 forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
512 backwardsWorklist.push_back(
513 ShapeOrValueInfo::getShapeInfoOf(op.getSource()));
514 }
forwardDim__anon297788610111::ShapeVisitor515 void forwardDim(tensor::DimOp op) {
516 auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
517 if (auto index = op.getIndex().getDefiningOp<arith::ConstantOp>()) {
518 int64_t i = index.getValue().cast<IntegerAttr>().getInt();
519 auto in = lookup(ShapeOrValueInfo::getShapeInfoOf(op.getSource()));
520 dims.push_back({in[i].symbols, in[i].expr});
521 } else {
522 forwardUnknown(op);
523 }
524 }
525 template <typename Op>
backwardBinOp__anon297788610111::ShapeVisitor526 void backwardBinOp(Op op) {
527 forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
528 // TODO(jpienaar): Switch to named accessors when MHLO uses prefixed form.
529 backwardsWorklist.append(
530 {ShapeOrValueInfo::getValueInfoOf(op.getOperand(0)),
531 ShapeOrValueInfo::getValueInfoOf(op.getOperand(1))});
532 }
533 template <typename Op, typename Combiner>
forwardBinOp__anon297788610111::ShapeVisitor534 void forwardBinOp(Op op, Combiner &&combiner) {
535 auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
536 // TODO(jpienaar): Switch to named accessors when MHLO uses prefixed form.
537 auto lhs = lookup(ShapeOrValueInfo::getValueInfoOf(op.getOperand(0)));
538 auto rhs = lookup(ShapeOrValueInfo::getValueInfoOf(op.getOperand(1)));
539 for (int64_t i = 0, e = dim0size(op.getType()); i != e; ++i) {
540 dims.emplace_back();
541 auto &dim = dims.back();
542 dim.symbols.append(lhs[i].symbols);
543 dim.symbols.append(rhs[i].symbols);
544 dim.expr = combiner(lhs[i].expr,
545 rhs[i].expr.shiftSymbols(rhs[i].symbols.size(),
546 lhs[i].symbols.size()));
547 }
548 }
backwardIndexCast__anon297788610111::ShapeVisitor549 void backwardIndexCast(arith::IndexCastOp op) {
550 forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
551 backwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op.getIn()));
552 }
forwardIndexCast__anon297788610111::ShapeVisitor553 void forwardIndexCast(arith::IndexCastOp op) {
554 auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
555 auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.getIn()));
556 for (int64_t i = 0, e = dim0size(op.getType()); i != e; ++i) {
557 // This is intentionally not modelling the truncation/zero extension of
558 // index_cast. While it's incorrect it doesn't really matter for shape
559 // computations.
560 dims.push_back({in[i].symbols, in[i].expr});
561 }
562 }
backwardTensorFromElements__anon297788610111::ShapeVisitor563 void backwardTensorFromElements(tensor::FromElementsOp op) {
564 forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
565 for (auto operand : op.getOperands())
566 backwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(operand));
567 }
forwardTensorFromElements__anon297788610111::ShapeVisitor568 void forwardTensorFromElements(tensor::FromElementsOp op) {
569 auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
570 for (auto operand : op.getOperands()) {
571 auto in = lookup(ShapeOrValueInfo::getValueInfoOf(operand));
572 assert(in.size() == 1);
573 dims.push_back({in[0].symbols, in[0].expr});
574 }
575 }
backwardTensorExtract__anon297788610111::ShapeVisitor576 void backwardTensorExtract(tensor::ExtractOp op) {
577 forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
578 backwardsWorklist.push_back(
579 ShapeOrValueInfo::getValueInfoOf(op.getTensor()));
580 }
forwardTensorExtract__anon297788610111::ShapeVisitor581 void forwardTensorExtract(tensor::ExtractOp op) {
582 auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
583 assert(op.getIndices().size() == 1);
584 if (auto index =
585 op.getIndices().front().getDefiningOp<arith::ConstantOp>()) {
586 int64_t i = index.getValue().cast<IntegerAttr>().getInt();
587 // We asssume this is in bounds.
588 auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.getTensor()));
589 dims.push_back({in[i].symbols, in[i].expr});
590 } else {
591 forwardUnknown(op);
592 }
593 }
backwardConstant__anon297788610111::ShapeVisitor594 void backwardConstant(Value v) {
595 forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(v));
596 }
forwardConstant__anon297788610111::ShapeVisitor597 void forwardConstant(Value v) {
598 IntegerAttr intAttr;
599 DenseIntElementsAttr denseAttr;
600 if (matchPattern(v, m_Constant(&denseAttr))) {
601 auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(v));
602 for (uint64_t i = 0, e = dim0size(v.getType()); i != e; ++i) {
603 dims.emplace_back();
604 auto &dim = dims.back();
605 dim.expr = getAffineConstantExpr(
606 denseAttr.getValues<APInt>()[i].getSExtValue(), v.getContext());
607 }
608 } else if (matchPattern(v, m_Constant(&intAttr))) {
609 auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(v));
610 dims.emplace_back();
611 auto &dim = dims.back();
612 dim.expr = getAffineConstantExpr(intAttr.getInt(), v.getContext());
613 } else {
614 forwardUnknown(v);
615 }
616 }
backwardConcatenate__anon297788610111::ShapeVisitor617 void backwardConcatenate(mhlo::ConcatenateOp op) {
618 forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
619 for (auto operand : op.getOperands())
620 backwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(operand));
621 }
forwardConcatenate__anon297788610111::ShapeVisitor622 void forwardConcatenate(mhlo::ConcatenateOp op) {
623 for (auto operand : op.getOperands()) {
624 auto in = lookup(ShapeOrValueInfo::getValueInfoOf(operand));
625 if (in.size() != 1) return forwardUnknown(op);
626 }
627 auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
628 for (auto operand : op.getOperands()) {
629 auto in = lookup(ShapeOrValueInfo::getValueInfoOf(operand));
630 dims.push_back({in[0].symbols, in[0].expr});
631 }
632 }
backwardReshape__anon297788610111::ShapeVisitor633 void backwardReshape(mhlo::ReshapeOp op) {
634 forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
635 backwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op.operand()));
636 }
forwardReshape__anon297788610111::ShapeVisitor637 void forwardReshape(mhlo::ReshapeOp op) {
638 auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.operand()));
639 if (in.size() != 1) return forwardUnknown(op);
640 auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
641 dims.push_back({in[0].symbols, in[0].expr});
642 }
backwardSlice__anon297788610111::ShapeVisitor643 void backwardSlice(mhlo::SliceOp op) {
644 forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op));
645 backwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(op.operand()));
646 }
forwardSlice__anon297788610111::ShapeVisitor647 void forwardSlice(mhlo::SliceOp op) {
648 // Only handle slices equivalent to an extract.
649 if (!op.getType().hasStaticShape({1})) {
650 return forwardUnknown(op);
651 }
652 auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op));
653 auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.operand()));
654 auto elem = op.start_indices().cast<DenseIntElementsAttr>();
655 auto i = (*elem.begin()).getZExtValue();
656 if (i >= in.size()) { // Bounds check.
657 return forwardUnknown(op);
658 }
659 dims.push_back({in[i].symbols, in[i].expr});
660 }
backwardUnknown__anon297788610111::ShapeVisitor661 void backwardUnknown(Value v) {
662 forwardsWorklist.push_back(ShapeOrValueInfo::getValueInfoOf(v));
663 }
forwardUnknown__anon297788610111::ShapeVisitor664 void forwardUnknown(Value v) {
665 auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(v));
666 auto id = getAffineSymbolExpr(0, v.getContext());
667 for (size_t i = 0, e = dim0size(v.getType()); i != e; ++i) {
668 dims.emplace_back();
669 auto &dim = dims.back();
670 dim.symbols.push_back({ShapeOrValueInfo::getValueInfoOf(v), i});
671 dim.expr = id;
672 }
673 }
674
675 // ===
676 // Helpers
677 // ===
678
dimsFromStaticShape__anon297788610111::ShapeVisitor679 static void dimsFromStaticShape(
680 RankedTensorType rankedTy,
681 llvm::function_ref<SymbolicExpr(int64_t)> fallback,
682 std::vector<SymbolicExpr> *mergedDims) {
683 auto *ctx = rankedTy.getContext();
684 for (int64_t i = 0, e = rankedTy.getRank(); i != e; ++i) {
685 if (rankedTy.isDynamicDim(i)) {
686 mergedDims->push_back(fallback(i));
687 } else {
688 mergedDims->emplace_back();
689 auto &d = mergedDims->back();
690 d.expr = getAffineConstantExpr(rankedTy.getDimSize(i), ctx);
691 }
692 }
693 }
694
dimsFromStaticShape__anon297788610111::ShapeVisitor695 static void dimsFromStaticShape(RankedTensorType rankedTy,
696 ArrayRef<SymbolicExpr> fallback,
697 std::vector<SymbolicExpr> *mergedDims) {
698 return dimsFromStaticShape(
699 rankedTy, [&](int64_t i) { return fallback[i]; }, mergedDims);
700 }
701
702 // Return the size of the first dimension. Returns 1 for scalars.
dim0size__anon297788610111::ShapeVisitor703 static int64_t dim0size(Type type) {
704 if (auto rankedType = type.dyn_cast<RankedTensorType>())
705 return rankedType.getRank() == 0 ? 1 : rankedType.getDimSize(0);
706 return 1;
707 }
708
709 // Retrieves the existing information from the cache.
lookup__anon297788610111::ShapeVisitor710 ArrayRef<SymbolicExpr> lookup(ShapeOrValueInfo requestedInfo) {
711 auto i = symbolicExprsMap->find(requestedInfo);
712 assert(i != symbolicExprsMap->end() && "op not processed yet?");
713 return llvm::makeArrayRef(i->second);
714 }
715
716 // Inserts a new entry into the cache and returns a reference to its result
717 // components.
insert__anon297788610111::ShapeVisitor718 std::vector<SymbolicExpr> &insert(ShapeOrValueInfo requestedInfo) {
719 auto i = symbolicExprsMap->try_emplace(requestedInfo);
720 assert(i.second && "op already processed?");
721 return i.first->second;
722 }
723
724 SymbolicExprsMap *symbolicExprsMap;
725 SymbolicShapeConstraintsMap *symbolicShapeConstraintsMap;
726
727 // Worklists for the forward and backward passes.
728 SmallVector<ShapeOrValueInfo> backwardsWorklist;
729 SmallVector<ShapeOrValueInfo> forwardsWorklist;
730 };
731 } // namespace
732
compute(ShapeOrValueInfo requestedInfo)733 void ShapeComponentAnalysis::compute(ShapeOrValueInfo requestedInfo) {
734 ShapeVisitor(&symbolicExprsMap, &symbolicShapeConstraintsMap)
735 .visit(requestedInfo);
736 }
737
738 Optional<ArrayRef<SymbolicExpr>>
GetShapeInfo(Value value)739 ShapeComponentAnalysis::ShapeComponentAnalysis::GetShapeInfo(Value value) {
740 auto request = ShapeOrValueInfo::getShapeInfoOf(value);
741 compute(request);
742 auto found = symbolicExprsMap.find(request);
743 if (found == symbolicExprsMap.end()) return {};
744 return llvm::makeArrayRef(found->second);
745 }
746
747 Optional<ArrayRef<SymbolicExpr>>
GetValueInfo(Value shape)748 ShapeComponentAnalysis::ShapeComponentAnalysis::GetValueInfo(Value shape) {
749 auto request = ShapeOrValueInfo::getValueInfoOf(shape);
750 compute(request);
751 auto found = symbolicExprsMap.find(request);
752 if (found == symbolicExprsMap.end()) return {};
753 return llvm::makeArrayRef(found->second);
754 }
755
reset()756 void ShapeComponentAnalysis::reset() {
757 symbolicExprsMap.clear();
758 symbolicShapeConstraintsMap.clear();
759 }
760
isConstant(int64_t value) const761 bool SymbolicExpr::isConstant(int64_t value) const {
762 return expr.isa<AffineConstantExpr>() &&
763 expr.cast<AffineConstantExpr>().getValue() == value;
764 }
765
isKnownNotNegativeOne() const766 bool SymbolicExpr::isKnownNotNegativeOne() const {
767 // If the symbol is coming from a shape it can't be a -1. Also allow results
768 // of shape_of, compute_reshape_shape, and num_elements. This is correct, not
769 // complete.
770 auto isGoodSymbol = [](const Symbol &symbol) {
771 if (symbol.source.isShapeInfo()) return true;
772 Operation *op = symbol.source.value().getDefiningOp();
773 if (op == nullptr) return false;
774 return llvm::isa<shape::ShapeOfOp, mhlo::ComputeReshapeShapeOp,
775 shape::NumElementsOp>(op);
776 };
777
778 // For constants we know if it's -1 or not. Checking the sign is sufficient
779 // here and allows for reuse below. This is correct, not complete.
780 auto isGoodSymbolOrGoodConstantExpr = [&](AffineExpr expr) {
781 if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
782 return isGoodSymbol(symbols[symExpr.getPosition()]);
783 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
784 return constExpr.getValue() >= 0;
785 return false;
786 };
787
788 if (isGoodSymbolOrGoodConstantExpr(expr)) return true;
789
790 // Multiplying non-negative symbols and non-negative constants will always
791 // give a positive result. This is correct, not complete.
792 // TODO(kramerb): Could the analysis provide a generic interface for this?
793 if (auto bexpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
794 return bexpr.getKind() == AffineExprKind::Mul &&
795 isGoodSymbolOrGoodConstantExpr(bexpr.getLHS()) &&
796 isGoodSymbolOrGoodConstantExpr(bexpr.getRHS());
797 }
798
799 return false;
800 }
801
isKnownNotOne() const802 bool SymbolicExpr::isKnownNotOne() const {
803 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
804 return constExpr.getValue() != 1;
805 }
806 return false;
807 }
808
singleton() const809 llvm::Optional<Symbol> SymbolicExpr::singleton() const {
810 if (expr.isa<AffineSymbolExpr>() &&
811 expr.cast<AffineSymbolExpr>().getPosition() == 0) {
812 assert(symbols.size() == 1);
813 return symbols[0];
814 }
815 return llvm::None;
816 }
817
dump(llvm::raw_ostream & os) const818 void SymbolicExpr::dump(llvm::raw_ostream &os) const {
819 expr.print(os);
820 if (!symbols.empty()) os << " with";
821 os << "\n";
822 if (symbols.empty()) return;
823 for (const auto &sym : llvm::enumerate(symbols)) {
824 os.indent(4);
825 os << 's' << sym.index() << " = ";
826 if (!sym.value().source.isValueInfo()) os << "shapeof(";
827 sym.value().source.value().print(os);
828 if (!sym.value().source.isValueInfo()) os << ")";
829 os << '[' << sym.value().index << "]\n";
830 }
831 }
832