1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file defines the operations used in the MHLO dialect.
17
18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
19
20 #include <assert.h>
21 #include <stddef.h>
22 #include <stdint.h>
23
24 #include <algorithm>
25 #include <cstdint>
26 #include <functional>
27 #include <numeric>
28 #include <set>
29 #include <unordered_map>
30 #include <utility>
31
32 #include "llvm/ADT/APFloat.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/DenseMap.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/SmallVector.h"
38 #include "llvm/ADT/StringExtras.h"
39 #include "llvm/ADT/StringRef.h"
40 #include "llvm/ADT/Twine.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/ADT/iterator_range.h"
43 #include "llvm/Support/Casting.h"
44 #include "llvm/Support/FormatVariadic.h"
45 #include "llvm/Support/MathExtras.h"
46 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
47 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
48 #include "mlir-hlo/utils/convert_op_folder.h"
49 #include "mlir-hlo/utils/hlo_utils.h"
50 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
51 #include "mlir/Dialect/Complex/IR/Complex.h"
52 #include "mlir/Dialect/Shape/IR/Shape.h"
53 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
54 #include "mlir/Dialect/Tensor/IR/Tensor.h"
55 #include "mlir/IR/Attributes.h"
56 #include "mlir/IR/Builders.h"
57 #include "mlir/IR/BuiltinAttributes.h"
58 #include "mlir/IR/BuiltinTypes.h"
59 #include "mlir/IR/Diagnostics.h"
60 #include "mlir/IR/Dialect.h"
61 #include "mlir/IR/FunctionInterfaces.h"
62 #include "mlir/IR/Location.h"
63 #include "mlir/IR/MLIRContext.h"
64 #include "mlir/IR/Matchers.h"
65 #include "mlir/IR/OpDefinition.h"
66 #include "mlir/IR/OpImplementation.h"
67 #include "mlir/IR/Operation.h"
68 #include "mlir/IR/OperationSupport.h"
69 #include "mlir/IR/PatternMatch.h"
70 #include "mlir/IR/TypeUtilities.h"
71 #include "mlir/IR/Types.h"
72 #include "mlir/IR/Value.h"
73 #include "mlir/Support/LLVM.h"
74 #include "mlir/Support/LogicalResult.h"
75 #include "mlir/Transforms/InliningUtils.h"
76
77 namespace mlir {
78 #include "hlo_patterns.cc.inc"
79 } // namespace mlir
80
81 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.cc.inc"
82 #define GET_ATTRDEF_CLASSES
83 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.cc.inc"
84
85 namespace mlir {
86 namespace mhlo {
87 namespace {
createArgs(ArrayRef<OpAsmParser::UnresolvedOperand> operands,ArrayRef<Type> types,SmallVector<OpAsmParser::Argument> & args)88 void createArgs(ArrayRef<OpAsmParser::UnresolvedOperand> operands,
89 ArrayRef<Type> types,
90 SmallVector<OpAsmParser::Argument>& args) {
91 for (auto argAndType : llvm::zip(operands, types)) {
92 auto& arg = args.emplace_back();
93 arg.ssaName = std::get<0>(argAndType);
94 arg.type = std::get<1>(argAndType);
95 }
96 }
97
__anon00baf10a0202(SmallVector<int64_t>& nums) 98 const auto hasDuplicates = [](SmallVector<int64_t>& nums) {
99 if (!llvm::is_sorted(nums)) std::sort(nums.begin(), nums.end());
100 auto* last = std::unique(nums.begin(), nums.end());
101 return last != nums.end();
102 };
103
104 //===----------------------------------------------------------------------===//
105 // Utilities for the canonicalize patterns
106 //===----------------------------------------------------------------------===//
107
108 // This is an upper limit on how many elements can be folded by an op folder.
109 // This limit doesn't apply to some special cases like adding a zero,
110 // multiplying by one, doing many operations with splats.
111 constexpr int64_t kFoldOpEltLimit = 65536;
112
113 // Clamps value to the range [lower, upper]. Requires lower <= upper.
114 template <typename T>
clamp(const T & value,const T & lower,const T & upper)115 static T clamp(const T& value, const T& lower, const T& upper) {
116 assert(lower <= upper);
117 return std::max(lower, std::min(value, upper));
118 }
119
120 // Verifies that dimension attribute for the op correctly indexes in operand or
121 // result shape.
122 template <typename OpT>
verifyDimAttr(OpT op)123 static LogicalResult verifyDimAttr(OpT op) {
124 int64_t rank = -1;
125 if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) {
126 rank = ty.getRank();
127 } else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) {
128 rank = ty.getRank();
129 } else {
130 return success();
131 }
132
133 int64_t dim = op.dimension();
134 if (dim < 0 || dim >= rank)
135 return op.emitOpError() << "requires dimension attribute in range [0, "
136 << rank << "); found (" << dim << ")";
137 return success();
138 }
139
140 // Given the start indices and slice sizes for a dynamic-slice that can be
141 // converted to a static slice, returns the limits for the static slice.
buildSliceLimits(DenseIntElementsAttr startIndices,DenseIntElementsAttr sliceSizes,Builder * builder)142 DenseIntElementsAttr buildSliceLimits(DenseIntElementsAttr startIndices,
143 DenseIntElementsAttr sliceSizes,
144 Builder* builder) {
145 SmallVector<int64_t, 4> sliceLimits;
146 for (int64_t i = 0; i < sliceSizes.getNumElements(); ++i) {
147 int64_t startIndex = startIndices.getValues<IntegerAttr>()[i].getInt();
148 int64_t sliceSize = sliceSizes.getValues<IntegerAttr>()[i].getInt();
149 sliceLimits.push_back(startIndex + sliceSize);
150 }
151 return builder->getI64TensorAttr(sliceLimits);
152 }
153
154 /// Replaces the given op with the contents of the given single-block region,
155 /// using the operands of the block terminator to replace operation results.
replaceOpWithRegion(PatternRewriter & rewriter,Operation * op,Region & region,ValueRange blockArgs={})156 static void replaceOpWithRegion(PatternRewriter& rewriter, Operation* op,
157 Region& region, ValueRange blockArgs = {}) {
158 assert(llvm::hasSingleElement(region) && "expected single-block region");
159 Block* block = ®ion.front();
160 Operation* terminator = block->getTerminator();
161 ValueRange results = terminator->getOperands();
162 rewriter.mergeBlockBefore(block, op, blockArgs);
163 rewriter.replaceOp(op, results);
164 rewriter.eraseOp(terminator);
165 }
166
167 #include "mhlo_canonicalize.inc"
168
169 // Check if the dimension size is dynamic.
isDynamicDimSize(int64_t val)170 inline static bool isDynamicDimSize(int64_t val) {
171 return val == ShapedType::kDynamicSize;
172 }
173
174 // Common shape function helper for RngNormal and RngUniform.
rngInferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)175 static LogicalResult rngInferReturnTypeComponents(
176 MLIRContext* context, Optional<Location> location, ValueRange operands,
177 DictionaryAttr attributes, RegionRange regions,
178 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
179 if (operands.size() != 3)
180 return emitOptionalError(location, "expected 3 operands");
181
182 SmallVector<int64_t> shapeVector;
183 Value shapeOperand = operands[2];
184 auto shapeOperandType = shapeOperand.getType().cast<ShapedType>();
185 Type elementType = getElementTypeOrSelf(operands[1]);
186
187 // Operand `shape` (1D by ODS) may be a constant or not, if `shape` is:
188 // 1, not constant and have dynimic dim (tensor<?x>): infer tensor<*x>.
189 // 2. not constant nor dynimic (e.g. tensor<3xi64>): infer tensor<?x?x?x>.
190 // 3. constant (e.g. dense<[2, 3, 5]>): infer tensor<2x3x5x>.
191
192 // Match to check whether the `shape` operand is a constant.
193 DenseIntElementsAttr shape;
194 if (!matchPattern(shapeOperand, m_Constant(&shape))) {
195 int size = shapeOperandType.getDimSize(0);
196 if (isDynamicDimSize(size)) {
197 inferredReturnShapes.emplace_back(elementType);
198 return success();
199 }
200 shapeVector.resize(size, ShapedType::kDynamicSize);
201 inferredReturnShapes.emplace_back(shapeVector, elementType);
202 return success();
203 }
204
205 // `shape` operand is a constant.
206 shapeVector.reserve(shape.size());
207 for (const APInt& fp : shape.getValues<APInt>())
208 shapeVector.push_back(fp.getSExtValue());
209 inferredReturnShapes.emplace_back(shapeVector, elementType);
210 return success();
211 }
212
213 // Returns a new scalar integer value having type `type`. Here `type` must be
214 // an integer or index type.
maybeCastTo(OpBuilder & b,Location loc,Value value,Type type)215 Value maybeCastTo(OpBuilder& b, Location loc, Value value, Type type) {
216 if (type == value.getType()) return value;
217 assert(type.isIndex() || value.getType().isIndex());
218 return b.create<arith::IndexCastOp>(loc, type, value);
219 }
220
reshape(DenseElementsAttr attr,ShapedType newType)221 DenseElementsAttr reshape(DenseElementsAttr attr, ShapedType newType) {
222 // TODO(b/232866626): DenseElementsAttr::reshape is broken for bool splats.
223 // Once that ticket is fixed, we can remove this conditional.
224 if (attr.isSplat() && newType.getElementType().isInteger(/*width=*/1)) {
225 auto splatValue = attr.getValues<bool>()[0];
226 return DenseElementsAttr::get(newType, {splatValue});
227 }
228 return attr.reshape(newType);
229 }
230
231 //===----------------------------------------------------------------------===//
232 // Utilities for verifiers
233 //===----------------------------------------------------------------------===//
234
235 // Convert a 1D dense int64 attribute to a list of values.
convertDenseIntAttr(llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr)236 SmallVector<int64_t> convertDenseIntAttr(
237 llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr) {
238 if (!optionalAttr.has_value()) return SmallVector<int64_t>{};
239
240 mlir::DenseIntElementsAttr attr = *optionalAttr;
241 auto values = attr.getValues<int64_t>();
242 return {values.begin(), values.end()};
243 }
244
245 // Convert a 1D or Nx2 dense int64 attribute to a list of tuples.
convertNx2Attribute(llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr,Location loc)246 FailureOr<SmallVector<std::pair<int64_t, int64_t>>> convertNx2Attribute(
247 llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr, Location loc) {
248 if (!optionalAttr.has_value())
249 return SmallVector<std::pair<int64_t, int64_t>>{};
250 mlir::DenseIntElementsAttr attr = *optionalAttr;
251
252 auto attrType = attr.getType().cast<RankedTensorType>(); // ensured by ODS.
253 if (attrType.getRank() > 1) {
254 if (attrType.getRank() != 2 || attrType.getShape()[1] != 2)
255 return (mlir::emitError(loc) << "expects the shape of padding-attribute "
256 "to be {N, 2}, but got {"
257 << attrType.getShape() << "}.",
258 failure());
259 } else {
260 // Padding values can be provided as a 1D vector as well.
261 if (attr.getValues<int64_t>().size() % 2 != 0)
262 return (mlir::emitError(loc)
263 << "expects the padding-entries to have even number of "
264 "elements, but got "
265 << attr.getValues<int64_t>().size() << " elements.",
266 failure());
267 }
268
269 auto it = attr.getValues<int64_t>().begin();
270 SmallVector<std::pair<int64_t, int64_t>> out(attr.getNumElements() / 2);
271 for (auto& item : out) {
272 int64_t first = *it;
273 ++it;
274 int64_t second = *it;
275 ++it;
276 item = {first, second};
277 }
278 return out;
279 }
280
281 // If a window with the given bound in some dimension is dilated with the given
282 // dilation factor in that dimension, then the value returned is the bound for
283 // the array in that dimension after dilation.
284 //
285 // For a 1D array with 3 entries 1, 2, 3, a dilation factor of 2 yields a new
286 // window with values 1, x, 2, x, 3, where x indicates holes left by the
287 // dilation. So DilatedBound(3, 2) == 5.
dilatedBound(int64_t bound,int64_t dilation)288 int64_t dilatedBound(int64_t bound, int64_t dilation) {
289 assert(bound >= 0 && "The dimension to dialate must be >= 0");
290 if (bound == 0) return 0;
291
292 // Suppose the array has three entries 123 and the dilation factor is 4. Then
293 // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except
294 // the last expands into 4 entries, so that is (bound - 1) * dilation. Then we
295 // add 1 to account for the final input element.
296 return (bound - 1) * dilation + 1;
297 }
298
299 // Returns the number of valid positions of a window with the given size and
300 // stride within an array with the given bound. This is the bound of an output
301 // array with one element per valid position of the window.
302 //
303 // For example, for arguments of (bound=5, window_size=2, stride=2), the
304 // returned value is 2. There are valid positions at offset 0 and offset 2,
305 // while offset 4 is not valid since the window's last entry would be at 5,
306 // which is beyond the bound of 5.
stridedBound(int64_t bound,int64_t windowSize,int64_t stride)307 int64_t stridedBound(int64_t bound, int64_t windowSize, int64_t stride) {
308 assert(windowSize >= 0 && "Expected window size to be >= 0");
309 assert(bound >= 0 && "Expected bound to be >= 0");
310
311 if (bound == 0 || windowSize > bound) return 0;
312
313 // Without considering stride, the maximum valid offset is bound -
314 // window_size. Taking stride into account, the valid offsets then have the
315 // form q * stride for q = 0, ..., Q such that q * stride <= bound -
316 // window_size. This implies that Q equals floor(bound - window_size /
317 // stride). There are Q + 1 valid values of q, yielding the formula below.
318 return (bound - windowSize) / stride + 1;
319 }
320
321 // WindowDimension described how the kernel window moves across the base area
322 // in a particular dimension.
323 // Describes the windowing in an operation such as convolution.
324 // The window is moved across a base area and for each position of the
325 // window a computation is performed. The field below describes the
326 // window and the movement of the window across a base area.
327 struct WindowDimension {
328 int64_t size = 0;
329 int64_t stride = 1;
330 int64_t paddingLow = 0;
331 int64_t paddingHigh = 0;
332 int64_t windowDilation = 1;
333 int64_t baseDilation = 1;
334 bool windowReversal = false;
335 };
336
337 // Verifies various properties of window-attributes (viz., stride, padding,
338 // lhs_dilation and rhs_dilation) and collects all the window-attributes for
339 // each kernel spatial dimensions.
340 FailureOr<SmallVector<WindowDimension>>
verifyWindowAttributesAndInferWindowDimensions(ArrayRef<int64_t> windowDimensions,ArrayRef<int64_t> windowStrides,ArrayRef<std::pair<int64_t,int64_t>> padding,ArrayRef<int64_t> lhsDilation,ArrayRef<int64_t> rhsDilation,Location loc)341 verifyWindowAttributesAndInferWindowDimensions(
342 ArrayRef<int64_t> windowDimensions, ArrayRef<int64_t> windowStrides,
343 ArrayRef<std::pair<int64_t, int64_t>> padding,
344 ArrayRef<int64_t> lhsDilation, ArrayRef<int64_t> rhsDilation,
345 Location loc) {
346 const auto verifySize = [&](const size_t attrSize,
347 StringRef attrName) -> LogicalResult {
348 if (attrSize == 0 || attrSize == windowDimensions.size()) return success();
349 return mlir::emitError(loc)
350 << "expects " << attrName
351 << " to have same dimension-size as size of "
352 "window dimensions "
353 "("
354 << windowDimensions.size() << "), but got: " << attrSize << ".";
355 };
356
357 if (failed(verifySize(windowStrides.size(), "window-strides")))
358 return failure();
359 if (failed(verifySize(lhsDilation.size(), "base-dilation factors")))
360 return failure();
361 if (failed(verifySize(rhsDilation.size(), "window-dilation factors")))
362 return failure();
363 if (failed(verifySize(padding.size(), "padding-entries"))) return failure();
364
365 SmallVector<WindowDimension> window(windowDimensions.size());
366 for (size_t i = 0; i < windowDimensions.size(); i++) {
367 WindowDimension& dim = window[i];
368
369 dim.size = windowDimensions[i];
370 if (!isDynamicDimSize(dim.size) && dim.size <= 0)
371 return (mlir::emitError(loc)
372 << "expects window to have positive value for " << i
373 << "-th window dimension, but got " << dim.size << ".",
374 failure());
375
376 if (!windowStrides.empty()) dim.stride = windowStrides[i];
377 if (dim.stride <= 0)
378 return (mlir::emitError(loc)
379 << "expects window to have positive stride for " << i
380 << "-th window dimension, but got " << dim.stride << ".",
381 failure());
382
383 if (!lhsDilation.empty()) dim.baseDilation = lhsDilation[i];
384 if (dim.baseDilation <= 0)
385 return (mlir::emitError(loc) << "expects window to have positive base "
386 "dilation factor for "
387 << i << "-th window dimension, but got "
388 << dim.baseDilation << ".",
389 failure());
390
391 if (!rhsDilation.empty()) dim.windowDilation = rhsDilation[i];
392 if (dim.windowDilation <= 0)
393 return (mlir::emitError(loc) << "expects window to have positive window "
394 "dilation factor for "
395 << i << "-th window dimension, but got "
396 << dim.windowDilation << ".",
397 failure());
398
399 if (!padding.empty()) {
400 dim.paddingLow = padding[i].first;
401 dim.paddingHigh = padding[i].second;
402 }
403 }
404
405 return window;
406 }
407
408 // Infer the shape of the output window.
409 // Foreach dimension d,
410 // output-window-shape[d] =
411 // stridedBound(padding_low + dilatedBound(base_shape[d]) +
412 // padding_high,
413 // dilatedBound(window_shape[d]))
414 // where (padding_low, padding_high) is the padding-pair for d.
inferWindowOutputShape(const ArrayRef<int64_t> baseShape,const ArrayRef<WindowDimension> window)415 SmallVector<int64_t> inferWindowOutputShape(
416 const ArrayRef<int64_t> baseShape, const ArrayRef<WindowDimension> window) {
417 assert(baseShape.size() == window.size() &&
418 "Size of window dimensions must match the size of base shape.");
419
420 SmallVector<int64_t> outputDimensions(window.size());
421 for (int64_t i = 0; i < static_cast<int64_t>(window.size()); ++i) {
422 if (isDynamicDimSize(baseShape[i]) || isDynamicDimSize(window[i].size)) {
423 outputDimensions[i] = ShapedType::kDynamicSize;
424 } else {
425 const auto& dim = window[i];
426
427 const int64_t dilatedBase = dilatedBound(baseShape[i], dim.baseDilation);
428 const int64_t paddedDilatedBase =
429 dim.paddingLow + dilatedBase + dim.paddingHigh;
430 const int64_t dilatedWindow = dilatedBound(dim.size, dim.windowDilation);
431
432 outputDimensions[i] =
433 stridedBound(paddedDilatedBase, dilatedWindow, dim.stride);
434 }
435 }
436
437 return outputDimensions;
438 }
439
440 // Return true if type1 and type2 are tensors and have the same
441 // element-type, else return false. With float element-types, ignore comparing
442 // floating-point precision if ignoreFpPrecision is True.
tensorsHaveSameElType(Type type1,Type type2,bool ignoreFpPrecision)443 bool tensorsHaveSameElType(Type type1, Type type2, bool ignoreFpPrecision) {
444 auto tensorTy1 = type1.dyn_cast<TensorType>();
445 auto tensorTy2 = type2.dyn_cast<TensorType>();
446
447 if (!tensorTy1 || !tensorTy2) return false;
448
449 if (ignoreFpPrecision && tensorTy1.getElementType().isa<FloatType>() &&
450 tensorTy2.getElementType().isa<FloatType>())
451 return true;
452
453 return tensorTy1.getElementType() == tensorTy2.getElementType();
454 }
455
456 // Return true if type1 and type2 are shape-compatible and have same element
457 // type. If 'ignoreFpPrecision' is True, then allow floats with different
458 // precisions while checking element-types.
compatibleShapeAndElementType(Type type1,Type type2,bool ignoreFpPrecision=false)459 bool compatibleShapeAndElementType(Type type1, Type type2,
460 bool ignoreFpPrecision = false) {
461 if (failed(verifyCompatibleShape(type1, type2))) return false;
462 return tensorsHaveSameElType(type1.cast<ShapedType>(),
463 type2.cast<ShapedType>(), ignoreFpPrecision);
464 }
465
verifyReducerShape(Location loc,Block & block,ArrayRef<TensorType> inputArgTypes,ArrayRef<TensorType> initValueTypes,int64_t numInputs,ArrayRef<int64_t> allowedDimensions,bool allInputsUnranked,SmallVectorImpl<TensorType> & accumulatorSubShapes)466 LogicalResult verifyReducerShape(
467 Location loc, Block& block, ArrayRef<TensorType> inputArgTypes,
468 ArrayRef<TensorType> initValueTypes, int64_t numInputs,
469 ArrayRef<int64_t> allowedDimensions, bool allInputsUnranked,
470 SmallVectorImpl<TensorType>& accumulatorSubShapes) {
471 // Check that the number of reduction-region arguments matches with that of
472 // reduce-op's arguments.
473 if (static_cast<int64_t>(block.getArguments().size()) != numInputs * 2)
474 return mlir::emitError(loc)
475 << "Reduction-region must take " << numInputs * 2
476 << " parameters, but takes " << block.getArguments().size()
477 << " parameter(s)";
478
479 // Check if the reduction-region produces non-zero outputs.
480 if (block.getTerminator()->getOperands().empty())
481 return mlir::emitError(loc)
482 << "The reduction-region expected to return some value(s)";
483
484 // Check that the reduction-region returns list- of tensors.
485 // The number of result-tensors must match the `numInputs`.
486 if (static_cast<int64_t>(block.getTerminator()->getOperands().size()) !=
487 numInputs)
488 return mlir::emitError(loc)
489 << "Reduction-region here must produce " << numInputs
490 << " tensors, but produces "
491 << block.getTerminator()->getOperands().size() << " instead";
492
493 for (Value retOperand : block.getTerminator()->getOperands()) {
494 auto tensorTy = retOperand.getType().dyn_cast<TensorType>();
495 if (!tensorTy)
496 return mlir::emitError(loc) << "Reduction-region here must produce "
497 "tensor-typed result(s), but "
498 "produces "
499 << retOperand.getType() << " instead";
500
501 accumulatorSubShapes.push_back(tensorTy);
502 }
503
504 // Consider typical reduce-* op syntax:
505 //
506 // op(I(i), V(j)):
507 // block(BI(i), BV(j)):
508 // ... some computation ...
509 // return(R(i))
510 //
511 // where
512 // I(i) : i-th input of op
513 // V(j) : j-th init-value of op
514 // BI(i) : i-th input of reducer-function
515 // BV(j) : j-th init-value of reducer-function
516 // R(i) : i-th return-type
517 //
518 // Note that: |I(i)| == V(j)| == |BI(i)| == |BV(j)| == |R(i)|
519 //
520 // Here are the type-constraints among V(j), BI(i), BV(j), and R(i).
521 // C1 : Check that BI(i) and R(i) have same shape and element-type.
522 // C2 : Check that BV(j) and R(i) have same shape and element-type.
523 // C3 : Check that V(j) and R(i) have same shape and element-type.
524 //
525 // From C1, C2, and C3, we can infer that V(j), BI(i), BV(j), and R(i) all
526 // have compatible shapes and element-types.
527 // The next check, C4, adds constraints on how the type if I(i) is related
528 // to any_of(V(j), BI(i), BV(j), and R(i)), say BV(j);
529 //
530 // C4.1 : Check that I(i) and BV(j) have same element-type.
531 // C4.2 : Check that shape of BV(j) is a 'sub-sequence' of
532 // 'allowedDimensions'. 'allowedDimensions' is a list of dimensions
533 // which any of BI(i), BV(j), and R(i) is allowed to have.
534 for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
535 // Check C1.
536 if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
537 block.getArgument(inputIdx).getType()))
538 return mlir::emitError(loc)
539 << "The type of reduction-region's parameter at index " << inputIdx
540 << " is different than the corresponding result type: "
541 << block.getArgument(inputIdx).getType() << " vs "
542 << accumulatorSubShapes[inputIdx];
543
544 // Check C2.
545 if (!compatibleShapeAndElementType(
546 accumulatorSubShapes[inputIdx],
547 block.getArgument(numInputs + inputIdx).getType(),
548 /*ignoreFpPrecision=*/true))
549 return mlir::emitError(loc)
550 << "The type of reduction-region's parameter at index "
551 << numInputs + inputIdx
552 << " is different than the corresponding result type: "
553 << block.getArgument(numInputs + inputIdx).getType() << " vs "
554 << accumulatorSubShapes[inputIdx];
555
556 // Check C3.
557 if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
558 initValueTypes[inputIdx],
559 /*ignoreFpPrecision=*/true))
560 return mlir::emitError(loc)
561 << "The type of reduction-region's result type at index "
562 << inputIdx
563 << " differs from the op's corresponding init-value type: "
564 << accumulatorSubShapes[inputIdx] << " vs "
565 << initValueTypes[inputIdx];
566
567 // Check C4.1.
568 if (!tensorsHaveSameElType(
569 inputArgTypes[inputIdx],
570 block.getArgument(numInputs + inputIdx).getType(), true))
571 return mlir::emitError(loc)
572 << "The element-type of reduction-region's argument at index "
573 << numInputs + inputIdx << " is expected to be "
574 << inputArgTypes[inputIdx].getElementType() << ", but got "
575 << block.getArgument(numInputs + inputIdx).getType()
576 << " as its type.";
577
578 // Check C4.2.
579 Type blockArgType = block.getArgument(numInputs + inputIdx).getType();
580 auto blockArgTensorTy = blockArgType.cast<TensorType>();
581
582 if (allInputsUnranked || !blockArgTensorTy.hasRank()) return success();
583
584 auto argShape = blockArgTensorTy.getShape();
585 if (argShape.size() > allowedDimensions.size())
586 return mlir::emitError(loc)
587 << "The rank of reduction-region's argument at index "
588 << numInputs + inputIdx
589 << " is expected to be <= " << allowedDimensions.size() << ", got "
590 << argShape.size();
591
592 int64_t argShapeIdx = 0;
593 for (int64_t outputShapeIdx = 0;
594 outputShapeIdx < static_cast<int64_t>(allowedDimensions.size()) &&
595 argShapeIdx < static_cast<int64_t>(argShape.size());
596 outputShapeIdx++)
597 if (allowedDimensions[outputShapeIdx] == argShape[argShapeIdx])
598 argShapeIdx++;
599
600 if (argShapeIdx != static_cast<int64_t>(argShape.size()))
601 return mlir::emitError(loc)
602 << "The shape of reduction-region's argument at index "
603 << numInputs + inputIdx
604 << " is not compatible with that of reduce-op's input-parameter "
605 "at index "
606 << inputIdx;
607 }
608
609 return success();
610 }
611
potentiallyComplexBitwidth(Type type)612 unsigned potentiallyComplexBitwidth(Type type) {
613 auto complexTy = type.dyn_cast<ComplexType>();
614 return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth()
615 : type.getIntOrFloatBitWidth();
616 }
617 } // namespace
618
619 //===----------------------------------------------------------------------===//
620 // AllReduceOp
621 //===----------------------------------------------------------------------===//
622
build(::mlir::OpBuilder & ods_builder,::mlir::OperationState & ods_state,::mlir::Type result_type,::mlir::Value operand,::mlir::DenseIntElementsAttr replica_groups,::mlir::mhlo::ChannelHandleAttr channel_handle)623 void AllReduceOp::build(
624 ::mlir::OpBuilder& ods_builder, ::mlir::OperationState& ods_state,
625 ::mlir::Type result_type, ::mlir::Value operand,
626 ::mlir::DenseIntElementsAttr replica_groups,
627 /*optional*/ ::mlir::mhlo::ChannelHandleAttr channel_handle) {
628 AllReduceOp::build(ods_builder, ods_state, result_type, operand,
629 replica_groups, channel_handle, nullptr);
630 }
631
632 //===----------------------------------------------------------------------===//
633 // ReduceScatterOp
634 //===----------------------------------------------------------------------===//
635
verify()636 LogicalResult ReduceScatterOp::verify() {
637 if (failed(mlir::hlo::verifyReplicaGroups(*this, /*is_uniform_sized=*/true)))
638 return failure();
639 auto operandType = operand().getType().cast<TensorType>();
640 bool operandTypeRanked = operandType.isa<RankedTensorType>();
641 Block& block = computation().front();
642 SmallVector<TensorType> accumulatorSubshapes;
643 if (failed(verifyReducerShape(
644 this->getLoc(), block, {operandType},
645 {RankedTensorType::get({}, operandType.getElementType())},
646 /*numInputs=*/1, /*allowedDimensions=*/{},
647 /*allInputsUnranked=*/!operandTypeRanked, accumulatorSubshapes)))
648 return failure();
649
650 return mlir::hlo::verifyReduceScatter(
651 *this,
652 /*operand_types=*/{operand().getType()},
653 /*result_types=*/{getType()},
654 /*scatter_dimension=*/scatter_dimension());
655 }
656
657 //===----------------------------------------------------------------------===//
658 // CompatibleOperandsAndResultType
659 //===----------------------------------------------------------------------===//
660
661 // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that
662 // support quantization or sparsity.
663 #define INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Op) \
664 LogicalResult Op::inferReturnTypeComponents( \
665 MLIRContext* context, Optional<Location> location, \
666 ValueShapeRange operands, DictionaryAttr attributes, \
667 RegionRange regions, \
668 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { \
669 return inferReturnTypeComponentsFromOperands(context, location, operands, \
670 attributes, regions, \
671 inferredReturnShapes); \
672 }
673
674 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AddOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AllReduceOp)675 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AllReduceOp)
676 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AndOp)
677 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op)
678 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp)
679 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CeilOp)
680 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ClzOp)
681 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectivePermuteOp)
682 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CopyOp)
683 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CosineOp)
684 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CrossReplicaSumOp)
685 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DivOp)
686 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DomainOp)
687 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ExpOp)
688 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Expm1Op)
689 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(FloorOp)
690 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LogOp)
691 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Log1pOp)
692 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LogisticOp)
693 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MaxOp)
694 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MinOp)
695 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MulOp)
696 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NegOp)
697 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NotOp)
698 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(OrOp)
699 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PopulationCountOp)
700 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PowOp)
701 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ReducePrecisionOp)
702 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RemOp)
703 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ReverseOp)
704 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RoundNearestEvenOp)
705 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RoundOp)
706 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RsqrtOp)
707 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftLeftOp)
708 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftRightArithmeticOp)
709 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftRightLogicalOp)
710 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SignOp)
711 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SineOp)
712 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SqrtOp)
713 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SubtractOp)
714 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(TanhOp)
715 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(XorOp)
716
717 //===----------------------------------------------------------------------===//
718 // ConstantOp
719 //===----------------------------------------------------------------------===//
720
721 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
722 assert(operands.empty() && "constant has no operands");
723
724 // Return the held attribute value.
725 return value();
726 }
727
728 // Builds a constant op with the specified attribute `value`.
build(OpBuilder &,OperationState & result,Attribute value)729 void ConstantOp::build(OpBuilder& /*builder*/, OperationState& result,
730 Attribute value) {
731 Type type;
732 if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
733 type = elemAttr.getType();
734 } else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
735 // All XLA types must be tensor types. In the build() method, we want to
736 // provide more flexibility by allowing attributes of scalar types. But we
737 // need to wrap it up with ElementsAttr to construct valid XLA constants.
738 type =
739 RankedTensorType::get(/*shape=*/{}, value.cast<TypedAttr>().getType());
740 value = DenseElementsAttr::get(type.cast<TensorType>(), value);
741 } else if (auto complexAttr = value.dyn_cast<complex::NumberAttr>()) {
742 type = RankedTensorType::get(/*shape=*/{},
743 complexAttr.cast<TypedAttr>().getType());
744 value =
745 DenseElementsAttr::get(type.cast<TensorType>(), complexAttr.getValue());
746 }
747
748 // TODO: support other XLA specific types.
749 assert(type && "unsupported attribute type for building mhlo.constant");
750 result.types.push_back(type);
751 result.addAttribute("value", value);
752 }
753
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)754 LogicalResult ConstantOp::inferReturnTypes(
755 MLIRContext*, Optional<Location>, ValueRange operands,
756 DictionaryAttr attributes, RegionRange,
757 SmallVectorImpl<Type>& inferredReturnTypes) {
758 ConstantOpAdaptor adaptor(operands, attributes);
759 Type type = adaptor.value().getType();
760 inferredReturnTypes.push_back(type);
761 return success();
762 }
763
isCompatibleReturnTypes(TypeRange l,TypeRange r)764 bool ConstantOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
765 if (l.size() != r.size() || l.size() != 1) return false;
766 auto lhsTy = l.front().cast<TensorType>();
767 auto rhsTy = r.front().cast<TensorType>();
768 // For comparisons of the uniform quantized element based tensor type, use the
769 // storage type since the constant value will be stored through the underlying
770 // storage type.
771 if (auto rhsElemTy =
772 rhsTy.getElementType().dyn_cast<quant::QuantizedType>()) {
773 rhsTy = getSameShapeTensorType(rhsTy, rhsElemTy.getStorageType());
774 }
775 return lhsTy == rhsTy;
776 }
777
parse(OpAsmParser & parser,OperationState & result)778 ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) {
779 // Parse the generic form.
780 if (succeeded(parser.parseOptionalLParen())) {
781 if (parser.parseRParen()) return failure();
782 if (parser.parseOptionalAttrDict(result.attributes)) return failure();
783 if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() ||
784 parser.parseArrow())
785 return failure();
786 Type resultTy;
787 if (parser.parseType(resultTy)) {
788 return failure();
789 }
790 result.addTypes(resultTy);
791 return success();
792 }
793
794 ElementsAttr valueAttr;
795 if (parser.parseOptionalAttrDict(result.attributes)) return failure();
796
797 if (parser.parseCustomAttributeWithFallback(valueAttr, Type{}, "value",
798 result.attributes)) {
799 return failure();
800 }
801 result.addTypes(valueAttr.getType());
802 return success();
803 }
804
805 /// Print a `constant` op.
806 ///
807 /// op ::= attr-dict $value
808 ///
809 /// When the `value` and `output` have different type, it just uses the default
810 /// operator assembly format as a fallback.
print(::mlir::OpAsmPrinter & p)811 void ConstantOp::print(::mlir::OpAsmPrinter& p) {
812 // If not all types are the same, use generic form.
813 if (value().getType() != getType()) {
814 p.printGenericOp(getOperation(), /*printOpName=*/false);
815 return;
816 }
817
818 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
819 p << ' ';
820 p.printStrippedAttrOrType(valueAttr());
821 }
822
823 //===----------------------------------------------------------------------===//
824 // CustomCallOp
825 //===----------------------------------------------------------------------===//
826
verify()827 LogicalResult CustomCallOp::verify() {
828 // If both operand and result layout attributes are not specified then nothing
829 // to verify.
830 if (!operand_layouts().has_value() && !result_layouts().has_value())
831 return success();
832
833 // Layout constraints for either both operands & results or none should be
834 // specified.
835 if (operand_layouts().has_value() != result_layouts().has_value())
836 return emitOpError() << "Layout attributes should be specified for "
837 "either both operands and results or none.";
838
839 // Helper function to verify types and the corresponding layouts.
840 auto verifyTypesAndLayouts =
841 [this](TypeRange types, mlir::ArrayAttr layouts,
842 const std::string& valueName) -> LogicalResult {
843 if (types.size() != layouts.size())
844 return emitOpError() << "Number of " << valueName
845 << "s must match the number of " << valueName
846 << " layouts, " << types.size()
847 << " != " << layouts.size();
848
849 for (const auto& indexedTypeAndLayout :
850 llvm::enumerate(llvm::zip(types, layouts))) {
851 // Get index for more descriptive error message.
852 auto index = indexedTypeAndLayout.index();
853
854 auto type = std::get<0>(indexedTypeAndLayout.value());
855 auto layout = std::get<1>(indexedTypeAndLayout.value())
856 .cast<DenseIntElementsAttr>();
857
858 if (type.isa<TupleType>())
859 return emitOpError() << "Tuple types are not fully supported with "
860 "layout constraints yet";
861 auto tensorType = type.dyn_cast<TensorType>();
862
863 // For non-tensor types such as !mhlo.token, the layout should be empty.
864 if (!tensorType) {
865 if (layout.empty()) continue;
866 return emitOpError()
867 << "Only tensor types can have non-empty layout: " << valueName
868 << " #" << index << " of type " << type << " has layout "
869 << layout;
870 }
871
872 // For unranked tensors, we cannot verify the compatibility with layout
873 // any further.
874 if (!tensorType.hasRank()) continue;
875
876 // Layout must be a permutation of [0, N) where N is the rank of the
877 // tensor type.
878 std::vector<int64_t> range(tensorType.getRank());
879 std::iota(range.begin(), range.end(), 0);
880 if (tensorType.getRank() != layout.size() ||
881 !std::is_permutation(range.begin(), range.end(), layout.begin()))
882 return emitOpError() << "incorrect layout " << layout << " for type "
883 << type << ", layout must be a permutation of [0, "
884 << tensorType.getRank() << ")";
885 }
886 return success();
887 };
888
889 // At this point both `operand_layouts` and `result_layouts` are defined.
890 ArrayAttr operandLayouts = this->operand_layouts().value();
891 ArrayAttr resultLayouts = this->result_layouts().value();
892
893 // Full support for layouts for arbitrary nesting of tuples is not
894 // supported yet.
895 //
896 // If result does not have any tuples, then i-th element of `result_layouts`
897 // specifies the layout constraints on i-th result.
898 //
899 // For the common case of a single tuple result packing non-tuple values, the
900 // i-th element of `result_layouts` specifies layout for i-th element of the
901 // result tuple.
902 TypeRange resultTypes;
903 if (getNumResults() == 1 && getResult(0).getType().isa<TupleType>())
904 resultTypes = getResult(0).getType().cast<TupleType>().getTypes();
905 else
906 resultTypes = getResultTypes();
907
908 // Verify that operands and operand layouts match.
909 if (failed(
910 verifyTypesAndLayouts(getOperandTypes(), operandLayouts, "operand")))
911 return failure();
912
913 // Verify that results and result layouts match.
914 return verifyTypesAndLayouts(resultTypes, resultLayouts, "result");
915 }
916
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)917 void CustomCallOp::getEffects(
918 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>&
919 effects) {
920 // CustomCall has "all possible effects" unless the has_side_effect is present
921 // and set to false.
922 auto hasSideEffect = (*this)->getAttrOfType<BoolAttr>("has_side_effect");
923 if (hasSideEffect && !hasSideEffect.getValue()) return;
924 effects.emplace_back(MemoryEffects::Allocate::get());
925 effects.emplace_back(MemoryEffects::Free::get());
926 effects.emplace_back(MemoryEffects::Write::get());
927 effects.emplace_back(MemoryEffects::Read::get());
928 }
929
930 //===----------------------------------------------------------------------===//
931 // CholeskyOp
932 //===----------------------------------------------------------------------===//
933
934 // The following properties are already enforced by the ODS:
935 // P0. a.element_type is floating or complex
936 // We intend to verify the following properties
937 // P1. The 'a' argument to Cholesky must have rank >= 2, got shape %s
938 // P2. The two minor dimensions of 'a' must have equal size, got %s.
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)939 LogicalResult CholeskyOp::inferReturnTypeComponents(
940 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
941 DictionaryAttr attributes, RegionRange regions,
942 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
943 CholeskyOp::Adaptor adaptor(operands, attributes, regions);
944 Type aType = adaptor.a().getType();
945 RankedTensorType aRankedType = aType.dyn_cast<RankedTensorType>();
946 if (!aRankedType) {
947 inferredReturnShapes.emplace_back(
948 aType.cast<TensorType>().getElementType());
949 return success();
950 }
951
952 ArrayRef<int64_t> aShape = aRankedType.getShape();
953 if (aShape.size() < 2) {
954 return emitOptionalError(
955 location, "argument 'a' must have rank >= 2, got shape ", aShape, ".");
956 }
957
958 int64_t lastDim = aShape[aShape.size() - 1];
959 int64_t penultimateDim = aShape[aShape.size() - 2];
960 if (!isDynamicDimSize(lastDim) && !isDynamicDimSize(penultimateDim) &&
961 lastDim != penultimateDim) {
962 return emitOptionalError(
963 location, "minor dimensions of 'a' must have equal size, got shape ",
964 aShape, ".");
965 }
966 inferredReturnShapes.emplace_back(aRankedType.getShape(),
967 aRankedType.getElementType());
968 return success();
969 }
970
971 //===----------------------------------------------------------------------===//
972 // DotOp
973 //===----------------------------------------------------------------------===//
974 namespace {
dimCompatible(int64_t a,int64_t b)975 bool dimCompatible(int64_t a, int64_t b) {
976 return isDynamicDimSize(a) || isDynamicDimSize(b) || a == b;
977 }
978
inferDotReturnType(ShapedType lhs,ShapedType rhs)979 ShapedType inferDotReturnType(ShapedType lhs, ShapedType rhs) {
980 auto elementType = lhs.getElementType();
981 if (!lhs.hasRank() || !rhs.hasRank()) {
982 return UnrankedTensorType::get(elementType);
983 }
984
985 // vector dot vector
986 if (1 == lhs.getRank() && 1 == rhs.getRank() &&
987 dimCompatible(lhs.getDimSize(0), rhs.getDimSize(0))) {
988 return RankedTensorType::get({}, elementType);
989 }
990 // matrix dot vector
991 if (2 == lhs.getRank() && 1 == rhs.getRank() &&
992 dimCompatible(lhs.getDimSize(1), rhs.getDimSize(0))) {
993 return RankedTensorType::get({lhs.getDimSize(0)}, elementType);
994 }
995 // vector dot matrix
996 if (1 == lhs.getRank() && 2 == rhs.getRank() &&
997 dimCompatible(lhs.getDimSize(0), rhs.getDimSize(0))) {
998 return RankedTensorType::get({rhs.getDimSize(1)}, elementType);
999 }
1000 // matrix dot matrix
1001 if (2 == lhs.getRank() && 2 == rhs.getRank() &&
1002 dimCompatible(lhs.getDimSize(1), rhs.getDimSize(0))) {
1003 int64_t shape[2] = {lhs.getDimSize(0), rhs.getDimSize(1)};
1004 return RankedTensorType::get(shape, elementType);
1005 }
1006 return {};
1007 }
1008 } // namespace
1009
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1010 LogicalResult DotOp::inferReturnTypes(
1011 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1012 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1013 DotOp::Adaptor op(operands);
1014 auto lhsType = op.lhs().getType().cast<ShapedType>();
1015 auto rhsType = op.rhs().getType().cast<ShapedType>();
1016 inferredReturnTypes.push_back(inferDotReturnType(lhsType, rhsType));
1017 return success();
1018 }
1019
verify()1020 LogicalResult DotOp::verify() {
1021 auto lhsType = lhs().getType().cast<ShapedType>();
1022 auto rhsType = rhs().getType().cast<ShapedType>();
1023 auto resultType = getType().cast<ShapedType>();
1024 auto expectReturnType = inferDotReturnType(lhsType, rhsType);
1025 if (!expectReturnType) {
1026 return emitError() << "Unexpected operands type: " << lhsType << " and "
1027 << rhsType;
1028 }
1029 if (resultType.hasRank() && expectReturnType.hasRank()) {
1030 if (resultType.getShape() != expectReturnType.getShape()) {
1031 return emitError() << "Unexpected result type: has " << resultType
1032 << " but inferred " << expectReturnType
1033 << " from operands " << lhsType << " and " << rhsType;
1034 }
1035 }
1036 return success();
1037 }
1038
1039 //===----------------------------------------------------------------------===//
1040 // DotGeneralOp
1041 //===----------------------------------------------------------------------===//
1042
verify()1043 LogicalResult DotGeneralOp::verify() {
1044 auto dimNumbers = this->dot_dimension_numbers();
1045
1046 ArrayRef<int64_t> lhsBatchingDims = dimNumbers.getLhsBatchingDimensions();
1047 ArrayRef<int64_t> rhsBatchingDims = dimNumbers.getRhsBatchingDimensions();
1048 ArrayRef<int64_t> lhsContractingDims =
1049 dimNumbers.getLhsContractingDimensions();
1050 ArrayRef<int64_t> rhsContractingDims =
1051 dimNumbers.getRhsContractingDimensions();
1052
1053 if (lhsBatchingDims.size() != rhsBatchingDims.size()) {
1054 return emitOpError() << "lhs and rhs should have the same number of "
1055 "batching dimensions";
1056 }
1057 if (lhsContractingDims.size() != rhsContractingDims.size()) {
1058 return emitOpError() << "lhs and rhs should have the same number of "
1059 "contracting dimensions";
1060 }
1061
1062 llvm::SmallDenseSet<int64_t> dimSet;
1063
1064 auto checkDimsDistinct =
1065 [this](ArrayRef<int64_t> batchingDims, ArrayRef<int64_t> contractingDims,
1066 llvm::SmallDenseSet<int64_t>& dimSet, llvm::StringRef lhs,
1067 llvm::StringRef rhs) -> LogicalResult {
1068 auto dims = llvm::concat<const int64_t>(batchingDims, contractingDims);
1069 for (auto dim : dims) {
1070 auto [_, wasInserted] = dimSet.insert(dim);
1071 if (!wasInserted) {
1072 return emitOpError() << "has duplicated dimension from " << lhs
1073 << " and " << rhs << ": " << dim;
1074 }
1075 }
1076 return success();
1077 };
1078
1079 if (failed(checkDimsDistinct(lhsBatchingDims, lhsContractingDims, dimSet,
1080 "lhs_batching_dimensions",
1081 "lhs_contracting_dimensions"))) {
1082 return failure();
1083 }
1084 dimSet.clear();
1085 if (failed(checkDimsDistinct(rhsBatchingDims, rhsContractingDims, dimSet,
1086 "rhs_batching_dimensions",
1087 "rhs_contracting_dimensions"))) {
1088 return failure();
1089 }
1090
1091 auto checkDimsInRange = [this](int64_t rank, ArrayRef<int64_t> dims,
1092 llvm::StringRef dimName) -> LogicalResult {
1093 auto inRange = [&](int64_t i) -> bool { return 0 <= i && i < rank; };
1094 const auto* dimsNotInRange =
1095 std::find_if_not(dims.begin(), dims.end(), inRange);
1096 if (dimsNotInRange != dims.end()) {
1097 return emitOpError() << dimName << " value: " << *dimsNotInRange
1098 << " is out of range: "
1099 << "[0, " << rank << ")";
1100 }
1101 return success();
1102 };
1103
1104 auto lhsType = this->lhs().getType().dyn_cast<RankedTensorType>();
1105 auto rhsType = this->rhs().getType().dyn_cast<RankedTensorType>();
1106
1107 if (lhsType) {
1108 if (failed(checkDimsInRange(lhsType.getRank(), lhsBatchingDims,
1109 "lhs_batching_dimensions")) ||
1110 failed(checkDimsInRange(lhsType.getRank(), lhsContractingDims,
1111 "lhs_contracting_dimensions"))) {
1112 return failure();
1113 }
1114 }
1115 if (rhsType) {
1116 if (failed(checkDimsInRange(rhsType.getRank(), rhsBatchingDims,
1117 "rhs_batching_dimensions")) ||
1118 failed(checkDimsInRange(rhsType.getRank(), rhsContractingDims,
1119 "rhs_contracting_dimensions"))) {
1120 return failure();
1121 }
1122 }
1123
1124 if (lhsType && rhsType) {
1125 // Dimension sizes must be compatible for lhs/rhs.
1126 auto lhsShape = lhsType.getShape();
1127 auto rhsShape = rhsType.getShape();
1128
1129 for (auto [lhs, rhs] : llvm::zip(lhsBatchingDims, rhsBatchingDims)) {
1130 if (lhsShape[lhs] != rhsShape[rhs]) {
1131 return emitOpError() << "batching dimension sizes must match for "
1132 "lhs/rhs";
1133 }
1134 }
1135 for (auto [lhs, rhs] : llvm::zip(lhsContractingDims, rhsContractingDims)) {
1136 if (lhsShape[lhs] != rhsShape[rhs]) {
1137 return emitOpError() << "contracting dimension sizes must match for "
1138 "lhs/rhs";
1139 }
1140 }
1141 }
1142 return success();
1143 }
1144
1145 namespace {
1146 // Handle the generic case of DotGeneral and convert to a regulat DotOp.
1147 struct DotGeneralToDot : public OpRewritePattern<DotGeneralOp> {
1148 using OpRewritePattern<DotGeneralOp>::OpRewritePattern;
1149
matchAndRewritemlir::mhlo::__anon00baf10a0911::DotGeneralToDot1150 LogicalResult matchAndRewrite(DotGeneralOp dot,
1151 PatternRewriter& rewriter) const override {
1152 auto lhs = dot.lhs();
1153 auto rhs = dot.rhs();
1154 auto lhsTy = lhs.getType().cast<ShapedType>();
1155 auto rhsTy = rhs.getType().cast<ShapedType>();
1156
1157 if (lhsTy.getRank() != 2) return failure();
1158 if (rhsTy.getRank() != 2) return failure();
1159
1160 auto nums = dot.dot_dimension_numbers();
1161 if (!nums.getLhsBatchingDimensions().empty()) return failure();
1162 if (!nums.getRhsBatchingDimensions().empty()) return failure();
1163
1164 auto lhsContract = nums.getLhsContractingDimensions();
1165 auto rhsContract = nums.getRhsContractingDimensions();
1166 if (lhsContract.size() != 1 || rhsContract.size() != 1) return failure();
1167
1168 if (lhsContract.front() != 1) return failure();
1169 if (rhsContract.front() != 0) return failure();
1170
1171 rewriter.replaceOpWithNewOp<mhlo::DotOp>(
1172 dot, dot.getType(), lhs, rhs, dot.precision_config().value_or(nullptr));
1173
1174 return success();
1175 }
1176 };
1177 } // namespace
1178
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1179 void DotGeneralOp::getCanonicalizationPatterns(RewritePatternSet& results,
1180 MLIRContext* context) {
1181 results.add<DotGeneralToDot>(context);
1182 }
1183
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1184 LogicalResult DotGeneralOp::reifyReturnTypeShapes(
1185 OpBuilder& builder, ValueRange operands,
1186 SmallVectorImpl<Value>& reifiedReturnShapes) {
1187 auto lhsType = lhs().getType().dyn_cast<ShapedType>();
1188 auto rhsType = rhs().getType().dyn_cast<ShapedType>();
1189 if (!lhsType || !rhsType) {
1190 return failure();
1191 }
1192
1193 Adaptor adaptor(operands);
1194 auto dimNumbers = dot_dimension_numbers();
1195 SmallVector<Value> dimensions;
1196 for (const int64_t lhsDim : dimNumbers.getLhsBatchingDimensions()) {
1197 dimensions.push_back(
1198 builder.create<tensor::DimOp>(getLoc(), adaptor.lhs(), lhsDim));
1199 }
1200
1201 for (int64_t i = 0; i < lhsType.getRank(); i++) {
1202 if (!llvm::is_contained(dimNumbers.getLhsContractingDimensions(), i) &&
1203 !llvm::is_contained(dimNumbers.getLhsBatchingDimensions(), i)) {
1204 dimensions.push_back(
1205 builder.create<tensor::DimOp>(getLoc(), adaptor.lhs(), i));
1206 }
1207 }
1208 for (int64_t i = 0; i < rhsType.getRank(); i++) {
1209 if (!llvm::is_contained(dimNumbers.getRhsContractingDimensions(), i) &&
1210 !llvm::is_contained(dimNumbers.getRhsBatchingDimensions(), i)) {
1211 dimensions.push_back(
1212 builder.create<tensor::DimOp>(getLoc(), adaptor.rhs(), i));
1213 }
1214 }
1215
1216 reifiedReturnShapes.push_back(
1217 builder.create<tensor::FromElementsOp>(getLoc(), dimensions));
1218 return success();
1219 }
1220
1221 //===----------------------------------------------------------------------===//
1222 // FftOp
1223 //===----------------------------------------------------------------------===//
1224
1225 // We intend to verify the following properties
1226 // P1. 1 <= rank <= 3
1227 // P2. Element types agree with fft_type
1228 // P3. Operand shape dimensions agree with fft_length for the given fft_type
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1229 LogicalResult FftOp::inferReturnTypeComponents(
1230 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
1231 DictionaryAttr attributes, RegionRange regions,
1232 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
1233 FftOp::Adaptor adaptor(operands, attributes, regions);
1234 auto fftLength = adaptor.fft_length().getValues<int64_t>();
1235 int64_t fftRank = fftLength.size();
1236
1237 // P1.
1238 if (fftRank > 3 || fftRank < 1) {
1239 return emitOptionalError(location, "rank must be between 1 and 3, but got ",
1240 fftRank, ".");
1241 }
1242
1243 // P2. Element type agreement
1244 // FFT : C -> C
1245 // IFFT : C -> C
1246 // RFFT : R -> C
1247 // IRFFT : C -> R
1248 auto fftType = adaptor.fft_type();
1249 auto operandType = adaptor.operand().getType().cast<TensorType>();
1250 Type operandElementType = operandType.getElementType();
1251 // Check the input element type and infer return element type
1252 if (fftType == FftType::RFFT) {
1253 if (!operandElementType.isF32() && !operandElementType.isF64()) {
1254 return emitOptionalError(
1255 location, "RFFT requires f32 or f64 input type, but is given ",
1256 operandElementType, ".");
1257 }
1258 } else {
1259 if (!operandElementType.isa<ComplexType>()) {
1260 return emitOptionalError(
1261 location, stringifyFftType(fftType),
1262 " takes a complex tensor as input, but is given ", operandType, ".");
1263 }
1264 }
1265 // Generate the output element type
1266 Type resultElementType = operandElementType;
1267 if (fftType == FftType::RFFT) { // RFFT : R -> C
1268 resultElementType = ComplexType::get(resultElementType);
1269 } else if (fftType == FftType::IRFFT) { // IRFFT : C -> R
1270 resultElementType = operandElementType.cast<ComplexType>().getElementType();
1271 }
1272
1273 // P3. Check input shape and infer return shape
1274 operandType = operandType.dyn_cast<RankedTensorType>();
1275 if (!operandType) {
1276 inferredReturnShapes.emplace_back(resultElementType);
1277 return success();
1278 }
1279 auto operandShape = operandType.getShape();
1280 if (static_cast<int64_t>(operandShape.size()) < fftRank) {
1281 return emitOptionalError(
1282 location, "operand rank must not be less than fft rank of ", fftRank,
1283 " for operand of type ", operandType, ".");
1284 }
1285
1286 SmallVector<int64_t> resultShape = to_vector(operandShape);
1287
1288 if (fftType == FftType::RFFT) {
1289 auto shapeBack = operandShape.take_back(fftRank);
1290 for (auto [operandDim, fftDim] : llvm::zip(shapeBack, fftLength)) {
1291 if (operandDim != fftDim) {
1292 return emitOptionalError(
1293 location,
1294 "RFFT requires innermost dimensions match fft_length. Got: ",
1295 operandShape, " but wanted ", fftLength, ".");
1296 }
1297 }
1298 if (fftLength[fftRank - 1] != 0) {
1299 resultShape[resultShape.size() - 1] = fftLength[fftRank - 1] / 2 + 1;
1300 }
1301 }
1302 if (fftType == FftType::IRFFT) {
1303 auto shapeBack = operandShape.take_back(fftRank).drop_back();
1304 for (auto [operandDim, fftDim] : llvm::zip(shapeBack, fftLength)) {
1305 if (operandDim != fftDim) {
1306 return emitOptionalError(location,
1307 "IRFFT requires non-final dimensions "
1308 "match fft_length. Got: ",
1309 operandShape, " but wanted ", fftLength,
1310 ", and ", operandDim, " != ", fftDim, ".");
1311 }
1312 }
1313 if ((operandShape[operandShape.size() - 1] != 0 ||
1314 fftLength[fftRank - 1] != 0) &&
1315 operandShape[operandShape.size() - 1] != fftLength[fftRank - 1] / 2 + 1)
1316 return emitOptionalError(location,
1317 "IRFFT requires innermost dimension match "
1318 "fft_length[-1]/2+1. Got: ",
1319 operandShape, " but fft_length is ", fftLength,
1320 ".");
1321 resultShape[resultShape.size() - 1] = fftLength[fftRank - 1];
1322 }
1323
1324 inferredReturnShapes.emplace_back(resultShape, resultElementType);
1325 return success();
1326 }
1327
1328 //===----------------------------------------------------------------------===//
1329 // GatherOp
1330 //===----------------------------------------------------------------------===//
1331
1332 // Converts gather ops to slice ops in case we have a single set of constant
1333 // indices.
1334 struct GatherSlice : public OpRewritePattern<GatherOp> {
1335 using OpRewritePattern<GatherOp>::OpRewritePattern;
1336
matchAndRewritemlir::mhlo::GatherSlice1337 LogicalResult matchAndRewrite(GatherOp gather,
1338 PatternRewriter& rewriter) const override {
1339 DenseIntElementsAttr index;
1340 if (!matchPattern(gather.start_indices(), m_Constant(&index)))
1341 return failure();
1342
1343 const auto& dnums = gather.dimension_numbers();
1344 if (dnums.getIndexVectorDim() != 0 || index.getType().getRank() > 1)
1345 return failure();
1346
1347 // TODO(tberghammer): Remove when the verifier catches this case what is
1348 // invalid if all previous condition holds.
1349 if (index.getNumElements() !=
1350 static_cast<int64_t>(dnums.getStartIndexMap().size()))
1351 return failure();
1352
1353 RankedTensorType operandType =
1354 gather->getOperand(0).getType().dyn_cast<RankedTensorType>();
1355 if (!operandType || !operandType.hasStaticShape()) return failure();
1356
1357 auto sliceEnd =
1358 llvm::to_vector<8>(gather.slice_sizes().getValues<int64_t>());
1359 llvm::SmallVector<int64_t, 8> sliceStart(sliceEnd.size(), 0);
1360 for (auto it :
1361 llvm::zip(dnums.getStartIndexMap(), index.getValues<APInt>())) {
1362 int64_t mapIndex = std::get<0>(it);
1363 // Clamp the indices within bounds to faithfully mirror gather semantics.
1364 int64_t offset =
1365 clamp(std::get<1>(it).getSExtValue(), static_cast<int64_t>(0),
1366 operandType.getDimSize(mapIndex) - sliceEnd[mapIndex]);
1367 sliceStart[mapIndex] += offset;
1368 sliceEnd[mapIndex] += offset;
1369 }
1370
1371 llvm::SmallVector<int64_t, 8> sliceStride(sliceEnd.size(), 1);
1372 llvm::SmallVector<int64_t, 8> sliceShape(sliceEnd.size());
1373 for (size_t i = 0; i < sliceEnd.size(); ++i) {
1374 sliceShape[i] = sliceEnd[i] - sliceStart[i];
1375 }
1376 Type elementType = gather.getType().cast<TensorType>().getElementType();
1377 auto sliceType = RankedTensorType::get(sliceShape, elementType);
1378 Value result = rewriter.create<SliceOp>(
1379 gather.getLoc(), sliceType, gather.getOperand(0),
1380 rewriter.getI64TensorAttr(sliceStart),
1381 rewriter.getI64TensorAttr(sliceEnd),
1382 rewriter.getI64TensorAttr(sliceStride));
1383
1384 auto collapsedSliceDims = dnums.getCollapsedSliceDims();
1385 if (!collapsedSliceDims.empty()) {
1386 llvm::SmallVector<int64_t, 8> reshapeShape;
1387 for (size_t i = 0; i < sliceShape.size(); ++i) {
1388 if (llvm::count(collapsedSliceDims, i) == 0) {
1389 reshapeShape.push_back(sliceShape[i]);
1390 }
1391 }
1392 auto reshapeType = RankedTensorType::get(reshapeShape, elementType);
1393 result = rewriter.create<ReshapeOp>(gather.getLoc(), reshapeType, result);
1394 }
1395
1396 result.setType(gather.getType());
1397 rewriter.replaceOp(gather, result);
1398 return success();
1399 }
1400 };
1401
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1402 void GatherOp::getCanonicalizationPatterns(RewritePatternSet& results,
1403 MLIRContext* context) {
1404 results.add<GatherSlice>(context);
1405 }
1406
1407 namespace {
1408
1409 // following https://www.tensorflow.org/xla/operation_semantics#gather
1410 // The bounds for the output array along dimension i is computed as follows:
1411 // (1) If i is present in batch_dims (i.e. is equal to batch_dims[k] for some k)
1412 // then we pick
1413 // the corresponding dimension bounds out of start_indices.shape, skipping
1414 // index_vector_dim
1415 // (i.e. pick start_indices.shape.dims[k] if k < index_vector_dim and
1416 // start_indices.shape.dims[k+1] otherwise).
1417 // (2) If i is present in offset_dims (i.e. equal to offset_dims[k] for some k)
1418 // then we pick
1419 // the corresponding bound out of slice_sizes after accounting for
1420 // collapsed_slice_dims
1421 // (i.e. we pick adjusted_slice_sizes[k] where adjusted_slice_sizes is
1422 // slice_sizes with the bounds at indices collapsed_slice_dims removed).
1423
getSliceSizeValues(GatherOp * gather,OpBuilder & builder,Location loc,ValueRange operands,SmallVectorImpl<Value> & sliceSizes)1424 void getSliceSizeValues(GatherOp* gather, OpBuilder& builder, Location loc,
1425 ValueRange operands,
1426 SmallVectorImpl<Value>& sliceSizes) {
1427 for (int64_t val : gather->slice_sizes().getValues<int64_t>()) {
1428 sliceSizes.push_back(builder.create<arith::ConstantIndexOp>(loc, val));
1429 }
1430 }
1431
getSliceSizeValues(DynamicGatherOp *,OpBuilder & builder,Location loc,ValueRange operands,SmallVectorImpl<Value> & sliceSizeValues)1432 void getSliceSizeValues(DynamicGatherOp* /*dGather*/, OpBuilder& builder,
1433 Location loc, ValueRange operands,
1434 SmallVectorImpl<Value>& sliceSizeValues) {
1435 DynamicGatherOp::Adaptor adaptor(operands);
1436 Value sliceSizes = adaptor.slice_sizes();
1437 auto sliceSizesTy = sliceSizes.getType().cast<ShapedType>();
1438 for (int64_t i = 0; i < sliceSizesTy.getDimSize(0); ++i) {
1439 Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
1440 sliceSizeValues.push_back(
1441 builder.create<tensor::ExtractOp>(loc, sliceSizes, idx));
1442 }
1443 }
1444
1445 // Verify the following properties:
1446 // P1. Verify no repeat in start_index_map.
1447 // P2. Verify 0 <= start_index_map[i] < rank(operand), for every i.
1448 // P3. Verify 0 <= index_vector_dim <= rank(start_indices).
1449 // P4. Verify size(start_index_map) == shape(start_indices)[index_vector_dim].
1450 // P5. Verify offset_dims is_sorted and no repeated.
1451 // P6. Verify collapsed_slice_dims is_sorted and no repeated.
1452 // P7. Verify rank(operand) == size(offset_dims) + size(collapsed_slice_dims).
1453 // P8. Verify slice_sizes has rank of 1.
1454 // P9. Verify size(slice_sizes) == rank(operand).
1455 // P10. Verify 0 <= collapsed_slice_dims[i] < size(slice_sizes) for all items.
verifyGather(ShapeAdaptor operandShape,ShapeAdaptor startIndicesShape,ShapeAdaptor sliceSizesShape,GatherDimensionNumbersAttr dimensionNumbers,llvm::function_ref<InFlightDiagnostic ()> errorEmitter)1456 static LogicalResult verifyGather(
1457 ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape,
1458 ShapeAdaptor sliceSizesShape, GatherDimensionNumbersAttr dimensionNumbers,
1459 llvm::function_ref<InFlightDiagnostic()> errorEmitter) {
1460 int64_t indexVectorDim = dimensionNumbers.getIndexVectorDim();
1461
1462 // Check startIndexMap
1463 auto startIndexMap = to_vector(dimensionNumbers.getStartIndexMap());
1464 // P1.
1465 if (hasDuplicates(startIndexMap))
1466 return errorEmitter() << "expects start_index_map to not repeat, got: ["
1467 << startIndexMap << "]";
1468
1469 // P2.
1470 for (int i = 0; i < startIndexMap.size(); ++i)
1471 if (startIndexMap[i] < 0 ||
1472 (operandShape.hasRank() && startIndexMap[i] >= operandShape.getRank()))
1473 return errorEmitter()
1474 << "start_index_map[" << i << "]: " << startIndexMap[i]
1475 << " is out of bounds for "
1476 << "operand rank " << operandShape.getRank();
1477
1478 if (startIndicesShape.hasRank()) {
1479 // P3.
1480 // index_vector_dim == start_indices.rank implies a trailing 1 on the shape
1481 // of start_indices.
1482 if (indexVectorDim > startIndicesShape.getRank() || indexVectorDim < 0)
1483 return errorEmitter() << "index_vector_dim " << indexVectorDim
1484 << " is out of bounds for start indices with rank "
1485 << startIndicesShape.getRank();
1486
1487 bool impliedTrailingDim = indexVectorDim == startIndicesShape.getRank();
1488 if (impliedTrailingDim || !startIndicesShape.isDynamicDim(indexVectorDim)) {
1489 int64_t effectiveDimSize;
1490 if (impliedTrailingDim)
1491 effectiveDimSize = 1;
1492 else
1493 effectiveDimSize = startIndicesShape.getDimSize(indexVectorDim);
1494 // P4.
1495 if (effectiveDimSize !=
1496 static_cast<int64_t>(dimensionNumbers.getStartIndexMap().size()))
1497 return errorEmitter() << "start_index_map size ("
1498 << dimensionNumbers.getStartIndexMap().size()
1499 << ") is not equal to size of index dimension ("
1500 << indexVectorDim << ") of start_indices ("
1501 << effectiveDimSize << ")";
1502 }
1503 }
1504
1505 // P5.
1506 auto offsetDims = to_vector(dimensionNumbers.getOffsetDims());
1507 if (!llvm::is_sorted(offsetDims))
1508 return errorEmitter() << "expects offset_dims to be sorted, got: ["
1509 << offsetDims << "]";
1510 if (hasDuplicates(offsetDims))
1511 return errorEmitter() << "expects offset_dims to not repeat, got: ["
1512 << offsetDims << "]";
1513
1514 // P6.
1515 auto collapsedSliceDims = to_vector(dimensionNumbers.getCollapsedSliceDims());
1516 if (!llvm::is_sorted(collapsedSliceDims))
1517 return errorEmitter() << "expects collapsed_slice_dims to be sorted, got: ["
1518 << collapsedSliceDims << "]";
1519 if (hasDuplicates(collapsedSliceDims))
1520 return errorEmitter()
1521 << "expects collapsed_slice_dims to not repeat, got: ["
1522 << collapsedSliceDims << "]";
1523
1524 // P7.
1525 int64_t impliedOperandRank = dimensionNumbers.getOffsetDims().size() +
1526 dimensionNumbers.getCollapsedSliceDims().size();
1527 if (operandShape.hasRank() && operandShape.getRank() != impliedOperandRank)
1528 return errorEmitter() << "offset_dims size ("
1529 << dimensionNumbers.getOffsetDims().size()
1530 << ") plus collapse_slice_dims size ("
1531 << dimensionNumbers.getCollapsedSliceDims().size()
1532 << ") is not equal to operand rank ("
1533 << operandShape.getRank() << ")";
1534
1535 // P8.
1536 // This should be fully expressible with type constraints, but it isn't
1537 // obvious how to do that with the current infrastructure.
1538 if (sliceSizesShape.hasRank() && sliceSizesShape.getRank() != 1)
1539 return errorEmitter() << "slice_sizes.rank != 1";
1540 if (sliceSizesShape.hasStaticShape()) {
1541 int64_t sliceSize = sliceSizesShape.getNumElements();
1542
1543 // P9.
1544 if (sliceSize != impliedOperandRank)
1545 return errorEmitter() << "slice_sizes size (" << sliceSize
1546 << ") not equal to (implied) operand rank ("
1547 << impliedOperandRank << ")";
1548
1549 // P10.
1550 for (auto dim : dimensionNumbers.getCollapsedSliceDims())
1551 if (dim < 0 || dim >= sliceSize)
1552 return errorEmitter() << "collapsed dimension " << dim
1553 << " is out of bounds for slice_sizes.size ("
1554 << sliceSize << ")";
1555 }
1556
1557 return success();
1558 }
1559
1560 // Verify the following properties:
1561 // P1. Verifications by verifyGather().
1562 // P2. Verify slice_sizes[i] <= 1 for i in collapsed_slice_dims.
1563 // P3. Verify 0 <= slice_sizes[i] < shape(operand)[i], for every i.
verifyStaticGather(ShapeAdaptor operandShape,ShapeAdaptor startIndicesShape,DenseIntElementsAttr sliceSizes,GatherDimensionNumbersAttr dimensionNumbers,llvm::function_ref<InFlightDiagnostic ()> errorEmitter)1564 static LogicalResult verifyStaticGather(
1565 ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape,
1566 DenseIntElementsAttr sliceSizes,
1567 GatherDimensionNumbersAttr dimensionNumbers,
1568 llvm::function_ref<InFlightDiagnostic()> errorEmitter) {
1569 // P1.
1570 // For some reason the getType call is necessary here
1571 if (failed(verifyGather(
1572 /*operandShape=*/operandShape,
1573 /*startIndicesShape=*/startIndicesShape,
1574 /*sliceSizesShape=*/sliceSizes.getType(), dimensionNumbers,
1575 errorEmitter)))
1576 return failure();
1577
1578 // P2.
1579 for (auto dim : dimensionNumbers.getCollapsedSliceDims()) {
1580 int64_t sliceDimSize = sliceSizes.getValues<int64_t>()[dim];
1581 if (sliceDimSize > 1) {
1582 return errorEmitter() << "slice_sizes collapsed dimension " << dim
1583 << " should <= 1 but got " << sliceDimSize;
1584 }
1585 }
1586
1587 // P3.
1588 if (operandShape.hasRank()) {
1589 for (const auto& it : llvm::enumerate(sliceSizes.getValues<int64_t>())) {
1590 if (operandShape.isDynamicDim(it.index())) continue;
1591 auto operandDimSize = operandShape.getDimSize(it.index());
1592 auto sliceDimSize = it.value();
1593 if (sliceDimSize < 0 || sliceDimSize > operandDimSize)
1594 return errorEmitter() << "slice size (" << sliceDimSize
1595 << ") is out of bounds for operand dimension ("
1596 << operandDimSize << ") at index " << it.index();
1597 }
1598 }
1599 return success();
1600 }
1601
1602 template <typename dimTy>
inferGatherShape(int64_t resultRank,llvm::function_ref<dimTy (int64_t)> getStartIndicesDim,llvm::function_ref<dimTy (int64_t)> getSliceDim,GatherDimensionNumbersAttr dimensionNumbers,SmallVectorImpl<dimTy> & shape)1603 static void inferGatherShape(
1604 int64_t resultRank, llvm::function_ref<dimTy(int64_t)> getStartIndicesDim,
1605 llvm::function_ref<dimTy(int64_t)> getSliceDim,
1606 GatherDimensionNumbersAttr dimensionNumbers,
1607 SmallVectorImpl<dimTy>& shape) {
1608 ArrayRef<int64_t> collapsedSliceDims =
1609 dimensionNumbers.getCollapsedSliceDims();
1610 int64_t indexVectorDim = dimensionNumbers.getIndexVectorDim();
1611
1612 // We don't necessarily know the rank of sliceSizes, but we do know that it
1613 // can't be larger than the highest collapsed dimension. So go through those
1614 // and populate the leading dimensions of adjustedSliceSizes. The trailing
1615 // dimensions can just be adjusted by an offset.
1616 const auto* maxCollapsedDimIt =
1617 std::max_element(collapsedSliceDims.begin(), collapsedSliceDims.end());
1618 int64_t maxCollapsedDim = -1;
1619 if (maxCollapsedDimIt != collapsedSliceDims.end())
1620 maxCollapsedDim = *maxCollapsedDimIt;
1621
1622 SmallVector<dimTy> adjustedSliceSizePrefix;
1623 for (int dimIndex = 0; dimIndex <= maxCollapsedDim; ++dimIndex) {
1624 if (llvm::is_contained(collapsedSliceDims, dimIndex)) continue;
1625 adjustedSliceSizePrefix.push_back(getSliceDim(dimIndex));
1626 }
1627 auto getAdjustedSliceDim = [&](int64_t index) -> dimTy {
1628 if (index < static_cast<int64_t>(adjustedSliceSizePrefix.size()))
1629 return adjustedSliceSizePrefix[index];
1630 return getSliceDim(index + collapsedSliceDims.size());
1631 };
1632
1633 ArrayRef<int64_t> offsetDims = dimensionNumbers.getOffsetDims();
1634
1635 // Dimensions in the output that aren't offset dimensions are called batch
1636 // dimensions.
1637 SmallVector<int64_t> batchDims;
1638 for (int dim = 0; dim < resultRank; ++dim)
1639 if (!llvm::is_contained(offsetDims, dim)) batchDims.push_back(dim);
1640
1641 for (int i = 0; i < resultRank; ++i) {
1642 const auto* offsetDimsIt =
1643 std::find(offsetDims.begin(), offsetDims.end(), i);
1644 if (offsetDimsIt != offsetDims.end()) {
1645 auto index = std::distance(offsetDims.begin(), offsetDimsIt);
1646 shape.push_back(getAdjustedSliceDim(index));
1647 continue;
1648 }
1649 auto* batchDimsIt = std::find(batchDims.begin(), batchDims.end(), i);
1650 assert(batchDimsIt != batchDims.end());
1651 auto index = std::distance(batchDims.begin(), batchDimsIt);
1652 // This can never run into the special case where start_indices gets
1653 // implicitly expanded with a trailing 1 if
1654 // index_vector_dim = start_indices.rank because then index would equal
1655 // index_vector_dim, which means we'd be looking at index+1, which would be
1656 // out of bounds anyway.
1657 if (index >= indexVectorDim) ++index;
1658 shape.push_back(getStartIndicesDim(index));
1659 }
1660 }
1661
1662 // Verify the following properties:
1663 // P1. Verify 0 <= offset_dims[i] < output_shape_rank, for every i.
1664 // (output_shape_rank = size(offset_dims) + rank(start_indices) -1)
inferGatherReturnTypeComponents(ShapeAdaptor operandShape,ShapeAdaptor startIndicesShape,llvm::function_ref<int64_t (int64_t)> getSliceDim,GatherDimensionNumbersAttr dimensionNumbers,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes,llvm::function_ref<InFlightDiagnostic ()> errorEmitter)1665 static LogicalResult inferGatherReturnTypeComponents(
1666 ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape,
1667 llvm::function_ref<int64_t(int64_t)> getSliceDim,
1668 GatherDimensionNumbersAttr dimensionNumbers,
1669 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes,
1670 llvm::function_ref<InFlightDiagnostic()> errorEmitter) {
1671 Type elementType = operandShape.getElementType();
1672
1673 // We need this to determine the result rank. We could still place bounds on
1674 // the result rank if that was something ShapedTypeComponents could express.
1675 if (!startIndicesShape.hasRank()) {
1676 inferredReturnShapes.push_back(elementType);
1677 return success();
1678 }
1679
1680 ArrayRef<int64_t> offsetDims = dimensionNumbers.getOffsetDims();
1681 int64_t startIndicesRank = startIndicesShape.getRank();
1682 // If index_vector_dim == start_indices.rank, then an implicit trailing 1 is
1683 // appended to start_indices shape.
1684 if (dimensionNumbers.getIndexVectorDim() == startIndicesRank)
1685 ++startIndicesRank;
1686 int64_t resultRank = offsetDims.size() + startIndicesRank - 1;
1687 // P1.
1688 for (int i = 0; i < offsetDims.size(); ++i)
1689 if (offsetDims[i] < 0 || offsetDims[i] >= resultRank)
1690 return errorEmitter() << "offset_dims[" << i << "]: " << offsetDims[i]
1691 << " is out of bounds for "
1692 << "implied result rank " << resultRank;
1693
1694 auto getStartIndicesDim = [&](int64_t index) {
1695 return startIndicesShape.getDimSize(index);
1696 };
1697
1698 SmallVector<int64_t> shape;
1699 inferGatherShape<int64_t>(resultRank, getStartIndicesDim, getSliceDim,
1700 dimensionNumbers, shape);
1701
1702 inferredReturnShapes.emplace_back(shape, elementType);
1703 return success();
1704 }
1705
1706 template <typename Op>
reifyGatherShape(Op * op,OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1707 LogicalResult reifyGatherShape(Op* op, OpBuilder& builder, ValueRange operands,
1708 SmallVectorImpl<Value>& reifiedReturnShapes) {
1709 // No support for unranked gather output shape a.t.m.
1710 auto resultTy =
1711 op->getResult().getType().template dyn_cast<RankedTensorType>();
1712 if (!resultTy) return failure();
1713
1714 typename Op::Adaptor adaptor(operands);
1715 Value startIndices = adaptor.start_indices();
1716
1717 Location loc = op->getLoc();
1718 int resultRank = resultTy.getRank();
1719 Type shapeElTy = startIndices.getType().cast<ShapedType>().getElementType();
1720 auto toShapeElType = [&](Value v) {
1721 return maybeCastTo(builder, loc, v, shapeElTy);
1722 };
1723
1724 SmallVector<Value, 4> sliceSizes;
1725 getSliceSizeValues(op, builder, loc, operands, sliceSizes);
1726 llvm::transform(sliceSizes, sliceSizes.begin(),
1727 [&](Value v) { return toShapeElType(v); });
1728
1729 auto getStartIndicesDim = [&](int64_t index) {
1730 return toShapeElType(
1731 builder.create<tensor::DimOp>(loc, startIndices, index));
1732 };
1733 SmallVector<Value, 4> shapeValues;
1734 auto getSliceDim = [&sliceSizes](int64_t index) -> Value {
1735 return sliceSizes[index];
1736 };
1737 inferGatherShape<Value>(resultRank, getStartIndicesDim, getSliceDim,
1738 op->dimension_numbers(), shapeValues);
1739
1740 Value outputShape = builder.create<tensor::FromElementsOp>(
1741 loc, RankedTensorType::get({resultRank}, shapeElTy), shapeValues);
1742 reifiedReturnShapes.push_back(outputShape);
1743
1744 return success();
1745 }
1746
1747 } // namespace
1748
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1749 LogicalResult GatherOp::reifyReturnTypeShapes(
1750 OpBuilder& builder, ValueRange operands,
1751 SmallVectorImpl<Value>& reifiedReturnShapes) {
1752 return reifyGatherShape(this, builder, operands, reifiedReturnShapes);
1753 }
1754
1755 // The following properties are already enforced by the ODS:
1756 // P0. Verify the start_indices has element type of integer.
1757 // Verify the following properties:
1758 // Verifications by verifyStaticGather() and verifyGather() inside it.
1759 // Verifications by inferGatherReturnTypeComponents.
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1760 LogicalResult GatherOp::inferReturnTypeComponents(
1761 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
1762 DictionaryAttr attributes, RegionRange regions,
1763 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
1764 // TODO(zhouxin) remove this comment after the ordering issue is clear.
1765 // This can get called before other op verify methods, so we have to do a
1766 // bunch of verification up front. With a better story for ordering and/or
1767 // multi-phase op verification, this should hopefully all go away.
1768 Location loc = location.value_or(UnknownLoc::get(context));
1769 auto errorEmitter = [&loc]() {
1770 return mlir::emitError(loc)
1771 << "'" << GatherOp::getOperationName() << "' op ";
1772 };
1773 GatherOp::Adaptor adaptor(operands, attributes, regions);
1774 if (failed(adaptor.verify(loc))) return failure();
1775
1776 // We want the ShapeAdaptors, so can't route via the adaptor :-/
1777 ShapeAdaptor operandShape = operands.getShape(0);
1778 ShapeAdaptor startIndicesShape = operands.getShape(1);
1779 GatherDimensionNumbersAttr dimensionNumbers = adaptor.dimension_numbers();
1780 DenseIntElementsAttr sliceSizesAttr = adaptor.slice_sizes();
1781
1782 if (failed(verifyStaticGather(/*operandShape=*/operandShape,
1783 /*startIndicesShape=*/startIndicesShape,
1784 /*sliceSizes=*/sliceSizesAttr, dimensionNumbers,
1785 errorEmitter)))
1786 return failure();
1787
1788 auto getSliceDim = [&sliceSizesAttr](int64_t index) -> int64_t {
1789 return sliceSizesAttr.getValues<int64_t>()[index];
1790 };
1791
1792 return inferGatherReturnTypeComponents(operandShape, startIndicesShape,
1793 getSliceDim, dimensionNumbers,
1794 inferredReturnShapes, errorEmitter);
1795 }
1796
1797 //===----------------------------------------------------------------------===//
1798 // DynamicGatherOp
1799 //===----------------------------------------------------------------------===//
1800
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1801 LogicalResult DynamicGatherOp::reifyReturnTypeShapes(
1802 OpBuilder& builder, ValueRange operands,
1803 SmallVectorImpl<Value>& reifiedReturnShapes) {
1804 return reifyGatherShape(this, builder, operands, reifiedReturnShapes);
1805 }
1806
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1807 LogicalResult DynamicGatherOp::inferReturnTypeComponents(
1808 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
1809 DictionaryAttr attributes, RegionRange regions,
1810 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
1811 // This can get called before other op verify methods, so we have to do a
1812 // bunch of verification up front. With a better story for ordering and/or
1813 // multi-phase op verification, this should hopefully all go away.
1814 Location loc = location.value_or(UnknownLoc::get(context));
1815 auto errorEmitter = [&loc]() {
1816 return mlir::emitError(loc)
1817 << "'" << DynamicGatherOp::getOperationName() << "' op ";
1818 };
1819 DynamicGatherOp::Adaptor adaptor(operands, attributes, regions);
1820 if (failed(adaptor.verify(loc))) return failure();
1821
1822 // We want the ShapeAdaptors, so can't route via the adaptor :-/
1823 ShapeAdaptor operandShape = operands.getShape(0);
1824 ShapeAdaptor startIndicesShape = operands.getShape(1);
1825 ShapeAdaptor sliceSizesShape = operands.getShape(2);
1826 GatherDimensionNumbersAttr dimensionNumbers = adaptor.dimension_numbers();
1827
1828 if (failed(verifyGather(/*operandShape=*/operandShape,
1829 /*startIndicesShape=*/startIndicesShape,
1830 /*sliceSizesShape=*/sliceSizesShape, dimensionNumbers,
1831 errorEmitter)))
1832 return failure();
1833
1834 auto getSliceDim = [](int64_t index) { return ShapedType::kDynamicSize; };
1835 return inferGatherReturnTypeComponents(operandShape, startIndicesShape,
1836 getSliceDim, dimensionNumbers,
1837 inferredReturnShapes, errorEmitter);
1838 }
1839
1840 //===----------------------------------------------------------------------===//
1841 // GetDimensionSizeOp
1842 //===----------------------------------------------------------------------===//
1843 //
verify()1844 LogicalResult GetDimensionSizeOp::verify() { return verifyDimAttr(*this); }
1845
1846 /// Fold get_dimension_size when the said shape dimension is a constant.
fold(ArrayRef<Attribute> attrs)1847 OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
1848 RankedTensorType type = operand().getType().dyn_cast<RankedTensorType>();
1849 if (!type) return {};
1850
1851 int32_t dim = dimension();
1852 if (type.isDynamicDim(dim)) return {};
1853 // The result type is always is a 0-d i32 tensor.
1854 return DenseIntElementsAttr::get<int32_t>(
1855 getResult().getType().cast<RankedTensorType>(), type.getDimSize(dim));
1856 }
1857
1858 //===----------------------------------------------------------------------===//
1859 // IotaOp
1860 //===----------------------------------------------------------------------===//
1861
verify()1862 LogicalResult IotaOp::verify() {
1863 auto shape = getType().cast<ShapedType>();
1864 if (!shape.hasRank()) return success();
1865
1866 if (shape.getRank() == 0) return emitOpError() << "does not support scalars.";
1867
1868 auto iotaDimension = this->iota_dimension();
1869 if (static_cast<int64_t>(iotaDimension) >= shape.getRank() ||
1870 iotaDimension < 0)
1871 return emitOpError()
1872 << "iota dimension cannot go beyond the output rank or be negative.";
1873 return success();
1874 }
1875
1876 // Iota operations across multiple dimensions can be reduced to an iota and a
1877 // ranked broadcast.
1878 struct IotaBroadcast : public OpRewritePattern<IotaOp> {
1879 using OpRewritePattern<IotaOp>::OpRewritePattern;
1880
matchAndRewritemlir::mhlo::IotaBroadcast1881 LogicalResult matchAndRewrite(IotaOp iota,
1882 PatternRewriter& rewriter) const override {
1883 auto resultTy = iota.getType().cast<ShapedType>();
1884 if (!resultTy.hasRank() || resultTy.getRank() < 2) {
1885 return failure();
1886 }
1887
1888 auto iotaDimension = iota.iota_dimension();
1889
1890 auto iotaType = RankedTensorType::get({resultTy.getDimSize(iotaDimension)},
1891 resultTy.getElementType());
1892
1893 auto newIota = rewriter.create<IotaOp>(iota.getLoc(), iotaType,
1894 rewriter.getI64IntegerAttr(0));
1895
1896 auto broadcastAttr = DenseIntElementsAttr::get(
1897 RankedTensorType::get({1}, rewriter.getIntegerType(64)),
1898 {iotaDimension});
1899 rewriter.replaceOpWithNewOp<BroadcastInDimOp>(iota, resultTy, newIota,
1900 broadcastAttr);
1901 return success();
1902 }
1903 };
1904
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1905 void IotaOp::getCanonicalizationPatterns(RewritePatternSet& results,
1906 MLIRContext* context) {
1907 results.add<IotaBroadcast>(context);
1908 }
1909
fold(ArrayRef<Attribute> operands)1910 OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
1911 auto dimension = iota_dimension();
1912 auto resultTy = getResult().getType().cast<ShapedType>();
1913 if (resultTy.hasRank() && resultTy.getDimSize(dimension) == 1) {
1914 Builder builder(getContext());
1915 return builder.getZeroAttr(resultTy);
1916 }
1917
1918 return {};
1919 }
1920
1921 //===----------------------------------------------------------------------===//
1922 // DynamicIotaOp
1923 //===----------------------------------------------------------------------===//
1924
1925 // Does the same as PatternRewriter::replaceOpWithNewOp, but with a twist.
1926 //
1927 // Sometimes, we want to replace an op with a new op and simultaneously refine
1928 // the result type from a dynamically-shaped type to a statically-shaped type.
1929 // (Search for usages of this function for examples).
1930 //
1931 // Oftentimes, this works just fine because MHLO is designed to accommodate
1932 // this kind of type refinements. But sometimes, this doesn't work - when
1933 // the op is used outside of the MHLO dialect (e.g. in func.return). In these
1934 // cases, we insert a tensor.cast to smooth things out.
1935 template <typename OpTy, typename... Args>
refineOpWithNewOp(PatternRewriter & rewriter,Operation * op,Args &&...args)1936 OpTy refineOpWithNewOp(PatternRewriter& rewriter, Operation* op,
1937 Args&&... args) {
1938 auto newOp = rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
1939
1940 llvm::SmallVector<Value> replacementResults;
1941 assert(op->getNumResults() == newOp->getNumResults() &&
1942 "replacement op doesn't match results of original op");
1943 for (auto [opResult, newOpResult] :
1944 llvm::zip(op->getResults(), newOp->getResults())) {
1945 Value replacementResult = newOpResult;
1946 if (llvm::any_of(opResult.getUsers(), [&](Operation* user) {
1947 return user->getDialect() != op->getDialect();
1948 })) {
1949 replacementResult = rewriter.create<tensor::CastOp>(
1950 op->getLoc(), opResult.getType(), newOpResult);
1951 }
1952 replacementResults.push_back(replacementResult);
1953 }
1954
1955 rewriter.replaceOp(op, replacementResults);
1956 return newOp;
1957 }
1958
1959 namespace {
1960
1961 struct DynamicIotaIsStatic : public OpRewritePattern<DynamicIotaOp> {
1962 using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
1963
matchAndRewritemlir::mhlo::__anon00baf10a1611::DynamicIotaIsStatic1964 LogicalResult matchAndRewrite(DynamicIotaOp iota,
1965 PatternRewriter& rewriter) const override {
1966 // Result type has static shape, replace with iota.
1967 auto resultTy = iota.getType().cast<ShapedType>();
1968 if (resultTy.hasStaticShape()) {
1969 rewriter.replaceOpWithNewOp<IotaOp>(iota, resultTy,
1970 iota.iota_dimension());
1971 return success();
1972 }
1973
1974 // Output shape is constant, compute result type with static shape, then
1975 // replace with iota.
1976 DenseIntElementsAttr outputShapeAttr;
1977 if (matchPattern(iota.output_shape(), m_Constant(&outputShapeAttr))) {
1978 SmallVector<int64_t> outputShape;
1979 for (APInt dim : outputShapeAttr.getValues<APInt>()) {
1980 outputShape.push_back(dim.getSExtValue());
1981 }
1982 resultTy = RankedTensorType::get(outputShape, resultTy.getElementType());
1983 refineOpWithNewOp<IotaOp>(rewriter, iota, resultTy,
1984 iota.iota_dimension());
1985 return success();
1986 }
1987
1988 return rewriter.notifyMatchFailure(
1989 iota, "requires static shape or constant output shape");
1990 }
1991 };
1992
1993 // Dynamic Iota operations across multiple dimensions can be reduced to an iota
1994 // and a ranked broadcast.
1995 struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
1996 using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
1997
matchAndRewritemlir::mhlo::__anon00baf10a1611::DynamicIotaBroadcast1998 LogicalResult matchAndRewrite(DynamicIotaOp iota,
1999 PatternRewriter& rewriter) const override {
2000 auto resultTy = iota.getType().cast<ShapedType>();
2001 if (!resultTy.hasRank() || resultTy.getRank() < 2) {
2002 return failure();
2003 }
2004
2005 auto iotaDimension = iota.iota_dimension();
2006 auto iotaDimensionInt = iotaDimension;
2007
2008 auto convertedShape = rewriter.create<arith::IndexCastOp>(
2009 iota.getLoc(),
2010 RankedTensorType::get(
2011 iota.output_shape().getType().cast<ShapedType>().getShape(),
2012 rewriter.getI64Type()),
2013 iota.output_shape());
2014
2015 auto slicedShape = rewriter.create<SliceOp>(
2016 iota.getLoc(), convertedShape,
2017 rewriter.getI64TensorAttr(iotaDimensionInt),
2018 rewriter.getI64TensorAttr(iotaDimensionInt + 1),
2019 rewriter.getI64TensorAttr(1));
2020
2021 auto convertedSlicedShape = rewriter.create<arith::IndexCastOp>(
2022 iota.getLoc(),
2023 RankedTensorType::get(
2024 {1},
2025 iota.output_shape().getType().cast<ShapedType>().getElementType()),
2026 slicedShape);
2027
2028 auto iotaType = RankedTensorType::get(
2029 {resultTy.getDimSize(iotaDimensionInt)}, resultTy.getElementType());
2030
2031 auto newIota = rewriter.create<DynamicIotaOp>(
2032 iota.getLoc(), iotaType, convertedSlicedShape,
2033 rewriter.getI64IntegerAttr(0));
2034
2035 auto broadcastAttr = DenseIntElementsAttr::get(
2036 RankedTensorType::get({1}, rewriter.getIntegerType(64)),
2037 {iotaDimension});
2038 rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
2039 iota, resultTy, newIota, iota.output_shape(), broadcastAttr);
2040 return success();
2041 }
2042 };
2043
2044 } // namespace
2045
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2046 void DynamicIotaOp::getCanonicalizationPatterns(RewritePatternSet& results,
2047 MLIRContext* context) {
2048 results.add<DynamicIotaIsStatic>(context);
2049 results.add<DynamicIotaBroadcast>(context);
2050 }
2051
castToIndexTensor(OpBuilder & builder,Location loc,Value shapeOp)2052 static Value castToIndexTensor(OpBuilder& builder, Location loc,
2053 Value shapeOp) {
2054 ShapedType resultTy = shape::getExtentTensorType(
2055 builder.getContext(), shapeOp.getType().cast<ShapedType>().getDimSize(0));
2056 if (shapeOp.getType() == resultTy) return shapeOp; // Nothing to do.
2057 return builder.create<arith::IndexCastOp>(loc, resultTy, shapeOp);
2058 }
2059
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2060 LogicalResult DynamicIotaOp::reifyReturnTypeShapes(
2061 OpBuilder& builder, ValueRange operands,
2062 SmallVectorImpl<Value>& reifiedReturnShapes) {
2063 DynamicIotaOp::Adaptor adaptor(operands);
2064 reifiedReturnShapes.push_back(
2065 castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
2066 return success();
2067 }
2068
2069 //===----------------------------------------------------------------------===//
2070 // DynamicUpdateSliceOp
2071 //===----------------------------------------------------------------------===//
2072
verify()2073 LogicalResult DynamicUpdateSliceOp::verify() {
2074 OperandRange indices = start_indices();
2075 if (indices.size() <= 1) return success();
2076
2077 // Note: start_indices is constrained to Variadic<HLO_ScalarIntTensor>, so it
2078 // is OK to cast indices to ShapedType here.
2079 auto idxTensor = indices.take_front().front().getType().cast<ShapedType>();
2080 Type firstElemTy = idxTensor.getElementType();
2081 Type elemTy;
2082
2083 for (auto idx : llvm::drop_begin(indices, 1)) {
2084 idxTensor = idx.getType().cast<ShapedType>();
2085 elemTy = idxTensor.getElementType();
2086
2087 if (firstElemTy != elemTy) {
2088 return emitOpError() << "start indices must have same element type "
2089 "(encountered mismatch: "
2090 << firstElemTy << " vs " << elemTy << ")";
2091 }
2092 }
2093 return success();
2094 }
2095
fold(ArrayRef<Attribute> operands)2096 OpFoldResult DynamicUpdateSliceOp::fold(ArrayRef<Attribute> operands) {
2097 auto operandShape = this->operand().getType().cast<RankedTensorType>();
2098 auto updateShape = this->update().getType().cast<RankedTensorType>();
2099
2100 // If any of the dimensions are length-0, the update does nothing.
2101 for (auto dim : updateShape.getShape()) {
2102 if (dim == 0) {
2103 return this->operand();
2104 }
2105 }
2106
2107 if (operandShape != updateShape || !operandShape.hasStaticShape()) {
2108 return {};
2109 }
2110
2111 // Ensure that indices are 0 constants. The 0 check mostly ensures
2112 // correctness. For non-constants, the pattern does not fold to avoid hiding
2113 // the behavior of incorrect user input.
2114 for (Value index : this->start_indices()) {
2115 DenseIntElementsAttr deAttr;
2116 if (!matchPattern(index, m_Constant(&deAttr))) return {};
2117 if (!deAttr.getSplatValue<IntegerAttr>().getValue().isZero()) return {};
2118 }
2119 return this->update();
2120 }
2121
2122 //===----------------------------------------------------------------------===//
2123 // AbsOp
2124 //===----------------------------------------------------------------------===//
2125
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2126 LogicalResult AbsOp::inferReturnTypes(
2127 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
2128 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
2129 auto operandTy = (*operands.begin()).getType().cast<ShapedType>();
2130 Type elementTy = operandTy.getElementType();
2131 if (auto complexTy = elementTy.dyn_cast<ComplexType>()) {
2132 elementTy = complexTy.getElementType();
2133 }
2134
2135 Type resultTy;
2136 if (auto rankedOperandTy = operandTy.dyn_cast<RankedTensorType>()) {
2137 resultTy = RankedTensorType::get(operandTy.getShape(), elementTy,
2138 rankedOperandTy.getEncoding());
2139 } else if (operandTy.hasRank()) {
2140 resultTy = RankedTensorType::get(operandTy.getShape(), elementTy);
2141 } else {
2142 resultTy = UnrankedTensorType::get(elementTy);
2143 }
2144 inferredReturnTypes.push_back(resultTy);
2145 return success();
2146 }
2147
2148 //===----------------------------------------------------------------------===//
2149 // CollectivePermuteOp
2150 //===----------------------------------------------------------------------===//
2151
verify()2152 LogicalResult CollectivePermuteOp::verify() {
2153 return mlir::hlo::verifyCollectivePermuteSourceTargetPairs(
2154 *this, source_target_pairs());
2155 }
2156
2157 //===----------------------------------------------------------------------===//
2158 // ConvolutionOp
2159 //===----------------------------------------------------------------------===//
2160
2161 namespace {
2162 // Checks:
2163 // P1. Same sizes for input, kernel and output spatial_dims.
2164 // P2. Spatial and non-spatial dimentions (for input,kernel, &output) should
2165 // be unique and in range [0, num_dims), where num_dims = rank of input
2166 // (lhs/rhs) tensors.
2167 //
2168 // Note that the spatial + non-spatial dimensions may not cover all the
2169 // dimensions in the range [0,num) because of the presence of 'unknown'
2170 // dimensions (ref. cl/415132294).
isSpatialDimensionsValid(ConvolutionOp op)2171 LogicalResult isSpatialDimensionsValid(ConvolutionOp op) {
2172 auto inputSpatialDimensions =
2173 op.dimension_numbers().getInputSpatialDimensions();
2174 auto kernelSpatialDimensions =
2175 op.dimension_numbers().getKernelSpatialDimensions();
2176 auto outputSpatialDimensions =
2177 op.dimension_numbers().getOutputSpatialDimensions();
2178
2179 // P1.
2180 if ((inputSpatialDimensions.size() != kernelSpatialDimensions.size()) ||
2181 (inputSpatialDimensions.size() != outputSpatialDimensions.size()))
2182 return op.emitOpError() << "expects the same size for input, kernel and "
2183 "output spatial-dimensions, but got "
2184 << inputSpatialDimensions.size() << ", "
2185 << kernelSpatialDimensions.size() << ", and "
2186 << outputSpatialDimensions.size() << " resp.";
2187
2188 // P2.
2189 SmallVector<int64_t> inputDnums(inputSpatialDimensions.size() + 2);
2190 inputDnums[0] = op.dimension_numbers().getInputBatchDimension();
2191 inputDnums[1] = op.dimension_numbers().getInputFeatureDimension();
2192 std::copy(inputSpatialDimensions.begin(), inputSpatialDimensions.end(),
2193 inputDnums.begin() + 2);
2194
2195 SmallVector<int64_t> windowDnums(kernelSpatialDimensions.size() + 2);
2196 windowDnums[0] = op.dimension_numbers().getKernelInputFeatureDimension();
2197 windowDnums[1] = op.dimension_numbers().getKernelOutputFeatureDimension();
2198 std::copy(kernelSpatialDimensions.begin(), kernelSpatialDimensions.end(),
2199 windowDnums.begin() + 2);
2200
2201 SmallVector<int64_t> outputDnums(outputSpatialDimensions.size() + 2);
2202 outputDnums[0] = op.dimension_numbers().getOutputBatchDimension();
2203 outputDnums[1] = op.dimension_numbers().getOutputFeatureDimension();
2204 std::copy(outputSpatialDimensions.begin(), outputSpatialDimensions.end(),
2205 outputDnums.begin() + 2);
2206
2207 auto numDims = op.lhs().getType().cast<RankedTensorType>().getRank();
2208 const auto inRange = [numDims](int64_t i) { return 0 <= i && i < numDims; };
2209
2210 if (!llvm::all_of(inputDnums, inRange) ||
2211 !llvm::all_of(windowDnums, inRange) ||
2212 !llvm::all_of(outputDnums, inRange))
2213 return op.emitOpError() << "expects input, kernel, and output "
2214 "dimension-numbers to be in-range [0, "
2215 << numDims << ").";
2216
2217 if (hasDuplicates(inputDnums))
2218 return op.emitOpError()
2219 << "expects input dimension-numbers to be unique, got {"
2220 << inputDnums << "}.";
2221
2222 if (hasDuplicates(windowDnums))
2223 return op.emitOpError()
2224 << "expects kernel dimension-numbers to be unique, got {"
2225 << windowDnums << "}.";
2226
2227 if (hasDuplicates(outputDnums))
2228 return op.emitOpError()
2229 << "expects output dimension-numbers to be unique, got {"
2230 << outputDnums << "}.";
2231
2232 return success();
2233 }
2234
2235 // Verifies the following properties:
2236 // P1. The input, kernel, and output spatial-dimentions are valid.
2237 // P2. Given,
2238 // input-dimensions: b * input-spatial-dims * f
2239 // kernel-dimensions: kernel-spatial-dims * i * o
2240 // output-dimensions: b' * out-spatial-dims * f'
2241 // where b = input-batch-dims
2242 // where f = input-feature-dims
2243 // where i = kernel-input-feature-dims
2244 // where o = kernel-output-feature-dims
2245 // where b' = output-batch-dims
2246 // where f' = output-feature-dims
2247 // Check the following properties w.r.t feature_group_count (fgc) and
2248 // batch_group_count (bgc).
2249 // fgc > 0, bgc > 1 and !(fgc > 1 && bgc > 1)
2250 // b % bgc == 0
2251 // f % fgc == 0 and i = f / fgc
2252 // o (or f') % bgc == 0 and o (or f') % fgc == 0
verifyConvolutionAttributes(ConvolutionOp op)2253 LogicalResult verifyConvolutionAttributes(ConvolutionOp op) {
2254 // P1.
2255 if (failed(isSpatialDimensionsValid(op))) return failure();
2256
2257 // P2.
2258 const int64_t featureGroupCount = op.feature_group_count();
2259 const int64_t batchGroupCount = op.batch_group_count();
2260
2261 if (featureGroupCount <= 0)
2262 return op.emitOpError()
2263 << "expects feature_group_count to be a positive number, got "
2264 << featureGroupCount << ".";
2265
2266 if (batchGroupCount <= 0)
2267 return op.emitOpError()
2268 << "expects batch_group_count to be a positive number, got "
2269 << batchGroupCount << ".";
2270
2271 if (batchGroupCount > 1 && featureGroupCount > 1)
2272 return op.emitOpError()
2273 << "expects batch_group_count and feature_group_count not to be "
2274 "both greater than 1. Got "
2275 << batchGroupCount << " and " << featureGroupCount << " resp.";
2276
2277 auto lhsType = op.lhs().getType().cast<RankedTensorType>();
2278 const int64_t inputFeatures =
2279 lhsType.getShape()[op.dimension_numbers().getInputFeatureDimension()];
2280 const int64_t inputBatch =
2281 lhsType.getShape()[op.dimension_numbers().getInputBatchDimension()];
2282
2283 auto rhsType = op.rhs().getType().cast<RankedTensorType>();
2284 const int64_t kernelInputFeatures =
2285 rhsType
2286 .getShape()[op.dimension_numbers().getKernelInputFeatureDimension()];
2287 const int64_t kernelOutputFeatures =
2288 rhsType
2289 .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()];
2290
2291 if (!isDynamicDimSize(kernelOutputFeatures)) {
2292 if (kernelOutputFeatures % batchGroupCount != 0)
2293 return op.emitOpError() << "expects output feature dimension size ("
2294 << kernelOutputFeatures
2295 << ") to be a multiple of "
2296 "batch_group_count. Got batch_group_count = "
2297 << batchGroupCount << ".";
2298
2299 if (kernelOutputFeatures % featureGroupCount != 0)
2300 return op.emitOpError()
2301 << "expects kernel output feature dimension ("
2302 << kernelOutputFeatures
2303 << ") to be divisible by "
2304 "feature_group_count. For feature_group_count = "
2305 << featureGroupCount << ".";
2306 }
2307
2308 if (!isDynamicDimSize(inputFeatures)) {
2309 if (inputFeatures % featureGroupCount != 0)
2310 return op.emitOpError()
2311 << "expects input feature dimension (" << inputFeatures
2312 << ") to be a multiple of "
2313 "feature_group_count. Got feature_group_count = "
2314 << featureGroupCount << ".";
2315
2316 if (!isDynamicDimSize(kernelInputFeatures) &&
2317 inputFeatures / featureGroupCount != kernelInputFeatures)
2318 return op.emitOpError()
2319 << "expects input feature dimension (" << inputFeatures
2320 << ") / "
2321 "feature_group_count = kernel input feature dimension ("
2322 << kernelInputFeatures
2323 << "). Got feature_group_count = " << featureGroupCount << ".";
2324 }
2325
2326 if (!isDynamicDimSize(inputBatch) && inputBatch % batchGroupCount != 0)
2327 return op.emitOpError() << "expects input batch dimension (" << inputBatch
2328 << ") to be divisible by "
2329 "batch_group_count. Got batch_group_count = "
2330 << batchGroupCount << ".";
2331
2332 return success();
2333 }
2334
2335 // Infer the return-shape of ConvolutionOp.
2336 // Precondition:
2337 // 1. Input args to ConvolutionOp 'op' are RankedTypes.
2338 // 2. rank-of(input-type) == rank-of(output-type)
inferConvolutionOpReturnShape(ConvolutionOp op,const ArrayRef<WindowDimension> window)2339 SmallVector<int64_t> inferConvolutionOpReturnShape(
2340 ConvolutionOp op, const ArrayRef<WindowDimension> window) {
2341 // We keep the 'unknown' dimensions (cl/415132294) as it is in the
2342 // output-shape. To do that we initilize the output dimensions with the shape
2343 // of the return-type and updates only the spatial + non-spatial dimensions.
2344 // Precondition 2 ensures that size of output-shape == size of input-shape.
2345 SmallVector<int64_t> outputDimensions =
2346 to_vector(op.getResult().getType().cast<ShapedType>().getShape());
2347
2348 // Infer the output spatial dimensions.
2349 auto lhsType = op.lhs().getType().cast<RankedTensorType>();
2350 auto inputSpatialDims = op.dimension_numbers().getInputSpatialDimensions();
2351 auto numSpatialDims = inputSpatialDims.size();
2352 SmallVector<int64_t> inputSpatialDimVals(numSpatialDims);
2353 for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i)
2354 inputSpatialDimVals[i] = lhsType.getShape()[inputSpatialDims[i]];
2355
2356 auto windowOutputShape = inferWindowOutputShape(inputSpatialDimVals, window);
2357
2358 for (int64_t i = 0; i < static_cast<int64_t>(window.size()); ++i)
2359 outputDimensions[op.dimension_numbers().getOutputSpatialDimensions()[i]] =
2360 windowOutputShape[i];
2361
2362 // Infer the output-batch-dimension and output-feature-dimension.
2363 auto rhsType = op.rhs().getType().cast<RankedTensorType>();
2364 const int64_t inputBatch =
2365 lhsType.getShape()[op.dimension_numbers().getInputBatchDimension()];
2366 const int64_t kernelOutputFeatures =
2367 rhsType
2368 .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()];
2369
2370 outputDimensions[op.dimension_numbers().getOutputBatchDimension()] =
2371 isDynamicDimSize(inputBatch) ? ShapedType::kDynamicSize
2372 : inputBatch / op.batch_group_count();
2373 outputDimensions[op.dimension_numbers().getOutputFeatureDimension()] =
2374 kernelOutputFeatures;
2375
2376 return outputDimensions;
2377 }
2378
2379 // Some mhlo.convolutions are dot products, specifically when there is no
2380 // padding and no spatial dimensions. DotGeneralOp is general enough that it
2381 // can sufficiently describe it.
2382 struct ConvolutionIsDot : public OpRewritePattern<mhlo::ConvolutionOp> {
2383 using OpRewritePattern<mhlo::ConvolutionOp>::OpRewritePattern;
matchAndRewritemlir::mhlo::__anon00baf10a1711::ConvolutionIsDot2384 LogicalResult matchAndRewrite(mhlo::ConvolutionOp op,
2385 PatternRewriter& rewriter) const override {
2386 auto lhs = op.lhs();
2387 auto rhs = op.rhs();
2388 auto lhsTy = lhs.getType().cast<RankedTensorType>();
2389 auto rhsTy = rhs.getType().cast<RankedTensorType>();
2390 auto resultTy = op.getType().cast<RankedTensorType>();
2391
2392 if (lhsTy.getRank() != 2) return failure();
2393 if (rhsTy.getRank() != 2) return failure();
2394
2395 if (op.batch_group_count() != 1) return failure();
2396
2397 // There should not be any padding if this is a matmul.
2398 auto dNums = op.dimension_numbers();
2399 assert(!op.padding() || op.padding()->empty());
2400 assert(dNums.getKernelSpatialDimensions().empty());
2401
2402 auto lhsBatchDim = dNums.getInputBatchDimension();
2403 auto rhsBatchDim = dNums.getKernelOutputFeatureDimension();
2404 auto lhsContractDim = dNums.getInputFeatureDimension();
2405 auto rhsContractDim = dNums.getKernelInputFeatureDimension();
2406 auto outBatchDim = dNums.getOutputBatchDimension();
2407 auto outFeatureDim = dNums.getOutputFeatureDimension();
2408
2409 // If the input features are not grouped then we can directly convert to an
2410 // mhlo.dot_general.
2411 if (op.feature_group_count() == 1) {
2412 // We can swap the lhs and rhs sides to avoid a transpose.
2413 if (outBatchDim == 1 && outFeatureDim == 0) {
2414 std::swap(lhs, rhs);
2415 std::swap(outBatchDim, outFeatureDim);
2416 std::swap(lhsContractDim, rhsContractDim);
2417 }
2418
2419 auto dotNums = DotDimensionNumbersAttr::get(
2420 op.getContext(), {}, {}, {lhsContractDim}, {rhsContractDim});
2421 auto dotOp = rewriter.create<mhlo::DotGeneralOp>(
2422 op.getLoc(), op.getType(), lhs, rhs, dotNums,
2423 op.precision_config().value_or(nullptr));
2424
2425 rewriter.replaceOp(op, dotOp.getResult());
2426 return success();
2427 }
2428
2429 int64_t featureGroupCount = op.feature_group_count();
2430 int64_t lhsBatchSize = lhsTy.getDimSize(lhsBatchDim);
2431 int64_t lhsContractSize = lhsTy.getDimSize(lhsContractDim);
2432 int64_t rhsBatchSize = rhsTy.getDimSize(rhsBatchDim);
2433 int64_t rhsContractSize = rhsTy.getDimSize(rhsContractDim);
2434
2435 llvm::SmallVector<int64_t> lhsShape;
2436 llvm::SmallVector<int64_t> rhsShape;
2437 lhsShape.resize(3, lhsBatchSize);
2438 rhsShape.resize(3, rhsContractSize);
2439 lhsShape[lhsContractDim] = featureGroupCount;
2440 lhsShape[lhsContractDim + 1] = lhsContractSize / featureGroupCount;
2441 rhsShape[rhsContractDim] = featureGroupCount;
2442 rhsShape[rhsContractDim + 1] = rhsBatchSize / featureGroupCount;
2443
2444 lhsTy = RankedTensorType::get(lhsShape, lhsTy.getElementType());
2445 rhsTy = RankedTensorType::get(rhsShape, rhsTy.getElementType());
2446
2447 lhs = rewriter.create<mhlo::ReshapeOp>(op.getLoc(), lhsTy, lhs);
2448 rhs = rewriter.create<mhlo::ReshapeOp>(op.getLoc(), rhsTy, rhs);
2449
2450 auto dotTy = RankedTensorType::get(
2451 {featureGroupCount, lhsBatchSize, rhsBatchSize / featureGroupCount},
2452 resultTy.getElementType());
2453
2454 auto dotNums = DotDimensionNumbersAttr::get(
2455 op.getContext(), {lhsContractDim}, {rhsContractDim},
2456 {lhsContractDim + 1}, {rhsContractDim == 0 ? 2 : 0});
2457 auto dotOp = rewriter.create<mhlo::DotGeneralOp>(
2458 op.getLoc(), dotTy, lhs, rhs, dotNums,
2459 op.precision_config().value_or(nullptr));
2460
2461 llvm::SmallVector<int64_t> perms;
2462 perms.resize(3, dNums.getOutputBatchDimension() == 0 ? 0 : 2);
2463 perms[0] = dNums.getOutputFeatureDimension();
2464 perms[2] = dNums.getOutputFeatureDimension() + 1;
2465
2466 auto transposeTy = RankedTensorType::get(
2467 {dotTy.getDimSize(perms[0]), dotTy.getDimSize(perms[1]),
2468 dotTy.getDimSize(perms[2])},
2469 dotTy.getElementType());
2470 auto transposeOp = rewriter.create<mhlo::TransposeOp>(
2471 op.getLoc(), transposeTy, dotOp, rewriter.getI64TensorAttr(perms));
2472
2473 rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(op, resultTy, transposeOp);
2474 return success();
2475 }
2476 };
2477
2478 } // namespace
2479
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2480 void ConvolutionOp::getCanonicalizationPatterns(RewritePatternSet& results,
2481 MLIRContext* context) {
2482 results.add<ConvolutionIsDot>(context);
2483 }
2484
2485 /*
2486 * We intend to verify the following properties
2487 * P1. Verify the input, kernel types.
2488 * P2. Verify the convolution atributes.
2489 * P3. Verify and collect the window atributes.
2490 * P4. Verify the return shape.
2491 * TODO(b/232574102): Verify the element-type of return-value.
2492 */
verify()2493 LogicalResult ConvolutionOp::verify() {
2494 auto lhsType = lhs().getType().dyn_cast<RankedTensorType>();
2495 auto rhsType = rhs().getType().dyn_cast<RankedTensorType>();
2496
2497 if (!lhsType || !rhsType) return success();
2498
2499 // P1.
2500 int numDims = lhsType.getRank();
2501 if (numDims != rhsType.getRank())
2502 return emitOpError()
2503 << "expects convolution arguments to have same number of "
2504 "dimensions. Got: "
2505 << lhsType << " and " << rhsType << ".";
2506
2507 if (numDims < 2)
2508 return emitOpError()
2509 << "expects convolution arguments to have >= 2 dimensions. "
2510 "Got: "
2511 << lhsType << " and " << rhsType << ".";
2512
2513 // P2.
2514 if (failed(verifyConvolutionAttributes(*this))) return failure();
2515
2516 // P3.
2517 auto kernelSpatialDimensions =
2518 dimension_numbers().getKernelSpatialDimensions();
2519 SmallVector<int64_t> windowDimensions(kernelSpatialDimensions.size());
2520 for (size_t i = 0; i < windowDimensions.size(); i++)
2521 windowDimensions[i] = rhsType.getShape()[kernelSpatialDimensions[i]];
2522
2523 auto paddingOrErr = convertNx2Attribute(this->padding(), getLoc());
2524 if (failed(paddingOrErr)) return failure();
2525 SmallVector<std::pair<int64_t, int64_t>> padding = *paddingOrErr;
2526
2527 auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
2528 windowDimensions, convertDenseIntAttr(window_strides()), padding,
2529 convertDenseIntAttr(lhs_dilation()), convertDenseIntAttr(rhs_dilation()),
2530 getLoc());
2531 if (failed(windowOrErr)) return failure();
2532
2533 // P4.
2534 auto actualReturnType = getResult().getType().cast<TensorType>();
2535 auto actualReturnElementType = actualReturnType.getElementType();
2536 if (!actualReturnType.hasRank()) return success();
2537
2538 auto actualReturnRankedType = actualReturnType.cast<RankedTensorType>();
2539 if (numDims != actualReturnRankedType.getRank())
2540 return emitOpError() << "expects rank of convolution return-type to be "
2541 "equal to input-ranks ("
2542 << numDims << "), but got "
2543 << actualReturnRankedType.getRank() << ".";
2544
2545 auto expectedReturnShape = inferConvolutionOpReturnShape(*this, *windowOrErr);
2546 auto expectedReturnType =
2547 RankedTensorType::get(expectedReturnShape, actualReturnElementType);
2548 if (failed(verifyCompatibleShape(expectedReturnType, actualReturnRankedType)))
2549 return emitOpError()
2550 << "has shape mismatch between the expected return-type ("
2551 << expectedReturnType << ") and actual return-type ("
2552 << actualReturnRankedType << ").";
2553
2554 return success();
2555 }
2556
2557 //===----------------------------------------------------------------------===//
2558 // ConvertOp
2559 //===----------------------------------------------------------------------===//
2560
build(OpBuilder & builder,OperationState & result,Value operand,Type resultElementTy)2561 void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand,
2562 Type resultElementTy) {
2563 Type resultTy;
2564 Type operandTy = operand.getType();
2565 if (auto rankedTy = operandTy.dyn_cast<RankedTensorType>()) {
2566 resultTy = RankedTensorType::get(rankedTy.getShape(), resultElementTy);
2567 } else {
2568 resultTy = UnrankedTensorType::get(resultElementTy);
2569 }
2570 build(builder, result, resultTy, operand);
2571 }
2572
fold(ArrayRef<Attribute> operands)2573 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
2574 auto operandTy = getOperand().getType().cast<TensorType>();
2575 auto resultTy = getResult().getType().cast<TensorType>();
2576 if (operandTy == resultTy) return getOperand();
2577
2578 // If the result has non-static shape, a convert op is necessary to go from
2579 // static shape to non-static shape.
2580 if (!resultTy.hasStaticShape()) return {};
2581
2582 // If the operand is constant, we can do the conversion now.
2583 auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>();
2584 if (!elementsAttr) return {};
2585
2586 // Prevent folding if the result is too large.
2587 if (elementsAttr.getNumElements() > kFoldOpEltLimit) return {};
2588 return hlo::convertElementsAttr(elementsAttr,
2589 getElementTypeOrSelf(getResult()));
2590 }
2591
2592 namespace {
2593
2594 struct EliminateRedundantConvert : public OpRewritePattern<ConvertOp> {
2595 using OpRewritePattern<ConvertOp>::OpRewritePattern;
matchAndRewritemlir::mhlo::__anon00baf10a1911::EliminateRedundantConvert2596 LogicalResult matchAndRewrite(ConvertOp op,
2597 PatternRewriter& rewriter) const override {
2598 auto convertOp = op.operand().getDefiningOp<ConvertOp>();
2599 if (!convertOp) {
2600 return failure();
2601 }
2602 auto firstType =
2603 convertOp.operand().getType().cast<TensorType>().getElementType();
2604 auto secondType =
2605 op.operand().getType().cast<TensorType>().getElementType();
2606 auto thirdType =
2607 op.getResult().getType().cast<TensorType>().getElementType();
2608 auto loc = rewriter.getFusedLoc({convertOp->getLoc(), op->getLoc()});
2609 if (firstType.isa<FloatType>() && secondType.isa<FloatType>() &&
2610 thirdType.isa<FloatType>()) {
2611 // fold when the second float type's width is longer than first,
2612 // like fp16 -> fp32 -> fp64, bf16 -> fp32 -> fp16
2613 if (secondType.cast<FloatType>().getWidth() >
2614 firstType.cast<FloatType>().getWidth()) {
2615 Value result = rewriter.create<ConvertOp>(loc, op.getResult().getType(),
2616 convertOp.operand());
2617 rewriter.replaceOp(op, result);
2618 return success();
2619 }
2620 } else if (firstType.isa<IntegerType>() && secondType.isa<IntegerType>() &&
2621 thirdType.isa<IntegerType>()) {
2622 // fold when the second integer type's width is longer than first,
2623 // like i16 -> i32 -> i64, u16 -> i32 -> u32
2624 if (secondType.cast<IntegerType>().getWidth() >
2625 firstType.cast<IntegerType>().getWidth()) {
2626 Value result = rewriter.create<ConvertOp>(loc, op.getResult().getType(),
2627 convertOp.operand());
2628 rewriter.replaceOp(op, result);
2629 return success();
2630 }
2631 }
2632 return failure();
2633 }
2634 };
2635
2636 } // namespace
2637
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2638 void ConvertOp::getCanonicalizationPatterns(RewritePatternSet& results,
2639 MLIRContext* context) {
2640 results.add<EliminateIdentityConvert>(context);
2641 results.add<EliminateRedundantConvert>(context);
2642 }
2643
2644 //===----------------------------------------------------------------------===//
2645 // GetTupleElementOp
2646 //===----------------------------------------------------------------------===//
2647
verify()2648 LogicalResult GetTupleElementOp::verify() {
2649 auto indexVal = index();
2650 auto operandType = getOperand().getType().cast<TupleType>();
2651 if (indexVal >= operandType.size()) {
2652 return emitOpError(
2653 llvm::formatv("index {0} is out of bounds of operand with size {1}",
2654 indexVal, operandType.size()));
2655 }
2656
2657 auto expectedType = operandType.getType(indexVal);
2658 if (getType() != expectedType) {
2659 return emitOpError(llvm::formatv("has return type {0}, but expected {1}",
2660 getType(), expectedType));
2661 }
2662 return success();
2663 }
2664
fold(ArrayRef<Attribute> operands)2665 OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
2666 if (auto tupleOp = getOperand().getDefiningOp<mhlo::TupleOp>()) {
2667 return tupleOp.getOperand(index());
2668 }
2669
2670 return {};
2671 }
2672
2673 //===----------------------------------------------------------------------===//
2674 // TupleOp
2675 //===----------------------------------------------------------------------===//
2676
verify()2677 LogicalResult TupleOp::verify() {
2678 auto opType = getType().dyn_cast<TupleType>();
2679 if (!opType) return emitOpError("tuple op with non-tuple result");
2680 if (getNumOperands() != opType.size())
2681 return emitOpError(
2682 "number of operands to tuple expected to match number of types in "
2683 "resultant tuple type");
2684 for (const auto& it :
2685 llvm::enumerate(llvm::zip_first(getOperandTypes(), opType.getTypes()))) {
2686 if (std::get<0>(it.value()) != std::get<1>(it.value()))
2687 return emitOpError("has return type mismatch at ")
2688 << it.index() << "th value (" << std::get<0>(it.value())
2689 << " != " << std::get<1>(it.value()) << ")";
2690 }
2691 return success();
2692 }
2693
2694 namespace {
2695
2696 // Pattern for unpacking and repacking the same tuple.
2697 struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
2698 using OpRewritePattern<TupleOp>::OpRewritePattern;
2699
matchAndRewritemlir::mhlo::__anon00baf10a1a11::UnpackRepackSameTuple2700 LogicalResult matchAndRewrite(TupleOp op,
2701 PatternRewriter& rewriter) const override {
2702 if (op.val().empty()) return failure();
2703
2704 Value firstElement = op.val().front();
2705 auto firstElementOp = firstElement.getDefiningOp<GetTupleElementOp>();
2706 if (!firstElementOp || firstElementOp.indexAttr().getInt() != 0)
2707 return failure();
2708
2709 Value tuplePredecessor = firstElementOp.getOperand();
2710 if (tuplePredecessor.getType() != op.getType()) return failure();
2711
2712 for (const auto& elementAndIdx : llvm::enumerate(op.val().drop_front(1))) {
2713 auto elementOp = elementAndIdx.value().getDefiningOp<GetTupleElementOp>();
2714 if (!elementOp ||
2715 elementOp.indexAttr().getInt() !=
2716 static_cast<int64_t>(elementAndIdx.index() + 1) ||
2717 elementOp.getOperand() != tuplePredecessor)
2718 return failure();
2719 }
2720
2721 rewriter.replaceOp(op, tuplePredecessor);
2722 return success();
2723 }
2724 };
2725
2726 } // namespace
2727
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2728 void TupleOp::getCanonicalizationPatterns(RewritePatternSet& results,
2729 MLIRContext* context) {
2730 results.add<UnpackRepackSameTuple>(context);
2731 }
2732
2733 //===----------------------------------------------------------------------===//
2734 // AllToAllOp
2735 //===----------------------------------------------------------------------===//
2736
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)2737 LogicalResult AllToAllOp::inferReturnTypeComponents(
2738 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
2739 DictionaryAttr attributes, RegionRange regions,
2740 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
2741 AllToAllOp::Adaptor adaptor(operands, attributes, regions);
2742 Type operandType = adaptor.operand().getType();
2743 RankedTensorType operandRankedType = operandType.dyn_cast<RankedTensorType>();
2744 if (!operandRankedType) {
2745 inferredReturnShapes.emplace_back(
2746 operandType.cast<TensorType>().getElementType());
2747 return success();
2748 }
2749
2750 int64_t inputRank = operandRankedType.getRank();
2751 int64_t splitDimension = static_cast<int64_t>(adaptor.split_dimension());
2752 int64_t concatDimension = static_cast<int64_t>(adaptor.concat_dimension());
2753 if (splitDimension >= inputRank || splitDimension < 0) {
2754 return emitOptionalError(location, "AllToAll split_dimension ",
2755 splitDimension,
2756 " is out-of-bounds for input rank ", inputRank);
2757 }
2758 if (concatDimension >= inputRank || concatDimension < 0) {
2759 return emitOptionalError(location, "AllToAll concat_dimension ",
2760 concatDimension,
2761 " is out-of-bounds for input rank ", inputRank);
2762 }
2763
2764 // If operand is ranked, size of split dimension should be a multiple of split
2765 // count.
2766 int64_t splitCount = adaptor.split_count();
2767 auto splitDimSize = operandRankedType.getDimSize(splitDimension);
2768 if (splitDimSize % splitCount != 0) {
2769 return emitOptionalError(
2770 location, "split dimension has size ", splitDimSize,
2771 ", expected to be a multiple of split_count ", splitCount);
2772 }
2773 SmallVector<int64_t> resultShape(operandRankedType.getShape().begin(),
2774 operandRankedType.getShape().end());
2775 resultShape[splitDimension] /= splitCount;
2776 resultShape[concatDimension] *= splitCount;
2777 inferredReturnShapes.emplace_back(resultShape,
2778 operandRankedType.getElementType());
2779 return success();
2780 }
2781
2782 //===----------------------------------------------------------------------===//
2783 // AllGatherOp
2784 //===----------------------------------------------------------------------===//
2785
verify()2786 LogicalResult AllGatherOp::verify() {
2787 // If operand and result are both ranked, then the size of the gather
2788 // dimension in the result should be a multiple of the size of the gather
2789 // dimension in the operand.
2790 auto operandType = operand().getType().dyn_cast<RankedTensorType>();
2791 auto resultType = getType().dyn_cast<RankedTensorType>();
2792 uint64_t allGatherDimIndex = all_gather_dim();
2793 if (!operandType || !resultType ||
2794 operandType.isDynamicDim(allGatherDimIndex) ||
2795 resultType.isDynamicDim(allGatherDimIndex))
2796 return success();
2797 if (operandType.getDimSize(allGatherDimIndex) == 0)
2798 return emitOpError() << "operand gather dimension cannot be zero.";
2799 if ((resultType.getDimSize(allGatherDimIndex) %
2800 operandType.getDimSize(allGatherDimIndex)) != 0)
2801 return emitOpError()
2802 << "result gather dimension has size "
2803 << resultType.getDimSize(allGatherDimIndex)
2804 << ", expected to be a multiple of operand gather dimension size "
2805 << operandType.getDimSize(allGatherDimIndex);
2806
2807 return success();
2808 }
2809
2810 //===----------------------------------------------------------------------===//
2811 // BatchNormGradOp
2812 //===----------------------------------------------------------------------===//
2813
verify()2814 LogicalResult BatchNormGradOp::verify() {
2815 // The following properties are already enforced by the ODS:
2816 // 1. Inputs 'operand' & 'grad_output' and outputs 'grad_operand',
2817 // are ranked-tensors with floating-point (fp) type.
2818 // 2. The shapes of inputs 'operand' & 'grad_output' match.
2819 // 3. Inputs 'scale', 'mean', 'variance' and Outputs 'grad_scale',
2820 // 'grad_offset' are all 1D fp tensors with same shape.
2821 // 4. The element-types of input 'operand' and outputs 'grad_scale',
2822 // 'grad_offset' match.
2823 // 5. The type of input 'operand' and output 'grad_operand' match.
2824 //
2825 // We intend to verify the following properties
2826 // P1. Inputs 'operand' & 'grad_output' has the same shape with fp
2827 // element-types, ignoring fp-precision : Inferred from (1) & (2).
2828 // P2. The feature dimension 'feature_index' is a valid index in 'operand':
2829 // Inferred from check C2 below.
2830 // P3. Inputs 'scale', 'mean', 'variance' must be 1D tensors with same shape
2831 // and fp element-type (ignoring precision) and the number of elements
2832 // in its sole-dimension == number of features in the 'operand's
2833 // feature-dimension 'feature_index': Inferred from (3) and check C3
2834 // below.
2835 // P4. Outputs 'grad_scale' & 'grad_offset' are 1D tensors with
2836 // element-type == element-type of(operand) and same shape as any of
2837 // the inputs 'scale', 'mean', or 'variance': Inferred from (3), (4) and
2838 // check C3 below.
2839 // P5. The type (shape + element-type) of input 'operand' and
2840 // output 'grad_operand' must match: Inferred from (5).
2841
2842 // C2.
2843 auto operandType = operand().getType().cast<RankedTensorType>();
2844 if (static_cast<int64_t>(feature_index()) >= operandType.getRank())
2845 return emitOpError() << "expects feature_index to be smaller "
2846 "than the rank of operand type; got feature_index "
2847 << feature_index() << ", and rank "
2848 << operandType.getRank() << ".";
2849
2850 if (static_cast<int64_t>(feature_index()) < 0)
2851 return emitOpError() << "expects feature_index to be a "
2852 << "non-negative number, got "
2853 << static_cast<int64_t>(feature_index()) << ".";
2854
2855 auto gradOutputType = grad_output().getType().cast<RankedTensorType>();
2856 if (operandType.getRank() != gradOutputType.getRank())
2857 return emitOpError() << "expects 'operand' and 'grad_output' to have the "
2858 "same rank. but got rank(oprand) "
2859 << operandType.getRank() << " and rank(grad_output) "
2860 << gradOutputType.getRank() << ".";
2861
2862 // C3.
2863 const int64_t featureCount = operandType.getShape()[feature_index()];
2864 const int64_t scaleShape =
2865 scale().getType().cast<RankedTensorType>().getShape()[0];
2866 if (scaleShape != featureCount)
2867 return emitOpError() << "expects the size of scale factor to be "
2868 "same as the feature count,"
2869 " but the size of scale factor is "
2870 << scaleShape << " and the feature count is "
2871 << featureCount << ".";
2872
2873 return success();
2874 }
2875
2876 //===----------------------------------------------------------------------===//
2877 // BatchNormTrainingOp
2878 //===----------------------------------------------------------------------===//
2879
verify()2880 LogicalResult BatchNormTrainingOp::verify() {
2881 // The following properties are already enforced by the ODS:
2882 // 1. 'operand' and 'output' are ranked tensors.
2883 // 2. 'scale', 'offset', 'batch_mean', 'batch_var' are 1D tensors.
2884 // 3. Types of 'operand' and 'output' matches.
2885 // 4. Same element-types for 'operand', 'batch_mean', & 'batch_var'.
2886 // 5. Same shapes for 'scale', 'offset', 'batch_mean', & 'batch_var'.
2887
2888 auto operandType = operand().getType().cast<RankedTensorType>();
2889 if (static_cast<int64_t>(feature_index()) >= operandType.getRank())
2890 return emitOpError() << "expects feature_index to be smaller "
2891 "than the rank of operand type; got feature_index "
2892 << feature_index() << ", and rank "
2893 << operandType.getRank() << ".";
2894
2895 if (static_cast<int64_t>(feature_index()) < 0)
2896 return emitOpError() << "expects feature_index to be a "
2897 << "non-negative number, got "
2898 << static_cast<int64_t>(feature_index()) << ".";
2899
2900 // Note:A valid value of feature-index implies 'operand_type.getRank() >=1'.
2901
2902 const int64_t featureCount = operandType.getShape()[feature_index()];
2903 const int64_t scaleShape =
2904 scale().getType().cast<RankedTensorType>().getShape()[0];
2905 // Check number of elements in input 'scale' equals feature_count.
2906 // Together with (5) implies that 'scale', 'offset', 'batch_mean', &
2907 // 'batch_var' all have the same shape.
2908 if (scaleShape != featureCount)
2909 return emitOpError() << "expects the size of scale factor to be "
2910 "same as the feature count,"
2911 " but the size of scale factor is "
2912 << scaleShape << " and the feature count is "
2913 << featureCount << ".";
2914
2915 return success();
2916 }
2917
2918 //===----------------------------------------------------------------------===//
2919 // BatchNormInferenceOp
2920 //===----------------------------------------------------------------------===//
2921
verify()2922 LogicalResult BatchNormInferenceOp::verify() {
2923 // The following properties are already enforced by the ODS:
2924 // 1. 'operand' and 'result' are ranked tensors.
2925 // 2. 'scale', 'offset', 'mean', 'variance' are 1D tensors.
2926 // 3. Types of 'operand' and 'result' matches.
2927 // 4. Same shapes for 'scale', 'offset', 'mean', & 'variance'.
2928
2929 auto operandType = operand().getType().cast<RankedTensorType>();
2930 if (static_cast<int64_t>(feature_index()) >= operandType.getRank())
2931 return emitOpError() << "expects feature_index to be smaller "
2932 "than the rank of operand type; got feature_index "
2933 << feature_index() << ", and rank "
2934 << operandType.getRank() << ".";
2935
2936 if (static_cast<int64_t>(feature_index()) < 0)
2937 return emitOpError() << "expects feature_index to be a "
2938 << "non-negative number, got "
2939 << static_cast<int64_t>(feature_index()) << ".";
2940
2941 // Note:A valid value of feature-index implies 'operand_type.getRank() >=1'.
2942
2943 const int64_t featureCount = operandType.getShape()[feature_index()];
2944 const int64_t scaleSize =
2945 scale().getType().cast<RankedTensorType>().getShape()[0];
2946 // Check number of elements in input 'scale' equals feature_count.
2947 // Together with (4) implies that 'scale', 'offset', 'mean', &
2948 // 'variance' all have the same shape.
2949 if (scaleSize != featureCount)
2950 return emitOpError() << "expects the size of scale factor to be "
2951 "same as the feature count,"
2952 " but the size of scale factor is "
2953 << scaleSize << " and the feature count is "
2954 << featureCount << ".";
2955
2956 return success();
2957 }
2958
2959 //===----------------------------------------------------------------------===//
2960 // BitcastConvertOp
2961 //===----------------------------------------------------------------------===//
2962
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2963 LogicalResult BitcastConvertOp::reifyReturnTypeShapes(
2964 OpBuilder& builder, ValueRange operands,
2965 SmallVectorImpl<Value>& reifiedReturnShapes) {
2966 auto operandType = operands[0].getType().dyn_cast<RankedTensorType>();
2967 auto resultType = getType().dyn_cast<RankedTensorType>();
2968
2969 // Only ranked tensors are supported.
2970 if (!operandType || !resultType) return failure();
2971
2972 // Shape-changing bitcast convert is not implemented.
2973 // TODO(kramerb): This could be done by adjusting the last dimension.
2974 DataLayout dataLayout = DataLayout::closest(*this);
2975 unsigned operandElementSize =
2976 dataLayout.getTypeSizeInBits(operandType.getElementType());
2977 unsigned resultElementSize =
2978 dataLayout.getTypeSizeInBits(resultType.getElementType());
2979 if (operandElementSize != resultElementSize) return failure();
2980
2981 return ::mlir::mhlo::deriveShapeFromOperand(
2982 &builder, getOperation(), operands.front(), &reifiedReturnShapes);
2983 }
2984
2985 /*
2986 * We intend to verify the following properties
2987 * P1. We cannot convert between complex and real types (cf xla)
2988 * P3. The dimensions of the operand and the target
2989 * shape must match, except that the shape with the smaller element bitwidth has
2990 * an appropriately-sized additional innermost dimension, e.g.
2991 * ... x f32 => [bitcast_convert] => ... x 4 x i8
2992 * ... x 4 x i8 => [bitcast_convert] => ... x f32
2993 */
verify()2994 LogicalResult BitcastConvertOp::verify() {
2995 auto operandTensorType = operand().getType().cast<TensorType>();
2996 auto targetTensorType = getResult().getType().cast<TensorType>();
2997
2998 // P1.
2999 auto targetElt = targetTensorType.getElementType();
3000 auto operandElt = operandTensorType.getElementType();
3001 if (targetElt.isa<ComplexType>() != operandElt.isa<ComplexType>()) {
3002 return emitOpError()
3003 << "cannot convert between real and complex types, but got: "
3004 << operandTensorType << " and " << targetTensorType;
3005 }
3006
3007 auto targetEltBitwidth = potentiallyComplexBitwidth(targetElt);
3008 auto operandEltBitwidth = potentiallyComplexBitwidth(operandElt);
3009
3010 // P2.
3011 auto operandType = operandTensorType.dyn_cast<RankedTensorType>();
3012 auto targetType = targetTensorType.dyn_cast<RankedTensorType>();
3013 if (!operandType || !targetType) return success();
3014
3015 auto targetShape = targetType.getShape();
3016 auto operandShape = operandType.getShape();
3017 ArrayRef<int64_t> smallerEltShape, biggerEltShape;
3018 Type smallerElt, biggerElt;
3019 if (operandEltBitwidth < targetEltBitwidth) {
3020 smallerEltShape = operandShape;
3021 smallerElt = operandElt;
3022 biggerEltShape = targetShape;
3023 biggerElt = targetElt;
3024 } else {
3025 smallerEltShape = targetShape;
3026 smallerElt = targetElt;
3027 biggerEltShape = operandShape;
3028 biggerElt = operandElt;
3029 }
3030
3031 ArrayRef<int64_t> smallerEltPrefix;
3032 auto smallerEltBitwidth = std::min(targetEltBitwidth, operandEltBitwidth);
3033 auto biggerEltBitwidth = std::max(targetEltBitwidth, operandEltBitwidth);
3034 if (operandEltBitwidth != targetEltBitwidth) {
3035 if (smallerEltShape.empty()) {
3036 return emitOpError() << "does not allow the smaller element type to be "
3037 "part of a 0d tensor, but got: "
3038 << operandType << " and " << targetType << ".";
3039 }
3040 smallerEltPrefix = smallerEltShape.drop_back();
3041 if (!isDynamicDimSize(smallerEltShape.back()) &&
3042 smallerEltShape.back() * smallerEltBitwidth != biggerEltBitwidth) {
3043 return emitOpError() << "requires compatible bitwidths. "
3044 << "Got: " << operandType << " and " << targetType
3045 << ", but " << smallerEltBitwidth << " * "
3046 << smallerEltShape.back()
3047 << " != " << biggerEltBitwidth << ".";
3048 }
3049 } else {
3050 smallerEltPrefix = smallerEltShape;
3051 }
3052
3053 for (auto it : llvm::zip(smallerEltPrefix, biggerEltShape)) {
3054 auto targetDim = std::get<0>(it);
3055 auto operandDim = std::get<1>(it);
3056 if (!isDynamicDimSize(targetDim) && !isDynamicDimSize(operandDim)) {
3057 if (targetDim != operandDim) {
3058 return emitOpError() << "operand and result shapes must match except "
3059 "for the innermost dimension of the shape with "
3060 "the smaller element type. Got: "
3061 << operandType << " and " << targetType << ".";
3062 }
3063 }
3064 }
3065
3066 return success();
3067 }
3068
3069 //===----------------------------------------------------------------------===//
3070 // BroadcastOp
3071 //===----------------------------------------------------------------------===//
3072
3073 // TODO(b/129012527) These should be expressed as type constraints.
verify()3074 LogicalResult BroadcastOp::verify() {
3075 auto sizes = broadcast_sizes();
3076 auto sizesType = sizes.getType();
3077 auto sizesRank = sizesType.getRank();
3078 if (sizesRank != 1) {
3079 return emitOpError(llvm::formatv(
3080 "broadcast_sizes has rank {0} instead of rank 1", sizesRank));
3081 }
3082
3083 return success();
3084 }
3085
fold(ArrayRef<Attribute> attrs)3086 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> attrs) {
3087 auto type = getType().cast<RankedTensorType>();
3088 auto sizesType = broadcast_sizes().getType();
3089 if (sizesType.getNumElements() == 0) {
3090 return getOperand();
3091 }
3092
3093 // Constant fold when an operand is a splat tensor attribute.
3094 if (!attrs[0] || !type.hasStaticShape()) return {};
3095 auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>();
3096 if (!splatOperandAttr) return {};
3097
3098 // Handle complex type
3099 if (type.getElementType().isa<ComplexType>()) {
3100 ComplexType complex = type.getElementType().cast<ComplexType>();
3101 if (complex.getElementType().isa<FloatType>()) {
3102 return DenseElementsAttr::get(
3103 type, {splatOperandAttr.getSplatValue<std::complex<APFloat>>()});
3104 }
3105 if (complex.getElementType().isa<IntegerType>()) {
3106 return DenseElementsAttr::get(
3107 type, {splatOperandAttr.getSplatValue<std::complex<APInt>>()});
3108 }
3109 return {};
3110 }
3111
3112 return SplatElementsAttr::get(
3113 type, splatOperandAttr.getSplatValue<mlir::Attribute>());
3114 }
3115
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)3116 LogicalResult BroadcastOp::inferReturnTypeComponents(
3117 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
3118 DictionaryAttr attributes, RegionRange regions,
3119 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
3120 BroadcastOp::Adaptor adaptor(operands, attributes, regions);
3121 Value operand = adaptor.operand();
3122 auto operandType = operand.getType().dyn_cast<RankedTensorType>();
3123 if (!operandType) return failure();
3124
3125 Type elementTy = operandType.getElementType();
3126 auto dimensionAttr = adaptor.broadcast_sizes();
3127 for (int64_t size : dimensionAttr.getValues<int64_t>()) {
3128 if (size < 0)
3129 return emitOptionalError(location,
3130 "Broadcast with negative dimension size ", size);
3131 }
3132 SmallVector<int64_t> shapeValues(dimensionAttr.getValues<int64_t>());
3133 llvm::append_range(shapeValues, operandType.getShape());
3134
3135 inferredReturnShapes.emplace_back(shapeValues, elementTy);
3136 return success();
3137 }
3138
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3139 LogicalResult BroadcastOp::reifyReturnTypeShapes(
3140 OpBuilder& builder, ValueRange operands,
3141 SmallVectorImpl<Value>& reifiedReturnShapes) {
3142 BroadcastOp::Adaptor adaptor(operands);
3143 Value operand = adaptor.operand();
3144
3145 auto operandType = operand.getType().dyn_cast<RankedTensorType>();
3146 // Unranked tensors are not supported.
3147 if (!operandType) return failure();
3148
3149 Location loc = getLoc();
3150 SmallVector<Value, 4> shapeValues;
3151
3152 // Collect the broadcast sizes.
3153 for (const auto& size : broadcast_sizes()) {
3154 shapeValues.push_back(
3155 builder.create<arith::ConstantIndexOp>(loc, size.getZExtValue()));
3156 }
3157
3158 // Collect the operand sizes.
3159 for (auto index : llvm::seq<int64_t>(0, operandType.getRank())) {
3160 shapeValues.push_back(
3161 builder.createOrFold<tensor::DimOp>(loc, operand, index));
3162 }
3163
3164 reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
3165 loc,
3166 RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
3167 builder.getIndexType()),
3168 shapeValues));
3169
3170 return success();
3171 }
3172
3173 //===----------------------------------------------------------------------===//
3174 // BroadcastInDimOp
3175 //===----------------------------------------------------------------------===//
3176
verify()3177 LogicalResult BroadcastInDimOp::verify() {
3178 auto operandType = operand().getType().dyn_cast<RankedTensorType>();
3179 if (!operandType) {
3180 // The following verification checks all depend on knowing the rank of
3181 // the operand. Bail out now if we don't know the rank of the operand.
3182 return success();
3183 }
3184
3185 auto operandRank = operandType.getRank();
3186 if (!broadcast_dimensions()) {
3187 if (operandRank == 0) {
3188 return success();
3189 }
3190 return emitOpError(
3191 llvm::formatv("broadcast_dimensions is absent, but required because "
3192 "operand has non-zero rank ({0})",
3193 operandRank));
3194 }
3195
3196 auto dimensions = broadcast_dimensions();
3197 auto dimensionsType = broadcast_dimensions().getType();
3198 auto dimensionsRank = dimensionsType.getRank();
3199 if (dimensionsRank != 1) {
3200 return emitOpError(llvm::formatv(
3201 "broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank));
3202 }
3203
3204 auto dimensionsSize = dimensionsType.getNumElements();
3205 if (dimensionsSize != operandRank) {
3206 return emitOpError(llvm::formatv(
3207 "broadcast_dimensions size ({0}) does not match operand rank ({1})",
3208 dimensionsSize, operandRank));
3209 }
3210
3211 auto resultType = getResult().getType().cast<RankedTensorType>();
3212 auto resultRank = resultType.getRank();
3213 if (resultRank < operandRank) {
3214 return emitOpError(
3215 llvm::formatv("result rank ({0}) is less than operand rank ({1})",
3216 resultRank, operandRank));
3217 }
3218
3219 for (int i = 0; i != dimensionsSize; ++i) {
3220 auto dimIndex = dimensions.getValues<int64_t>()[i];
3221 if (dimIndex >= resultRank) {
3222 return emitOpError(
3223 llvm::formatv("broadcast_dimensions contains invalid value {0} for "
3224 "result with rank {1}",
3225 dimIndex, resultRank));
3226 }
3227
3228 if (!operandType.isDynamicDim(i)) {
3229 auto dimSize = operandType.getDimSize(i);
3230 auto resultDimSize = resultType.getDimSize(dimIndex);
3231 if (dimSize != 1 && dimSize != resultDimSize) {
3232 return emitOpError(
3233 llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
3234 "1 or size of result dimension {2} ({3})",
3235 i, dimSize, dimIndex, resultDimSize));
3236 }
3237 }
3238 }
3239
3240 return success();
3241 }
3242
fold(ArrayRef<Attribute> attrs)3243 OpFoldResult BroadcastInDimOp::fold(ArrayRef<Attribute> attrs) {
3244 auto type = getType().cast<RankedTensorType>();
3245 if (type == getOperand().getType()) {
3246 auto broadcastValues = broadcast_dimensions().getValues<int64_t>();
3247 if (!std::equal(broadcastValues.begin(), broadcastValues.end(),
3248 llvm::seq<int64_t>(0, type.getRank()).begin())) {
3249 return {};
3250 }
3251 return getOperand();
3252 }
3253
3254 // Constant fold when an operand is a splat tensor attribute.
3255 if (!attrs[0] || !type.hasStaticShape()) return {};
3256 auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>();
3257 if (!splatOperandAttr) return {};
3258
3259 // Handle complex type
3260 if (type.getElementType().isa<ComplexType>()) {
3261 ComplexType complex = type.getElementType().cast<ComplexType>();
3262 if (complex.getElementType().isa<FloatType>()) {
3263 return DenseElementsAttr::get(
3264 type, {splatOperandAttr.getSplatValue<std::complex<APFloat>>()});
3265 }
3266 if (complex.getElementType().isa<IntegerType>()) {
3267 return DenseElementsAttr::get(
3268 type, {splatOperandAttr.getSplatValue<std::complex<APInt>>()});
3269 }
3270 return {};
3271 }
3272
3273 return SplatElementsAttr::get(
3274 type, splatOperandAttr.getSplatValue<mlir::Attribute>());
3275 }
3276
3277 // Simplify BroadcastInDim has the following behaviors: replace BroadcastInDim
3278 // with Reshape or Transpose if they are equivalent or replace
3279 // BroadcastInDim(BroadcastInDim(X)) with BroadcastInDim(X)
3280 class BroadcastInDimSimplifier : public OpRewritePattern<BroadcastInDimOp> {
3281 public:
3282 using OpRewritePattern<BroadcastInDimOp>::OpRewritePattern;
matchAndRewrite(BroadcastInDimOp op,PatternRewriter & rewriter) const3283 LogicalResult matchAndRewrite(BroadcastInDimOp op,
3284 PatternRewriter& rewriter) const override {
3285 auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
3286 auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
3287 if (!operandType || !resultType) {
3288 return failure();
3289 }
3290 auto bsDimIndices = op.broadcast_dimensions().getValues<int64_t>();
3291 if (operandType.hasStaticShape() && resultType.hasStaticShape()) {
3292 bool sameTotalElements =
3293 operandType.getNumElements() == resultType.getNumElements();
3294 // BroadcastInDim equivalent to reshape
3295 if (llvm::is_sorted(bsDimIndices) && sameTotalElements) {
3296 rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand());
3297 return success();
3298 }
3299 // BroadcastInDim equivalent to transpose
3300 if (operandType.getRank() == resultType.getRank() && sameTotalElements) {
3301 rewriter.replaceOpWithNewOp<TransposeOp>(op, op.getType(), op.operand(),
3302 op.broadcast_dimensions());
3303 return success();
3304 }
3305 }
3306 // eliminate redundant BroadcastInDim
3307 if (auto broadcastInDimOp = llvm::dyn_cast_or_null<BroadcastInDimOp>(
3308 op.operand().getDefiningOp())) {
3309 auto newIndices =
3310 broadcastInDimOp.broadcast_dimensions()
3311 .mapValues(op.broadcast_dimensions().getElementType(),
3312 [&bsDimIndices](const APInt& dim) -> APInt {
3313 return APInt(dim.getBitWidth(),
3314 bsDimIndices[dim.getSExtValue()], true);
3315 })
3316 .cast<DenseIntElementsAttr>();
3317 rewriter.replaceOpWithNewOp<BroadcastInDimOp>(
3318 op, op.getType(), broadcastInDimOp.operand(), newIndices);
3319 return success();
3320 }
3321 return failure();
3322 }
3323 };
3324
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3325 void BroadcastInDimOp::getCanonicalizationPatterns(RewritePatternSet& results,
3326 MLIRContext* context) {
3327 results.add<BroadcastInDimSimplifier>(context);
3328 }
3329
3330 //===----------------------------------------------------------------------===//
3331 // DynamicBroadcastInDimOp
3332 //===----------------------------------------------------------------------===//
3333
verify()3334 LogicalResult DynamicBroadcastInDimOp::verify() {
3335 auto operandType = operand().getType().dyn_cast<RankedTensorType>();
3336 auto resultType = getResult().getType().dyn_cast<RankedTensorType>();
3337
3338 // If either the operand or result are unranked, there is very little
3339 // to verify statically.
3340 if (!operandType || !resultType) {
3341 return success();
3342 }
3343
3344 auto outputDimensionsType =
3345 output_dimensions().getType().cast<RankedTensorType>();
3346 auto outputDimensionsSize = outputDimensionsType.getDimSize(0);
3347 auto operandRank = operandType.getRank();
3348 auto resultRank = resultType.getRank();
3349
3350 // Verify broadcast_dimensions.
3351 auto bcastDimensions = broadcast_dimensions();
3352 auto bcastDimensionsType = broadcast_dimensions().getType();
3353 auto bcastDimensionsRank = bcastDimensionsType.getRank();
3354 // TODO(laurenzo): Update the BroadcastDimAttr to constrain its rank to 1.
3355 if (bcastDimensionsRank != 1) {
3356 return emitOpError(
3357 llvm::formatv("broadcast_dimensions has rank {0} instead of rank 1",
3358 bcastDimensionsRank));
3359 }
3360
3361 auto bcastDimensionsSize = bcastDimensionsType.getNumElements();
3362 if (bcastDimensionsSize != operandRank) {
3363 return emitOpError(llvm::formatv(
3364 "broadcast_dimensions size ({0}) does not match operand rank ({1})",
3365 bcastDimensionsSize, operandRank));
3366 }
3367
3368 if (resultRank < operandRank) {
3369 return emitOpError(
3370 llvm::formatv("result rank ({0}) is less than operand rank ({1})",
3371 resultRank, operandRank));
3372 }
3373
3374 for (int i = 0; i != bcastDimensionsSize; ++i) {
3375 auto dimIndex = bcastDimensions.getValues<int64_t>()[i];
3376 if (dimIndex >= resultRank) {
3377 return emitOpError(
3378 llvm::formatv("broadcast_dimensions contains invalid value {0} for "
3379 "result with rank {1}",
3380 dimIndex, resultRank));
3381 }
3382
3383 auto dimSize = operandType.getDimSize(i);
3384 auto resultDimSize = resultType.getDimSize(dimIndex);
3385 // Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we
3386 // add a manual check for this.
3387 if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) {
3388 return emitOpError(
3389 llvm::formatv("size of operand dimension {0} ({1}) is not compatible "
3390 "with size of result dimension {2} ({3})",
3391 i, dimSize, dimIndex, resultDimSize));
3392 }
3393 }
3394
3395 if (outputDimensionsSize != resultRank) {
3396 return emitOpError(
3397 llvm::formatv("result rank ({0}) is not equal to number of output "
3398 "dimensions ({1})",
3399 resultRank, outputDimensionsSize));
3400 }
3401
3402 // Verify that the known expanding and non-expanding dimensions are a subset
3403 // of the operand's dimensions.
3404 int64_t numKnownExpansionBehavior = 0;
3405 DenseSet<int64_t> knownExpansionBehavior;
3406 auto collectExpansionBehaviorDims =
3407 [&](const Optional<DenseIntElementsAttr>& attr) {
3408 if (!attr) return;
3409 for (const APInt& it : *attr) {
3410 numKnownExpansionBehavior++;
3411 knownExpansionBehavior.insert(it.getLimitedValue());
3412 }
3413 };
3414 collectExpansionBehaviorDims(known_expanding_dimensions());
3415 collectExpansionBehaviorDims(known_nonexpanding_dimensions());
3416 if (knownExpansionBehavior.size() != numKnownExpansionBehavior) {
3417 return emitOpError(
3418 "duplicate expansion hint for at least one operand dimension");
3419 }
3420 for (int64_t i : knownExpansionBehavior) {
3421 if (i < 0 || i >= operandRank) {
3422 return emitOpError(
3423 llvm::formatv("hint for expanding dimension {0} does not refer to a "
3424 "valid operand dimension",
3425 i));
3426 }
3427 }
3428
3429 return success();
3430 }
3431
3432 namespace {
3433 // If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary
3434 // BroadcastInDimOp.
3435 class DynamicBroadcastInDimOpNotActuallyDynamic
3436 : public OpRewritePattern<DynamicBroadcastInDimOp> {
3437 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicBroadcastInDimOp op,PatternRewriter & rewriter) const3438 LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op,
3439 PatternRewriter& rewriter) const override {
3440 auto type = op.getType().dyn_cast<RankedTensorType>();
3441 auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
3442 auto* outputDimOp = op.output_dimensions().getDefiningOp();
3443 if (!type || !operandType || !operandType.hasStaticShape()) {
3444 return rewriter.notifyMatchFailure(op, "requires operand static shape");
3445 }
3446 // output has static shape, replace with broadcast_in_dim
3447 if (type.hasStaticShape()) {
3448 rewriter.replaceOpWithNewOp<BroadcastInDimOp>(op, type, op.operand(),
3449 op.broadcast_dimensions());
3450 return success();
3451 }
3452 // output_dimensions are constant, set output shape with output_dimensions,
3453 // then replace with broadcast_in_dim
3454 if (outputDimOp && outputDimOp->hasTrait<mlir::OpTrait::ConstantLike>()) {
3455 DenseIntElementsAttr shapeAttr;
3456 if (matchPattern(outputDimOp, m_Constant(&shapeAttr))) {
3457 SmallVector<int64_t> outputShape;
3458 for (APInt shape : shapeAttr.getValues<APInt>()) {
3459 outputShape.push_back(shape.getZExtValue());
3460 }
3461 refineOpWithNewOp<BroadcastInDimOp>(
3462 rewriter, op,
3463 RankedTensorType::get(outputShape, type.getElementType()),
3464 op.operand(), op.broadcast_dimensions());
3465 return success();
3466 }
3467 }
3468 return rewriter.notifyMatchFailure(
3469 op, "requires output static shape or constant broadcast dimensions");
3470 }
3471 };
3472
3473 class ChainedDynamicBroadcastInDimCanonicalization
3474 : public OpRewritePattern<DynamicBroadcastInDimOp> {
3475 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicBroadcastInDimOp bcast,PatternRewriter & rewriter) const3476 LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast,
3477 PatternRewriter& rewriter) const override {
3478 auto precedingBcast =
3479 bcast.operand().getDefiningOp<DynamicBroadcastInDimOp>();
3480 if (!precedingBcast) return failure();
3481
3482 // Compose broadcast dimensions.
3483 DenseIntElementsAttr precedingBcastDims =
3484 precedingBcast.broadcast_dimensions();
3485 DenseIntElementsAttr bcastDims = bcast.broadcast_dimensions();
3486 SmallVector<APInt, 4> composition;
3487 for (APInt precedingDim : precedingBcastDims) {
3488 composition.push_back(
3489 bcastDims.getValues<APInt>()[precedingDim.getZExtValue()]);
3490 }
3491 auto composedBcastDims =
3492 DenseIntElementsAttr::get(precedingBcastDims.getType(), composition);
3493
3494 rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
3495 bcast, bcast.getType(), precedingBcast.operand(),
3496 bcast.output_dimensions(), composedBcastDims);
3497 return success();
3498 }
3499 };
3500 } // namespace
3501
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3502 void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
3503 RewritePatternSet& results, MLIRContext* context) {
3504 results.add<ChainedDynamicBroadcastInDimCanonicalization,
3505 DynamicBroadcastInDimOpNotActuallyDynamic,
3506 DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2,
3507 DynamicBroadcastToOwnShape_3, DynamicBroadcastToOwnShape_4>(
3508 context);
3509 }
3510
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3511 LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
3512 OpBuilder& builder, ValueRange operands,
3513 SmallVectorImpl<Value>& reifiedReturnShapes) {
3514 DynamicBroadcastInDimOp::Adaptor adaptor(operands);
3515 reifiedReturnShapes.push_back(
3516 castToIndexTensor(builder, getLoc(), adaptor.output_dimensions()));
3517 return success();
3518 }
3519
3520 //===----------------------------------------------------------------------===//
3521 // ClampOp
3522 //===----------------------------------------------------------------------===//
3523
verify()3524 LogicalResult ClampOp::verify() {
3525 auto operandType = operand().getType().cast<RankedTensorType>();
3526 auto operandShape = operandType.getShape();
3527 auto minType = min().getType().cast<RankedTensorType>();
3528
3529 auto minShape = minType.getShape();
3530 if (failed(verifyCompatibleShape(minType, operandType)) &&
3531 minType.getRank() != 0) {
3532 return emitOpError(llvm::formatv(
3533 "min shape [{0}] is not scalar and is not compatible to operand shape "
3534 "[{1}]",
3535 llvm::make_range(minShape.begin(), minShape.end()),
3536 llvm::make_range(operandShape.begin(), operandShape.end())));
3537 }
3538
3539 auto maxType = max().getType().cast<RankedTensorType>();
3540 auto maxShape = maxType.getShape();
3541 if (failed(verifyCompatibleShape(maxType, operandType)) &&
3542 maxType.getRank() != 0) {
3543 return emitOpError(llvm::formatv(
3544 "max shape [{0}] is not scalar and is not compatible to operand shape "
3545 "[{1}]",
3546 llvm::make_range(maxShape.begin(), maxShape.end()),
3547 llvm::make_range(operandShape.begin(), operandShape.end())));
3548 }
3549
3550 return success();
3551 }
3552
inferReturnTypeComponents(MLIRContext *,Optional<Location>,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)3553 LogicalResult ClampOp::inferReturnTypeComponents(
3554 MLIRContext*, Optional<Location> /*location*/, ValueShapeRange operands,
3555 DictionaryAttr attributes, RegionRange regions,
3556 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
3557 ClampOp::Adaptor adaptor(operands, attributes, regions);
3558 RankedTensorType operandType =
3559 adaptor.operand().getType().cast<RankedTensorType>();
3560 inferredReturnShapes.emplace_back(operandType.getShape(),
3561 operandType.getElementType());
3562 return success();
3563 }
3564
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3565 LogicalResult ClampOp::reifyReturnTypeShapes(
3566 OpBuilder& builder, ValueRange operands,
3567 SmallVectorImpl<Value>& reifiedReturnShapes) {
3568 // For `mhlo.clamp`, the first operand may be a scalar.
3569 return deriveShapeFromOperand(&builder, getOperation(), operands[1],
3570 &reifiedReturnShapes);
3571 }
3572
3573 //===----------------------------------------------------------------------===//
3574 // ComplexOp
3575 //===----------------------------------------------------------------------===//
3576
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)3577 LogicalResult ComplexOp::inferReturnTypes(
3578 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
3579 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
3580 TensorType operandType = operands[0].getType().cast<TensorType>();
3581 ComplexType elementTy = ComplexType::get(operandType.getElementType());
3582 inferredReturnTypes.push_back(getSameShapeTensorType(operandType, elementTy));
3583 return success();
3584 }
3585
fold(ArrayRef<Attribute> operands)3586 OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
3587 auto realOp = getOperand(0).getDefiningOp<mhlo::RealOp>();
3588 auto imagOp = getOperand(1).getDefiningOp<mhlo::ImagOp>();
3589 if (realOp && imagOp && realOp.getOperand() == imagOp.getOperand()) {
3590 return realOp.getOperand();
3591 }
3592
3593 return {};
3594 }
3595
3596 //===----------------------------------------------------------------------===//
3597 // ImagOp
3598 //===----------------------------------------------------------------------===//
3599
3600 namespace {
createRealType(TensorType type)3601 Type createRealType(TensorType type) {
3602 auto elementTy = type.getElementType();
3603 if (auto complexTy = elementTy.dyn_cast<ComplexType>()) {
3604 elementTy = complexTy.getElementType();
3605 }
3606 return getSameShapeTensorType(type, elementTy);
3607 }
3608 } // namespace
3609
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)3610 LogicalResult ImagOp::inferReturnTypes(
3611 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
3612 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
3613 inferredReturnTypes.push_back(
3614 createRealType(operands[0].getType().cast<TensorType>()));
3615 return success();
3616 }
3617
fold(ArrayRef<Attribute> operands)3618 OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
3619 if (auto complexOp = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
3620 return complexOp.getOperand(1);
3621 }
3622
3623 return {};
3624 }
3625
3626 //===----------------------------------------------------------------------===//
3627 // IsFiniteOp
3628 //===----------------------------------------------------------------------===//
3629
getSameShapeTensorType(TensorType tensorType,Type elementType)3630 TensorType getSameShapeTensorType(TensorType tensorType, Type elementType) {
3631 if (auto rankedTensorTy = tensorType.dyn_cast<RankedTensorType>()) {
3632 return RankedTensorType::get(rankedTensorTy.getShape(), elementType,
3633 rankedTensorTy.getEncoding());
3634 }
3635 if (auto unrankedTensorTy = tensorType.dyn_cast<UnrankedTensorType>()) {
3636 return UnrankedTensorType::get(elementType);
3637 }
3638 llvm_unreachable("unhandled type");
3639 }
3640
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)3641 LogicalResult IsFiniteOp::inferReturnTypes(
3642 MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
3643 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
3644 auto argTy = operands.front().getType().cast<TensorType>();
3645 Builder b(ctx);
3646 inferredReturnTypes.push_back(getSameShapeTensorType(argTy, b.getI1Type()));
3647 return success();
3648 }
3649
3650 //===----------------------------------------------------------------------===//
3651 // RealOp
3652 //===----------------------------------------------------------------------===//
3653
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)3654 LogicalResult RealOp::inferReturnTypes(
3655 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
3656 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
3657 inferredReturnTypes.push_back(
3658 createRealType(operands[0].getType().cast<TensorType>()));
3659 return success();
3660 }
3661
fold(ArrayRef<Attribute> operands)3662 OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
3663 if (auto complexOp = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
3664 return complexOp.getOperand(0);
3665 }
3666
3667 return {};
3668 }
3669
3670 //===----------------------------------------------------------------------===//
3671 // ConcatenateOp
3672 //===----------------------------------------------------------------------===//
3673
3674 namespace {
3675 class SingleOperandConcatenateToCast : public OpRewritePattern<ConcatenateOp> {
3676 public:
3677 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(ConcatenateOp op,PatternRewriter & rewriter) const3678 LogicalResult matchAndRewrite(ConcatenateOp op,
3679 PatternRewriter& rewriter) const override {
3680 if (op.val().size() != 1) return failure();
3681
3682 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
3683 op.val().front());
3684 return success();
3685 }
3686 };
3687
3688 class ConcatenateOperandRemoval : public OpRewritePattern<ConcatenateOp> {
3689 public:
3690 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(ConcatenateOp op,PatternRewriter & rewriter) const3691 LogicalResult matchAndRewrite(ConcatenateOp op,
3692 PatternRewriter& rewriter) const override {
3693 auto axis = op.dimension();
3694 llvm::SmallVector<Value, 6> newOperands;
3695 for (auto operand : op.getOperands()) {
3696 auto ty = operand.getType().cast<ShapedType>();
3697 if (!ty.hasRank() || ty.getDimSize(axis) != 0) {
3698 newOperands.push_back(operand);
3699 }
3700 }
3701
3702 if (!newOperands.empty() && newOperands.size() < op.getNumOperands()) {
3703 rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(),
3704 newOperands, op.dimension());
3705 return success();
3706 }
3707
3708 return failure();
3709 }
3710 };
3711
3712 class ConcatenateForwarding : public OpRewritePattern<ConcatenateOp> {
3713 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(ConcatenateOp op,PatternRewriter & rewriter) const3714 LogicalResult matchAndRewrite(ConcatenateOp op,
3715 PatternRewriter& rewriter) const override {
3716 auto getFlattenedOperands = [&](const Value& val) -> ValueRange {
3717 auto definingOp = dyn_cast_or_null<ConcatenateOp>(val.getDefiningOp());
3718 // To avoid inflate the memory footprint, only flatten the ConcatenateOp
3719 // when it has only one use.
3720 if (definingOp && definingOp->hasOneUse() &&
3721 definingOp.dimension() == op.dimension())
3722 return definingOp.val();
3723 return val;
3724 };
3725
3726 bool needToFlatten = false;
3727 int operandCount = 0;
3728 llvm::for_each(op.val(), [&](Value val) {
3729 auto result = getFlattenedOperands(val);
3730 if (result.size() != 1 || result[0] != val) needToFlatten = true;
3731 operandCount += result.size();
3732 });
3733
3734 if (!needToFlatten) return failure();
3735
3736 llvm::SmallVector<Value, 6> newOperands;
3737 newOperands.reserve(operandCount);
3738
3739 for (auto operand : op.val()) {
3740 auto flattenedOperands = getFlattenedOperands(operand);
3741 newOperands.append(flattenedOperands.begin(), flattenedOperands.end());
3742 }
3743
3744 rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(),
3745 newOperands, op.dimension());
3746 return success();
3747 }
3748 };
3749
3750 } // namespace
3751
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)3752 LogicalResult ConcatenateOp::inferReturnTypes(
3753 MLIRContext*, Optional<Location> location, ValueRange operands,
3754 DictionaryAttr attributes, RegionRange regions,
3755 SmallVectorImpl<Type>& inferredReturnTypes) {
3756 if (operands.empty()) {
3757 return failure();
3758 }
3759
3760 auto dimensionAttr = attributes.get("dimension").cast<IntegerAttr>();
3761 auto dimension = dimensionAttr.getInt();
3762
3763 auto firstType = (*operands.begin()).getType().cast<ShapedType>();
3764 auto outElement = firstType.getElementType();
3765
3766 // Find the first ranked input to determine the output rank.
3767 for (auto type : operands.getTypes()) {
3768 auto shapedType = type.cast<ShapedType>();
3769 if (shapedType.hasRank()) {
3770 firstType = shapedType;
3771 break;
3772 }
3773 }
3774
3775 // If all inputs are unranked, the result must be unranked.
3776 if (!firstType.hasRank()) {
3777 inferredReturnTypes.push_back(UnrankedTensorType::get(outElement));
3778 return success();
3779 }
3780
3781 auto outShape = llvm::to_vector<6>(firstType.getShape());
3782
3783 // Determine what the non-concatenate dimensions should be.
3784 for (auto type : operands.getTypes()) {
3785 auto shapedTy = type.cast<ShapedType>();
3786 if (!shapedTy.hasRank()) {
3787 continue;
3788 }
3789
3790 for (const auto& it : llvm::enumerate(shapedTy.getShape())) {
3791 // If a dimension is not dynamic, the output shape should match.
3792 if (ShapedType::isDynamic(outShape[it.index()])) {
3793 outShape[it.index()] = it.value();
3794 }
3795 }
3796 }
3797
3798 outShape[dimension] = 0;
3799
3800 for (auto operand : operands.getTypes()) {
3801 auto type = operand.cast<ShapedType>();
3802 if (!type.hasRank()) {
3803 inferredReturnTypes.push_back(UnrankedTensorType::get(outElement));
3804 return success();
3805 }
3806
3807 // If the dimension is dynamic we know the output dimension is dynamic.
3808 auto dim = type.getShape()[dimension];
3809 if (ShapedType::isDynamic(dim)) {
3810 outShape[dimension] = ShapedType::kDynamicSize;
3811 break;
3812 }
3813
3814 outShape[dimension] += dim;
3815 }
3816
3817 inferredReturnTypes.push_back(RankedTensorType::get(outShape, outElement));
3818
3819 return success();
3820 }
3821
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3822 void ConcatenateOp::getCanonicalizationPatterns(RewritePatternSet& results,
3823 MLIRContext* context) {
3824 results.add<ConcatenateOperandRemoval, ConcatenateForwarding,
3825 SingleOperandConcatenateToCast>(context);
3826 }
3827
3828 template <typename T>
foldConcatenateHelper(ConcatenateOp * op,ArrayRef<Attribute> operands)3829 static Attribute foldConcatenateHelper(ConcatenateOp* op,
3830 ArrayRef<Attribute> operands) {
3831 auto axis = op->dimension();
3832 auto type = op->getType().cast<ShapedType>();
3833 auto shape = type.getShape();
3834
3835 size_t topSize = 1;
3836 for (int i = 0, e = axis; i < e; i++) {
3837 topSize = topSize * shape[i];
3838 }
3839
3840 // Prevent folding if the result is too large.
3841 if (type.getNumElements() > kFoldOpEltLimit) return {};
3842
3843 SmallVector<T, 6> values;
3844 for (size_t i = 0; i < topSize; i++) {
3845 for (auto operand : operands) {
3846 DenseElementsAttr attr = operand.cast<DenseElementsAttr>();
3847 size_t bottomSize = attr.getNumElements() / topSize;
3848 auto iter = attr.getValues<T>().begin() + i * bottomSize;
3849 values.append(iter, iter + bottomSize);
3850 }
3851 }
3852
3853 return DenseElementsAttr::get(type, values);
3854 }
3855
foldConcatenate(ConcatenateOp * op,ArrayRef<Attribute> operands)3856 static Attribute foldConcatenate(ConcatenateOp* op,
3857 ArrayRef<Attribute> operands) {
3858 for (auto operand : operands) {
3859 if (!operand) return {};
3860 }
3861
3862 auto type = op->getResult().getType().cast<ShapedType>();
3863 auto etype = type.getElementType();
3864 if (etype.isa<IntegerType>()) {
3865 return foldConcatenateHelper<APInt>(op, operands);
3866 }
3867
3868 if (etype.isa<FloatType>()) {
3869 return foldConcatenateHelper<APFloat>(op, operands);
3870 }
3871
3872 return {};
3873 }
3874
fold(ArrayRef<Attribute> operands)3875 OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
3876 if (getNumOperands() == 1) return getOperand(0);
3877
3878 ShapedType type = getResult().getType().cast<ShapedType>();
3879 if (!type.hasStaticShape()) return {};
3880
3881 auto axis = dimension();
3882 if (auto attr = foldConcatenate(this, operands)) {
3883 return attr;
3884 }
3885
3886 for (auto operand : getOperands()) {
3887 auto ty = operand.getType().cast<ShapedType>();
3888 if (ty.getDimSize(axis) != 0) {
3889 return {};
3890 }
3891 }
3892
3893 return DenseElementsAttr::get(type, ArrayRef<Attribute>());
3894 }
3895
verify()3896 LogicalResult ConcatenateOp::verify() {
3897 RankedTensorType firstRankedType;
3898 int firstRankedIndex;
3899 int numOperands = getNumOperands();
3900 int64_t concatDimension = static_cast<int64_t>(dimension());
3901 if (concatDimension < 0) {
3902 return emitOpError(
3903 llvm::formatv("dimension {0} is negative", concatDimension));
3904 }
3905 for (int i = 0; i < numOperands; i++) {
3906 auto secondType = getOperand(i).getType().dyn_cast<ShapedType>();
3907 if (!secondType.hasRank()) {
3908 continue;
3909 }
3910
3911 if (!firstRankedType) {
3912 firstRankedType = secondType.cast<RankedTensorType>();
3913 firstRankedIndex = i;
3914 if (firstRankedType.getRank() == 0)
3915 return emitOpError(
3916 llvm::formatv("rank-0 values cannot be concatenated"));
3917 if (concatDimension >= firstRankedType.getRank()) {
3918 return emitOpError(
3919 llvm::formatv("dimension {0} is out-of-bounds for input rank {1}",
3920 concatDimension, firstRankedType.getRank()));
3921 }
3922 continue;
3923 }
3924
3925 if (firstRankedType.getRank() != secondType.getRank()) {
3926 return emitOpError(llvm::formatv(
3927 "operands ({0}) and ({1}) do not match rank", firstRankedIndex, i));
3928 }
3929
3930 auto firstShape = firstRankedType.getShape();
3931 auto secondShape = secondType.getShape();
3932 for (int d = 0; d < firstRankedType.getRank(); ++d) {
3933 if (!ShapedType::isDynamic(firstShape[d]) &&
3934 !ShapedType::isDynamic(secondShape[d]) &&
3935 firstShape[d] != secondShape[d] && d != concatDimension) {
3936 return emitOpError(llvm::formatv(
3937 "shapes of operand ({0}) and ({1}) do not match at non-concat "
3938 "index: ({2}) != ({3}) at non-concat index {4}",
3939 firstRankedIndex, i,
3940 llvm::make_range(firstShape.begin(), firstShape.end()),
3941 llvm::make_range(secondShape.begin(), secondShape.end()), d));
3942 }
3943 }
3944 }
3945 return success();
3946 }
3947
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3948 LogicalResult ConcatenateOp::reifyReturnTypeShapes(
3949 OpBuilder& builder, ValueRange operands,
3950 SmallVectorImpl<Value>& reifiedReturnShapes) {
3951 ConcatenateOp::Adaptor adaptor(operands);
3952 auto inputs = adaptor.val();
3953
3954 auto operandType = inputs[0].getType().dyn_cast<RankedTensorType>();
3955 // Not support unranked type a.t.m.
3956 if (!operandType) return failure();
3957
3958 Location loc = this->getLoc();
3959 Type shapeScalarType = builder.getIndexType();
3960 auto toShapeScalarType = [&](Value v) {
3961 return maybeCastTo(builder, loc, v, shapeScalarType);
3962 };
3963
3964 SmallVector<SmallVector<Value, 4>, 4> allShapeValues;
3965 for (size_t inputId = 0; inputId < inputs.size(); ++inputId) {
3966 Value operand = inputs[inputId];
3967 auto operandType = operand.getType().dyn_cast<RankedTensorType>();
3968 if (!operandType) return failure();
3969
3970 SmallVector<Value, 4> shapeVals;
3971 for (const auto& element : llvm::enumerate(operandType.getShape())) {
3972 Value valueDim = toShapeScalarType(
3973 builder.create<tensor::DimOp>(loc, operand, element.index()));
3974 shapeVals.push_back(valueDim);
3975 }
3976 allShapeValues.emplace_back(std::move(shapeVals));
3977 }
3978
3979 int axis = this->dimension();
3980 auto& shapeValues = allShapeValues[0];
3981 for (size_t vecId = 1; vecId < allShapeValues.size(); ++vecId) {
3982 auto& otherShapeValues = allShapeValues[vecId];
3983 if (otherShapeValues.size() != shapeValues.size()) {
3984 this->emitOpError()
3985 << "Concatenate expects all operands must be of the same rank";
3986 return failure();
3987 }
3988 shapeValues[axis] = builder.create<arith::AddIOp>(loc, shapeValues[axis],
3989 otherShapeValues[axis]);
3990 }
3991
3992 Value outputShape = builder.create<tensor::FromElementsOp>(
3993 loc,
3994 RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
3995 shapeScalarType),
3996 shapeValues);
3997 reifiedReturnShapes.push_back(outputShape);
3998
3999 return success();
4000 }
4001
4002 //===----------------------------------------------------------------------===//
4003 // DynamicReshapeOp
4004 //===----------------------------------------------------------------------===//
4005
verify()4006 LogicalResult DynamicReshapeOp::verify() {
4007 auto resultType = result().getType().dyn_cast<RankedTensorType>();
4008 auto outputShapeType = output_shape().getType().dyn_cast<RankedTensorType>();
4009 if (resultType && outputShapeType && outputShapeType.hasStaticShape() &&
4010 outputShapeType.getDimSize(0) != resultType.getRank()) {
4011 return emitError() << "output should have a rank equal to the number of "
4012 "elements in output_shape";
4013 }
4014 return success();
4015 }
4016
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4017 LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
4018 OpBuilder& builder, ValueRange operands,
4019 SmallVectorImpl<Value>& reifiedReturnShapes) {
4020 DynamicReshapeOp::Adaptor adaptor(operands);
4021 reifiedReturnShapes.push_back(
4022 castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
4023 return success();
4024 }
4025
4026 namespace {
4027 class DynamicReshapeOpNotActuallyDynamic
4028 : public OpRewritePattern<DynamicReshapeOp> {
4029 public:
4030 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const4031 LogicalResult matchAndRewrite(DynamicReshapeOp op,
4032 PatternRewriter& rewriter) const override {
4033 auto type = op.result().getType().dyn_cast<RankedTensorType>();
4034 if (!type || !type.hasStaticShape()) {
4035 return rewriter.notifyMatchFailure(op, "requires static shape tensor");
4036 }
4037 rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand());
4038 return success();
4039 }
4040 };
4041
4042 // Canonicalizes
4043 // %0 = some_op(%tensor)
4044 // %1 = "mhlo.dynamic_reshape"(%0, %shape)
4045 // (tensor<?xT>, tensor<1xindex>) -> tensor<?xT>
4046 // ... uses of %1.
4047 //
4048 // into
4049 //
4050 // ... uses of %0.
4051 // This canonicalization is only correct if the input is correct!
4052 // TODO(b/178779691): Use a more sophisticated canonicalization that preserves
4053 // errors in input, and still allows us to get rid of redundant reshapes.
4054 class RemoveRedundantRank1DynamicReshape
4055 : public OpRewritePattern<DynamicReshapeOp> {
4056 public:
4057 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const4058 LogicalResult matchAndRewrite(DynamicReshapeOp op,
4059 PatternRewriter& rewriter) const override {
4060 auto type = op.result().getType().dyn_cast<RankedTensorType>();
4061 if (!type || type.getRank() != 1 || type.hasStaticShape()) {
4062 return rewriter.notifyMatchFailure(
4063 op, "requires rank 1 shape tensor with dynamic dimension");
4064 }
4065 auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
4066 if (!operandType || operandType.getRank() != 1 ||
4067 operandType.hasStaticShape()) {
4068 return rewriter.notifyMatchFailure(
4069 op, "requires rank 1 shape tensor with dynamic dimension");
4070 }
4071 rewriter.replaceOp(op, {op.operand()});
4072 return success();
4073 }
4074 };
4075
4076 // Canonicalizes
4077 // %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
4078 // %1 = same_operands_and_result_shape_op(%tensor)
4079 // %2 = "mhlo.dynamic_reshape"(%1, %shape)
4080 // ... uses of %2.
4081 //
4082 // into
4083 //
4084 // %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
4085 // %1 = same_operands_and_result_shape_op(%tensor)
4086 // ... uses of %1.
4087 class DynamicReshapeOpSameShapeOpResult
4088 : public OpRewritePattern<DynamicReshapeOp> {
4089 public:
4090 using OpRewritePattern::OpRewritePattern;
4091
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const4092 LogicalResult matchAndRewrite(DynamicReshapeOp op,
4093 PatternRewriter& rewriter) const override {
4094 Operation* defOp = op.operand().getDefiningOp();
4095 if (!defOp ||
4096 !defOp->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
4097 return failure();
4098 }
4099 Operation* inputDefOp = defOp->getOperand(0).getDefiningOp();
4100 if (!inputDefOp) {
4101 return failure();
4102 }
4103 auto reshape = dyn_cast<DynamicReshapeOp>(*inputDefOp);
4104 if (reshape && reshape.output_shape() == op.output_shape()) {
4105 rewriter.replaceOp(op, {defOp->getResult(0)});
4106 return success();
4107 }
4108 return failure();
4109 }
4110 };
4111 } // namespace
4112
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)4113 void DynamicReshapeOp::getCanonicalizationPatterns(RewritePatternSet& results,
4114 MLIRContext* context) {
4115 // clang-format off
4116 results.add<
4117 DynamicReshapeOpNotActuallyDynamic,
4118 DynamicReshapeOpSameShapeOpResult,
4119 RemoveRedundantDynamicBroadcast,
4120 RemoveRedundantDynamicReshape,
4121 RemoveRedundantRank1DynamicReshape,
4122 ShapeOfDynamicReshape
4123 >(context);
4124 // clang-format on
4125 }
4126
4127 //===----------------------------------------------------------------------===//
4128 // DynamicSliceOp
4129 //===----------------------------------------------------------------------===//
4130
4131 namespace {
4132 // Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops.
4133 // This canonicalization is applied the case when the `begin` input values are
4134 // compile time constants and thus can be made into a tensor.
4135 struct DynamicSliceToSlice : public OpRewritePattern<DynamicSliceOp> {
4136 using OpRewritePattern<DynamicSliceOp>::OpRewritePattern;
4137
matchAndRewritemlir::mhlo::__anon00baf10a2411::DynamicSliceToSlice4138 LogicalResult matchAndRewrite(DynamicSliceOp dynamicSlice,
4139 PatternRewriter& rewriter) const override {
4140 Value input = dynamicSlice.operand();
4141 auto inputTensor = input.getType().dyn_cast<RankedTensorType>();
4142 if (!inputTensor || !inputTensor.hasStaticShape()) return failure();
4143
4144 auto sliceSizes = dynamicSlice.slice_sizes().getValues<int64_t>();
4145 SmallVector<int64_t, 4> tempStartIndices;
4146 for (const auto& indexAndSliceStart :
4147 llvm::enumerate(dynamicSlice.start_indices())) {
4148 APInt val;
4149 Value start = indexAndSliceStart.value();
4150 int64_t index = indexAndSliceStart.index();
4151 if (!matchPattern(start, m_ConstantInt(&val))) {
4152 return failure();
4153 }
4154 // Clamp the indices within bounds to faithfully mirror dynamic slice
4155 // semantics.
4156 int64_t clampedStart =
4157 clamp(val.getSExtValue(), static_cast<int64_t>(0),
4158 inputTensor.getDimSize(index) - sliceSizes[index]);
4159 tempStartIndices.push_back(clampedStart);
4160 }
4161
4162 // At this point we've determined that the start indices are all constants;
4163 // pack them into a single tensor.
4164 auto loc = dynamicSlice.getLoc();
4165 int64_t inputRank = inputTensor.getRank();
4166 auto sliceStartIndices = rewriter.getI64TensorAttr(tempStartIndices);
4167 DenseIntElementsAttr sliceLimits = buildSliceLimits(
4168 sliceStartIndices, dynamicSlice.slice_sizes(), &rewriter);
4169 DenseIntElementsAttr sliceStrides =
4170 rewriter.getI64TensorAttr(SmallVector<int64_t, 4>(inputRank, 1));
4171 auto result = rewriter.create<SliceOp>(loc, input, sliceStartIndices,
4172 sliceLimits, sliceStrides);
4173 rewriter.replaceOp(dynamicSlice, {result});
4174 return success();
4175 }
4176 };
4177
4178 } // namespace
4179
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)4180 void DynamicSliceOp::getCanonicalizationPatterns(RewritePatternSet& results,
4181 MLIRContext* context) {
4182 results.add<DynamicSliceToSlice>(context);
4183 }
4184
4185 // Verifies that the number of slice sizes and the number of start indices match
verify()4186 LogicalResult DynamicSliceOp::verify() {
4187 int numSliceSizes = slice_sizes().getNumElements();
4188 int numStartIndices = start_indices().size();
4189 if (numStartIndices != numSliceSizes) {
4190 return emitOpError() << "has mismatched number of slice sizes ("
4191 << numSliceSizes << ") and number of start indices ("
4192 << numStartIndices << ")";
4193 }
4194 auto operandType = operand().getType().dyn_cast<RankedTensorType>();
4195 if (!operandType) return failure();
4196
4197 if (operandType.getRank() != numStartIndices) {
4198 return emitOpError() << "has mismatched number of start indices ("
4199 << numStartIndices << ") and the rank of operand ("
4200 << operandType.getRank() << ")";
4201 }
4202
4203 for (int i = 0; i < numSliceSizes; ++i) {
4204 int64_t sliceSize = slice_sizes().getValues<int64_t>()[i];
4205 if (sliceSize < 0) {
4206 return emitOpError() << "has negative size index to dynamic slice: "
4207 << sliceSize;
4208 }
4209 if (!operandType.isDynamicDim(i)) {
4210 int64_t dimSize = operandType.getDimSize(i);
4211 if (sliceSize > dimSize) {
4212 return emitOpError() << "has slice size " << sliceSize
4213 << " greater than dimension size " << dimSize
4214 << " in dimension " << i << " of operand";
4215 }
4216 }
4217 }
4218 return success();
4219 }
4220
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)4221 LogicalResult DynamicSliceOp::inferReturnTypeComponents(
4222 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
4223 DictionaryAttr attributes, RegionRange regions,
4224 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
4225 DynamicSliceOp::Adaptor adaptor(operands, attributes, regions);
4226 Value operand = adaptor.operand();
4227 auto operandType = operand.getType().dyn_cast<RankedTensorType>();
4228 if (!operandType) return failure();
4229
4230 auto sliceSizes = adaptor.slice_sizes();
4231 Type elementTy = operandType.getElementType();
4232 inferredReturnShapes.emplace_back(sliceSizes.getValues<int64_t>(), elementTy);
4233 return success();
4234 }
4235
4236 //===----------------------------------------------------------------------===//
4237 // RealDynamicSliceOp
4238 //===----------------------------------------------------------------------===//
4239 // Verifies that operand rank matches start_indices/limit_indices/strides size
verify()4240 LogicalResult RealDynamicSliceOp::verify() {
4241 auto inputType = operand().getType().dyn_cast<RankedTensorType>();
4242 // If operand is unranked, there is very little to verify statically.
4243 if (!inputType) return success();
4244 int inputRank = inputType.getRank();
4245
4246 auto startType = start_indices().getType().cast<RankedTensorType>();
4247 auto limitType = limit_indices().getType().cast<RankedTensorType>();
4248 auto stridesType = strides().getType().cast<RankedTensorType>();
4249
4250 if (inputRank != startType.getNumElements()) {
4251 return emitOpError() << "has mismatched number of operand rank ("
4252 << inputRank << ") and start_indices size ("
4253 << startType.getNumElements() << ")";
4254 }
4255
4256 if (inputRank != limitType.getNumElements()) {
4257 return emitOpError() << "has mismatched number of operand rank ("
4258 << inputRank << ") and limit_indices size ("
4259 << limitType.getNumElements() << ")";
4260 }
4261
4262 if (inputRank != stridesType.getNumElements()) {
4263 return emitOpError() << "has mismatched number of operand rank ("
4264 << inputRank << ") and strides size ("
4265 << stridesType.getNumElements() << ")";
4266 }
4267
4268 return success();
4269 }
4270
4271 namespace {
4272 // Canonicalizes RealDynamicSlice ops that can be replaced instead with Slice
4273 // ops. This canonicalization is applied the case when the `begin` input values
4274 // are compile time constants and thus can be made into a tensor.
4275 struct RealDynamicSliceIsStatic : public OpRewritePattern<RealDynamicSliceOp> {
4276 using OpRewritePattern<RealDynamicSliceOp>::OpRewritePattern;
4277
matchAndRewritemlir::mhlo::__anon00baf10a2511::RealDynamicSliceIsStatic4278 LogicalResult matchAndRewrite(RealDynamicSliceOp realDynamicSlice,
4279 PatternRewriter& rewriter) const override {
4280 Location loc = realDynamicSlice.getLoc();
4281 Value input = realDynamicSlice.operand();
4282 Value output = realDynamicSlice.result();
4283 auto inputTy = input.getType().dyn_cast<RankedTensorType>();
4284 auto outputTy = output.getType().dyn_cast<RankedTensorType>();
4285
4286 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
4287 !outputTy.hasStaticShape()) {
4288 return failure();
4289 }
4290
4291 int64_t inputRank = inputTy.getRank();
4292
4293 auto startVal = realDynamicSlice.start_indices();
4294 auto limitVal = realDynamicSlice.limit_indices();
4295 auto strideVal = realDynamicSlice.strides();
4296 auto startOp = startVal.getDefiningOp<mlir::arith::ConstantOp>();
4297 auto limitOp = limitVal.getDefiningOp<mlir::arith::ConstantOp>();
4298 auto strideOp = strideVal.getDefiningOp<mlir::arith::ConstantOp>();
4299 if (!startOp || !limitOp || !strideOp) return failure();
4300
4301 auto startAttr =
4302 startOp.getValue().dyn_cast_or_null<DenseIntElementsAttr>();
4303 auto limitAttr =
4304 limitOp.getValue().dyn_cast_or_null<DenseIntElementsAttr>();
4305 auto strideAttr =
4306 strideOp.getValue().dyn_cast_or_null<DenseIntElementsAttr>();
4307 if (!startAttr || !limitAttr || !strideAttr) return failure();
4308
4309 SmallVector<int64_t, 4> tempStartIndices;
4310 SmallVector<int64_t, 4> tempLimitIndices;
4311 SmallVector<int64_t, 4> tempStride;
4312 for (int64_t dimIdx = 0; dimIdx < inputRank; dimIdx++) {
4313 int64_t start = startAttr.getValues<IntegerAttr>()[dimIdx].getInt();
4314 tempStartIndices.push_back(start);
4315 int64_t limit = limitAttr.getValues<IntegerAttr>()[dimIdx].getInt();
4316 tempLimitIndices.push_back(limit);
4317 int64_t end = strideAttr.getValues<IntegerAttr>()[dimIdx].getInt();
4318 tempStride.push_back(end);
4319 }
4320
4321 DenseIntElementsAttr sliceStartIndices =
4322 rewriter.getI64TensorAttr(tempStartIndices);
4323 DenseIntElementsAttr sliceLimitIndices =
4324 rewriter.getI64TensorAttr(tempLimitIndices);
4325 DenseIntElementsAttr sliceStrides = rewriter.getI64TensorAttr(tempStride);
4326 auto result = rewriter.create<SliceOp>(loc, input, sliceStartIndices,
4327 sliceLimitIndices, sliceStrides);
4328 rewriter.replaceOp(realDynamicSlice, {result});
4329 return success();
4330 }
4331 };
4332 } // namespace
4333
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)4334 void RealDynamicSliceOp::getCanonicalizationPatterns(RewritePatternSet& results,
4335 MLIRContext* context) {
4336 results.add<RealDynamicSliceIsStatic, RealDSliceToSlice>(context);
4337 }
4338
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4339 LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes(
4340 OpBuilder& builder, ValueRange operands,
4341 SmallVectorImpl<Value>& reifiedReturnShapes) {
4342 RealDynamicSliceOp::Adaptor adaptor(operands);
4343 Value operand = adaptor.operand();
4344 Value startIndices = adaptor.start_indices();
4345 Value limitIndices = adaptor.limit_indices();
4346 Value strides = adaptor.strides();
4347
4348 auto operandType = operand.getType().dyn_cast<RankedTensorType>();
4349 // Not support unranked type a.t.m.
4350 if (!operandType) return failure();
4351
4352 Location loc = this->getLoc();
4353 SmallVector<Value, 4> shapeValues;
4354 shapeValues.reserve(operandType.getRank());
4355 Type shapeScalarType =
4356 startIndices.getType().cast<ShapedType>().getElementType();
4357 Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
4358 one = maybeCastTo(builder, loc, one, shapeScalarType);
4359 for (const auto& element : llvm::enumerate(operandType.getShape())) {
4360 Value offset = builder.create<arith::ConstantIndexOp>(loc, element.index());
4361 Value valueStart =
4362 builder.create<tensor::ExtractOp>(loc, startIndices, offset);
4363 Value valueLimit =
4364 builder.create<tensor::ExtractOp>(loc, limitIndices, offset);
4365 Value valueStride = builder.create<tensor::ExtractOp>(loc, strides, offset);
4366 // size = (limit - start + stride - 1) / stride
4367 shapeValues.push_back(builder.create<arith::DivSIOp>(
4368 loc,
4369 builder.create<arith::SubIOp>(
4370 loc,
4371 builder.create<arith::AddIOp>(
4372 loc, valueStride,
4373 builder.create<arith::SubIOp>(loc, valueLimit, valueStart)),
4374 one),
4375 valueStride));
4376 }
4377
4378 reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
4379 loc,
4380 RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
4381 shapeScalarType),
4382 shapeValues));
4383 return success();
4384 }
4385
4386 //===----------------------------------------------------------------------===//
4387 // InfeedOp
4388 //===----------------------------------------------------------------------===//
4389
4390 // Checks that the result type is of the form `zero_or_more_type(s),
4391 // mhlo::token`
verify()4392 LogicalResult InfeedOp::verify() {
4393 auto resultTypes = getResultTypes();
4394 if (resultTypes.empty())
4395 return emitOpError()
4396 << "result is expected to be at least of size 1, but got "
4397 << resultTypes.size();
4398
4399 if (!resultTypes[resultTypes.size() - 1].isa<TokenType>())
4400 return emitOpError() << "last element of result types is expected to "
4401 "be of token type, but got "
4402 << resultTypes[resultTypes.size() - 1];
4403
4404 // Verify layout attribute
4405 constexpr char kLayoutAttr[] = "layout";
4406 if (!getOperation()->hasAttr(kLayoutAttr)) return success();
4407
4408 mlir::ArrayAttr layout =
4409 getOperation()->getAttrOfType<mlir::ArrayAttr>(kLayoutAttr);
4410 if (!layout)
4411 return emitOpError() << "layout-attribute expected to be of array-type.";
4412
4413 if (layout.size() != resultTypes.size() - 1) {
4414 return emitOpError() << "layout-attribute size must be "
4415 << resultTypes.size() - 1
4416 << " (which is the number of "
4417 "op-results - 1 (for token result)), but got "
4418 << layout.size();
4419 }
4420
4421 for (auto childLayout : layout) {
4422 mlir::ArrayAttr childLayoutArr = childLayout.dyn_cast<mlir::ArrayAttr>();
4423 if (!childLayoutArr) {
4424 return emitOpError() << "layout-attribute expected to have "
4425 "elements of type array, but got "
4426 << childLayout;
4427 }
4428
4429 for (auto i : childLayoutArr) {
4430 mlir::IntegerAttr attr = i.dyn_cast<mlir::IntegerAttr>();
4431 if (!attr) {
4432 return emitOpError() << "layout-attribute's leaf elements are "
4433 "expected to be of type integer, but got "
4434 << i;
4435 }
4436 }
4437 }
4438
4439 return success();
4440 }
4441
4442 //===----------------------------------------------------------------------===//
4443 // MapOp
4444 //===----------------------------------------------------------------------===//
4445
verify()4446 LogicalResult MapOp::verify() {
4447 // Checks if the number of `operands` match the arity of the map `computation`
4448 // region.
4449 auto& computationBlock = computation().front();
4450 auto computationArgs = computationBlock.getArguments();
4451 if (operands().size() != computationArgs.size())
4452 return emitOpError() << "expects number of operands to match the arity "
4453 "of map computation, but got: "
4454 << operands().size() << " and "
4455 << computationArgs.size();
4456
4457 // The parameters of computation should all be scalars and match the element
4458 // type of operands.
4459 for (const auto& indexedArg : llvm::enumerate(computationArgs)) {
4460 auto argType = indexedArg.value().getType().dyn_cast<TensorType>();
4461 if (!argType || argType.getRank() != 0)
4462 return emitOpError()
4463 << "computation arguments must be 0-rank tensor, but got: arg #"
4464 << indexedArg.index() << " of type "
4465 << indexedArg.value().getType();
4466 auto operandElemTy = operands()[indexedArg.index()]
4467 .getType()
4468 .cast<TensorType>()
4469 .getElementType();
4470 if (argType.getElementType() != operandElemTy) {
4471 return emitOpError()
4472 << "element type of operands and computation arguments must "
4473 "match, but got: "
4474 << operandElemTy << " and " << argType.getElementType();
4475 }
4476 }
4477
4478 // Mapped computation must return single output
4479 auto computationOutputs = computationBlock.getTerminator()->getOperands();
4480 if (computationOutputs.size() != 1)
4481 return emitOpError() << "computation must return single output, but got: "
4482 << computationOutputs.size();
4483
4484 // The output of computation must be scalar and have the same element type
4485 // as op result.
4486 auto computationOutputType =
4487 computationOutputs[0].getType().dyn_cast<TensorType>();
4488 if (!computationOutputType || computationOutputType.getRank() != 0)
4489 return emitOpError() << "computation must return 0-rank tensor, but got: "
4490 << computationOutputs[0].getType();
4491
4492 auto resultType = getType().cast<TensorType>();
4493 if (computationOutputType.getElementType() != resultType.getElementType())
4494 return emitOpError() << "element type of result and computation output "
4495 "must match, but got: "
4496 << resultType.getElementType() << " and "
4497 << computationOutputType.getElementType();
4498
4499 // Checks that the requested map dimension numbers are monotonically
4500 // increasing.
4501 DenseIntElementsAttr dimensions = this->dimensions();
4502 for (const auto& indexedValue :
4503 llvm::enumerate(dimensions.getValues<int64_t>())) {
4504 if (indexedValue.value() != static_cast<int64_t>(indexedValue.index()))
4505 return emitOpError() << "requires monotonically increasing dimension "
4506 "numbers, but got: "
4507 << dimensions;
4508 }
4509
4510 // Checks that number of dimensions of operands matches the size of
4511 // `dimensions` since we currently only support mapping across all
4512 // dimensions: i.e., scalar map functions.
4513 auto operandType = operands()[0].getType().cast<TensorType>();
4514 if (operandType.hasRank()) {
4515 if (dimensions.size() !=
4516 static_cast<int64_t>(operandType.getShape().size()))
4517 return emitOpError()
4518 << "applied to a subset of dimensions currently not supported: "
4519 "operand dimensions = "
4520 << operandType.getShape().size()
4521 << ", requested map dimensions size = " << dimensions.size();
4522 }
4523
4524 return success();
4525 }
4526
fold(ArrayRef<Attribute> operands)4527 OpFoldResult MapOp::fold(ArrayRef<Attribute> operands) {
4528 mlir::Block& bb = computation().front();
4529 mlir::Operation& frontOp = bb.front();
4530
4531 auto retOp = mlir::dyn_cast<ReturnOp>(frontOp);
4532 if (!retOp) return nullptr;
4533 if (retOp.results().size() != 1) return nullptr;
4534
4535 for (mlir::BlockArgument barg : bb.getArguments()) {
4536 if (barg == retOp.results()[0]) return getOperands()[barg.getArgNumber()];
4537 }
4538 return nullptr;
4539 }
4540
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4541 LogicalResult MapOp::reifyReturnTypeShapes(
4542 OpBuilder& builder, ValueRange operands,
4543 SmallVectorImpl<Value>& reifiedReturnShapes) {
4544 return deriveShapeFromOperand(&builder, getOperation(), operands.front(),
4545 &reifiedReturnShapes);
4546 }
4547
4548 //===----------------------------------------------------------------------===//
4549 // RecvOp
4550 //===----------------------------------------------------------------------===//
4551
4552 // Checks that the result type is of the form `zero_or_more_type(s),
4553 // mhlo::token`
verify()4554 LogicalResult RecvOp::verify() {
4555 auto resultTypes = getResultTypes();
4556 if (resultTypes.empty())
4557 return emitOpError()
4558 << "result is expected to be at least of size 1, but got "
4559 << resultTypes.size();
4560 if (!resultTypes[resultTypes.size() - 1].isa<TokenType>())
4561 return emitOpError() << "last element of result types is expected to "
4562 "be of token type, but got "
4563 << resultTypes[resultTypes.size() - 1];
4564 return success();
4565 }
4566
4567 //===----------------------------------------------------------------------===//
4568 // CopyOp
4569 //===----------------------------------------------------------------------===//
4570
fold(ArrayRef<Attribute> operands)4571 OpFoldResult CopyOp::fold(ArrayRef<Attribute> operands) { return getOperand(); }
4572
4573 //===----------------------------------------------------------------------===//
4574 // ReduceWindowOp
4575 //===----------------------------------------------------------------------===//
4576
4577 namespace {
4578 // Infer the return-type of ReduceWindowOp.
inferReduceWindowOpReturnType(ArrayRef<TensorType> inputTypes,ArrayRef<TensorType> initTypes,const ArrayRef<WindowDimension> window)4579 SmallVector<TensorType> inferReduceWindowOpReturnType(
4580 ArrayRef<TensorType> inputTypes, ArrayRef<TensorType> initTypes,
4581 const ArrayRef<WindowDimension> window) {
4582 SmallVector<TensorType> outputTypes;
4583 for (size_t i = 0; i < inputTypes.size(); ++i) {
4584 if (!inputTypes[i].hasRank()) {
4585 outputTypes.push_back(
4586 UnrankedTensorType::get(initTypes[i].getElementType()));
4587 continue;
4588 }
4589
4590 outputTypes.push_back(RankedTensorType::get(
4591 inferWindowOutputShape(inputTypes[i].getShape(), window),
4592 initTypes[i].getElementType()));
4593 }
4594
4595 return outputTypes;
4596 }
4597 } // namespace
4598
4599 // We intend to verify the following properties
4600 // P1. The sizes of 'inputs' and 'init_values' must be at least 1.
4601 // P2. All `inputs` need to have compatible shapes.
4602 // P3. size-of(window_dimension) == rank-of(input),
4603 // where input is an element of 'inputs'.
4604 // P4. Verify and collect the window atributes.
4605 // P5. Verify the inner block defining the reducer function.
4606 // P6. Verify the return type.
verify()4607 LogicalResult ReduceWindowOp::verify() {
4608 // P1.
4609 // Note that the ODS ensures that there are even number of operands; Check if
4610 // that number is not zero.
4611 if (getOperands().empty())
4612 return emitOpError() << "expects the size of operands to be >= 2.";
4613
4614 // Collect the input and init-value operands. Note that the operand-type is
4615 // enforced as "TensorType" by ODS.
4616 int64_t numInputs = getNumOperands() / 2;
4617 auto operandTensorTypes = llvm::to_vector<4>(llvm::map_range(
4618 getOperandTypes(),
4619 [](Type t) -> TensorType { return t.cast<TensorType>(); }));
4620 ArrayRef<TensorType> inputTypes(operandTensorTypes.begin(),
4621 operandTensorTypes.begin() + numInputs);
4622 ArrayRef<TensorType> initTypes(operandTensorTypes.begin() + numInputs,
4623 operandTensorTypes.end());
4624
4625 // P2.
4626 if (failed(verifyCompatibleShapes(operands().getTypes())))
4627 return emitOpError() << "requires same shape for all inputs";
4628
4629 // P3.
4630 SmallVector<int64_t> windowDims =
4631 convertDenseIntAttr(this->window_dimensions());
4632 for (const auto inputType : inputTypes) {
4633 if (!inputType.hasRank()) continue;
4634 if (inputType.getRank() != static_cast<int64_t>(windowDims.size()))
4635 return emitOpError()
4636 << "expects window-dimensions size == input rank, but got "
4637 "window-dimensions size: "
4638 << windowDims.size() << " and input: " << inputType
4639 << " with rank = " << inputType.getRank() << ".";
4640 }
4641
4642 // P4.
4643 auto paddingOrErr = convertNx2Attribute(this->padding(), getLoc());
4644 if (failed(paddingOrErr)) return failure();
4645 SmallVector<std::pair<int64_t, int64_t>> padding = *paddingOrErr;
4646
4647 auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
4648 windowDims, convertDenseIntAttr(window_strides()), padding,
4649 /*lhs_dilation=*/convertDenseIntAttr(base_dilations()),
4650 /*rhs_dilation=*/convertDenseIntAttr(this->window_dilations()), getLoc());
4651 if (failed(windowOrErr)) return failure();
4652
4653 // P5.
4654 bool allInputsUnranked =
4655 llvm::all_of(inputTypes, [](TensorType t) { return !t.hasRank(); });
4656
4657 Block& block = body().front();
4658 SmallVector<TensorType> accumulatorSubshapes;
4659 if (failed(verifyReducerShape(this->getLoc(), block, inputTypes, initTypes,
4660 numInputs, windowDims, allInputsUnranked,
4661 accumulatorSubshapes)))
4662 return failure();
4663
4664 // P6.
4665 if (numInputs != getNumResults())
4666 return emitOpError() << "expects " << numInputs
4667 << " result values, but got " << getNumResults()
4668 << ".";
4669
4670 // The result-type is enforced as "TensorType" by ODS.
4671 auto resultTensorTypes = llvm::to_vector<4>(llvm::map_range(
4672 getResultTypes(),
4673 [](Type t) -> TensorType { return t.cast<TensorType>(); }));
4674
4675 // Check if the element-type of results match with the ones derived from
4676 // the reducer-block. Already ensured that |accumulator_subshapes| ==
4677 // num_inputs == num_of_results.
4678 for (int64_t shapeIdx = 0;
4679 shapeIdx < static_cast<int64_t>(accumulatorSubshapes.size());
4680 shapeIdx++) {
4681 if (accumulatorSubshapes[shapeIdx].getElementType() !=
4682 resultTensorTypes[shapeIdx].getElementType()) {
4683 return emitError()
4684 << "expects the element-type of reduce-op's return-value at index "
4685 << shapeIdx
4686 << " to match the element-type of reducer-block's "
4687 "corresponding return-value, but got "
4688 << resultTensorTypes[shapeIdx].getElementType() << " and "
4689 << accumulatorSubshapes[shapeIdx].getElementType() << " resp.";
4690 }
4691 }
4692
4693 // Check if the shape of results match with the ones derived from
4694 // the input-types and wndow-attributes.
4695 auto inferredReturnTypes = inferReduceWindowOpReturnType(
4696 inputTypes, accumulatorSubshapes, *windowOrErr);
4697
4698 for (size_t i = 0; i < getNumResults(); i++) {
4699 if (failed(verifyCompatibleShape(resultTensorTypes[i],
4700 inferredReturnTypes[i]))) {
4701 return emitOpError()
4702 << "expects result at index " << i
4703 << " to have compatible shape with the corresponding "
4704 "inferred type, but got "
4705 << resultTensorTypes[i] << " and " << inferredReturnTypes[i]
4706 << " resp.";
4707 }
4708 }
4709
4710 return success();
4711 }
4712
4713 // Get the operation used for reduction applied to `result_index`th result. Its
4714 // expected to be a binary operation that consumes `result_index`th and
4715 // `result_index + operands().size`th arguments of the body.
getReductionOp(int resultIndex)4716 Operation* ReduceWindowOp::getReductionOp(int resultIndex) {
4717 auto returnOp = cast<ReturnOp>(body().front().getTerminator());
4718 Operation* computeOp = returnOp.results()[resultIndex].getDefiningOp();
4719 if (computeOp->getNumOperands() != 2) return nullptr;
4720 auto arg0 = computeOp->getOperand(0).dyn_cast<BlockArgument>();
4721 auto arg1 = computeOp->getOperand(1).dyn_cast<BlockArgument>();
4722 if (!arg0 || !arg1) return nullptr;
4723 int64_t arg0Num = arg0.getArgNumber();
4724 int64_t arg1Num = arg1.getArgNumber();
4725 int64_t otherArgIndex = resultIndex + operands().size();
4726 if (arg0Num == resultIndex && arg1Num == otherArgIndex) return computeOp;
4727 if (arg0Num == otherArgIndex && arg1Num == resultIndex &&
4728 computeOp->hasTrait<mlir::OpTrait::IsCommutative>())
4729 return computeOp;
4730 return nullptr;
4731 }
4732
4733 //===----------------------------------------------------------------------===//
4734 // ReducePrecisionOp
4735 //===----------------------------------------------------------------------===//
4736
4737 // The following property is already enforced by the ODS:
4738 // P0. operand element type is float
4739 // P1. mantissa_bits >= 0
4740 // We intend to verify the following properties
4741 // P2. exponent_bits >= 1
verify()4742 LogicalResult ReducePrecisionOp::verify() {
4743 if (exponent_bits() < 1) {
4744 return emitOpError() << "exponent_bits must be at least 1.";
4745 }
4746 return success();
4747 }
4748
4749 //===----------------------------------------------------------------------===//
4750 // ReverseOp
4751 //===----------------------------------------------------------------------===//
4752
4753 template <typename T>
foldReverseHelper(DenseElementsAttr & attr,ShapedType & type,DenseIntElementsAttr & dims)4754 static Attribute foldReverseHelper(DenseElementsAttr& attr, ShapedType& type,
4755 DenseIntElementsAttr& dims) {
4756 int64_t numElements = attr.getNumElements();
4757 // No-op if the tensor has 0 elements.
4758 // No-op if the result of folding is too large.
4759 if (numElements == 0 || numElements > kFoldOpEltLimit) return {};
4760
4761 SmallVector<T> result(attr.getValues<T>().begin(), attr.getValues<T>().end());
4762
4763 size_t rank = type.getRank();
4764 SmallVector<int64_t> stride(rank + 1, numElements);
4765 for (size_t i = 0; i < rank; i++) {
4766 if (type.getDimSize(i) == 0) return {};
4767 stride[i + 1] = stride[i] / type.getDimSize(i);
4768 }
4769
4770 for (auto dim : dims.getValues<int64_t>()) {
4771 // For example, given:
4772 // * tensor: tensor<2x3x2xi32>
4773 // [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9,10], [11, 12]]]
4774 // * dim: [1]
4775 //
4776 // We're going to reverse the tensor with respect to dim as follows:
4777 // 1) Split the tensor into blocks, i.e. smaller tensors whose type is
4778 // derived from the tensor by dropping the first `dim` dimensions, i.e.
4779 // tensor<3x2xi32> for the running example.
4780 // 2) Split each block into windows, i.e. even smaller tensors whose type
4781 // is derived from the block by dropping the first dimension of the
4782 // block, i.e. tensor<2xi32> for the running example.
4783 // 3) Within each block, swap windows but don't change the order of
4784 // elements within the windows: 0th window goes to N-1st spot, 1st window
4785 // goes to N-2nd spot etc.
4786 //
4787 // For the running example, the result will be:
4788 // [[[5, 6], [3, 4], [1, 2]], [[11, 12], [9, 10], [7, 8]]].
4789 //
4790 // Note how elements within windows haven't changed their order with respect
4791 // to each other and how blocks haven't changed their order with respect to
4792 // each other.
4793 int64_t numWindows = type.getDimSize(dim);
4794 int64_t windowSize = stride[dim] / numWindows;
4795
4796 for (int64_t index = 0; index < numElements; index++) {
4797 int64_t blockNumber = index / stride[dim];
4798 int64_t windowNumber = (index % stride[dim]) / windowSize;
4799 int64_t reversedWindowNumber = numWindows - windowNumber - 1;
4800 if (windowNumber >= reversedWindowNumber) continue;
4801 int64_t reversedIndex = blockNumber * stride[dim] +
4802 reversedWindowNumber * windowSize +
4803 index % windowSize;
4804 std::swap(result[index], result[reversedIndex]);
4805 }
4806 }
4807 return DenseElementsAttr::get(type, result);
4808 }
4809
fold(ArrayRef<Attribute> operands)4810 OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
4811 Value input = operand();
4812
4813 // No dimensions to reverse.
4814 DenseIntElementsAttr dims = dimensions();
4815 if (dims.getNumElements() == 0) return input;
4816
4817 // If size of all dimensions to reverse equals 1, then the reverse is a no-op.
4818 // Eg. Reverse dimensions {0,1} of a 1x1x2 tensor
4819 auto shapedType = input.getType().cast<ShapedType>();
4820 if (llvm::all_of(dims.getValues<int64_t>(), [&](int64_t dim) {
4821 return shapedType.getDimSize(dim) == 1;
4822 }))
4823 return input;
4824
4825 // If the operand is a static shaped tensor of constants, return reversed
4826 // tensor
4827 DenseElementsAttr inputAttr =
4828 operands.begin()->dyn_cast_or_null<DenseElementsAttr>();
4829 if (inputAttr && shapedType.hasStaticShape()) {
4830 auto etype = shapedType.getElementType();
4831 if (etype.isa<IntegerType>())
4832 return foldReverseHelper<APInt>(inputAttr, shapedType, dims);
4833 if (etype.isa<FloatType>())
4834 return foldReverseHelper<APFloat>(inputAttr, shapedType, dims);
4835 }
4836
4837 return {};
4838 }
4839
4840 //===----------------------------------------------------------------------===//
4841 // ReduceOp
4842 //===----------------------------------------------------------------------===//
4843
4844 // Returns the result type after reducing operand of the given type across the
4845 // specified dimensions.
getReduceResultType(Type operandTy,DenseIntElementsAttr dimensions,Builder * builder)4846 static TensorType getReduceResultType(Type operandTy,
4847 DenseIntElementsAttr dimensions,
4848 Builder* builder) {
4849 Type elementTy = getElementTypeOrSelf(operandTy);
4850
4851 auto rankedTy = operandTy.dyn_cast<RankedTensorType>();
4852 if (!rankedTy) return UnrankedTensorType::get(elementTy);
4853
4854 int64_t rank = rankedTy.getRank();
4855 llvm::SmallVector<bool, 4> dimsMask(rank, false);
4856 for (int64_t dim : dimensions.getValues<int64_t>()) dimsMask[dim] = true;
4857
4858 SmallVector<int64_t, 4> shape;
4859 for (int64_t i = 0; i < rank; ++i) {
4860 if (!dimsMask[i]) shape.push_back(rankedTy.getDimSize(i));
4861 }
4862
4863 return RankedTensorType::get(shape, elementTy);
4864 }
4865
build(OpBuilder & builder,OperationState & state,ValueRange operands,ValueRange initValues,DenseIntElementsAttr dimensions)4866 void ReduceOp::build(OpBuilder& builder, OperationState& state,
4867 ValueRange operands, ValueRange initValues,
4868 DenseIntElementsAttr dimensions) {
4869 SmallVector<Type, 1> resultTy;
4870 resultTy.reserve(operands.size());
4871
4872 for (Value operand : operands) {
4873 resultTy.push_back(
4874 getReduceResultType(operand.getType(), dimensions, &builder));
4875 }
4876 build(builder, state, resultTy, operands, initValues, dimensions);
4877 }
4878
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)4879 LogicalResult ReduceOp::fold(ArrayRef<Attribute> operands,
4880 SmallVectorImpl<OpFoldResult>& results) {
4881 // No dimensions to reduce.
4882 if (dimensions().getNumElements() == 0) {
4883 for (Value operand : this->operands()) {
4884 results.push_back(operand);
4885 }
4886 return success();
4887 }
4888
4889 // If all returned values in the ReduceOp region exists outside
4890 // the region replace the ReduceOp with those values.
4891 mlir::Block& bb = this->body().front();
4892 SmallVector<Value> replacedResults;
4893 if (auto retOp = mlir::dyn_cast<ReturnOp>(bb.back())) {
4894 for (Value result : retOp.results()) {
4895 if (result.getParentRegion() == retOp->getParentRegion())
4896 return failure();
4897 replacedResults.push_back(result);
4898 }
4899
4900 results.insert(results.end(), replacedResults.begin(),
4901 replacedResults.end());
4902 return success();
4903 }
4904
4905 return failure();
4906 }
4907
hasSameOperandAndResultTypes(Operation & op)4908 bool hasSameOperandAndResultTypes(Operation& op) {
4909 Type expected;
4910 if (op.getNumResults() != 0) expected = op.getResult(0).getType();
4911 if (op.getNumOperands() != 0) expected = op.getOperand(0).getType();
4912 if (!expected) return false;
4913
4914 auto typeMatch = [&](Type actual) { return actual == expected; };
4915 return llvm::all_of(op.getOperandTypes(), typeMatch) &&
4916 llvm::all_of(op.getResultTypes(), typeMatch);
4917 }
4918
4919 // Checks the following eligibility criteria for compact printing of
4920 // mhlo.reduce:
4921 // E1. The reduce-op wraps a single inner-op in the associated region.
4922 // E2. The single operation is a commutative binary-op from mhlo dialect, zero
4923 // region, producing single result such that the operands and result all
4924 // have the same type.
4925 // E3. The reduce-op consist of at least one input-operand; The operand-types of
4926 // inner-op should be derived trivially from the element-type of reduce-op's
4927 // first input-operand.
4928 // E4. The arguments of the region's only basic block are forwarded perfectly
4929 // to inner-op's operands.
4930 // E5. The reduce-op, inner-op, blocks arguments, and the return-op all have the
4931 // same location.
4932 // E6. The single operation result is perfectly forwarded to the reduce op
4933 // return.
isEligibleForCompactPrint(ReduceOp op)4934 static bool isEligibleForCompactPrint(ReduceOp op) {
4935 // Check E1.
4936 auto& block = op.body().front();
4937 if (!hasSingleElement(block.without_terminator())) return false;
4938
4939 Operation& innerOp = *block.begin();
4940
4941 // Check E2.
4942 if (innerOp.getDialect() != op->getDialect()) return false;
4943
4944 if (innerOp.getNumOperands() != 2 ||
4945 !innerOp.hasTrait<mlir::OpTrait::OneResult>() ||
4946 !hasSameOperandAndResultTypes(innerOp) ||
4947 !innerOp.hasTrait<mlir::OpTrait::IsCommutative>() ||
4948 !innerOp.hasTrait<mlir::OpTrait::ZeroRegions>())
4949 return false;
4950
4951 // Check E3.
4952 if (op.operands().empty()) return false;
4953
4954 auto elemType =
4955 op.operands()[0].getType().cast<TensorType>().getElementType();
4956 auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType);
4957 if (innerOp.getOperands()[0].getType() != expectedInnerOpType) return false;
4958
4959 // Check E4.
4960 if (!llvm::equal(block.getArguments(), innerOp.getOperands())) return false;
4961
4962 // Check E5.
4963 auto retOp = dyn_cast<ReturnOp>(block.getTerminator());
4964 if (!retOp) return false;
4965
4966 auto blockArgLoc = block.getArgument(0).getLoc();
4967 if (blockArgLoc != block.getArgument(1).getLoc()) return false;
4968
4969 if (innerOp.getLoc() != op.getLoc() || retOp.getLoc() != op.getLoc() ||
4970 blockArgLoc != op.getLoc())
4971 return false;
4972
4973 // Check E6.
4974 return llvm::equal(innerOp.getResults(), retOp.getOperands());
4975 }
4976
print(OpAsmPrinter & p)4977 void ReduceOp::print(OpAsmPrinter& p) {
4978 {
4979 // Print the pairs of operands under the form:
4980 // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5)
4981 StringRef comma = "";
4982 int numOperandPairs = getNumOperands() / 2;
4983 for (int opId : llvm::seq<int>(0, numOperandPairs)) {
4984 p << comma << "(" << getOperand(opId)
4985 << " init: " << getOperand(opId + numOperandPairs) << ")";
4986 comma = ", ";
4987 }
4988 }
4989
4990 // If the reduce-op is eligible for compact printing, we emit the one-liner:
4991 // mhlo.reduce applies <inner-op> across dimensions = [...] : <func-type>
4992 // Note: We are not printing the function type of reduction operation. We
4993 // have some simplifying assumptions (refer to IsEligibleForCompactPrint::E3)
4994 // to derive the type from that of reduce-op.
4995 if (isEligibleForCompactPrint(*this)) {
4996 Operation& innerOp = body().front().front();
4997 p << " applies ";
4998 printEscapedString(innerOp.getName().getStringRef(), p.getStream());
4999
5000 p << " across dimensions = [";
5001 llvm::interleaveComma(dimensions().getValues<int64_t>(), p);
5002 p << "]";
5003 p << " : ";
5004 p.printFunctionalType(*this);
5005 } else {
5006 p << " across dimensions = [";
5007 llvm::interleaveComma(dimensions().getValues<int64_t>(), p);
5008 p << "]";
5009 p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"});
5010 p << " : ";
5011 p.printFunctionalType(*this);
5012 p.printNewline();
5013 p << " reducer";
5014 {
5015 // Print the pairs of block operands under the form:
5016 // (%arg0_elt, %arg0_acc) (%arg1_elt, %arg1_acc):
5017 Block& reducer = body().front();
5018 int numOperandPairs = getNumOperands() / 2;
5019 for (int opId : llvm::seq<int>(0, numOperandPairs)) {
5020 p << "(";
5021 p.printRegionArgument(reducer.getArgument(opId));
5022 p << ", ";
5023 p.printRegionArgument(reducer.getArgument(opId + numOperandPairs));
5024 p << ") ";
5025 }
5026 }
5027 p << ' ';
5028 p.printRegion(body(), /*printEntryBlockArgs=*/false);
5029 }
5030 }
5031
parse(OpAsmParser & parser,OperationState & result)5032 ParseResult ReduceOp::parse(OpAsmParser& parser, OperationState& result) {
5033 llvm::SMLoc loc = parser.getCurrentLocation();
5034 Location currLocation = parser.getEncodedSourceLoc(loc);
5035
5036 // Parse the operands of reduce-op, this is a list of pair under the form:
5037 // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5)
5038 // Each input to reduce is paired with its init value, even though in memory
5039 // they are stored with the input first and the init values after.
5040 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
5041 SmallVector<OpAsmParser::UnresolvedOperand, 2> initOperands;
5042 do {
5043 (void)parser.parseOptionalComma();
5044 if (parser.parseOptionalLParen()) break;
5045 OpAsmParser::UnresolvedOperand operand, initOperand;
5046 if (parser.parseOperand(operand) || parser.parseKeyword("init") ||
5047 parser.parseColon() || parser.parseOperand(initOperand) ||
5048 parser.parseRParen())
5049 return failure();
5050 operands.push_back(operand);
5051 initOperands.push_back(initOperand);
5052 } while (true);
5053 operands.append(initOperands);
5054
5055 // Check if we are parsing the compact version of reduce-op:
5056 // mhlo.reduce applies <inner-op> across dimensions = [...] : <func-type>
5057 // else parse the "region-based" variant.
5058 if (failed(parser.parseOptionalKeyword("applies"))) {
5059 // Parse the inner-op dimensions, reduce-op's function-type and
5060 // optional location.
5061 SmallVector<int64_t> dimensions;
5062 auto parseDim = [&]() -> ParseResult {
5063 if (parser.parseInteger(dimensions.emplace_back())) return failure();
5064 return success();
5065 };
5066
5067 FunctionType reduceOpFntype;
5068 if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") ||
5069 parser.parseEqual() ||
5070 parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
5071 parseDim) ||
5072 parser.parseOptionalAttrDict(result.attributes) ||
5073 parser.parseColon() || parser.parseType(reduceOpFntype) ||
5074 parser.parseKeyword("reducer"))
5075 return failure();
5076 OpBuilder builder(parser.getBuilder().getContext());
5077 result.addAttribute("dimensions", builder.getI64TensorAttr(dimensions));
5078
5079 // Parse the "reducer" region now.
5080 SmallVector<OpAsmParser::UnresolvedOperand, 2> reducerOperands;
5081 SmallVector<OpAsmParser::UnresolvedOperand, 2> reducerInitOperands;
5082 SmallVector<Type, 2> reducerTypes;
5083 SmallVector<Type, 2> reducerInitTypes;
5084 SmallVector<Optional<Location>, 2> reducerLocs;
5085 SmallVector<Optional<Location>, 2> reducerInitLocs;
5086 auto parseBlockOperand =
5087 [&](SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands,
5088 SmallVectorImpl<Type>& types,
5089 SmallVectorImpl<Optional<Location>>& locs) -> ParseResult {
5090 OpAsmParser::UnresolvedOperand operand;
5091 Type type;
5092 Optional<Location> loc;
5093 if (parser.parseOperand(operand, /*allowResultNumber=*/false) ||
5094 parser.parseColon() || parser.parseType(type) ||
5095 parser.parseOptionalLocationSpecifier(loc))
5096 return failure();
5097 operands.push_back(operand);
5098 types.push_back(type);
5099 locs.push_back(loc);
5100 return success();
5101 };
5102 do {
5103 if (failed(parser.parseOptionalLParen())) break;
5104 if (parseBlockOperand(reducerOperands, reducerTypes, reducerLocs) ||
5105 parser.parseComma() ||
5106 parseBlockOperand(reducerInitOperands, reducerInitTypes,
5107 reducerInitLocs) ||
5108 parser.parseRParen())
5109 return failure();
5110 } while (true);
5111 reducerOperands.append(reducerInitOperands);
5112 reducerTypes.append(reducerInitTypes);
5113 reducerLocs.append(reducerInitLocs);
5114 result.addTypes(reduceOpFntype.getResults());
5115 SmallVector<OpAsmParser::Argument> reducerArgs;
5116 createArgs(reducerOperands, reducerTypes, reducerArgs);
5117
5118 // Derive the SSA-values for reduce-op's operands and parse the region, and
5119 // the optional trailing location.
5120 Optional<Location> trailingLoc;
5121 if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc,
5122 result.operands) ||
5123 parser.parseRegion(*result.addRegion(), reducerArgs))
5124 return failure();
5125 // Set the individual block arguments.
5126 for (auto argAndLoc :
5127 llvm::zip(result.regions.front()->front().getArguments(), reducerLocs))
5128 if (std::get<1>(argAndLoc))
5129 std::get<0>(argAndLoc).setLoc(std::get<1>(argAndLoc).value());
5130 result.location = trailingLoc.value_or(currLocation);
5131 return success();
5132 }
5133
5134 // Parse the inner-op name and check if the contract on inner-op
5135 // mentioned in "isEligibleForCompactPrint::E2" for pretty-priting is met.
5136 FailureOr<OperationName> innerOpNameInfo = parser.parseCustomOperationName();
5137 if (failed(innerOpNameInfo)) return failure();
5138
5139 StringRef innerOpName = innerOpNameInfo->getStringRef();
5140 Dialect* innerOpDialect = innerOpNameInfo->getDialect();
5141 if (!innerOpDialect || !innerOpDialect->getNamespace().equals("mhlo") ||
5142 !innerOpNameInfo->hasTrait<mlir::OpTrait::NOperands<2>::Impl>() ||
5143 !innerOpNameInfo->hasTrait<mlir::OpTrait::OneResult>() ||
5144 !innerOpNameInfo->hasTrait<mlir::OpTrait::IsCommutative>() ||
5145 !innerOpNameInfo->hasTrait<mlir::OpTrait::ZeroRegions>()) {
5146 parser.emitError(loc,
5147 "expected the inner-op to be a commutative binary-op from "
5148 "mhlo dialect, zero region, producing single result");
5149 return failure();
5150 }
5151
5152 // Parse the inner-op dimensions, reduce-op's function-type and
5153 // optional location.
5154 SmallVector<int64_t> dimensions;
5155 auto parseDim = [&]() -> ParseResult {
5156 if (parser.parseInteger(dimensions.emplace_back())) return failure();
5157 return success();
5158 };
5159
5160 Optional<Location> explicitLoc;
5161 FunctionType reduceOpFntype;
5162 if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") ||
5163 parser.parseEqual() ||
5164 parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) ||
5165 parser.parseColon() || parser.parseType(reduceOpFntype) ||
5166 parser.parseOptionalLocationSpecifier(explicitLoc))
5167 return failure();
5168
5169 if (!reduceOpFntype || reduceOpFntype.getInputs().empty()) {
5170 if (!reduceOpFntype) return parser.emitError(loc, "expected function type");
5171 return parser.emitError(loc,
5172 "input types missing in reduce-op function type");
5173 }
5174
5175 // If location of reduce-op is explicitly provided, then use it; Else use
5176 // the parser's current location.
5177 Location reduceOpLoc = explicitLoc.value_or(currLocation);
5178
5179 // Derive the SSA-values for reduce-op's operands.
5180 if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc,
5181 result.operands))
5182 return failure();
5183
5184 // Derive the type of inner-op from that of reduce-op's input operand.
5185 auto innerOpType = RankedTensorType::get(
5186 /*shape=*/{}, getElementTypeOrSelf(reduceOpFntype.getInput(0)));
5187
5188 // Add a region for reduce-op.
5189 Region& region = *result.addRegion();
5190
5191 // Create a basic-block inside reduce-op's region.
5192 Block& block = region.emplaceBlock();
5193 auto lhs = block.addArgument(innerOpType, reduceOpLoc);
5194 auto rhs = block.addArgument(innerOpType, reduceOpLoc);
5195
5196 // Create and insert an "inner-op" operation in the block.
5197 OpBuilder builder(parser.getBuilder().getContext());
5198 builder.setInsertionPointToStart(&block);
5199
5200 OperationState innerOpState(reduceOpLoc, innerOpName);
5201 innerOpState.operands.push_back(lhs);
5202 innerOpState.operands.push_back(rhs);
5203 innerOpState.addTypes(innerOpType);
5204
5205 Operation* innerOp = builder.create(innerOpState);
5206
5207 // Insert a return statement in the block returning the inner-op's result.
5208 builder.create<ReturnOp>(innerOp->getLoc(), innerOp->getResults());
5209
5210 // Populate the reduce-op operation-state with result-type, location, and
5211 // dimension attribute.
5212 result.addTypes(reduceOpFntype.getResults());
5213 result.location = innerOp->getLoc();
5214 result.addAttribute("dimensions", builder.getI64TensorAttr(dimensions));
5215
5216 return success();
5217 }
5218
verify()5219 LogicalResult ReduceOp::verify() {
5220 // Check that there are even number of operands and >= 2.
5221 if (getNumOperands() % 2 != 0 || getOperands().empty())
5222 return emitOpError() << "expects the size of operands to be even and >= 2";
5223
5224 // Collect the input and init-value operands. Note that the operand-type is
5225 // enforced as "TensorType" by ODS.
5226 int64_t numInputs = getNumOperands() / 2;
5227 auto operandTensorTypes = llvm::to_vector<4>(llvm::map_range(
5228 getOperandTypes(),
5229 [](Type t) -> TensorType { return t.cast<TensorType>(); }));
5230 ArrayRef<TensorType> inputArgTypes(operandTensorTypes.begin(),
5231 operandTensorTypes.begin() + numInputs);
5232 ArrayRef<TensorType> initValueTypes(operandTensorTypes.begin() + numInputs,
5233 operandTensorTypes.end());
5234
5235 // Check for unranked tensors in input operands.
5236 int64_t rankedInputIdx = -1;
5237 for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
5238 if (inputArgTypes[inputIdx].hasRank()) {
5239 rankedInputIdx = inputIdx;
5240 break;
5241 }
5242 }
5243
5244 bool allInputsUnranked = (rankedInputIdx == -1);
5245
5246 // Check that all input operands have compatible shapes. The element types may
5247 // be different.
5248 if (!allInputsUnranked) {
5249 for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
5250 if (failed(mlir::verifyCompatibleShape(inputArgTypes[rankedInputIdx],
5251 inputArgTypes[inputIdx]))) {
5252 return emitOpError()
5253 << "expects all inputs to have compatible shapes. Shape at"
5254 << " input-index " << inputIdx
5255 << " is not compatible with shape at input-index "
5256 << rankedInputIdx;
5257 }
5258 }
5259 }
5260
5261 // Check that
5262 // 1. the dimensions of reduce-op are in-bounds for the given shape.
5263 // 2. the dimension-attribute have no duplicate entries.
5264 DenseSet<int64_t> dimensionsToReduceSet;
5265 for (int64_t dimension : dimensions().getValues<int64_t>()) {
5266 if ((!allInputsUnranked &&
5267 dimension >= inputArgTypes[rankedInputIdx].getRank()) ||
5268 dimension < 0) {
5269 return emitError() << "Out-of-bounds dimension " << dimension
5270 << " for input-tensor rank: "
5271 << inputArgTypes[rankedInputIdx].getRank();
5272 }
5273
5274 if (!dimensionsToReduceSet.insert(dimension).second) {
5275 return emitError() << "Duplicate reduction dimension: " << dimension;
5276 }
5277 }
5278
5279 // Verify the inner block defining the reducer function.
5280 SmallVector<int64_t> newDimensions;
5281 if (!allInputsUnranked) {
5282 for (int inputIdx = 0; inputIdx < inputArgTypes[rankedInputIdx].getRank();
5283 ++inputIdx) {
5284 if (!dimensionsToReduceSet.count(inputIdx)) {
5285 newDimensions.push_back(
5286 inputArgTypes[rankedInputIdx].getDimSize(inputIdx));
5287 }
5288 }
5289 }
5290
5291 Block& block = body().front();
5292 SmallVector<TensorType> accumulatorSubShapes;
5293 if (failed(verifyReducerShape(this->getLoc(), block, inputArgTypes,
5294 initValueTypes, numInputs, newDimensions,
5295 allInputsUnranked, accumulatorSubShapes)))
5296 return failure();
5297
5298 // Check if the reduce-op's result-type matches with the one derived from
5299 // the reducer-block and dimensions attribute.
5300 if (getResults().size() != accumulatorSubShapes.size())
5301 return emitError() << "Unexpected number of reduce-op's returned values: "
5302 << getResults().size() << " vs "
5303 << accumulatorSubShapes.size() << " (expected)";
5304
5305 for (int64_t shapeIdx = 0;
5306 shapeIdx < static_cast<int64_t>(accumulatorSubShapes.size());
5307 shapeIdx++) {
5308 // The result-type is enforced as "TensorType" by ODS.
5309 auto opResultType = getResult(shapeIdx).getType().cast<TensorType>();
5310
5311 // Check element-type.
5312 if (accumulatorSubShapes[shapeIdx].getElementType() !=
5313 opResultType.getElementType()) {
5314 return emitError()
5315 << "Unexpected element-type for reduce-op's return value at index "
5316 << shapeIdx << ": " << opResultType.getElementType() << " vs "
5317 << accumulatorSubShapes[shapeIdx].getElementType()
5318 << " (expected)";
5319 }
5320
5321 // Check shape.
5322 if (!allInputsUnranked && opResultType.hasRank() &&
5323 failed(verifyCompatibleShape(newDimensions, opResultType.getShape()))) {
5324 Type expectedResultType = RankedTensorType::get(
5325 newDimensions, accumulatorSubShapes[shapeIdx].getElementType());
5326 return emitError()
5327 << "Unexpected type for reduce-op's return value at index "
5328 << shapeIdx << ": " << opResultType << " vs " << expectedResultType
5329 << " (expected)";
5330 }
5331 }
5332
5333 return success();
5334 }
5335
5336 // Enable constant folding to occur within the region of the ReduceOp
5337 // by replacing block argument uses with constants if:
5338 // 1. All the ReduceOp operands are splat constants.
5339 // 2. The ReduceOp region consists of a single logical AND or logical OR.
5340 // The pattern leverages the idempotent property of the AND and OR operators
5341 // to determine the value of a reduction on splat constants. Other boolean
5342 // operators do not have this property, and need separate patterns to resolve
5343 // reductions of their splat constants.
5344 struct LowerBoolSplatConstantsIntoRegion : public OpRewritePattern<ReduceOp> {
5345 using OpRewritePattern<ReduceOp>::OpRewritePattern;
5346
matchAndRewritemlir::mhlo::LowerBoolSplatConstantsIntoRegion5347 LogicalResult matchAndRewrite(ReduceOp op,
5348 PatternRewriter& rewriter) const override {
5349 mlir::Block& bb = op.body().front();
5350
5351 // Ensure only a compute op and return op exist and the
5352 // compute op is an AND or OR op.
5353 if (bb.getOperations().size() != 2) return failure();
5354 if (!mlir::isa<mhlo::AndOp, mhlo::OrOp>(bb.front())) return failure();
5355
5356 // Ensure all operands are splat constants.
5357 SmallVector<DenseElementsAttr, 4> bargCstAttrs;
5358 for (auto inpAndBarg : llvm::zip(op.getOperands(), bb.getArguments())) {
5359 Value inp = std::get<0>(inpAndBarg);
5360 BlockArgument barg = std::get<1>(inpAndBarg);
5361 ConstantOp cst = inp.getDefiningOp<ConstantOp>();
5362 if (!cst) return failure();
5363
5364 auto cstAttr = cst.value().dyn_cast_or_null<DenseElementsAttr>();
5365 if (!cstAttr.isSplat()) {
5366 return rewriter.notifyMatchFailure(op, "Must be splat constant.");
5367 }
5368
5369 auto bargShapedType = barg.getType().dyn_cast<ShapedType>();
5370 if (!bargShapedType) return failure();
5371
5372 auto bargCstAttr = DenseElementsAttr::get(
5373 bargShapedType, cstAttr.getSplatValue<mlir::Attribute>());
5374 bargCstAttrs.push_back(bargCstAttr);
5375 }
5376
5377 // Create new splat constants to replace block arguments.
5378 for (BlockArgument barg : bb.getArguments()) {
5379 int argIdx = barg.getArgNumber();
5380 mhlo::ConstantOp newCst = rewriter.create<mhlo::ConstantOp>(
5381 bb.front().getLoc(), barg.getType(), bargCstAttrs[argIdx]);
5382 barg.replaceAllUsesWith(newCst);
5383 }
5384 return success();
5385 }
5386 };
5387
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)5388 void ReduceOp::getCanonicalizationPatterns(RewritePatternSet& results,
5389 MLIRContext* context) {
5390 results.add<LowerBoolSplatConstantsIntoRegion>(context);
5391 }
5392
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)5393 LogicalResult ReduceOp::reifyReturnTypeShapes(
5394 OpBuilder& builder, ValueRange operands,
5395 SmallVectorImpl<Value>& reifiedReturnShapes) {
5396 ReduceOp::Adaptor adaptor(operands);
5397 auto inputs = adaptor.operands();
5398
5399 auto operandType = inputs[0].getType().dyn_cast<RankedTensorType>();
5400 // Not support unranked type a.t.m.
5401 if (!operandType) return failure();
5402
5403 Location loc = this->getLoc();
5404 SmallVector<Value, 4> shapeValues;
5405 SmallVector<int64_t, 4> dimensions(this->dimensions().getValues<int64_t>());
5406 shapeValues.reserve(operandType.getRank());
5407 Type shapeScalarType = builder.getIndexType();
5408 auto toShapeScalarType = [&](Value v) {
5409 return maybeCastTo(builder, loc, v, shapeScalarType);
5410 };
5411
5412 for (const auto& element : llvm::enumerate(operandType.getShape())) {
5413 int64_t idx = element.index();
5414 auto* it = std::find(dimensions.begin(), dimensions.end(), idx);
5415 if (it != dimensions.end()) {
5416 continue;
5417 }
5418 Value valueDim = toShapeScalarType(
5419 builder.create<tensor::DimOp>(loc, inputs[0], element.index()));
5420 shapeValues.push_back(valueDim);
5421 }
5422
5423 Value outputShape = builder.create<tensor::FromElementsOp>(
5424 loc,
5425 RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
5426 shapeScalarType),
5427 shapeValues);
5428 for (size_t i = 0; i < inputs.size(); ++i) {
5429 reifiedReturnShapes.push_back(outputShape);
5430 }
5431
5432 return success();
5433 }
5434
5435 //===----------------------------------------------------------------------===//
5436 // RngBitGeneratorOp
5437 //===----------------------------------------------------------------------===//
5438
5439 // Verify that input state has the same shape as output shape
verify()5440 LogicalResult RngBitGeneratorOp::verify() {
5441 auto initialShape = initial_state().getType().dyn_cast<RankedTensorType>();
5442 auto outputShape = output_state().getType().dyn_cast<RankedTensorType>();
5443 if (initialShape.getShape() != outputShape.getShape())
5444 return emitOpError()
5445 << "output state shape must match initial state shape. Got: "
5446 << initialShape << " and " << outputShape;
5447 return success();
5448 }
5449
5450 //===----------------------------------------------------------------------===//
5451 // RngOp
5452 //===----------------------------------------------------------------------===//
5453
verify()5454 LogicalResult RngOp::verify() {
5455 auto dist = rng_distribution();
5456 if (dist == RngDistribution::UNIFORM) {
5457 return success();
5458 }
5459 auto muTy = a().getType().cast<TensorType>().getElementType();
5460 auto sigmaTy = b().getType().cast<TensorType>().getElementType();
5461 if (muTy.isa<FloatType>() && sigmaTy.isa<FloatType>()) {
5462 return success();
5463 }
5464 return emitOpError() << "mu and sigma must be floats";
5465 }
5466
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)5467 LogicalResult RngOp::inferReturnTypeComponents(
5468 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
5469 DictionaryAttr attributes, RegionRange regions,
5470 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
5471 return rngInferReturnTypeComponents(context, location, operands, attributes,
5472 regions, inferredReturnShapes);
5473 }
5474
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)5475 LogicalResult RngOp::reifyReturnTypeShapes(
5476 OpBuilder& builder, ValueRange operands,
5477 SmallVectorImpl<Value>& reifiedReturnShapes) {
5478 RngOp::Adaptor adaptor(operands);
5479 reifiedReturnShapes.push_back(
5480 castToIndexTensor(builder, getLoc(), adaptor.shape()));
5481 return success();
5482 }
5483
5484 //===----------------------------------------------------------------------===//
5485 // XlaRngGetAndUpdateStateOp
5486 //===----------------------------------------------------------------------===//
5487
verify()5488 LogicalResult XlaRngGetAndUpdateStateOp::verify() {
5489 auto resultTy = getType().cast<RankedTensorType>();
5490 if (!resultTy) return emitOpError() << "Output is not ranked.";
5491 if (!resultTy.hasStaticShape())
5492 return emitOpError() << "Output is not statically shaped.";
5493 auto rank = resultTy.getRank();
5494 if (rank != 1)
5495 return emitOpError() << "Output is of rank " << rank << " instead of 1";
5496 auto extent = resultTy.getDimSize(0);
5497 if (extent != 2)
5498 return emitOpError() << "Output size is " << extent << " instead of 2";
5499
5500 return success();
5501 }
5502
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)5503 LogicalResult XlaRngGetAndUpdateStateOp::inferReturnTypes(
5504 MLIRContext* ctx, Optional<Location>, ValueRange, DictionaryAttr,
5505 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
5506 inferredReturnTypes.push_back(mlir::RankedTensorType::get(
5507 {2}, mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned)));
5508 return success();
5509 }
5510
5511 //===----------------------------------------------------------------------===//
5512 // SelectOp
5513 //===----------------------------------------------------------------------===//
5514
verify()5515 LogicalResult SelectOp::verify() {
5516 // The operands 'on_true' and 'on_false' should have compatible types, i.e.,
5517 // (a) have the same element type, and
5518 // (b) have compatible shapes (i.e. the same shape and/or at least one
5519 // dynamic shape)
5520 if (!compatibleShapeAndElementType(on_true().getType(), on_false().getType()))
5521 return emitOpError()
5522 << "requires compatible types for non-predicate operands";
5523
5524 // The predicate, if not-scalar, should have the same shape as the remaining
5525 // operands.
5526 auto predTy = pred().getType().dyn_cast<RankedTensorType>();
5527 bool predMayBeScalar = !predTy || predTy.getRank() == 0;
5528 if (predMayBeScalar) return success();
5529
5530 if (failed(verifyCompatibleShape(pred().getType(), on_true().getType())))
5531 return emitOpError() << "requires the same shape for all operands";
5532
5533 return success();
5534 }
5535
fold(ArrayRef<Attribute> operands)5536 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
5537 if (on_true() == on_false()) {
5538 return on_true();
5539 }
5540
5541 auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
5542 if (!predicate) {
5543 return {};
5544 }
5545
5546 auto predicateTy = predicate.getType().cast<ShapedType>();
5547 if (!predicateTy.getElementType().isInteger(1)) {
5548 return {};
5549 }
5550
5551 if (predicate.isSplat()) {
5552 return predicate.getSplatValue<APInt>().getBoolValue() ? on_true()
5553 : on_false();
5554 }
5555
5556 return {};
5557 }
5558
5559 // simplify select(not(%pred), true_value, false_value) => select(%pred,
5560 // false_value, true_value)
selectCanonicalization(SelectOp selectOp,PatternRewriter & rewriter)5561 static LogicalResult selectCanonicalization(SelectOp selectOp,
5562 PatternRewriter& rewriter) {
5563 auto notOp = selectOp.pred().getDefiningOp<NotOp>();
5564 if (!notOp) {
5565 return failure();
5566 }
5567 std::array<Value, 3> newOperands = {notOp.operand(), selectOp.on_false(),
5568 selectOp.on_true()};
5569 rewriter.updateRootInPlace(
5570 selectOp, [&]() { selectOp.getOperation()->setOperands(newOperands); });
5571 return success();
5572 }
5573
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext *)5574 void SelectOp::getCanonicalizationPatterns(RewritePatternSet& results,
5575 MLIRContext* /*context*/) {
5576 results.add(&selectCanonicalization);
5577 }
5578
5579 // Makes it such that a SelectOp that is a non-root operation in a DRR infers
5580 // the return type based on operand type.
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)5581 LogicalResult SelectOp::inferReturnTypeComponents(
5582 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
5583 DictionaryAttr attributes, RegionRange,
5584 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
5585 SelectOp::Adaptor op(operands, attributes);
5586 auto trueType = op.on_true().getType().cast<TensorType>();
5587 auto falseType = op.on_false().getType().cast<TensorType>();
5588
5589 // The output shape should be the most general of the operand shapes at each
5590 // dimension.
5591 ShapedTypeComponents& outputType = inferredReturnShapes.emplace_back();
5592 if (trueType == falseType || !trueType.hasRank()) {
5593 outputType = ShapedTypeComponents(trueType.cast<ShapedType>());
5594 } else if (!falseType.hasRank()) {
5595 outputType = ShapedTypeComponents(falseType.cast<ShapedType>());
5596 } else {
5597 assert(trueType.getRank() == falseType.getRank());
5598 llvm::SmallVector<int64_t, 4> dims;
5599 dims.reserve(trueType.getRank());
5600 for (auto dim : llvm::zip(trueType.getShape(), falseType.getShape())) {
5601 dims.push_back(std::get<0>(dim) == std::get<1>(dim)
5602 ? std::get<0>(dim)
5603 : ShapedType::kDynamicSize);
5604 }
5605 outputType = ShapedTypeComponents(dims, trueType.getElementType());
5606 }
5607 return success();
5608 }
5609
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)5610 LogicalResult SelectOp::reifyReturnTypeShapes(
5611 OpBuilder& builder, ValueRange operands,
5612 SmallVectorImpl<Value>& reifiedReturnShapes) {
5613 // For `hlo.select`, the first operand may be a scalar.
5614 return deriveShapeFromOperand(&builder, getOperation(), operands[1],
5615 &reifiedReturnShapes);
5616 }
5617
5618 //===----------------------------------------------------------------------===//
5619 // SetDimensionSizeOp
5620 //===----------------------------------------------------------------------===//
5621
verify()5622 LogicalResult SetDimensionSizeOp::verify() {
5623 if (auto size = this->size().getType().dyn_cast<RankedTensorType>()) {
5624 if (size.getRank() != 0)
5625 return emitOpError() << "size operand should be of rank-0";
5626 }
5627
5628 return verifyDimAttr(*this);
5629 }
5630
fold(ArrayRef<Attribute> operands)5631 OpFoldResult SetDimensionSizeOp::fold(ArrayRef<Attribute> operands) {
5632 DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
5633 if (input) return input;
5634
5635 DenseElementsAttr size = operands[1].dyn_cast_or_null<DenseElementsAttr>();
5636 if (!size || !size.isSplat()) return {};
5637
5638 auto ty = getType().dyn_cast<RankedTensorType>();
5639 if (!ty) return {};
5640
5641 int64_t dimSize = ty.getDimSize(dimension());
5642 if (dimSize == size.getSplatValue<IntegerAttr>().getInt()) return operand();
5643 return {};
5644 }
5645
5646 // TODO(b/238903565): Switch to inferReturnTypeComponents after adding support
5647 // for the encoding upstream.
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)5648 LogicalResult SetDimensionSizeOp::inferReturnTypes(
5649 MLIRContext* context, Optional<Location> location, ValueRange operands,
5650 DictionaryAttr attributes, RegionRange regions,
5651 SmallVectorImpl<Type>& inferredReturnTypes) {
5652 Location loc = location.value_or(UnknownLoc::get(context));
5653
5654 SetDimensionSizeOp::Adaptor adaptor(operands, attributes, regions);
5655 if (failed(adaptor.verify(loc))) return failure();
5656
5657 auto inputType = adaptor.operand().getType().dyn_cast<RankedTensorType>();
5658 if (!inputType) {
5659 inferredReturnTypes.push_back(adaptor.operand().getType());
5660 return success();
5661 }
5662
5663 int64_t dim = adaptor.dimension();
5664 int64_t rank = inputType.getRank();
5665 if (dim < 0 || dim >= rank) {
5666 return mlir::emitError(loc) << "expects dimension to be in range [0, "
5667 << rank << "); got: [" << dim << "].";
5668 }
5669
5670 auto shape = llvm::to_vector<4>(inputType.getShape());
5671 llvm::SmallVector<int64_t, 4> bounds(rank, ShapedType::kDynamicSize);
5672 if (auto encoding =
5673 inputType.getEncoding().dyn_cast_or_null<TypeExtensionsAttr>())
5674 bounds = llvm::to_vector<4>(encoding.getBounds());
5675
5676 // TODO(hinsu): Handle the case when the size operand is a constant.
5677 if (shape[dim] != ShapedType::kDynamicSize) bounds[dim] = shape[dim];
5678 shape[dim] = ShapedType::kDynamicSize;
5679
5680 auto extensions = TypeExtensionsAttr::get(context, bounds);
5681 auto resultType =
5682 RankedTensorType::get(shape, inputType.getElementType(), extensions);
5683 inferredReturnTypes.push_back(resultType);
5684 return success();
5685 }
5686
5687 //===----------------------------------------------------------------------===//
5688 // PadOp
5689 //===----------------------------------------------------------------------===//
5690
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)5691 LogicalResult PadOp::inferReturnTypeComponents(
5692 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
5693 DictionaryAttr attributes, RegionRange regions,
5694 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
5695 PadOp::Adaptor adaptor(operands, attributes, regions);
5696 auto inputType = adaptor.operand().getType().cast<RankedTensorType>();
5697 auto padType = adaptor.padding_value().getType().cast<RankedTensorType>();
5698
5699 if (padType.getRank() != 0) {
5700 return emitOptionalError(
5701 location, llvm::formatv("padding value type should be a rank-0 "
5702 "tensor, is rank {0}",
5703 padType.getRank()));
5704 }
5705
5706 const auto& paddingLow = adaptor.edge_padding_low();
5707 if (paddingLow.getType().getNumElements() != inputType.getRank()) {
5708 return emitOptionalError(
5709 location,
5710 llvm::formatv(
5711 "edge_padding_low length ({0}) must match operand rank ({1})",
5712 paddingLow.getType().getNumElements(), inputType.getRank()));
5713 }
5714
5715 const auto& paddingHigh = adaptor.edge_padding_high();
5716 if (paddingHigh.getType().getNumElements() != inputType.getRank()) {
5717 return emitOptionalError(
5718 location,
5719 llvm::formatv(
5720 "edge_padding_high length ({0}) must match operand rank ({1})",
5721 paddingHigh.getType().getNumElements(), inputType.getRank()));
5722 }
5723
5724 const auto& paddingInterior = adaptor.interior_padding();
5725 if (paddingInterior.getType().getNumElements() != inputType.getRank()) {
5726 return emitOptionalError(
5727 location,
5728 llvm::formatv(
5729 "interior_padding length ({0}) must match operand rank ({1})",
5730 paddingInterior.getType().getNumElements(), inputType.getRank()));
5731 }
5732
5733 auto inputShape = inputType.getShape();
5734 SmallVector<int64_t> resultShape;
5735 for (int i = 0, e = inputShape.size(); i < e; i++) {
5736 if (isDynamicDimSize(inputShape[i])) {
5737 resultShape.push_back(ShapedType::kDynamicSize);
5738 continue;
5739 }
5740
5741 int64_t paddingLowVal = paddingLow.getValues<APInt>()[i].getSExtValue();
5742 int64_t paddingHighVal = paddingHigh.getValues<APInt>()[i].getSExtValue();
5743 int64_t paddingInteriorVal =
5744 paddingInterior.getValues<APInt>()[i].getSExtValue();
5745 if (paddingInteriorVal < 0) {
5746 return emitOptionalError(
5747 location, llvm::formatv("Interior padding cannot be negative: {0}",
5748 paddingInteriorVal));
5749 }
5750 int64_t expectedOutput =
5751 inputShape[i] + paddingLowVal + paddingHighVal +
5752 std::max<int64_t>(inputShape[i] - 1, 0LL) * paddingInteriorVal;
5753 if (expectedOutput < 0) {
5754 return emitOptionalError(
5755 location,
5756 llvm::formatv("Padding result in negative size for dimension {0}",
5757 i));
5758 }
5759 resultShape.push_back(expectedOutput);
5760 }
5761 inferredReturnShapes.emplace_back(resultShape, inputType.getElementType());
5762
5763 return success();
5764 }
5765
5766 template <typename T>
padOpFoldHelper(DenseElementsAttr input,DenseElementsAttr padding,RankedTensorType returnType,DenseIntElementsAttr edgePaddingLow,DenseIntElementsAttr,DenseIntElementsAttr interiorPadding)5767 OpFoldResult padOpFoldHelper(DenseElementsAttr input, DenseElementsAttr padding,
5768 RankedTensorType returnType,
5769 DenseIntElementsAttr edgePaddingLow,
5770 DenseIntElementsAttr /*edgePaddingHigh*/,
5771 DenseIntElementsAttr interiorPadding) {
5772 // Prevent folding if the result is too large.
5773 if (returnType.getNumElements() > kFoldOpEltLimit) return {};
5774
5775 // Fill the full result tensor with the padding value.
5776 llvm::SmallVector<T, 4> result(returnType.getNumElements(),
5777 padding.getValues<T>()[0]);
5778
5779 auto nextIndex = [](llvm::SmallVector<uint64_t, 8>& index,
5780 llvm::ArrayRef<int64_t> shape) {
5781 for (int64_t i = index.size() - 1; i >= 0; --i) {
5782 ++index[i];
5783 if (static_cast<int64_t>(index[i]) < shape[i]) return;
5784 index[i] = 0;
5785 }
5786 };
5787
5788 // Iterate over all elements of the input tensor and copy it to the correct
5789 // location in the output tensor.
5790 llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
5791 uint64_t numElements = input.getNumElements();
5792 for (uint64_t operandIdx = 0; operandIdx < numElements; operandIdx++) {
5793 uint64_t resultIdx = 0;
5794 uint64_t idxMultiplyer = 1;
5795 for (int64_t i = index.size() - 1; i >= 0; --i) {
5796 resultIdx += (edgePaddingLow.getValues<int64_t>()[i] +
5797 index[i] * (interiorPadding.getValues<int64_t>()[i] + 1)) *
5798 idxMultiplyer;
5799 idxMultiplyer *= returnType.getDimSize(i);
5800 }
5801 result[resultIdx] = input.getValues<T>()[index];
5802 nextIndex(index, input.getType().getShape());
5803 }
5804 return DenseElementsAttr::get(returnType, result);
5805 }
5806
fold(ArrayRef<Attribute> operands)5807 OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
5808 // If all padding is zero then it is an identity pad.
5809 auto isZero = [](const APInt& i) { return i == 0; };
5810 if (llvm::all_of(edge_padding_low().getValues<APInt>(), isZero) &&
5811 llvm::all_of(edge_padding_high().getValues<APInt>(), isZero) &&
5812 llvm::all_of(interior_padding().getValues<APInt>(), isZero))
5813 return operand();
5814
5815 // If any padding is negative then it isn't supported by the folder (yet).
5816 auto isNegative = [](const APInt& i) { return i.slt(0); };
5817 if (llvm::any_of(edge_padding_low().getValues<APInt>(), isNegative) ||
5818 llvm::any_of(edge_padding_high().getValues<APInt>(), isNegative) ||
5819 llvm::any_of(interior_padding().getValues<APInt>(), isNegative))
5820 return {};
5821
5822 DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
5823 DenseElementsAttr padding = operands[1].dyn_cast_or_null<DenseElementsAttr>();
5824 RankedTensorType returnType = getType().dyn_cast_or_null<RankedTensorType>();
5825 if (!input || !input.getType().hasRank() || !padding || !returnType ||
5826 !returnType.hasStaticShape())
5827 return {};
5828
5829 if (returnType.getElementType().isa<IntegerType>())
5830 return padOpFoldHelper<APInt>(input, padding, returnType,
5831 edge_padding_low(), edge_padding_high(),
5832 interior_padding());
5833 if (returnType.getElementType().isa<FloatType>())
5834 return padOpFoldHelper<APFloat>(input, padding, returnType,
5835 edge_padding_low(), edge_padding_high(),
5836 interior_padding());
5837 if (ComplexType complex =
5838 returnType.getElementType().dyn_cast_or_null<ComplexType>()) {
5839 // TODO(atondwal): Allow int types in HLO_complex
5840 if (complex.getElementType().isa<FloatType>())
5841 return padOpFoldHelper<std::complex<APFloat>>(
5842 input, padding, returnType, edge_padding_low(), edge_padding_high(),
5843 interior_padding());
5844 }
5845 return {};
5846 }
5847
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)5848 LogicalResult PadOp::reifyReturnTypeShapes(
5849 OpBuilder& builder, ValueRange operands,
5850 SmallVectorImpl<Value>& reifiedReturnShapes) {
5851 PadOp::Adaptor adaptor(operands, this->getOperation()->getAttrDictionary());
5852 auto loc = this->getLoc();
5853 Value operand = adaptor.operand();
5854 auto operandTy = operand.getType().cast<RankedTensorType>();
5855
5856 llvm::SmallVector<int32_t> padHigh;
5857 llvm::SmallVector<int32_t> padLow;
5858 llvm::SmallVector<int32_t> padInterior;
5859
5860 auto padHighAttr = adaptor.edge_padding_high();
5861 auto padLowAttr = adaptor.edge_padding_low();
5862 auto padInteriorAttr = adaptor.interior_padding();
5863
5864 padHigh.reserve(padHighAttr.getNumElements());
5865 padLow.reserve(padLowAttr.getNumElements());
5866 padInterior.reserve(padInteriorAttr.getNumElements());
5867
5868 for (const APInt& val : padHighAttr.getValues<APInt>())
5869 padHigh.push_back(val.getSExtValue());
5870
5871 for (const APInt& val : padLowAttr.getValues<APInt>())
5872 padLow.push_back(val.getSExtValue());
5873
5874 for (const APInt& val : padInteriorAttr.getValues<APInt>())
5875 padInterior.push_back(val.getSExtValue());
5876
5877 Value one = builder.create<arith::ConstantIndexOp>(loc, 1).getResult();
5878 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0).getResult();
5879
5880 llvm::SmallVector<Value> dimensions;
5881 dimensions.reserve(operandTy.getRank());
5882 for (int i = 0, s = operandTy.getRank(); i < s; ++i) {
5883 Value padEdge =
5884 builder.create<arith::ConstantIndexOp>(loc, padHigh[i] + padLow[i]);
5885
5886 // First we grab the initial interior size.
5887 Value dim = builder.create<tensor::DimOp>(loc, operand, i).getResult();
5888
5889 // Compute the interior of the tensor and determine padding size.
5890 if (padInterior[i] > 0) {
5891 Value padInter =
5892 builder.create<arith::ConstantIndexOp>(loc, padInterior[i])
5893 .getResult();
5894 Value interior = builder.create<arith::SubIOp>(loc, dim, one).getResult();
5895 interior = builder.create<arith::MaxSIOp>(loc, interior, zero);
5896 interior = builder.create<arith::MulIOp>(loc, interior, padInter);
5897 dim = builder.create<arith::AddIOp>(loc, dim, interior).getResult();
5898 }
5899
5900 // Then we add the padding on the edge of the tensor.
5901 dim = builder.create<arith::AddIOp>(loc, dim, padEdge).getResult();
5902 dimensions.push_back(dim);
5903 }
5904
5905 Value dimensionTensor =
5906 builder.create<tensor::FromElementsOp>(loc, dimensions).getResult();
5907 reifiedReturnShapes.push_back(dimensionTensor);
5908 return success();
5909 }
5910
5911 // If the input tensor has a dimension of length-0, the input tensor is
5912 // irrelevant. Instead we can broadcast the pad value to the output size rather
5913 // than pad the input tensor.
5914 struct PadEmptyTensor : public OpRewritePattern<PadOp> {
5915 using OpRewritePattern<PadOp>::OpRewritePattern;
5916
matchAndRewritemlir::mhlo::PadEmptyTensor5917 LogicalResult matchAndRewrite(PadOp op,
5918 PatternRewriter& rewriter) const override {
5919 auto operand = op.operand();
5920 auto padVal = op.padding_value();
5921
5922 auto operandTy = operand.getType().cast<RankedTensorType>();
5923 auto resultTy = op.getType().cast<RankedTensorType>();
5924
5925 if (llvm::all_of(operandTy.getShape(), [](int64_t d) { return d != 0; })) {
5926 return failure();
5927 }
5928
5929 if (resultTy.hasStaticShape()) {
5930 auto dimsType = RankedTensorType::get({0}, rewriter.getIntegerType(64));
5931 auto dims =
5932 DenseIntElementsAttr::get(dimsType, SmallVector<int64_t, 1>{});
5933 rewriter.replaceOpWithNewOp<mhlo::BroadcastInDimOp>(op, resultTy, padVal,
5934 dims);
5935 return success();
5936 }
5937
5938 llvm::SmallVector<Value> reifiedShapes;
5939 if (failed(op.reifyReturnTypeShapes(rewriter, op.getOperands(),
5940 reifiedShapes)))
5941 return failure();
5942
5943 auto dimsType = RankedTensorType::get({0}, rewriter.getIntegerType(64));
5944 auto broadcastDims =
5945 DenseIntElementsAttr::get(dimsType, SmallVector<int64_t, 1>{});
5946 rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
5947 op, op.getType(), padVal, reifiedShapes.front(), broadcastDims);
5948
5949 return failure();
5950 }
5951 };
5952
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)5953 void PadOp::getCanonicalizationPatterns(RewritePatternSet& results,
5954 MLIRContext* context) {
5955 results.add<PadEmptyTensor>(context);
5956 }
5957
5958 //===----------------------------------------------------------------------===//
5959 // DynamicPadOp
5960 //===----------------------------------------------------------------------===//
5961
5962 // If the input tensor has a dimension of length-0, the input tensor is
5963 // irrelevant. Instead we can broadcast the pad value to the output size rather
5964 // than pad the input tensor.
5965 struct DynamicPadEmptyTensor : public OpRewritePattern<DynamicPadOp> {
5966 using OpRewritePattern<DynamicPadOp>::OpRewritePattern;
5967
matchAndRewritemlir::mhlo::DynamicPadEmptyTensor5968 LogicalResult matchAndRewrite(DynamicPadOp op,
5969 PatternRewriter& rewriter) const override {
5970 // auto loc = op.getLoc();
5971 auto operand = op.operand();
5972 auto padVal = op.padding_value();
5973
5974 auto operandTy = operand.getType().cast<RankedTensorType>();
5975
5976 if (llvm::all_of(operandTy.getShape(), [](int64_t d) { return d != 0; })) {
5977 return failure();
5978 }
5979
5980 llvm::SmallVector<Value> reifiedShapes;
5981 if (failed(op.reifyReturnTypeShapes(rewriter, op->getOperands(),
5982 reifiedShapes)))
5983 return failure();
5984
5985 auto dimsType = RankedTensorType::get({0}, rewriter.getIntegerType(64));
5986 auto broadcastDims =
5987 DenseIntElementsAttr::get(dimsType, SmallVector<int64_t, 1>{});
5988 rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
5989 op, op.getType(), padVal, reifiedShapes.front(), broadcastDims);
5990
5991 return failure();
5992 }
5993 };
5994
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)5995 void DynamicPadOp::getCanonicalizationPatterns(RewritePatternSet& results,
5996 MLIRContext* context) {
5997 results.add<DPadToPad, DynamicPadEmptyTensor>(context);
5998 }
5999
verify()6000 LogicalResult DynamicPadOp::verify() {
6001 auto inputType = operand().getType().dyn_cast<RankedTensorType>();
6002 // If operand is unranked, there is very little to verify statically.
6003 if (!inputType) return success();
6004 int inputRank = inputType.getRank();
6005
6006 auto padType = padding_value().getType().cast<RankedTensorType>();
6007 if (padType.getRank() != 0) {
6008 return emitOpError() << "padding value type should be a rank-0";
6009 }
6010
6011 auto paddingLowType = edge_padding_low().getType().cast<RankedTensorType>();
6012 if (paddingLowType.getNumElements() != inputRank) {
6013 return emitOpError() << "edge_padding_low length("
6014 << paddingLowType.getNumElements()
6015 << ") must match operand rank(" << inputRank << ").";
6016 }
6017
6018 auto paddingHighType = edge_padding_high().getType().cast<RankedTensorType>();
6019 if (paddingHighType.getNumElements() != inputRank) {
6020 return emitOpError() << "edge_padding_high length("
6021 << paddingHighType.getNumElements()
6022 << ") must match operand rank(" << inputRank << ").";
6023 }
6024
6025 auto interiorPaddingType =
6026 interior_padding().getType().cast<RankedTensorType>();
6027 if (interiorPaddingType.getNumElements() != inputRank) {
6028 return emitOpError() << "edge_padding_interior length("
6029 << interiorPaddingType.getNumElements()
6030 << ") must match operand rank(" << inputRank << ").";
6031 }
6032
6033 auto outputType = getResult().getType().dyn_cast<RankedTensorType>();
6034 // If result is unranked, there is very little to verify statically.
6035 if (!outputType) return success();
6036 int outputRank = outputType.getRank();
6037 if (inputRank != outputRank) {
6038 return emitOpError() << "operand rank(" << inputRank
6039 << ") must match result(" << outputRank << ").";
6040 }
6041
6042 return success();
6043 }
6044
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)6045 LogicalResult DynamicPadOp::reifyReturnTypeShapes(
6046 OpBuilder& builder, ValueRange operands,
6047 SmallVectorImpl<Value>& reifiedReturnShapes) {
6048 DynamicPadOp::Adaptor adaptor(operands);
6049 Value operand = adaptor.operand();
6050 Value edgePaddingLow = adaptor.edge_padding_low();
6051 Value edgePaddingHigh = adaptor.edge_padding_high();
6052 Value interiorPadding = adaptor.interior_padding();
6053
6054 auto operandType = operand.getType().dyn_cast<RankedTensorType>();
6055 // Not support unranked pad a.t.m.
6056 if (!operandType) return failure();
6057
6058 auto loc = this->getLoc();
6059 SmallVector<Value, 4> shapeValues;
6060 shapeValues.reserve(operandType.getRank());
6061 Type shapeScalarType =
6062 edgePaddingLow.getType().cast<ShapedType>().getElementType();
6063
6064 auto toShapeScalarType = [&](Value v) {
6065 return maybeCastTo(builder, loc, v, shapeScalarType);
6066 };
6067
6068 Value zero =
6069 toShapeScalarType(builder.create<arith::ConstantIndexOp>(loc, 0));
6070 Value one = toShapeScalarType(builder.create<arith::ConstantIndexOp>(loc, 1));
6071
6072 for (int idx : llvm::seq<int>(0, operandType.getShape().size())) {
6073 Value valueDim =
6074 toShapeScalarType(builder.create<tensor::DimOp>(loc, operand, idx));
6075 Value offset = builder.create<arith::ConstantIndexOp>(loc, idx);
6076 Value valueLow =
6077 builder.create<tensor::ExtractOp>(loc, edgePaddingLow, offset);
6078 Value valueHigh =
6079 builder.create<tensor::ExtractOp>(loc, edgePaddingHigh, offset);
6080 Value valueInterior =
6081 builder.create<tensor::ExtractOp>(loc, interiorPadding, offset);
6082 // output_size = input_size + padding_low + padding_high + interior *
6083 // max(input_size - 1, 0)
6084 Value valueDimLessThanOne = builder.create<arith::CmpIOp>(
6085 loc, arith::CmpIPredicate::slt, valueDim, one);
6086 Value interiorSize = builder.create<arith::MulIOp>(
6087 loc, valueInterior,
6088 builder.create<mlir::arith::SelectOp>(
6089 loc, valueDimLessThanOne, zero,
6090 builder.create<arith::SubIOp>(loc, valueDim, one)));
6091 shapeValues.push_back(builder.create<arith::AddIOp>(
6092 loc,
6093 builder.create<arith::AddIOp>(
6094 loc, builder.create<arith::AddIOp>(loc, interiorSize, valueDim),
6095 valueLow),
6096 valueHigh));
6097 }
6098
6099 reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
6100 loc,
6101 RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
6102 shapeScalarType),
6103 shapeValues));
6104
6105 return success();
6106 }
6107
6108 //===----------------------------------------------------------------------===//
6109 // ReshapeOp
6110 //===----------------------------------------------------------------------===//
6111
verify()6112 LogicalResult ReshapeOp::verify() {
6113 // If the operand type is dynamically shaped there is nothing to verify.
6114 auto operandTy = operand().getType().dyn_cast<RankedTensorType>();
6115 if (!operandTy || !operandTy.hasStaticShape()) return success();
6116
6117 // If the operand type is statically shaped (not required) the number of
6118 // elements must match that of the result type.
6119 auto resultTy = getType().cast<RankedTensorType>();
6120 assert(resultTy && resultTy.hasStaticShape() &&
6121 "result type must be statically shaped");
6122 int64_t numResultElements = resultTy.getNumElements();
6123 int64_t numOperandElements = operandTy.getNumElements();
6124 if (numResultElements != numOperandElements)
6125 return emitOpError() << "number of output elements (" << numResultElements
6126 << ") doesn't match expected number of elements ("
6127 << numOperandElements << ")";
6128
6129 return success();
6130 }
6131
fold(ArrayRef<Attribute> operands)6132 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
6133 if (getOperand().getType() == getType()) {
6134 return getOperand();
6135 }
6136
6137 if (auto prevOp = getOperand().getDefiningOp<ReshapeOp>()) {
6138 setOperand(prevOp.getOperand());
6139 return getResult();
6140 }
6141
6142 if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
6143 return reshape(elements, getResult().getType().cast<ShapedType>());
6144 }
6145
6146 return {};
6147 }
6148
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)6149 void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet& results,
6150 MLIRContext* context) {
6151 results.add<IdentityBroadcastReshape, IdentityBroadcastInDimReshape,
6152 EliminateRedundantReshape, EliminateIdentityReshape>(context);
6153 }
6154
6155 //===----------------------------------------------------------------------===//
6156 // ReplicaId Op
6157 //===----------------------------------------------------------------------===//
6158
inferReturnTypes(MLIRContext * context,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)6159 LogicalResult ReplicaIdOp::inferReturnTypes(
6160 MLIRContext* context, Optional<Location>, ValueRange operands,
6161 DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
6162 inferredReturnTypes.push_back(RankedTensorType::get(
6163 /*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned)));
6164 return success();
6165 }
6166
6167 //===----------------------------------------------------------------------===//
6168 // AddDependency Op
6169 //===----------------------------------------------------------------------===//
6170
inferReturnTypes(MLIRContext * context,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)6171 LogicalResult AddDependencyOp::inferReturnTypes(
6172 MLIRContext* context, Optional<Location>, ValueRange operands,
6173 DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
6174 inferredReturnTypes.push_back(operands.getTypes()[0]);
6175 return success();
6176 }
6177
6178 //===----------------------------------------------------------------------===//
6179 // If Op
6180 //===----------------------------------------------------------------------===//
6181
verifyConditionalBranch(Operation * op,Region & region,llvm::Twine branchName)6182 static LogicalResult verifyConditionalBranch(Operation* op, Region& region,
6183 llvm::Twine branchName) {
6184 if (region.getNumArguments() != 0)
6185 return op->emitOpError()
6186 << branchName << " must have 0 arguments, but found "
6187 << region.getNumArguments();
6188
6189 TypeRange branchReturnTypes =
6190 region.front().getTerminator()->getOperandTypes();
6191 if (branchReturnTypes != op->getResultTypes())
6192 return op->emitOpError()
6193 << branchName << " returned types (" << branchReturnTypes
6194 << ") do not match op result types (" << op->getResultTypes() << ")";
6195
6196 return success();
6197 }
6198
verify()6199 LogicalResult IfOp::verify() {
6200 if (failed(verifyConditionalBranch(*this, true_branch(),
6201 /*branchName=*/"true_branch"))) {
6202 return failure();
6203 }
6204
6205 if (failed(verifyConditionalBranch(*this, false_branch(),
6206 /*branchName=*/"false_branch"))) {
6207 return failure();
6208 }
6209 return success();
6210 }
6211
inlineIfConstantCondition(IfOp ifOp,PatternRewriter & rewriter)6212 static LogicalResult inlineIfConstantCondition(IfOp ifOp,
6213 PatternRewriter& rewriter) {
6214 DenseIntElementsAttr predAttr;
6215 if (!matchPattern(ifOp.pred(), m_Constant(&predAttr))) return failure();
6216
6217 if (predAttr.getSplatValue<BoolAttr>().getValue()) {
6218 replaceOpWithRegion(rewriter, ifOp, ifOp.true_branch());
6219 } else {
6220 replaceOpWithRegion(rewriter, ifOp, ifOp.false_branch());
6221 }
6222 return success();
6223 }
6224
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)6225 void IfOp::getCanonicalizationPatterns(RewritePatternSet& results,
6226 MLIRContext* context) {
6227 results.add(&inlineIfConstantCondition);
6228 }
6229
6230 //===----------------------------------------------------------------------===//
6231 // Case Op
6232 //===----------------------------------------------------------------------===//
6233
verify()6234 LogicalResult CaseOp::verify() {
6235 auto numBranches = branches().size();
6236
6237 for (unsigned i = 0; i < numBranches; ++i)
6238 if (failed(verifyConditionalBranch(*this, branches()[i],
6239 /*branchName=*/"branch " + Twine(i))))
6240 return failure();
6241
6242 return success();
6243 }
6244
inlineCaseConstantCondition(CaseOp caseOp,PatternRewriter & rewriter)6245 static LogicalResult inlineCaseConstantCondition(CaseOp caseOp,
6246 PatternRewriter& rewriter) {
6247 DenseIntElementsAttr indexAttr;
6248 if (!matchPattern(caseOp.index(), m_Constant(&indexAttr))) {
6249 return failure();
6250 }
6251 int64_t index =
6252 indexAttr.getSplatValue<IntegerAttr>().getValue().getSExtValue();
6253 // For an OOB index, the last branch is executed as the default branch:
6254 // https://www.tensorflow.org/xla/operation_semantics#conditional
6255 if (index < 0 || index >= caseOp.getNumRegions())
6256 index = caseOp.getNumRegions() - 1;
6257
6258 Region& region = caseOp.getRegion(index);
6259 if (!llvm::hasSingleElement(region)) return failure();
6260 replaceOpWithRegion(rewriter, caseOp, region);
6261 return success();
6262 }
6263
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)6264 void CaseOp::getCanonicalizationPatterns(RewritePatternSet& results,
6265 MLIRContext* context) {
6266 results.add(&inlineCaseConstantCondition);
6267 }
6268
6269 //===----------------------------------------------------------------------===//
6270 // SqrtOp
6271 //===----------------------------------------------------------------------===//
6272
fold(ArrayRef<Attribute> operands)6273 OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
6274 auto val = operands[0].dyn_cast_or_null<DenseElementsAttr>();
6275 if (!val) return {};
6276
6277 auto type = getElementTypeOrSelf(getType());
6278 if (!type.isF32() && !type.isF64()) return {};
6279
6280 auto shapedType = getType().cast<ShapedType>();
6281 if (!shapedType.hasStaticShape()) return {};
6282
6283 // Prevent folding if the result is too large.
6284 if (val.getNumElements() > kFoldOpEltLimit) return {};
6285
6286 int bitWidth = type.getIntOrFloatBitWidth();
6287 llvm::SmallVector<APFloat, 4> values;
6288 values.reserve(val.getNumElements());
6289 for (auto it : val.getValues<APFloat>()) {
6290 double value = bitWidth == 32 ? it.convertToFloat() : it.convertToDouble();
6291 if (value < 0) return {};
6292 value = std::sqrt(value);
6293 if (bitWidth == 32)
6294 values.emplace_back(static_cast<float>(value));
6295 else
6296 values.emplace_back(value);
6297 }
6298 return DenseFPElementsAttr::get(shapedType, values);
6299 }
6300
6301 //===----------------------------------------------------------------------===//
6302 // UnaryOps
6303 //===----------------------------------------------------------------------===//
6304
parseUnaryOp(OpAsmParser & parser,OperationState & result)6305 ParseResult parseUnaryOp(OpAsmParser& parser, OperationState& result) {
6306 SmallVector<OpAsmParser::UnresolvedOperand> operands;
6307 Type type;
6308 // If the operand is in-between parentheses, use generic form.
6309 SMLoc loc = parser.getCurrentLocation();
6310 if (!parser.parseOptionalLParen()) {
6311 if (parser.parseOperandList(operands) || parser.parseRParen() ||
6312 parser.parseOptionalAttrDict(result.attributes) ||
6313 parser.parseColon() || parser.parseType(type))
6314 return failure();
6315 auto fnType = type.dyn_cast<FunctionType>();
6316 if (!fnType) {
6317 parser.emitError(loc, "expected function type");
6318 return failure();
6319 }
6320 if (parser.resolveOperands(operands, fnType.getInputs(), loc,
6321 result.operands))
6322 return failure();
6323 result.addTypes(fnType.getResults());
6324 return success();
6325 }
6326 // Otherwise, use shorthand syntax.
6327 return failure(parser.parseOperandList(operands) ||
6328 parser.parseOptionalAttrDict(result.attributes) ||
6329 parser.parseColonType(type) ||
6330 parser.resolveOperands(operands, type, result.operands) ||
6331 parser.addTypeToList(type, result.types));
6332 }
6333
printUnaryOp(Operation * op,OpAsmPrinter & p)6334 void printUnaryOp(Operation* op, OpAsmPrinter& p) {
6335 assert(op->getNumResults() == 1 && "op should have one result");
6336 assert(op->getNumOperands() == 1 && "op should have one input");
6337 // If not all types are the same, use generic form.
6338 auto resultType = op->getResult(0).getType();
6339 if (resultType != op->getOperandTypes()[0]) {
6340 p.printGenericOp(op, /*printOpName=*/false);
6341 return;
6342 }
6343 // Otherwise, use the shorthand syntax.
6344 p << ' ';
6345 p.printOperands(op->getOperands());
6346 p.printOptionalAttrDict(op->getAttrs());
6347 p << " : " << resultType;
6348 }
6349
6350 template <typename Op, typename ElementType = Type, typename ValType,
6351 typename Convert>
UnaryFolder(Op * op,ArrayRef<Attribute> attrs)6352 static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
6353 if (!attrs[0]) return {};
6354
6355 DenseElementsAttr val = attrs[0].dyn_cast<DenseElementsAttr>();
6356 if (!val) return {};
6357
6358 ShapedType type = op->getType().template cast<ShapedType>();
6359 if (!type.hasStaticShape()) {
6360 return {};
6361 }
6362
6363 Type etype = type.getElementType();
6364
6365 // Evaluate for integer values.
6366 if (!etype.isa<ElementType>()) {
6367 return {};
6368 }
6369
6370 // Prevent folding if the result is too large.
6371 if (val.getNumElements() > kFoldOpEltLimit) return {};
6372
6373 SmallVector<ValType, 6> values;
6374 values.reserve(val.getNumElements());
6375 for (const auto v : val.getValues<ValType>()) {
6376 values.push_back(Convert()(v));
6377 }
6378
6379 return DenseElementsAttr::get(type, values);
6380 }
6381
6382 struct Round {
operator ()mlir::mhlo::Round6383 APFloat operator()(const APFloat& f) {
6384 APFloat r = f;
6385 r.roundToIntegral(llvm::RoundingMode::NearestTiesToAway);
6386 return r;
6387 }
6388 };
6389
6390 struct RoundNearestEven {
operator ()mlir::mhlo::RoundNearestEven6391 APFloat operator()(const APFloat& f) {
6392 APFloat r = f;
6393 r.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
6394 return r;
6395 }
6396 };
6397
6398 struct LogicalNot {
operator ()mlir::mhlo::LogicalNot6399 APInt operator()(const APInt& i) {
6400 return APInt(i.getBitWidth(), static_cast<uint64_t>(!i));
6401 }
6402 };
6403
6404 template <typename FloatOrInt>
6405 struct Sign {
computemlir::mhlo::Sign6406 APFloat compute(const APFloat& f) {
6407 if (f.isZero() || f.isNaN()) return f;
6408 double value = f.isNegative() ? -1.0 : 1.0;
6409 APFloat val(value);
6410 bool unused;
6411 val.convert(f.getSemantics(), APFloat::rmNearestTiesToEven, &unused);
6412 return val;
6413 }
6414
computemlir::mhlo::Sign6415 APInt compute(const APInt& i) {
6416 APInt r = i;
6417 if (r == 0) return r;
6418 if (r.isNegative()) {
6419 return APInt(r.getBitWidth(), -1, /*isSigned=*/true);
6420 }
6421 return APInt(r.getBitWidth(), 1, /*isSigned=*/true);
6422 }
6423
operator ()mlir::mhlo::Sign6424 FloatOrInt operator()(const FloatOrInt& fi) { return compute(fi); }
6425 };
6426
6427 #define UNARY_FOLDER(Op, Func) \
6428 OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
6429 if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
6430 return UnaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
6431 if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \
6432 return UnaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs); \
6433 return {}; \
6434 }
6435
6436 #define UNARY_FOLDER_INT(Op, Func) \
6437 OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
6438 if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \
6439 return UnaryFolder<Op, IntegerType, APInt, Func>(this, attrs); \
6440 return {}; \
6441 }
6442
6443 #define UNARY_FOLDER_FLOAT(Op, Func) \
6444 OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
6445 if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
6446 return UnaryFolder<Op, FloatType, APFloat, Func>(this, attrs); \
6447 return {}; \
6448 }
6449
6450 UNARY_FOLDER(NegOp, std::negate);
6451 UNARY_FOLDER(SignOp, Sign);
6452 UNARY_FOLDER_INT(NotOp, LogicalNot);
6453 UNARY_FOLDER_FLOAT(RoundNearestEvenOp, RoundNearestEven);
6454 UNARY_FOLDER_FLOAT(RoundOp, Round);
6455
6456 #undef UNARY_FOLDER
6457 #undef UNARY_FOLDER_INT
6458 #undef UNARY_FOLDER_FLOAT
6459
6460 //===----------------------------------------------------------------------===//
6461 // BinaryOps
6462 //===----------------------------------------------------------------------===//
6463
parseBinaryOp(OpAsmParser & parser,OperationState & result)6464 ParseResult parseBinaryOp(OpAsmParser& parser, OperationState& result) {
6465 SmallVector<OpAsmParser::UnresolvedOperand> operands;
6466 Type type;
6467 // If the operand list is in-between parentheses, use generic form.
6468 SMLoc loc = parser.getCurrentLocation();
6469 if (!parser.parseOptionalLParen()) {
6470 if (parser.parseOperandList(operands) || parser.parseRParen() ||
6471 parser.parseOptionalAttrDict(result.attributes) ||
6472 parser.parseColon() || parser.parseType(type))
6473 return failure();
6474 auto fnType = type.dyn_cast<FunctionType>();
6475 if (!fnType) {
6476 parser.emitError(loc, "expected function type");
6477 return failure();
6478 }
6479 if (parser.resolveOperands(operands, fnType.getInputs(), loc,
6480 result.operands))
6481 return failure();
6482 result.addTypes(fnType.getResults());
6483 return success();
6484 }
6485 // Otherwise, use shorthand syntax.
6486 return failure(parser.parseOperandList(operands) ||
6487 parser.parseOptionalAttrDict(result.attributes) ||
6488 parser.parseColonType(type) ||
6489 parser.resolveOperands(operands, type, result.operands) ||
6490 parser.addTypeToList(type, result.types));
6491 }
6492
printBinaryOp(Operation * op,OpAsmPrinter & p)6493 void printBinaryOp(Operation* op, OpAsmPrinter& p) {
6494 assert(op->getNumResults() == 1 && "op should have one result");
6495 // If not all types are the same, use generic form.
6496 auto resultType = op->getResult(0).getType();
6497 if (llvm::any_of(op->getOperandTypes(),
6498 [&](Type type) { return type != resultType; })) {
6499 p.printGenericOp(op, /*printOpName=*/false);
6500 return;
6501 }
6502 // Otherwise, use the shorthand syntax.
6503 p << ' ';
6504 p.printOperands(op->getOperands());
6505 p.printOptionalAttrDict(op->getAttrs());
6506 p << " : " << resultType;
6507 }
6508
addSign(const APFloat & v,Type)6509 static const APFloat& addSign(const APFloat& v, Type) { return v; }
addSign(const APInt & v,Type t)6510 static APSInt addSign(const APInt& v, Type t) {
6511 // Add signedness information to the value, treating signless as signed.
6512 return APSInt(v, t.isUnsignedInteger());
6513 }
6514
6515 template <typename Op, typename ElementType = Type, typename ValType,
6516 typename Convert>
BinaryFolder(Op * op,ArrayRef<Attribute> attrs)6517 static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {
6518 if (!attrs[0] || !attrs[1]) return {};
6519
6520 DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
6521 DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
6522 if (!lhs || !rhs) return {};
6523
6524 ShapedType type = op->getType().template cast<ShapedType>();
6525 if (!type.hasStaticShape()) {
6526 return {};
6527 }
6528
6529 Type etype = type.getElementType();
6530
6531 // Evaluate for integer values.
6532 if (!etype.isa<ElementType>()) {
6533 return {};
6534 }
6535
6536 // Special case for folding splats no matter how large.
6537 // Only covers the case of both attrs being splats; operation-specific cases
6538 // like adding a zero or multiplying by one are handled elsewhere.
6539 SplatElementsAttr splatLhs = lhs.dyn_cast<SplatElementsAttr>();
6540 SplatElementsAttr splatRhs = rhs.dyn_cast<SplatElementsAttr>();
6541 if (splatLhs && splatRhs) {
6542 auto signedLhs = addSign(splatLhs.getSplatValue<ValType>(), etype);
6543 auto signedRhs = addSign(splatRhs.getSplatValue<ValType>(), etype);
6544 FailureOr<decltype(signedLhs)> result(Convert()(signedLhs, signedRhs));
6545 return succeeded(result) ? SplatElementsAttr::get(type, *result)
6546 : Attribute();
6547 }
6548
6549 // Prevent folding if the result is too large.
6550 if (lhs.getNumElements() > kFoldOpEltLimit) return {};
6551
6552 SmallVector<ValType, 6> values;
6553 values.reserve(lhs.getNumElements());
6554 for (const auto zip :
6555 llvm::zip(lhs.getValues<ValType>(), rhs.getValues<ValType>())) {
6556 auto signedLhs = addSign(std::get<0>(zip), etype);
6557 auto signedRhs = addSign(std::get<1>(zip), etype);
6558 FailureOr<decltype(signedLhs)> result(Convert()(signedLhs, signedRhs));
6559 if (failed(result)) {
6560 return {};
6561 }
6562 values.push_back(std::move(*result));
6563 }
6564
6565 return DenseElementsAttr::get(type, values);
6566 }
6567
6568 template <typename T>
6569 struct Divide : std::divides<T> {};
6570
6571 template <>
6572 struct Divide<APSInt> {
operator ()mlir::mhlo::Divide6573 FailureOr<APSInt> operator()(const APSInt& a, const APSInt& b) const {
6574 if (b.isZero()) return failure();
6575 return a / b;
6576 }
6577 };
6578
6579 template <typename T>
6580 struct Remainder : std::modulus<T> {};
6581
6582 template <>
6583 struct Remainder<APSInt> {
operator ()mlir::mhlo::Remainder6584 FailureOr<APSInt> operator()(const APSInt& a, const APSInt& b) const {
6585 if (b.isZero()) return failure();
6586 return a % b;
6587 }
6588 };
6589
6590 template <>
6591 struct Remainder<APFloat> {
operator ()mlir::mhlo::Remainder6592 APFloat operator()(const APFloat& a, const APFloat& b) const {
6593 APFloat result(a);
6594 result.remainder(b);
6595 return result;
6596 }
6597 };
6598
6599 template <typename T>
6600 struct Max {
operator ()mlir::mhlo::Max6601 T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
6602 };
6603
6604 template <typename T>
6605 struct Min {
operator ()mlir::mhlo::Min6606 T operator()(const T& a, const T& b) const { return std::min<T>(a, b); }
6607 };
6608
6609 template <typename T>
6610 struct And {
operator ()mlir::mhlo::And6611 T operator()(const T& a, const T& b) const { return a & b; }
6612 };
6613
6614 template <typename T>
6615 struct Or {
operator ()mlir::mhlo::Or6616 T operator()(const T& a, const T& b) const { return a | b; }
6617 };
6618
6619 template <typename T>
6620 struct Xor {
operator ()mlir::mhlo::Xor6621 T operator()(const T& a, const T& b) const { return a ^ b; }
6622 };
6623
6624 #define BINARY_FOLDER_INTERNAL(Op, Func) \
6625 if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
6626 return BinaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
6627 if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \
6628 return BinaryFolder<Op, IntegerType, APInt, Func<APSInt>>(this, attrs); \
6629 return {};
6630
6631 #define BINARY_FOLDER(Op, Func) \
6632 OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
6633 BINARY_FOLDER_INTERNAL(Op, Func) \
6634 }
6635
6636 // Addition, subtraction and multiplication use the std:: versions of the ops.
6637 // Due to the other ops behaving differently in signed vs unsigned integers,
6638 // APInts need a special implementation. Currently, it replicates signed int
6639 // op behavior.
6640 BINARY_FOLDER(SubtractOp, std::minus);
6641 BINARY_FOLDER(DivOp, Divide);
6642 BINARY_FOLDER(RemOp, Remainder);
6643 BINARY_FOLDER(MaxOp, Max);
6644 BINARY_FOLDER(MinOp, Min);
6645
isSplatZero(SplatElementsAttr attr)6646 bool isSplatZero(SplatElementsAttr attr) {
6647 if (!attr) return false;
6648 if (attr.getElementType().isa<FloatType>()) {
6649 return attr.getSplatValue<APFloat>().isZero();
6650 }
6651 if (attr.getElementType().isa<IntegerType>()) {
6652 return attr.getSplatValue<APInt>().isZero();
6653 }
6654 return false;
6655 }
6656
fold(ArrayRef<Attribute> attrs)6657 OpFoldResult AddOp::fold(ArrayRef<Attribute> attrs) {
6658 // Handle special case where one operand is 0: x + 0 => x
6659 if (attrs[0] || attrs[1]) {
6660 SplatElementsAttr splatLhs = attrs[0].dyn_cast_or_null<SplatElementsAttr>();
6661 SplatElementsAttr splatRhs = attrs[1].dyn_cast_or_null<SplatElementsAttr>();
6662 if (isSplatZero(splatLhs)) return splatRhs ? (OpFoldResult)splatRhs : rhs();
6663 if (isSplatZero(splatRhs)) return splatLhs ? (OpFoldResult)splatLhs : lhs();
6664 }
6665 if (attrs[0] && attrs[1]) {
6666 BINARY_FOLDER_INTERNAL(AddOp, std::plus)
6667 }
6668 return {};
6669 }
6670
isSplatOne(SplatElementsAttr attr)6671 bool isSplatOne(SplatElementsAttr attr) {
6672 if (!attr) return false;
6673 if (attr.getElementType().isa<FloatType>()) {
6674 return attr.getSplatValue<APFloat>().convertToDouble() == 1.0;
6675 }
6676 if (attr.getElementType().isa<IntegerType>()) {
6677 return attr.getSplatValue<APInt>().getSExtValue() == 1;
6678 }
6679 return false;
6680 }
6681
fold(ArrayRef<Attribute> attrs)6682 OpFoldResult MulOp::fold(ArrayRef<Attribute> attrs) {
6683 // Handle special case where one operand is 1: x * 1 => x
6684 if (attrs[0] || attrs[1]) {
6685 SplatElementsAttr splatLhs = attrs[0].dyn_cast_or_null<SplatElementsAttr>();
6686 SplatElementsAttr splatRhs = attrs[1].dyn_cast_or_null<SplatElementsAttr>();
6687 if (isSplatOne(splatLhs)) return splatRhs ? (OpFoldResult)splatRhs : rhs();
6688 if (isSplatOne(splatRhs)) return splatLhs ? (OpFoldResult)splatLhs : lhs();
6689 }
6690 if (attrs[0] && attrs[1]) {
6691 BINARY_FOLDER_INTERNAL(MulOp, std::multiplies);
6692 }
6693 return {};
6694 }
6695
6696 //===----------------------------------------------------------------------===//
6697 // Logical Ops
6698 //===----------------------------------------------------------------------===//
6699
fold(ArrayRef<Attribute> operands)6700 OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
6701 if (lhs() == rhs()) return lhs();
6702
6703 auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
6704 auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
6705
6706 if (lhsVal && lhsVal.isSplat()) {
6707 if (lhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) {
6708 return rhs();
6709 }
6710
6711 if (lhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
6712 return lhsVal;
6713 }
6714 }
6715
6716 if (rhsVal && rhsVal.isSplat()) {
6717 if (rhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) {
6718 return lhs();
6719 }
6720
6721 if (rhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
6722 return rhsVal;
6723 }
6724 }
6725
6726 if (!rhsVal || !lhsVal) return {};
6727 return BinaryFolder<AndOp, IntegerType, APInt, And<APSInt>>(this, operands);
6728 }
6729
fold(ArrayRef<Attribute> operands)6730 OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
6731 if (lhs() == rhs()) return lhs();
6732
6733 auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
6734 auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
6735
6736 if (lhsVal && lhsVal.isSplat()) {
6737 if (lhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) {
6738 return lhsVal;
6739 }
6740
6741 if (lhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
6742 return rhs();
6743 }
6744 }
6745
6746 if (rhsVal && rhsVal.isSplat()) {
6747 if (rhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) {
6748 return rhsVal;
6749 }
6750
6751 if (rhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
6752 return lhs();
6753 }
6754 }
6755
6756 if (!rhsVal || !lhsVal) return {};
6757 return BinaryFolder<OrOp, IntegerType, APInt, Or<APSInt>>(this, operands);
6758 }
6759
fold(ArrayRef<Attribute> operands)6760 OpFoldResult XorOp::fold(ArrayRef<Attribute> operands) {
6761 // Fold x^x to 0. Attributes only support static shapes.
6762 auto rType = getType().cast<ShapedType>();
6763 if (lhs() == rhs() && rType.hasStaticShape()) {
6764 Builder builder(getContext());
6765 return builder.getZeroAttr(rType);
6766 }
6767
6768 auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
6769 auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
6770
6771 if (lhsVal && lhsVal.isSplat()) {
6772 if (lhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
6773 return rhs();
6774 }
6775 }
6776
6777 if (rhsVal && rhsVal.isSplat()) {
6778 if (rhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
6779 return lhs();
6780 }
6781 }
6782
6783 if (!rhsVal || !lhsVal) return {};
6784 return BinaryFolder<XorOp, IntegerType, APInt, Xor<APSInt>>(this, operands);
6785 }
6786
6787 #undef BINARY_FOLDER_INTERNAL
6788 #undef BINARY_FOLDER
6789
6790 //===----------------------------------------------------------------------===//
6791 // SliceOp
6792 //===----------------------------------------------------------------------===//
6793
6794 // Returns output dimension size for slice result for the given arguments.
6795 // Returns -1 if arguments are illegal.
inferSliceDim(int64_t inputDim,int64_t start,int64_t end,int64_t stride)6796 static int64_t inferSliceDim(int64_t inputDim, int64_t start, int64_t end,
6797 int64_t stride) {
6798 if (inputDim == -1 || start < 0 || start > end || end > inputDim ||
6799 stride == 0)
6800 return -1;
6801
6802 return llvm::divideCeil(end - start, stride);
6803 }
6804
6805 // The following properties are already enforced by the ODS:
6806 // type(start_indices) == type(limit_indices) == type(strides).
6807 // Verify the following properties:
6808 // P1. Verify rank(start_indices) == 1.
6809 // P2. Verify size(start_indices) == rank(operand).
6810 // P3~5. Verify 0 <= start_indices[i] <= limit_indices[i] <= shape(operand)[i].
6811 // P6. Verify stride[i] > 0.
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)6812 LogicalResult SliceOp::inferReturnTypes(
6813 MLIRContext* context, Optional<Location> location, ValueRange operands,
6814 DictionaryAttr attributes, RegionRange regions,
6815 SmallVectorImpl<Type>& inferredReturnTypes) {
6816 SliceOpAdaptor slice(operands, attributes);
6817 Type ty = slice.operand().getType();
6818 RankedTensorType rankedTy = ty.dyn_cast<RankedTensorType>();
6819 if (!rankedTy) {
6820 // The operand type is unranked, so the best we can infer for the result
6821 // type is an unranked tensor with the same element type as the operand
6822 // type.
6823 inferredReturnTypes.assign({ty});
6824 return success();
6825 }
6826
6827 ShapedType attrTy = slice.start_indices().getType();
6828 // P1.
6829 // Note: ODS has type(start_indices) == type(limit_indices) == type(strides)
6830 // So this implies rank(limit_indices) == rank(strides) == 1 also.
6831 if (attrTy.getRank() != 1) {
6832 return emitOptionalError(location, "start_indices has rank ",
6833 attrTy.getRank(), " instead of required rank 1");
6834 }
6835
6836 // P2.
6837 int64_t rank = rankedTy.getRank();
6838 if (attrTy.getNumElements() != rank) {
6839 return emitOptionalError(
6840 location, "the number of elements in start_indices (",
6841 attrTy.getNumElements(), ") does not match the rank of the operand (",
6842 rank, ")");
6843 }
6844
6845 SmallVector<int64_t, 4> start(slice.start_indices().getValues<int64_t>());
6846 SmallVector<int64_t, 4> limit(slice.limit_indices().getValues<int64_t>());
6847 SmallVector<int64_t, 4> strideVals(slice.strides().getValues<int64_t>());
6848
6849 SmallVector<int64_t, 4> shape;
6850 shape.reserve(rank);
6851 for (int64_t i = 0, e = rank; i != e; i++) {
6852 if (isDynamicDimSize(rankedTy.getDimSize(i))) {
6853 shape.push_back(ShapedType::kDynamicSize);
6854 continue;
6855 }
6856 // P3.
6857 if (start[i] < 0)
6858 return emitOptionalError(location, "negative start index ", start[i],
6859 " in dimension ", i);
6860 // P4.
6861 if (limit[i] > rankedTy.getDimSize(i))
6862 return emitOptionalError(location, "limit index ", limit[i],
6863 " is larger than dimension size ",
6864 rankedTy.getDimSize(i), " in dimension ", i);
6865 // P5.
6866 if (start[i] > limit[i])
6867 return emitOptionalError(location, "start index ", start[i],
6868 " is larger than limit index ", limit[i],
6869 " in dimension ", i);
6870 // P6.
6871 if (strideVals[i] <= 0)
6872 return emitOptionalError(location, "stride must be positive but got ",
6873 strideVals[i], " in dimension ", i);
6874
6875 shape.push_back(inferSliceDim(rankedTy.getDimSize(i), start[i], limit[i],
6876 strideVals[i]));
6877 }
6878 inferredReturnTypes.assign(
6879 {RankedTensorType::get(shape, rankedTy.getElementType())});
6880 return success();
6881 }
6882
6883 template <typename I, typename E>
sliceElements(I values,ArrayRef<int64_t> sizes,ArrayRef<int64_t> starts,ArrayRef<int64_t> limits,ArrayRef<int64_t> strides,llvm::SmallVectorImpl<E> * outValues)6884 static void sliceElements(I values, ArrayRef<int64_t> sizes,
6885 ArrayRef<int64_t> starts, ArrayRef<int64_t> limits,
6886 ArrayRef<int64_t> strides,
6887 llvm::SmallVectorImpl<E>* outValues) {
6888 assert(starts.size() == limits.size());
6889 assert(starts.size() == strides.size());
6890 if (starts.empty()) return;
6891
6892 int64_t start = starts.front();
6893 int64_t limit = limits.front();
6894 int64_t stride = strides.front();
6895 if (starts.size() == 1) {
6896 for (int i = start; i < limit; i += stride) {
6897 outValues->push_back(*(values + i));
6898 }
6899 return;
6900 }
6901
6902 for (; start < limit; start += stride) {
6903 auto begin = values + start * sizes.front();
6904 sliceElements<I, E>(begin, sizes.drop_front(), starts.drop_front(),
6905 limits.drop_front(), strides.drop_front(), outValues);
6906 }
6907 }
6908
6909 template <typename I, typename E>
foldSlice(SliceOp * op,I values)6910 static Attribute foldSlice(SliceOp* op, I values) {
6911 auto start = llvm::to_vector<6>(op->start_indices().getValues<int64_t>());
6912 auto limit = llvm::to_vector<6>(op->limit_indices().getValues<int64_t>());
6913 auto stride = llvm::to_vector<6>(op->strides().getValues<int64_t>());
6914
6915 // TODO(b/235903849): This should be op->getType().case<ShapedType>().
6916 auto resultType = op->operand().getType().cast<ShapedType>();
6917 if (!resultType.hasStaticShape()) return {};
6918
6919 auto shape = resultType.getShape();
6920 int64_t count = resultType.getNumElements();
6921 if (count == 0) {
6922 return DenseElementsAttr::get<E>(
6923 op->getResult().getType().cast<ShapedType>(),
6924 /*list=*/{});
6925 }
6926
6927 // Compute the striding for each dimension.
6928 llvm::SmallVector<int64_t, 6> sizes;
6929 sizes.reserve(shape.size());
6930 for (auto v : shape) {
6931 count = count / v;
6932 sizes.push_back(count);
6933 }
6934
6935 // Prevent folding if the result is too large.
6936 if (resultType.getNumElements() > kFoldOpEltLimit) return {};
6937
6938 llvm::SmallVector<E, 6> outValues;
6939 outValues.reserve(resultType.getNumElements());
6940 sliceElements<I, E>(values, sizes, start, limit, stride, &outValues);
6941
6942 return DenseElementsAttr::get(op->getResult().getType().cast<ShapedType>(),
6943 outValues);
6944 }
6945
fold(ArrayRef<Attribute> operands)6946 OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
6947 // Check if the SliceOp is a NoOp operation.
6948 auto operandType = getOperand().getType().cast<ShapedType>();
6949 auto resultType = getResult().getType().cast<ShapedType>();
6950
6951 if (operandType.hasStaticShape() && resultType.hasStaticShape() &&
6952 (operandType.getShape() == resultType.getShape())) {
6953 return getOperand();
6954 }
6955
6956 if (operands.empty() || !operands.front()) return {};
6957
6958 // Evaluate for statically valued inputs.
6959 DenseElementsAttr elements = operands.front().dyn_cast<DenseElementsAttr>();
6960 if (!elements) return {};
6961
6962 auto etype = elements.getType().getElementType();
6963 if (etype.isa<IntegerType>()) {
6964 return foldSlice<DenseElementsAttr::IntElementIterator, APInt>(
6965 this, elements.value_begin<APInt>());
6966 }
6967 if (etype.isa<FloatType>()) {
6968 return foldSlice<DenseElementsAttr::FloatElementIterator, APFloat>(
6969 this, elements.value_begin<APFloat>());
6970 }
6971
6972 return {};
6973 }
6974
6975 namespace {
6976 // In cases where a concat is fed into a slice, it is possible the concat
6977 // can be simplified or bypassed. This checks which inputs to the concat are
6978 // used by the slice, either reducing the number of concatenated values or
6979 // entirely removes the concat.
6980 struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
6981 using OpRewritePattern<SliceOp>::OpRewritePattern;
6982
matchAndRewritemlir::mhlo::__anon00baf10a3911::SimplifyConcatSlice6983 LogicalResult matchAndRewrite(SliceOp slice,
6984 PatternRewriter& rewriter) const override {
6985 auto resultTy = slice.getType().cast<ShapedType>();
6986 if (!resultTy.hasStaticShape()) {
6987 return failure();
6988 }
6989
6990 auto sliceInput = slice.operand();
6991 auto sliceInputTy = sliceInput.getType().cast<ShapedType>();
6992 auto concat = sliceInput.getDefiningOp<ConcatenateOp>();
6993 if (!concat) {
6994 return failure();
6995 }
6996
6997 auto dimension = concat.dimension();
6998
6999 auto start = slice.start_indices().getValues<APInt>();
7000 auto limit = slice.limit_indices().getValues<APInt>();
7001
7002 auto sliceStart = (*(start.begin() + dimension)).getSExtValue();
7003 auto sliceLimit = (*(limit.begin() + dimension)).getSExtValue();
7004
7005 // We need to determine what inputs from the concat affect the slice, and
7006 // how the bounds of the slice need to be updated for the minimally required
7007 // inputs.
7008 int64_t runningSize = 0;
7009 int64_t frontOffset = sliceInputTy.getShape()[dimension];
7010
7011 auto subsetStart = concat.operand_end();
7012 auto subsetEnd = concat.operand_end();
7013 for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) {
7014 auto input = *it;
7015 ShapedType inputTy = input.getType().cast<ShapedType>();
7016 if (inputTy.isDynamicDim(dimension)) {
7017 return failure();
7018 }
7019 auto dimSize = inputTy.getShape()[dimension];
7020
7021 // If this position is in the slice its the start of the subset and we
7022 // need to update the start and limit values.
7023 if (runningSize + dimSize > sliceStart &&
7024 subsetStart == concat.operand_end()) {
7025 subsetStart = it;
7026 frontOffset = runningSize;
7027 }
7028
7029 // Determine the last required offset.
7030 if (runningSize < sliceLimit) {
7031 subsetEnd = it + 1;
7032 }
7033
7034 runningSize += dimSize;
7035 }
7036
7037 auto subsetSize = subsetEnd - subsetStart;
7038 // We need all inputs so no optimization.
7039 if (subsetSize == concat.getNumOperands()) {
7040 return failure();
7041 }
7042
7043 // If there's nothing to slice that means the output is an empty tensor and
7044 // there is dead code. We do nothing here and rely on other passes to clean
7045 // this up.
7046 if (subsetSize == 0) {
7047 return failure();
7048 }
7049
7050 if (subsetSize > 1 && !concat.getResult().hasOneUse()) {
7051 return failure();
7052 }
7053
7054 auto concatRange = OperandRange(subsetStart, subsetEnd);
7055 auto newConcat = rewriter.create<ConcatenateOp>(
7056 concat.getLoc(), concatRange, concat.dimension());
7057
7058 llvm::SmallVector<APInt, 6> newStart(start);
7059 llvm::SmallVector<APInt, 6> newLimit(limit);
7060 newStart[dimension] -= frontOffset;
7061 newLimit[dimension] -= frontOffset;
7062
7063 auto attrType = slice.start_indices().getType().cast<ShapedType>();
7064 auto create = rewriter.create<SliceOp>(
7065 slice.getLoc(), newConcat,
7066 DenseIntElementsAttr::get(attrType, newStart),
7067 DenseIntElementsAttr::get(attrType, newLimit), slice.strides());
7068 rewriter.replaceOp(slice, create.getResult());
7069 return success();
7070 }
7071 };
7072 } // namespace
7073
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)7074 void SliceOp::getCanonicalizationPatterns(RewritePatternSet& results,
7075 MLIRContext* context) {
7076 results.add<SimplifyConcatSlice>(context);
7077 }
7078
7079 //===----------------------------------------------------------------------===//
7080 // SortOp
7081 //===----------------------------------------------------------------------===//
7082
build(OpBuilder & builder,OperationState & state,ValueRange operands,int64_t dimension,bool isStable)7083 void SortOp::build(OpBuilder& builder, OperationState& state,
7084 ValueRange operands, int64_t dimension, bool isStable) {
7085 state.addOperands(operands);
7086 state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
7087 state.addAttribute("is_stable", builder.getBoolAttr(isStable));
7088
7089 for (Value operand : operands) state.addTypes(operand.getType());
7090
7091 state.addRegion();
7092 }
7093
verify()7094 LogicalResult SortOp::verify() {
7095 Operation::operand_range operands = this->operands();
7096 if (operands.empty()) return emitOpError("requires at least one input");
7097
7098 // TODO(antiagainst): verify partionally dynamic shapes
7099 if (llvm::all_of(operands, [](Value operand) {
7100 return operand.getType().cast<ShapedType>().hasRank();
7101 })) {
7102 ArrayRef<int64_t> inputShape =
7103 (*operands.begin()).getType().cast<ShapedType>().getShape();
7104
7105 if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) {
7106 return operand.getType().cast<ShapedType>().getShape() != inputShape;
7107 }))
7108 return emitOpError("requires all inputs to have the same dimensions");
7109
7110 int64_t rank = inputShape.size();
7111 int64_t cmpDim = dimension();
7112 if (cmpDim < -rank || cmpDim >= rank)
7113 return emitOpError("dimension attribute value must be in range [-")
7114 << rank << ", " << rank << "), but found " << cmpDim;
7115 }
7116
7117 Block& block = comparator().front();
7118 size_t numOperands = getOperation()->getNumOperands();
7119 if (block.getNumArguments() != 2 * numOperands)
7120 return emitOpError("comparator block should have ")
7121 << 2 * numOperands << " arguments";
7122
7123 for (const auto& indexedOperand : llvm::enumerate(operands)) {
7124 int index = indexedOperand.index();
7125 Type elementType =
7126 indexedOperand.value().getType().cast<ShapedType>().getElementType();
7127 Type tensorType = RankedTensorType::get({}, elementType);
7128 for (int i : {2 * index, 2 * index + 1}) {
7129 Type argType = block.getArgument(i).getType();
7130 if (argType != tensorType)
7131 return emitOpError("comparator block argument #")
7132 << i << " should be of type " << tensorType << " but got "
7133 << argType;
7134 }
7135 }
7136
7137 // Mapped computation must return single output.
7138 auto comparatorResult = block.getTerminator()->getOperands();
7139 if (comparatorResult.size() != 1)
7140 return emitOpError() << "comparator must return single output, but got: "
7141 << comparatorResult.size();
7142
7143 // The output of computation must be 0-ranked tensor with element-type i1.
7144 auto comparatorResultType =
7145 comparatorResult[0].getType().dyn_cast<RankedTensorType>();
7146 if (!comparatorResultType || comparatorResultType.getRank() != 0 ||
7147 !comparatorResultType.getElementType().isInteger(1))
7148 return emitOpError() << "comparator must return tensor<i1>, but got: "
7149 << comparatorResult[0].getType();
7150
7151 // check number of return-values and their element-types.
7152 auto resultTypes = getResultTypes();
7153 if (resultTypes.size() != numOperands)
7154 return emitOpError() << "expects the number of results to be same as "
7155 "number of operands. Got number of results = "
7156 << resultTypes.size()
7157 << " and number of operands = " << numOperands;
7158
7159 for (auto it : llvm::zip(operands, getResultTypes()))
7160 if (std::get<0>(it).getType().cast<TensorType>().getElementType() !=
7161 std::get<1>(it).cast<TensorType>().getElementType())
7162 return emitOpError()
7163 << "expects the operands and results to have pairwize equal "
7164 "element-types, but got "
7165 << std::get<0>(it).getType().cast<TensorType>().getElementType()
7166 << " vs " << std::get<1>(it).cast<TensorType>().getElementType();
7167
7168 return success();
7169 }
7170
7171 /// Drops the operands if the results are not used and they are not used in
7172 /// op.comparator().
sortDropEmptyUseArgs(SortOp op,PatternRewriter & rewriter)7173 static LogicalResult sortDropEmptyUseArgs(SortOp op,
7174 PatternRewriter& rewriter) {
7175 DenseSet<unsigned> erasedArgs;
7176 unsigned numOperands = op.getNumOperands();
7177 for (unsigned i = 0; i < numOperands; ++i) {
7178 if (!op.getResult(i).use_empty()) continue;
7179 Block& block = op.comparator().front();
7180 if (!block.getArgument(i * 2).use_empty()) continue;
7181 if (!block.getArgument(i * 2 + 1).use_empty()) continue;
7182 erasedArgs.insert(i);
7183 }
7184 if (erasedArgs.empty()) return failure();
7185
7186 SmallVector<Value> newOperands;
7187 SmallVector<unsigned> erasedBlockArgs;
7188 for (const auto& en : llvm::enumerate(op.operands())) {
7189 if (erasedArgs.contains(en.index())) {
7190 erasedBlockArgs.push_back(en.index() * 2);
7191 erasedBlockArgs.push_back(en.index() * 2 + 1);
7192 } else {
7193 newOperands.push_back(en.value());
7194 }
7195 }
7196
7197 auto newOp = rewriter.create<SortOp>(op.getLoc(), newOperands, op.dimension(),
7198 op.is_stable());
7199 Region& region = newOp.comparator();
7200 rewriter.inlineRegionBefore(op.comparator(), region, region.end());
7201 region.front().eraseArguments(erasedBlockArgs);
7202
7203 SmallVector<Value> results;
7204 for (unsigned i = 0, j = 0; i < numOperands; ++i) {
7205 if (erasedArgs.contains(i)) {
7206 results.push_back({});
7207 } else {
7208 results.push_back(newOp.getResult(j++));
7209 }
7210 }
7211 rewriter.replaceOp(op, results);
7212
7213 return success();
7214 }
7215
7216 /// Set the sorting dimension to the last dimension if it's not set and the rank
7217 /// is known.
sortOpInferDefaultDimension(SortOp op,PatternRewriter & rewriter)7218 static LogicalResult sortOpInferDefaultDimension(SortOp op,
7219 PatternRewriter& rewriter) {
7220 auto ty = op.getResultTypes()[0].dyn_cast<ShapedType>();
7221 if (!ty) {
7222 return failure();
7223 }
7224 if (static_cast<int64_t>(op.dimension()) != -1) {
7225 return failure();
7226 }
7227
7228 IntegerAttr dim = rewriter.getI64IntegerAttr(ty.getRank() - 1);
7229 auto newOp = rewriter.create<SortOp>(op.getLoc(), op.getResultTypes(),
7230 op.operands(), dim, op.is_stableAttr());
7231 Region& region = newOp.comparator();
7232 rewriter.inlineRegionBefore(op.comparator(), region, region.end());
7233 rewriter.replaceOp(op, newOp.getResults());
7234
7235 return success();
7236 }
7237
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext *)7238 void SortOp::getCanonicalizationPatterns(RewritePatternSet& results,
7239 MLIRContext* /*context*/) {
7240 results.add(sortDropEmptyUseArgs);
7241 results.add(sortOpInferDefaultDimension);
7242 }
7243
7244 //===----------------------------------------------------------------------===//
7245 // TransposeOp
7246 //===----------------------------------------------------------------------===//
7247
fold(ArrayRef<Attribute> operands)7248 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
7249 if (auto elements = operands.front().dyn_cast_or_null<SplatElementsAttr>()) {
7250 return reshape(elements, getResult().getType().cast<ShapedType>());
7251 }
7252 for (const auto& it : llvm::enumerate(permutation().getValues<APInt>())) {
7253 if (it.index() != it.value()) {
7254 return {};
7255 }
7256 }
7257 return getOperand();
7258 }
7259
7260 // transpose(transpose(X)) => transpose(X)
eliminateRedundantTranspse(TransposeOp op,PatternRewriter & rewriter)7261 static LogicalResult eliminateRedundantTranspse(TransposeOp op,
7262 PatternRewriter& rewriter) {
7263 auto tranposeOperand = op.operand().getDefiningOp<TransposeOp>();
7264 if (!tranposeOperand) {
7265 return failure();
7266 }
7267 auto operandPermutation = tranposeOperand.permutation().getValues<APInt>();
7268 auto newPermutation =
7269 op.permutation()
7270 .mapValues(op.permutation().getElementType(),
7271 [&operandPermutation](const APInt& index) -> APInt {
7272 return operandPermutation[index.getSExtValue()];
7273 })
7274 .cast<DenseIntElementsAttr>();
7275 rewriter.replaceOpWithNewOp<TransposeOp>(
7276 op, op.getResult().getType(), tranposeOperand.operand(), newPermutation);
7277 return success();
7278 }
7279
7280 // transpose(broadcast_in_dim(X)) => broadcast_in_dim(X)
eliminateBroadcastInDimTranspose(TransposeOp op,PatternRewriter & rewriter)7281 static LogicalResult eliminateBroadcastInDimTranspose(
7282 TransposeOp op, PatternRewriter& rewriter) {
7283 auto broadcastInDimOp = op.operand().getDefiningOp<BroadcastInDimOp>();
7284 if (!broadcastInDimOp) {
7285 return failure();
7286 }
7287 DenseIntElementsAttr broadcastDimensions =
7288 broadcastInDimOp.broadcast_dimensions();
7289 DenseIntElementsAttr permutation = op.permutation();
7290 SmallVector<int64_t> newBroadcastDimensions;
7291 for (auto dimension : broadcastDimensions.getValues<int64_t>()) {
7292 int64_t index = 0;
7293 for (auto p : permutation.getValues<int64_t>()) {
7294 if (p == dimension) {
7295 newBroadcastDimensions.push_back(index);
7296 break;
7297 }
7298 index++;
7299 }
7300 }
7301 rewriter.replaceOpWithNewOp<BroadcastInDimOp>(
7302 op, op->getResultTypes(), broadcastInDimOp.operand(),
7303 rewriter.getI64TensorAttr(newBroadcastDimensions));
7304 return success();
7305 }
7306
7307 // simplify Transpose: replace Transpose with Reshape if they are equivalent
simplifyTranspose(TransposeOp op,PatternRewriter & rewriter)7308 static LogicalResult simplifyTranspose(TransposeOp op,
7309 PatternRewriter& rewriter) {
7310 auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
7311 auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
7312 if (!operandType || !resultType) {
7313 return failure();
7314 }
7315 // Not support dynamic shape a.t.m. BTW, when it's dynamic shape,
7316 // maybe Transpose should be replaced by DynamicReshape.
7317 if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) {
7318 return failure();
7319 }
7320 auto permutation = op.permutation().getValues<int64_t>();
7321 llvm::SmallVector<int64_t> sortedPermutation;
7322 for (int64_t i = 0, e = resultType.getRank(); i < e; i++) {
7323 if (resultType.getDimSize(i) != 1) {
7324 sortedPermutation.push_back(permutation[i]);
7325 }
7326 }
7327 if (llvm::is_sorted(sortedPermutation)) {
7328 rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand());
7329 return success();
7330 }
7331 return failure();
7332 }
7333
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext *)7334 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet& results,
7335 MLIRContext* /*context*/) {
7336 results.add(eliminateRedundantTranspse);
7337 results.add(eliminateBroadcastInDimTranspose);
7338 results.add(simplifyTranspose);
7339 }
7340
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)7341 LogicalResult TransposeOp::reifyReturnTypeShapes(
7342 OpBuilder& builder, ValueRange operands,
7343 SmallVectorImpl<Value>& reifiedReturnShapes) {
7344 TransposeOp::Adaptor adaptor(operands);
7345 Value operand = adaptor.operand();
7346
7347 auto operandType = operand.getType().dyn_cast<RankedTensorType>();
7348 // Not support unranked type a.t.m.
7349 if (!operandType) return failure();
7350
7351 Location loc = this->getLoc();
7352 SmallVector<int64_t, 4> permutation(this->permutation().getValues<int64_t>());
7353 SmallVector<Value, 4> shapeValues(permutation.size());
7354
7355 Type shapeScalarType = builder.getIndexType();
7356 auto toShapeScalarType = [&](Value v) {
7357 return maybeCastTo(builder, loc, v, shapeScalarType);
7358 };
7359
7360 for (const auto& element : llvm::enumerate(operandType.getShape())) {
7361 int64_t idx = element.index();
7362 auto* it = std::find(permutation.begin(), permutation.end(), idx);
7363 Value valueDim = toShapeScalarType(
7364 builder.createOrFold<tensor::DimOp>(loc, operand, element.index()));
7365 shapeValues[std::distance(permutation.begin(), it)] = valueDim;
7366 }
7367
7368 Value outputShape = builder.create<tensor::FromElementsOp>(
7369 loc,
7370 RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
7371 shapeScalarType),
7372 shapeValues);
7373 reifiedReturnShapes.push_back(outputShape);
7374
7375 return success();
7376 }
7377
7378 // Method for InferTypeOpInterface: infer the return type from the operand type
7379 // and the permutation.
inferReturnTypes(MLIRContext *,Optional<Location> loc,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)7380 LogicalResult TransposeOp::inferReturnTypes(
7381 MLIRContext* /*context*/, Optional<Location> loc, ValueRange operands,
7382 DictionaryAttr attributes, RegionRange,
7383 SmallVectorImpl<Type>& inferredReturnTypes) {
7384 auto type = operands[0].getType();
7385 auto rankedTy = type.dyn_cast<RankedTensorType>();
7386 if (!rankedTy) {
7387 auto shapedTy = type.dyn_cast<ShapedType>();
7388 inferredReturnTypes.emplace_back(shapedTy);
7389 return success();
7390 }
7391 auto permutation = attributes.getAs<DenseIntElementsAttr>("permutation");
7392 int64_t rank = rankedTy.getRank();
7393 if (permutation.getType().getRank() != 1)
7394 return emitOptionalError(loc, "TransposeOp permutation has rank ",
7395 permutation.getType().getRank(),
7396 " instead of rank 1");
7397
7398 if (permutation.size() != rank)
7399 return emitOptionalError(loc, "TransposeOp operand rank ", rank,
7400 " does not match permutation size ",
7401 permutation.size());
7402
7403 std::vector<int64_t> range(rank);
7404 std::iota(range.begin(), range.end(), 0);
7405 if (!std::is_permutation(range.begin(), range.end(), permutation.begin()))
7406 return emitOptionalError(loc,
7407 "attribute permutation must be a permutation"
7408 " of [",
7409 range, "] but got ", permutation);
7410
7411 SmallVector<int64_t> resultShape;
7412 ArrayRef<int64_t> inputShape = rankedTy.getShape();
7413 for (int64_t dim : permutation.getValues<int64_t>()) {
7414 resultShape.push_back(inputShape[dim]);
7415 }
7416 inferredReturnTypes.emplace_back(RankedTensorType::get(
7417 resultShape, rankedTy.getElementType(), rankedTy.getEncoding()));
7418 return success();
7419 }
7420
7421 //===----------------------------------------------------------------------===//
7422 // TriangularSolveOp
7423 //===----------------------------------------------------------------------===//
7424
verify()7425 LogicalResult TriangularSolveOp::verify() {
7426 auto aType = a().getType().dyn_cast<RankedTensorType>();
7427
7428 // Skip verifier if a is unranked tensor.
7429 if (!aType) return success();
7430
7431 // Check that a should have rank >= 2
7432 auto aRank = aType.getRank();
7433 if (aRank < 2)
7434 return emitOpError() << "operand 'a' must have rank >= 2, but got "
7435 << aType;
7436
7437 // The two minor dimensions of a must have same size.
7438 if (aType.getDimSize(aRank - 2) != aType.getDimSize(aRank - 1))
7439 return emitOpError() << "two minor dimensions of operand 'a' must have "
7440 "equal size, but got "
7441 << aType;
7442
7443 auto bType = b().getType().dyn_cast<RankedTensorType>();
7444 // If b is unranked skip remaining checks.
7445 if (!bType) return success();
7446
7447 // Check that a and b have same rank.
7448 auto bRank = bType.getRank();
7449 if (aRank != bRank)
7450 return emitOpError() << "operands must have equal rank, but got " << aType
7451 << " and " << bType;
7452
7453 // The shared dimension of a and b should match.
7454 if (aType.getDimSize(aRank - 1) !=
7455 bType.getDimSize(bRank - (left_side() ? 2 : 1)))
7456 return emitOpError() << "shared dimension of operands 'a' and 'b' does "
7457 "not match, but got "
7458 << aType << " and " << bType;
7459
7460 // The leading batch dimensions of a and b must be equal.
7461 auto aBatchDims = aType.getShape().drop_back(2);
7462 auto bBatchDims = bType.getShape().drop_back(2);
7463 if (aBatchDims != bBatchDims)
7464 return emitOpError()
7465 << "leading batch dimensions of the operands must be same, but got "
7466 << aType << " and " << bType;
7467
7468 // Result and argument b must have same shape.
7469 auto resultType = getType().dyn_cast<RankedTensorType>();
7470 if (!resultType) return success();
7471 if (resultType != bType)
7472 return emitOpError()
7473 << "result and operand 'b' must have same shape, but got "
7474 << resultType << " and " << bType;
7475 return success();
7476 }
7477
7478 //===----------------------------------------------------------------------===//
7479 // GetTupleElementOp
7480 //===----------------------------------------------------------------------===//
7481
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)7482 LogicalResult GetTupleElementOp::inferReturnTypes(
7483 MLIRContext*, Optional<Location>, ValueRange operands,
7484 DictionaryAttr attributes, RegionRange,
7485 SmallVectorImpl<Type>& inferredReturnTypes) {
7486 auto tupleType = operands[0].getType().dyn_cast<TupleType>();
7487 if (!tupleType) return failure();
7488
7489 auto indexAttr = attributes.get("index").cast<IntegerAttr>();
7490 auto index = indexAttr.getInt();
7491 if (index < 0 || index >= static_cast<int64_t>(tupleType.size()))
7492 return failure();
7493
7494 inferredReturnTypes.push_back(tupleType.getType(index));
7495 return success();
7496 }
7497
7498 //===----------------------------------------------------------------------===//
7499 // TupleOp
7500 //===----------------------------------------------------------------------===//
7501
inferReturnTypes(MLIRContext * context,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)7502 LogicalResult TupleOp::inferReturnTypes(
7503 MLIRContext* context, Optional<Location>, ValueRange operands,
7504 DictionaryAttr attributes, RegionRange,
7505 SmallVectorImpl<Type>& inferredReturnTypes) {
7506 inferredReturnTypes.push_back(TupleType::get(context, TypeRange(operands)));
7507 return success();
7508 }
7509
7510 //===----------------------------------------------------------------------===//
7511 // UnaryEinsumOp
7512 //===----------------------------------------------------------------------===//
7513
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)7514 void UnaryEinsumOp::getCanonicalizationPatterns(RewritePatternSet& results,
7515 MLIRContext* context) {
7516 results.add<UnaryEinsumToEinsum>(context);
7517 }
7518
7519 //===----------------------------------------------------------------------===//
7520 // CompareOp
7521 //===----------------------------------------------------------------------===//
7522
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,ComparisonDirection comparisonDirection,ComparisonType compareType)7523 void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
7524 Value rhs, ComparisonDirection comparisonDirection,
7525 ComparisonType compareType) {
7526 build(builder, result, lhs, rhs,
7527 ComparisonDirectionAttr::get(builder.getContext(), comparisonDirection),
7528 ComparisonTypeAttr::get(builder.getContext(), compareType));
7529 }
7530
inferReturnTypeComponents(mlir::MLIRContext * ctx,llvm::Optional<mlir::Location>,ValueShapeRange operands,mlir::DictionaryAttr,mlir::RegionRange,llvm::SmallVectorImpl<mlir::ShapedTypeComponents> & inferredReturnTypes)7531 LogicalResult CompareOp::inferReturnTypeComponents(
7532 mlir::MLIRContext* ctx, llvm::Optional<mlir::Location>,
7533 ValueShapeRange operands, mlir::DictionaryAttr, mlir::RegionRange,
7534 llvm::SmallVectorImpl<mlir::ShapedTypeComponents>& inferredReturnTypes) {
7535 ShapedTypeComponents& components =
7536 inferredReturnTypes.emplace_back(IntegerType::get(ctx, /*width=*/1));
7537 auto argTy = operands.front().getType().cast<TensorType>();
7538 if (argTy.hasRank()) {
7539 components =
7540 ShapedTypeComponents(argTy.getShape(), components.getElementType());
7541 }
7542 return success();
7543 }
7544
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)7545 LogicalResult CompareOp::reifyReturnTypeShapes(
7546 OpBuilder& builder, ValueRange operands,
7547 SmallVectorImpl<Value>& reifiedReturnShapes) {
7548 return deriveShapeFromOperand(&builder, getOperation(), operands.front(),
7549 &reifiedReturnShapes);
7550 }
7551
7552 template <typename T>
7553 struct Less : std::less<T> {};
7554
7555 template <>
7556 struct Less<APInt> {
operator ()mlir::mhlo::Less7557 bool operator()(const APInt& a, const APInt& b) const { return a.slt(b); }
7558 };
7559
7560 template <typename T>
7561 struct LessEqual : std::less_equal<T> {};
7562
7563 template <>
7564 struct LessEqual<APInt> {
operator ()mlir::mhlo::LessEqual7565 bool operator()(const APInt& a, const APInt& b) const { return a.sle(b); }
7566 };
7567
7568 template <typename T>
7569 struct Greater : std::greater<T> {};
7570
7571 template <>
7572 struct Greater<APInt> {
operator ()mlir::mhlo::Greater7573 bool operator()(const APInt& a, const APInt& b) const { return a.sgt(b); }
7574 };
7575
7576 template <typename T>
7577 struct GreaterEqual : std::greater_equal<T> {};
7578
7579 template <>
7580 struct GreaterEqual<APInt> {
operator ()mlir::mhlo::GreaterEqual7581 bool operator()(const APInt& a, const APInt& b) const { return a.sge(b); }
7582 };
7583
7584 template <typename Op, typename ElementType, typename SrcType, typename Convert>
CompareFolder(CompareOp op,ArrayRef<Attribute> attrs)7585 static Attribute CompareFolder(CompareOp op, ArrayRef<Attribute> attrs) {
7586 if (!attrs[0] || !attrs[1]) return {};
7587
7588 DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
7589 DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
7590 if (!lhs || !rhs) return {};
7591
7592 ShapedType operandType =
7593 op.getOperand(0).getType().template cast<ShapedType>();
7594 if (!operandType.hasStaticShape()) {
7595 return {};
7596 }
7597
7598 if (!operandType.getElementType().isa<ElementType>()) {
7599 return {};
7600 }
7601
7602 // Prevent folding if the result is too large.
7603 if (lhs.getNumElements() > kFoldOpEltLimit) return {};
7604
7605 SmallVector<bool, 6> values;
7606 values.reserve(lhs.getNumElements());
7607 for (const auto zip :
7608 llvm::zip(lhs.getValues<SrcType>(), rhs.getValues<SrcType>())) {
7609 values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip)));
7610 }
7611
7612 auto resultTy = op.getType().cast<ShapedType>();
7613 return DenseElementsAttr::get(resultTy, values);
7614 }
7615
fold(ArrayRef<Attribute> operands)7616 OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
7617 auto resultTy = getType().cast<ShapedType>();
7618 if (!resultTy.hasStaticShape()) return {};
7619
7620 auto direction = comparison_direction();
7621 auto lhsTy = getElementTypeOrSelf(lhs());
7622 if (lhs() == rhs() && !lhsTy.isa<FloatType>() &&
7623 (!lhsTy.isa<ComplexType>() ||
7624 !lhsTy.cast<ComplexType>().getElementType().isa<FloatType>())) {
7625 if (direction == ComparisonDirection::LE ||
7626 direction == ComparisonDirection::EQ ||
7627 direction == ComparisonDirection::GE) {
7628 return DenseIntElementsAttr::get(resultTy, {true});
7629 }
7630 return DenseIntElementsAttr::get(resultTy, {false});
7631 }
7632
7633 auto opElType = lhs().getType().cast<ShapedType>().getElementType();
7634 // Fold tensor<*xi1> != false to just return tensor<*xi1>
7635 if (direction == ComparisonDirection::NE && opElType.isInteger(1)) {
7636 DenseIntElementsAttr cstAttr;
7637 if (matchPattern(lhs(), m_Constant(&cstAttr))) {
7638 if (cstAttr.isSplat() && !cstAttr.getSplatValue<bool>()) {
7639 return rhs();
7640 }
7641 }
7642
7643 if (matchPattern(rhs(), m_Constant(&cstAttr))) {
7644 if (cstAttr.isSplat() && !cstAttr.getSplatValue<bool>()) {
7645 return lhs();
7646 }
7647 }
7648 }
7649
7650 // Fold tensor<*xi1> == True to just return tensor<*xi1>
7651 if (direction == ComparisonDirection::EQ && opElType.isInteger(1)) {
7652 DenseIntElementsAttr cstAttr;
7653 if (matchPattern(lhs(), m_Constant(&cstAttr))) {
7654 if (cstAttr.isSplat() && cstAttr.getSplatValue<bool>()) {
7655 return rhs();
7656 }
7657 }
7658
7659 if (matchPattern(rhs(), m_Constant(&cstAttr))) {
7660 if (cstAttr.isSplat() && cstAttr.getSplatValue<bool>()) {
7661 return lhs();
7662 }
7663 }
7664 }
7665
7666 if (!operands[0] || !operands[1]) {
7667 return {};
7668 }
7669
7670 #define COMPARE_FOLDER(Op, comparison, Func) \
7671 if (direction == comparison) { \
7672 if (auto folded = CompareFolder<Op, FloatType, APFloat, Func<APFloat>>( \
7673 *this, operands)) \
7674 return folded; \
7675 if (auto folded = CompareFolder<Op, IntegerType, APInt, Func<APInt>>( \
7676 *this, operands)) \
7677 return folded; \
7678 }
7679
7680 COMPARE_FOLDER(CompareOp, ComparisonDirection::EQ, std::equal_to);
7681 COMPARE_FOLDER(CompareOp, ComparisonDirection::NE, std::not_equal_to);
7682 COMPARE_FOLDER(CompareOp, ComparisonDirection::LT, Less);
7683 COMPARE_FOLDER(CompareOp, ComparisonDirection::LE, LessEqual);
7684 COMPARE_FOLDER(CompareOp, ComparisonDirection::GT, Greater);
7685 COMPARE_FOLDER(CompareOp, ComparisonDirection::GE, GreaterEqual);
7686 #undef COMPARE_FOLDER
7687
7688 return {};
7689 }
7690
7691 //===----------------------------------------------------------------------===//
7692 // SelectAndScatterOp
7693 //===----------------------------------------------------------------------===//
7694
7695 namespace {
7696 // Infer the return-type of SelectAndScatterOp.
inferSelectAndScatterOpReturnType(TensorType operandType,const ArrayRef<WindowDimension> window)7697 TensorType inferSelectAndScatterOpReturnType(
7698 TensorType operandType, const ArrayRef<WindowDimension> window) {
7699 if (!operandType.hasRank())
7700 return UnrankedTensorType::get(operandType.getElementType());
7701
7702 return RankedTensorType::get(
7703 inferWindowOutputShape(operandType.getShape(), window),
7704 operandType.getElementType());
7705 }
7706 } // namespace
7707
7708 // We intend to verify the following properties:
7709 // P1. Check if the select function has a proper shape of (T,T) -> PRED, where
7710 // T is a 0-D tensor with element-type same as 'operand' element-type.
7711 // P2. Verify scatter-computation type.
7712 // P3. size-of(window_dimension) == rank-of(input),
7713 // where input is an element of 'inputs'.
7714 // P4. Verify and collect the window attributes.
7715 // P5. Verify the return type matches the operand-type.
7716 // P6. Check if the result type of window operation matches the source type.
verify()7717 LogicalResult SelectAndScatterOp::verify() {
7718 auto operandType = operand().getType().cast<TensorType>();
7719 auto initValueType = init_value().getType().cast<TensorType>();
7720 auto sourceType = source().getType().cast<TensorType>();
7721 auto resultType = getResult().getType().cast<TensorType>();
7722
7723 // P1.
7724 Block& selectBlock = select().front();
7725
7726 if (selectBlock.getArguments().size() != 2)
7727 return emitOpError()
7728 << "expects the select-region to take 2 parameters, but takes "
7729 << selectBlock.getArguments().size();
7730
7731 Type expectedSelectArgType =
7732 RankedTensorType::get({}, operandType.getElementType());
7733 for (const auto& selectArgIt : llvm::enumerate(selectBlock.getArguments()))
7734 if (!compatibleShapeAndElementType(expectedSelectArgType,
7735 selectArgIt.value().getType(),
7736 /*ignoreFpPrecision=*/true))
7737 return emitOpError()
7738 << "expects the type of select-region's parameter at index "
7739 << selectArgIt.index() << " to be " << expectedSelectArgType
7740 << ", but got " << selectArgIt.value().getType();
7741
7742 auto selectResult = selectBlock.getTerminator()->getOperands();
7743 if (selectResult.size() != 1)
7744 return emitOpError()
7745 << "expects select-region to return single value, but got: "
7746 << selectResult.size();
7747
7748 auto selectResultType = selectResult[0].getType().dyn_cast<TensorType>();
7749 if (!selectResultType || !selectResultType.getElementType().isInteger(1) ||
7750 (selectResultType.hasRank() &&
7751 selectResultType.cast<RankedTensorType>().getRank() != 0))
7752 return emitOpError() << "expects the return-type of select-region to be "
7753 "tensor<i1>, but got: "
7754 << selectResult[0].getType();
7755
7756 // P2.
7757 Block& scatterBlock = scatter().front();
7758 SmallVector<TensorType> accumulatorSubshapes;
7759 if (failed(verifyReducerShape(
7760 this->getLoc(), scatterBlock,
7761 {RankedTensorType::get({}, sourceType.getElementType())},
7762 {initValueType},
7763 /*numInputs=*/1, /*allowedDimensions=*/{},
7764 /*allInputsUnranked=*/false, accumulatorSubshapes)))
7765 return failure();
7766
7767 // P3.
7768 SmallVector<int64_t> windowDims =
7769 convertDenseIntAttr(this->window_dimensions());
7770 if (operandType.hasRank()) {
7771 if (operandType.getRank() != static_cast<int64_t>(windowDims.size()))
7772 return emitOpError()
7773 << "expects window-dimensions size == operand rank, but got "
7774 "window-dimensions size: "
7775 << windowDims.size() << " and operand-type: " << operandType
7776 << " with rank = " << operandType.getRank() << ".";
7777 }
7778
7779 // P4.
7780 auto paddingOrErr = convertNx2Attribute(this->padding(), getLoc());
7781 if (failed(paddingOrErr)) return failure();
7782 SmallVector<std::pair<int64_t, int64_t>> padding = *paddingOrErr;
7783
7784 auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
7785 windowDims, convertDenseIntAttr(window_strides()), padding,
7786 /*lhs_dilation=*/{}, /*rhs_dilation=*/{}, getLoc());
7787 if (failed(windowOrErr)) return failure();
7788
7789 // P5.
7790 if (!compatibleShapeAndElementType(operandType, resultType))
7791 return emitOpError()
7792 << "expects the return-type to match the operand-type, but got "
7793 << resultType << " and " << operandType << " resp.";
7794
7795 // P6.
7796 auto windowResultType =
7797 inferSelectAndScatterOpReturnType(operandType, *windowOrErr);
7798
7799 if (!compatibleShapeAndElementType(windowResultType, sourceType,
7800 /*ignoreFpPrecision=*/true))
7801 return emitOpError() << "expects source-type to be " << windowResultType
7802 << ", but got" << sourceType;
7803
7804 return success();
7805 }
7806
7807 //===----------------------------------------------------------------------===//
7808 // ScatterOp
7809 //===----------------------------------------------------------------------===//
7810
7811 /*
7812 * We intend to verify the following properties:
7813 * P1. The 'update_window_dims' must be valid indices of 'updates' tensor.
7814 * P2. The 'inserted_window_dims' must be valid indices of 'operand' tensor.
7815 * P3. Check if the rank-of('operand') == size-of('update_window_dims') +
7816 * size-of('inserted_window_dims')
7817 * P4. size-of('scatter_dims_to_operand_dims') =
7818 * 'scatter_indices'['index_vector_dim'] &
7819 * 'scatter_dims_to_operand_dims' must be valid indices of 'operand' tensor.
7820 */
validateScatterDimensionNumbers(ShapedType operandType,ArrayRef<int64_t> scatterIndicesShape,ShapedType updateType,bool operandTypeRanked,bool scatterIndicesTypeRanked,bool updatesTypeRanked,ScatterDimensionNumbersAttr dimNumbers,Location loc)7821 LogicalResult validateScatterDimensionNumbers(
7822 ShapedType operandType, ArrayRef<int64_t> scatterIndicesShape,
7823 ShapedType updateType, bool operandTypeRanked,
7824 bool scatterIndicesTypeRanked, bool updatesTypeRanked,
7825 ScatterDimensionNumbersAttr dimNumbers, Location loc) {
7826 // P1.
7827 auto updateWindowDims = to_vector(dimNumbers.getUpdateWindowDims());
7828 if (!llvm::is_sorted(updateWindowDims))
7829 return mlir::emitError(loc)
7830 << "Expects update_window_dims to be sorted; got: ["
7831 << updateWindowDims << "].";
7832
7833 if (hasDuplicates(updateWindowDims))
7834 return mlir::emitError(loc)
7835 << "Expects update_window_dims to not repeat; got: ["
7836 << updateWindowDims << "].";
7837
7838 if (updatesTypeRanked) {
7839 for (int64_t windowDim : updateWindowDims) {
7840 if (windowDim < 0 || windowDim >= updateType.getRank()) {
7841 return mlir::emitError(loc)
7842 << "Expects each element of update_window_dims to be in range "
7843 "[0, "
7844 "rank-of('updates') i.e. [0, "
7845 << updateType.getRank() << "). got: " << windowDim << ".";
7846 }
7847 }
7848 }
7849
7850 // P2.
7851 auto insertedWindowDims = to_vector(dimNumbers.getInsertedWindowDims());
7852 if (!llvm::is_sorted(insertedWindowDims))
7853 return mlir::emitError(loc)
7854 << "Expects inserted_window_dims to be sorted; got: ["
7855 << insertedWindowDims << "].";
7856
7857 if (hasDuplicates(insertedWindowDims))
7858 return mlir::emitError(loc)
7859 << "Expects inserted_window_dims to not repeat; got: ["
7860 << insertedWindowDims << "].";
7861
7862 if (operandTypeRanked) {
7863 for (int64_t insertedDim : insertedWindowDims) {
7864 if (insertedDim < 0 || insertedDim >= operandType.getRank()) {
7865 return mlir::emitError(loc)
7866 << "Expects each element of inserted_window_dims to be in range "
7867 "[0, rank-of('operand') i.e. [0, "
7868 << operandType.getRank() << "). got: " << insertedDim << ".";
7869 }
7870 }
7871 }
7872
7873 // P3.
7874 if (operandTypeRanked) {
7875 auto windowSize = updateWindowDims.size() + insertedWindowDims.size();
7876 if (operandType.getRank() != static_cast<int64_t>(windowSize))
7877 return mlir::emitError(loc)
7878 << "Expects rank-of operand to match "
7879 "size-of('update_window_dims') + "
7880 "size-of('inserted_window_dims') i.e. "
7881 << windowSize << " but got " << operandType.getRank() << ".";
7882 }
7883
7884 // P4.
7885 auto scatterDimsToOperandDims =
7886 to_vector(dimNumbers.getScatterDimsToOperandDims());
7887 auto indexVectorDim = dimNumbers.getIndexVectorDim();
7888 if (scatterIndicesTypeRanked) {
7889 if (!isDynamicDimSize(scatterIndicesShape[indexVectorDim]) &&
7890 static_cast<int64_t>(scatterDimsToOperandDims.size()) !=
7891 scatterIndicesShape[dimNumbers.getIndexVectorDim()])
7892 return mlir::emitError(loc)
7893 << "Scatter op has " << scatterDimsToOperandDims.size()
7894 << " elements in scatter_dims_to_operand_dims and the bound of "
7895 "dimension index_vector_dim="
7896 << dimNumbers.getIndexVectorDim() << " of scatter_indices is "
7897 << scatterIndicesShape[dimNumbers.getIndexVectorDim()]
7898 << ". These two numbers must be equal.";
7899 }
7900
7901 if (operandTypeRanked) {
7902 for (int64_t i = 0;
7903 i < static_cast<int64_t>(scatterDimsToOperandDims.size()); ++i) {
7904 int64_t scatterDimToOperandDim = scatterDimsToOperandDims[i];
7905 if (scatterDimToOperandDim < 0 ||
7906 scatterDimToOperandDim >= operandType.getRank())
7907 return mlir::emitError(loc)
7908 << "Invalid scatter_dims_to_operand_dims mapping; domain is [0, "
7909 << operandType.getRank() << "), got: " << i << "->"
7910 << scatterDimToOperandDim << ".";
7911 }
7912 }
7913
7914 if (hasDuplicates(scatterDimsToOperandDims))
7915 return mlir::emitError(loc)
7916 << "Expects scatter_dims_to_operand_dims to not repeat; got: ["
7917 << scatterDimsToOperandDims << "].";
7918
7919 return success();
7920 }
7921 /*
7922 * We intend to verify the following properties:
7923 * P0. scatter_indices argument must be an integral tensor. Enforced by ODS.
7924 * P1. Scatter index leaf dimension must be within [0, rank(scatter_indices)"
7925 * " + 1).
7926 * P2. Verify reducer shape.
7927 * P3. rank-of('updates[i]') == size-of('update_window_dims') +
7928 * rank-of('scatter_indices') - 1, where 'scatter_indices' is expanded by a
7929 * trailing 1 dimension if 'index_vector_dim' == rank-of('scatter_indices')
7930 * for all values of `i`.
7931 * P4. Validate the scatter-dimensions-numbers.
7932 * P5. Valide the bounds of each of the 'updates' w.r.t the operands.
7933 * P6. Validate the bounds of each of the 'updates' w.r.t the
7934 * 'scatter_indices'.
7935 * P7. Check return types.
7936 */
verify()7937 LogicalResult ScatterOp::verify() {
7938 // Get the first operand and update, since variadic Scatter is not yet
7939 // implemented
7940 auto numOperands = operands().size();
7941 auto scatterIndicesType = scatter_indices().getType().dyn_cast<TensorType>();
7942
7943 SmallVector<TensorType, 1> operandTypes =
7944 llvm::to_vector(llvm::map_range(operands().getTypes(), [](Type type) {
7945 return type.cast<TensorType>();
7946 }));
7947 SmallVector<TensorType, 1> updatesTypes = llvm::to_vector(llvm::map_range(
7948 updates().getTypes(), [](Type type) { return type.cast<TensorType>(); }));
7949 bool allOperandTypesRanked =
7950 llvm::all_of(operands().getTypes(),
7951 [](Type type) { return type.isa<RankedTensorType>(); });
7952 bool scatterIndicesTypeRanked = scatterIndicesType.isa<RankedTensorType>();
7953
7954 // P1.
7955 int64_t indexVectorDim = scatter_dimension_numbers().getIndexVectorDim();
7956 if (scatterIndicesTypeRanked) {
7957 if (indexVectorDim > scatterIndicesType.getRank() || indexVectorDim < 0)
7958 return emitOpError()
7959 << "expects scatter index leaf dimension to be within [0, "
7960 "rank(scatter_indices) + 1."
7961 " rank(scatter_indices) is "
7962 << scatterIndicesType.getRank()
7963 << " and scatter index leaf dimension is " << indexVectorDim
7964 << ".";
7965 }
7966
7967 // P2.
7968 Block& block = update_computation().front();
7969 SmallVector<TensorType> accumulatorSubshapes;
7970 SmallVector<TensorType> inputTypes, initValueTypes;
7971 for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
7972 inputTypes.push_back(operandTypes[i]);
7973 initValueTypes.push_back(
7974 RankedTensorType::get({}, updatesTypes[i].getElementType()));
7975 }
7976 if (failed(verifyReducerShape(
7977 this->getLoc(), block, inputTypes, initValueTypes, numOperands,
7978 /*allowedDimensions=*/{},
7979 /*allInputsUnranked=*/!allOperandTypesRanked, accumulatorSubshapes)))
7980 return failure();
7981
7982 // P3.
7983 auto updateWindowDims = scatter_dimension_numbers().getUpdateWindowDims();
7984 SmallVector<int64_t> expandedScatterIndicesShape;
7985 if (scatterIndicesTypeRanked) {
7986 expandedScatterIndicesShape =
7987 llvm::to_vector(scatterIndicesType.getShape());
7988 if (static_cast<int64_t>(expandedScatterIndicesShape.size()) ==
7989 indexVectorDim)
7990 expandedScatterIndicesShape.push_back(1);
7991 }
7992
7993 for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
7994 if (scatterIndicesTypeRanked && updatesTypes[i].isa<RankedTensorType>()) {
7995 int64_t expectedUpdatesRank =
7996 expandedScatterIndicesShape.size() - 1 + updateWindowDims.size();
7997 if (updatesTypes[i].getRank() != expectedUpdatesRank)
7998 return emitOpError()
7999 << "expects updates tensor must be of rank "
8000 << expectedUpdatesRank
8001 << " ( == rank-of('scatter_indices') - 1 + "
8002 "size-of('update_window_dims'), where 'scatter_indices' is "
8003 "expanded by a trailing 1 dimension if 'index_vector_dim' == "
8004 "rank-of('scatter_indices')), but got "
8005 << updatesTypes[i].getRank() << ".";
8006 }
8007 }
8008
8009 // P4.
8010 for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
8011 if (failed(validateScatterDimensionNumbers(
8012 operandTypes[i], expandedScatterIndicesShape, updatesTypes[i],
8013 operandTypes[i].isa<RankedTensorType>(), scatterIndicesTypeRanked,
8014 updatesTypes[i].isa<RankedTensorType>(),
8015 scatter_dimension_numbers(), getLoc())))
8016 return failure();
8017 }
8018
8019 // P5.
8020 for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
8021 if (updatesTypes[i].isa<RankedTensorType>()) {
8022 auto updatesShape = updatesTypes[i].getShape();
8023 if (operandTypes[i].isa<RankedTensorType>()) {
8024 auto operandShape = operandTypes[i].getShape();
8025 auto insertedWindowDims =
8026 scatter_dimension_numbers().getInsertedWindowDims();
8027
8028 int64_t insertedDimsSeen = 0;
8029 SmallVector<int64_t> maxUpdateSliceSizes;
8030 const auto dimensionsSize = operandTypes[i].getRank();
8031 maxUpdateSliceSizes.reserve(dimensionsSize);
8032 for (int i = 0; i < dimensionsSize; ++i) {
8033 if (insertedDimsSeen <
8034 static_cast<int64_t>(insertedWindowDims.size()) &&
8035 insertedWindowDims[insertedDimsSeen] == i) {
8036 ++insertedDimsSeen;
8037 } else {
8038 maxUpdateSliceSizes.push_back(operandShape[i]);
8039 }
8040 }
8041
8042 for (int64_t i = 0; i < static_cast<int64_t>(updateWindowDims.size());
8043 ++i) {
8044 auto updateWindowDim = updateWindowDims[i];
8045
8046 if (isDynamicDimSize(updatesShape[updateWindowDim]) ||
8047 isDynamicDimSize(maxUpdateSliceSizes[i]))
8048 continue;
8049
8050 if (updatesShape[updateWindowDim] > maxUpdateSliceSizes[i]) {
8051 return emitOpError()
8052 << "expects bounds of the window dimensions of "
8053 "updates to not exceed the "
8054 "bounds of the corresponding dimensions of "
8055 "operand. For dimension "
8056 << updateWindowDim << ", updates bound is "
8057 << updatesShape[updateWindowDim] << ", operand bound is "
8058 << maxUpdateSliceSizes[i] << ".";
8059 }
8060 }
8061 }
8062
8063 // P6.
8064 if (scatterIndicesTypeRanked) {
8065 int64_t scatterDimsSeen = 0;
8066 for (int64_t i = 0; i < static_cast<int64_t>(updatesShape.size());
8067 ++i) {
8068 bool isUpdateWindowDim = std::binary_search(
8069 updateWindowDims.begin(), updateWindowDims.end(), i);
8070
8071 if (isUpdateWindowDim) continue;
8072 if (scatterDimsSeen == indexVectorDim) ++scatterDimsSeen;
8073
8074 if (!isDynamicDimSize(updatesShape[i]) &&
8075 !isDynamicDimSize(expandedScatterIndicesShape[scatterDimsSeen]) &&
8076 (updatesShape[i] !=
8077 expandedScatterIndicesShape[scatterDimsSeen])) {
8078 return emitOpError()
8079 << "expects bounds of the scatter dimensions of "
8080 "updates to be same as the "
8081 "bounds of the corresponding dimensions of "
8082 "scatter indices. For "
8083 "scatter dimension "
8084 << i << ", updates bound is " << updatesShape[i]
8085 << " , scatter_indices "
8086 "bound is "
8087 << expandedScatterIndicesShape[scatterDimsSeen] << ".";
8088 }
8089 ++scatterDimsSeen;
8090 }
8091 }
8092 }
8093 }
8094
8095 // P7.
8096 for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
8097 if (!compatibleShapeAndElementType(operandTypes[i], getResult(i).getType()))
8098 return emitOpError()
8099 << "expects the return type to be same as the operand type: "
8100 << operandTypes[i] << ", but got " << getResult(i).getType()
8101 << ".";
8102 }
8103
8104 return success();
8105 }
8106
evaluateMhloRegion(Region & region,ArrayRef<Attribute> inputs)8107 llvm::SmallVector<Attribute, 4> evaluateMhloRegion(Region& region,
8108 ArrayRef<Attribute> inputs) {
8109 if (region.getNumArguments() != inputs.size()) return {};
8110
8111 llvm::DenseMap<Value, Attribute> values;
8112 values.reserve(region.getNumArguments());
8113 for (auto it : llvm::zip(region.getArguments(), inputs)) {
8114 values.try_emplace(std::get<0>(it), std::get<1>(it));
8115 }
8116
8117 for (auto& op : region.getOps()) {
8118 llvm::SmallVector<Attribute, 4> inputs;
8119 for (auto& operand : op.getOpOperands()) {
8120 inputs.push_back(values.lookup(operand.get()));
8121 }
8122 if (isa<ReturnOp>(op)) return inputs;
8123
8124 llvm::SmallVector<OpFoldResult, 4> results;
8125 if (failed(op.fold(inputs, results))) return {};
8126 for (auto it : llvm::zip(op.getResults(), results)) {
8127 if (!std::get<1>(it).is<Attribute>()) return {};
8128 values.insert({std::get<0>(it), std::get<1>(it).get<Attribute>()});
8129 }
8130 }
8131 return {};
8132 }
8133
fold(ArrayRef<Attribute> args,llvm::SmallVectorImpl<OpFoldResult> & foldResults)8134 LogicalResult ScatterOp::fold(
8135 ArrayRef<Attribute> args,
8136 llvm::SmallVectorImpl<OpFoldResult>& foldResults) {
8137 // Variadic Scatter not yet implemented
8138 if (operands().size() != 1 || updates().size() != 1) return failure();
8139 auto index = args[1].dyn_cast_or_null<DenseIntElementsAttr>();
8140 if (!index) return failure();
8141
8142 auto baseType = operands().getTypes()[0].dyn_cast<RankedTensorType>();
8143 auto updateType = updates().getTypes()[0].dyn_cast<RankedTensorType>();
8144 auto indexType = index.getType().cast<RankedTensorType>();
8145 if (!baseType || !indexType || !updateType) return failure();
8146
8147 // TODO(b/228310289): Work around canonicalization crash for complex types.
8148 // Remove after upstream MLIR has been fixed.
8149 if (baseType.getElementType().isa<ComplexType>()) return failure();
8150
8151 // Catch a trivial full replacement of base with update, this does not require
8152 // these to be constant: just that we know the type.
8153 if (updateType == baseType && updateType.hasStaticShape() &&
8154 baseType.hasStaticShape() && index.isSplat() &&
8155 index.getSplatValue<uint32_t>() == 0 &&
8156 llvm::hasSingleElement(update_computation().front())) {
8157 foldResults.push_back(updates()[0]);
8158 return success();
8159 }
8160 auto base = args[0].dyn_cast_or_null<DenseElementsAttr>();
8161 auto update = args[2].dyn_cast_or_null<DenseElementsAttr>();
8162 if (!base || !update) return failure();
8163
8164 // Add the virtual trailing dimension of size 1 if indexVectorDim equals to
8165 // indexType.rank.
8166 const int64_t indexVectorDim =
8167 scatter_dimension_numbers().getIndexVectorDim();
8168 if (indexVectorDim == indexType.getRank()) {
8169 auto indexShape = indexType.getShape().vec();
8170 indexShape.push_back(1);
8171 indexType = RankedTensorType::get(indexShape, indexType.getElementType());
8172 index = reshape(index, indexType).cast<DenseIntElementsAttr>();
8173 }
8174
8175 // Increment the multi-dimensional index vector based on the limits for each
8176 // dimension specified by shape and returns false if the index rolled around
8177 // with true otherwise.
8178 auto nextIndex = [](llvm::SmallVector<uint64_t, 8>& index,
8179 llvm::ArrayRef<int64_t> shape) {
8180 for (int64_t i = index.size() - 1; i >= 0; --i) {
8181 ++index[i];
8182 if (index[i] < static_cast<unsigned long>(shape[i])) return true;
8183 index[i] = 0;
8184 }
8185 return false;
8186 };
8187
8188 // Prevent folding if the result is too large.
8189 if (base.getNumElements() > kFoldOpEltLimit) return failure();
8190
8191 // Iterate over all elements of the update tensor, then find the corresponding
8192 // value in the indices tensor to determine which location we have to update
8193 // in the base/result tensor.
8194 llvm::SmallVector<Attribute, 8> results(base.getValues<Attribute>());
8195 llvm::SmallVector<uint64_t, 8> updateIndex(updateType.getRank(), 0);
8196 llvm::SmallVector<uint64_t, 8> indexIndex;
8197 indexIndex.reserve(indexType.getRank());
8198 llvm::SmallVector<int64_t, 8> baseIndex;
8199 baseIndex.reserve(baseType.getRank());
8200 do {
8201 // Compute the index for the slice of the indices tensor for this update
8202 // value.
8203 indexIndex.clear();
8204 if (indexVectorDim == 0) indexIndex.push_back(0);
8205 for (int64_t i = 0; i < static_cast<int64_t>(updateIndex.size()); ++i) {
8206 if (llvm::count(scatter_dimension_numbers().getUpdateWindowDims(), i) ==
8207 0)
8208 indexIndex.push_back(updateIndex[i]);
8209 if (static_cast<int64_t>(indexIndex.size()) == indexVectorDim)
8210 indexIndex.push_back(0);
8211 }
8212
8213 // Compute the index for the given update value in the base tensor.
8214 baseIndex.assign(baseType.getRank(), 0);
8215 uint64_t indexCount = indexType.getShape()[indexVectorDim];
8216 for (uint64_t i = 0; i < indexCount; ++i) {
8217 uint64_t operandDim =
8218 scatter_dimension_numbers().getScatterDimsToOperandDims()[i];
8219 indexIndex[indexVectorDim] = i;
8220 baseIndex[operandDim] +=
8221 index.getValues<APInt>()[indexIndex].getSExtValue();
8222 }
8223 uint64_t updateWindowDimIndex = 0;
8224 auto insertedWindowDims =
8225 scatter_dimension_numbers().getInsertedWindowDims();
8226 auto updateWindowDims = scatter_dimension_numbers().getUpdateWindowDims();
8227 for (uint64_t i = 0; i < baseIndex.size(); ++i) {
8228 if (llvm::count(insertedWindowDims, i)) continue;
8229 baseIndex[i] += updateIndex[updateWindowDims[updateWindowDimIndex]];
8230 updateWindowDimIndex++;
8231 }
8232
8233 // Compute the linear index for the index into the base tensor.
8234 int64_t linearBaseIndex = 0;
8235 int64_t linearBaseIndexMultiplyer = 1;
8236 for (int64_t i = baseIndex.size() - 1; i >= 0; --i) {
8237 // Out of bound index have backend specific behaviour so avoid folding it.
8238 if (baseIndex[i] < 0 || baseIndex[i] >= baseType.getShape()[i])
8239 return failure();
8240 linearBaseIndex += baseIndex[i] * linearBaseIndexMultiplyer;
8241 linearBaseIndexMultiplyer *= baseType.getShape()[i];
8242 }
8243
8244 // Evaluate update computation and update the value with the newly computed
8245 // attribute in the base tensor.
8246 auto lhs = DenseElementsAttr::get(
8247 RankedTensorType::get({}, baseType.getElementType()),
8248 results[linearBaseIndex]);
8249 auto rhs = DenseElementsAttr::get(
8250 RankedTensorType::get({}, baseType.getElementType()),
8251 update.getValues<Attribute>()[updateIndex]);
8252 auto newValue = evaluateMhloRegion(update_computation(), {lhs, rhs});
8253 if (newValue.size() != 1 || !newValue[0]) return failure();
8254 results[linearBaseIndex] =
8255 newValue[0].cast<DenseElementsAttr>().getValues<Attribute>()[0];
8256 } while (nextIndex(updateIndex, updateType.getShape()));
8257
8258 foldResults.push_back(DenseElementsAttr::get(baseType, results));
8259 return success();
8260 }
8261
8262 //===----------------------------------------------------------------------===//
8263 // WhileOp
8264 //===----------------------------------------------------------------------===//
8265
verify()8266 LogicalResult WhileOp::verify() {
8267 if (getNumOperands() != cond().front().getNumArguments())
8268 return emitOpError() << "mismatch in operand count (" << getNumOperands()
8269 << ") vs the condition block argument count ("
8270 << cond().front().getNumArguments() << ")";
8271 if (getNumOperands() != body().front().getNumArguments())
8272 return emitOpError() << "mismatch in operand count (" << getNumOperands()
8273 << ") vs the body block argument count ("
8274 << body().front().getNumArguments() << ")";
8275 for (const auto& enumeratedOperands : llvm::enumerate(
8276 llvm::zip(getOperandTypes(), cond().front().getArgumentTypes(),
8277 body().front().getArgumentTypes()))) {
8278 int argCount = enumeratedOperands.index();
8279 const auto& operands = enumeratedOperands.value();
8280 Type operandType = std::get<0>(operands);
8281 Type condType = std::get<1>(operands);
8282 Type bodyType = std::get<2>(operands);
8283 if (operandType != condType)
8284 return emitOpError() << "type mismatch between operand #" << argCount
8285 << " and the matching condition block argument: "
8286 << operandType << " vs " << condType;
8287 if (operandType != bodyType)
8288 return emitOpError() << "type mismatch between operand #" << argCount
8289 << " and the matching body block argument: "
8290 << operandType << " vs " << bodyType;
8291 }
8292 // Check the return type for the condition block.
8293 {
8294 auto condReturnOp = cast<ReturnOp>(cond().front().back());
8295 if (condReturnOp->getNumOperands() != 1)
8296 return condReturnOp.emitOpError()
8297 << "expects a single operand for while condition body return, got "
8298 << condReturnOp->getNumOperands();
8299 auto operandType =
8300 condReturnOp->getOperand(0).getType().dyn_cast<RankedTensorType>();
8301 if (!operandType || operandType.getRank() != 0 ||
8302 !operandType.getElementType().isInteger(1))
8303 return condReturnOp.emitOpError()
8304 << "expects a zero-ranked tensor of i1, got "
8305 << condReturnOp->getOperand(0).getType();
8306 }
8307 // Check the return type for the body block.
8308 {
8309 auto bodyReturnOp = cast<ReturnOp>(body().front().back());
8310 if (bodyReturnOp->getNumOperands() != getNumOperands())
8311 return bodyReturnOp.emitOpError()
8312 << "expects body to return a many value as the operands ("
8313 << getNumOperands() << "), got " << bodyReturnOp->getNumOperands();
8314 for (const auto& enumeratedOperandTypes : llvm::enumerate(
8315 llvm::zip(bodyReturnOp->getOperandTypes(), getOperandTypes()))) {
8316 Type operandType = std::get<0>(enumeratedOperandTypes.value());
8317 Type returnType = std::get<1>(enumeratedOperandTypes.value());
8318 if (operandType != returnType)
8319 return bodyReturnOp.emitOpError()
8320 << "type mismatch between operand #"
8321 << enumeratedOperandTypes.index()
8322 << " and the enclosing WhileOp returned value: " << operandType
8323 << " vs " << returnType;
8324 }
8325 }
8326 return success();
8327 }
8328
8329 /// Print a `while` op.
8330 ///
8331 /// op ::= `mhlo.while` `(` assignment-list `)` `:` types attribute-dict
8332 /// `cond` region
8333 /// `do` region
8334 /// assignment-list ::= assignment | assignment `,` assignment-list
8335 /// assignment ::= ssa-value `=` ssa-value
print(OpAsmPrinter & p)8336 void WhileOp::print(OpAsmPrinter& p) {
8337 p << '(';
8338 llvm::interleaveComma(llvm::zip(getBody()->getArguments(), getOperands()), p,
8339 [&](auto zip) {
8340 p.printOperand(std::get<0>(zip));
8341 p << " = ";
8342 p.printOperand(std::get<1>(zip));
8343 });
8344 p << ")";
8345 if (getNumOperands()) {
8346 p << " : ";
8347 llvm::interleaveComma(getOperandTypes(), p);
8348 }
8349 p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs());
8350 p.printNewline();
8351 p << " cond ";
8352 p.printRegion(getRegion(0), /*printEntryBlockArgs=*/false);
8353 p << " do ";
8354 p.printRegion(getRegion(1), /*printEntryBlockArgs=*/false);
8355 }
8356
parse(OpAsmParser & parser,OperationState & result)8357 ParseResult WhileOp::parse(OpAsmParser& parser, OperationState& result) {
8358 llvm::SMLoc loc = parser.getCurrentLocation();
8359 // Parse the operands of the while: these are of the form:
8360 // %iter_arg = %init_val
8361 // where %iter_arg is the name of the block argument in the cond/body blocks
8362 // and %init_val is the actual operand.
8363 SmallVector<OpAsmParser::UnresolvedOperand> operands;
8364 SmallVector<OpAsmParser::UnresolvedOperand> iterArgs;
8365 if (parser.parseLParen()) return failure();
8366 do {
8367 if (succeeded(parser.parseOptionalRParen())) break;
8368 OpAsmParser::UnresolvedOperand operand, iterArg;
8369 if (parser.parseOperand(iterArg) || parser.parseEqual() ||
8370 parser.parseOperand(operand))
8371 return failure();
8372 iterArgs.push_back(iterArg);
8373 operands.push_back(operand);
8374 if (succeeded(parser.parseOptionalRParen())) break;
8375 if (failed(parser.parseComma())) return failure();
8376 } while (true);
8377 if (!operands.empty()) {
8378 if (parser.parseColon() || parser.parseTypeList(result.types))
8379 return failure();
8380 }
8381
8382 SmallVector<OpAsmParser::Argument> args;
8383 createArgs(iterArgs, result.types, args);
8384 if (parser.resolveOperands(operands, result.types, loc, result.operands) ||
8385 parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
8386 parser.parseKeyword("cond") ||
8387 parser.parseRegion(*result.addRegion(), args) ||
8388 parser.parseKeyword("do") ||
8389 parser.parseRegion(*result.addRegion(), args))
8390 return failure();
8391 return success();
8392 }
8393
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> & results)8394 LogicalResult WhileOp::fold(ArrayRef<Attribute> /*operands*/,
8395 SmallVectorImpl<OpFoldResult>& results) {
8396 DenseIntElementsAttr condValue;
8397 auto condReturnOp = cast<ReturnOp>(cond().front().back());
8398 if (!matchPattern(condReturnOp.getOperand(0), m_Constant(&condValue)))
8399 return failure();
8400 if (condValue.getSplatValue<BoolAttr>().getValue())
8401 return failure(); // TODO(mhlo): this is an infinite loop, should we fold?
8402
8403 results.append(getOperands().begin(), getOperands().end());
8404 return success();
8405 }
8406
whileCanonicalization(WhileOp whileOp,PatternRewriter & rewriter)8407 static LogicalResult whileCanonicalization(WhileOp whileOp,
8408 PatternRewriter& rewriter) {
8409 // Turn loop invariant values into implicit capture.
8410 // Check if there is at least one value is forwarded from one iteration to the
8411 // next, or one of the yielded value is an implicit capture already. Otherwise
8412 // there is nothing to do here.
8413 Block* cond = whileOp.getBody(0);
8414 Block* body = whileOp.getBody(1);
8415 auto bodyReturnOp = cast<ReturnOp>(body->getTerminator());
8416 if (!llvm::any_of(llvm::zip(whileOp->getOperands(), body->getArguments(),
8417 bodyReturnOp->getOperands()),
8418 [&](auto zip) {
8419 return (std::get<0>(zip) == std::get<2>(zip) ||
8420 std::get<1>(zip) == std::get<2>(zip));
8421 }))
8422 return rewriter.notifyMatchFailure(whileOp, "no loop invariant found");
8423
8424 SmallVector<Value> newOperands, resultsToReplace;
8425 SmallVector<unsigned> invariantArgIdxs;
8426 for (const auto& enumeratedOperands : llvm::enumerate(llvm::zip(
8427 whileOp.getOperands(), cond->getArguments(), body->getArguments(),
8428 bodyReturnOp->getOperands(), whileOp->getResults()))) {
8429 const auto& operands = enumeratedOperands.value();
8430 Value whileOperand = std::get<0>(operands);
8431 BlockArgument condBlockArg = std::get<1>(operands);
8432 BlockArgument bodyBlockArg = std::get<2>(operands);
8433 Value bodyReturnOperand = std::get<3>(operands);
8434 Value whileResult = std::get<4>(operands);
8435
8436 bool forwarded = (whileOperand == bodyReturnOperand ||
8437 bodyBlockArg == bodyReturnOperand);
8438 if (forwarded) {
8439 invariantArgIdxs.push_back(enumeratedOperands.index());
8440 condBlockArg.replaceAllUsesWith(whileOperand);
8441 bodyBlockArg.replaceAllUsesWith(whileOperand);
8442 whileResult.replaceAllUsesWith(whileOperand);
8443 continue;
8444 }
8445 newOperands.push_back(whileOperand);
8446 resultsToReplace.push_back(whileResult);
8447 }
8448 cond->eraseArguments(invariantArgIdxs);
8449 body->eraseArguments(invariantArgIdxs);
8450 for (int idx : llvm::reverse(invariantArgIdxs))
8451 bodyReturnOp->eraseOperand(idx);
8452
8453 WhileOp newWhileOp = rewriter.create<WhileOp>(
8454 whileOp.getLoc(), bodyReturnOp->getOperandTypes(), newOperands);
8455 newWhileOp.getBodyRegion(0).takeBody(whileOp.getBodyRegion(0));
8456 newWhileOp.getBodyRegion(1).takeBody(whileOp.getBodyRegion(1));
8457 for (auto results : llvm::zip(resultsToReplace, newWhileOp->getResults()))
8458 std::get<0>(results).replaceAllUsesWith(std::get<1>(results));
8459 rewriter.eraseOp(whileOp);
8460 return success();
8461 }
8462
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)8463 void WhileOp::getCanonicalizationPatterns(RewritePatternSet& results,
8464 MLIRContext* context) {
8465 results.add(&whileCanonicalization);
8466 }
8467
inferReturnTypeComponents(MLIRContext *,Optional<Location>,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)8468 LogicalResult UniformDequantizeOp::inferReturnTypeComponents(
8469 MLIRContext*, Optional<Location> /*location*/, ValueShapeRange operands,
8470 DictionaryAttr attributes, RegionRange regions,
8471 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
8472 UniformDequantizeOp::Adaptor adaptor(operands, attributes, regions);
8473 auto operandType = (*operands.begin()).getType().cast<ShapedType>();
8474 // Trait HLO_QuantizedIntTensor in ODS guarantees QuantizedType;
8475 auto quantType = operandType.getElementType().cast<quant::QuantizedType>();
8476 auto shape = operandType.dyn_cast<ShapedType>().getShape();
8477 inferredReturnShapes.emplace_back(shape, quantType.getExpressedType());
8478 return success();
8479 }
8480
8481 using mlir::hlo::parseWindowAttributes;
8482 using mlir::hlo::printWindowAttributes;
8483
8484 } // namespace mhlo
8485 } // namespace mlir
8486
8487 #define GET_OP_CLASSES
8488 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
8489
8490 namespace mlir {
8491 namespace mhlo {
8492
8493 //===----------------------------------------------------------------------===//
8494 // mhlo Dialect Interfaces
8495 //===----------------------------------------------------------------------===//
8496
8497 namespace {
8498 struct HLOInlinerInterface : public DialectInlinerInterface {
8499 using DialectInlinerInterface::DialectInlinerInterface;
8500
8501 // Allow all call operations to be inlined.
isLegalToInlinemlir::mhlo::__anon00baf10a4511::HLOInlinerInterface8502 bool isLegalToInline(Operation* call, Operation* callable,
8503 bool wouldBeCloned) const final {
8504 return true;
8505 }
8506 // We don't have any special restrictions on what can be inlined into
8507 // destination regions (e.g. while/conditional bodies). Always allow it.
isLegalToInlinemlir::mhlo::__anon00baf10a4511::HLOInlinerInterface8508 bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned,
8509 BlockAndValueMapping& valueMapping) const final {
8510 return true;
8511 }
8512 // Operations in mhlo dialect are always legal to inline since they are
8513 // pure.
isLegalToInlinemlir::mhlo::__anon00baf10a4511::HLOInlinerInterface8514 bool isLegalToInline(Operation*, Region*, bool,
8515 BlockAndValueMapping&) const final {
8516 return true;
8517 }
8518 };
8519 } // end anonymous namespace
8520
8521 //===----------------------------------------------------------------------===//
8522 // mhlo Dialect Constructor
8523 //===----------------------------------------------------------------------===//
8524
MhloDialect(MLIRContext * context)8525 MhloDialect::MhloDialect(MLIRContext* context)
8526 : Dialect(getDialectNamespace(), context, TypeID::get<MhloDialect>()) {
8527 addOperations<
8528 #define GET_OP_LIST
8529 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
8530 >();
8531 addInterfaces<HLOInlinerInterface>();
8532 addTypes<TokenType>();
8533 addAttributes<
8534 #define GET_ATTRDEF_LIST
8535 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.cc.inc"
8536 >();
8537 context->loadDialect<tensor::TensorDialect>();
8538 }
8539
parseType(DialectAsmParser & parser) const8540 Type MhloDialect::parseType(DialectAsmParser& parser) const {
8541 StringRef dataType;
8542 if (parser.parseKeyword(&dataType)) return Type();
8543
8544 if (dataType == "token") return TokenType::get(getContext());
8545 parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << dataType;
8546 return nullptr;
8547 }
8548
printType(Type type,DialectAsmPrinter & os) const8549 void MhloDialect::printType(Type type, DialectAsmPrinter& os) const {
8550 if (type.isa<TokenType>()) {
8551 os << "token";
8552 return;
8553 }
8554 os << "<unknown mhlo type>";
8555 }
8556
8557 // Entry point for Attribute parsing, TableGen generated code will handle the
8558 // dispatch to the individual classes.
parseAttribute(DialectAsmParser & parser,Type type) const8559 Attribute MhloDialect::parseAttribute(DialectAsmParser& parser,
8560 Type type) const {
8561 StringRef attrTag;
8562 Attribute attr;
8563 auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
8564 if (parseResult.hasValue()) return attr;
8565 parser.emitError(parser.getNameLoc(), "unknown mhlo attribute");
8566 return Attribute();
8567 }
8568
8569 // Entry point for Attribute printing, TableGen generated code will handle the
8570 // dispatch to the individual classes.
printAttribute(Attribute attr,DialectAsmPrinter & os) const8571 void MhloDialect::printAttribute(Attribute attr, DialectAsmPrinter& os) const {
8572 LogicalResult result = generatedAttributePrinter(attr, os);
8573 (void)result;
8574 assert(succeeded(result));
8575 }
8576
8577 /// Helpers for attributes parsing.
8578
parseDims(AsmParser & parser,SmallVector<int64_t> & dims)8579 static ParseResult parseDims(AsmParser& parser, SmallVector<int64_t>& dims) {
8580 dims.clear();
8581 if (parser.parseLSquare()) return failure();
8582 while (failed(parser.parseOptionalRSquare())) {
8583 dims.emplace_back();
8584 if (parser.parseInteger(dims.back())) return failure();
8585 (void)parser.parseOptionalComma();
8586 }
8587 return success();
8588 }
8589
parseDimsWithMinimumElements(AsmParser & parser,SmallVector<int64_t> & dims,int minElements)8590 static ParseResult parseDimsWithMinimumElements(AsmParser& parser,
8591 SmallVector<int64_t>& dims,
8592 int minElements) {
8593 if (failed(parseDims(parser, dims))) return failure();
8594 if (static_cast<int64_t>(dims.size()) < minElements)
8595 return parser.emitError(parser.getCurrentLocation())
8596 << "expected at least " << minElements << " element(s), found "
8597 << dims.size();
8598 return success();
8599 }
8600
8601 /// Parse a custom attribute that resembles a struct of the form
8602 /// <
8603 /// foo = something_parsed_by_custom_parser,
8604 /// bar = something_parsed_by_different_custom_parser,
8605 /// baz something_parsed_by_another_custom_parser
8606 /// >
8607 /// The optional argument `parse_equal` array can be used to denote if
8608 /// '=' follows the keyword (see baz in the example above) for a field. If
8609 /// not provided, all fields must be followed by a '='.
parseStruct(AsmParser & parser,ArrayRef<StringRef> keywords,ArrayRef<llvm::function_ref<ParseResult ()>> parseFuncs,ArrayRef<bool> parseEqual={})8610 static ParseResult parseStruct(
8611 AsmParser& parser, ArrayRef<StringRef> keywords,
8612 ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs,
8613 ArrayRef<bool> parseEqual = {}) {
8614 assert(keywords.size() == parseFuncs.size());
8615 assert(parseEqual.empty() || parseEqual.size() == keywords.size());
8616 SmallVector<bool> seen(keywords.size(), false);
8617 while (failed(parser.parseOptionalGreater())) {
8618 bool foundOne = false;
8619 for (const auto& it : llvm::enumerate(keywords)) {
8620 size_t index = it.index();
8621 StringRef keyword = it.value();
8622 if (succeeded(parser.parseOptionalKeyword(keyword))) {
8623 if (seen[index]) {
8624 return parser.emitError(parser.getCurrentLocation())
8625 << "duplicated `" << keyword << "` entry";
8626 }
8627 if (parseEqual.empty() || parseEqual[index]) {
8628 if (failed(parser.parseEqual())) return failure();
8629 }
8630 if (failed(parseFuncs[index]())) return failure();
8631 if (failed(parser.parseOptionalComma())) return parser.parseGreater();
8632 seen[index] = true;
8633 foundOne = true;
8634 }
8635 }
8636 if (!foundOne) {
8637 auto parseError = parser.emitError(parser.getCurrentLocation())
8638 << "expected one of: ";
__anon00baf10a4602(StringRef kw) 8639 llvm::interleaveComma(keywords, parseError, [&](StringRef kw) {
8640 parseError << '`' << kw << '`';
8641 });
8642 return parseError;
8643 }
8644 }
8645 return success();
8646 }
8647
8648 // Helpers to print an optional array or integer field, to simplify writing
8649 // attribute printers.
8650 template <typename T>
printField(AsmPrinter & printer,StringRef name,T field,StringRef & separator)8651 static void printField(AsmPrinter& printer, StringRef name, T field,
8652 StringRef& separator) {
8653 if (field != 0) {
8654 printer << separator << name << " = " << field;
8655 separator = ", ";
8656 }
8657 }
8658 template <typename T>
printField(AsmPrinter & printer,StringRef name,ArrayRef<T> field,StringRef & separator)8659 static void printField(AsmPrinter& printer, StringRef name, ArrayRef<T> field,
8660 StringRef& separator) {
8661 if (!field.empty()) {
8662 printer << separator << name << " = [";
8663 llvm::interleaveComma(field, printer);
8664 printer << "]";
8665 separator = ", ";
8666 }
8667 }
8668 template <typename... Ts>
printStruct(AsmPrinter & printer,StringRef name,Ts...printFields)8669 static void printStruct(AsmPrinter& printer, StringRef name,
8670 Ts... printFields) {
8671 printer << "<";
8672 StringRef separator = "";
8673 // Fold expression to print each entry in the parameter pack.
8674 // TODO(mhlo-team): this can be simplified when TF moves to C++17.
8675 using unused = int[];
8676 (void)unused{0, (printField(printer, std::get<0>(printFields),
8677 std::get<1>(printFields), separator),
8678 0)...};
8679 printer << ">";
8680 }
8681
8682 // Custom printer and parser for ScatterDimensionNumbersAttr.
print(AsmPrinter & printer) const8683 void ScatterDimensionNumbersAttr::print(AsmPrinter& printer) const {
8684 printStruct(printer, "scatter",
8685 std::make_pair("update_window_dims", getUpdateWindowDims()),
8686 std::make_pair("inserted_window_dims", getInsertedWindowDims()),
8687 std::make_pair("scatter_dims_to_operand_dims",
8688 getScatterDimsToOperandDims()),
8689 std::make_pair("index_vector_dim", getIndexVectorDim()));
8690 }
parse(AsmParser & parser,Type type)8691 Attribute ScatterDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
8692 if (failed(parser.parseLess())) return {};
8693 SmallVector<int64_t> updateWindowDims;
8694 SmallVector<int64_t> insertedWindowDims;
8695 SmallVector<int64_t> scatterDimsToOperandDims;
8696 int64_t indexVectorDim = 0;
8697
8698 if (failed(parseStruct(
8699 parser,
8700 {"update_window_dims", "inserted_window_dims",
8701 "scatter_dims_to_operand_dims", "index_vector_dim"},
8702 {[&]() { return parseDims(parser, updateWindowDims); },
8703 [&]() { return parseDims(parser, insertedWindowDims); },
8704 [&]() { return parseDims(parser, scatterDimsToOperandDims); },
8705 [&]() { return parser.parseInteger(indexVectorDim); }}))) {
8706 parser.emitError(parser.getCurrentLocation())
8707 << "failed parsing scatter dimension numbers attribute";
8708 return {};
8709 }
8710
8711 return ScatterDimensionNumbersAttr::get(
8712 parser.getContext(), updateWindowDims, insertedWindowDims,
8713 scatterDimsToOperandDims, indexVectorDim);
8714 }
8715
8716 // Custom printer and parser for GatherDimensionNumbersAttr.
print(AsmPrinter & printer) const8717 void GatherDimensionNumbersAttr::print(AsmPrinter& printer) const {
8718 printStruct(printer, "gather", std::make_pair("offset_dims", getOffsetDims()),
8719 std::make_pair("collapsed_slice_dims", getCollapsedSliceDims()),
8720 std::make_pair("start_index_map", getStartIndexMap()),
8721 std::make_pair("index_vector_dim", getIndexVectorDim()));
8722 }
8723
parse(AsmParser & parser,Type type)8724 Attribute GatherDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
8725 if (failed(parser.parseLess())) return {};
8726
8727 SmallVector<int64_t> offsetDims;
8728 SmallVector<int64_t> collapsedSliceDims;
8729 SmallVector<int64_t> startIndexMap;
8730 int64_t indexVectorDim = 0;
8731
8732 if (failed(parseStruct(
8733 parser,
8734 {"offset_dims", "collapsed_slice_dims", "start_index_map",
8735 "index_vector_dim"},
8736 {[&]() { return parseDims(parser, offsetDims); },
8737 [&]() { return parseDims(parser, collapsedSliceDims); },
8738 [&]() { return parseDims(parser, startIndexMap); },
8739 [&]() { return parser.parseInteger(indexVectorDim); }}))) {
8740 parser.emitError(parser.getCurrentLocation())
8741 << "failed parsing gather dimension numbers attribute";
8742 return {};
8743 }
8744
8745 return GatherDimensionNumbersAttr::get(parser.getContext(), offsetDims,
8746 collapsedSliceDims, startIndexMap,
8747 indexVectorDim);
8748 }
8749
8750 // Custom printer and parser for DotDimensionNumbersAttr.
print(AsmPrinter & printer) const8751 void DotDimensionNumbersAttr::print(AsmPrinter& printer) const {
8752 printStruct(
8753 printer, "dot",
8754 std::make_pair("lhs_batching_dimensions", getLhsBatchingDimensions()),
8755 std::make_pair("rhs_batching_dimensions", getRhsBatchingDimensions()),
8756 std::make_pair("lhs_contracting_dimensions",
8757 getLhsContractingDimensions()),
8758 std::make_pair("rhs_contracting_dimensions",
8759 getRhsContractingDimensions()));
8760 }
8761
parse(AsmParser & parser,Type type)8762 Attribute DotDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
8763 if (failed(parser.parseLess())) return {};
8764
8765 SmallVector<int64_t> lhsBatchingDimensions;
8766 SmallVector<int64_t> rhsBatchingDimensions;
8767 SmallVector<int64_t> lhsContractingDimensions;
8768 SmallVector<int64_t> rhsContractingDimensions;
8769
8770 if (failed(parseStruct(
8771 parser,
8772 {"lhs_batching_dimensions", "rhs_batching_dimensions",
8773 "lhs_contracting_dimensions", "rhs_contracting_dimensions"},
8774 {[&]() { return parseDims(parser, lhsBatchingDimensions); },
8775 [&]() { return parseDims(parser, rhsBatchingDimensions); },
8776 [&]() { return parseDims(parser, lhsContractingDimensions); },
8777 [&]() { return parseDims(parser, rhsContractingDimensions); }}))) {
8778 parser.emitError(parser.getCurrentLocation())
8779 << "failed parsing dot dimension numbers attribute";
8780 return {};
8781 }
8782 return DotDimensionNumbersAttr::get(
8783 parser.getContext(), lhsBatchingDimensions, rhsBatchingDimensions,
8784 lhsContractingDimensions, rhsContractingDimensions);
8785 }
8786
8787 namespace {
8788 enum NonSpatialDim : int64_t {
8789 IOBatch = -1, // Input or output batch dimension
8790 IOFeature = -2, // Input or output feature dimension
8791 KIFeature = -3, // Kernel input feature dimension
8792 KOFeature = -4, // Kernel output feature dimensions.
8793 };
8794
8795 struct DenseMapInfoNonSpatialDim {
getEmptyKeymlir::mhlo::__anon00baf10a5311::DenseMapInfoNonSpatialDim8796 static inline NonSpatialDim getEmptyKey() {
8797 return NonSpatialDim(DenseMapInfo<int64_t>::getEmptyKey());
8798 }
8799
getTombstoneKeymlir::mhlo::__anon00baf10a5311::DenseMapInfoNonSpatialDim8800 static inline NonSpatialDim getTombstoneKey() {
8801 return NonSpatialDim(DenseMapInfo<int64_t>::getTombstoneKey());
8802 }
8803
getHashValuemlir::mhlo::__anon00baf10a5311::DenseMapInfoNonSpatialDim8804 static unsigned getHashValue(const NonSpatialDim& key) {
8805 return DenseMapInfo<int64_t>::getHashValue(key);
8806 }
8807
isEqualmlir::mhlo::__anon00baf10a5311::DenseMapInfoNonSpatialDim8808 static bool isEqual(const NonSpatialDim& lhs, const NonSpatialDim& rhs) {
8809 return lhs == rhs;
8810 }
8811 };
8812
nonSpatialDimToString(NonSpatialDim dim)8813 char nonSpatialDimToString(NonSpatialDim dim) {
8814 switch (dim) {
8815 case IOBatch:
8816 return 'b';
8817 case IOFeature:
8818 return 'f';
8819 case KIFeature:
8820 return 'i';
8821 case KOFeature:
8822 return 'o';
8823 }
8824 llvm_unreachable("Unknown NonSpatialDim");
8825 }
8826 } // namespace
8827
8828 // Custom printer and parser for convolution attribute.
printConvolutionDimensions(AsmPrinter & p,ConvDimensionNumbersAttr dnums)8829 void printConvolutionDimensions(AsmPrinter& p, ConvDimensionNumbersAttr dnums) {
8830 // TODO(b/202040055): we should check the attribute invariant and print the
8831 // "raw" form if they are violated, otherwise we'll crash here.
8832 constexpr int64_t kUnknownDim = std::numeric_limits<int64_t>::min();
8833 auto printDim =
8834 [&](ArrayRef<int64_t> spatialDims,
8835 ArrayRef<std::pair<int64_t, NonSpatialDim>> nonSpatialDims) {
8836 int64_t numDims = 0;
8837 if (!spatialDims.empty()) {
8838 numDims =
8839 *std::max_element(spatialDims.begin(), spatialDims.end()) + 1;
8840 }
8841 for (const auto& dim : nonSpatialDims) {
8842 numDims = std::max(numDims, dim.first + 1);
8843 }
8844
8845 llvm::SmallVector<int64_t> dims(numDims, kUnknownDim);
8846 // Fill each element of dims with a (< 0) NonSpatialDim enum or a (>=0)
8847 // spatial dimension index.
8848 for (const std::pair<int64_t, NonSpatialDim>& nonSpatialDim :
8849 nonSpatialDims) {
8850 dims[nonSpatialDim.first] = nonSpatialDim.second;
8851 }
8852 for (const auto& spatialDim : llvm::enumerate(spatialDims)) {
8853 dims[spatialDim.value()] = static_cast<int64_t>(spatialDim.index());
8854 }
8855
8856 // Each dimension numbers will be printed as a comma separated list
8857 // surrounded by square brackets, e.g., [b, 0, 1, 2, f]
8858 p << '[';
8859 llvm::interleaveComma(dims, p, [&](int64_t dim) {
8860 if (dim == kUnknownDim) {
8861 p << "?";
8862 } else if (dim >= 0) {
8863 p << dim;
8864 } else {
8865 p << nonSpatialDimToString(static_cast<NonSpatialDim>(dim));
8866 }
8867 });
8868 p << ']';
8869 };
8870
8871 printDim(dnums.getInputSpatialDimensions(),
8872 {{dnums.getInputBatchDimension(), IOBatch},
8873 {dnums.getInputFeatureDimension(), IOFeature}});
8874 p << "x";
8875 printDim(dnums.getKernelSpatialDimensions(),
8876 {{dnums.getKernelInputFeatureDimension(), KIFeature},
8877 {dnums.getKernelOutputFeatureDimension(), KOFeature}});
8878 p << "->";
8879 printDim(dnums.getOutputSpatialDimensions(),
8880 {{dnums.getOutputBatchDimension(), IOBatch},
8881 {dnums.getOutputFeatureDimension(), IOFeature}});
8882 }
8883
printConvolutionDimensions(AsmPrinter & p,Operation *,ConvDimensionNumbersAttr dnums)8884 void printConvolutionDimensions(AsmPrinter& p, Operation*,
8885 ConvDimensionNumbersAttr dnums) {
8886 printConvolutionDimensions(p, dnums);
8887 }
8888
8889 // Custom printer and parser for ConvDimensionNumbersAttr.
print(AsmPrinter & printer) const8890 void ConvDimensionNumbersAttr::print(AsmPrinter& printer) const {
8891 printer << "<";
8892 printConvolutionDimensions(printer, *this);
8893 printer << ">";
8894 }
8895
8896 // If the attribute is written with `#mhlo.conv raw<`, we parse it as a struct
8897 // instead of the compressed format. This enables writing tests covering
8898 // impossible/invalid internal representation for the attribute.
parseConvolutionDimensionsRaw(AsmParser & parser,ConvDimensionNumbersAttr & dnums)8899 static ParseResult parseConvolutionDimensionsRaw(
8900 AsmParser& parser, ConvDimensionNumbersAttr& dnums) {
8901 int64_t inputBatchDimension = 0;
8902 int64_t inputFeatureDimension = 0;
8903 SmallVector<int64_t> inputSpatialDimensions;
8904 int64_t kernelInputFeatureDimension = 0;
8905 int64_t kernelOutputFeatureDimension = 0;
8906 SmallVector<int64_t> kernelSpatialDimensions;
8907 int64_t outBatchDimension = 0;
8908 int64_t outputFeatureDimension = 0;
8909 SmallVector<int64_t> outputSpatialDimensions;
8910 if (failed(parseStruct(
8911 parser,
8912 {"input_batch_dimension", "input_feature_dimension",
8913 "input_spatial_dimensions", "kernel_input_feature_dimension",
8914 "kernel_output_feature_dimension", "kernel_spatial_dimensions",
8915 "output_batch_dimension", "output_feature_dimension",
8916 "output_spatial_dimensions"},
8917 {
8918 [&]() { return parser.parseInteger(inputBatchDimension); },
8919 [&]() { return parser.parseInteger(inputFeatureDimension); },
8920 [&]() { return parseDims(parser, inputSpatialDimensions); },
8921 [&]() {
8922 return parser.parseInteger(kernelInputFeatureDimension);
8923 },
8924 [&]() {
8925 return parser.parseInteger(kernelOutputFeatureDimension);
8926 },
8927 [&]() { return parseDims(parser, kernelSpatialDimensions); },
8928 [&]() { return parser.parseInteger(outBatchDimension); },
8929 [&]() { return parser.parseInteger(outputFeatureDimension); },
8930 [&]() { return parseDims(parser, outputSpatialDimensions); },
8931 }))) {
8932 parser.emitError(parser.getCurrentLocation())
8933 << "failed parsing dot dimension numbers attribute";
8934 return failure();
8935 }
8936 dnums = ConvDimensionNumbersAttr::get(
8937 parser.getBuilder().getContext(), inputBatchDimension,
8938 inputFeatureDimension, inputSpatialDimensions,
8939 kernelInputFeatureDimension, kernelOutputFeatureDimension,
8940 kernelSpatialDimensions, outBatchDimension, outputFeatureDimension,
8941 outputSpatialDimensions);
8942 return success();
8943 }
8944
parseConvolutionDimensions(AsmParser & parser,ConvDimensionNumbersAttr & dnums)8945 ParseResult parseConvolutionDimensions(AsmParser& parser,
8946 ConvDimensionNumbersAttr& dnums) {
8947 // Parsing a single set of dim numbers gives the spatial dimensions as a
8948 // single ArrayRef<int64_t> and a list of non-spatial dimensions as
8949 // IntegerAttrs (indexed by the NonSpatialDim enum).
8950 using parse_dim_result_t =
8951 std::pair<llvm::SmallVector<int64_t>,
8952 llvm::SmallDenseMap<NonSpatialDim, int64_t, 4,
8953 DenseMapInfoNonSpatialDim>>;
8954
8955 // Note that the allowed_non_spatial_dims is a set (as opposed to unordered
8956 // set) because its used to print a list of allowed non spatial dims in the
8957 // error messages, so making it a set keeps the error messages deterministic.
8958 auto parseDims =
8959 [&](std::set<NonSpatialDim, std::greater<>> allowedNonSpatialDims,
8960 parse_dim_result_t& parsedDims) -> ParseResult {
8961 auto& spatialDims = std::get<0>(parsedDims);
8962 auto& nonSpatialDims = std::get<1>(parsedDims);
8963 spatialDims.clear();
8964 nonSpatialDims.clear();
8965
8966 // Parse the starting [
8967 if (parser.parseLSquare()) {
8968 return failure();
8969 }
8970
8971 llvm::SmallDenseMap<int64_t, int64_t> spatialDimsMap;
8972 constexpr int64_t kInvalidDimension = -1;
8973 // Keep track of the maximum spatial dimension parsed as we expect to see
8974 // all the dimensions from 0 to maximum dimension parsed.
8975 int64_t maxParsedSpatialDim = kInvalidDimension;
8976
8977 int64_t index = 0;
8978 do {
8979 int64_t spatialDim;
8980 auto dimLocation = parser.getCurrentLocation();
8981 OptionalParseResult parseResult = parser.parseOptionalInteger(spatialDim);
8982 if (parseResult.hasValue()) {
8983 if (parseResult.getValue().failed()) {
8984 return failure();
8985 }
8986 // We were successful in parsing an integer. Check if it is a valid
8987 // dimension (non-negative and no duplicate) and add its index to the
8988 // spatial dims map.
8989 if (spatialDim < 0)
8990 return parser.emitError(dimLocation)
8991 << "Unexpected dimension " << spatialDim;
8992 if (!spatialDimsMap
8993 .insert(std::pair<int64_t, int64_t>(spatialDim, index))
8994 .second)
8995 return parser.emitError(dimLocation)
8996 << "Duplicate entries for spatial dimension " << spatialDim;
8997 maxParsedSpatialDim = std::max(spatialDim, maxParsedSpatialDim);
8998 } else if (!parser.parseOptionalQuestion()) {
8999 // Do nothing other than increment `index` at the bottom of the loop;
9000 // '?' means "unknown dimension", and it's not represented in the
9001 // return value of this function.
9002 } else {
9003 // We did not parse an integer or question mark. We expect a keyword
9004 // token.
9005 StringRef keyword;
9006 if (parser.parseKeyword(&keyword)) {
9007 return failure();
9008 }
9009 if (keyword.size() != 1 || allowedNonSpatialDims.empty()) {
9010 return parser.emitError(dimLocation, "Unexpected keyword ")
9011 << keyword;
9012 }
9013 // Check if the keyword matches one of the allowed non-spatial dims.
9014 // If so, add it to the non_spatial dims and remove it from the
9015 // allowed set so that it won't be allowed again.
9016 bool isAllowed = false;
9017 for (NonSpatialDim allowed : allowedNonSpatialDims) {
9018 if (keyword[0] == nonSpatialDimToString(allowed)) {
9019 nonSpatialDims.insert({allowed, index});
9020 allowedNonSpatialDims.erase(allowed);
9021 isAllowed = true;
9022 break;
9023 }
9024 }
9025
9026 if (!isAllowed) {
9027 mlir::InFlightDiagnostic diag =
9028 parser.emitError(dimLocation, "Unexpected dimension ");
9029 diag << keyword << ", expecting ";
9030 llvm::interleaveComma(
9031 allowedNonSpatialDims, diag,
9032 [&](NonSpatialDim dim) { diag << nonSpatialDimToString(dim); });
9033 return diag;
9034 }
9035 }
9036 index++;
9037 } while (parser.parseOptionalComma().succeeded());
9038
9039 // Make sure all expected non-spatial dimensions are parsed.
9040 if (!allowedNonSpatialDims.empty()) {
9041 mlir::InFlightDiagnostic diag =
9042 parser.emitError(parser.getCurrentLocation(), "Expected dimensions ");
9043 llvm::interleaveComma(
9044 allowedNonSpatialDims, diag,
9045 [&](NonSpatialDim dim) { diag << nonSpatialDimToString(dim); });
9046 diag << " not specified";
9047 return diag;
9048 }
9049
9050 // parse ending ]
9051 if (parser.parseRSquare()) {
9052 return failure();
9053 }
9054
9055 // Number of expected spatial dimensions is one more than the maximum parsed
9056 // spatial dimension. For example, if we parse [0, 3, 2, b, i, 1], then the
9057 // maximum parsed spatial dimension is 3 and the number of expected spatial
9058 // dimensions is 4.
9059 int64_t numSpatialDimensions = maxParsedSpatialDim + 1;
9060 spatialDims.resize(numSpatialDimensions);
9061 // Store spatial dimensions in a vector which maps spatial dim (vector
9062 // index) -> index in the tensor dimensions. For example, for parsed
9063 // dimension numbers [0, 3, 2, b, i, 1] the spatial dimension vector would
9064 // be [0, 5, 2, 1].
9065 //
9066 // Get all the unspecified spatial dimensions to throw a more descriptive
9067 // error later.
9068 llvm::SmallVector<int64_t> unspecifiedSpatialDims;
9069 constexpr int kPrintUnspecifiedDimsMax = 10;
9070 for (int dim = 0; dim < numSpatialDimensions; ++dim) {
9071 auto it = spatialDimsMap.find(dim);
9072 if (it == spatialDimsMap.end()) {
9073 // Have an upper bound on the number of unspecified dimensions to print
9074 // in the error message.
9075 if (unspecifiedSpatialDims.size() < kPrintUnspecifiedDimsMax)
9076 unspecifiedSpatialDims.push_back(dim);
9077 continue;
9078 }
9079 spatialDims[dim] = it->second;
9080 }
9081
9082 // Verify that we got all spatial dimensions between 0 and maximum parsed
9083 // spatial dimension.
9084 if (!unspecifiedSpatialDims.empty()) {
9085 mlir::InFlightDiagnostic diag = parser.emitError(
9086 parser.getCurrentLocation(), "Expected spatial dimensions ");
9087 llvm::interleaveComma(unspecifiedSpatialDims, diag);
9088 diag << " not specified";
9089 return diag;
9090 }
9091
9092 return success();
9093 };
9094
9095 parse_dim_result_t parsedDims;
9096 if (parseDims({IOBatch, IOFeature}, parsedDims)) {
9097 return failure();
9098 }
9099 llvm::SmallVector<int64_t> inputSpatialDimensions = parsedDims.first;
9100 int64_t inputBatchDimension = parsedDims.second[IOBatch];
9101 int64_t inputFeatureDimension = parsedDims.second[IOFeature];
9102 if (parser.parseKeyword("x")) return failure();
9103 if (parseDims({KIFeature, KOFeature}, parsedDims)) {
9104 return failure();
9105 }
9106 llvm::SmallVector<int64_t> kernelSpatialDimensions = parsedDims.first;
9107 int64_t kernelInputFeatureDimension = parsedDims.second[KIFeature];
9108 int64_t kernelOutputFeatureDimension = parsedDims.second[KOFeature];
9109 if (parser.parseArrow()) {
9110 return failure();
9111 }
9112 if (parseDims({IOBatch, IOFeature}, parsedDims)) {
9113 return failure();
9114 }
9115 llvm::SmallVector<int64_t> outputSpatialDimensions = parsedDims.first;
9116 const int64_t outBatchDimension = parsedDims.second[IOBatch];
9117 const int64_t outputFeatureDimension = parsedDims.second[IOFeature];
9118 dnums = ConvDimensionNumbersAttr::get(
9119 parser.getBuilder().getContext(), inputBatchDimension,
9120 inputFeatureDimension, inputSpatialDimensions,
9121 kernelInputFeatureDimension, kernelOutputFeatureDimension,
9122 kernelSpatialDimensions, outBatchDimension, outputFeatureDimension,
9123 outputSpatialDimensions);
9124
9125 return success();
9126 }
9127
parse(AsmParser & parser,Type type)9128 Attribute ConvDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
9129 if (failed(parser.parseLess())) return {};
9130 ConvDimensionNumbersAttr dnums;
9131 if (succeeded(parser.parseOptionalKeyword("raw"))) {
9132 if (failed(parseConvolutionDimensionsRaw(parser, dnums))) return {};
9133 return dnums;
9134 }
9135 if (failed(parseConvolutionDimensions(parser, dnums))) return {};
9136 if (failed(parser.parseGreater())) return {};
9137 return dnums;
9138 }
9139
9140 // Custom printer and parser for ArgResultAliasAttr.
9141 constexpr char kMustAlias[] = "must_alias";
9142 constexpr char kResult[] = "result_index";
9143 constexpr char kArgTupleIndices[] = "tuple_indices";
9144
print(AsmPrinter & printer) const9145 void ArgResultAliasAttr::print(AsmPrinter& printer) const {
9146 printer << "<";
9147
9148 // The attribute can have empty tuple indices. Only print argument tuple
9149 // indices if they are non-empty.
9150 if (!getArgTupleIndices().empty())
9151 printer << kArgTupleIndices << " = [" << getArgTupleIndices() << "], ";
9152
9153 // Print the result index followed by any result tuple indices if present.
9154 printer << kResult << " = [";
9155 printer << getResultIndex();
9156 if (!getResultTupleIndices().empty()) {
9157 printer << ", " << getResultTupleIndices();
9158 }
9159 printer << "]";
9160
9161 // Print the "must_alias" keyword if this is a must alias, otherwise skip.
9162 if (getIsMustAlias()) printer << ", " << kMustAlias;
9163
9164 printer << ">";
9165 }
9166
parse(AsmParser & parser,Type type)9167 Attribute ArgResultAliasAttr::parse(AsmParser& parser, Type type) {
9168 if (failed(parser.parseLess())) return {};
9169 llvm::SmallVector<int64_t> argTupleIndices;
9170 // The first element of result indices holds the aliased result index and the
9171 // remaining elements are the result tuple indices.
9172 llvm::SmallVector<int64_t> resultIndices;
9173 bool isMustAlias = false;
9174
9175 // This conveys to parseStruct that keyword "must_alias" (3rd field) is not
9176 // followed by a "=", but other fields are.
9177 llvm::SmallVector<bool, 3> parseEqual = {true, true, false};
9178
9179 if (failed(parseStruct(parser, {kArgTupleIndices, kResult, kMustAlias},
9180 {[&]() { return parseDims(parser, argTupleIndices); },
9181 [&]() {
9182 // Since the first element is the index of result,
9183 // at least one element is expected.
9184 return parseDimsWithMinimumElements(
9185 parser, resultIndices, /*minElements=*/1);
9186 },
9187 [&]() {
9188 // always succeeds if the keyword "must_alias" was
9189 // parsed
9190 isMustAlias = true;
9191 return success();
9192 }},
9193 parseEqual))) {
9194 parser.emitError(parser.getCurrentLocation())
9195 << "failed parsing argument-result alias attribute";
9196 return {};
9197 }
9198
9199 int64_t resultIndex = resultIndices[0];
9200 auto resultTupleIndices =
9201 ArrayRef<int64_t>{resultIndices.begin() + 1, resultIndices.end()};
9202
9203 return ArgResultAliasAttr::get(parser.getContext(), argTupleIndices,
9204 resultIndex, resultTupleIndices, isMustAlias);
9205 }
9206
9207 // Returns the element type pointed to by `indices` in type `t`. If the indices
9208 // are invalid, returns nullptr.
getTypeFromTupleIndices(Type type,ArrayRef<int64_t> indices)9209 static Type getTypeFromTupleIndices(Type type, ArrayRef<int64_t> indices) {
9210 Type current = type;
9211 for (auto index : indices) {
9212 TupleType tupleType = current.dyn_cast<TupleType>();
9213 if (!tupleType || index >= static_cast<int64_t>(tupleType.size()))
9214 return {};
9215 current = tupleType.getType(index);
9216 }
9217 return current;
9218 }
9219
verifyArgResultAliasAttr(StringAttr attrName,ArgResultAliasAttr aliasAttr,unsigned argIndex,Operation * op)9220 static LogicalResult verifyArgResultAliasAttr(StringAttr attrName,
9221 ArgResultAliasAttr aliasAttr,
9222 unsigned argIndex,
9223 Operation* op) {
9224 // The attribute can only be applied to function-like operations.
9225 if (!isa<mlir::FunctionOpInterface>(op))
9226 return op->emitOpError() << "attribute " << attrName
9227 << " can only be used on function-like operations";
9228
9229 // Verify there are no negative indices.
9230 auto tupleIndices = llvm::concat<const int64_t>(
9231 aliasAttr.getArgTupleIndices(), aliasAttr.getResultTupleIndices());
9232 if (llvm::any_of(tupleIndices, [](const int64_t val) { return val < 0; }) ||
9233 aliasAttr.getResultIndex() < 0)
9234 return op->emitOpError()
9235 << "attribute " << attrName
9236 << " expects all argument and result indices to be >= 0";
9237
9238 // Verify that the result index is not out of range. Since the attribute is a
9239 // function argument attribute, the argument index is always correct when this
9240 // verifier is called.
9241 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
9242 ArrayRef<Type> argTypes = funcOp.getArgumentTypes();
9243 ArrayRef<Type> resultTypes = funcOp.getResultTypes();
9244 if (aliasAttr.getResultIndex() >= static_cast<int64_t>(resultTypes.size()))
9245 return op->emitOpError()
9246 << "attribute " << attrName
9247 << " result index is out of range, must be <" << resultTypes.size();
9248
9249 // Verify that argument and result types pointed to by the indices are valid
9250 // and compatible.
9251 Type argType = getTypeFromTupleIndices(argTypes[argIndex],
9252 aliasAttr.getArgTupleIndices());
9253 if (!argType)
9254 return op->emitOpError()
9255 << "attribute " << attrName << " argument tuple indices are invalid";
9256 Type resultType =
9257 getTypeFromTupleIndices(resultTypes[aliasAttr.getResultIndex()],
9258 aliasAttr.getResultTupleIndices());
9259 if (!resultType)
9260 return op->emitOpError()
9261 << "attribute " << attrName << " result tuple indices are invalid";
9262
9263 if (failed(mlir::verifyCompatibleShape(argType, resultType)) ||
9264 getElementTypeOrSelf(argType) != getElementTypeOrSelf(resultType))
9265 return op->emitOpError() << "attribute " << attrName
9266 << " aliases do not have compatible types, "
9267 << argType << " vs. " << resultType;
9268 return success();
9269 }
9270
9271 //===----------------------------------------------------------------------===//
9272 // Type utilities
9273 //===----------------------------------------------------------------------===//
9274
getExpressedTypeOrSelf(Type type)9275 Type getExpressedTypeOrSelf(Type type) {
9276 auto quantType = type.dyn_cast<quant::QuantizedType>();
9277 return quantType ? quantType.getExpressedType() : type;
9278 }
9279
verifyCompatibleShapeWithBounds(Type type1,Type type2)9280 LogicalResult verifyCompatibleShapeWithBounds(Type type1, Type type2) {
9281 if (failed(verifyCompatibleShape(type1, type2))) return failure();
9282
9283 // Verify shapes against bounds
9284 auto isCompatible = [](ArrayRef<int64_t> shape,
9285 TypeExtensionsAttr extensionAttr) {
9286 if (shape.empty() || !extensionAttr) return true;
9287 auto bounds = extensionAttr.getBounds();
9288 for (auto [dim_size, bound] : llvm::zip(shape, bounds)) // NOLINT
9289 if (bound != ShapedType::kDynamicSize && bound < dim_size) return false;
9290 return true;
9291 };
9292
9293 RankedTensorType rankedType1 = type1.dyn_cast<RankedTensorType>();
9294 RankedTensorType rankedType2 = type2.dyn_cast<RankedTensorType>();
9295 if (rankedType1 && rankedType2) {
9296 TypeExtensionsAttr extensionAttr1 =
9297 rankedType1.getEncoding().dyn_cast_or_null<TypeExtensionsAttr>();
9298 TypeExtensionsAttr extensionAttr2 =
9299 rankedType2.getEncoding().dyn_cast_or_null<TypeExtensionsAttr>();
9300 return LogicalResult::success(
9301 isCompatible(rankedType1.getShape(), extensionAttr2) &&
9302 isCompatible(rankedType2.getShape(), extensionAttr1));
9303 }
9304 return success();
9305 }
9306
isCompatibleForMhloTypeInference(Type tp1,Type tp2)9307 bool isCompatibleForMhloTypeInference(Type tp1, Type tp2) {
9308 // Dynamism: We don't require shapes to be the same, we only require them
9309 // to be compatible, which means that:
9310 // 1) At least one of the shapes is unranked.
9311 // 2) Or both shapes have the same rank and their dimensions are compatible,
9312 // i.e. for each pair of corresponding dimensions:
9313 // 2.1) At least one of the dimensions is dynamic,
9314 // 2.2) Or both dimensions are equal.
9315 // These relaxed rules simplify the implementation of type inference, allowing
9316 // ops with partially inferred types to pass verification.
9317 auto stp1 = tp1.dyn_cast<ShapedType>();
9318 auto stp2 = tp2.dyn_cast<ShapedType>();
9319 if (stp1 && stp2) {
9320 return succeeded(verifyCompatibleShapeWithBounds(stp1, stp2)) &&
9321 isCompatibleForMhloTypeInference(stp1.getElementType(),
9322 stp2.getElementType());
9323 }
9324
9325 // Quantization: In the most general case, we allow any combination of
9326 // quantized/non-quantized across any combination of operands/results,
9327 // and some differences in quantization parameters across operands/results.
9328 // Individual ops may introduce additional constraints.
9329 auto qtp1 = tp1.dyn_cast<quant::QuantizedType>();
9330 auto qtp2 = tp2.dyn_cast<quant::QuantizedType>();
9331 if (qtp1 && qtp2) {
9332 if (qtp1.getStorageType() != qtp2.getStorageType() ||
9333 qtp1.getStorageTypeMin() != qtp2.getStorageTypeMin() ||
9334 qtp1.getStorageTypeMax() != qtp2.getStorageTypeMax())
9335 return false;
9336 }
9337 auto etp1 = getExpressedTypeOrSelf(tp1);
9338 auto etp2 = getExpressedTypeOrSelf(tp2);
9339
9340 // Sparsity: In the most general case, we allow any combination of
9341 // sparsity/denseness across any combination of operands/results, as well as
9342 // differences in sparsity encodings for operands and results.
9343 // Individual ops may introduce additional constraints.
9344 // No additional code is needed to check this because of how sparsity is
9345 // currently implemented.
9346
9347 // Default case: Unless dynamism, quantization and/or sparsity are involved,
9348 // the types are required to be exactly equal.
9349 return etp1 == etp2;
9350 }
9351
9352 //===----------------------------------------------------------------------===//
9353 // Builder utilities
9354 //===----------------------------------------------------------------------===//
9355
9356 // Builds the region `body` for mhlo.sort's comparator: for each type in
9357 // `element_types`, create two block arguments, one for lhs and one for rhs, and
9358 // generates mhlo.compare op to compare them with the given `direction`.
9359 //
9360 // Note that this right now only does comparision on the first pair of block
9361 // arguments.
buildSortComparisonBody(llvm::ArrayRef<Type> elementTypes,ComparisonDirection direction,llvm::Optional<StringRef> compareType,Region * body,OpBuilder * builder)9362 static void buildSortComparisonBody(llvm::ArrayRef<Type> elementTypes,
9363 ComparisonDirection direction,
9364 llvm::Optional<StringRef> compareType,
9365 Region* body, OpBuilder* builder) {
9366 OpBuilder::InsertionGuard insertionPointGurad(*builder);
9367
9368 Location loc = body->getLoc();
9369 Block* block = builder->createBlock(body);
9370 // Add two arguments for each element type.
9371 for (Type elementType : elementTypes) {
9372 TensorType tensorType = RankedTensorType::get({}, elementType);
9373 block->addArguments({tensorType, tensorType},
9374 SmallVector<Location, 2>(2, loc));
9375 }
9376
9377 ComparisonType typeAttr;
9378 if (compareType)
9379 typeAttr = symbolizeComparisonType(*compareType).value();
9380 else
9381 typeAttr = ComparisonType::NOTYPE;
9382 Value compare = builder->create<mhlo::CompareOp>(
9383 loc, block->getArgument(0), block->getArgument(1), direction, typeAttr);
9384
9385 builder->create<mhlo::ReturnOp>(loc, compare);
9386 }
9387
createSortOp(PatternRewriter * rewriter,const Location & loc,const llvm::ArrayRef<Value> & operands,const llvm::ArrayRef<Type> & elementTypes,int64_t dimension,bool isStable,ComparisonDirection direction)9388 SortOp createSortOp(PatternRewriter* rewriter, const Location& loc,
9389 const llvm::ArrayRef<Value>& operands,
9390 const llvm::ArrayRef<Type>& elementTypes, int64_t dimension,
9391 bool isStable, ComparisonDirection direction) {
9392 assert(!operands.empty() && "No operands to sort");
9393 // Create the sort op.
9394 auto sortOp =
9395 rewriter->create<mhlo::SortOp>(loc, operands, dimension, isStable);
9396
9397 // Use TOTALORDER comparison type instead of the default comparison if the
9398 // element type is of type float.
9399 llvm::Optional<StringRef> compareType = llvm::None;
9400 for (auto const& elementType : elementTypes)
9401 if (elementType.isa<FloatType>()) {
9402 compareType.emplace("TOTALORDER");
9403 break;
9404 }
9405 buildSortComparisonBody(elementTypes, direction, compareType,
9406 &sortOp.comparator(), rewriter);
9407 return sortOp;
9408 }
9409
9410 //===----------------------------------------------------------------------===//
9411 // Shape inference
9412 //===----------------------------------------------------------------------===//
9413
deriveShapeFromOperand(OpBuilder * builder,Operation * op,Value operand,SmallVectorImpl<Value> * reifiedReturnShapes)9414 LogicalResult deriveShapeFromOperand(
9415 OpBuilder* builder, Operation* op, Value operand,
9416 SmallVectorImpl<Value>* reifiedReturnShapes) {
9417 auto shapedTy = operand.getType().dyn_cast<ShapedType>();
9418 if (!shapedTy) {
9419 op->emitOpError() << "operand is not a shaped type";
9420 return failure();
9421 }
9422 reifiedReturnShapes->assign(
9423 {builder->create<shape::ShapeOfOp>(op->getLoc(), operand)});
9424 return success();
9425 }
9426
9427 //===----------------------------------------------------------------------===//
9428 // MHLO Dialect Hooks
9429 //===----------------------------------------------------------------------===//
9430
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)9431 Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value,
9432 Type type, Location loc) {
9433 auto elementsAttr = value.dyn_cast<ElementsAttr>();
9434 // HLO dialect constants only support ElementsAttr unlike standard dialect
9435 // constant which supports all attributes.
9436 if (!elementsAttr) return nullptr;
9437 // HLO dialect constants require the type of value and result to match.
9438 if (type != elementsAttr.getType()) return nullptr;
9439
9440 return builder.create<mhlo::ConstantOp>(loc, type, elementsAttr);
9441 }
9442
verifyRegionArgAttribute(Operation * op,unsigned,unsigned argIndex,NamedAttribute attr)9443 LogicalResult MhloDialect::verifyRegionArgAttribute(Operation* op,
9444 unsigned /*regionIndex*/,
9445 unsigned argIndex,
9446 NamedAttribute attr) {
9447 if (auto aliasAttr = attr.getValue().dyn_cast<ArgResultAliasAttr>()) {
9448 if (failed(
9449 verifyArgResultAliasAttr(attr.getName(), aliasAttr, argIndex, op)))
9450 return failure();
9451 }
9452 return success();
9453 }
9454
verifyOperationAttribute(Operation * op,NamedAttribute attr)9455 LogicalResult MhloDialect::verifyOperationAttribute(Operation* op,
9456 NamedAttribute attr) {
9457 if (auto aliasAttr = attr.getValue().dyn_cast<ArgResultAliasAttr>()) {
9458 if (!isa<mlir::FunctionOpInterface>(op))
9459 return op->emitOpError()
9460 << "attribute " << attr.getName()
9461 << " can only be used on function-like operations";
9462 }
9463 return success();
9464 }
9465
9466 } // namespace mhlo
9467 } // namespace mlir
9468