xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the operations used in the MHLO dialect.
17 
18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
19 
20 #include <assert.h>
21 #include <stddef.h>
22 #include <stdint.h>
23 
24 #include <algorithm>
25 #include <cstdint>
26 #include <functional>
27 #include <numeric>
28 #include <set>
29 #include <unordered_map>
30 #include <utility>
31 
32 #include "llvm/ADT/APFloat.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/DenseMap.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/SmallVector.h"
38 #include "llvm/ADT/StringExtras.h"
39 #include "llvm/ADT/StringRef.h"
40 #include "llvm/ADT/Twine.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/ADT/iterator_range.h"
43 #include "llvm/Support/Casting.h"
44 #include "llvm/Support/FormatVariadic.h"
45 #include "llvm/Support/MathExtras.h"
46 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
47 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
48 #include "mlir-hlo/utils/convert_op_folder.h"
49 #include "mlir-hlo/utils/hlo_utils.h"
50 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
51 #include "mlir/Dialect/Complex/IR/Complex.h"
52 #include "mlir/Dialect/Shape/IR/Shape.h"
53 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
54 #include "mlir/Dialect/Tensor/IR/Tensor.h"
55 #include "mlir/IR/Attributes.h"
56 #include "mlir/IR/Builders.h"
57 #include "mlir/IR/BuiltinAttributes.h"
58 #include "mlir/IR/BuiltinTypes.h"
59 #include "mlir/IR/Diagnostics.h"
60 #include "mlir/IR/Dialect.h"
61 #include "mlir/IR/FunctionInterfaces.h"
62 #include "mlir/IR/Location.h"
63 #include "mlir/IR/MLIRContext.h"
64 #include "mlir/IR/Matchers.h"
65 #include "mlir/IR/OpDefinition.h"
66 #include "mlir/IR/OpImplementation.h"
67 #include "mlir/IR/Operation.h"
68 #include "mlir/IR/OperationSupport.h"
69 #include "mlir/IR/PatternMatch.h"
70 #include "mlir/IR/TypeUtilities.h"
71 #include "mlir/IR/Types.h"
72 #include "mlir/IR/Value.h"
73 #include "mlir/Support/LLVM.h"
74 #include "mlir/Support/LogicalResult.h"
75 #include "mlir/Transforms/InliningUtils.h"
76 
77 namespace mlir {
78 #include "hlo_patterns.cc.inc"
79 }  // namespace mlir
80 
81 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.cc.inc"
82 #define GET_ATTRDEF_CLASSES
83 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.cc.inc"
84 
85 namespace mlir {
86 namespace mhlo {
87 namespace {
createArgs(ArrayRef<OpAsmParser::UnresolvedOperand> operands,ArrayRef<Type> types,SmallVector<OpAsmParser::Argument> & args)88 void createArgs(ArrayRef<OpAsmParser::UnresolvedOperand> operands,
89                 ArrayRef<Type> types,
90                 SmallVector<OpAsmParser::Argument>& args) {
91   for (auto argAndType : llvm::zip(operands, types)) {
92     auto& arg = args.emplace_back();
93     arg.ssaName = std::get<0>(argAndType);
94     arg.type = std::get<1>(argAndType);
95   }
96 }
97 
__anon00baf10a0202(SmallVector<int64_t>& nums) 98 const auto hasDuplicates = [](SmallVector<int64_t>& nums) {
99   if (!llvm::is_sorted(nums)) std::sort(nums.begin(), nums.end());
100   auto* last = std::unique(nums.begin(), nums.end());
101   return last != nums.end();
102 };
103 
104 //===----------------------------------------------------------------------===//
105 // Utilities for the canonicalize patterns
106 //===----------------------------------------------------------------------===//
107 
108 // This is an upper limit on how many elements can be folded by an op folder.
109 // This limit doesn't apply to some special cases like adding a zero,
110 // multiplying by one, doing many operations with splats.
111 constexpr int64_t kFoldOpEltLimit = 65536;
112 
113 // Clamps value to the range [lower, upper].  Requires lower <= upper.
114 template <typename T>
clamp(const T & value,const T & lower,const T & upper)115 static T clamp(const T& value, const T& lower, const T& upper) {
116   assert(lower <= upper);
117   return std::max(lower, std::min(value, upper));
118 }
119 
120 // Verifies that dimension attribute for the op correctly indexes in operand or
121 // result shape.
122 template <typename OpT>
verifyDimAttr(OpT op)123 static LogicalResult verifyDimAttr(OpT op) {
124   int64_t rank = -1;
125   if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) {
126     rank = ty.getRank();
127   } else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) {
128     rank = ty.getRank();
129   } else {
130     return success();
131   }
132 
133   int64_t dim = op.dimension();
134   if (dim < 0 || dim >= rank)
135     return op.emitOpError() << "requires dimension attribute in range [0, "
136                             << rank << "); found (" << dim << ")";
137   return success();
138 }
139 
140 // Given the start indices and slice sizes for a dynamic-slice that can be
141 // converted to a static slice, returns the limits for the static slice.
buildSliceLimits(DenseIntElementsAttr startIndices,DenseIntElementsAttr sliceSizes,Builder * builder)142 DenseIntElementsAttr buildSliceLimits(DenseIntElementsAttr startIndices,
143                                       DenseIntElementsAttr sliceSizes,
144                                       Builder* builder) {
145   SmallVector<int64_t, 4> sliceLimits;
146   for (int64_t i = 0; i < sliceSizes.getNumElements(); ++i) {
147     int64_t startIndex = startIndices.getValues<IntegerAttr>()[i].getInt();
148     int64_t sliceSize = sliceSizes.getValues<IntegerAttr>()[i].getInt();
149     sliceLimits.push_back(startIndex + sliceSize);
150   }
151   return builder->getI64TensorAttr(sliceLimits);
152 }
153 
154 /// Replaces the given op with the contents of the given single-block region,
155 /// using the operands of the block terminator to replace operation results.
replaceOpWithRegion(PatternRewriter & rewriter,Operation * op,Region & region,ValueRange blockArgs={})156 static void replaceOpWithRegion(PatternRewriter& rewriter, Operation* op,
157                                 Region& region, ValueRange blockArgs = {}) {
158   assert(llvm::hasSingleElement(region) && "expected single-block region");
159   Block* block = &region.front();
160   Operation* terminator = block->getTerminator();
161   ValueRange results = terminator->getOperands();
162   rewriter.mergeBlockBefore(block, op, blockArgs);
163   rewriter.replaceOp(op, results);
164   rewriter.eraseOp(terminator);
165 }
166 
167 #include "mhlo_canonicalize.inc"
168 
169 // Check if the dimension size is dynamic.
isDynamicDimSize(int64_t val)170 inline static bool isDynamicDimSize(int64_t val) {
171   return val == ShapedType::kDynamicSize;
172 }
173 
174 // Common shape function helper for RngNormal and RngUniform.
rngInferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)175 static LogicalResult rngInferReturnTypeComponents(
176     MLIRContext* context, Optional<Location> location, ValueRange operands,
177     DictionaryAttr attributes, RegionRange regions,
178     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
179   if (operands.size() != 3)
180     return emitOptionalError(location, "expected 3 operands");
181 
182   SmallVector<int64_t> shapeVector;
183   Value shapeOperand = operands[2];
184   auto shapeOperandType = shapeOperand.getType().cast<ShapedType>();
185   Type elementType = getElementTypeOrSelf(operands[1]);
186 
187   // Operand `shape` (1D by ODS) may be a constant or not, if `shape` is:
188   // 1, not constant and have dynimic dim (tensor<?x>): infer tensor<*x>.
189   // 2. not constant nor dynimic (e.g. tensor<3xi64>): infer tensor<?x?x?x>.
190   // 3. constant (e.g. dense<[2, 3, 5]>): infer tensor<2x3x5x>.
191 
192   // Match to check whether the `shape` operand is a constant.
193   DenseIntElementsAttr shape;
194   if (!matchPattern(shapeOperand, m_Constant(&shape))) {
195     int size = shapeOperandType.getDimSize(0);
196     if (isDynamicDimSize(size)) {
197       inferredReturnShapes.emplace_back(elementType);
198       return success();
199     }
200     shapeVector.resize(size, ShapedType::kDynamicSize);
201     inferredReturnShapes.emplace_back(shapeVector, elementType);
202     return success();
203   }
204 
205   // `shape` operand is a constant.
206   shapeVector.reserve(shape.size());
207   for (const APInt& fp : shape.getValues<APInt>())
208     shapeVector.push_back(fp.getSExtValue());
209   inferredReturnShapes.emplace_back(shapeVector, elementType);
210   return success();
211 }
212 
213 // Returns a new scalar integer value having type `type`. Here `type` must be
214 // an integer or index type.
maybeCastTo(OpBuilder & b,Location loc,Value value,Type type)215 Value maybeCastTo(OpBuilder& b, Location loc, Value value, Type type) {
216   if (type == value.getType()) return value;
217   assert(type.isIndex() || value.getType().isIndex());
218   return b.create<arith::IndexCastOp>(loc, type, value);
219 }
220 
reshape(DenseElementsAttr attr,ShapedType newType)221 DenseElementsAttr reshape(DenseElementsAttr attr, ShapedType newType) {
222   // TODO(b/232866626): DenseElementsAttr::reshape is broken for bool splats.
223   // Once that ticket is fixed, we can remove this conditional.
224   if (attr.isSplat() && newType.getElementType().isInteger(/*width=*/1)) {
225     auto splatValue = attr.getValues<bool>()[0];
226     return DenseElementsAttr::get(newType, {splatValue});
227   }
228   return attr.reshape(newType);
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // Utilities for verifiers
233 //===----------------------------------------------------------------------===//
234 
235 // Convert a 1D dense int64 attribute to a list of values.
convertDenseIntAttr(llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr)236 SmallVector<int64_t> convertDenseIntAttr(
237     llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr) {
238   if (!optionalAttr.has_value()) return SmallVector<int64_t>{};
239 
240   mlir::DenseIntElementsAttr attr = *optionalAttr;
241   auto values = attr.getValues<int64_t>();
242   return {values.begin(), values.end()};
243 }
244 
245 // Convert a 1D or Nx2 dense int64 attribute to a list of tuples.
convertNx2Attribute(llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr,Location loc)246 FailureOr<SmallVector<std::pair<int64_t, int64_t>>> convertNx2Attribute(
247     llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr, Location loc) {
248   if (!optionalAttr.has_value())
249     return SmallVector<std::pair<int64_t, int64_t>>{};
250   mlir::DenseIntElementsAttr attr = *optionalAttr;
251 
252   auto attrType = attr.getType().cast<RankedTensorType>();  // ensured by ODS.
253   if (attrType.getRank() > 1) {
254     if (attrType.getRank() != 2 || attrType.getShape()[1] != 2)
255       return (mlir::emitError(loc) << "expects the shape of padding-attribute "
256                                       "to be {N, 2}, but got {"
257                                    << attrType.getShape() << "}.",
258               failure());
259   } else {
260     // Padding values can be provided as a 1D vector as well.
261     if (attr.getValues<int64_t>().size() % 2 != 0)
262       return (mlir::emitError(loc)
263                   << "expects the padding-entries to have even number of "
264                      "elements, but got "
265                   << attr.getValues<int64_t>().size() << " elements.",
266               failure());
267   }
268 
269   auto it = attr.getValues<int64_t>().begin();
270   SmallVector<std::pair<int64_t, int64_t>> out(attr.getNumElements() / 2);
271   for (auto& item : out) {
272     int64_t first = *it;
273     ++it;
274     int64_t second = *it;
275     ++it;
276     item = {first, second};
277   }
278   return out;
279 }
280 
281 // If a window with the given bound in some dimension is dilated with the given
282 // dilation factor in that dimension, then the value returned is the bound for
283 // the array in that dimension after dilation.
284 //
285 // For a 1D array with 3 entries 1, 2, 3, a dilation factor of 2 yields a new
286 // window with values 1, x, 2, x, 3, where x indicates holes left by the
287 // dilation. So DilatedBound(3, 2) == 5.
dilatedBound(int64_t bound,int64_t dilation)288 int64_t dilatedBound(int64_t bound, int64_t dilation) {
289   assert(bound >= 0 && "The dimension to dialate must be >= 0");
290   if (bound == 0) return 0;
291 
292   // Suppose the array has three entries 123 and the dilation factor is 4. Then
293   // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except
294   // the last expands into 4 entries, so that is (bound - 1) * dilation. Then we
295   // add 1 to account for the final input element.
296   return (bound - 1) * dilation + 1;
297 }
298 
299 // Returns the number of valid positions of a window with the given size and
300 // stride within an array with the given bound. This is the bound of an output
301 // array with one element per valid position of the window.
302 //
303 // For example, for arguments of (bound=5, window_size=2, stride=2), the
304 // returned value is 2. There are valid positions at offset 0 and offset 2,
305 // while offset 4 is not valid since the window's last entry would be at 5,
306 // which is beyond the bound of 5.
stridedBound(int64_t bound,int64_t windowSize,int64_t stride)307 int64_t stridedBound(int64_t bound, int64_t windowSize, int64_t stride) {
308   assert(windowSize >= 0 && "Expected window size to be >= 0");
309   assert(bound >= 0 && "Expected bound to be >= 0");
310 
311   if (bound == 0 || windowSize > bound) return 0;
312 
313   // Without considering stride, the maximum valid offset is bound -
314   // window_size. Taking stride into account, the valid offsets then have the
315   // form q * stride for q = 0, ..., Q such that q * stride <= bound -
316   // window_size. This implies that Q equals floor(bound - window_size /
317   // stride). There are Q + 1 valid values of q, yielding the formula below.
318   return (bound - windowSize) / stride + 1;
319 }
320 
321 // WindowDimension described how the kernel window moves across the base area
322 // in a particular dimension.
323 // Describes the windowing in an operation such as convolution.
324 // The window is moved across a base area and for each position of the
325 // window a computation is performed. The field below describes the
326 // window and the movement of the window across a base area.
327 struct WindowDimension {
328   int64_t size = 0;
329   int64_t stride = 1;
330   int64_t paddingLow = 0;
331   int64_t paddingHigh = 0;
332   int64_t windowDilation = 1;
333   int64_t baseDilation = 1;
334   bool windowReversal = false;
335 };
336 
337 // Verifies various properties of window-attributes (viz., stride, padding,
338 // lhs_dilation and rhs_dilation) and collects all the window-attributes for
339 // each kernel spatial dimensions.
340 FailureOr<SmallVector<WindowDimension>>
verifyWindowAttributesAndInferWindowDimensions(ArrayRef<int64_t> windowDimensions,ArrayRef<int64_t> windowStrides,ArrayRef<std::pair<int64_t,int64_t>> padding,ArrayRef<int64_t> lhsDilation,ArrayRef<int64_t> rhsDilation,Location loc)341 verifyWindowAttributesAndInferWindowDimensions(
342     ArrayRef<int64_t> windowDimensions, ArrayRef<int64_t> windowStrides,
343     ArrayRef<std::pair<int64_t, int64_t>> padding,
344     ArrayRef<int64_t> lhsDilation, ArrayRef<int64_t> rhsDilation,
345     Location loc) {
346   const auto verifySize = [&](const size_t attrSize,
347                               StringRef attrName) -> LogicalResult {
348     if (attrSize == 0 || attrSize == windowDimensions.size()) return success();
349     return mlir::emitError(loc)
350            << "expects " << attrName
351            << " to have same dimension-size as size of "
352               "window dimensions "
353               "("
354            << windowDimensions.size() << "), but got: " << attrSize << ".";
355   };
356 
357   if (failed(verifySize(windowStrides.size(), "window-strides")))
358     return failure();
359   if (failed(verifySize(lhsDilation.size(), "base-dilation factors")))
360     return failure();
361   if (failed(verifySize(rhsDilation.size(), "window-dilation factors")))
362     return failure();
363   if (failed(verifySize(padding.size(), "padding-entries"))) return failure();
364 
365   SmallVector<WindowDimension> window(windowDimensions.size());
366   for (size_t i = 0; i < windowDimensions.size(); i++) {
367     WindowDimension& dim = window[i];
368 
369     dim.size = windowDimensions[i];
370     if (!isDynamicDimSize(dim.size) && dim.size <= 0)
371       return (mlir::emitError(loc)
372                   << "expects window to have positive value for " << i
373                   << "-th window dimension, but got " << dim.size << ".",
374               failure());
375 
376     if (!windowStrides.empty()) dim.stride = windowStrides[i];
377     if (dim.stride <= 0)
378       return (mlir::emitError(loc)
379                   << "expects window to have positive stride for " << i
380                   << "-th window dimension, but got " << dim.stride << ".",
381               failure());
382 
383     if (!lhsDilation.empty()) dim.baseDilation = lhsDilation[i];
384     if (dim.baseDilation <= 0)
385       return (mlir::emitError(loc) << "expects window to have positive base "
386                                       "dilation factor for "
387                                    << i << "-th window dimension, but got "
388                                    << dim.baseDilation << ".",
389               failure());
390 
391     if (!rhsDilation.empty()) dim.windowDilation = rhsDilation[i];
392     if (dim.windowDilation <= 0)
393       return (mlir::emitError(loc) << "expects window to have positive window "
394                                       "dilation factor for "
395                                    << i << "-th window dimension, but got "
396                                    << dim.windowDilation << ".",
397               failure());
398 
399     if (!padding.empty()) {
400       dim.paddingLow = padding[i].first;
401       dim.paddingHigh = padding[i].second;
402     }
403   }
404 
405   return window;
406 }
407 
408 // Infer the shape of the output window.
409 //  Foreach dimension d,
410 //    output-window-shape[d] =
411 //            stridedBound(padding_low + dilatedBound(base_shape[d]) +
412 //            padding_high,
413 //                         dilatedBound(window_shape[d]))
414 //      where (padding_low, padding_high) is the padding-pair for d.
inferWindowOutputShape(const ArrayRef<int64_t> baseShape,const ArrayRef<WindowDimension> window)415 SmallVector<int64_t> inferWindowOutputShape(
416     const ArrayRef<int64_t> baseShape, const ArrayRef<WindowDimension> window) {
417   assert(baseShape.size() == window.size() &&
418          "Size of window dimensions must match the size of base shape.");
419 
420   SmallVector<int64_t> outputDimensions(window.size());
421   for (int64_t i = 0; i < static_cast<int64_t>(window.size()); ++i) {
422     if (isDynamicDimSize(baseShape[i]) || isDynamicDimSize(window[i].size)) {
423       outputDimensions[i] = ShapedType::kDynamicSize;
424     } else {
425       const auto& dim = window[i];
426 
427       const int64_t dilatedBase = dilatedBound(baseShape[i], dim.baseDilation);
428       const int64_t paddedDilatedBase =
429           dim.paddingLow + dilatedBase + dim.paddingHigh;
430       const int64_t dilatedWindow = dilatedBound(dim.size, dim.windowDilation);
431 
432       outputDimensions[i] =
433           stridedBound(paddedDilatedBase, dilatedWindow, dim.stride);
434     }
435   }
436 
437   return outputDimensions;
438 }
439 
440 // Return true if type1 and type2 are tensors and have the same
441 // element-type, else return false. With float element-types, ignore comparing
442 // floating-point precision if ignoreFpPrecision is True.
tensorsHaveSameElType(Type type1,Type type2,bool ignoreFpPrecision)443 bool tensorsHaveSameElType(Type type1, Type type2, bool ignoreFpPrecision) {
444   auto tensorTy1 = type1.dyn_cast<TensorType>();
445   auto tensorTy2 = type2.dyn_cast<TensorType>();
446 
447   if (!tensorTy1 || !tensorTy2) return false;
448 
449   if (ignoreFpPrecision && tensorTy1.getElementType().isa<FloatType>() &&
450       tensorTy2.getElementType().isa<FloatType>())
451     return true;
452 
453   return tensorTy1.getElementType() == tensorTy2.getElementType();
454 }
455 
456 // Return true if type1 and type2 are shape-compatible and have same element
457 // type. If 'ignoreFpPrecision' is True, then allow floats with different
458 // precisions while checking element-types.
compatibleShapeAndElementType(Type type1,Type type2,bool ignoreFpPrecision=false)459 bool compatibleShapeAndElementType(Type type1, Type type2,
460                                    bool ignoreFpPrecision = false) {
461   if (failed(verifyCompatibleShape(type1, type2))) return false;
462   return tensorsHaveSameElType(type1.cast<ShapedType>(),
463                                type2.cast<ShapedType>(), ignoreFpPrecision);
464 }
465 
verifyReducerShape(Location loc,Block & block,ArrayRef<TensorType> inputArgTypes,ArrayRef<TensorType> initValueTypes,int64_t numInputs,ArrayRef<int64_t> allowedDimensions,bool allInputsUnranked,SmallVectorImpl<TensorType> & accumulatorSubShapes)466 LogicalResult verifyReducerShape(
467     Location loc, Block& block, ArrayRef<TensorType> inputArgTypes,
468     ArrayRef<TensorType> initValueTypes, int64_t numInputs,
469     ArrayRef<int64_t> allowedDimensions, bool allInputsUnranked,
470     SmallVectorImpl<TensorType>& accumulatorSubShapes) {
471   // Check that the number of reduction-region arguments matches with that of
472   // reduce-op's arguments.
473   if (static_cast<int64_t>(block.getArguments().size()) != numInputs * 2)
474     return mlir::emitError(loc)
475            << "Reduction-region must take " << numInputs * 2
476            << " parameters, but takes " << block.getArguments().size()
477            << " parameter(s)";
478 
479   // Check if the reduction-region produces non-zero outputs.
480   if (block.getTerminator()->getOperands().empty())
481     return mlir::emitError(loc)
482            << "The reduction-region expected to return some value(s)";
483 
484   // Check that the reduction-region returns list- of tensors.
485   // The number of result-tensors must match the `numInputs`.
486   if (static_cast<int64_t>(block.getTerminator()->getOperands().size()) !=
487       numInputs)
488     return mlir::emitError(loc)
489            << "Reduction-region here must produce " << numInputs
490            << " tensors, but produces "
491            << block.getTerminator()->getOperands().size() << " instead";
492 
493   for (Value retOperand : block.getTerminator()->getOperands()) {
494     auto tensorTy = retOperand.getType().dyn_cast<TensorType>();
495     if (!tensorTy)
496       return mlir::emitError(loc) << "Reduction-region here must produce "
497                                      "tensor-typed result(s), but "
498                                      "produces "
499                                   << retOperand.getType() << " instead";
500 
501     accumulatorSubShapes.push_back(tensorTy);
502   }
503 
504   // Consider typical reduce-* op syntax:
505   //
506   //      op(I(i), V(j)):
507   //       block(BI(i), BV(j)):
508   //         ... some computation ...
509   //         return(R(i))
510   //
511   // where
512   //  I(i)  : i-th input of op
513   //  V(j)  : j-th init-value of op
514   //  BI(i) : i-th input of reducer-function
515   //  BV(j) : j-th init-value of reducer-function
516   //  R(i)  : i-th return-type
517   //
518   //  Note that: |I(i)| == V(j)| == |BI(i)| == |BV(j)| == |R(i)|
519   //
520   //  Here are the type-constraints among V(j), BI(i), BV(j), and R(i).
521   //    C1 : Check that BI(i) and R(i) have same shape and element-type.
522   //    C2 : Check that BV(j) and R(i) have same shape and element-type.
523   //    C3 : Check that V(j) and R(i) have same shape and element-type.
524   //
525   //  From C1, C2, and C3, we can infer that V(j), BI(i), BV(j), and R(i) all
526   //  have compatible shapes and element-types.
527   //  The next check, C4, adds constraints on how the type if I(i) is related
528   //  to any_of(V(j), BI(i), BV(j), and R(i)), say BV(j);
529   //
530   //  C4.1 : Check that I(i) and BV(j) have same element-type.
531   //  C4.2 : Check that shape of BV(j) is a 'sub-sequence' of
532   //         'allowedDimensions'. 'allowedDimensions' is a list of dimensions
533   //         which any of BI(i), BV(j), and R(i) is allowed to have.
534   for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
535     // Check C1.
536     if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
537                                        block.getArgument(inputIdx).getType()))
538       return mlir::emitError(loc)
539              << "The type of reduction-region's parameter at index " << inputIdx
540              << " is different than the corresponding result type: "
541              << block.getArgument(inputIdx).getType() << " vs "
542              << accumulatorSubShapes[inputIdx];
543 
544     // Check C2.
545     if (!compatibleShapeAndElementType(
546             accumulatorSubShapes[inputIdx],
547             block.getArgument(numInputs + inputIdx).getType(),
548             /*ignoreFpPrecision=*/true))
549       return mlir::emitError(loc)
550              << "The type of reduction-region's parameter at index "
551              << numInputs + inputIdx
552              << " is different than the corresponding result type: "
553              << block.getArgument(numInputs + inputIdx).getType() << " vs "
554              << accumulatorSubShapes[inputIdx];
555 
556     // Check C3.
557     if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
558                                        initValueTypes[inputIdx],
559                                        /*ignoreFpPrecision=*/true))
560       return mlir::emitError(loc)
561              << "The type of reduction-region's result type at index "
562              << inputIdx
563              << " differs from the op's corresponding init-value type: "
564              << accumulatorSubShapes[inputIdx] << " vs "
565              << initValueTypes[inputIdx];
566 
567     // Check C4.1.
568     if (!tensorsHaveSameElType(
569             inputArgTypes[inputIdx],
570             block.getArgument(numInputs + inputIdx).getType(), true))
571       return mlir::emitError(loc)
572              << "The element-type of reduction-region's argument at index "
573              << numInputs + inputIdx << " is expected to be "
574              << inputArgTypes[inputIdx].getElementType() << ", but got "
575              << block.getArgument(numInputs + inputIdx).getType()
576              << " as its type.";
577 
578     // Check C4.2.
579     Type blockArgType = block.getArgument(numInputs + inputIdx).getType();
580     auto blockArgTensorTy = blockArgType.cast<TensorType>();
581 
582     if (allInputsUnranked || !blockArgTensorTy.hasRank()) return success();
583 
584     auto argShape = blockArgTensorTy.getShape();
585     if (argShape.size() > allowedDimensions.size())
586       return mlir::emitError(loc)
587              << "The rank of reduction-region's argument at index "
588              << numInputs + inputIdx
589              << " is expected to be <= " << allowedDimensions.size() << ", got "
590              << argShape.size();
591 
592     int64_t argShapeIdx = 0;
593     for (int64_t outputShapeIdx = 0;
594          outputShapeIdx < static_cast<int64_t>(allowedDimensions.size()) &&
595          argShapeIdx < static_cast<int64_t>(argShape.size());
596          outputShapeIdx++)
597       if (allowedDimensions[outputShapeIdx] == argShape[argShapeIdx])
598         argShapeIdx++;
599 
600     if (argShapeIdx != static_cast<int64_t>(argShape.size()))
601       return mlir::emitError(loc)
602              << "The shape of reduction-region's argument at index "
603              << numInputs + inputIdx
604              << " is not compatible with that of reduce-op's input-parameter "
605                 "at index "
606              << inputIdx;
607   }
608 
609   return success();
610 }
611 
potentiallyComplexBitwidth(Type type)612 unsigned potentiallyComplexBitwidth(Type type) {
613   auto complexTy = type.dyn_cast<ComplexType>();
614   return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth()
615                    : type.getIntOrFloatBitWidth();
616 }
617 }  // namespace
618 
619 //===----------------------------------------------------------------------===//
620 // AllReduceOp
621 //===----------------------------------------------------------------------===//
622 
build(::mlir::OpBuilder & ods_builder,::mlir::OperationState & ods_state,::mlir::Type result_type,::mlir::Value operand,::mlir::DenseIntElementsAttr replica_groups,::mlir::mhlo::ChannelHandleAttr channel_handle)623 void AllReduceOp::build(
624     ::mlir::OpBuilder& ods_builder, ::mlir::OperationState& ods_state,
625     ::mlir::Type result_type, ::mlir::Value operand,
626     ::mlir::DenseIntElementsAttr replica_groups,
627     /*optional*/ ::mlir::mhlo::ChannelHandleAttr channel_handle) {
628   AllReduceOp::build(ods_builder, ods_state, result_type, operand,
629                      replica_groups, channel_handle, nullptr);
630 }
631 
632 //===----------------------------------------------------------------------===//
633 // ReduceScatterOp
634 //===----------------------------------------------------------------------===//
635 
verify()636 LogicalResult ReduceScatterOp::verify() {
637   if (failed(mlir::hlo::verifyReplicaGroups(*this, /*is_uniform_sized=*/true)))
638     return failure();
639   auto operandType = operand().getType().cast<TensorType>();
640   bool operandTypeRanked = operandType.isa<RankedTensorType>();
641   Block& block = computation().front();
642   SmallVector<TensorType> accumulatorSubshapes;
643   if (failed(verifyReducerShape(
644           this->getLoc(), block, {operandType},
645           {RankedTensorType::get({}, operandType.getElementType())},
646           /*numInputs=*/1, /*allowedDimensions=*/{},
647           /*allInputsUnranked=*/!operandTypeRanked, accumulatorSubshapes)))
648     return failure();
649 
650   return mlir::hlo::verifyReduceScatter(
651       *this,
652       /*operand_types=*/{operand().getType()},
653       /*result_types=*/{getType()},
654       /*scatter_dimension=*/scatter_dimension());
655 }
656 
657 //===----------------------------------------------------------------------===//
658 // CompatibleOperandsAndResultType
659 //===----------------------------------------------------------------------===//
660 
661 // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that
662 // support quantization or sparsity.
663 #define INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Op)                        \
664   LogicalResult Op::inferReturnTypeComponents(                                \
665       MLIRContext* context, Optional<Location> location,                      \
666       ValueShapeRange operands, DictionaryAttr attributes,                    \
667       RegionRange regions,                                                    \
668       SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {          \
669     return inferReturnTypeComponentsFromOperands(context, location, operands, \
670                                                  attributes, regions,         \
671                                                  inferredReturnShapes);       \
672   }
673 
674 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AddOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AllReduceOp)675 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AllReduceOp)
676 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AndOp)
677 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op)
678 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp)
679 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CeilOp)
680 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ClzOp)
681 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectivePermuteOp)
682 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CopyOp)
683 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CosineOp)
684 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CrossReplicaSumOp)
685 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DivOp)
686 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DomainOp)
687 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ExpOp)
688 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Expm1Op)
689 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(FloorOp)
690 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LogOp)
691 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Log1pOp)
692 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LogisticOp)
693 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MaxOp)
694 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MinOp)
695 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MulOp)
696 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NegOp)
697 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NotOp)
698 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(OrOp)
699 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PopulationCountOp)
700 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PowOp)
701 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ReducePrecisionOp)
702 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RemOp)
703 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ReverseOp)
704 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RoundNearestEvenOp)
705 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RoundOp)
706 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RsqrtOp)
707 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftLeftOp)
708 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftRightArithmeticOp)
709 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftRightLogicalOp)
710 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SignOp)
711 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SineOp)
712 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SqrtOp)
713 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SubtractOp)
714 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(TanhOp)
715 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(XorOp)
716 
717 //===----------------------------------------------------------------------===//
718 // ConstantOp
719 //===----------------------------------------------------------------------===//
720 
721 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
722   assert(operands.empty() && "constant has no operands");
723 
724   // Return the held attribute value.
725   return value();
726 }
727 
728 // Builds a constant op with the specified attribute `value`.
build(OpBuilder &,OperationState & result,Attribute value)729 void ConstantOp::build(OpBuilder& /*builder*/, OperationState& result,
730                        Attribute value) {
731   Type type;
732   if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
733     type = elemAttr.getType();
734   } else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
735     // All XLA types must be tensor types. In the build() method, we want to
736     // provide more flexibility by allowing attributes of scalar types. But we
737     // need to wrap it up with ElementsAttr to construct valid XLA constants.
738     type =
739         RankedTensorType::get(/*shape=*/{}, value.cast<TypedAttr>().getType());
740     value = DenseElementsAttr::get(type.cast<TensorType>(), value);
741   } else if (auto complexAttr = value.dyn_cast<complex::NumberAttr>()) {
742     type = RankedTensorType::get(/*shape=*/{},
743                                  complexAttr.cast<TypedAttr>().getType());
744     value =
745         DenseElementsAttr::get(type.cast<TensorType>(), complexAttr.getValue());
746   }
747 
748   // TODO: support other XLA specific types.
749   assert(type && "unsupported attribute type for building mhlo.constant");
750   result.types.push_back(type);
751   result.addAttribute("value", value);
752 }
753 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)754 LogicalResult ConstantOp::inferReturnTypes(
755     MLIRContext*, Optional<Location>, ValueRange operands,
756     DictionaryAttr attributes, RegionRange,
757     SmallVectorImpl<Type>& inferredReturnTypes) {
758   ConstantOpAdaptor adaptor(operands, attributes);
759   Type type = adaptor.value().getType();
760   inferredReturnTypes.push_back(type);
761   return success();
762 }
763 
isCompatibleReturnTypes(TypeRange l,TypeRange r)764 bool ConstantOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
765   if (l.size() != r.size() || l.size() != 1) return false;
766   auto lhsTy = l.front().cast<TensorType>();
767   auto rhsTy = r.front().cast<TensorType>();
768   // For comparisons of the uniform quantized element based tensor type, use the
769   // storage type since the constant value will be stored through the underlying
770   // storage type.
771   if (auto rhsElemTy =
772           rhsTy.getElementType().dyn_cast<quant::QuantizedType>()) {
773     rhsTy = getSameShapeTensorType(rhsTy, rhsElemTy.getStorageType());
774   }
775   return lhsTy == rhsTy;
776 }
777 
parse(OpAsmParser & parser,OperationState & result)778 ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) {
779   // Parse the generic form.
780   if (succeeded(parser.parseOptionalLParen())) {
781     if (parser.parseRParen()) return failure();
782     if (parser.parseOptionalAttrDict(result.attributes)) return failure();
783     if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() ||
784         parser.parseArrow())
785       return failure();
786     Type resultTy;
787     if (parser.parseType(resultTy)) {
788       return failure();
789     }
790     result.addTypes(resultTy);
791     return success();
792   }
793 
794   ElementsAttr valueAttr;
795   if (parser.parseOptionalAttrDict(result.attributes)) return failure();
796 
797   if (parser.parseCustomAttributeWithFallback(valueAttr, Type{}, "value",
798                                               result.attributes)) {
799     return failure();
800   }
801   result.addTypes(valueAttr.getType());
802   return success();
803 }
804 
805 /// Print a `constant` op.
806 ///
807 /// op ::= attr-dict $value
808 ///
809 /// When the `value` and `output` have different type, it just uses the default
810 /// operator assembly format as a fallback.
print(::mlir::OpAsmPrinter & p)811 void ConstantOp::print(::mlir::OpAsmPrinter& p) {
812   // If not all types are the same, use generic form.
813   if (value().getType() != getType()) {
814     p.printGenericOp(getOperation(), /*printOpName=*/false);
815     return;
816   }
817 
818   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
819   p << ' ';
820   p.printStrippedAttrOrType(valueAttr());
821 }
822 
823 //===----------------------------------------------------------------------===//
824 // CustomCallOp
825 //===----------------------------------------------------------------------===//
826 
verify()827 LogicalResult CustomCallOp::verify() {
828   // If both operand and result layout attributes are not specified then nothing
829   // to verify.
830   if (!operand_layouts().has_value() && !result_layouts().has_value())
831     return success();
832 
833   // Layout constraints for either both operands & results or none should be
834   // specified.
835   if (operand_layouts().has_value() != result_layouts().has_value())
836     return emitOpError() << "Layout attributes should be specified for "
837                             "either both operands and results or none.";
838 
839   // Helper function to verify types and the corresponding layouts.
840   auto verifyTypesAndLayouts =
841       [this](TypeRange types, mlir::ArrayAttr layouts,
842              const std::string& valueName) -> LogicalResult {
843     if (types.size() != layouts.size())
844       return emitOpError() << "Number of " << valueName
845                            << "s must match the number of " << valueName
846                            << " layouts, " << types.size()
847                            << " != " << layouts.size();
848 
849     for (const auto& indexedTypeAndLayout :
850          llvm::enumerate(llvm::zip(types, layouts))) {
851       // Get index for more descriptive error message.
852       auto index = indexedTypeAndLayout.index();
853 
854       auto type = std::get<0>(indexedTypeAndLayout.value());
855       auto layout = std::get<1>(indexedTypeAndLayout.value())
856                         .cast<DenseIntElementsAttr>();
857 
858       if (type.isa<TupleType>())
859         return emitOpError() << "Tuple types are not fully supported with "
860                                 "layout constraints yet";
861       auto tensorType = type.dyn_cast<TensorType>();
862 
863       // For non-tensor types such as !mhlo.token, the layout should be empty.
864       if (!tensorType) {
865         if (layout.empty()) continue;
866         return emitOpError()
867                << "Only tensor types can have non-empty layout: " << valueName
868                << " #" << index << " of type " << type << " has layout "
869                << layout;
870       }
871 
872       // For unranked tensors, we cannot verify the compatibility with layout
873       // any further.
874       if (!tensorType.hasRank()) continue;
875 
876       // Layout must be a permutation of [0, N) where N is the rank of the
877       // tensor type.
878       std::vector<int64_t> range(tensorType.getRank());
879       std::iota(range.begin(), range.end(), 0);
880       if (tensorType.getRank() != layout.size() ||
881           !std::is_permutation(range.begin(), range.end(), layout.begin()))
882         return emitOpError() << "incorrect layout " << layout << " for type "
883                              << type << ", layout must be a permutation of [0, "
884                              << tensorType.getRank() << ")";
885     }
886     return success();
887   };
888 
889   // At this point both `operand_layouts` and `result_layouts` are defined.
890   ArrayAttr operandLayouts = this->operand_layouts().value();
891   ArrayAttr resultLayouts = this->result_layouts().value();
892 
893   // Full support for layouts for arbitrary nesting of tuples is not
894   // supported yet.
895   //
896   // If result does not have any tuples, then i-th element of `result_layouts`
897   // specifies the layout constraints on i-th result.
898   //
899   // For the common case of a single tuple result packing non-tuple values, the
900   // i-th element of `result_layouts` specifies layout for i-th element of the
901   // result tuple.
902   TypeRange resultTypes;
903   if (getNumResults() == 1 && getResult(0).getType().isa<TupleType>())
904     resultTypes = getResult(0).getType().cast<TupleType>().getTypes();
905   else
906     resultTypes = getResultTypes();
907 
908   // Verify that operands and operand layouts match.
909   if (failed(
910           verifyTypesAndLayouts(getOperandTypes(), operandLayouts, "operand")))
911     return failure();
912 
913   // Verify that results and result layouts match.
914   return verifyTypesAndLayouts(resultTypes, resultLayouts, "result");
915 }
916 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)917 void CustomCallOp::getEffects(
918     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>&
919         effects) {
920   // CustomCall has "all possible effects" unless the has_side_effect is present
921   // and set to false.
922   auto hasSideEffect = (*this)->getAttrOfType<BoolAttr>("has_side_effect");
923   if (hasSideEffect && !hasSideEffect.getValue()) return;
924   effects.emplace_back(MemoryEffects::Allocate::get());
925   effects.emplace_back(MemoryEffects::Free::get());
926   effects.emplace_back(MemoryEffects::Write::get());
927   effects.emplace_back(MemoryEffects::Read::get());
928 }
929 
930 //===----------------------------------------------------------------------===//
931 // CholeskyOp
932 //===----------------------------------------------------------------------===//
933 
934 // The following properties are already enforced by the ODS:
935 //   P0. a.element_type is floating or complex
936 // We intend to verify the following properties
937 //   P1. The 'a' argument to Cholesky must have rank >= 2, got shape %s
938 //   P2. The two minor dimensions of 'a' must have equal size, got %s.
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)939 LogicalResult CholeskyOp::inferReturnTypeComponents(
940     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
941     DictionaryAttr attributes, RegionRange regions,
942     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
943   CholeskyOp::Adaptor adaptor(operands, attributes, regions);
944   Type aType = adaptor.a().getType();
945   RankedTensorType aRankedType = aType.dyn_cast<RankedTensorType>();
946   if (!aRankedType) {
947     inferredReturnShapes.emplace_back(
948         aType.cast<TensorType>().getElementType());
949     return success();
950   }
951 
952   ArrayRef<int64_t> aShape = aRankedType.getShape();
953   if (aShape.size() < 2) {
954     return emitOptionalError(
955         location, "argument 'a' must have rank >= 2, got shape ", aShape, ".");
956   }
957 
958   int64_t lastDim = aShape[aShape.size() - 1];
959   int64_t penultimateDim = aShape[aShape.size() - 2];
960   if (!isDynamicDimSize(lastDim) && !isDynamicDimSize(penultimateDim) &&
961       lastDim != penultimateDim) {
962     return emitOptionalError(
963         location, "minor dimensions of 'a' must have equal size, got shape ",
964         aShape, ".");
965   }
966   inferredReturnShapes.emplace_back(aRankedType.getShape(),
967                                     aRankedType.getElementType());
968   return success();
969 }
970 
971 //===----------------------------------------------------------------------===//
972 // DotOp
973 //===----------------------------------------------------------------------===//
974 namespace {
dimCompatible(int64_t a,int64_t b)975 bool dimCompatible(int64_t a, int64_t b) {
976   return isDynamicDimSize(a) || isDynamicDimSize(b) || a == b;
977 }
978 
inferDotReturnType(ShapedType lhs,ShapedType rhs)979 ShapedType inferDotReturnType(ShapedType lhs, ShapedType rhs) {
980   auto elementType = lhs.getElementType();
981   if (!lhs.hasRank() || !rhs.hasRank()) {
982     return UnrankedTensorType::get(elementType);
983   }
984 
985   // vector dot vector
986   if (1 == lhs.getRank() && 1 == rhs.getRank() &&
987       dimCompatible(lhs.getDimSize(0), rhs.getDimSize(0))) {
988     return RankedTensorType::get({}, elementType);
989   }
990   // matrix dot vector
991   if (2 == lhs.getRank() && 1 == rhs.getRank() &&
992       dimCompatible(lhs.getDimSize(1), rhs.getDimSize(0))) {
993     return RankedTensorType::get({lhs.getDimSize(0)}, elementType);
994   }
995   // vector dot matrix
996   if (1 == lhs.getRank() && 2 == rhs.getRank() &&
997       dimCompatible(lhs.getDimSize(0), rhs.getDimSize(0))) {
998     return RankedTensorType::get({rhs.getDimSize(1)}, elementType);
999   }
1000   // matrix dot matrix
1001   if (2 == lhs.getRank() && 2 == rhs.getRank() &&
1002       dimCompatible(lhs.getDimSize(1), rhs.getDimSize(0))) {
1003     int64_t shape[2] = {lhs.getDimSize(0), rhs.getDimSize(1)};
1004     return RankedTensorType::get(shape, elementType);
1005   }
1006   return {};
1007 }
1008 }  // namespace
1009 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1010 LogicalResult DotOp::inferReturnTypes(
1011     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1012     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1013   DotOp::Adaptor op(operands);
1014   auto lhsType = op.lhs().getType().cast<ShapedType>();
1015   auto rhsType = op.rhs().getType().cast<ShapedType>();
1016   inferredReturnTypes.push_back(inferDotReturnType(lhsType, rhsType));
1017   return success();
1018 }
1019 
verify()1020 LogicalResult DotOp::verify() {
1021   auto lhsType = lhs().getType().cast<ShapedType>();
1022   auto rhsType = rhs().getType().cast<ShapedType>();
1023   auto resultType = getType().cast<ShapedType>();
1024   auto expectReturnType = inferDotReturnType(lhsType, rhsType);
1025   if (!expectReturnType) {
1026     return emitError() << "Unexpected operands type: " << lhsType << " and "
1027                        << rhsType;
1028   }
1029   if (resultType.hasRank() && expectReturnType.hasRank()) {
1030     if (resultType.getShape() != expectReturnType.getShape()) {
1031       return emitError() << "Unexpected result type: has " << resultType
1032                          << " but inferred " << expectReturnType
1033                          << " from operands " << lhsType << " and " << rhsType;
1034     }
1035   }
1036   return success();
1037 }
1038 
1039 //===----------------------------------------------------------------------===//
1040 // DotGeneralOp
1041 //===----------------------------------------------------------------------===//
1042 
verify()1043 LogicalResult DotGeneralOp::verify() {
1044   auto dimNumbers = this->dot_dimension_numbers();
1045 
1046   ArrayRef<int64_t> lhsBatchingDims = dimNumbers.getLhsBatchingDimensions();
1047   ArrayRef<int64_t> rhsBatchingDims = dimNumbers.getRhsBatchingDimensions();
1048   ArrayRef<int64_t> lhsContractingDims =
1049       dimNumbers.getLhsContractingDimensions();
1050   ArrayRef<int64_t> rhsContractingDims =
1051       dimNumbers.getRhsContractingDimensions();
1052 
1053   if (lhsBatchingDims.size() != rhsBatchingDims.size()) {
1054     return emitOpError() << "lhs and rhs should have the same number of "
1055                             "batching dimensions";
1056   }
1057   if (lhsContractingDims.size() != rhsContractingDims.size()) {
1058     return emitOpError() << "lhs and rhs should have the same number of "
1059                             "contracting dimensions";
1060   }
1061 
1062   llvm::SmallDenseSet<int64_t> dimSet;
1063 
1064   auto checkDimsDistinct =
1065       [this](ArrayRef<int64_t> batchingDims, ArrayRef<int64_t> contractingDims,
1066              llvm::SmallDenseSet<int64_t>& dimSet, llvm::StringRef lhs,
1067              llvm::StringRef rhs) -> LogicalResult {
1068     auto dims = llvm::concat<const int64_t>(batchingDims, contractingDims);
1069     for (auto dim : dims) {
1070       auto [_, wasInserted] = dimSet.insert(dim);
1071       if (!wasInserted) {
1072         return emitOpError() << "has duplicated dimension from " << lhs
1073                              << " and " << rhs << ": " << dim;
1074       }
1075     }
1076     return success();
1077   };
1078 
1079   if (failed(checkDimsDistinct(lhsBatchingDims, lhsContractingDims, dimSet,
1080                                "lhs_batching_dimensions",
1081                                "lhs_contracting_dimensions"))) {
1082     return failure();
1083   }
1084   dimSet.clear();
1085   if (failed(checkDimsDistinct(rhsBatchingDims, rhsContractingDims, dimSet,
1086                                "rhs_batching_dimensions",
1087                                "rhs_contracting_dimensions"))) {
1088     return failure();
1089   }
1090 
1091   auto checkDimsInRange = [this](int64_t rank, ArrayRef<int64_t> dims,
1092                                  llvm::StringRef dimName) -> LogicalResult {
1093     auto inRange = [&](int64_t i) -> bool { return 0 <= i && i < rank; };
1094     const auto* dimsNotInRange =
1095         std::find_if_not(dims.begin(), dims.end(), inRange);
1096     if (dimsNotInRange != dims.end()) {
1097       return emitOpError() << dimName << " value: " << *dimsNotInRange
1098                            << " is out of range: "
1099                            << "[0, " << rank << ")";
1100     }
1101     return success();
1102   };
1103 
1104   auto lhsType = this->lhs().getType().dyn_cast<RankedTensorType>();
1105   auto rhsType = this->rhs().getType().dyn_cast<RankedTensorType>();
1106 
1107   if (lhsType) {
1108     if (failed(checkDimsInRange(lhsType.getRank(), lhsBatchingDims,
1109                                 "lhs_batching_dimensions")) ||
1110         failed(checkDimsInRange(lhsType.getRank(), lhsContractingDims,
1111                                 "lhs_contracting_dimensions"))) {
1112       return failure();
1113     }
1114   }
1115   if (rhsType) {
1116     if (failed(checkDimsInRange(rhsType.getRank(), rhsBatchingDims,
1117                                 "rhs_batching_dimensions")) ||
1118         failed(checkDimsInRange(rhsType.getRank(), rhsContractingDims,
1119                                 "rhs_contracting_dimensions"))) {
1120       return failure();
1121     }
1122   }
1123 
1124   if (lhsType && rhsType) {
1125     // Dimension sizes must be compatible for lhs/rhs.
1126     auto lhsShape = lhsType.getShape();
1127     auto rhsShape = rhsType.getShape();
1128 
1129     for (auto [lhs, rhs] : llvm::zip(lhsBatchingDims, rhsBatchingDims)) {
1130       if (lhsShape[lhs] != rhsShape[rhs]) {
1131         return emitOpError() << "batching dimension sizes must match for "
1132                                 "lhs/rhs";
1133       }
1134     }
1135     for (auto [lhs, rhs] : llvm::zip(lhsContractingDims, rhsContractingDims)) {
1136       if (lhsShape[lhs] != rhsShape[rhs]) {
1137         return emitOpError() << "contracting dimension sizes must match for "
1138                                 "lhs/rhs";
1139       }
1140     }
1141   }
1142   return success();
1143 }
1144 
1145 namespace {
1146 // Handle the generic case of DotGeneral and convert to a regulat DotOp.
1147 struct DotGeneralToDot : public OpRewritePattern<DotGeneralOp> {
1148   using OpRewritePattern<DotGeneralOp>::OpRewritePattern;
1149 
matchAndRewritemlir::mhlo::__anon00baf10a0911::DotGeneralToDot1150   LogicalResult matchAndRewrite(DotGeneralOp dot,
1151                                 PatternRewriter& rewriter) const override {
1152     auto lhs = dot.lhs();
1153     auto rhs = dot.rhs();
1154     auto lhsTy = lhs.getType().cast<ShapedType>();
1155     auto rhsTy = rhs.getType().cast<ShapedType>();
1156 
1157     if (lhsTy.getRank() != 2) return failure();
1158     if (rhsTy.getRank() != 2) return failure();
1159 
1160     auto nums = dot.dot_dimension_numbers();
1161     if (!nums.getLhsBatchingDimensions().empty()) return failure();
1162     if (!nums.getRhsBatchingDimensions().empty()) return failure();
1163 
1164     auto lhsContract = nums.getLhsContractingDimensions();
1165     auto rhsContract = nums.getRhsContractingDimensions();
1166     if (lhsContract.size() != 1 || rhsContract.size() != 1) return failure();
1167 
1168     if (lhsContract.front() != 1) return failure();
1169     if (rhsContract.front() != 0) return failure();
1170 
1171     rewriter.replaceOpWithNewOp<mhlo::DotOp>(
1172         dot, dot.getType(), lhs, rhs, dot.precision_config().value_or(nullptr));
1173 
1174     return success();
1175   }
1176 };
1177 }  // namespace
1178 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1179 void DotGeneralOp::getCanonicalizationPatterns(RewritePatternSet& results,
1180                                                MLIRContext* context) {
1181   results.add<DotGeneralToDot>(context);
1182 }
1183 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1184 LogicalResult DotGeneralOp::reifyReturnTypeShapes(
1185     OpBuilder& builder, ValueRange operands,
1186     SmallVectorImpl<Value>& reifiedReturnShapes) {
1187   auto lhsType = lhs().getType().dyn_cast<ShapedType>();
1188   auto rhsType = rhs().getType().dyn_cast<ShapedType>();
1189   if (!lhsType || !rhsType) {
1190     return failure();
1191   }
1192 
1193   Adaptor adaptor(operands);
1194   auto dimNumbers = dot_dimension_numbers();
1195   SmallVector<Value> dimensions;
1196   for (const int64_t lhsDim : dimNumbers.getLhsBatchingDimensions()) {
1197     dimensions.push_back(
1198         builder.create<tensor::DimOp>(getLoc(), adaptor.lhs(), lhsDim));
1199   }
1200 
1201   for (int64_t i = 0; i < lhsType.getRank(); i++) {
1202     if (!llvm::is_contained(dimNumbers.getLhsContractingDimensions(), i) &&
1203         !llvm::is_contained(dimNumbers.getLhsBatchingDimensions(), i)) {
1204       dimensions.push_back(
1205           builder.create<tensor::DimOp>(getLoc(), adaptor.lhs(), i));
1206     }
1207   }
1208   for (int64_t i = 0; i < rhsType.getRank(); i++) {
1209     if (!llvm::is_contained(dimNumbers.getRhsContractingDimensions(), i) &&
1210         !llvm::is_contained(dimNumbers.getRhsBatchingDimensions(), i)) {
1211       dimensions.push_back(
1212           builder.create<tensor::DimOp>(getLoc(), adaptor.rhs(), i));
1213     }
1214   }
1215 
1216   reifiedReturnShapes.push_back(
1217       builder.create<tensor::FromElementsOp>(getLoc(), dimensions));
1218   return success();
1219 }
1220 
1221 //===----------------------------------------------------------------------===//
1222 // FftOp
1223 //===----------------------------------------------------------------------===//
1224 
1225 // We intend to verify the following properties
1226 // P1. 1 <= rank <= 3
1227 // P2. Element types agree with fft_type
1228 // P3. Operand shape dimensions agree with fft_length for the given fft_type
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1229 LogicalResult FftOp::inferReturnTypeComponents(
1230     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
1231     DictionaryAttr attributes, RegionRange regions,
1232     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
1233   FftOp::Adaptor adaptor(operands, attributes, regions);
1234   auto fftLength = adaptor.fft_length().getValues<int64_t>();
1235   int64_t fftRank = fftLength.size();
1236 
1237   // P1.
1238   if (fftRank > 3 || fftRank < 1) {
1239     return emitOptionalError(location, "rank must be between 1 and 3, but got ",
1240                              fftRank, ".");
1241   }
1242 
1243   // P2. Element type agreement
1244   // FFT : C -> C
1245   // IFFT : C -> C
1246   // RFFT : R -> C
1247   // IRFFT : C -> R
1248   auto fftType = adaptor.fft_type();
1249   auto operandType = adaptor.operand().getType().cast<TensorType>();
1250   Type operandElementType = operandType.getElementType();
1251   // Check the input element type and infer return element type
1252   if (fftType == FftType::RFFT) {
1253     if (!operandElementType.isF32() && !operandElementType.isF64()) {
1254       return emitOptionalError(
1255           location, "RFFT requires f32 or f64 input type, but is given ",
1256           operandElementType, ".");
1257     }
1258   } else {
1259     if (!operandElementType.isa<ComplexType>()) {
1260       return emitOptionalError(
1261           location, stringifyFftType(fftType),
1262           " takes a complex tensor as input, but is given ", operandType, ".");
1263     }
1264   }
1265   // Generate the output element type
1266   Type resultElementType = operandElementType;
1267   if (fftType == FftType::RFFT) {  // RFFT : R -> C
1268     resultElementType = ComplexType::get(resultElementType);
1269   } else if (fftType == FftType::IRFFT) {  // IRFFT : C -> R
1270     resultElementType = operandElementType.cast<ComplexType>().getElementType();
1271   }
1272 
1273   // P3. Check input shape and infer return shape
1274   operandType = operandType.dyn_cast<RankedTensorType>();
1275   if (!operandType) {
1276     inferredReturnShapes.emplace_back(resultElementType);
1277     return success();
1278   }
1279   auto operandShape = operandType.getShape();
1280   if (static_cast<int64_t>(operandShape.size()) < fftRank) {
1281     return emitOptionalError(
1282         location, "operand rank must not be less than fft rank of ", fftRank,
1283         " for operand of type ", operandType, ".");
1284   }
1285 
1286   SmallVector<int64_t> resultShape = to_vector(operandShape);
1287 
1288   if (fftType == FftType::RFFT) {
1289     auto shapeBack = operandShape.take_back(fftRank);
1290     for (auto [operandDim, fftDim] : llvm::zip(shapeBack, fftLength)) {
1291       if (operandDim != fftDim) {
1292         return emitOptionalError(
1293             location,
1294             "RFFT requires innermost dimensions match fft_length. Got: ",
1295             operandShape, " but wanted ", fftLength, ".");
1296       }
1297     }
1298     if (fftLength[fftRank - 1] != 0) {
1299       resultShape[resultShape.size() - 1] = fftLength[fftRank - 1] / 2 + 1;
1300     }
1301   }
1302   if (fftType == FftType::IRFFT) {
1303     auto shapeBack = operandShape.take_back(fftRank).drop_back();
1304     for (auto [operandDim, fftDim] : llvm::zip(shapeBack, fftLength)) {
1305       if (operandDim != fftDim) {
1306         return emitOptionalError(location,
1307                                  "IRFFT requires non-final dimensions "
1308                                  "match fft_length. Got: ",
1309                                  operandShape, " but wanted ", fftLength,
1310                                  ", and ", operandDim, " != ", fftDim, ".");
1311       }
1312     }
1313     if ((operandShape[operandShape.size() - 1] != 0 ||
1314          fftLength[fftRank - 1] != 0) &&
1315         operandShape[operandShape.size() - 1] != fftLength[fftRank - 1] / 2 + 1)
1316       return emitOptionalError(location,
1317                                "IRFFT requires innermost dimension match "
1318                                "fft_length[-1]/2+1. Got: ",
1319                                operandShape, " but fft_length is ", fftLength,
1320                                ".");
1321     resultShape[resultShape.size() - 1] = fftLength[fftRank - 1];
1322   }
1323 
1324   inferredReturnShapes.emplace_back(resultShape, resultElementType);
1325   return success();
1326 }
1327 
1328 //===----------------------------------------------------------------------===//
1329 // GatherOp
1330 //===----------------------------------------------------------------------===//
1331 
1332 // Converts gather ops to slice ops in case we have a single set of constant
1333 // indices.
1334 struct GatherSlice : public OpRewritePattern<GatherOp> {
1335   using OpRewritePattern<GatherOp>::OpRewritePattern;
1336 
matchAndRewritemlir::mhlo::GatherSlice1337   LogicalResult matchAndRewrite(GatherOp gather,
1338                                 PatternRewriter& rewriter) const override {
1339     DenseIntElementsAttr index;
1340     if (!matchPattern(gather.start_indices(), m_Constant(&index)))
1341       return failure();
1342 
1343     const auto& dnums = gather.dimension_numbers();
1344     if (dnums.getIndexVectorDim() != 0 || index.getType().getRank() > 1)
1345       return failure();
1346 
1347     // TODO(tberghammer): Remove when the verifier catches this case what is
1348     // invalid if all previous condition holds.
1349     if (index.getNumElements() !=
1350         static_cast<int64_t>(dnums.getStartIndexMap().size()))
1351       return failure();
1352 
1353     RankedTensorType operandType =
1354         gather->getOperand(0).getType().dyn_cast<RankedTensorType>();
1355     if (!operandType || !operandType.hasStaticShape()) return failure();
1356 
1357     auto sliceEnd =
1358         llvm::to_vector<8>(gather.slice_sizes().getValues<int64_t>());
1359     llvm::SmallVector<int64_t, 8> sliceStart(sliceEnd.size(), 0);
1360     for (auto it :
1361          llvm::zip(dnums.getStartIndexMap(), index.getValues<APInt>())) {
1362       int64_t mapIndex = std::get<0>(it);
1363       // Clamp the indices within bounds to faithfully mirror gather semantics.
1364       int64_t offset =
1365           clamp(std::get<1>(it).getSExtValue(), static_cast<int64_t>(0),
1366                 operandType.getDimSize(mapIndex) - sliceEnd[mapIndex]);
1367       sliceStart[mapIndex] += offset;
1368       sliceEnd[mapIndex] += offset;
1369     }
1370 
1371     llvm::SmallVector<int64_t, 8> sliceStride(sliceEnd.size(), 1);
1372     llvm::SmallVector<int64_t, 8> sliceShape(sliceEnd.size());
1373     for (size_t i = 0; i < sliceEnd.size(); ++i) {
1374       sliceShape[i] = sliceEnd[i] - sliceStart[i];
1375     }
1376     Type elementType = gather.getType().cast<TensorType>().getElementType();
1377     auto sliceType = RankedTensorType::get(sliceShape, elementType);
1378     Value result = rewriter.create<SliceOp>(
1379         gather.getLoc(), sliceType, gather.getOperand(0),
1380         rewriter.getI64TensorAttr(sliceStart),
1381         rewriter.getI64TensorAttr(sliceEnd),
1382         rewriter.getI64TensorAttr(sliceStride));
1383 
1384     auto collapsedSliceDims = dnums.getCollapsedSliceDims();
1385     if (!collapsedSliceDims.empty()) {
1386       llvm::SmallVector<int64_t, 8> reshapeShape;
1387       for (size_t i = 0; i < sliceShape.size(); ++i) {
1388         if (llvm::count(collapsedSliceDims, i) == 0) {
1389           reshapeShape.push_back(sliceShape[i]);
1390         }
1391       }
1392       auto reshapeType = RankedTensorType::get(reshapeShape, elementType);
1393       result = rewriter.create<ReshapeOp>(gather.getLoc(), reshapeType, result);
1394     }
1395 
1396     result.setType(gather.getType());
1397     rewriter.replaceOp(gather, result);
1398     return success();
1399   }
1400 };
1401 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1402 void GatherOp::getCanonicalizationPatterns(RewritePatternSet& results,
1403                                            MLIRContext* context) {
1404   results.add<GatherSlice>(context);
1405 }
1406 
1407 namespace {
1408 
1409 // following https://www.tensorflow.org/xla/operation_semantics#gather
1410 // The bounds for the output array along dimension i is computed as follows:
1411 // (1) If i is present in batch_dims (i.e. is equal to batch_dims[k] for some k)
1412 // then we pick
1413 // the corresponding dimension bounds out of start_indices.shape, skipping
1414 // index_vector_dim
1415 // (i.e. pick start_indices.shape.dims[k] if k < index_vector_dim and
1416 // start_indices.shape.dims[k+1] otherwise).
1417 // (2) If i is present in offset_dims (i.e. equal to offset_dims[k] for some k)
1418 // then we pick
1419 // the corresponding bound out of slice_sizes after accounting for
1420 // collapsed_slice_dims
1421 // (i.e. we pick adjusted_slice_sizes[k] where adjusted_slice_sizes is
1422 // slice_sizes with the bounds at indices collapsed_slice_dims removed).
1423 
getSliceSizeValues(GatherOp * gather,OpBuilder & builder,Location loc,ValueRange operands,SmallVectorImpl<Value> & sliceSizes)1424 void getSliceSizeValues(GatherOp* gather, OpBuilder& builder, Location loc,
1425                         ValueRange operands,
1426                         SmallVectorImpl<Value>& sliceSizes) {
1427   for (int64_t val : gather->slice_sizes().getValues<int64_t>()) {
1428     sliceSizes.push_back(builder.create<arith::ConstantIndexOp>(loc, val));
1429   }
1430 }
1431 
getSliceSizeValues(DynamicGatherOp *,OpBuilder & builder,Location loc,ValueRange operands,SmallVectorImpl<Value> & sliceSizeValues)1432 void getSliceSizeValues(DynamicGatherOp* /*dGather*/, OpBuilder& builder,
1433                         Location loc, ValueRange operands,
1434                         SmallVectorImpl<Value>& sliceSizeValues) {
1435   DynamicGatherOp::Adaptor adaptor(operands);
1436   Value sliceSizes = adaptor.slice_sizes();
1437   auto sliceSizesTy = sliceSizes.getType().cast<ShapedType>();
1438   for (int64_t i = 0; i < sliceSizesTy.getDimSize(0); ++i) {
1439     Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
1440     sliceSizeValues.push_back(
1441         builder.create<tensor::ExtractOp>(loc, sliceSizes, idx));
1442   }
1443 }
1444 
1445 // Verify the following properties:
1446 //  P1. Verify no repeat in start_index_map.
1447 //  P2. Verify 0 <= start_index_map[i] < rank(operand), for every i.
1448 //  P3. Verify 0 <= index_vector_dim <= rank(start_indices).
1449 //  P4. Verify size(start_index_map) == shape(start_indices)[index_vector_dim].
1450 //  P5. Verify offset_dims is_sorted and no repeated.
1451 //  P6. Verify collapsed_slice_dims is_sorted and no repeated.
1452 //  P7. Verify rank(operand) == size(offset_dims) + size(collapsed_slice_dims).
1453 //  P8. Verify slice_sizes has rank of 1.
1454 //  P9. Verify size(slice_sizes) == rank(operand).
1455 //  P10. Verify 0 <= collapsed_slice_dims[i] < size(slice_sizes) for all items.
verifyGather(ShapeAdaptor operandShape,ShapeAdaptor startIndicesShape,ShapeAdaptor sliceSizesShape,GatherDimensionNumbersAttr dimensionNumbers,llvm::function_ref<InFlightDiagnostic ()> errorEmitter)1456 static LogicalResult verifyGather(
1457     ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape,
1458     ShapeAdaptor sliceSizesShape, GatherDimensionNumbersAttr dimensionNumbers,
1459     llvm::function_ref<InFlightDiagnostic()> errorEmitter) {
1460   int64_t indexVectorDim = dimensionNumbers.getIndexVectorDim();
1461 
1462   // Check startIndexMap
1463   auto startIndexMap = to_vector(dimensionNumbers.getStartIndexMap());
1464   // P1.
1465   if (hasDuplicates(startIndexMap))
1466     return errorEmitter() << "expects start_index_map to not repeat, got: ["
1467                           << startIndexMap << "]";
1468 
1469   // P2.
1470   for (int i = 0; i < startIndexMap.size(); ++i)
1471     if (startIndexMap[i] < 0 ||
1472         (operandShape.hasRank() && startIndexMap[i] >= operandShape.getRank()))
1473       return errorEmitter()
1474              << "start_index_map[" << i << "]: " << startIndexMap[i]
1475              << " is out of bounds for "
1476              << "operand rank " << operandShape.getRank();
1477 
1478   if (startIndicesShape.hasRank()) {
1479     // P3.
1480     // index_vector_dim == start_indices.rank implies a trailing 1 on the shape
1481     // of start_indices.
1482     if (indexVectorDim > startIndicesShape.getRank() || indexVectorDim < 0)
1483       return errorEmitter() << "index_vector_dim " << indexVectorDim
1484                             << " is out of bounds for start indices with rank "
1485                             << startIndicesShape.getRank();
1486 
1487     bool impliedTrailingDim = indexVectorDim == startIndicesShape.getRank();
1488     if (impliedTrailingDim || !startIndicesShape.isDynamicDim(indexVectorDim)) {
1489       int64_t effectiveDimSize;
1490       if (impliedTrailingDim)
1491         effectiveDimSize = 1;
1492       else
1493         effectiveDimSize = startIndicesShape.getDimSize(indexVectorDim);
1494       // P4.
1495       if (effectiveDimSize !=
1496           static_cast<int64_t>(dimensionNumbers.getStartIndexMap().size()))
1497         return errorEmitter() << "start_index_map size ("
1498                               << dimensionNumbers.getStartIndexMap().size()
1499                               << ") is not equal to size of index dimension ("
1500                               << indexVectorDim << ") of start_indices ("
1501                               << effectiveDimSize << ")";
1502     }
1503   }
1504 
1505   // P5.
1506   auto offsetDims = to_vector(dimensionNumbers.getOffsetDims());
1507   if (!llvm::is_sorted(offsetDims))
1508     return errorEmitter() << "expects offset_dims to be sorted, got: ["
1509                           << offsetDims << "]";
1510   if (hasDuplicates(offsetDims))
1511     return errorEmitter() << "expects offset_dims to not repeat, got: ["
1512                           << offsetDims << "]";
1513 
1514   // P6.
1515   auto collapsedSliceDims = to_vector(dimensionNumbers.getCollapsedSliceDims());
1516   if (!llvm::is_sorted(collapsedSliceDims))
1517     return errorEmitter() << "expects collapsed_slice_dims to be sorted, got: ["
1518                           << collapsedSliceDims << "]";
1519   if (hasDuplicates(collapsedSliceDims))
1520     return errorEmitter()
1521            << "expects collapsed_slice_dims to not repeat, got: ["
1522            << collapsedSliceDims << "]";
1523 
1524   // P7.
1525   int64_t impliedOperandRank = dimensionNumbers.getOffsetDims().size() +
1526                                dimensionNumbers.getCollapsedSliceDims().size();
1527   if (operandShape.hasRank() && operandShape.getRank() != impliedOperandRank)
1528     return errorEmitter() << "offset_dims size ("
1529                           << dimensionNumbers.getOffsetDims().size()
1530                           << ") plus collapse_slice_dims size ("
1531                           << dimensionNumbers.getCollapsedSliceDims().size()
1532                           << ") is not equal to operand rank ("
1533                           << operandShape.getRank() << ")";
1534 
1535   // P8.
1536   // This should be fully expressible with type constraints, but it isn't
1537   // obvious how to do that with the current infrastructure.
1538   if (sliceSizesShape.hasRank() && sliceSizesShape.getRank() != 1)
1539     return errorEmitter() << "slice_sizes.rank != 1";
1540   if (sliceSizesShape.hasStaticShape()) {
1541     int64_t sliceSize = sliceSizesShape.getNumElements();
1542 
1543     // P9.
1544     if (sliceSize != impliedOperandRank)
1545       return errorEmitter() << "slice_sizes size (" << sliceSize
1546                             << ") not equal to (implied) operand rank ("
1547                             << impliedOperandRank << ")";
1548 
1549     // P10.
1550     for (auto dim : dimensionNumbers.getCollapsedSliceDims())
1551       if (dim < 0 || dim >= sliceSize)
1552         return errorEmitter() << "collapsed dimension " << dim
1553                               << " is out of bounds for slice_sizes.size ("
1554                               << sliceSize << ")";
1555   }
1556 
1557   return success();
1558 }
1559 
1560 // Verify the following properties:
1561 //  P1. Verifications by verifyGather().
1562 //  P2. Verify slice_sizes[i] <= 1 for i in collapsed_slice_dims.
1563 //  P3. Verify 0 <= slice_sizes[i] < shape(operand)[i], for every i.
verifyStaticGather(ShapeAdaptor operandShape,ShapeAdaptor startIndicesShape,DenseIntElementsAttr sliceSizes,GatherDimensionNumbersAttr dimensionNumbers,llvm::function_ref<InFlightDiagnostic ()> errorEmitter)1564 static LogicalResult verifyStaticGather(
1565     ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape,
1566     DenseIntElementsAttr sliceSizes,
1567     GatherDimensionNumbersAttr dimensionNumbers,
1568     llvm::function_ref<InFlightDiagnostic()> errorEmitter) {
1569   // P1.
1570   // For some reason the getType call is necessary here
1571   if (failed(verifyGather(
1572           /*operandShape=*/operandShape,
1573           /*startIndicesShape=*/startIndicesShape,
1574           /*sliceSizesShape=*/sliceSizes.getType(), dimensionNumbers,
1575           errorEmitter)))
1576     return failure();
1577 
1578   // P2.
1579   for (auto dim : dimensionNumbers.getCollapsedSliceDims()) {
1580     int64_t sliceDimSize = sliceSizes.getValues<int64_t>()[dim];
1581     if (sliceDimSize > 1) {
1582       return errorEmitter() << "slice_sizes collapsed dimension " << dim
1583                             << " should <= 1 but got " << sliceDimSize;
1584     }
1585   }
1586 
1587   // P3.
1588   if (operandShape.hasRank()) {
1589     for (const auto& it : llvm::enumerate(sliceSizes.getValues<int64_t>())) {
1590       if (operandShape.isDynamicDim(it.index())) continue;
1591       auto operandDimSize = operandShape.getDimSize(it.index());
1592       auto sliceDimSize = it.value();
1593       if (sliceDimSize < 0 || sliceDimSize > operandDimSize)
1594         return errorEmitter() << "slice size (" << sliceDimSize
1595                               << ") is out of bounds for operand dimension ("
1596                               << operandDimSize << ") at index " << it.index();
1597     }
1598   }
1599   return success();
1600 }
1601 
1602 template <typename dimTy>
inferGatherShape(int64_t resultRank,llvm::function_ref<dimTy (int64_t)> getStartIndicesDim,llvm::function_ref<dimTy (int64_t)> getSliceDim,GatherDimensionNumbersAttr dimensionNumbers,SmallVectorImpl<dimTy> & shape)1603 static void inferGatherShape(
1604     int64_t resultRank, llvm::function_ref<dimTy(int64_t)> getStartIndicesDim,
1605     llvm::function_ref<dimTy(int64_t)> getSliceDim,
1606     GatherDimensionNumbersAttr dimensionNumbers,
1607     SmallVectorImpl<dimTy>& shape) {
1608   ArrayRef<int64_t> collapsedSliceDims =
1609       dimensionNumbers.getCollapsedSliceDims();
1610   int64_t indexVectorDim = dimensionNumbers.getIndexVectorDim();
1611 
1612   // We don't necessarily know the rank of sliceSizes, but we do know that it
1613   // can't be larger than the highest collapsed dimension. So go through those
1614   // and populate the leading dimensions of adjustedSliceSizes. The trailing
1615   // dimensions can just be adjusted by an offset.
1616   const auto* maxCollapsedDimIt =
1617       std::max_element(collapsedSliceDims.begin(), collapsedSliceDims.end());
1618   int64_t maxCollapsedDim = -1;
1619   if (maxCollapsedDimIt != collapsedSliceDims.end())
1620     maxCollapsedDim = *maxCollapsedDimIt;
1621 
1622   SmallVector<dimTy> adjustedSliceSizePrefix;
1623   for (int dimIndex = 0; dimIndex <= maxCollapsedDim; ++dimIndex) {
1624     if (llvm::is_contained(collapsedSliceDims, dimIndex)) continue;
1625     adjustedSliceSizePrefix.push_back(getSliceDim(dimIndex));
1626   }
1627   auto getAdjustedSliceDim = [&](int64_t index) -> dimTy {
1628     if (index < static_cast<int64_t>(adjustedSliceSizePrefix.size()))
1629       return adjustedSliceSizePrefix[index];
1630     return getSliceDim(index + collapsedSliceDims.size());
1631   };
1632 
1633   ArrayRef<int64_t> offsetDims = dimensionNumbers.getOffsetDims();
1634 
1635   // Dimensions in the output that aren't offset dimensions are called batch
1636   // dimensions.
1637   SmallVector<int64_t> batchDims;
1638   for (int dim = 0; dim < resultRank; ++dim)
1639     if (!llvm::is_contained(offsetDims, dim)) batchDims.push_back(dim);
1640 
1641   for (int i = 0; i < resultRank; ++i) {
1642     const auto* offsetDimsIt =
1643         std::find(offsetDims.begin(), offsetDims.end(), i);
1644     if (offsetDimsIt != offsetDims.end()) {
1645       auto index = std::distance(offsetDims.begin(), offsetDimsIt);
1646       shape.push_back(getAdjustedSliceDim(index));
1647       continue;
1648     }
1649     auto* batchDimsIt = std::find(batchDims.begin(), batchDims.end(), i);
1650     assert(batchDimsIt != batchDims.end());
1651     auto index = std::distance(batchDims.begin(), batchDimsIt);
1652     // This can never run into the special case where start_indices gets
1653     // implicitly expanded with a trailing 1 if
1654     // index_vector_dim = start_indices.rank because then index would equal
1655     // index_vector_dim, which means we'd be looking at index+1, which would be
1656     // out of bounds anyway.
1657     if (index >= indexVectorDim) ++index;
1658     shape.push_back(getStartIndicesDim(index));
1659   }
1660 }
1661 
1662 // Verify the following properties:
1663 //  P1. Verify 0 <= offset_dims[i] < output_shape_rank, for every i.
1664 //      (output_shape_rank = size(offset_dims) + rank(start_indices) -1)
inferGatherReturnTypeComponents(ShapeAdaptor operandShape,ShapeAdaptor startIndicesShape,llvm::function_ref<int64_t (int64_t)> getSliceDim,GatherDimensionNumbersAttr dimensionNumbers,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes,llvm::function_ref<InFlightDiagnostic ()> errorEmitter)1665 static LogicalResult inferGatherReturnTypeComponents(
1666     ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape,
1667     llvm::function_ref<int64_t(int64_t)> getSliceDim,
1668     GatherDimensionNumbersAttr dimensionNumbers,
1669     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes,
1670     llvm::function_ref<InFlightDiagnostic()> errorEmitter) {
1671   Type elementType = operandShape.getElementType();
1672 
1673   // We need this to determine the result rank. We could still place bounds on
1674   // the result rank if that was something ShapedTypeComponents could express.
1675   if (!startIndicesShape.hasRank()) {
1676     inferredReturnShapes.push_back(elementType);
1677     return success();
1678   }
1679 
1680   ArrayRef<int64_t> offsetDims = dimensionNumbers.getOffsetDims();
1681   int64_t startIndicesRank = startIndicesShape.getRank();
1682   // If index_vector_dim == start_indices.rank, then an implicit trailing 1 is
1683   // appended to start_indices shape.
1684   if (dimensionNumbers.getIndexVectorDim() == startIndicesRank)
1685     ++startIndicesRank;
1686   int64_t resultRank = offsetDims.size() + startIndicesRank - 1;
1687   // P1.
1688   for (int i = 0; i < offsetDims.size(); ++i)
1689     if (offsetDims[i] < 0 || offsetDims[i] >= resultRank)
1690       return errorEmitter() << "offset_dims[" << i << "]: " << offsetDims[i]
1691                             << " is out of bounds for "
1692                             << "implied result rank " << resultRank;
1693 
1694   auto getStartIndicesDim = [&](int64_t index) {
1695     return startIndicesShape.getDimSize(index);
1696   };
1697 
1698   SmallVector<int64_t> shape;
1699   inferGatherShape<int64_t>(resultRank, getStartIndicesDim, getSliceDim,
1700                             dimensionNumbers, shape);
1701 
1702   inferredReturnShapes.emplace_back(shape, elementType);
1703   return success();
1704 }
1705 
1706 template <typename Op>
reifyGatherShape(Op * op,OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1707 LogicalResult reifyGatherShape(Op* op, OpBuilder& builder, ValueRange operands,
1708                                SmallVectorImpl<Value>& reifiedReturnShapes) {
1709   // No support for unranked gather output shape a.t.m.
1710   auto resultTy =
1711       op->getResult().getType().template dyn_cast<RankedTensorType>();
1712   if (!resultTy) return failure();
1713 
1714   typename Op::Adaptor adaptor(operands);
1715   Value startIndices = adaptor.start_indices();
1716 
1717   Location loc = op->getLoc();
1718   int resultRank = resultTy.getRank();
1719   Type shapeElTy = startIndices.getType().cast<ShapedType>().getElementType();
1720   auto toShapeElType = [&](Value v) {
1721     return maybeCastTo(builder, loc, v, shapeElTy);
1722   };
1723 
1724   SmallVector<Value, 4> sliceSizes;
1725   getSliceSizeValues(op, builder, loc, operands, sliceSizes);
1726   llvm::transform(sliceSizes, sliceSizes.begin(),
1727                   [&](Value v) { return toShapeElType(v); });
1728 
1729   auto getStartIndicesDim = [&](int64_t index) {
1730     return toShapeElType(
1731         builder.create<tensor::DimOp>(loc, startIndices, index));
1732   };
1733   SmallVector<Value, 4> shapeValues;
1734   auto getSliceDim = [&sliceSizes](int64_t index) -> Value {
1735     return sliceSizes[index];
1736   };
1737   inferGatherShape<Value>(resultRank, getStartIndicesDim, getSliceDim,
1738                           op->dimension_numbers(), shapeValues);
1739 
1740   Value outputShape = builder.create<tensor::FromElementsOp>(
1741       loc, RankedTensorType::get({resultRank}, shapeElTy), shapeValues);
1742   reifiedReturnShapes.push_back(outputShape);
1743 
1744   return success();
1745 }
1746 
1747 }  // namespace
1748 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1749 LogicalResult GatherOp::reifyReturnTypeShapes(
1750     OpBuilder& builder, ValueRange operands,
1751     SmallVectorImpl<Value>& reifiedReturnShapes) {
1752   return reifyGatherShape(this, builder, operands, reifiedReturnShapes);
1753 }
1754 
1755 // The following properties are already enforced by the ODS:
1756 //  P0. Verify the start_indices has element type of integer.
1757 // Verify the following properties:
1758 //  Verifications by verifyStaticGather() and verifyGather() inside it.
1759 //  Verifications by inferGatherReturnTypeComponents.
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1760 LogicalResult GatherOp::inferReturnTypeComponents(
1761     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
1762     DictionaryAttr attributes, RegionRange regions,
1763     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
1764   // TODO(zhouxin) remove this comment after the ordering issue is clear.
1765   // This can get called before other op verify methods, so we have to do a
1766   // bunch of verification up front. With a better story for ordering and/or
1767   // multi-phase op verification, this should hopefully all go away.
1768   Location loc = location.value_or(UnknownLoc::get(context));
1769   auto errorEmitter = [&loc]() {
1770     return mlir::emitError(loc)
1771            << "'" << GatherOp::getOperationName() << "' op ";
1772   };
1773   GatherOp::Adaptor adaptor(operands, attributes, regions);
1774   if (failed(adaptor.verify(loc))) return failure();
1775 
1776   // We want the ShapeAdaptors, so can't route via the adaptor :-/
1777   ShapeAdaptor operandShape = operands.getShape(0);
1778   ShapeAdaptor startIndicesShape = operands.getShape(1);
1779   GatherDimensionNumbersAttr dimensionNumbers = adaptor.dimension_numbers();
1780   DenseIntElementsAttr sliceSizesAttr = adaptor.slice_sizes();
1781 
1782   if (failed(verifyStaticGather(/*operandShape=*/operandShape,
1783                                 /*startIndicesShape=*/startIndicesShape,
1784                                 /*sliceSizes=*/sliceSizesAttr, dimensionNumbers,
1785                                 errorEmitter)))
1786     return failure();
1787 
1788   auto getSliceDim = [&sliceSizesAttr](int64_t index) -> int64_t {
1789     return sliceSizesAttr.getValues<int64_t>()[index];
1790   };
1791 
1792   return inferGatherReturnTypeComponents(operandShape, startIndicesShape,
1793                                          getSliceDim, dimensionNumbers,
1794                                          inferredReturnShapes, errorEmitter);
1795 }
1796 
1797 //===----------------------------------------------------------------------===//
1798 // DynamicGatherOp
1799 //===----------------------------------------------------------------------===//
1800 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1801 LogicalResult DynamicGatherOp::reifyReturnTypeShapes(
1802     OpBuilder& builder, ValueRange operands,
1803     SmallVectorImpl<Value>& reifiedReturnShapes) {
1804   return reifyGatherShape(this, builder, operands, reifiedReturnShapes);
1805 }
1806 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1807 LogicalResult DynamicGatherOp::inferReturnTypeComponents(
1808     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
1809     DictionaryAttr attributes, RegionRange regions,
1810     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
1811   // This can get called before other op verify methods, so we have to do a
1812   // bunch of verification up front. With a better story for ordering and/or
1813   // multi-phase op verification, this should hopefully all go away.
1814   Location loc = location.value_or(UnknownLoc::get(context));
1815   auto errorEmitter = [&loc]() {
1816     return mlir::emitError(loc)
1817            << "'" << DynamicGatherOp::getOperationName() << "' op ";
1818   };
1819   DynamicGatherOp::Adaptor adaptor(operands, attributes, regions);
1820   if (failed(adaptor.verify(loc))) return failure();
1821 
1822   // We want the ShapeAdaptors, so can't route via the adaptor :-/
1823   ShapeAdaptor operandShape = operands.getShape(0);
1824   ShapeAdaptor startIndicesShape = operands.getShape(1);
1825   ShapeAdaptor sliceSizesShape = operands.getShape(2);
1826   GatherDimensionNumbersAttr dimensionNumbers = adaptor.dimension_numbers();
1827 
1828   if (failed(verifyGather(/*operandShape=*/operandShape,
1829                           /*startIndicesShape=*/startIndicesShape,
1830                           /*sliceSizesShape=*/sliceSizesShape, dimensionNumbers,
1831                           errorEmitter)))
1832     return failure();
1833 
1834   auto getSliceDim = [](int64_t index) { return ShapedType::kDynamicSize; };
1835   return inferGatherReturnTypeComponents(operandShape, startIndicesShape,
1836                                          getSliceDim, dimensionNumbers,
1837                                          inferredReturnShapes, errorEmitter);
1838 }
1839 
1840 //===----------------------------------------------------------------------===//
1841 // GetDimensionSizeOp
1842 //===----------------------------------------------------------------------===//
1843 //
verify()1844 LogicalResult GetDimensionSizeOp::verify() { return verifyDimAttr(*this); }
1845 
1846 /// Fold get_dimension_size when the said shape dimension is a constant.
fold(ArrayRef<Attribute> attrs)1847 OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
1848   RankedTensorType type = operand().getType().dyn_cast<RankedTensorType>();
1849   if (!type) return {};
1850 
1851   int32_t dim = dimension();
1852   if (type.isDynamicDim(dim)) return {};
1853   // The result type is always is a 0-d i32 tensor.
1854   return DenseIntElementsAttr::get<int32_t>(
1855       getResult().getType().cast<RankedTensorType>(), type.getDimSize(dim));
1856 }
1857 
1858 //===----------------------------------------------------------------------===//
1859 // IotaOp
1860 //===----------------------------------------------------------------------===//
1861 
verify()1862 LogicalResult IotaOp::verify() {
1863   auto shape = getType().cast<ShapedType>();
1864   if (!shape.hasRank()) return success();
1865 
1866   if (shape.getRank() == 0) return emitOpError() << "does not support scalars.";
1867 
1868   auto iotaDimension = this->iota_dimension();
1869   if (static_cast<int64_t>(iotaDimension) >= shape.getRank() ||
1870       iotaDimension < 0)
1871     return emitOpError()
1872            << "iota dimension cannot go beyond the output rank or be negative.";
1873   return success();
1874 }
1875 
1876 // Iota operations across multiple dimensions can be reduced to an iota and a
1877 // ranked broadcast.
1878 struct IotaBroadcast : public OpRewritePattern<IotaOp> {
1879   using OpRewritePattern<IotaOp>::OpRewritePattern;
1880 
matchAndRewritemlir::mhlo::IotaBroadcast1881   LogicalResult matchAndRewrite(IotaOp iota,
1882                                 PatternRewriter& rewriter) const override {
1883     auto resultTy = iota.getType().cast<ShapedType>();
1884     if (!resultTy.hasRank() || resultTy.getRank() < 2) {
1885       return failure();
1886     }
1887 
1888     auto iotaDimension = iota.iota_dimension();
1889 
1890     auto iotaType = RankedTensorType::get({resultTy.getDimSize(iotaDimension)},
1891                                           resultTy.getElementType());
1892 
1893     auto newIota = rewriter.create<IotaOp>(iota.getLoc(), iotaType,
1894                                            rewriter.getI64IntegerAttr(0));
1895 
1896     auto broadcastAttr = DenseIntElementsAttr::get(
1897         RankedTensorType::get({1}, rewriter.getIntegerType(64)),
1898         {iotaDimension});
1899     rewriter.replaceOpWithNewOp<BroadcastInDimOp>(iota, resultTy, newIota,
1900                                                   broadcastAttr);
1901     return success();
1902   }
1903 };
1904 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1905 void IotaOp::getCanonicalizationPatterns(RewritePatternSet& results,
1906                                          MLIRContext* context) {
1907   results.add<IotaBroadcast>(context);
1908 }
1909 
fold(ArrayRef<Attribute> operands)1910 OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
1911   auto dimension = iota_dimension();
1912   auto resultTy = getResult().getType().cast<ShapedType>();
1913   if (resultTy.hasRank() && resultTy.getDimSize(dimension) == 1) {
1914     Builder builder(getContext());
1915     return builder.getZeroAttr(resultTy);
1916   }
1917 
1918   return {};
1919 }
1920 
1921 //===----------------------------------------------------------------------===//
1922 // DynamicIotaOp
1923 //===----------------------------------------------------------------------===//
1924 
1925 // Does the same as PatternRewriter::replaceOpWithNewOp, but with a twist.
1926 //
1927 // Sometimes, we want to replace an op with a new op and simultaneously refine
1928 // the result type from a dynamically-shaped type to a statically-shaped type.
1929 // (Search for usages of this function for examples).
1930 //
1931 // Oftentimes, this works just fine because MHLO is designed to accommodate
1932 // this kind of type refinements. But sometimes, this doesn't work - when
1933 // the op is used outside of the MHLO dialect (e.g. in func.return). In these
1934 // cases, we insert a tensor.cast to smooth things out.
1935 template <typename OpTy, typename... Args>
refineOpWithNewOp(PatternRewriter & rewriter,Operation * op,Args &&...args)1936 OpTy refineOpWithNewOp(PatternRewriter& rewriter, Operation* op,
1937                        Args&&... args) {
1938   auto newOp = rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
1939 
1940   llvm::SmallVector<Value> replacementResults;
1941   assert(op->getNumResults() == newOp->getNumResults() &&
1942          "replacement op doesn't match results of original op");
1943   for (auto [opResult, newOpResult] :
1944        llvm::zip(op->getResults(), newOp->getResults())) {
1945     Value replacementResult = newOpResult;
1946     if (llvm::any_of(opResult.getUsers(), [&](Operation* user) {
1947           return user->getDialect() != op->getDialect();
1948         })) {
1949       replacementResult = rewriter.create<tensor::CastOp>(
1950           op->getLoc(), opResult.getType(), newOpResult);
1951     }
1952     replacementResults.push_back(replacementResult);
1953   }
1954 
1955   rewriter.replaceOp(op, replacementResults);
1956   return newOp;
1957 }
1958 
1959 namespace {
1960 
1961 struct DynamicIotaIsStatic : public OpRewritePattern<DynamicIotaOp> {
1962   using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
1963 
matchAndRewritemlir::mhlo::__anon00baf10a1611::DynamicIotaIsStatic1964   LogicalResult matchAndRewrite(DynamicIotaOp iota,
1965                                 PatternRewriter& rewriter) const override {
1966     // Result type has static shape, replace with iota.
1967     auto resultTy = iota.getType().cast<ShapedType>();
1968     if (resultTy.hasStaticShape()) {
1969       rewriter.replaceOpWithNewOp<IotaOp>(iota, resultTy,
1970                                           iota.iota_dimension());
1971       return success();
1972     }
1973 
1974     // Output shape is constant, compute result type with static shape, then
1975     // replace with iota.
1976     DenseIntElementsAttr outputShapeAttr;
1977     if (matchPattern(iota.output_shape(), m_Constant(&outputShapeAttr))) {
1978       SmallVector<int64_t> outputShape;
1979       for (APInt dim : outputShapeAttr.getValues<APInt>()) {
1980         outputShape.push_back(dim.getSExtValue());
1981       }
1982       resultTy = RankedTensorType::get(outputShape, resultTy.getElementType());
1983       refineOpWithNewOp<IotaOp>(rewriter, iota, resultTy,
1984                                 iota.iota_dimension());
1985       return success();
1986     }
1987 
1988     return rewriter.notifyMatchFailure(
1989         iota, "requires static shape or constant output shape");
1990   }
1991 };
1992 
1993 // Dynamic Iota operations across multiple dimensions can be reduced to an iota
1994 // and a ranked broadcast.
1995 struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
1996   using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
1997 
matchAndRewritemlir::mhlo::__anon00baf10a1611::DynamicIotaBroadcast1998   LogicalResult matchAndRewrite(DynamicIotaOp iota,
1999                                 PatternRewriter& rewriter) const override {
2000     auto resultTy = iota.getType().cast<ShapedType>();
2001     if (!resultTy.hasRank() || resultTy.getRank() < 2) {
2002       return failure();
2003     }
2004 
2005     auto iotaDimension = iota.iota_dimension();
2006     auto iotaDimensionInt = iotaDimension;
2007 
2008     auto convertedShape = rewriter.create<arith::IndexCastOp>(
2009         iota.getLoc(),
2010         RankedTensorType::get(
2011             iota.output_shape().getType().cast<ShapedType>().getShape(),
2012             rewriter.getI64Type()),
2013         iota.output_shape());
2014 
2015     auto slicedShape = rewriter.create<SliceOp>(
2016         iota.getLoc(), convertedShape,
2017         rewriter.getI64TensorAttr(iotaDimensionInt),
2018         rewriter.getI64TensorAttr(iotaDimensionInt + 1),
2019         rewriter.getI64TensorAttr(1));
2020 
2021     auto convertedSlicedShape = rewriter.create<arith::IndexCastOp>(
2022         iota.getLoc(),
2023         RankedTensorType::get(
2024             {1},
2025             iota.output_shape().getType().cast<ShapedType>().getElementType()),
2026         slicedShape);
2027 
2028     auto iotaType = RankedTensorType::get(
2029         {resultTy.getDimSize(iotaDimensionInt)}, resultTy.getElementType());
2030 
2031     auto newIota = rewriter.create<DynamicIotaOp>(
2032         iota.getLoc(), iotaType, convertedSlicedShape,
2033         rewriter.getI64IntegerAttr(0));
2034 
2035     auto broadcastAttr = DenseIntElementsAttr::get(
2036         RankedTensorType::get({1}, rewriter.getIntegerType(64)),
2037         {iotaDimension});
2038     rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
2039         iota, resultTy, newIota, iota.output_shape(), broadcastAttr);
2040     return success();
2041   }
2042 };
2043 
2044 }  // namespace
2045 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2046 void DynamicIotaOp::getCanonicalizationPatterns(RewritePatternSet& results,
2047                                                 MLIRContext* context) {
2048   results.add<DynamicIotaIsStatic>(context);
2049   results.add<DynamicIotaBroadcast>(context);
2050 }
2051 
castToIndexTensor(OpBuilder & builder,Location loc,Value shapeOp)2052 static Value castToIndexTensor(OpBuilder& builder, Location loc,
2053                                Value shapeOp) {
2054   ShapedType resultTy = shape::getExtentTensorType(
2055       builder.getContext(), shapeOp.getType().cast<ShapedType>().getDimSize(0));
2056   if (shapeOp.getType() == resultTy) return shapeOp;  // Nothing to do.
2057   return builder.create<arith::IndexCastOp>(loc, resultTy, shapeOp);
2058 }
2059 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2060 LogicalResult DynamicIotaOp::reifyReturnTypeShapes(
2061     OpBuilder& builder, ValueRange operands,
2062     SmallVectorImpl<Value>& reifiedReturnShapes) {
2063   DynamicIotaOp::Adaptor adaptor(operands);
2064   reifiedReturnShapes.push_back(
2065       castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
2066   return success();
2067 }
2068 
2069 //===----------------------------------------------------------------------===//
2070 // DynamicUpdateSliceOp
2071 //===----------------------------------------------------------------------===//
2072 
verify()2073 LogicalResult DynamicUpdateSliceOp::verify() {
2074   OperandRange indices = start_indices();
2075   if (indices.size() <= 1) return success();
2076 
2077   // Note: start_indices is constrained to Variadic<HLO_ScalarIntTensor>, so it
2078   // is OK to cast indices to ShapedType here.
2079   auto idxTensor = indices.take_front().front().getType().cast<ShapedType>();
2080   Type firstElemTy = idxTensor.getElementType();
2081   Type elemTy;
2082 
2083   for (auto idx : llvm::drop_begin(indices, 1)) {
2084     idxTensor = idx.getType().cast<ShapedType>();
2085     elemTy = idxTensor.getElementType();
2086 
2087     if (firstElemTy != elemTy) {
2088       return emitOpError() << "start indices must have same element type "
2089                               "(encountered mismatch: "
2090                            << firstElemTy << " vs " << elemTy << ")";
2091     }
2092   }
2093   return success();
2094 }
2095 
fold(ArrayRef<Attribute> operands)2096 OpFoldResult DynamicUpdateSliceOp::fold(ArrayRef<Attribute> operands) {
2097   auto operandShape = this->operand().getType().cast<RankedTensorType>();
2098   auto updateShape = this->update().getType().cast<RankedTensorType>();
2099 
2100   // If any of the dimensions are length-0, the update does nothing.
2101   for (auto dim : updateShape.getShape()) {
2102     if (dim == 0) {
2103       return this->operand();
2104     }
2105   }
2106 
2107   if (operandShape != updateShape || !operandShape.hasStaticShape()) {
2108     return {};
2109   }
2110 
2111   // Ensure that indices are 0 constants. The 0 check mostly ensures
2112   // correctness. For non-constants, the pattern does not fold to avoid hiding
2113   // the behavior of incorrect user input.
2114   for (Value index : this->start_indices()) {
2115     DenseIntElementsAttr deAttr;
2116     if (!matchPattern(index, m_Constant(&deAttr))) return {};
2117     if (!deAttr.getSplatValue<IntegerAttr>().getValue().isZero()) return {};
2118   }
2119   return this->update();
2120 }
2121 
2122 //===----------------------------------------------------------------------===//
2123 // AbsOp
2124 //===----------------------------------------------------------------------===//
2125 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2126 LogicalResult AbsOp::inferReturnTypes(
2127     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
2128     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
2129   auto operandTy = (*operands.begin()).getType().cast<ShapedType>();
2130   Type elementTy = operandTy.getElementType();
2131   if (auto complexTy = elementTy.dyn_cast<ComplexType>()) {
2132     elementTy = complexTy.getElementType();
2133   }
2134 
2135   Type resultTy;
2136   if (auto rankedOperandTy = operandTy.dyn_cast<RankedTensorType>()) {
2137     resultTy = RankedTensorType::get(operandTy.getShape(), elementTy,
2138                                      rankedOperandTy.getEncoding());
2139   } else if (operandTy.hasRank()) {
2140     resultTy = RankedTensorType::get(operandTy.getShape(), elementTy);
2141   } else {
2142     resultTy = UnrankedTensorType::get(elementTy);
2143   }
2144   inferredReturnTypes.push_back(resultTy);
2145   return success();
2146 }
2147 
2148 //===----------------------------------------------------------------------===//
2149 // CollectivePermuteOp
2150 //===----------------------------------------------------------------------===//
2151 
verify()2152 LogicalResult CollectivePermuteOp::verify() {
2153   return mlir::hlo::verifyCollectivePermuteSourceTargetPairs(
2154       *this, source_target_pairs());
2155 }
2156 
2157 //===----------------------------------------------------------------------===//
2158 // ConvolutionOp
2159 //===----------------------------------------------------------------------===//
2160 
2161 namespace {
2162 // Checks:
2163 //  P1. Same sizes for input, kernel and output spatial_dims.
2164 //  P2. Spatial and non-spatial dimentions (for input,kernel, &output) should
2165 //      be unique and in range [0, num_dims), where num_dims = rank of input
2166 //      (lhs/rhs) tensors.
2167 //
2168 //  Note that the spatial + non-spatial dimensions may not cover all the
2169 //  dimensions in the range [0,num) because of the presence of 'unknown'
2170 //  dimensions (ref. cl/415132294).
isSpatialDimensionsValid(ConvolutionOp op)2171 LogicalResult isSpatialDimensionsValid(ConvolutionOp op) {
2172   auto inputSpatialDimensions =
2173       op.dimension_numbers().getInputSpatialDimensions();
2174   auto kernelSpatialDimensions =
2175       op.dimension_numbers().getKernelSpatialDimensions();
2176   auto outputSpatialDimensions =
2177       op.dimension_numbers().getOutputSpatialDimensions();
2178 
2179   // P1.
2180   if ((inputSpatialDimensions.size() != kernelSpatialDimensions.size()) ||
2181       (inputSpatialDimensions.size() != outputSpatialDimensions.size()))
2182     return op.emitOpError() << "expects the same size for input, kernel and "
2183                                "output spatial-dimensions, but got "
2184                             << inputSpatialDimensions.size() << ", "
2185                             << kernelSpatialDimensions.size() << ", and "
2186                             << outputSpatialDimensions.size() << " resp.";
2187 
2188   // P2.
2189   SmallVector<int64_t> inputDnums(inputSpatialDimensions.size() + 2);
2190   inputDnums[0] = op.dimension_numbers().getInputBatchDimension();
2191   inputDnums[1] = op.dimension_numbers().getInputFeatureDimension();
2192   std::copy(inputSpatialDimensions.begin(), inputSpatialDimensions.end(),
2193             inputDnums.begin() + 2);
2194 
2195   SmallVector<int64_t> windowDnums(kernelSpatialDimensions.size() + 2);
2196   windowDnums[0] = op.dimension_numbers().getKernelInputFeatureDimension();
2197   windowDnums[1] = op.dimension_numbers().getKernelOutputFeatureDimension();
2198   std::copy(kernelSpatialDimensions.begin(), kernelSpatialDimensions.end(),
2199             windowDnums.begin() + 2);
2200 
2201   SmallVector<int64_t> outputDnums(outputSpatialDimensions.size() + 2);
2202   outputDnums[0] = op.dimension_numbers().getOutputBatchDimension();
2203   outputDnums[1] = op.dimension_numbers().getOutputFeatureDimension();
2204   std::copy(outputSpatialDimensions.begin(), outputSpatialDimensions.end(),
2205             outputDnums.begin() + 2);
2206 
2207   auto numDims = op.lhs().getType().cast<RankedTensorType>().getRank();
2208   const auto inRange = [numDims](int64_t i) { return 0 <= i && i < numDims; };
2209 
2210   if (!llvm::all_of(inputDnums, inRange) ||
2211       !llvm::all_of(windowDnums, inRange) ||
2212       !llvm::all_of(outputDnums, inRange))
2213     return op.emitOpError() << "expects input, kernel, and output "
2214                                "dimension-numbers to be in-range [0, "
2215                             << numDims << ").";
2216 
2217   if (hasDuplicates(inputDnums))
2218     return op.emitOpError()
2219            << "expects input dimension-numbers to be unique, got {"
2220            << inputDnums << "}.";
2221 
2222   if (hasDuplicates(windowDnums))
2223     return op.emitOpError()
2224            << "expects kernel dimension-numbers to be unique, got {"
2225            << windowDnums << "}.";
2226 
2227   if (hasDuplicates(outputDnums))
2228     return op.emitOpError()
2229            << "expects output dimension-numbers to be unique, got {"
2230            << outputDnums << "}.";
2231 
2232   return success();
2233 }
2234 
2235 // Verifies the following properties:
2236 //  P1. The input, kernel, and output spatial-dimentions are valid.
2237 //  P2. Given,
2238 //          input-dimensions: b * input-spatial-dims * f
2239 //          kernel-dimensions: kernel-spatial-dims * i * o
2240 //          output-dimensions: b' * out-spatial-dims * f'
2241 //            where b = input-batch-dims
2242 //            where f = input-feature-dims
2243 //            where i = kernel-input-feature-dims
2244 //            where o = kernel-output-feature-dims
2245 //            where b' = output-batch-dims
2246 //            where f' = output-feature-dims
2247 //      Check the following properties w.r.t feature_group_count (fgc) and
2248 //      batch_group_count (bgc).
2249 //        fgc > 0, bgc > 1 and !(fgc > 1 && bgc > 1)
2250 //        b % bgc == 0
2251 //        f % fgc == 0 and i = f / fgc
2252 //        o (or f') % bgc == 0 and o (or f') % fgc == 0
verifyConvolutionAttributes(ConvolutionOp op)2253 LogicalResult verifyConvolutionAttributes(ConvolutionOp op) {
2254   // P1.
2255   if (failed(isSpatialDimensionsValid(op))) return failure();
2256 
2257   // P2.
2258   const int64_t featureGroupCount = op.feature_group_count();
2259   const int64_t batchGroupCount = op.batch_group_count();
2260 
2261   if (featureGroupCount <= 0)
2262     return op.emitOpError()
2263            << "expects feature_group_count to be a positive number, got "
2264            << featureGroupCount << ".";
2265 
2266   if (batchGroupCount <= 0)
2267     return op.emitOpError()
2268            << "expects batch_group_count to be a positive number, got "
2269            << batchGroupCount << ".";
2270 
2271   if (batchGroupCount > 1 && featureGroupCount > 1)
2272     return op.emitOpError()
2273            << "expects batch_group_count and feature_group_count not to be "
2274               "both greater than 1. Got "
2275            << batchGroupCount << " and " << featureGroupCount << " resp.";
2276 
2277   auto lhsType = op.lhs().getType().cast<RankedTensorType>();
2278   const int64_t inputFeatures =
2279       lhsType.getShape()[op.dimension_numbers().getInputFeatureDimension()];
2280   const int64_t inputBatch =
2281       lhsType.getShape()[op.dimension_numbers().getInputBatchDimension()];
2282 
2283   auto rhsType = op.rhs().getType().cast<RankedTensorType>();
2284   const int64_t kernelInputFeatures =
2285       rhsType
2286           .getShape()[op.dimension_numbers().getKernelInputFeatureDimension()];
2287   const int64_t kernelOutputFeatures =
2288       rhsType
2289           .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()];
2290 
2291   if (!isDynamicDimSize(kernelOutputFeatures)) {
2292     if (kernelOutputFeatures % batchGroupCount != 0)
2293       return op.emitOpError() << "expects output feature dimension size ("
2294                               << kernelOutputFeatures
2295                               << ") to be a multiple of "
2296                                  "batch_group_count. Got batch_group_count = "
2297                               << batchGroupCount << ".";
2298 
2299     if (kernelOutputFeatures % featureGroupCount != 0)
2300       return op.emitOpError()
2301              << "expects kernel output feature dimension ("
2302              << kernelOutputFeatures
2303              << ") to be divisible by "
2304                 "feature_group_count. For feature_group_count = "
2305              << featureGroupCount << ".";
2306   }
2307 
2308   if (!isDynamicDimSize(inputFeatures)) {
2309     if (inputFeatures % featureGroupCount != 0)
2310       return op.emitOpError()
2311              << "expects input feature dimension (" << inputFeatures
2312              << ") to be a multiple of "
2313                 "feature_group_count. Got feature_group_count = "
2314              << featureGroupCount << ".";
2315 
2316     if (!isDynamicDimSize(kernelInputFeatures) &&
2317         inputFeatures / featureGroupCount != kernelInputFeatures)
2318       return op.emitOpError()
2319              << "expects input feature dimension (" << inputFeatures
2320              << ") / "
2321                 "feature_group_count = kernel input feature dimension ("
2322              << kernelInputFeatures
2323              << "). Got feature_group_count = " << featureGroupCount << ".";
2324   }
2325 
2326   if (!isDynamicDimSize(inputBatch) && inputBatch % batchGroupCount != 0)
2327     return op.emitOpError() << "expects input batch dimension (" << inputBatch
2328                             << ") to be divisible by "
2329                                "batch_group_count. Got batch_group_count = "
2330                             << batchGroupCount << ".";
2331 
2332   return success();
2333 }
2334 
2335 // Infer the return-shape of ConvolutionOp.
2336 // Precondition:
2337 //  1. Input args to ConvolutionOp 'op' are RankedTypes.
2338 //  2. rank-of(input-type) == rank-of(output-type)
inferConvolutionOpReturnShape(ConvolutionOp op,const ArrayRef<WindowDimension> window)2339 SmallVector<int64_t> inferConvolutionOpReturnShape(
2340     ConvolutionOp op, const ArrayRef<WindowDimension> window) {
2341   // We keep the 'unknown' dimensions (cl/415132294) as it is in the
2342   // output-shape. To do that we initilize the output dimensions with the shape
2343   // of the return-type and updates only the spatial + non-spatial dimensions.
2344   // Precondition 2 ensures that size of output-shape == size of input-shape.
2345   SmallVector<int64_t> outputDimensions =
2346       to_vector(op.getResult().getType().cast<ShapedType>().getShape());
2347 
2348   // Infer the output spatial dimensions.
2349   auto lhsType = op.lhs().getType().cast<RankedTensorType>();
2350   auto inputSpatialDims = op.dimension_numbers().getInputSpatialDimensions();
2351   auto numSpatialDims = inputSpatialDims.size();
2352   SmallVector<int64_t> inputSpatialDimVals(numSpatialDims);
2353   for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i)
2354     inputSpatialDimVals[i] = lhsType.getShape()[inputSpatialDims[i]];
2355 
2356   auto windowOutputShape = inferWindowOutputShape(inputSpatialDimVals, window);
2357 
2358   for (int64_t i = 0; i < static_cast<int64_t>(window.size()); ++i)
2359     outputDimensions[op.dimension_numbers().getOutputSpatialDimensions()[i]] =
2360         windowOutputShape[i];
2361 
2362   // Infer the output-batch-dimension and output-feature-dimension.
2363   auto rhsType = op.rhs().getType().cast<RankedTensorType>();
2364   const int64_t inputBatch =
2365       lhsType.getShape()[op.dimension_numbers().getInputBatchDimension()];
2366   const int64_t kernelOutputFeatures =
2367       rhsType
2368           .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()];
2369 
2370   outputDimensions[op.dimension_numbers().getOutputBatchDimension()] =
2371       isDynamicDimSize(inputBatch) ? ShapedType::kDynamicSize
2372                                    : inputBatch / op.batch_group_count();
2373   outputDimensions[op.dimension_numbers().getOutputFeatureDimension()] =
2374       kernelOutputFeatures;
2375 
2376   return outputDimensions;
2377 }
2378 
2379 // Some mhlo.convolutions are dot products, specifically when there is no
2380 // padding and no spatial dimensions. DotGeneralOp is general enough that it
2381 // can sufficiently describe it.
2382 struct ConvolutionIsDot : public OpRewritePattern<mhlo::ConvolutionOp> {
2383   using OpRewritePattern<mhlo::ConvolutionOp>::OpRewritePattern;
matchAndRewritemlir::mhlo::__anon00baf10a1711::ConvolutionIsDot2384   LogicalResult matchAndRewrite(mhlo::ConvolutionOp op,
2385                                 PatternRewriter& rewriter) const override {
2386     auto lhs = op.lhs();
2387     auto rhs = op.rhs();
2388     auto lhsTy = lhs.getType().cast<RankedTensorType>();
2389     auto rhsTy = rhs.getType().cast<RankedTensorType>();
2390     auto resultTy = op.getType().cast<RankedTensorType>();
2391 
2392     if (lhsTy.getRank() != 2) return failure();
2393     if (rhsTy.getRank() != 2) return failure();
2394 
2395     if (op.batch_group_count() != 1) return failure();
2396 
2397     // There should not be any padding if this is a matmul.
2398     auto dNums = op.dimension_numbers();
2399     assert(!op.padding() || op.padding()->empty());
2400     assert(dNums.getKernelSpatialDimensions().empty());
2401 
2402     auto lhsBatchDim = dNums.getInputBatchDimension();
2403     auto rhsBatchDim = dNums.getKernelOutputFeatureDimension();
2404     auto lhsContractDim = dNums.getInputFeatureDimension();
2405     auto rhsContractDim = dNums.getKernelInputFeatureDimension();
2406     auto outBatchDim = dNums.getOutputBatchDimension();
2407     auto outFeatureDim = dNums.getOutputFeatureDimension();
2408 
2409     // If the input features are not grouped then we can directly convert to an
2410     // mhlo.dot_general.
2411     if (op.feature_group_count() == 1) {
2412       // We can swap the lhs and rhs sides to avoid a transpose.
2413       if (outBatchDim == 1 && outFeatureDim == 0) {
2414         std::swap(lhs, rhs);
2415         std::swap(outBatchDim, outFeatureDim);
2416         std::swap(lhsContractDim, rhsContractDim);
2417       }
2418 
2419       auto dotNums = DotDimensionNumbersAttr::get(
2420           op.getContext(), {}, {}, {lhsContractDim}, {rhsContractDim});
2421       auto dotOp = rewriter.create<mhlo::DotGeneralOp>(
2422           op.getLoc(), op.getType(), lhs, rhs, dotNums,
2423           op.precision_config().value_or(nullptr));
2424 
2425       rewriter.replaceOp(op, dotOp.getResult());
2426       return success();
2427     }
2428 
2429     int64_t featureGroupCount = op.feature_group_count();
2430     int64_t lhsBatchSize = lhsTy.getDimSize(lhsBatchDim);
2431     int64_t lhsContractSize = lhsTy.getDimSize(lhsContractDim);
2432     int64_t rhsBatchSize = rhsTy.getDimSize(rhsBatchDim);
2433     int64_t rhsContractSize = rhsTy.getDimSize(rhsContractDim);
2434 
2435     llvm::SmallVector<int64_t> lhsShape;
2436     llvm::SmallVector<int64_t> rhsShape;
2437     lhsShape.resize(3, lhsBatchSize);
2438     rhsShape.resize(3, rhsContractSize);
2439     lhsShape[lhsContractDim] = featureGroupCount;
2440     lhsShape[lhsContractDim + 1] = lhsContractSize / featureGroupCount;
2441     rhsShape[rhsContractDim] = featureGroupCount;
2442     rhsShape[rhsContractDim + 1] = rhsBatchSize / featureGroupCount;
2443 
2444     lhsTy = RankedTensorType::get(lhsShape, lhsTy.getElementType());
2445     rhsTy = RankedTensorType::get(rhsShape, rhsTy.getElementType());
2446 
2447     lhs = rewriter.create<mhlo::ReshapeOp>(op.getLoc(), lhsTy, lhs);
2448     rhs = rewriter.create<mhlo::ReshapeOp>(op.getLoc(), rhsTy, rhs);
2449 
2450     auto dotTy = RankedTensorType::get(
2451         {featureGroupCount, lhsBatchSize, rhsBatchSize / featureGroupCount},
2452         resultTy.getElementType());
2453 
2454     auto dotNums = DotDimensionNumbersAttr::get(
2455         op.getContext(), {lhsContractDim}, {rhsContractDim},
2456         {lhsContractDim + 1}, {rhsContractDim == 0 ? 2 : 0});
2457     auto dotOp = rewriter.create<mhlo::DotGeneralOp>(
2458         op.getLoc(), dotTy, lhs, rhs, dotNums,
2459         op.precision_config().value_or(nullptr));
2460 
2461     llvm::SmallVector<int64_t> perms;
2462     perms.resize(3, dNums.getOutputBatchDimension() == 0 ? 0 : 2);
2463     perms[0] = dNums.getOutputFeatureDimension();
2464     perms[2] = dNums.getOutputFeatureDimension() + 1;
2465 
2466     auto transposeTy = RankedTensorType::get(
2467         {dotTy.getDimSize(perms[0]), dotTy.getDimSize(perms[1]),
2468          dotTy.getDimSize(perms[2])},
2469         dotTy.getElementType());
2470     auto transposeOp = rewriter.create<mhlo::TransposeOp>(
2471         op.getLoc(), transposeTy, dotOp, rewriter.getI64TensorAttr(perms));
2472 
2473     rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(op, resultTy, transposeOp);
2474     return success();
2475   }
2476 };
2477 
2478 }  // namespace
2479 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2480 void ConvolutionOp::getCanonicalizationPatterns(RewritePatternSet& results,
2481                                                 MLIRContext* context) {
2482   results.add<ConvolutionIsDot>(context);
2483 }
2484 
2485 /*
2486  * We intend to verify the following properties
2487  *  P1. Verify the input, kernel types.
2488  *  P2. Verify the convolution atributes.
2489  *  P3. Verify and collect the window atributes.
2490  *  P4. Verify the return shape.
2491  *      TODO(b/232574102): Verify the element-type of return-value.
2492  */
verify()2493 LogicalResult ConvolutionOp::verify() {
2494   auto lhsType = lhs().getType().dyn_cast<RankedTensorType>();
2495   auto rhsType = rhs().getType().dyn_cast<RankedTensorType>();
2496 
2497   if (!lhsType || !rhsType) return success();
2498 
2499   // P1.
2500   int numDims = lhsType.getRank();
2501   if (numDims != rhsType.getRank())
2502     return emitOpError()
2503            << "expects convolution arguments to have same number of "
2504               "dimensions. Got: "
2505            << lhsType << " and " << rhsType << ".";
2506 
2507   if (numDims < 2)
2508     return emitOpError()
2509            << "expects convolution arguments to have >= 2 dimensions. "
2510               "Got: "
2511            << lhsType << " and " << rhsType << ".";
2512 
2513   // P2.
2514   if (failed(verifyConvolutionAttributes(*this))) return failure();
2515 
2516   // P3.
2517   auto kernelSpatialDimensions =
2518       dimension_numbers().getKernelSpatialDimensions();
2519   SmallVector<int64_t> windowDimensions(kernelSpatialDimensions.size());
2520   for (size_t i = 0; i < windowDimensions.size(); i++)
2521     windowDimensions[i] = rhsType.getShape()[kernelSpatialDimensions[i]];
2522 
2523   auto paddingOrErr = convertNx2Attribute(this->padding(), getLoc());
2524   if (failed(paddingOrErr)) return failure();
2525   SmallVector<std::pair<int64_t, int64_t>> padding = *paddingOrErr;
2526 
2527   auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
2528       windowDimensions, convertDenseIntAttr(window_strides()), padding,
2529       convertDenseIntAttr(lhs_dilation()), convertDenseIntAttr(rhs_dilation()),
2530       getLoc());
2531   if (failed(windowOrErr)) return failure();
2532 
2533   // P4.
2534   auto actualReturnType = getResult().getType().cast<TensorType>();
2535   auto actualReturnElementType = actualReturnType.getElementType();
2536   if (!actualReturnType.hasRank()) return success();
2537 
2538   auto actualReturnRankedType = actualReturnType.cast<RankedTensorType>();
2539   if (numDims != actualReturnRankedType.getRank())
2540     return emitOpError() << "expects rank of convolution return-type to be "
2541                             "equal to input-ranks ("
2542                          << numDims << "), but got "
2543                          << actualReturnRankedType.getRank() << ".";
2544 
2545   auto expectedReturnShape = inferConvolutionOpReturnShape(*this, *windowOrErr);
2546   auto expectedReturnType =
2547       RankedTensorType::get(expectedReturnShape, actualReturnElementType);
2548   if (failed(verifyCompatibleShape(expectedReturnType, actualReturnRankedType)))
2549     return emitOpError()
2550            << "has shape mismatch between the expected return-type ("
2551            << expectedReturnType << ") and actual return-type ("
2552            << actualReturnRankedType << ").";
2553 
2554   return success();
2555 }
2556 
2557 //===----------------------------------------------------------------------===//
2558 // ConvertOp
2559 //===----------------------------------------------------------------------===//
2560 
build(OpBuilder & builder,OperationState & result,Value operand,Type resultElementTy)2561 void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand,
2562                       Type resultElementTy) {
2563   Type resultTy;
2564   Type operandTy = operand.getType();
2565   if (auto rankedTy = operandTy.dyn_cast<RankedTensorType>()) {
2566     resultTy = RankedTensorType::get(rankedTy.getShape(), resultElementTy);
2567   } else {
2568     resultTy = UnrankedTensorType::get(resultElementTy);
2569   }
2570   build(builder, result, resultTy, operand);
2571 }
2572 
fold(ArrayRef<Attribute> operands)2573 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
2574   auto operandTy = getOperand().getType().cast<TensorType>();
2575   auto resultTy = getResult().getType().cast<TensorType>();
2576   if (operandTy == resultTy) return getOperand();
2577 
2578   // If the result has non-static shape, a convert op is necessary to go from
2579   // static shape to non-static shape.
2580   if (!resultTy.hasStaticShape()) return {};
2581 
2582   // If the operand is constant, we can do the conversion now.
2583   auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>();
2584   if (!elementsAttr) return {};
2585 
2586   // Prevent folding if the result is too large.
2587   if (elementsAttr.getNumElements() > kFoldOpEltLimit) return {};
2588   return hlo::convertElementsAttr(elementsAttr,
2589                                   getElementTypeOrSelf(getResult()));
2590 }
2591 
2592 namespace {
2593 
2594 struct EliminateRedundantConvert : public OpRewritePattern<ConvertOp> {
2595   using OpRewritePattern<ConvertOp>::OpRewritePattern;
matchAndRewritemlir::mhlo::__anon00baf10a1911::EliminateRedundantConvert2596   LogicalResult matchAndRewrite(ConvertOp op,
2597                                 PatternRewriter& rewriter) const override {
2598     auto convertOp = op.operand().getDefiningOp<ConvertOp>();
2599     if (!convertOp) {
2600       return failure();
2601     }
2602     auto firstType =
2603         convertOp.operand().getType().cast<TensorType>().getElementType();
2604     auto secondType =
2605         op.operand().getType().cast<TensorType>().getElementType();
2606     auto thirdType =
2607         op.getResult().getType().cast<TensorType>().getElementType();
2608     auto loc = rewriter.getFusedLoc({convertOp->getLoc(), op->getLoc()});
2609     if (firstType.isa<FloatType>() && secondType.isa<FloatType>() &&
2610         thirdType.isa<FloatType>()) {
2611       // fold when the second float type's width is longer than first,
2612       // like fp16 -> fp32 -> fp64, bf16 -> fp32 -> fp16
2613       if (secondType.cast<FloatType>().getWidth() >
2614           firstType.cast<FloatType>().getWidth()) {
2615         Value result = rewriter.create<ConvertOp>(loc, op.getResult().getType(),
2616                                                   convertOp.operand());
2617         rewriter.replaceOp(op, result);
2618         return success();
2619       }
2620     } else if (firstType.isa<IntegerType>() && secondType.isa<IntegerType>() &&
2621                thirdType.isa<IntegerType>()) {
2622       // fold when the second integer type's width is longer than first,
2623       // like i16 -> i32 -> i64, u16 -> i32 -> u32
2624       if (secondType.cast<IntegerType>().getWidth() >
2625           firstType.cast<IntegerType>().getWidth()) {
2626         Value result = rewriter.create<ConvertOp>(loc, op.getResult().getType(),
2627                                                   convertOp.operand());
2628         rewriter.replaceOp(op, result);
2629         return success();
2630       }
2631     }
2632     return failure();
2633   }
2634 };
2635 
2636 }  // namespace
2637 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2638 void ConvertOp::getCanonicalizationPatterns(RewritePatternSet& results,
2639                                             MLIRContext* context) {
2640   results.add<EliminateIdentityConvert>(context);
2641   results.add<EliminateRedundantConvert>(context);
2642 }
2643 
2644 //===----------------------------------------------------------------------===//
2645 // GetTupleElementOp
2646 //===----------------------------------------------------------------------===//
2647 
verify()2648 LogicalResult GetTupleElementOp::verify() {
2649   auto indexVal = index();
2650   auto operandType = getOperand().getType().cast<TupleType>();
2651   if (indexVal >= operandType.size()) {
2652     return emitOpError(
2653         llvm::formatv("index {0} is out of bounds of operand with size {1}",
2654                       indexVal, operandType.size()));
2655   }
2656 
2657   auto expectedType = operandType.getType(indexVal);
2658   if (getType() != expectedType) {
2659     return emitOpError(llvm::formatv("has return type {0}, but expected {1}",
2660                                      getType(), expectedType));
2661   }
2662   return success();
2663 }
2664 
fold(ArrayRef<Attribute> operands)2665 OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
2666   if (auto tupleOp = getOperand().getDefiningOp<mhlo::TupleOp>()) {
2667     return tupleOp.getOperand(index());
2668   }
2669 
2670   return {};
2671 }
2672 
2673 //===----------------------------------------------------------------------===//
2674 // TupleOp
2675 //===----------------------------------------------------------------------===//
2676 
verify()2677 LogicalResult TupleOp::verify() {
2678   auto opType = getType().dyn_cast<TupleType>();
2679   if (!opType) return emitOpError("tuple op with non-tuple result");
2680   if (getNumOperands() != opType.size())
2681     return emitOpError(
2682         "number of operands to tuple expected to match number of types in "
2683         "resultant tuple type");
2684   for (const auto& it :
2685        llvm::enumerate(llvm::zip_first(getOperandTypes(), opType.getTypes()))) {
2686     if (std::get<0>(it.value()) != std::get<1>(it.value()))
2687       return emitOpError("has return type mismatch at ")
2688              << it.index() << "th value (" << std::get<0>(it.value())
2689              << " != " << std::get<1>(it.value()) << ")";
2690   }
2691   return success();
2692 }
2693 
2694 namespace {
2695 
2696 // Pattern for unpacking and repacking the same tuple.
2697 struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
2698   using OpRewritePattern<TupleOp>::OpRewritePattern;
2699 
matchAndRewritemlir::mhlo::__anon00baf10a1a11::UnpackRepackSameTuple2700   LogicalResult matchAndRewrite(TupleOp op,
2701                                 PatternRewriter& rewriter) const override {
2702     if (op.val().empty()) return failure();
2703 
2704     Value firstElement = op.val().front();
2705     auto firstElementOp = firstElement.getDefiningOp<GetTupleElementOp>();
2706     if (!firstElementOp || firstElementOp.indexAttr().getInt() != 0)
2707       return failure();
2708 
2709     Value tuplePredecessor = firstElementOp.getOperand();
2710     if (tuplePredecessor.getType() != op.getType()) return failure();
2711 
2712     for (const auto& elementAndIdx : llvm::enumerate(op.val().drop_front(1))) {
2713       auto elementOp = elementAndIdx.value().getDefiningOp<GetTupleElementOp>();
2714       if (!elementOp ||
2715           elementOp.indexAttr().getInt() !=
2716               static_cast<int64_t>(elementAndIdx.index() + 1) ||
2717           elementOp.getOperand() != tuplePredecessor)
2718         return failure();
2719     }
2720 
2721     rewriter.replaceOp(op, tuplePredecessor);
2722     return success();
2723   }
2724 };
2725 
2726 }  // namespace
2727 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2728 void TupleOp::getCanonicalizationPatterns(RewritePatternSet& results,
2729                                           MLIRContext* context) {
2730   results.add<UnpackRepackSameTuple>(context);
2731 }
2732 
2733 //===----------------------------------------------------------------------===//
2734 // AllToAllOp
2735 //===----------------------------------------------------------------------===//
2736 
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)2737 LogicalResult AllToAllOp::inferReturnTypeComponents(
2738     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
2739     DictionaryAttr attributes, RegionRange regions,
2740     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
2741   AllToAllOp::Adaptor adaptor(operands, attributes, regions);
2742   Type operandType = adaptor.operand().getType();
2743   RankedTensorType operandRankedType = operandType.dyn_cast<RankedTensorType>();
2744   if (!operandRankedType) {
2745     inferredReturnShapes.emplace_back(
2746         operandType.cast<TensorType>().getElementType());
2747     return success();
2748   }
2749 
2750   int64_t inputRank = operandRankedType.getRank();
2751   int64_t splitDimension = static_cast<int64_t>(adaptor.split_dimension());
2752   int64_t concatDimension = static_cast<int64_t>(adaptor.concat_dimension());
2753   if (splitDimension >= inputRank || splitDimension < 0) {
2754     return emitOptionalError(location, "AllToAll split_dimension ",
2755                              splitDimension,
2756                              " is out-of-bounds for input rank ", inputRank);
2757   }
2758   if (concatDimension >= inputRank || concatDimension < 0) {
2759     return emitOptionalError(location, "AllToAll concat_dimension ",
2760                              concatDimension,
2761                              " is out-of-bounds for input rank ", inputRank);
2762   }
2763 
2764   // If operand is ranked, size of split dimension should be a multiple of split
2765   // count.
2766   int64_t splitCount = adaptor.split_count();
2767   auto splitDimSize = operandRankedType.getDimSize(splitDimension);
2768   if (splitDimSize % splitCount != 0) {
2769     return emitOptionalError(
2770         location, "split dimension has size ", splitDimSize,
2771         ", expected to be a multiple of split_count ", splitCount);
2772   }
2773   SmallVector<int64_t> resultShape(operandRankedType.getShape().begin(),
2774                                    operandRankedType.getShape().end());
2775   resultShape[splitDimension] /= splitCount;
2776   resultShape[concatDimension] *= splitCount;
2777   inferredReturnShapes.emplace_back(resultShape,
2778                                     operandRankedType.getElementType());
2779   return success();
2780 }
2781 
2782 //===----------------------------------------------------------------------===//
2783 // AllGatherOp
2784 //===----------------------------------------------------------------------===//
2785 
verify()2786 LogicalResult AllGatherOp::verify() {
2787   // If operand and result are both ranked, then the size of the gather
2788   // dimension in the result should be a multiple of the size of the gather
2789   // dimension in the operand.
2790   auto operandType = operand().getType().dyn_cast<RankedTensorType>();
2791   auto resultType = getType().dyn_cast<RankedTensorType>();
2792   uint64_t allGatherDimIndex = all_gather_dim();
2793   if (!operandType || !resultType ||
2794       operandType.isDynamicDim(allGatherDimIndex) ||
2795       resultType.isDynamicDim(allGatherDimIndex))
2796     return success();
2797   if (operandType.getDimSize(allGatherDimIndex) == 0)
2798     return emitOpError() << "operand gather dimension cannot be zero.";
2799   if ((resultType.getDimSize(allGatherDimIndex) %
2800        operandType.getDimSize(allGatherDimIndex)) != 0)
2801     return emitOpError()
2802            << "result gather dimension has size "
2803            << resultType.getDimSize(allGatherDimIndex)
2804            << ", expected to be a multiple of operand gather dimension size "
2805            << operandType.getDimSize(allGatherDimIndex);
2806 
2807   return success();
2808 }
2809 
2810 //===----------------------------------------------------------------------===//
2811 // BatchNormGradOp
2812 //===----------------------------------------------------------------------===//
2813 
verify()2814 LogicalResult BatchNormGradOp::verify() {
2815   // The following properties are already enforced by the ODS:
2816   //  1. Inputs 'operand' & 'grad_output' and outputs 'grad_operand',
2817   //     are ranked-tensors with floating-point (fp) type.
2818   //  2. The shapes of inputs 'operand' & 'grad_output' match.
2819   //  3. Inputs 'scale', 'mean', 'variance' and Outputs 'grad_scale',
2820   //     'grad_offset'  are all 1D fp tensors with same shape.
2821   //  4. The element-types of input 'operand' and outputs 'grad_scale',
2822   //     'grad_offset' match.
2823   //  5. The type of input 'operand' and output 'grad_operand' match.
2824   //
2825   // We intend to verify the following properties
2826   //  P1. Inputs 'operand' & 'grad_output' has the same shape with fp
2827   //      element-types, ignoring fp-precision : Inferred from (1) & (2).
2828   //  P2. The feature dimension 'feature_index' is a valid index in 'operand':
2829   //      Inferred from check C2 below.
2830   //  P3. Inputs 'scale', 'mean', 'variance' must be 1D tensors with same shape
2831   //      and fp element-type (ignoring precision) and the number of elements
2832   //      in its sole-dimension == number of features in the 'operand's
2833   //      feature-dimension 'feature_index': Inferred from (3) and check C3
2834   //      below.
2835   //  P4. Outputs 'grad_scale' & 'grad_offset' are 1D tensors with
2836   //      element-type == element-type of(operand) and same shape as any of
2837   //      the inputs 'scale', 'mean', or 'variance': Inferred from (3), (4) and
2838   //      check C3 below.
2839   //  P5. The type (shape + element-type) of input 'operand' and
2840   //      output 'grad_operand' must match: Inferred from (5).
2841 
2842   // C2.
2843   auto operandType = operand().getType().cast<RankedTensorType>();
2844   if (static_cast<int64_t>(feature_index()) >= operandType.getRank())
2845     return emitOpError() << "expects feature_index to be smaller "
2846                             "than the rank of operand type; got feature_index "
2847                          << feature_index() << ", and rank "
2848                          << operandType.getRank() << ".";
2849 
2850   if (static_cast<int64_t>(feature_index()) < 0)
2851     return emitOpError() << "expects feature_index to be a "
2852                          << "non-negative number, got "
2853                          << static_cast<int64_t>(feature_index()) << ".";
2854 
2855   auto gradOutputType = grad_output().getType().cast<RankedTensorType>();
2856   if (operandType.getRank() != gradOutputType.getRank())
2857     return emitOpError() << "expects 'operand' and 'grad_output' to have the "
2858                             "same rank. but got rank(oprand) "
2859                          << operandType.getRank() << " and rank(grad_output) "
2860                          << gradOutputType.getRank() << ".";
2861 
2862   // C3.
2863   const int64_t featureCount = operandType.getShape()[feature_index()];
2864   const int64_t scaleShape =
2865       scale().getType().cast<RankedTensorType>().getShape()[0];
2866   if (scaleShape != featureCount)
2867     return emitOpError() << "expects the size of scale factor to be "
2868                             "same as the feature count,"
2869                             " but the size of scale factor is "
2870                          << scaleShape << " and the feature count is "
2871                          << featureCount << ".";
2872 
2873   return success();
2874 }
2875 
2876 //===----------------------------------------------------------------------===//
2877 // BatchNormTrainingOp
2878 //===----------------------------------------------------------------------===//
2879 
verify()2880 LogicalResult BatchNormTrainingOp::verify() {
2881   // The following properties are already enforced by the ODS:
2882   //  1. 'operand' and 'output' are ranked tensors.
2883   //  2. 'scale', 'offset', 'batch_mean', 'batch_var' are 1D tensors.
2884   //  3. Types of 'operand' and 'output' matches.
2885   //  4. Same element-types for 'operand', 'batch_mean', & 'batch_var'.
2886   //  5. Same shapes for 'scale', 'offset', 'batch_mean', & 'batch_var'.
2887 
2888   auto operandType = operand().getType().cast<RankedTensorType>();
2889   if (static_cast<int64_t>(feature_index()) >= operandType.getRank())
2890     return emitOpError() << "expects feature_index to be smaller "
2891                             "than the rank of operand type; got feature_index "
2892                          << feature_index() << ", and rank "
2893                          << operandType.getRank() << ".";
2894 
2895   if (static_cast<int64_t>(feature_index()) < 0)
2896     return emitOpError() << "expects feature_index to be a "
2897                          << "non-negative number, got "
2898                          << static_cast<int64_t>(feature_index()) << ".";
2899 
2900   // Note:A valid value of feature-index implies 'operand_type.getRank() >=1'.
2901 
2902   const int64_t featureCount = operandType.getShape()[feature_index()];
2903   const int64_t scaleShape =
2904       scale().getType().cast<RankedTensorType>().getShape()[0];
2905   // Check number of elements in input 'scale' equals feature_count.
2906   // Together with (5) implies that 'scale', 'offset', 'batch_mean', &
2907   // 'batch_var' all have the same shape.
2908   if (scaleShape != featureCount)
2909     return emitOpError() << "expects the size of scale factor to be "
2910                             "same as the feature count,"
2911                             " but the size of scale factor is "
2912                          << scaleShape << " and the feature count is "
2913                          << featureCount << ".";
2914 
2915   return success();
2916 }
2917 
2918 //===----------------------------------------------------------------------===//
2919 // BatchNormInferenceOp
2920 //===----------------------------------------------------------------------===//
2921 
verify()2922 LogicalResult BatchNormInferenceOp::verify() {
2923   // The following properties are already enforced by the ODS:
2924   //  1. 'operand' and 'result' are ranked tensors.
2925   //  2. 'scale', 'offset', 'mean', 'variance' are 1D tensors.
2926   //  3. Types of 'operand' and 'result' matches.
2927   //  4. Same shapes for 'scale', 'offset', 'mean', & 'variance'.
2928 
2929   auto operandType = operand().getType().cast<RankedTensorType>();
2930   if (static_cast<int64_t>(feature_index()) >= operandType.getRank())
2931     return emitOpError() << "expects feature_index to be smaller "
2932                             "than the rank of operand type; got feature_index "
2933                          << feature_index() << ", and rank "
2934                          << operandType.getRank() << ".";
2935 
2936   if (static_cast<int64_t>(feature_index()) < 0)
2937     return emitOpError() << "expects feature_index to be a "
2938                          << "non-negative number, got "
2939                          << static_cast<int64_t>(feature_index()) << ".";
2940 
2941   // Note:A valid value of feature-index implies 'operand_type.getRank() >=1'.
2942 
2943   const int64_t featureCount = operandType.getShape()[feature_index()];
2944   const int64_t scaleSize =
2945       scale().getType().cast<RankedTensorType>().getShape()[0];
2946   // Check number of elements in input 'scale' equals feature_count.
2947   // Together with (4) implies that 'scale', 'offset', 'mean', &
2948   // 'variance' all have the same shape.
2949   if (scaleSize != featureCount)
2950     return emitOpError() << "expects the size of scale factor to be "
2951                             "same as the feature count,"
2952                             " but the size of scale factor is "
2953                          << scaleSize << " and the feature count is "
2954                          << featureCount << ".";
2955 
2956   return success();
2957 }
2958 
2959 //===----------------------------------------------------------------------===//
2960 // BitcastConvertOp
2961 //===----------------------------------------------------------------------===//
2962 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2963 LogicalResult BitcastConvertOp::reifyReturnTypeShapes(
2964     OpBuilder& builder, ValueRange operands,
2965     SmallVectorImpl<Value>& reifiedReturnShapes) {
2966   auto operandType = operands[0].getType().dyn_cast<RankedTensorType>();
2967   auto resultType = getType().dyn_cast<RankedTensorType>();
2968 
2969   // Only ranked tensors are supported.
2970   if (!operandType || !resultType) return failure();
2971 
2972   // Shape-changing bitcast convert is not implemented.
2973   // TODO(kramerb): This could be done by adjusting the last dimension.
2974   DataLayout dataLayout = DataLayout::closest(*this);
2975   unsigned operandElementSize =
2976       dataLayout.getTypeSizeInBits(operandType.getElementType());
2977   unsigned resultElementSize =
2978       dataLayout.getTypeSizeInBits(resultType.getElementType());
2979   if (operandElementSize != resultElementSize) return failure();
2980 
2981   return ::mlir::mhlo::deriveShapeFromOperand(
2982       &builder, getOperation(), operands.front(), &reifiedReturnShapes);
2983 }
2984 
2985 /*
2986  * We intend to verify the following properties
2987  * P1. We cannot convert between complex and real types (cf xla)
2988  * P3. The dimensions of the operand and the target
2989  * shape must match, except that the shape with the smaller element bitwidth has
2990  * an appropriately-sized additional innermost dimension, e.g.
2991  * ... x f32 => [bitcast_convert] => ... x 4 x i8
2992  * ... x 4 x i8 => [bitcast_convert] => ... x f32
2993  */
verify()2994 LogicalResult BitcastConvertOp::verify() {
2995   auto operandTensorType = operand().getType().cast<TensorType>();
2996   auto targetTensorType = getResult().getType().cast<TensorType>();
2997 
2998   // P1.
2999   auto targetElt = targetTensorType.getElementType();
3000   auto operandElt = operandTensorType.getElementType();
3001   if (targetElt.isa<ComplexType>() != operandElt.isa<ComplexType>()) {
3002     return emitOpError()
3003            << "cannot convert between real and complex types, but got: "
3004            << operandTensorType << " and " << targetTensorType;
3005   }
3006 
3007   auto targetEltBitwidth = potentiallyComplexBitwidth(targetElt);
3008   auto operandEltBitwidth = potentiallyComplexBitwidth(operandElt);
3009 
3010   // P2.
3011   auto operandType = operandTensorType.dyn_cast<RankedTensorType>();
3012   auto targetType = targetTensorType.dyn_cast<RankedTensorType>();
3013   if (!operandType || !targetType) return success();
3014 
3015   auto targetShape = targetType.getShape();
3016   auto operandShape = operandType.getShape();
3017   ArrayRef<int64_t> smallerEltShape, biggerEltShape;
3018   Type smallerElt, biggerElt;
3019   if (operandEltBitwidth < targetEltBitwidth) {
3020     smallerEltShape = operandShape;
3021     smallerElt = operandElt;
3022     biggerEltShape = targetShape;
3023     biggerElt = targetElt;
3024   } else {
3025     smallerEltShape = targetShape;
3026     smallerElt = targetElt;
3027     biggerEltShape = operandShape;
3028     biggerElt = operandElt;
3029   }
3030 
3031   ArrayRef<int64_t> smallerEltPrefix;
3032   auto smallerEltBitwidth = std::min(targetEltBitwidth, operandEltBitwidth);
3033   auto biggerEltBitwidth = std::max(targetEltBitwidth, operandEltBitwidth);
3034   if (operandEltBitwidth != targetEltBitwidth) {
3035     if (smallerEltShape.empty()) {
3036       return emitOpError() << "does not allow the smaller element type to be "
3037                               "part of a 0d tensor, but got: "
3038                            << operandType << " and " << targetType << ".";
3039     }
3040     smallerEltPrefix = smallerEltShape.drop_back();
3041     if (!isDynamicDimSize(smallerEltShape.back()) &&
3042         smallerEltShape.back() * smallerEltBitwidth != biggerEltBitwidth) {
3043       return emitOpError() << "requires compatible bitwidths. "
3044                            << "Got: " << operandType << " and " << targetType
3045                            << ", but " << smallerEltBitwidth << " * "
3046                            << smallerEltShape.back()
3047                            << " != " << biggerEltBitwidth << ".";
3048     }
3049   } else {
3050     smallerEltPrefix = smallerEltShape;
3051   }
3052 
3053   for (auto it : llvm::zip(smallerEltPrefix, biggerEltShape)) {
3054     auto targetDim = std::get<0>(it);
3055     auto operandDim = std::get<1>(it);
3056     if (!isDynamicDimSize(targetDim) && !isDynamicDimSize(operandDim)) {
3057       if (targetDim != operandDim) {
3058         return emitOpError() << "operand and result shapes must match except "
3059                                 "for the innermost dimension of the shape with "
3060                                 "the smaller element type. Got: "
3061                              << operandType << " and " << targetType << ".";
3062       }
3063     }
3064   }
3065 
3066   return success();
3067 }
3068 
3069 //===----------------------------------------------------------------------===//
3070 // BroadcastOp
3071 //===----------------------------------------------------------------------===//
3072 
3073 // TODO(b/129012527) These should be expressed as type constraints.
verify()3074 LogicalResult BroadcastOp::verify() {
3075   auto sizes = broadcast_sizes();
3076   auto sizesType = sizes.getType();
3077   auto sizesRank = sizesType.getRank();
3078   if (sizesRank != 1) {
3079     return emitOpError(llvm::formatv(
3080         "broadcast_sizes has rank {0} instead of rank 1", sizesRank));
3081   }
3082 
3083   return success();
3084 }
3085 
fold(ArrayRef<Attribute> attrs)3086 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> attrs) {
3087   auto type = getType().cast<RankedTensorType>();
3088   auto sizesType = broadcast_sizes().getType();
3089   if (sizesType.getNumElements() == 0) {
3090     return getOperand();
3091   }
3092 
3093   // Constant fold when an operand is a splat tensor attribute.
3094   if (!attrs[0] || !type.hasStaticShape()) return {};
3095   auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>();
3096   if (!splatOperandAttr) return {};
3097 
3098   // Handle complex type
3099   if (type.getElementType().isa<ComplexType>()) {
3100     ComplexType complex = type.getElementType().cast<ComplexType>();
3101     if (complex.getElementType().isa<FloatType>()) {
3102       return DenseElementsAttr::get(
3103           type, {splatOperandAttr.getSplatValue<std::complex<APFloat>>()});
3104     }
3105     if (complex.getElementType().isa<IntegerType>()) {
3106       return DenseElementsAttr::get(
3107           type, {splatOperandAttr.getSplatValue<std::complex<APInt>>()});
3108     }
3109     return {};
3110   }
3111 
3112   return SplatElementsAttr::get(
3113       type, splatOperandAttr.getSplatValue<mlir::Attribute>());
3114 }
3115 
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)3116 LogicalResult BroadcastOp::inferReturnTypeComponents(
3117     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
3118     DictionaryAttr attributes, RegionRange regions,
3119     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
3120   BroadcastOp::Adaptor adaptor(operands, attributes, regions);
3121   Value operand = adaptor.operand();
3122   auto operandType = operand.getType().dyn_cast<RankedTensorType>();
3123   if (!operandType) return failure();
3124 
3125   Type elementTy = operandType.getElementType();
3126   auto dimensionAttr = adaptor.broadcast_sizes();
3127   for (int64_t size : dimensionAttr.getValues<int64_t>()) {
3128     if (size < 0)
3129       return emitOptionalError(location,
3130                                "Broadcast with negative dimension size ", size);
3131   }
3132   SmallVector<int64_t> shapeValues(dimensionAttr.getValues<int64_t>());
3133   llvm::append_range(shapeValues, operandType.getShape());
3134 
3135   inferredReturnShapes.emplace_back(shapeValues, elementTy);
3136   return success();
3137 }
3138 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3139 LogicalResult BroadcastOp::reifyReturnTypeShapes(
3140     OpBuilder& builder, ValueRange operands,
3141     SmallVectorImpl<Value>& reifiedReturnShapes) {
3142   BroadcastOp::Adaptor adaptor(operands);
3143   Value operand = adaptor.operand();
3144 
3145   auto operandType = operand.getType().dyn_cast<RankedTensorType>();
3146   // Unranked tensors are not supported.
3147   if (!operandType) return failure();
3148 
3149   Location loc = getLoc();
3150   SmallVector<Value, 4> shapeValues;
3151 
3152   // Collect the broadcast sizes.
3153   for (const auto& size : broadcast_sizes()) {
3154     shapeValues.push_back(
3155         builder.create<arith::ConstantIndexOp>(loc, size.getZExtValue()));
3156   }
3157 
3158   // Collect the operand sizes.
3159   for (auto index : llvm::seq<int64_t>(0, operandType.getRank())) {
3160     shapeValues.push_back(
3161         builder.createOrFold<tensor::DimOp>(loc, operand, index));
3162   }
3163 
3164   reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
3165       loc,
3166       RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
3167                             builder.getIndexType()),
3168       shapeValues));
3169 
3170   return success();
3171 }
3172 
3173 //===----------------------------------------------------------------------===//
3174 // BroadcastInDimOp
3175 //===----------------------------------------------------------------------===//
3176 
verify()3177 LogicalResult BroadcastInDimOp::verify() {
3178   auto operandType = operand().getType().dyn_cast<RankedTensorType>();
3179   if (!operandType) {
3180     // The following verification checks all depend on knowing the rank of
3181     // the operand. Bail out now if we don't know the rank of the operand.
3182     return success();
3183   }
3184 
3185   auto operandRank = operandType.getRank();
3186   if (!broadcast_dimensions()) {
3187     if (operandRank == 0) {
3188       return success();
3189     }
3190     return emitOpError(
3191         llvm::formatv("broadcast_dimensions is absent, but required because "
3192                       "operand has non-zero rank ({0})",
3193                       operandRank));
3194   }
3195 
3196   auto dimensions = broadcast_dimensions();
3197   auto dimensionsType = broadcast_dimensions().getType();
3198   auto dimensionsRank = dimensionsType.getRank();
3199   if (dimensionsRank != 1) {
3200     return emitOpError(llvm::formatv(
3201         "broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank));
3202   }
3203 
3204   auto dimensionsSize = dimensionsType.getNumElements();
3205   if (dimensionsSize != operandRank) {
3206     return emitOpError(llvm::formatv(
3207         "broadcast_dimensions size ({0}) does not match operand rank ({1})",
3208         dimensionsSize, operandRank));
3209   }
3210 
3211   auto resultType = getResult().getType().cast<RankedTensorType>();
3212   auto resultRank = resultType.getRank();
3213   if (resultRank < operandRank) {
3214     return emitOpError(
3215         llvm::formatv("result rank ({0}) is less than operand rank ({1})",
3216                       resultRank, operandRank));
3217   }
3218 
3219   for (int i = 0; i != dimensionsSize; ++i) {
3220     auto dimIndex = dimensions.getValues<int64_t>()[i];
3221     if (dimIndex >= resultRank) {
3222       return emitOpError(
3223           llvm::formatv("broadcast_dimensions contains invalid value {0} for "
3224                         "result with rank {1}",
3225                         dimIndex, resultRank));
3226     }
3227 
3228     if (!operandType.isDynamicDim(i)) {
3229       auto dimSize = operandType.getDimSize(i);
3230       auto resultDimSize = resultType.getDimSize(dimIndex);
3231       if (dimSize != 1 && dimSize != resultDimSize) {
3232         return emitOpError(
3233             llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
3234                           "1 or size of result dimension {2} ({3})",
3235                           i, dimSize, dimIndex, resultDimSize));
3236       }
3237     }
3238   }
3239 
3240   return success();
3241 }
3242 
fold(ArrayRef<Attribute> attrs)3243 OpFoldResult BroadcastInDimOp::fold(ArrayRef<Attribute> attrs) {
3244   auto type = getType().cast<RankedTensorType>();
3245   if (type == getOperand().getType()) {
3246     auto broadcastValues = broadcast_dimensions().getValues<int64_t>();
3247     if (!std::equal(broadcastValues.begin(), broadcastValues.end(),
3248                     llvm::seq<int64_t>(0, type.getRank()).begin())) {
3249       return {};
3250     }
3251     return getOperand();
3252   }
3253 
3254   // Constant fold when an operand is a splat tensor attribute.
3255   if (!attrs[0] || !type.hasStaticShape()) return {};
3256   auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>();
3257   if (!splatOperandAttr) return {};
3258 
3259   // Handle complex type
3260   if (type.getElementType().isa<ComplexType>()) {
3261     ComplexType complex = type.getElementType().cast<ComplexType>();
3262     if (complex.getElementType().isa<FloatType>()) {
3263       return DenseElementsAttr::get(
3264           type, {splatOperandAttr.getSplatValue<std::complex<APFloat>>()});
3265     }
3266     if (complex.getElementType().isa<IntegerType>()) {
3267       return DenseElementsAttr::get(
3268           type, {splatOperandAttr.getSplatValue<std::complex<APInt>>()});
3269     }
3270     return {};
3271   }
3272 
3273   return SplatElementsAttr::get(
3274       type, splatOperandAttr.getSplatValue<mlir::Attribute>());
3275 }
3276 
3277 // Simplify BroadcastInDim has the following behaviors: replace BroadcastInDim
3278 // with Reshape or Transpose if they are equivalent or replace
3279 // BroadcastInDim(BroadcastInDim(X)) with BroadcastInDim(X)
3280 class BroadcastInDimSimplifier : public OpRewritePattern<BroadcastInDimOp> {
3281  public:
3282   using OpRewritePattern<BroadcastInDimOp>::OpRewritePattern;
matchAndRewrite(BroadcastInDimOp op,PatternRewriter & rewriter) const3283   LogicalResult matchAndRewrite(BroadcastInDimOp op,
3284                                 PatternRewriter& rewriter) const override {
3285     auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
3286     auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
3287     if (!operandType || !resultType) {
3288       return failure();
3289     }
3290     auto bsDimIndices = op.broadcast_dimensions().getValues<int64_t>();
3291     if (operandType.hasStaticShape() && resultType.hasStaticShape()) {
3292       bool sameTotalElements =
3293           operandType.getNumElements() == resultType.getNumElements();
3294       // BroadcastInDim equivalent to reshape
3295       if (llvm::is_sorted(bsDimIndices) && sameTotalElements) {
3296         rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand());
3297         return success();
3298       }
3299       // BroadcastInDim equivalent to transpose
3300       if (operandType.getRank() == resultType.getRank() && sameTotalElements) {
3301         rewriter.replaceOpWithNewOp<TransposeOp>(op, op.getType(), op.operand(),
3302                                                  op.broadcast_dimensions());
3303         return success();
3304       }
3305     }
3306     // eliminate redundant BroadcastInDim
3307     if (auto broadcastInDimOp = llvm::dyn_cast_or_null<BroadcastInDimOp>(
3308             op.operand().getDefiningOp())) {
3309       auto newIndices =
3310           broadcastInDimOp.broadcast_dimensions()
3311               .mapValues(op.broadcast_dimensions().getElementType(),
3312                          [&bsDimIndices](const APInt& dim) -> APInt {
3313                            return APInt(dim.getBitWidth(),
3314                                         bsDimIndices[dim.getSExtValue()], true);
3315                          })
3316               .cast<DenseIntElementsAttr>();
3317       rewriter.replaceOpWithNewOp<BroadcastInDimOp>(
3318           op, op.getType(), broadcastInDimOp.operand(), newIndices);
3319       return success();
3320     }
3321     return failure();
3322   }
3323 };
3324 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3325 void BroadcastInDimOp::getCanonicalizationPatterns(RewritePatternSet& results,
3326                                                    MLIRContext* context) {
3327   results.add<BroadcastInDimSimplifier>(context);
3328 }
3329 
3330 //===----------------------------------------------------------------------===//
3331 // DynamicBroadcastInDimOp
3332 //===----------------------------------------------------------------------===//
3333 
verify()3334 LogicalResult DynamicBroadcastInDimOp::verify() {
3335   auto operandType = operand().getType().dyn_cast<RankedTensorType>();
3336   auto resultType = getResult().getType().dyn_cast<RankedTensorType>();
3337 
3338   // If either the operand or result are unranked, there is very little
3339   // to verify statically.
3340   if (!operandType || !resultType) {
3341     return success();
3342   }
3343 
3344   auto outputDimensionsType =
3345       output_dimensions().getType().cast<RankedTensorType>();
3346   auto outputDimensionsSize = outputDimensionsType.getDimSize(0);
3347   auto operandRank = operandType.getRank();
3348   auto resultRank = resultType.getRank();
3349 
3350   // Verify broadcast_dimensions.
3351   auto bcastDimensions = broadcast_dimensions();
3352   auto bcastDimensionsType = broadcast_dimensions().getType();
3353   auto bcastDimensionsRank = bcastDimensionsType.getRank();
3354   // TODO(laurenzo): Update the BroadcastDimAttr to constrain its rank to 1.
3355   if (bcastDimensionsRank != 1) {
3356     return emitOpError(
3357         llvm::formatv("broadcast_dimensions has rank {0} instead of rank 1",
3358                       bcastDimensionsRank));
3359   }
3360 
3361   auto bcastDimensionsSize = bcastDimensionsType.getNumElements();
3362   if (bcastDimensionsSize != operandRank) {
3363     return emitOpError(llvm::formatv(
3364         "broadcast_dimensions size ({0}) does not match operand rank ({1})",
3365         bcastDimensionsSize, operandRank));
3366   }
3367 
3368   if (resultRank < operandRank) {
3369     return emitOpError(
3370         llvm::formatv("result rank ({0}) is less than operand rank ({1})",
3371                       resultRank, operandRank));
3372   }
3373 
3374   for (int i = 0; i != bcastDimensionsSize; ++i) {
3375     auto dimIndex = bcastDimensions.getValues<int64_t>()[i];
3376     if (dimIndex >= resultRank) {
3377       return emitOpError(
3378           llvm::formatv("broadcast_dimensions contains invalid value {0} for "
3379                         "result with rank {1}",
3380                         dimIndex, resultRank));
3381     }
3382 
3383     auto dimSize = operandType.getDimSize(i);
3384     auto resultDimSize = resultType.getDimSize(dimIndex);
3385     // Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we
3386     // add a manual check for this.
3387     if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) {
3388       return emitOpError(
3389           llvm::formatv("size of operand dimension {0} ({1}) is not compatible "
3390                         "with size of result dimension {2} ({3})",
3391                         i, dimSize, dimIndex, resultDimSize));
3392     }
3393   }
3394 
3395   if (outputDimensionsSize != resultRank) {
3396     return emitOpError(
3397         llvm::formatv("result rank ({0}) is not equal to number of output "
3398                       "dimensions ({1})",
3399                       resultRank, outputDimensionsSize));
3400   }
3401 
3402   // Verify that the known expanding and non-expanding dimensions are a subset
3403   // of the operand's dimensions.
3404   int64_t numKnownExpansionBehavior = 0;
3405   DenseSet<int64_t> knownExpansionBehavior;
3406   auto collectExpansionBehaviorDims =
3407       [&](const Optional<DenseIntElementsAttr>& attr) {
3408         if (!attr) return;
3409         for (const APInt& it : *attr) {
3410           numKnownExpansionBehavior++;
3411           knownExpansionBehavior.insert(it.getLimitedValue());
3412         }
3413       };
3414   collectExpansionBehaviorDims(known_expanding_dimensions());
3415   collectExpansionBehaviorDims(known_nonexpanding_dimensions());
3416   if (knownExpansionBehavior.size() != numKnownExpansionBehavior) {
3417     return emitOpError(
3418         "duplicate expansion hint for at least one operand dimension");
3419   }
3420   for (int64_t i : knownExpansionBehavior) {
3421     if (i < 0 || i >= operandRank) {
3422       return emitOpError(
3423           llvm::formatv("hint for expanding dimension {0} does not refer to a "
3424                         "valid operand dimension",
3425                         i));
3426     }
3427   }
3428 
3429   return success();
3430 }
3431 
3432 namespace {
3433 // If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary
3434 // BroadcastInDimOp.
3435 class DynamicBroadcastInDimOpNotActuallyDynamic
3436     : public OpRewritePattern<DynamicBroadcastInDimOp> {
3437   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicBroadcastInDimOp op,PatternRewriter & rewriter) const3438   LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op,
3439                                 PatternRewriter& rewriter) const override {
3440     auto type = op.getType().dyn_cast<RankedTensorType>();
3441     auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
3442     auto* outputDimOp = op.output_dimensions().getDefiningOp();
3443     if (!type || !operandType || !operandType.hasStaticShape()) {
3444       return rewriter.notifyMatchFailure(op, "requires operand static shape");
3445     }
3446     // output has static shape, replace with broadcast_in_dim
3447     if (type.hasStaticShape()) {
3448       rewriter.replaceOpWithNewOp<BroadcastInDimOp>(op, type, op.operand(),
3449                                                     op.broadcast_dimensions());
3450       return success();
3451     }
3452     // output_dimensions are constant, set output shape with output_dimensions,
3453     // then replace with broadcast_in_dim
3454     if (outputDimOp && outputDimOp->hasTrait<mlir::OpTrait::ConstantLike>()) {
3455       DenseIntElementsAttr shapeAttr;
3456       if (matchPattern(outputDimOp, m_Constant(&shapeAttr))) {
3457         SmallVector<int64_t> outputShape;
3458         for (APInt shape : shapeAttr.getValues<APInt>()) {
3459           outputShape.push_back(shape.getZExtValue());
3460         }
3461         refineOpWithNewOp<BroadcastInDimOp>(
3462             rewriter, op,
3463             RankedTensorType::get(outputShape, type.getElementType()),
3464             op.operand(), op.broadcast_dimensions());
3465         return success();
3466       }
3467     }
3468     return rewriter.notifyMatchFailure(
3469         op, "requires output static shape or constant broadcast dimensions");
3470   }
3471 };
3472 
3473 class ChainedDynamicBroadcastInDimCanonicalization
3474     : public OpRewritePattern<DynamicBroadcastInDimOp> {
3475   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicBroadcastInDimOp bcast,PatternRewriter & rewriter) const3476   LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast,
3477                                 PatternRewriter& rewriter) const override {
3478     auto precedingBcast =
3479         bcast.operand().getDefiningOp<DynamicBroadcastInDimOp>();
3480     if (!precedingBcast) return failure();
3481 
3482     // Compose broadcast dimensions.
3483     DenseIntElementsAttr precedingBcastDims =
3484         precedingBcast.broadcast_dimensions();
3485     DenseIntElementsAttr bcastDims = bcast.broadcast_dimensions();
3486     SmallVector<APInt, 4> composition;
3487     for (APInt precedingDim : precedingBcastDims) {
3488       composition.push_back(
3489           bcastDims.getValues<APInt>()[precedingDim.getZExtValue()]);
3490     }
3491     auto composedBcastDims =
3492         DenseIntElementsAttr::get(precedingBcastDims.getType(), composition);
3493 
3494     rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
3495         bcast, bcast.getType(), precedingBcast.operand(),
3496         bcast.output_dimensions(), composedBcastDims);
3497     return success();
3498   }
3499 };
3500 }  // namespace
3501 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3502 void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
3503     RewritePatternSet& results, MLIRContext* context) {
3504   results.add<ChainedDynamicBroadcastInDimCanonicalization,
3505               DynamicBroadcastInDimOpNotActuallyDynamic,
3506               DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2,
3507               DynamicBroadcastToOwnShape_3, DynamicBroadcastToOwnShape_4>(
3508       context);
3509 }
3510 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3511 LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
3512     OpBuilder& builder, ValueRange operands,
3513     SmallVectorImpl<Value>& reifiedReturnShapes) {
3514   DynamicBroadcastInDimOp::Adaptor adaptor(operands);
3515   reifiedReturnShapes.push_back(
3516       castToIndexTensor(builder, getLoc(), adaptor.output_dimensions()));
3517   return success();
3518 }
3519 
3520 //===----------------------------------------------------------------------===//
3521 // ClampOp
3522 //===----------------------------------------------------------------------===//
3523 
verify()3524 LogicalResult ClampOp::verify() {
3525   auto operandType = operand().getType().cast<RankedTensorType>();
3526   auto operandShape = operandType.getShape();
3527   auto minType = min().getType().cast<RankedTensorType>();
3528 
3529   auto minShape = minType.getShape();
3530   if (failed(verifyCompatibleShape(minType, operandType)) &&
3531       minType.getRank() != 0) {
3532     return emitOpError(llvm::formatv(
3533         "min shape [{0}] is not scalar and is not compatible to operand shape "
3534         "[{1}]",
3535         llvm::make_range(minShape.begin(), minShape.end()),
3536         llvm::make_range(operandShape.begin(), operandShape.end())));
3537   }
3538 
3539   auto maxType = max().getType().cast<RankedTensorType>();
3540   auto maxShape = maxType.getShape();
3541   if (failed(verifyCompatibleShape(maxType, operandType)) &&
3542       maxType.getRank() != 0) {
3543     return emitOpError(llvm::formatv(
3544         "max shape [{0}] is not scalar and is not compatible to operand shape "
3545         "[{1}]",
3546         llvm::make_range(maxShape.begin(), maxShape.end()),
3547         llvm::make_range(operandShape.begin(), operandShape.end())));
3548   }
3549 
3550   return success();
3551 }
3552 
inferReturnTypeComponents(MLIRContext *,Optional<Location>,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)3553 LogicalResult ClampOp::inferReturnTypeComponents(
3554     MLIRContext*, Optional<Location> /*location*/, ValueShapeRange operands,
3555     DictionaryAttr attributes, RegionRange regions,
3556     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
3557   ClampOp::Adaptor adaptor(operands, attributes, regions);
3558   RankedTensorType operandType =
3559       adaptor.operand().getType().cast<RankedTensorType>();
3560   inferredReturnShapes.emplace_back(operandType.getShape(),
3561                                     operandType.getElementType());
3562   return success();
3563 }
3564 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3565 LogicalResult ClampOp::reifyReturnTypeShapes(
3566     OpBuilder& builder, ValueRange operands,
3567     SmallVectorImpl<Value>& reifiedReturnShapes) {
3568   // For `mhlo.clamp`, the first operand may be a scalar.
3569   return deriveShapeFromOperand(&builder, getOperation(), operands[1],
3570                                 &reifiedReturnShapes);
3571 }
3572 
3573 //===----------------------------------------------------------------------===//
3574 // ComplexOp
3575 //===----------------------------------------------------------------------===//
3576 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)3577 LogicalResult ComplexOp::inferReturnTypes(
3578     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
3579     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
3580   TensorType operandType = operands[0].getType().cast<TensorType>();
3581   ComplexType elementTy = ComplexType::get(operandType.getElementType());
3582   inferredReturnTypes.push_back(getSameShapeTensorType(operandType, elementTy));
3583   return success();
3584 }
3585 
fold(ArrayRef<Attribute> operands)3586 OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
3587   auto realOp = getOperand(0).getDefiningOp<mhlo::RealOp>();
3588   auto imagOp = getOperand(1).getDefiningOp<mhlo::ImagOp>();
3589   if (realOp && imagOp && realOp.getOperand() == imagOp.getOperand()) {
3590     return realOp.getOperand();
3591   }
3592 
3593   return {};
3594 }
3595 
3596 //===----------------------------------------------------------------------===//
3597 // ImagOp
3598 //===----------------------------------------------------------------------===//
3599 
3600 namespace {
createRealType(TensorType type)3601 Type createRealType(TensorType type) {
3602   auto elementTy = type.getElementType();
3603   if (auto complexTy = elementTy.dyn_cast<ComplexType>()) {
3604     elementTy = complexTy.getElementType();
3605   }
3606   return getSameShapeTensorType(type, elementTy);
3607 }
3608 }  // namespace
3609 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)3610 LogicalResult ImagOp::inferReturnTypes(
3611     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
3612     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
3613   inferredReturnTypes.push_back(
3614       createRealType(operands[0].getType().cast<TensorType>()));
3615   return success();
3616 }
3617 
fold(ArrayRef<Attribute> operands)3618 OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
3619   if (auto complexOp = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
3620     return complexOp.getOperand(1);
3621   }
3622 
3623   return {};
3624 }
3625 
3626 //===----------------------------------------------------------------------===//
3627 // IsFiniteOp
3628 //===----------------------------------------------------------------------===//
3629 
getSameShapeTensorType(TensorType tensorType,Type elementType)3630 TensorType getSameShapeTensorType(TensorType tensorType, Type elementType) {
3631   if (auto rankedTensorTy = tensorType.dyn_cast<RankedTensorType>()) {
3632     return RankedTensorType::get(rankedTensorTy.getShape(), elementType,
3633                                  rankedTensorTy.getEncoding());
3634   }
3635   if (auto unrankedTensorTy = tensorType.dyn_cast<UnrankedTensorType>()) {
3636     return UnrankedTensorType::get(elementType);
3637   }
3638   llvm_unreachable("unhandled type");
3639 }
3640 
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)3641 LogicalResult IsFiniteOp::inferReturnTypes(
3642     MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
3643     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
3644   auto argTy = operands.front().getType().cast<TensorType>();
3645   Builder b(ctx);
3646   inferredReturnTypes.push_back(getSameShapeTensorType(argTy, b.getI1Type()));
3647   return success();
3648 }
3649 
3650 //===----------------------------------------------------------------------===//
3651 // RealOp
3652 //===----------------------------------------------------------------------===//
3653 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)3654 LogicalResult RealOp::inferReturnTypes(
3655     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
3656     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
3657   inferredReturnTypes.push_back(
3658       createRealType(operands[0].getType().cast<TensorType>()));
3659   return success();
3660 }
3661 
fold(ArrayRef<Attribute> operands)3662 OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
3663   if (auto complexOp = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
3664     return complexOp.getOperand(0);
3665   }
3666 
3667   return {};
3668 }
3669 
3670 //===----------------------------------------------------------------------===//
3671 // ConcatenateOp
3672 //===----------------------------------------------------------------------===//
3673 
3674 namespace {
3675 class SingleOperandConcatenateToCast : public OpRewritePattern<ConcatenateOp> {
3676  public:
3677   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(ConcatenateOp op,PatternRewriter & rewriter) const3678   LogicalResult matchAndRewrite(ConcatenateOp op,
3679                                 PatternRewriter& rewriter) const override {
3680     if (op.val().size() != 1) return failure();
3681 
3682     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
3683                                                 op.val().front());
3684     return success();
3685   }
3686 };
3687 
3688 class ConcatenateOperandRemoval : public OpRewritePattern<ConcatenateOp> {
3689  public:
3690   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(ConcatenateOp op,PatternRewriter & rewriter) const3691   LogicalResult matchAndRewrite(ConcatenateOp op,
3692                                 PatternRewriter& rewriter) const override {
3693     auto axis = op.dimension();
3694     llvm::SmallVector<Value, 6> newOperands;
3695     for (auto operand : op.getOperands()) {
3696       auto ty = operand.getType().cast<ShapedType>();
3697       if (!ty.hasRank() || ty.getDimSize(axis) != 0) {
3698         newOperands.push_back(operand);
3699       }
3700     }
3701 
3702     if (!newOperands.empty() && newOperands.size() < op.getNumOperands()) {
3703       rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(),
3704                                                  newOperands, op.dimension());
3705       return success();
3706     }
3707 
3708     return failure();
3709   }
3710 };
3711 
3712 class ConcatenateForwarding : public OpRewritePattern<ConcatenateOp> {
3713   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(ConcatenateOp op,PatternRewriter & rewriter) const3714   LogicalResult matchAndRewrite(ConcatenateOp op,
3715                                 PatternRewriter& rewriter) const override {
3716     auto getFlattenedOperands = [&](const Value& val) -> ValueRange {
3717       auto definingOp = dyn_cast_or_null<ConcatenateOp>(val.getDefiningOp());
3718       // To avoid inflate the memory footprint, only flatten the ConcatenateOp
3719       // when it has only one use.
3720       if (definingOp && definingOp->hasOneUse() &&
3721           definingOp.dimension() == op.dimension())
3722         return definingOp.val();
3723       return val;
3724     };
3725 
3726     bool needToFlatten = false;
3727     int operandCount = 0;
3728     llvm::for_each(op.val(), [&](Value val) {
3729       auto result = getFlattenedOperands(val);
3730       if (result.size() != 1 || result[0] != val) needToFlatten = true;
3731       operandCount += result.size();
3732     });
3733 
3734     if (!needToFlatten) return failure();
3735 
3736     llvm::SmallVector<Value, 6> newOperands;
3737     newOperands.reserve(operandCount);
3738 
3739     for (auto operand : op.val()) {
3740       auto flattenedOperands = getFlattenedOperands(operand);
3741       newOperands.append(flattenedOperands.begin(), flattenedOperands.end());
3742     }
3743 
3744     rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(),
3745                                                newOperands, op.dimension());
3746     return success();
3747   }
3748 };
3749 
3750 }  // namespace
3751 
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)3752 LogicalResult ConcatenateOp::inferReturnTypes(
3753     MLIRContext*, Optional<Location> location, ValueRange operands,
3754     DictionaryAttr attributes, RegionRange regions,
3755     SmallVectorImpl<Type>& inferredReturnTypes) {
3756   if (operands.empty()) {
3757     return failure();
3758   }
3759 
3760   auto dimensionAttr = attributes.get("dimension").cast<IntegerAttr>();
3761   auto dimension = dimensionAttr.getInt();
3762 
3763   auto firstType = (*operands.begin()).getType().cast<ShapedType>();
3764   auto outElement = firstType.getElementType();
3765 
3766   // Find the first ranked input to determine the output rank.
3767   for (auto type : operands.getTypes()) {
3768     auto shapedType = type.cast<ShapedType>();
3769     if (shapedType.hasRank()) {
3770       firstType = shapedType;
3771       break;
3772     }
3773   }
3774 
3775   // If all inputs are unranked, the result must be unranked.
3776   if (!firstType.hasRank()) {
3777     inferredReturnTypes.push_back(UnrankedTensorType::get(outElement));
3778     return success();
3779   }
3780 
3781   auto outShape = llvm::to_vector<6>(firstType.getShape());
3782 
3783   // Determine what the non-concatenate dimensions should be.
3784   for (auto type : operands.getTypes()) {
3785     auto shapedTy = type.cast<ShapedType>();
3786     if (!shapedTy.hasRank()) {
3787       continue;
3788     }
3789 
3790     for (const auto& it : llvm::enumerate(shapedTy.getShape())) {
3791       // If a dimension is not dynamic, the output shape should match.
3792       if (ShapedType::isDynamic(outShape[it.index()])) {
3793         outShape[it.index()] = it.value();
3794       }
3795     }
3796   }
3797 
3798   outShape[dimension] = 0;
3799 
3800   for (auto operand : operands.getTypes()) {
3801     auto type = operand.cast<ShapedType>();
3802     if (!type.hasRank()) {
3803       inferredReturnTypes.push_back(UnrankedTensorType::get(outElement));
3804       return success();
3805     }
3806 
3807     // If the dimension is dynamic we know the output dimension is dynamic.
3808     auto dim = type.getShape()[dimension];
3809     if (ShapedType::isDynamic(dim)) {
3810       outShape[dimension] = ShapedType::kDynamicSize;
3811       break;
3812     }
3813 
3814     outShape[dimension] += dim;
3815   }
3816 
3817   inferredReturnTypes.push_back(RankedTensorType::get(outShape, outElement));
3818 
3819   return success();
3820 }
3821 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3822 void ConcatenateOp::getCanonicalizationPatterns(RewritePatternSet& results,
3823                                                 MLIRContext* context) {
3824   results.add<ConcatenateOperandRemoval, ConcatenateForwarding,
3825               SingleOperandConcatenateToCast>(context);
3826 }
3827 
3828 template <typename T>
foldConcatenateHelper(ConcatenateOp * op,ArrayRef<Attribute> operands)3829 static Attribute foldConcatenateHelper(ConcatenateOp* op,
3830                                        ArrayRef<Attribute> operands) {
3831   auto axis = op->dimension();
3832   auto type = op->getType().cast<ShapedType>();
3833   auto shape = type.getShape();
3834 
3835   size_t topSize = 1;
3836   for (int i = 0, e = axis; i < e; i++) {
3837     topSize = topSize * shape[i];
3838   }
3839 
3840   // Prevent folding if the result is too large.
3841   if (type.getNumElements() > kFoldOpEltLimit) return {};
3842 
3843   SmallVector<T, 6> values;
3844   for (size_t i = 0; i < topSize; i++) {
3845     for (auto operand : operands) {
3846       DenseElementsAttr attr = operand.cast<DenseElementsAttr>();
3847       size_t bottomSize = attr.getNumElements() / topSize;
3848       auto iter = attr.getValues<T>().begin() + i * bottomSize;
3849       values.append(iter, iter + bottomSize);
3850     }
3851   }
3852 
3853   return DenseElementsAttr::get(type, values);
3854 }
3855 
foldConcatenate(ConcatenateOp * op,ArrayRef<Attribute> operands)3856 static Attribute foldConcatenate(ConcatenateOp* op,
3857                                  ArrayRef<Attribute> operands) {
3858   for (auto operand : operands) {
3859     if (!operand) return {};
3860   }
3861 
3862   auto type = op->getResult().getType().cast<ShapedType>();
3863   auto etype = type.getElementType();
3864   if (etype.isa<IntegerType>()) {
3865     return foldConcatenateHelper<APInt>(op, operands);
3866   }
3867 
3868   if (etype.isa<FloatType>()) {
3869     return foldConcatenateHelper<APFloat>(op, operands);
3870   }
3871 
3872   return {};
3873 }
3874 
fold(ArrayRef<Attribute> operands)3875 OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
3876   if (getNumOperands() == 1) return getOperand(0);
3877 
3878   ShapedType type = getResult().getType().cast<ShapedType>();
3879   if (!type.hasStaticShape()) return {};
3880 
3881   auto axis = dimension();
3882   if (auto attr = foldConcatenate(this, operands)) {
3883     return attr;
3884   }
3885 
3886   for (auto operand : getOperands()) {
3887     auto ty = operand.getType().cast<ShapedType>();
3888     if (ty.getDimSize(axis) != 0) {
3889       return {};
3890     }
3891   }
3892 
3893   return DenseElementsAttr::get(type, ArrayRef<Attribute>());
3894 }
3895 
verify()3896 LogicalResult ConcatenateOp::verify() {
3897   RankedTensorType firstRankedType;
3898   int firstRankedIndex;
3899   int numOperands = getNumOperands();
3900   int64_t concatDimension = static_cast<int64_t>(dimension());
3901   if (concatDimension < 0) {
3902     return emitOpError(
3903         llvm::formatv("dimension {0} is negative", concatDimension));
3904   }
3905   for (int i = 0; i < numOperands; i++) {
3906     auto secondType = getOperand(i).getType().dyn_cast<ShapedType>();
3907     if (!secondType.hasRank()) {
3908       continue;
3909     }
3910 
3911     if (!firstRankedType) {
3912       firstRankedType = secondType.cast<RankedTensorType>();
3913       firstRankedIndex = i;
3914       if (firstRankedType.getRank() == 0)
3915         return emitOpError(
3916             llvm::formatv("rank-0 values cannot be concatenated"));
3917       if (concatDimension >= firstRankedType.getRank()) {
3918         return emitOpError(
3919             llvm::formatv("dimension {0} is out-of-bounds for input rank {1}",
3920                           concatDimension, firstRankedType.getRank()));
3921       }
3922       continue;
3923     }
3924 
3925     if (firstRankedType.getRank() != secondType.getRank()) {
3926       return emitOpError(llvm::formatv(
3927           "operands ({0}) and ({1}) do not match rank", firstRankedIndex, i));
3928     }
3929 
3930     auto firstShape = firstRankedType.getShape();
3931     auto secondShape = secondType.getShape();
3932     for (int d = 0; d < firstRankedType.getRank(); ++d) {
3933       if (!ShapedType::isDynamic(firstShape[d]) &&
3934           !ShapedType::isDynamic(secondShape[d]) &&
3935           firstShape[d] != secondShape[d] && d != concatDimension) {
3936         return emitOpError(llvm::formatv(
3937             "shapes of operand ({0}) and ({1}) do not match at non-concat "
3938             "index: ({2}) != ({3}) at non-concat index {4}",
3939             firstRankedIndex, i,
3940             llvm::make_range(firstShape.begin(), firstShape.end()),
3941             llvm::make_range(secondShape.begin(), secondShape.end()), d));
3942       }
3943     }
3944   }
3945   return success();
3946 }
3947 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3948 LogicalResult ConcatenateOp::reifyReturnTypeShapes(
3949     OpBuilder& builder, ValueRange operands,
3950     SmallVectorImpl<Value>& reifiedReturnShapes) {
3951   ConcatenateOp::Adaptor adaptor(operands);
3952   auto inputs = adaptor.val();
3953 
3954   auto operandType = inputs[0].getType().dyn_cast<RankedTensorType>();
3955   // Not support unranked type a.t.m.
3956   if (!operandType) return failure();
3957 
3958   Location loc = this->getLoc();
3959   Type shapeScalarType = builder.getIndexType();
3960   auto toShapeScalarType = [&](Value v) {
3961     return maybeCastTo(builder, loc, v, shapeScalarType);
3962   };
3963 
3964   SmallVector<SmallVector<Value, 4>, 4> allShapeValues;
3965   for (size_t inputId = 0; inputId < inputs.size(); ++inputId) {
3966     Value operand = inputs[inputId];
3967     auto operandType = operand.getType().dyn_cast<RankedTensorType>();
3968     if (!operandType) return failure();
3969 
3970     SmallVector<Value, 4> shapeVals;
3971     for (const auto& element : llvm::enumerate(operandType.getShape())) {
3972       Value valueDim = toShapeScalarType(
3973           builder.create<tensor::DimOp>(loc, operand, element.index()));
3974       shapeVals.push_back(valueDim);
3975     }
3976     allShapeValues.emplace_back(std::move(shapeVals));
3977   }
3978 
3979   int axis = this->dimension();
3980   auto& shapeValues = allShapeValues[0];
3981   for (size_t vecId = 1; vecId < allShapeValues.size(); ++vecId) {
3982     auto& otherShapeValues = allShapeValues[vecId];
3983     if (otherShapeValues.size() != shapeValues.size()) {
3984       this->emitOpError()
3985           << "Concatenate expects all operands must be of the same rank";
3986       return failure();
3987     }
3988     shapeValues[axis] = builder.create<arith::AddIOp>(loc, shapeValues[axis],
3989                                                       otherShapeValues[axis]);
3990   }
3991 
3992   Value outputShape = builder.create<tensor::FromElementsOp>(
3993       loc,
3994       RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
3995                             shapeScalarType),
3996       shapeValues);
3997   reifiedReturnShapes.push_back(outputShape);
3998 
3999   return success();
4000 }
4001 
4002 //===----------------------------------------------------------------------===//
4003 // DynamicReshapeOp
4004 //===----------------------------------------------------------------------===//
4005 
verify()4006 LogicalResult DynamicReshapeOp::verify() {
4007   auto resultType = result().getType().dyn_cast<RankedTensorType>();
4008   auto outputShapeType = output_shape().getType().dyn_cast<RankedTensorType>();
4009   if (resultType && outputShapeType && outputShapeType.hasStaticShape() &&
4010       outputShapeType.getDimSize(0) != resultType.getRank()) {
4011     return emitError() << "output should have a rank equal to the number of "
4012                           "elements in output_shape";
4013   }
4014   return success();
4015 }
4016 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4017 LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
4018     OpBuilder& builder, ValueRange operands,
4019     SmallVectorImpl<Value>& reifiedReturnShapes) {
4020   DynamicReshapeOp::Adaptor adaptor(operands);
4021   reifiedReturnShapes.push_back(
4022       castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
4023   return success();
4024 }
4025 
4026 namespace {
4027 class DynamicReshapeOpNotActuallyDynamic
4028     : public OpRewritePattern<DynamicReshapeOp> {
4029  public:
4030   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const4031   LogicalResult matchAndRewrite(DynamicReshapeOp op,
4032                                 PatternRewriter& rewriter) const override {
4033     auto type = op.result().getType().dyn_cast<RankedTensorType>();
4034     if (!type || !type.hasStaticShape()) {
4035       return rewriter.notifyMatchFailure(op, "requires static shape tensor");
4036     }
4037     rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand());
4038     return success();
4039   }
4040 };
4041 
4042 // Canonicalizes
4043 // %0 = some_op(%tensor)
4044 // %1 = "mhlo.dynamic_reshape"(%0, %shape)
4045 //      (tensor<?xT>, tensor<1xindex>) -> tensor<?xT>
4046 // ... uses of %1.
4047 //
4048 // into
4049 //
4050 // ... uses of %0.
4051 // This canonicalization is only correct if the input is correct!
4052 // TODO(b/178779691): Use a more sophisticated canonicalization that preserves
4053 // errors in input, and still allows us to get rid of redundant reshapes.
4054 class RemoveRedundantRank1DynamicReshape
4055     : public OpRewritePattern<DynamicReshapeOp> {
4056  public:
4057   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const4058   LogicalResult matchAndRewrite(DynamicReshapeOp op,
4059                                 PatternRewriter& rewriter) const override {
4060     auto type = op.result().getType().dyn_cast<RankedTensorType>();
4061     if (!type || type.getRank() != 1 || type.hasStaticShape()) {
4062       return rewriter.notifyMatchFailure(
4063           op, "requires rank 1 shape tensor with dynamic dimension");
4064     }
4065     auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
4066     if (!operandType || operandType.getRank() != 1 ||
4067         operandType.hasStaticShape()) {
4068       return rewriter.notifyMatchFailure(
4069           op, "requires rank 1 shape tensor with dynamic dimension");
4070     }
4071     rewriter.replaceOp(op, {op.operand()});
4072     return success();
4073   }
4074 };
4075 
4076 // Canonicalizes
4077 // %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
4078 // %1 = same_operands_and_result_shape_op(%tensor)
4079 // %2 = "mhlo.dynamic_reshape"(%1, %shape)
4080 // ... uses of %2.
4081 //
4082 // into
4083 //
4084 // %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
4085 // %1 = same_operands_and_result_shape_op(%tensor)
4086 // ... uses of %1.
4087 class DynamicReshapeOpSameShapeOpResult
4088     : public OpRewritePattern<DynamicReshapeOp> {
4089  public:
4090   using OpRewritePattern::OpRewritePattern;
4091 
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const4092   LogicalResult matchAndRewrite(DynamicReshapeOp op,
4093                                 PatternRewriter& rewriter) const override {
4094     Operation* defOp = op.operand().getDefiningOp();
4095     if (!defOp ||
4096         !defOp->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
4097       return failure();
4098     }
4099     Operation* inputDefOp = defOp->getOperand(0).getDefiningOp();
4100     if (!inputDefOp) {
4101       return failure();
4102     }
4103     auto reshape = dyn_cast<DynamicReshapeOp>(*inputDefOp);
4104     if (reshape && reshape.output_shape() == op.output_shape()) {
4105       rewriter.replaceOp(op, {defOp->getResult(0)});
4106       return success();
4107     }
4108     return failure();
4109   }
4110 };
4111 }  // namespace
4112 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)4113 void DynamicReshapeOp::getCanonicalizationPatterns(RewritePatternSet& results,
4114                                                    MLIRContext* context) {
4115   // clang-format off
4116   results.add<
4117       DynamicReshapeOpNotActuallyDynamic,
4118       DynamicReshapeOpSameShapeOpResult,
4119       RemoveRedundantDynamicBroadcast,
4120       RemoveRedundantDynamicReshape,
4121       RemoveRedundantRank1DynamicReshape,
4122       ShapeOfDynamicReshape
4123     >(context);
4124   // clang-format on
4125 }
4126 
4127 //===----------------------------------------------------------------------===//
4128 // DynamicSliceOp
4129 //===----------------------------------------------------------------------===//
4130 
4131 namespace {
4132 // Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops.
4133 // This canonicalization is applied the case when the `begin` input values are
4134 // compile time constants and thus can be made into a tensor.
4135 struct DynamicSliceToSlice : public OpRewritePattern<DynamicSliceOp> {
4136   using OpRewritePattern<DynamicSliceOp>::OpRewritePattern;
4137 
matchAndRewritemlir::mhlo::__anon00baf10a2411::DynamicSliceToSlice4138   LogicalResult matchAndRewrite(DynamicSliceOp dynamicSlice,
4139                                 PatternRewriter& rewriter) const override {
4140     Value input = dynamicSlice.operand();
4141     auto inputTensor = input.getType().dyn_cast<RankedTensorType>();
4142     if (!inputTensor || !inputTensor.hasStaticShape()) return failure();
4143 
4144     auto sliceSizes = dynamicSlice.slice_sizes().getValues<int64_t>();
4145     SmallVector<int64_t, 4> tempStartIndices;
4146     for (const auto& indexAndSliceStart :
4147          llvm::enumerate(dynamicSlice.start_indices())) {
4148       APInt val;
4149       Value start = indexAndSliceStart.value();
4150       int64_t index = indexAndSliceStart.index();
4151       if (!matchPattern(start, m_ConstantInt(&val))) {
4152         return failure();
4153       }
4154       // Clamp the indices within bounds to faithfully mirror dynamic slice
4155       // semantics.
4156       int64_t clampedStart =
4157           clamp(val.getSExtValue(), static_cast<int64_t>(0),
4158                 inputTensor.getDimSize(index) - sliceSizes[index]);
4159       tempStartIndices.push_back(clampedStart);
4160     }
4161 
4162     // At this point we've determined that the start indices are all constants;
4163     // pack them into a single tensor.
4164     auto loc = dynamicSlice.getLoc();
4165     int64_t inputRank = inputTensor.getRank();
4166     auto sliceStartIndices = rewriter.getI64TensorAttr(tempStartIndices);
4167     DenseIntElementsAttr sliceLimits = buildSliceLimits(
4168         sliceStartIndices, dynamicSlice.slice_sizes(), &rewriter);
4169     DenseIntElementsAttr sliceStrides =
4170         rewriter.getI64TensorAttr(SmallVector<int64_t, 4>(inputRank, 1));
4171     auto result = rewriter.create<SliceOp>(loc, input, sliceStartIndices,
4172                                            sliceLimits, sliceStrides);
4173     rewriter.replaceOp(dynamicSlice, {result});
4174     return success();
4175   }
4176 };
4177 
4178 }  // namespace
4179 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)4180 void DynamicSliceOp::getCanonicalizationPatterns(RewritePatternSet& results,
4181                                                  MLIRContext* context) {
4182   results.add<DynamicSliceToSlice>(context);
4183 }
4184 
4185 // Verifies that the number of slice sizes and the number of start indices match
verify()4186 LogicalResult DynamicSliceOp::verify() {
4187   int numSliceSizes = slice_sizes().getNumElements();
4188   int numStartIndices = start_indices().size();
4189   if (numStartIndices != numSliceSizes) {
4190     return emitOpError() << "has mismatched number of slice sizes ("
4191                          << numSliceSizes << ") and number of start indices ("
4192                          << numStartIndices << ")";
4193   }
4194   auto operandType = operand().getType().dyn_cast<RankedTensorType>();
4195   if (!operandType) return failure();
4196 
4197   if (operandType.getRank() != numStartIndices) {
4198     return emitOpError() << "has mismatched number of start indices ("
4199                          << numStartIndices << ") and the rank of operand ("
4200                          << operandType.getRank() << ")";
4201   }
4202 
4203   for (int i = 0; i < numSliceSizes; ++i) {
4204     int64_t sliceSize = slice_sizes().getValues<int64_t>()[i];
4205     if (sliceSize < 0) {
4206       return emitOpError() << "has negative size index to dynamic slice: "
4207                            << sliceSize;
4208     }
4209     if (!operandType.isDynamicDim(i)) {
4210       int64_t dimSize = operandType.getDimSize(i);
4211       if (sliceSize > dimSize) {
4212         return emitOpError() << "has slice size " << sliceSize
4213                              << " greater than dimension size " << dimSize
4214                              << " in dimension " << i << " of operand";
4215       }
4216     }
4217   }
4218   return success();
4219 }
4220 
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)4221 LogicalResult DynamicSliceOp::inferReturnTypeComponents(
4222     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
4223     DictionaryAttr attributes, RegionRange regions,
4224     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
4225   DynamicSliceOp::Adaptor adaptor(operands, attributes, regions);
4226   Value operand = adaptor.operand();
4227   auto operandType = operand.getType().dyn_cast<RankedTensorType>();
4228   if (!operandType) return failure();
4229 
4230   auto sliceSizes = adaptor.slice_sizes();
4231   Type elementTy = operandType.getElementType();
4232   inferredReturnShapes.emplace_back(sliceSizes.getValues<int64_t>(), elementTy);
4233   return success();
4234 }
4235 
4236 //===----------------------------------------------------------------------===//
4237 // RealDynamicSliceOp
4238 //===----------------------------------------------------------------------===//
4239 // Verifies that operand rank matches start_indices/limit_indices/strides size
verify()4240 LogicalResult RealDynamicSliceOp::verify() {
4241   auto inputType = operand().getType().dyn_cast<RankedTensorType>();
4242   // If operand is unranked, there is very little to verify statically.
4243   if (!inputType) return success();
4244   int inputRank = inputType.getRank();
4245 
4246   auto startType = start_indices().getType().cast<RankedTensorType>();
4247   auto limitType = limit_indices().getType().cast<RankedTensorType>();
4248   auto stridesType = strides().getType().cast<RankedTensorType>();
4249 
4250   if (inputRank != startType.getNumElements()) {
4251     return emitOpError() << "has mismatched number of operand rank ("
4252                          << inputRank << ") and start_indices size ("
4253                          << startType.getNumElements() << ")";
4254   }
4255 
4256   if (inputRank != limitType.getNumElements()) {
4257     return emitOpError() << "has mismatched number of operand rank ("
4258                          << inputRank << ") and limit_indices size ("
4259                          << limitType.getNumElements() << ")";
4260   }
4261 
4262   if (inputRank != stridesType.getNumElements()) {
4263     return emitOpError() << "has mismatched number of operand rank ("
4264                          << inputRank << ") and strides size ("
4265                          << stridesType.getNumElements() << ")";
4266   }
4267 
4268   return success();
4269 }
4270 
4271 namespace {
4272 // Canonicalizes RealDynamicSlice ops that can be replaced instead with Slice
4273 // ops. This canonicalization is applied the case when the `begin` input values
4274 // are compile time constants and thus can be made into a tensor.
4275 struct RealDynamicSliceIsStatic : public OpRewritePattern<RealDynamicSliceOp> {
4276   using OpRewritePattern<RealDynamicSliceOp>::OpRewritePattern;
4277 
matchAndRewritemlir::mhlo::__anon00baf10a2511::RealDynamicSliceIsStatic4278   LogicalResult matchAndRewrite(RealDynamicSliceOp realDynamicSlice,
4279                                 PatternRewriter& rewriter) const override {
4280     Location loc = realDynamicSlice.getLoc();
4281     Value input = realDynamicSlice.operand();
4282     Value output = realDynamicSlice.result();
4283     auto inputTy = input.getType().dyn_cast<RankedTensorType>();
4284     auto outputTy = output.getType().dyn_cast<RankedTensorType>();
4285 
4286     if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
4287         !outputTy.hasStaticShape()) {
4288       return failure();
4289     }
4290 
4291     int64_t inputRank = inputTy.getRank();
4292 
4293     auto startVal = realDynamicSlice.start_indices();
4294     auto limitVal = realDynamicSlice.limit_indices();
4295     auto strideVal = realDynamicSlice.strides();
4296     auto startOp = startVal.getDefiningOp<mlir::arith::ConstantOp>();
4297     auto limitOp = limitVal.getDefiningOp<mlir::arith::ConstantOp>();
4298     auto strideOp = strideVal.getDefiningOp<mlir::arith::ConstantOp>();
4299     if (!startOp || !limitOp || !strideOp) return failure();
4300 
4301     auto startAttr =
4302         startOp.getValue().dyn_cast_or_null<DenseIntElementsAttr>();
4303     auto limitAttr =
4304         limitOp.getValue().dyn_cast_or_null<DenseIntElementsAttr>();
4305     auto strideAttr =
4306         strideOp.getValue().dyn_cast_or_null<DenseIntElementsAttr>();
4307     if (!startAttr || !limitAttr || !strideAttr) return failure();
4308 
4309     SmallVector<int64_t, 4> tempStartIndices;
4310     SmallVector<int64_t, 4> tempLimitIndices;
4311     SmallVector<int64_t, 4> tempStride;
4312     for (int64_t dimIdx = 0; dimIdx < inputRank; dimIdx++) {
4313       int64_t start = startAttr.getValues<IntegerAttr>()[dimIdx].getInt();
4314       tempStartIndices.push_back(start);
4315       int64_t limit = limitAttr.getValues<IntegerAttr>()[dimIdx].getInt();
4316       tempLimitIndices.push_back(limit);
4317       int64_t end = strideAttr.getValues<IntegerAttr>()[dimIdx].getInt();
4318       tempStride.push_back(end);
4319     }
4320 
4321     DenseIntElementsAttr sliceStartIndices =
4322         rewriter.getI64TensorAttr(tempStartIndices);
4323     DenseIntElementsAttr sliceLimitIndices =
4324         rewriter.getI64TensorAttr(tempLimitIndices);
4325     DenseIntElementsAttr sliceStrides = rewriter.getI64TensorAttr(tempStride);
4326     auto result = rewriter.create<SliceOp>(loc, input, sliceStartIndices,
4327                                            sliceLimitIndices, sliceStrides);
4328     rewriter.replaceOp(realDynamicSlice, {result});
4329     return success();
4330   }
4331 };
4332 }  // namespace
4333 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)4334 void RealDynamicSliceOp::getCanonicalizationPatterns(RewritePatternSet& results,
4335                                                      MLIRContext* context) {
4336   results.add<RealDynamicSliceIsStatic, RealDSliceToSlice>(context);
4337 }
4338 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4339 LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes(
4340     OpBuilder& builder, ValueRange operands,
4341     SmallVectorImpl<Value>& reifiedReturnShapes) {
4342   RealDynamicSliceOp::Adaptor adaptor(operands);
4343   Value operand = adaptor.operand();
4344   Value startIndices = adaptor.start_indices();
4345   Value limitIndices = adaptor.limit_indices();
4346   Value strides = adaptor.strides();
4347 
4348   auto operandType = operand.getType().dyn_cast<RankedTensorType>();
4349   // Not support unranked type a.t.m.
4350   if (!operandType) return failure();
4351 
4352   Location loc = this->getLoc();
4353   SmallVector<Value, 4> shapeValues;
4354   shapeValues.reserve(operandType.getRank());
4355   Type shapeScalarType =
4356       startIndices.getType().cast<ShapedType>().getElementType();
4357   Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
4358   one = maybeCastTo(builder, loc, one, shapeScalarType);
4359   for (const auto& element : llvm::enumerate(operandType.getShape())) {
4360     Value offset = builder.create<arith::ConstantIndexOp>(loc, element.index());
4361     Value valueStart =
4362         builder.create<tensor::ExtractOp>(loc, startIndices, offset);
4363     Value valueLimit =
4364         builder.create<tensor::ExtractOp>(loc, limitIndices, offset);
4365     Value valueStride = builder.create<tensor::ExtractOp>(loc, strides, offset);
4366     // size = (limit - start + stride - 1) / stride
4367     shapeValues.push_back(builder.create<arith::DivSIOp>(
4368         loc,
4369         builder.create<arith::SubIOp>(
4370             loc,
4371             builder.create<arith::AddIOp>(
4372                 loc, valueStride,
4373                 builder.create<arith::SubIOp>(loc, valueLimit, valueStart)),
4374             one),
4375         valueStride));
4376   }
4377 
4378   reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
4379       loc,
4380       RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
4381                             shapeScalarType),
4382       shapeValues));
4383   return success();
4384 }
4385 
4386 //===----------------------------------------------------------------------===//
4387 // InfeedOp
4388 //===----------------------------------------------------------------------===//
4389 
4390 // Checks that the result type is of the form `zero_or_more_type(s),
4391 // mhlo::token`
verify()4392 LogicalResult InfeedOp::verify() {
4393   auto resultTypes = getResultTypes();
4394   if (resultTypes.empty())
4395     return emitOpError()
4396            << "result is expected to be at least of size 1, but got "
4397            << resultTypes.size();
4398 
4399   if (!resultTypes[resultTypes.size() - 1].isa<TokenType>())
4400     return emitOpError() << "last element of result types is expected to "
4401                             "be of token type, but got "
4402                          << resultTypes[resultTypes.size() - 1];
4403 
4404   // Verify layout attribute
4405   constexpr char kLayoutAttr[] = "layout";
4406   if (!getOperation()->hasAttr(kLayoutAttr)) return success();
4407 
4408   mlir::ArrayAttr layout =
4409       getOperation()->getAttrOfType<mlir::ArrayAttr>(kLayoutAttr);
4410   if (!layout)
4411     return emitOpError() << "layout-attribute expected to be of array-type.";
4412 
4413   if (layout.size() != resultTypes.size() - 1) {
4414     return emitOpError() << "layout-attribute size must be "
4415                          << resultTypes.size() - 1
4416                          << " (which is the number of "
4417                             "op-results - 1 (for token result)), but got "
4418                          << layout.size();
4419   }
4420 
4421   for (auto childLayout : layout) {
4422     mlir::ArrayAttr childLayoutArr = childLayout.dyn_cast<mlir::ArrayAttr>();
4423     if (!childLayoutArr) {
4424       return emitOpError() << "layout-attribute expected to have "
4425                               "elements of type array, but got "
4426                            << childLayout;
4427     }
4428 
4429     for (auto i : childLayoutArr) {
4430       mlir::IntegerAttr attr = i.dyn_cast<mlir::IntegerAttr>();
4431       if (!attr) {
4432         return emitOpError() << "layout-attribute's leaf elements are "
4433                                 "expected to be of type integer, but got "
4434                              << i;
4435       }
4436     }
4437   }
4438 
4439   return success();
4440 }
4441 
4442 //===----------------------------------------------------------------------===//
4443 // MapOp
4444 //===----------------------------------------------------------------------===//
4445 
verify()4446 LogicalResult MapOp::verify() {
4447   // Checks if the number of `operands` match the arity of the map `computation`
4448   // region.
4449   auto& computationBlock = computation().front();
4450   auto computationArgs = computationBlock.getArguments();
4451   if (operands().size() != computationArgs.size())
4452     return emitOpError() << "expects number of operands to match the arity "
4453                             "of map computation, but got: "
4454                          << operands().size() << " and "
4455                          << computationArgs.size();
4456 
4457   // The parameters of computation should all be scalars and match the element
4458   // type of operands.
4459   for (const auto& indexedArg : llvm::enumerate(computationArgs)) {
4460     auto argType = indexedArg.value().getType().dyn_cast<TensorType>();
4461     if (!argType || argType.getRank() != 0)
4462       return emitOpError()
4463              << "computation arguments must be 0-rank tensor, but got: arg #"
4464              << indexedArg.index() << " of type "
4465              << indexedArg.value().getType();
4466     auto operandElemTy = operands()[indexedArg.index()]
4467                              .getType()
4468                              .cast<TensorType>()
4469                              .getElementType();
4470     if (argType.getElementType() != operandElemTy) {
4471       return emitOpError()
4472              << "element type of operands and computation arguments must "
4473                 "match, but got: "
4474              << operandElemTy << " and " << argType.getElementType();
4475     }
4476   }
4477 
4478   // Mapped computation must return single output
4479   auto computationOutputs = computationBlock.getTerminator()->getOperands();
4480   if (computationOutputs.size() != 1)
4481     return emitOpError() << "computation must return single output, but got: "
4482                          << computationOutputs.size();
4483 
4484   // The output of computation must be scalar and have the same element type
4485   // as op result.
4486   auto computationOutputType =
4487       computationOutputs[0].getType().dyn_cast<TensorType>();
4488   if (!computationOutputType || computationOutputType.getRank() != 0)
4489     return emitOpError() << "computation must return 0-rank tensor, but got: "
4490                          << computationOutputs[0].getType();
4491 
4492   auto resultType = getType().cast<TensorType>();
4493   if (computationOutputType.getElementType() != resultType.getElementType())
4494     return emitOpError() << "element type of result and computation output "
4495                             "must match, but got: "
4496                          << resultType.getElementType() << " and "
4497                          << computationOutputType.getElementType();
4498 
4499   // Checks that the requested map dimension numbers are monotonically
4500   // increasing.
4501   DenseIntElementsAttr dimensions = this->dimensions();
4502   for (const auto& indexedValue :
4503        llvm::enumerate(dimensions.getValues<int64_t>())) {
4504     if (indexedValue.value() != static_cast<int64_t>(indexedValue.index()))
4505       return emitOpError() << "requires monotonically increasing dimension "
4506                               "numbers, but got: "
4507                            << dimensions;
4508   }
4509 
4510   // Checks that number of dimensions of operands matches the size of
4511   // `dimensions` since we currently only support mapping across all
4512   // dimensions: i.e., scalar map functions.
4513   auto operandType = operands()[0].getType().cast<TensorType>();
4514   if (operandType.hasRank()) {
4515     if (dimensions.size() !=
4516         static_cast<int64_t>(operandType.getShape().size()))
4517       return emitOpError()
4518              << "applied to a subset of dimensions currently not supported: "
4519                 "operand dimensions = "
4520              << operandType.getShape().size()
4521              << ", requested map dimensions size = " << dimensions.size();
4522   }
4523 
4524   return success();
4525 }
4526 
fold(ArrayRef<Attribute> operands)4527 OpFoldResult MapOp::fold(ArrayRef<Attribute> operands) {
4528   mlir::Block& bb = computation().front();
4529   mlir::Operation& frontOp = bb.front();
4530 
4531   auto retOp = mlir::dyn_cast<ReturnOp>(frontOp);
4532   if (!retOp) return nullptr;
4533   if (retOp.results().size() != 1) return nullptr;
4534 
4535   for (mlir::BlockArgument barg : bb.getArguments()) {
4536     if (barg == retOp.results()[0]) return getOperands()[barg.getArgNumber()];
4537   }
4538   return nullptr;
4539 }
4540 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4541 LogicalResult MapOp::reifyReturnTypeShapes(
4542     OpBuilder& builder, ValueRange operands,
4543     SmallVectorImpl<Value>& reifiedReturnShapes) {
4544   return deriveShapeFromOperand(&builder, getOperation(), operands.front(),
4545                                 &reifiedReturnShapes);
4546 }
4547 
4548 //===----------------------------------------------------------------------===//
4549 // RecvOp
4550 //===----------------------------------------------------------------------===//
4551 
4552 // Checks that the result type is of the form `zero_or_more_type(s),
4553 // mhlo::token`
verify()4554 LogicalResult RecvOp::verify() {
4555   auto resultTypes = getResultTypes();
4556   if (resultTypes.empty())
4557     return emitOpError()
4558            << "result is expected to be at least of size 1, but got "
4559            << resultTypes.size();
4560   if (!resultTypes[resultTypes.size() - 1].isa<TokenType>())
4561     return emitOpError() << "last element of result types is expected to "
4562                             "be of token type, but got "
4563                          << resultTypes[resultTypes.size() - 1];
4564   return success();
4565 }
4566 
4567 //===----------------------------------------------------------------------===//
4568 // CopyOp
4569 //===----------------------------------------------------------------------===//
4570 
fold(ArrayRef<Attribute> operands)4571 OpFoldResult CopyOp::fold(ArrayRef<Attribute> operands) { return getOperand(); }
4572 
4573 //===----------------------------------------------------------------------===//
4574 // ReduceWindowOp
4575 //===----------------------------------------------------------------------===//
4576 
4577 namespace {
4578 // Infer the return-type of ReduceWindowOp.
inferReduceWindowOpReturnType(ArrayRef<TensorType> inputTypes,ArrayRef<TensorType> initTypes,const ArrayRef<WindowDimension> window)4579 SmallVector<TensorType> inferReduceWindowOpReturnType(
4580     ArrayRef<TensorType> inputTypes, ArrayRef<TensorType> initTypes,
4581     const ArrayRef<WindowDimension> window) {
4582   SmallVector<TensorType> outputTypes;
4583   for (size_t i = 0; i < inputTypes.size(); ++i) {
4584     if (!inputTypes[i].hasRank()) {
4585       outputTypes.push_back(
4586           UnrankedTensorType::get(initTypes[i].getElementType()));
4587       continue;
4588     }
4589 
4590     outputTypes.push_back(RankedTensorType::get(
4591         inferWindowOutputShape(inputTypes[i].getShape(), window),
4592         initTypes[i].getElementType()));
4593   }
4594 
4595   return outputTypes;
4596 }
4597 }  // namespace
4598 
4599 // We intend to verify the following properties
4600 //  P1. The sizes of 'inputs' and 'init_values' must be at least 1.
4601 //  P2. All `inputs` need to have compatible shapes.
4602 //  P3. size-of(window_dimension) == rank-of(input),
4603 //        where input is an element of 'inputs'.
4604 //  P4. Verify and collect the window atributes.
4605 //  P5. Verify the inner block defining the reducer function.
4606 //  P6. Verify the return type.
verify()4607 LogicalResult ReduceWindowOp::verify() {
4608   // P1.
4609   // Note that the ODS ensures that there are even number of operands; Check if
4610   // that number is not zero.
4611   if (getOperands().empty())
4612     return emitOpError() << "expects the size of operands to be >= 2.";
4613 
4614   // Collect the input and init-value operands. Note that the operand-type is
4615   // enforced as "TensorType" by ODS.
4616   int64_t numInputs = getNumOperands() / 2;
4617   auto operandTensorTypes = llvm::to_vector<4>(llvm::map_range(
4618       getOperandTypes(),
4619       [](Type t) -> TensorType { return t.cast<TensorType>(); }));
4620   ArrayRef<TensorType> inputTypes(operandTensorTypes.begin(),
4621                                   operandTensorTypes.begin() + numInputs);
4622   ArrayRef<TensorType> initTypes(operandTensorTypes.begin() + numInputs,
4623                                  operandTensorTypes.end());
4624 
4625   // P2.
4626   if (failed(verifyCompatibleShapes(operands().getTypes())))
4627     return emitOpError() << "requires same shape for all inputs";
4628 
4629   // P3.
4630   SmallVector<int64_t> windowDims =
4631       convertDenseIntAttr(this->window_dimensions());
4632   for (const auto inputType : inputTypes) {
4633     if (!inputType.hasRank()) continue;
4634     if (inputType.getRank() != static_cast<int64_t>(windowDims.size()))
4635       return emitOpError()
4636              << "expects window-dimensions size == input rank, but got "
4637                 "window-dimensions size: "
4638              << windowDims.size() << " and input: " << inputType
4639              << " with rank = " << inputType.getRank() << ".";
4640   }
4641 
4642   // P4.
4643   auto paddingOrErr = convertNx2Attribute(this->padding(), getLoc());
4644   if (failed(paddingOrErr)) return failure();
4645   SmallVector<std::pair<int64_t, int64_t>> padding = *paddingOrErr;
4646 
4647   auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
4648       windowDims, convertDenseIntAttr(window_strides()), padding,
4649       /*lhs_dilation=*/convertDenseIntAttr(base_dilations()),
4650       /*rhs_dilation=*/convertDenseIntAttr(this->window_dilations()), getLoc());
4651   if (failed(windowOrErr)) return failure();
4652 
4653   // P5.
4654   bool allInputsUnranked =
4655       llvm::all_of(inputTypes, [](TensorType t) { return !t.hasRank(); });
4656 
4657   Block& block = body().front();
4658   SmallVector<TensorType> accumulatorSubshapes;
4659   if (failed(verifyReducerShape(this->getLoc(), block, inputTypes, initTypes,
4660                                 numInputs, windowDims, allInputsUnranked,
4661                                 accumulatorSubshapes)))
4662     return failure();
4663 
4664   // P6.
4665   if (numInputs != getNumResults())
4666     return emitOpError() << "expects " << numInputs
4667                          << " result values, but got " << getNumResults()
4668                          << ".";
4669 
4670   // The result-type is enforced as "TensorType" by ODS.
4671   auto resultTensorTypes = llvm::to_vector<4>(llvm::map_range(
4672       getResultTypes(),
4673       [](Type t) -> TensorType { return t.cast<TensorType>(); }));
4674 
4675   // Check if the element-type of results match with the ones derived from
4676   // the reducer-block. Already ensured that  |accumulator_subshapes| ==
4677   // num_inputs == num_of_results.
4678   for (int64_t shapeIdx = 0;
4679        shapeIdx < static_cast<int64_t>(accumulatorSubshapes.size());
4680        shapeIdx++) {
4681     if (accumulatorSubshapes[shapeIdx].getElementType() !=
4682         resultTensorTypes[shapeIdx].getElementType()) {
4683       return emitError()
4684              << "expects the element-type of reduce-op's return-value at index "
4685              << shapeIdx
4686              << " to match the element-type of reducer-block's "
4687                 "corresponding return-value, but got "
4688              << resultTensorTypes[shapeIdx].getElementType() << " and "
4689              << accumulatorSubshapes[shapeIdx].getElementType() << " resp.";
4690     }
4691   }
4692 
4693   // Check if the shape of results match with the ones derived from
4694   // the input-types and wndow-attributes.
4695   auto inferredReturnTypes = inferReduceWindowOpReturnType(
4696       inputTypes, accumulatorSubshapes, *windowOrErr);
4697 
4698   for (size_t i = 0; i < getNumResults(); i++) {
4699     if (failed(verifyCompatibleShape(resultTensorTypes[i],
4700                                      inferredReturnTypes[i]))) {
4701       return emitOpError()
4702              << "expects result at index " << i
4703              << " to have compatible shape with the corresponding "
4704                 "inferred type, but got "
4705              << resultTensorTypes[i] << " and " << inferredReturnTypes[i]
4706              << " resp.";
4707     }
4708   }
4709 
4710   return success();
4711 }
4712 
4713 // Get the operation used for reduction applied to `result_index`th result. Its
4714 // expected to be a binary operation that consumes `result_index`th and
4715 // `result_index + operands().size`th arguments of the body.
getReductionOp(int resultIndex)4716 Operation* ReduceWindowOp::getReductionOp(int resultIndex) {
4717   auto returnOp = cast<ReturnOp>(body().front().getTerminator());
4718   Operation* computeOp = returnOp.results()[resultIndex].getDefiningOp();
4719   if (computeOp->getNumOperands() != 2) return nullptr;
4720   auto arg0 = computeOp->getOperand(0).dyn_cast<BlockArgument>();
4721   auto arg1 = computeOp->getOperand(1).dyn_cast<BlockArgument>();
4722   if (!arg0 || !arg1) return nullptr;
4723   int64_t arg0Num = arg0.getArgNumber();
4724   int64_t arg1Num = arg1.getArgNumber();
4725   int64_t otherArgIndex = resultIndex + operands().size();
4726   if (arg0Num == resultIndex && arg1Num == otherArgIndex) return computeOp;
4727   if (arg0Num == otherArgIndex && arg1Num == resultIndex &&
4728       computeOp->hasTrait<mlir::OpTrait::IsCommutative>())
4729     return computeOp;
4730   return nullptr;
4731 }
4732 
4733 //===----------------------------------------------------------------------===//
4734 // ReducePrecisionOp
4735 //===----------------------------------------------------------------------===//
4736 
4737 // The following property is already enforced by the ODS:
4738 //  P0. operand element type is float
4739 //  P1. mantissa_bits >= 0
4740 // We intend to verify the following properties
4741 //  P2. exponent_bits >= 1
verify()4742 LogicalResult ReducePrecisionOp::verify() {
4743   if (exponent_bits() < 1) {
4744     return emitOpError() << "exponent_bits must be at least 1.";
4745   }
4746   return success();
4747 }
4748 
4749 //===----------------------------------------------------------------------===//
4750 // ReverseOp
4751 //===----------------------------------------------------------------------===//
4752 
4753 template <typename T>
foldReverseHelper(DenseElementsAttr & attr,ShapedType & type,DenseIntElementsAttr & dims)4754 static Attribute foldReverseHelper(DenseElementsAttr& attr, ShapedType& type,
4755                                    DenseIntElementsAttr& dims) {
4756   int64_t numElements = attr.getNumElements();
4757   // No-op if the tensor has 0 elements.
4758   // No-op if the result of folding is too large.
4759   if (numElements == 0 || numElements > kFoldOpEltLimit) return {};
4760 
4761   SmallVector<T> result(attr.getValues<T>().begin(), attr.getValues<T>().end());
4762 
4763   size_t rank = type.getRank();
4764   SmallVector<int64_t> stride(rank + 1, numElements);
4765   for (size_t i = 0; i < rank; i++) {
4766     if (type.getDimSize(i) == 0) return {};
4767     stride[i + 1] = stride[i] / type.getDimSize(i);
4768   }
4769 
4770   for (auto dim : dims.getValues<int64_t>()) {
4771     // For example, given:
4772     //   * tensor: tensor<2x3x2xi32>
4773     //     [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9,10], [11, 12]]]
4774     //   * dim: [1]
4775     //
4776     // We're going to reverse the tensor with respect to dim as follows:
4777     //   1) Split the tensor into blocks, i.e. smaller tensors whose type is
4778     //   derived from the tensor by dropping the first `dim` dimensions, i.e.
4779     //   tensor<3x2xi32> for the running example.
4780     //   2) Split each block into windows, i.e. even smaller tensors whose type
4781     //   is derived from the block by dropping the first dimension of the
4782     //   block, i.e. tensor<2xi32> for the running example.
4783     //   3) Within each block, swap windows but don't change the order of
4784     //   elements within the windows: 0th window goes to N-1st spot, 1st window
4785     //   goes to N-2nd spot etc.
4786     //
4787     // For the running example, the result will be:
4788     //   [[[5, 6], [3, 4], [1, 2]], [[11, 12], [9, 10], [7, 8]]].
4789     //
4790     // Note how elements within windows haven't changed their order with respect
4791     // to each other and how blocks haven't changed their order with respect to
4792     // each other.
4793     int64_t numWindows = type.getDimSize(dim);
4794     int64_t windowSize = stride[dim] / numWindows;
4795 
4796     for (int64_t index = 0; index < numElements; index++) {
4797       int64_t blockNumber = index / stride[dim];
4798       int64_t windowNumber = (index % stride[dim]) / windowSize;
4799       int64_t reversedWindowNumber = numWindows - windowNumber - 1;
4800       if (windowNumber >= reversedWindowNumber) continue;
4801       int64_t reversedIndex = blockNumber * stride[dim] +
4802                               reversedWindowNumber * windowSize +
4803                               index % windowSize;
4804       std::swap(result[index], result[reversedIndex]);
4805     }
4806   }
4807   return DenseElementsAttr::get(type, result);
4808 }
4809 
fold(ArrayRef<Attribute> operands)4810 OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
4811   Value input = operand();
4812 
4813   // No dimensions to reverse.
4814   DenseIntElementsAttr dims = dimensions();
4815   if (dims.getNumElements() == 0) return input;
4816 
4817   // If size of all dimensions to reverse equals 1, then the reverse is a no-op.
4818   // Eg. Reverse dimensions {0,1} of a 1x1x2 tensor
4819   auto shapedType = input.getType().cast<ShapedType>();
4820   if (llvm::all_of(dims.getValues<int64_t>(), [&](int64_t dim) {
4821         return shapedType.getDimSize(dim) == 1;
4822       }))
4823     return input;
4824 
4825   // If the operand is a static shaped tensor of constants, return reversed
4826   // tensor
4827   DenseElementsAttr inputAttr =
4828       operands.begin()->dyn_cast_or_null<DenseElementsAttr>();
4829   if (inputAttr && shapedType.hasStaticShape()) {
4830     auto etype = shapedType.getElementType();
4831     if (etype.isa<IntegerType>())
4832       return foldReverseHelper<APInt>(inputAttr, shapedType, dims);
4833     if (etype.isa<FloatType>())
4834       return foldReverseHelper<APFloat>(inputAttr, shapedType, dims);
4835   }
4836 
4837   return {};
4838 }
4839 
4840 //===----------------------------------------------------------------------===//
4841 // ReduceOp
4842 //===----------------------------------------------------------------------===//
4843 
4844 // Returns the result type after reducing operand of the given type across the
4845 // specified dimensions.
getReduceResultType(Type operandTy,DenseIntElementsAttr dimensions,Builder * builder)4846 static TensorType getReduceResultType(Type operandTy,
4847                                       DenseIntElementsAttr dimensions,
4848                                       Builder* builder) {
4849   Type elementTy = getElementTypeOrSelf(operandTy);
4850 
4851   auto rankedTy = operandTy.dyn_cast<RankedTensorType>();
4852   if (!rankedTy) return UnrankedTensorType::get(elementTy);
4853 
4854   int64_t rank = rankedTy.getRank();
4855   llvm::SmallVector<bool, 4> dimsMask(rank, false);
4856   for (int64_t dim : dimensions.getValues<int64_t>()) dimsMask[dim] = true;
4857 
4858   SmallVector<int64_t, 4> shape;
4859   for (int64_t i = 0; i < rank; ++i) {
4860     if (!dimsMask[i]) shape.push_back(rankedTy.getDimSize(i));
4861   }
4862 
4863   return RankedTensorType::get(shape, elementTy);
4864 }
4865 
build(OpBuilder & builder,OperationState & state,ValueRange operands,ValueRange initValues,DenseIntElementsAttr dimensions)4866 void ReduceOp::build(OpBuilder& builder, OperationState& state,
4867                      ValueRange operands, ValueRange initValues,
4868                      DenseIntElementsAttr dimensions) {
4869   SmallVector<Type, 1> resultTy;
4870   resultTy.reserve(operands.size());
4871 
4872   for (Value operand : operands) {
4873     resultTy.push_back(
4874         getReduceResultType(operand.getType(), dimensions, &builder));
4875   }
4876   build(builder, state, resultTy, operands, initValues, dimensions);
4877 }
4878 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)4879 LogicalResult ReduceOp::fold(ArrayRef<Attribute> operands,
4880                              SmallVectorImpl<OpFoldResult>& results) {
4881   // No dimensions to reduce.
4882   if (dimensions().getNumElements() == 0) {
4883     for (Value operand : this->operands()) {
4884       results.push_back(operand);
4885     }
4886     return success();
4887   }
4888 
4889   // If all returned values in the ReduceOp region exists outside
4890   // the region replace the ReduceOp with those values.
4891   mlir::Block& bb = this->body().front();
4892   SmallVector<Value> replacedResults;
4893   if (auto retOp = mlir::dyn_cast<ReturnOp>(bb.back())) {
4894     for (Value result : retOp.results()) {
4895       if (result.getParentRegion() == retOp->getParentRegion())
4896         return failure();
4897       replacedResults.push_back(result);
4898     }
4899 
4900     results.insert(results.end(), replacedResults.begin(),
4901                    replacedResults.end());
4902     return success();
4903   }
4904 
4905   return failure();
4906 }
4907 
hasSameOperandAndResultTypes(Operation & op)4908 bool hasSameOperandAndResultTypes(Operation& op) {
4909   Type expected;
4910   if (op.getNumResults() != 0) expected = op.getResult(0).getType();
4911   if (op.getNumOperands() != 0) expected = op.getOperand(0).getType();
4912   if (!expected) return false;
4913 
4914   auto typeMatch = [&](Type actual) { return actual == expected; };
4915   return llvm::all_of(op.getOperandTypes(), typeMatch) &&
4916          llvm::all_of(op.getResultTypes(), typeMatch);
4917 }
4918 
4919 // Checks the following eligibility criteria for compact printing of
4920 // mhlo.reduce:
4921 // E1. The reduce-op wraps a single inner-op in the associated region.
4922 // E2. The single operation is a commutative binary-op from mhlo dialect, zero
4923 //     region, producing single result such that the operands and result all
4924 //     have the same type.
4925 // E3. The reduce-op consist of at least one input-operand; The operand-types of
4926 //     inner-op should be derived trivially from the element-type of reduce-op's
4927 //     first input-operand.
4928 // E4. The  arguments of the region's only basic block are forwarded perfectly
4929 //     to inner-op's operands.
4930 // E5. The reduce-op, inner-op, blocks arguments, and the return-op all have the
4931 //     same location.
4932 // E6. The single operation result is perfectly forwarded to the reduce op
4933 //     return.
isEligibleForCompactPrint(ReduceOp op)4934 static bool isEligibleForCompactPrint(ReduceOp op) {
4935   // Check E1.
4936   auto& block = op.body().front();
4937   if (!hasSingleElement(block.without_terminator())) return false;
4938 
4939   Operation& innerOp = *block.begin();
4940 
4941   // Check E2.
4942   if (innerOp.getDialect() != op->getDialect()) return false;
4943 
4944   if (innerOp.getNumOperands() != 2 ||
4945       !innerOp.hasTrait<mlir::OpTrait::OneResult>() ||
4946       !hasSameOperandAndResultTypes(innerOp) ||
4947       !innerOp.hasTrait<mlir::OpTrait::IsCommutative>() ||
4948       !innerOp.hasTrait<mlir::OpTrait::ZeroRegions>())
4949     return false;
4950 
4951   // Check E3.
4952   if (op.operands().empty()) return false;
4953 
4954   auto elemType =
4955       op.operands()[0].getType().cast<TensorType>().getElementType();
4956   auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType);
4957   if (innerOp.getOperands()[0].getType() != expectedInnerOpType) return false;
4958 
4959   // Check E4.
4960   if (!llvm::equal(block.getArguments(), innerOp.getOperands())) return false;
4961 
4962   // Check E5.
4963   auto retOp = dyn_cast<ReturnOp>(block.getTerminator());
4964   if (!retOp) return false;
4965 
4966   auto blockArgLoc = block.getArgument(0).getLoc();
4967   if (blockArgLoc != block.getArgument(1).getLoc()) return false;
4968 
4969   if (innerOp.getLoc() != op.getLoc() || retOp.getLoc() != op.getLoc() ||
4970       blockArgLoc != op.getLoc())
4971     return false;
4972 
4973   // Check E6.
4974   return llvm::equal(innerOp.getResults(), retOp.getOperands());
4975 }
4976 
print(OpAsmPrinter & p)4977 void ReduceOp::print(OpAsmPrinter& p) {
4978   {
4979     // Print the pairs of operands under the form:
4980     //   (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5)
4981     StringRef comma = "";
4982     int numOperandPairs = getNumOperands() / 2;
4983     for (int opId : llvm::seq<int>(0, numOperandPairs)) {
4984       p << comma << "(" << getOperand(opId)
4985         << " init: " << getOperand(opId + numOperandPairs) << ")";
4986       comma = ", ";
4987     }
4988   }
4989 
4990   // If the reduce-op is eligible for compact printing, we emit the one-liner:
4991   //  mhlo.reduce applies <inner-op> across dimensions = [...] : <func-type>
4992   // Note: We are not printing the function type of reduction operation. We
4993   // have some simplifying assumptions (refer to IsEligibleForCompactPrint::E3)
4994   // to derive the type from that of reduce-op.
4995   if (isEligibleForCompactPrint(*this)) {
4996     Operation& innerOp = body().front().front();
4997     p << " applies ";
4998     printEscapedString(innerOp.getName().getStringRef(), p.getStream());
4999 
5000     p << " across dimensions = [";
5001     llvm::interleaveComma(dimensions().getValues<int64_t>(), p);
5002     p << "]";
5003     p << " : ";
5004     p.printFunctionalType(*this);
5005   } else {
5006     p << " across dimensions = [";
5007     llvm::interleaveComma(dimensions().getValues<int64_t>(), p);
5008     p << "]";
5009     p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"});
5010     p << " : ";
5011     p.printFunctionalType(*this);
5012     p.printNewline();
5013     p << " reducer";
5014     {
5015       // Print the pairs of block operands under the form:
5016       //   (%arg0_elt, %arg0_acc) (%arg1_elt, %arg1_acc):
5017       Block& reducer = body().front();
5018       int numOperandPairs = getNumOperands() / 2;
5019       for (int opId : llvm::seq<int>(0, numOperandPairs)) {
5020         p << "(";
5021         p.printRegionArgument(reducer.getArgument(opId));
5022         p << ", ";
5023         p.printRegionArgument(reducer.getArgument(opId + numOperandPairs));
5024         p << ") ";
5025       }
5026     }
5027     p << ' ';
5028     p.printRegion(body(), /*printEntryBlockArgs=*/false);
5029   }
5030 }
5031 
parse(OpAsmParser & parser,OperationState & result)5032 ParseResult ReduceOp::parse(OpAsmParser& parser, OperationState& result) {
5033   llvm::SMLoc loc = parser.getCurrentLocation();
5034   Location currLocation = parser.getEncodedSourceLoc(loc);
5035 
5036   // Parse the operands of reduce-op, this is a list of pair under the form:
5037   //   (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5)
5038   // Each input to reduce is paired with its init value, even though in memory
5039   // they are stored with the input first and the init values after.
5040   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
5041   SmallVector<OpAsmParser::UnresolvedOperand, 2> initOperands;
5042   do {
5043     (void)parser.parseOptionalComma();
5044     if (parser.parseOptionalLParen()) break;
5045     OpAsmParser::UnresolvedOperand operand, initOperand;
5046     if (parser.parseOperand(operand) || parser.parseKeyword("init") ||
5047         parser.parseColon() || parser.parseOperand(initOperand) ||
5048         parser.parseRParen())
5049       return failure();
5050     operands.push_back(operand);
5051     initOperands.push_back(initOperand);
5052   } while (true);
5053   operands.append(initOperands);
5054 
5055   // Check if we are parsing the compact version of reduce-op:
5056   //  mhlo.reduce applies <inner-op> across dimensions = [...] : <func-type>
5057   // else parse the "region-based" variant.
5058   if (failed(parser.parseOptionalKeyword("applies"))) {
5059     // Parse the inner-op dimensions, reduce-op's function-type and
5060     // optional location.
5061     SmallVector<int64_t> dimensions;
5062     auto parseDim = [&]() -> ParseResult {
5063       if (parser.parseInteger(dimensions.emplace_back())) return failure();
5064       return success();
5065     };
5066 
5067     FunctionType reduceOpFntype;
5068     if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") ||
5069         parser.parseEqual() ||
5070         parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
5071                                        parseDim) ||
5072         parser.parseOptionalAttrDict(result.attributes) ||
5073         parser.parseColon() || parser.parseType(reduceOpFntype) ||
5074         parser.parseKeyword("reducer"))
5075       return failure();
5076     OpBuilder builder(parser.getBuilder().getContext());
5077     result.addAttribute("dimensions", builder.getI64TensorAttr(dimensions));
5078 
5079     // Parse the "reducer" region now.
5080     SmallVector<OpAsmParser::UnresolvedOperand, 2> reducerOperands;
5081     SmallVector<OpAsmParser::UnresolvedOperand, 2> reducerInitOperands;
5082     SmallVector<Type, 2> reducerTypes;
5083     SmallVector<Type, 2> reducerInitTypes;
5084     SmallVector<Optional<Location>, 2> reducerLocs;
5085     SmallVector<Optional<Location>, 2> reducerInitLocs;
5086     auto parseBlockOperand =
5087         [&](SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands,
5088             SmallVectorImpl<Type>& types,
5089             SmallVectorImpl<Optional<Location>>& locs) -> ParseResult {
5090       OpAsmParser::UnresolvedOperand operand;
5091       Type type;
5092       Optional<Location> loc;
5093       if (parser.parseOperand(operand, /*allowResultNumber=*/false) ||
5094           parser.parseColon() || parser.parseType(type) ||
5095           parser.parseOptionalLocationSpecifier(loc))
5096         return failure();
5097       operands.push_back(operand);
5098       types.push_back(type);
5099       locs.push_back(loc);
5100       return success();
5101     };
5102     do {
5103       if (failed(parser.parseOptionalLParen())) break;
5104       if (parseBlockOperand(reducerOperands, reducerTypes, reducerLocs) ||
5105           parser.parseComma() ||
5106           parseBlockOperand(reducerInitOperands, reducerInitTypes,
5107                             reducerInitLocs) ||
5108           parser.parseRParen())
5109         return failure();
5110     } while (true);
5111     reducerOperands.append(reducerInitOperands);
5112     reducerTypes.append(reducerInitTypes);
5113     reducerLocs.append(reducerInitLocs);
5114     result.addTypes(reduceOpFntype.getResults());
5115     SmallVector<OpAsmParser::Argument> reducerArgs;
5116     createArgs(reducerOperands, reducerTypes, reducerArgs);
5117 
5118     // Derive the SSA-values for reduce-op's operands and parse the region, and
5119     // the optional trailing location.
5120     Optional<Location> trailingLoc;
5121     if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc,
5122                                result.operands) ||
5123         parser.parseRegion(*result.addRegion(), reducerArgs))
5124       return failure();
5125     // Set the individual block arguments.
5126     for (auto argAndLoc :
5127          llvm::zip(result.regions.front()->front().getArguments(), reducerLocs))
5128       if (std::get<1>(argAndLoc))
5129         std::get<0>(argAndLoc).setLoc(std::get<1>(argAndLoc).value());
5130     result.location = trailingLoc.value_or(currLocation);
5131     return success();
5132   }
5133 
5134   // Parse the inner-op name and check if the contract on inner-op
5135   // mentioned in "isEligibleForCompactPrint::E2" for pretty-priting is met.
5136   FailureOr<OperationName> innerOpNameInfo = parser.parseCustomOperationName();
5137   if (failed(innerOpNameInfo)) return failure();
5138 
5139   StringRef innerOpName = innerOpNameInfo->getStringRef();
5140   Dialect* innerOpDialect = innerOpNameInfo->getDialect();
5141   if (!innerOpDialect || !innerOpDialect->getNamespace().equals("mhlo") ||
5142       !innerOpNameInfo->hasTrait<mlir::OpTrait::NOperands<2>::Impl>() ||
5143       !innerOpNameInfo->hasTrait<mlir::OpTrait::OneResult>() ||
5144       !innerOpNameInfo->hasTrait<mlir::OpTrait::IsCommutative>() ||
5145       !innerOpNameInfo->hasTrait<mlir::OpTrait::ZeroRegions>()) {
5146     parser.emitError(loc,
5147                      "expected the inner-op to be a commutative binary-op from "
5148                      "mhlo dialect, zero region, producing single result");
5149     return failure();
5150   }
5151 
5152   // Parse the inner-op dimensions, reduce-op's function-type and
5153   // optional location.
5154   SmallVector<int64_t> dimensions;
5155   auto parseDim = [&]() -> ParseResult {
5156     if (parser.parseInteger(dimensions.emplace_back())) return failure();
5157     return success();
5158   };
5159 
5160   Optional<Location> explicitLoc;
5161   FunctionType reduceOpFntype;
5162   if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") ||
5163       parser.parseEqual() ||
5164       parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) ||
5165       parser.parseColon() || parser.parseType(reduceOpFntype) ||
5166       parser.parseOptionalLocationSpecifier(explicitLoc))
5167     return failure();
5168 
5169   if (!reduceOpFntype || reduceOpFntype.getInputs().empty()) {
5170     if (!reduceOpFntype) return parser.emitError(loc, "expected function type");
5171     return parser.emitError(loc,
5172                             "input types missing in reduce-op function type");
5173   }
5174 
5175   // If location of reduce-op is explicitly provided, then use it; Else use
5176   // the parser's current location.
5177   Location reduceOpLoc = explicitLoc.value_or(currLocation);
5178 
5179   // Derive the SSA-values for reduce-op's operands.
5180   if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc,
5181                              result.operands))
5182     return failure();
5183 
5184   // Derive the type of inner-op from that of reduce-op's input operand.
5185   auto innerOpType = RankedTensorType::get(
5186       /*shape=*/{}, getElementTypeOrSelf(reduceOpFntype.getInput(0)));
5187 
5188   // Add a region for reduce-op.
5189   Region& region = *result.addRegion();
5190 
5191   // Create a basic-block inside reduce-op's region.
5192   Block& block = region.emplaceBlock();
5193   auto lhs = block.addArgument(innerOpType, reduceOpLoc);
5194   auto rhs = block.addArgument(innerOpType, reduceOpLoc);
5195 
5196   // Create and insert an "inner-op" operation in the block.
5197   OpBuilder builder(parser.getBuilder().getContext());
5198   builder.setInsertionPointToStart(&block);
5199 
5200   OperationState innerOpState(reduceOpLoc, innerOpName);
5201   innerOpState.operands.push_back(lhs);
5202   innerOpState.operands.push_back(rhs);
5203   innerOpState.addTypes(innerOpType);
5204 
5205   Operation* innerOp = builder.create(innerOpState);
5206 
5207   // Insert a return statement in the block returning the inner-op's result.
5208   builder.create<ReturnOp>(innerOp->getLoc(), innerOp->getResults());
5209 
5210   // Populate the reduce-op operation-state with result-type, location, and
5211   // dimension attribute.
5212   result.addTypes(reduceOpFntype.getResults());
5213   result.location = innerOp->getLoc();
5214   result.addAttribute("dimensions", builder.getI64TensorAttr(dimensions));
5215 
5216   return success();
5217 }
5218 
verify()5219 LogicalResult ReduceOp::verify() {
5220   // Check that there are even number of operands and >= 2.
5221   if (getNumOperands() % 2 != 0 || getOperands().empty())
5222     return emitOpError() << "expects the size of operands to be even and >= 2";
5223 
5224   // Collect the input and init-value operands. Note that the operand-type is
5225   // enforced as "TensorType" by ODS.
5226   int64_t numInputs = getNumOperands() / 2;
5227   auto operandTensorTypes = llvm::to_vector<4>(llvm::map_range(
5228       getOperandTypes(),
5229       [](Type t) -> TensorType { return t.cast<TensorType>(); }));
5230   ArrayRef<TensorType> inputArgTypes(operandTensorTypes.begin(),
5231                                      operandTensorTypes.begin() + numInputs);
5232   ArrayRef<TensorType> initValueTypes(operandTensorTypes.begin() + numInputs,
5233                                       operandTensorTypes.end());
5234 
5235   // Check for unranked tensors in input operands.
5236   int64_t rankedInputIdx = -1;
5237   for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
5238     if (inputArgTypes[inputIdx].hasRank()) {
5239       rankedInputIdx = inputIdx;
5240       break;
5241     }
5242   }
5243 
5244   bool allInputsUnranked = (rankedInputIdx == -1);
5245 
5246   // Check that all input operands have compatible shapes. The element types may
5247   // be different.
5248   if (!allInputsUnranked) {
5249     for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
5250       if (failed(mlir::verifyCompatibleShape(inputArgTypes[rankedInputIdx],
5251                                              inputArgTypes[inputIdx]))) {
5252         return emitOpError()
5253                << "expects all inputs to have compatible shapes. Shape at"
5254                << " input-index " << inputIdx
5255                << " is not compatible with shape at input-index "
5256                << rankedInputIdx;
5257       }
5258     }
5259   }
5260 
5261   // Check that
5262   //   1. the dimensions of reduce-op are in-bounds for the given shape.
5263   //   2. the dimension-attribute have no duplicate entries.
5264   DenseSet<int64_t> dimensionsToReduceSet;
5265   for (int64_t dimension : dimensions().getValues<int64_t>()) {
5266     if ((!allInputsUnranked &&
5267          dimension >= inputArgTypes[rankedInputIdx].getRank()) ||
5268         dimension < 0) {
5269       return emitError() << "Out-of-bounds dimension " << dimension
5270                          << " for input-tensor rank: "
5271                          << inputArgTypes[rankedInputIdx].getRank();
5272     }
5273 
5274     if (!dimensionsToReduceSet.insert(dimension).second) {
5275       return emitError() << "Duplicate reduction dimension: " << dimension;
5276     }
5277   }
5278 
5279   // Verify the inner block defining the reducer function.
5280   SmallVector<int64_t> newDimensions;
5281   if (!allInputsUnranked) {
5282     for (int inputIdx = 0; inputIdx < inputArgTypes[rankedInputIdx].getRank();
5283          ++inputIdx) {
5284       if (!dimensionsToReduceSet.count(inputIdx)) {
5285         newDimensions.push_back(
5286             inputArgTypes[rankedInputIdx].getDimSize(inputIdx));
5287       }
5288     }
5289   }
5290 
5291   Block& block = body().front();
5292   SmallVector<TensorType> accumulatorSubShapes;
5293   if (failed(verifyReducerShape(this->getLoc(), block, inputArgTypes,
5294                                 initValueTypes, numInputs, newDimensions,
5295                                 allInputsUnranked, accumulatorSubShapes)))
5296     return failure();
5297 
5298   // Check if the reduce-op's result-type matches with the one derived from
5299   // the reducer-block and dimensions attribute.
5300   if (getResults().size() != accumulatorSubShapes.size())
5301     return emitError() << "Unexpected number of reduce-op's returned values: "
5302                        << getResults().size() << " vs "
5303                        << accumulatorSubShapes.size() << " (expected)";
5304 
5305   for (int64_t shapeIdx = 0;
5306        shapeIdx < static_cast<int64_t>(accumulatorSubShapes.size());
5307        shapeIdx++) {
5308     // The result-type is enforced as "TensorType" by ODS.
5309     auto opResultType = getResult(shapeIdx).getType().cast<TensorType>();
5310 
5311     // Check element-type.
5312     if (accumulatorSubShapes[shapeIdx].getElementType() !=
5313         opResultType.getElementType()) {
5314       return emitError()
5315              << "Unexpected element-type for reduce-op's return value at index "
5316              << shapeIdx << ": " << opResultType.getElementType() << " vs "
5317              << accumulatorSubShapes[shapeIdx].getElementType()
5318              << " (expected)";
5319     }
5320 
5321     // Check shape.
5322     if (!allInputsUnranked && opResultType.hasRank() &&
5323         failed(verifyCompatibleShape(newDimensions, opResultType.getShape()))) {
5324       Type expectedResultType = RankedTensorType::get(
5325           newDimensions, accumulatorSubShapes[shapeIdx].getElementType());
5326       return emitError()
5327              << "Unexpected type for reduce-op's return value at index "
5328              << shapeIdx << ": " << opResultType << " vs " << expectedResultType
5329              << " (expected)";
5330     }
5331   }
5332 
5333   return success();
5334 }
5335 
5336 // Enable constant folding to occur within the region of the ReduceOp
5337 // by replacing block argument uses with constants if:
5338 //  1. All the ReduceOp operands are splat constants.
5339 //  2. The ReduceOp region consists of a single logical AND or logical OR.
5340 // The pattern leverages the idempotent property of the AND and OR operators
5341 // to determine the value of a reduction on splat constants. Other boolean
5342 // operators do not have this property, and need separate patterns to resolve
5343 // reductions of their splat constants.
5344 struct LowerBoolSplatConstantsIntoRegion : public OpRewritePattern<ReduceOp> {
5345   using OpRewritePattern<ReduceOp>::OpRewritePattern;
5346 
matchAndRewritemlir::mhlo::LowerBoolSplatConstantsIntoRegion5347   LogicalResult matchAndRewrite(ReduceOp op,
5348                                 PatternRewriter& rewriter) const override {
5349     mlir::Block& bb = op.body().front();
5350 
5351     // Ensure only a compute op and return op exist and the
5352     // compute op is an AND or OR op.
5353     if (bb.getOperations().size() != 2) return failure();
5354     if (!mlir::isa<mhlo::AndOp, mhlo::OrOp>(bb.front())) return failure();
5355 
5356     // Ensure all operands are splat constants.
5357     SmallVector<DenseElementsAttr, 4> bargCstAttrs;
5358     for (auto inpAndBarg : llvm::zip(op.getOperands(), bb.getArguments())) {
5359       Value inp = std::get<0>(inpAndBarg);
5360       BlockArgument barg = std::get<1>(inpAndBarg);
5361       ConstantOp cst = inp.getDefiningOp<ConstantOp>();
5362       if (!cst) return failure();
5363 
5364       auto cstAttr = cst.value().dyn_cast_or_null<DenseElementsAttr>();
5365       if (!cstAttr.isSplat()) {
5366         return rewriter.notifyMatchFailure(op, "Must be splat constant.");
5367       }
5368 
5369       auto bargShapedType = barg.getType().dyn_cast<ShapedType>();
5370       if (!bargShapedType) return failure();
5371 
5372       auto bargCstAttr = DenseElementsAttr::get(
5373           bargShapedType, cstAttr.getSplatValue<mlir::Attribute>());
5374       bargCstAttrs.push_back(bargCstAttr);
5375     }
5376 
5377     // Create new splat constants to replace block arguments.
5378     for (BlockArgument barg : bb.getArguments()) {
5379       int argIdx = barg.getArgNumber();
5380       mhlo::ConstantOp newCst = rewriter.create<mhlo::ConstantOp>(
5381           bb.front().getLoc(), barg.getType(), bargCstAttrs[argIdx]);
5382       barg.replaceAllUsesWith(newCst);
5383     }
5384     return success();
5385   }
5386 };
5387 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)5388 void ReduceOp::getCanonicalizationPatterns(RewritePatternSet& results,
5389                                            MLIRContext* context) {
5390   results.add<LowerBoolSplatConstantsIntoRegion>(context);
5391 }
5392 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)5393 LogicalResult ReduceOp::reifyReturnTypeShapes(
5394     OpBuilder& builder, ValueRange operands,
5395     SmallVectorImpl<Value>& reifiedReturnShapes) {
5396   ReduceOp::Adaptor adaptor(operands);
5397   auto inputs = adaptor.operands();
5398 
5399   auto operandType = inputs[0].getType().dyn_cast<RankedTensorType>();
5400   // Not support unranked type a.t.m.
5401   if (!operandType) return failure();
5402 
5403   Location loc = this->getLoc();
5404   SmallVector<Value, 4> shapeValues;
5405   SmallVector<int64_t, 4> dimensions(this->dimensions().getValues<int64_t>());
5406   shapeValues.reserve(operandType.getRank());
5407   Type shapeScalarType = builder.getIndexType();
5408   auto toShapeScalarType = [&](Value v) {
5409     return maybeCastTo(builder, loc, v, shapeScalarType);
5410   };
5411 
5412   for (const auto& element : llvm::enumerate(operandType.getShape())) {
5413     int64_t idx = element.index();
5414     auto* it = std::find(dimensions.begin(), dimensions.end(), idx);
5415     if (it != dimensions.end()) {
5416       continue;
5417     }
5418     Value valueDim = toShapeScalarType(
5419         builder.create<tensor::DimOp>(loc, inputs[0], element.index()));
5420     shapeValues.push_back(valueDim);
5421   }
5422 
5423   Value outputShape = builder.create<tensor::FromElementsOp>(
5424       loc,
5425       RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
5426                             shapeScalarType),
5427       shapeValues);
5428   for (size_t i = 0; i < inputs.size(); ++i) {
5429     reifiedReturnShapes.push_back(outputShape);
5430   }
5431 
5432   return success();
5433 }
5434 
5435 //===----------------------------------------------------------------------===//
5436 // RngBitGeneratorOp
5437 //===----------------------------------------------------------------------===//
5438 
5439 // Verify that input state has the same shape as output shape
verify()5440 LogicalResult RngBitGeneratorOp::verify() {
5441   auto initialShape = initial_state().getType().dyn_cast<RankedTensorType>();
5442   auto outputShape = output_state().getType().dyn_cast<RankedTensorType>();
5443   if (initialShape.getShape() != outputShape.getShape())
5444     return emitOpError()
5445            << "output state shape must match initial state shape. Got: "
5446            << initialShape << " and " << outputShape;
5447   return success();
5448 }
5449 
5450 //===----------------------------------------------------------------------===//
5451 // RngOp
5452 //===----------------------------------------------------------------------===//
5453 
verify()5454 LogicalResult RngOp::verify() {
5455   auto dist = rng_distribution();
5456   if (dist == RngDistribution::UNIFORM) {
5457     return success();
5458   }
5459   auto muTy = a().getType().cast<TensorType>().getElementType();
5460   auto sigmaTy = b().getType().cast<TensorType>().getElementType();
5461   if (muTy.isa<FloatType>() && sigmaTy.isa<FloatType>()) {
5462     return success();
5463   }
5464   return emitOpError() << "mu and sigma must be floats";
5465 }
5466 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)5467 LogicalResult RngOp::inferReturnTypeComponents(
5468     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
5469     DictionaryAttr attributes, RegionRange regions,
5470     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
5471   return rngInferReturnTypeComponents(context, location, operands, attributes,
5472                                       regions, inferredReturnShapes);
5473 }
5474 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)5475 LogicalResult RngOp::reifyReturnTypeShapes(
5476     OpBuilder& builder, ValueRange operands,
5477     SmallVectorImpl<Value>& reifiedReturnShapes) {
5478   RngOp::Adaptor adaptor(operands);
5479   reifiedReturnShapes.push_back(
5480       castToIndexTensor(builder, getLoc(), adaptor.shape()));
5481   return success();
5482 }
5483 
5484 //===----------------------------------------------------------------------===//
5485 // XlaRngGetAndUpdateStateOp
5486 //===----------------------------------------------------------------------===//
5487 
verify()5488 LogicalResult XlaRngGetAndUpdateStateOp::verify() {
5489   auto resultTy = getType().cast<RankedTensorType>();
5490   if (!resultTy) return emitOpError() << "Output is not ranked.";
5491   if (!resultTy.hasStaticShape())
5492     return emitOpError() << "Output is not statically shaped.";
5493   auto rank = resultTy.getRank();
5494   if (rank != 1)
5495     return emitOpError() << "Output is of rank " << rank << " instead of 1";
5496   auto extent = resultTy.getDimSize(0);
5497   if (extent != 2)
5498     return emitOpError() << "Output size is " << extent << " instead of 2";
5499 
5500   return success();
5501 }
5502 
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)5503 LogicalResult XlaRngGetAndUpdateStateOp::inferReturnTypes(
5504     MLIRContext* ctx, Optional<Location>, ValueRange, DictionaryAttr,
5505     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
5506   inferredReturnTypes.push_back(mlir::RankedTensorType::get(
5507       {2}, mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned)));
5508   return success();
5509 }
5510 
5511 //===----------------------------------------------------------------------===//
5512 // SelectOp
5513 //===----------------------------------------------------------------------===//
5514 
verify()5515 LogicalResult SelectOp::verify() {
5516   // The operands 'on_true' and 'on_false' should have compatible types, i.e.,
5517   //   (a) have the same element type, and
5518   //   (b) have compatible shapes (i.e. the same shape and/or at least one
5519   //       dynamic shape)
5520   if (!compatibleShapeAndElementType(on_true().getType(), on_false().getType()))
5521     return emitOpError()
5522            << "requires compatible types for non-predicate operands";
5523 
5524   // The predicate, if not-scalar, should have the same shape as the remaining
5525   // operands.
5526   auto predTy = pred().getType().dyn_cast<RankedTensorType>();
5527   bool predMayBeScalar = !predTy || predTy.getRank() == 0;
5528   if (predMayBeScalar) return success();
5529 
5530   if (failed(verifyCompatibleShape(pred().getType(), on_true().getType())))
5531     return emitOpError() << "requires the same shape for all operands";
5532 
5533   return success();
5534 }
5535 
fold(ArrayRef<Attribute> operands)5536 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
5537   if (on_true() == on_false()) {
5538     return on_true();
5539   }
5540 
5541   auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
5542   if (!predicate) {
5543     return {};
5544   }
5545 
5546   auto predicateTy = predicate.getType().cast<ShapedType>();
5547   if (!predicateTy.getElementType().isInteger(1)) {
5548     return {};
5549   }
5550 
5551   if (predicate.isSplat()) {
5552     return predicate.getSplatValue<APInt>().getBoolValue() ? on_true()
5553                                                            : on_false();
5554   }
5555 
5556   return {};
5557 }
5558 
5559 // simplify select(not(%pred), true_value, false_value) => select(%pred,
5560 // false_value, true_value)
selectCanonicalization(SelectOp selectOp,PatternRewriter & rewriter)5561 static LogicalResult selectCanonicalization(SelectOp selectOp,
5562                                             PatternRewriter& rewriter) {
5563   auto notOp = selectOp.pred().getDefiningOp<NotOp>();
5564   if (!notOp) {
5565     return failure();
5566   }
5567   std::array<Value, 3> newOperands = {notOp.operand(), selectOp.on_false(),
5568                                       selectOp.on_true()};
5569   rewriter.updateRootInPlace(
5570       selectOp, [&]() { selectOp.getOperation()->setOperands(newOperands); });
5571   return success();
5572 }
5573 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext *)5574 void SelectOp::getCanonicalizationPatterns(RewritePatternSet& results,
5575                                            MLIRContext* /*context*/) {
5576   results.add(&selectCanonicalization);
5577 }
5578 
5579 // Makes it such that a SelectOp that is a non-root operation in a DRR infers
5580 // the return type based on operand type.
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)5581 LogicalResult SelectOp::inferReturnTypeComponents(
5582     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
5583     DictionaryAttr attributes, RegionRange,
5584     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
5585   SelectOp::Adaptor op(operands, attributes);
5586   auto trueType = op.on_true().getType().cast<TensorType>();
5587   auto falseType = op.on_false().getType().cast<TensorType>();
5588 
5589   // The output shape should be the most general of the operand shapes at each
5590   // dimension.
5591   ShapedTypeComponents& outputType = inferredReturnShapes.emplace_back();
5592   if (trueType == falseType || !trueType.hasRank()) {
5593     outputType = ShapedTypeComponents(trueType.cast<ShapedType>());
5594   } else if (!falseType.hasRank()) {
5595     outputType = ShapedTypeComponents(falseType.cast<ShapedType>());
5596   } else {
5597     assert(trueType.getRank() == falseType.getRank());
5598     llvm::SmallVector<int64_t, 4> dims;
5599     dims.reserve(trueType.getRank());
5600     for (auto dim : llvm::zip(trueType.getShape(), falseType.getShape())) {
5601       dims.push_back(std::get<0>(dim) == std::get<1>(dim)
5602                          ? std::get<0>(dim)
5603                          : ShapedType::kDynamicSize);
5604     }
5605     outputType = ShapedTypeComponents(dims, trueType.getElementType());
5606   }
5607   return success();
5608 }
5609 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)5610 LogicalResult SelectOp::reifyReturnTypeShapes(
5611     OpBuilder& builder, ValueRange operands,
5612     SmallVectorImpl<Value>& reifiedReturnShapes) {
5613   // For `hlo.select`, the first operand may be a scalar.
5614   return deriveShapeFromOperand(&builder, getOperation(), operands[1],
5615                                 &reifiedReturnShapes);
5616 }
5617 
5618 //===----------------------------------------------------------------------===//
5619 // SetDimensionSizeOp
5620 //===----------------------------------------------------------------------===//
5621 
verify()5622 LogicalResult SetDimensionSizeOp::verify() {
5623   if (auto size = this->size().getType().dyn_cast<RankedTensorType>()) {
5624     if (size.getRank() != 0)
5625       return emitOpError() << "size operand should be of rank-0";
5626   }
5627 
5628   return verifyDimAttr(*this);
5629 }
5630 
fold(ArrayRef<Attribute> operands)5631 OpFoldResult SetDimensionSizeOp::fold(ArrayRef<Attribute> operands) {
5632   DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
5633   if (input) return input;
5634 
5635   DenseElementsAttr size = operands[1].dyn_cast_or_null<DenseElementsAttr>();
5636   if (!size || !size.isSplat()) return {};
5637 
5638   auto ty = getType().dyn_cast<RankedTensorType>();
5639   if (!ty) return {};
5640 
5641   int64_t dimSize = ty.getDimSize(dimension());
5642   if (dimSize == size.getSplatValue<IntegerAttr>().getInt()) return operand();
5643   return {};
5644 }
5645 
5646 // TODO(b/238903565): Switch to inferReturnTypeComponents after adding support
5647 // for the encoding upstream.
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)5648 LogicalResult SetDimensionSizeOp::inferReturnTypes(
5649     MLIRContext* context, Optional<Location> location, ValueRange operands,
5650     DictionaryAttr attributes, RegionRange regions,
5651     SmallVectorImpl<Type>& inferredReturnTypes) {
5652   Location loc = location.value_or(UnknownLoc::get(context));
5653 
5654   SetDimensionSizeOp::Adaptor adaptor(operands, attributes, regions);
5655   if (failed(adaptor.verify(loc))) return failure();
5656 
5657   auto inputType = adaptor.operand().getType().dyn_cast<RankedTensorType>();
5658   if (!inputType) {
5659     inferredReturnTypes.push_back(adaptor.operand().getType());
5660     return success();
5661   }
5662 
5663   int64_t dim = adaptor.dimension();
5664   int64_t rank = inputType.getRank();
5665   if (dim < 0 || dim >= rank) {
5666     return mlir::emitError(loc) << "expects dimension to be in range [0, "
5667                                 << rank << "); got: [" << dim << "].";
5668   }
5669 
5670   auto shape = llvm::to_vector<4>(inputType.getShape());
5671   llvm::SmallVector<int64_t, 4> bounds(rank, ShapedType::kDynamicSize);
5672   if (auto encoding =
5673           inputType.getEncoding().dyn_cast_or_null<TypeExtensionsAttr>())
5674     bounds = llvm::to_vector<4>(encoding.getBounds());
5675 
5676   // TODO(hinsu): Handle the case when the size operand is a constant.
5677   if (shape[dim] != ShapedType::kDynamicSize) bounds[dim] = shape[dim];
5678   shape[dim] = ShapedType::kDynamicSize;
5679 
5680   auto extensions = TypeExtensionsAttr::get(context, bounds);
5681   auto resultType =
5682       RankedTensorType::get(shape, inputType.getElementType(), extensions);
5683   inferredReturnTypes.push_back(resultType);
5684   return success();
5685 }
5686 
5687 //===----------------------------------------------------------------------===//
5688 // PadOp
5689 //===----------------------------------------------------------------------===//
5690 
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)5691 LogicalResult PadOp::inferReturnTypeComponents(
5692     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
5693     DictionaryAttr attributes, RegionRange regions,
5694     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
5695   PadOp::Adaptor adaptor(operands, attributes, regions);
5696   auto inputType = adaptor.operand().getType().cast<RankedTensorType>();
5697   auto padType = adaptor.padding_value().getType().cast<RankedTensorType>();
5698 
5699   if (padType.getRank() != 0) {
5700     return emitOptionalError(
5701         location, llvm::formatv("padding value type should be a rank-0 "
5702                                 "tensor, is rank {0}",
5703                                 padType.getRank()));
5704   }
5705 
5706   const auto& paddingLow = adaptor.edge_padding_low();
5707   if (paddingLow.getType().getNumElements() != inputType.getRank()) {
5708     return emitOptionalError(
5709         location,
5710         llvm::formatv(
5711             "edge_padding_low length ({0}) must match operand rank ({1})",
5712             paddingLow.getType().getNumElements(), inputType.getRank()));
5713   }
5714 
5715   const auto& paddingHigh = adaptor.edge_padding_high();
5716   if (paddingHigh.getType().getNumElements() != inputType.getRank()) {
5717     return emitOptionalError(
5718         location,
5719         llvm::formatv(
5720             "edge_padding_high length ({0}) must match operand rank ({1})",
5721             paddingHigh.getType().getNumElements(), inputType.getRank()));
5722   }
5723 
5724   const auto& paddingInterior = adaptor.interior_padding();
5725   if (paddingInterior.getType().getNumElements() != inputType.getRank()) {
5726     return emitOptionalError(
5727         location,
5728         llvm::formatv(
5729             "interior_padding length ({0}) must match operand rank ({1})",
5730             paddingInterior.getType().getNumElements(), inputType.getRank()));
5731   }
5732 
5733   auto inputShape = inputType.getShape();
5734   SmallVector<int64_t> resultShape;
5735   for (int i = 0, e = inputShape.size(); i < e; i++) {
5736     if (isDynamicDimSize(inputShape[i])) {
5737       resultShape.push_back(ShapedType::kDynamicSize);
5738       continue;
5739     }
5740 
5741     int64_t paddingLowVal = paddingLow.getValues<APInt>()[i].getSExtValue();
5742     int64_t paddingHighVal = paddingHigh.getValues<APInt>()[i].getSExtValue();
5743     int64_t paddingInteriorVal =
5744         paddingInterior.getValues<APInt>()[i].getSExtValue();
5745     if (paddingInteriorVal < 0) {
5746       return emitOptionalError(
5747           location, llvm::formatv("Interior padding cannot be negative: {0}",
5748                                   paddingInteriorVal));
5749     }
5750     int64_t expectedOutput =
5751         inputShape[i] + paddingLowVal + paddingHighVal +
5752         std::max<int64_t>(inputShape[i] - 1, 0LL) * paddingInteriorVal;
5753     if (expectedOutput < 0) {
5754       return emitOptionalError(
5755           location,
5756           llvm::formatv("Padding result in negative size for dimension {0}",
5757                         i));
5758     }
5759     resultShape.push_back(expectedOutput);
5760   }
5761   inferredReturnShapes.emplace_back(resultShape, inputType.getElementType());
5762 
5763   return success();
5764 }
5765 
5766 template <typename T>
padOpFoldHelper(DenseElementsAttr input,DenseElementsAttr padding,RankedTensorType returnType,DenseIntElementsAttr edgePaddingLow,DenseIntElementsAttr,DenseIntElementsAttr interiorPadding)5767 OpFoldResult padOpFoldHelper(DenseElementsAttr input, DenseElementsAttr padding,
5768                              RankedTensorType returnType,
5769                              DenseIntElementsAttr edgePaddingLow,
5770                              DenseIntElementsAttr /*edgePaddingHigh*/,
5771                              DenseIntElementsAttr interiorPadding) {
5772   // Prevent folding if the result is too large.
5773   if (returnType.getNumElements() > kFoldOpEltLimit) return {};
5774 
5775   // Fill the full result tensor with the padding value.
5776   llvm::SmallVector<T, 4> result(returnType.getNumElements(),
5777                                  padding.getValues<T>()[0]);
5778 
5779   auto nextIndex = [](llvm::SmallVector<uint64_t, 8>& index,
5780                       llvm::ArrayRef<int64_t> shape) {
5781     for (int64_t i = index.size() - 1; i >= 0; --i) {
5782       ++index[i];
5783       if (static_cast<int64_t>(index[i]) < shape[i]) return;
5784       index[i] = 0;
5785     }
5786   };
5787 
5788   // Iterate over all elements of the input tensor and copy it to the correct
5789   // location in the output tensor.
5790   llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
5791   uint64_t numElements = input.getNumElements();
5792   for (uint64_t operandIdx = 0; operandIdx < numElements; operandIdx++) {
5793     uint64_t resultIdx = 0;
5794     uint64_t idxMultiplyer = 1;
5795     for (int64_t i = index.size() - 1; i >= 0; --i) {
5796       resultIdx += (edgePaddingLow.getValues<int64_t>()[i] +
5797                     index[i] * (interiorPadding.getValues<int64_t>()[i] + 1)) *
5798                    idxMultiplyer;
5799       idxMultiplyer *= returnType.getDimSize(i);
5800     }
5801     result[resultIdx] = input.getValues<T>()[index];
5802     nextIndex(index, input.getType().getShape());
5803   }
5804   return DenseElementsAttr::get(returnType, result);
5805 }
5806 
fold(ArrayRef<Attribute> operands)5807 OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
5808   // If all padding is zero then it is an identity pad.
5809   auto isZero = [](const APInt& i) { return i == 0; };
5810   if (llvm::all_of(edge_padding_low().getValues<APInt>(), isZero) &&
5811       llvm::all_of(edge_padding_high().getValues<APInt>(), isZero) &&
5812       llvm::all_of(interior_padding().getValues<APInt>(), isZero))
5813     return operand();
5814 
5815   // If any padding is negative then it isn't supported by the folder (yet).
5816   auto isNegative = [](const APInt& i) { return i.slt(0); };
5817   if (llvm::any_of(edge_padding_low().getValues<APInt>(), isNegative) ||
5818       llvm::any_of(edge_padding_high().getValues<APInt>(), isNegative) ||
5819       llvm::any_of(interior_padding().getValues<APInt>(), isNegative))
5820     return {};
5821 
5822   DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
5823   DenseElementsAttr padding = operands[1].dyn_cast_or_null<DenseElementsAttr>();
5824   RankedTensorType returnType = getType().dyn_cast_or_null<RankedTensorType>();
5825   if (!input || !input.getType().hasRank() || !padding || !returnType ||
5826       !returnType.hasStaticShape())
5827     return {};
5828 
5829   if (returnType.getElementType().isa<IntegerType>())
5830     return padOpFoldHelper<APInt>(input, padding, returnType,
5831                                   edge_padding_low(), edge_padding_high(),
5832                                   interior_padding());
5833   if (returnType.getElementType().isa<FloatType>())
5834     return padOpFoldHelper<APFloat>(input, padding, returnType,
5835                                     edge_padding_low(), edge_padding_high(),
5836                                     interior_padding());
5837   if (ComplexType complex =
5838           returnType.getElementType().dyn_cast_or_null<ComplexType>()) {
5839     // TODO(atondwal): Allow int types in HLO_complex
5840     if (complex.getElementType().isa<FloatType>())
5841       return padOpFoldHelper<std::complex<APFloat>>(
5842           input, padding, returnType, edge_padding_low(), edge_padding_high(),
5843           interior_padding());
5844   }
5845   return {};
5846 }
5847 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)5848 LogicalResult PadOp::reifyReturnTypeShapes(
5849     OpBuilder& builder, ValueRange operands,
5850     SmallVectorImpl<Value>& reifiedReturnShapes) {
5851   PadOp::Adaptor adaptor(operands, this->getOperation()->getAttrDictionary());
5852   auto loc = this->getLoc();
5853   Value operand = adaptor.operand();
5854   auto operandTy = operand.getType().cast<RankedTensorType>();
5855 
5856   llvm::SmallVector<int32_t> padHigh;
5857   llvm::SmallVector<int32_t> padLow;
5858   llvm::SmallVector<int32_t> padInterior;
5859 
5860   auto padHighAttr = adaptor.edge_padding_high();
5861   auto padLowAttr = adaptor.edge_padding_low();
5862   auto padInteriorAttr = adaptor.interior_padding();
5863 
5864   padHigh.reserve(padHighAttr.getNumElements());
5865   padLow.reserve(padLowAttr.getNumElements());
5866   padInterior.reserve(padInteriorAttr.getNumElements());
5867 
5868   for (const APInt& val : padHighAttr.getValues<APInt>())
5869     padHigh.push_back(val.getSExtValue());
5870 
5871   for (const APInt& val : padLowAttr.getValues<APInt>())
5872     padLow.push_back(val.getSExtValue());
5873 
5874   for (const APInt& val : padInteriorAttr.getValues<APInt>())
5875     padInterior.push_back(val.getSExtValue());
5876 
5877   Value one = builder.create<arith::ConstantIndexOp>(loc, 1).getResult();
5878   Value zero = builder.create<arith::ConstantIndexOp>(loc, 0).getResult();
5879 
5880   llvm::SmallVector<Value> dimensions;
5881   dimensions.reserve(operandTy.getRank());
5882   for (int i = 0, s = operandTy.getRank(); i < s; ++i) {
5883     Value padEdge =
5884         builder.create<arith::ConstantIndexOp>(loc, padHigh[i] + padLow[i]);
5885 
5886     // First we grab the initial interior size.
5887     Value dim = builder.create<tensor::DimOp>(loc, operand, i).getResult();
5888 
5889     // Compute the interior of the tensor and determine padding size.
5890     if (padInterior[i] > 0) {
5891       Value padInter =
5892           builder.create<arith::ConstantIndexOp>(loc, padInterior[i])
5893               .getResult();
5894       Value interior = builder.create<arith::SubIOp>(loc, dim, one).getResult();
5895       interior = builder.create<arith::MaxSIOp>(loc, interior, zero);
5896       interior = builder.create<arith::MulIOp>(loc, interior, padInter);
5897       dim = builder.create<arith::AddIOp>(loc, dim, interior).getResult();
5898     }
5899 
5900     // Then we add the padding on the edge of the tensor.
5901     dim = builder.create<arith::AddIOp>(loc, dim, padEdge).getResult();
5902     dimensions.push_back(dim);
5903   }
5904 
5905   Value dimensionTensor =
5906       builder.create<tensor::FromElementsOp>(loc, dimensions).getResult();
5907   reifiedReturnShapes.push_back(dimensionTensor);
5908   return success();
5909 }
5910 
5911 // If the input tensor has a dimension of length-0, the input tensor is
5912 // irrelevant. Instead we can broadcast the pad value to the output size rather
5913 // than pad the input tensor.
5914 struct PadEmptyTensor : public OpRewritePattern<PadOp> {
5915   using OpRewritePattern<PadOp>::OpRewritePattern;
5916 
matchAndRewritemlir::mhlo::PadEmptyTensor5917   LogicalResult matchAndRewrite(PadOp op,
5918                                 PatternRewriter& rewriter) const override {
5919     auto operand = op.operand();
5920     auto padVal = op.padding_value();
5921 
5922     auto operandTy = operand.getType().cast<RankedTensorType>();
5923     auto resultTy = op.getType().cast<RankedTensorType>();
5924 
5925     if (llvm::all_of(operandTy.getShape(), [](int64_t d) { return d != 0; })) {
5926       return failure();
5927     }
5928 
5929     if (resultTy.hasStaticShape()) {
5930       auto dimsType = RankedTensorType::get({0}, rewriter.getIntegerType(64));
5931       auto dims =
5932           DenseIntElementsAttr::get(dimsType, SmallVector<int64_t, 1>{});
5933       rewriter.replaceOpWithNewOp<mhlo::BroadcastInDimOp>(op, resultTy, padVal,
5934                                                           dims);
5935       return success();
5936     }
5937 
5938     llvm::SmallVector<Value> reifiedShapes;
5939     if (failed(op.reifyReturnTypeShapes(rewriter, op.getOperands(),
5940                                         reifiedShapes)))
5941       return failure();
5942 
5943     auto dimsType = RankedTensorType::get({0}, rewriter.getIntegerType(64));
5944     auto broadcastDims =
5945         DenseIntElementsAttr::get(dimsType, SmallVector<int64_t, 1>{});
5946     rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
5947         op, op.getType(), padVal, reifiedShapes.front(), broadcastDims);
5948 
5949     return failure();
5950   }
5951 };
5952 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)5953 void PadOp::getCanonicalizationPatterns(RewritePatternSet& results,
5954                                         MLIRContext* context) {
5955   results.add<PadEmptyTensor>(context);
5956 }
5957 
5958 //===----------------------------------------------------------------------===//
5959 // DynamicPadOp
5960 //===----------------------------------------------------------------------===//
5961 
5962 // If the input tensor has a dimension of length-0, the input tensor is
5963 // irrelevant. Instead we can broadcast the pad value to the output size rather
5964 // than pad the input tensor.
5965 struct DynamicPadEmptyTensor : public OpRewritePattern<DynamicPadOp> {
5966   using OpRewritePattern<DynamicPadOp>::OpRewritePattern;
5967 
matchAndRewritemlir::mhlo::DynamicPadEmptyTensor5968   LogicalResult matchAndRewrite(DynamicPadOp op,
5969                                 PatternRewriter& rewriter) const override {
5970     // auto loc = op.getLoc();
5971     auto operand = op.operand();
5972     auto padVal = op.padding_value();
5973 
5974     auto operandTy = operand.getType().cast<RankedTensorType>();
5975 
5976     if (llvm::all_of(operandTy.getShape(), [](int64_t d) { return d != 0; })) {
5977       return failure();
5978     }
5979 
5980     llvm::SmallVector<Value> reifiedShapes;
5981     if (failed(op.reifyReturnTypeShapes(rewriter, op->getOperands(),
5982                                         reifiedShapes)))
5983       return failure();
5984 
5985     auto dimsType = RankedTensorType::get({0}, rewriter.getIntegerType(64));
5986     auto broadcastDims =
5987         DenseIntElementsAttr::get(dimsType, SmallVector<int64_t, 1>{});
5988     rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
5989         op, op.getType(), padVal, reifiedShapes.front(), broadcastDims);
5990 
5991     return failure();
5992   }
5993 };
5994 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)5995 void DynamicPadOp::getCanonicalizationPatterns(RewritePatternSet& results,
5996                                                MLIRContext* context) {
5997   results.add<DPadToPad, DynamicPadEmptyTensor>(context);
5998 }
5999 
verify()6000 LogicalResult DynamicPadOp::verify() {
6001   auto inputType = operand().getType().dyn_cast<RankedTensorType>();
6002   // If operand is unranked, there is very little to verify statically.
6003   if (!inputType) return success();
6004   int inputRank = inputType.getRank();
6005 
6006   auto padType = padding_value().getType().cast<RankedTensorType>();
6007   if (padType.getRank() != 0) {
6008     return emitOpError() << "padding value type should be a rank-0";
6009   }
6010 
6011   auto paddingLowType = edge_padding_low().getType().cast<RankedTensorType>();
6012   if (paddingLowType.getNumElements() != inputRank) {
6013     return emitOpError() << "edge_padding_low length("
6014                          << paddingLowType.getNumElements()
6015                          << ") must match operand rank(" << inputRank << ").";
6016   }
6017 
6018   auto paddingHighType = edge_padding_high().getType().cast<RankedTensorType>();
6019   if (paddingHighType.getNumElements() != inputRank) {
6020     return emitOpError() << "edge_padding_high length("
6021                          << paddingHighType.getNumElements()
6022                          << ") must match operand rank(" << inputRank << ").";
6023   }
6024 
6025   auto interiorPaddingType =
6026       interior_padding().getType().cast<RankedTensorType>();
6027   if (interiorPaddingType.getNumElements() != inputRank) {
6028     return emitOpError() << "edge_padding_interior length("
6029                          << interiorPaddingType.getNumElements()
6030                          << ") must match operand rank(" << inputRank << ").";
6031   }
6032 
6033   auto outputType = getResult().getType().dyn_cast<RankedTensorType>();
6034   // If result is unranked, there is very little to verify statically.
6035   if (!outputType) return success();
6036   int outputRank = outputType.getRank();
6037   if (inputRank != outputRank) {
6038     return emitOpError() << "operand rank(" << inputRank
6039                          << ") must match result(" << outputRank << ").";
6040   }
6041 
6042   return success();
6043 }
6044 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)6045 LogicalResult DynamicPadOp::reifyReturnTypeShapes(
6046     OpBuilder& builder, ValueRange operands,
6047     SmallVectorImpl<Value>& reifiedReturnShapes) {
6048   DynamicPadOp::Adaptor adaptor(operands);
6049   Value operand = adaptor.operand();
6050   Value edgePaddingLow = adaptor.edge_padding_low();
6051   Value edgePaddingHigh = adaptor.edge_padding_high();
6052   Value interiorPadding = adaptor.interior_padding();
6053 
6054   auto operandType = operand.getType().dyn_cast<RankedTensorType>();
6055   // Not support unranked pad a.t.m.
6056   if (!operandType) return failure();
6057 
6058   auto loc = this->getLoc();
6059   SmallVector<Value, 4> shapeValues;
6060   shapeValues.reserve(operandType.getRank());
6061   Type shapeScalarType =
6062       edgePaddingLow.getType().cast<ShapedType>().getElementType();
6063 
6064   auto toShapeScalarType = [&](Value v) {
6065     return maybeCastTo(builder, loc, v, shapeScalarType);
6066   };
6067 
6068   Value zero =
6069       toShapeScalarType(builder.create<arith::ConstantIndexOp>(loc, 0));
6070   Value one = toShapeScalarType(builder.create<arith::ConstantIndexOp>(loc, 1));
6071 
6072   for (int idx : llvm::seq<int>(0, operandType.getShape().size())) {
6073     Value valueDim =
6074         toShapeScalarType(builder.create<tensor::DimOp>(loc, operand, idx));
6075     Value offset = builder.create<arith::ConstantIndexOp>(loc, idx);
6076     Value valueLow =
6077         builder.create<tensor::ExtractOp>(loc, edgePaddingLow, offset);
6078     Value valueHigh =
6079         builder.create<tensor::ExtractOp>(loc, edgePaddingHigh, offset);
6080     Value valueInterior =
6081         builder.create<tensor::ExtractOp>(loc, interiorPadding, offset);
6082     // output_size = input_size + padding_low + padding_high + interior *
6083     // max(input_size - 1, 0)
6084     Value valueDimLessThanOne = builder.create<arith::CmpIOp>(
6085         loc, arith::CmpIPredicate::slt, valueDim, one);
6086     Value interiorSize = builder.create<arith::MulIOp>(
6087         loc, valueInterior,
6088         builder.create<mlir::arith::SelectOp>(
6089             loc, valueDimLessThanOne, zero,
6090             builder.create<arith::SubIOp>(loc, valueDim, one)));
6091     shapeValues.push_back(builder.create<arith::AddIOp>(
6092         loc,
6093         builder.create<arith::AddIOp>(
6094             loc, builder.create<arith::AddIOp>(loc, interiorSize, valueDim),
6095             valueLow),
6096         valueHigh));
6097   }
6098 
6099   reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
6100       loc,
6101       RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
6102                             shapeScalarType),
6103       shapeValues));
6104 
6105   return success();
6106 }
6107 
6108 //===----------------------------------------------------------------------===//
6109 // ReshapeOp
6110 //===----------------------------------------------------------------------===//
6111 
verify()6112 LogicalResult ReshapeOp::verify() {
6113   // If the operand type is dynamically shaped there is nothing to verify.
6114   auto operandTy = operand().getType().dyn_cast<RankedTensorType>();
6115   if (!operandTy || !operandTy.hasStaticShape()) return success();
6116 
6117   // If the operand type is statically shaped (not required) the number of
6118   // elements must match that of the result type.
6119   auto resultTy = getType().cast<RankedTensorType>();
6120   assert(resultTy && resultTy.hasStaticShape() &&
6121          "result type must be statically shaped");
6122   int64_t numResultElements = resultTy.getNumElements();
6123   int64_t numOperandElements = operandTy.getNumElements();
6124   if (numResultElements != numOperandElements)
6125     return emitOpError() << "number of output elements (" << numResultElements
6126                          << ") doesn't match expected number of elements ("
6127                          << numOperandElements << ")";
6128 
6129   return success();
6130 }
6131 
fold(ArrayRef<Attribute> operands)6132 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
6133   if (getOperand().getType() == getType()) {
6134     return getOperand();
6135   }
6136 
6137   if (auto prevOp = getOperand().getDefiningOp<ReshapeOp>()) {
6138     setOperand(prevOp.getOperand());
6139     return getResult();
6140   }
6141 
6142   if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
6143     return reshape(elements, getResult().getType().cast<ShapedType>());
6144   }
6145 
6146   return {};
6147 }
6148 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)6149 void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet& results,
6150                                             MLIRContext* context) {
6151   results.add<IdentityBroadcastReshape, IdentityBroadcastInDimReshape,
6152               EliminateRedundantReshape, EliminateIdentityReshape>(context);
6153 }
6154 
6155 //===----------------------------------------------------------------------===//
6156 // ReplicaId Op
6157 //===----------------------------------------------------------------------===//
6158 
inferReturnTypes(MLIRContext * context,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)6159 LogicalResult ReplicaIdOp::inferReturnTypes(
6160     MLIRContext* context, Optional<Location>, ValueRange operands,
6161     DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
6162   inferredReturnTypes.push_back(RankedTensorType::get(
6163       /*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned)));
6164   return success();
6165 }
6166 
6167 //===----------------------------------------------------------------------===//
6168 // AddDependency Op
6169 //===----------------------------------------------------------------------===//
6170 
inferReturnTypes(MLIRContext * context,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)6171 LogicalResult AddDependencyOp::inferReturnTypes(
6172     MLIRContext* context, Optional<Location>, ValueRange operands,
6173     DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
6174   inferredReturnTypes.push_back(operands.getTypes()[0]);
6175   return success();
6176 }
6177 
6178 //===----------------------------------------------------------------------===//
6179 // If Op
6180 //===----------------------------------------------------------------------===//
6181 
verifyConditionalBranch(Operation * op,Region & region,llvm::Twine branchName)6182 static LogicalResult verifyConditionalBranch(Operation* op, Region& region,
6183                                              llvm::Twine branchName) {
6184   if (region.getNumArguments() != 0)
6185     return op->emitOpError()
6186            << branchName << " must have 0 arguments, but found "
6187            << region.getNumArguments();
6188 
6189   TypeRange branchReturnTypes =
6190       region.front().getTerminator()->getOperandTypes();
6191   if (branchReturnTypes != op->getResultTypes())
6192     return op->emitOpError()
6193            << branchName << " returned types (" << branchReturnTypes
6194            << ") do not match op result types (" << op->getResultTypes() << ")";
6195 
6196   return success();
6197 }
6198 
verify()6199 LogicalResult IfOp::verify() {
6200   if (failed(verifyConditionalBranch(*this, true_branch(),
6201                                      /*branchName=*/"true_branch"))) {
6202     return failure();
6203   }
6204 
6205   if (failed(verifyConditionalBranch(*this, false_branch(),
6206                                      /*branchName=*/"false_branch"))) {
6207     return failure();
6208   }
6209   return success();
6210 }
6211 
inlineIfConstantCondition(IfOp ifOp,PatternRewriter & rewriter)6212 static LogicalResult inlineIfConstantCondition(IfOp ifOp,
6213                                                PatternRewriter& rewriter) {
6214   DenseIntElementsAttr predAttr;
6215   if (!matchPattern(ifOp.pred(), m_Constant(&predAttr))) return failure();
6216 
6217   if (predAttr.getSplatValue<BoolAttr>().getValue()) {
6218     replaceOpWithRegion(rewriter, ifOp, ifOp.true_branch());
6219   } else {
6220     replaceOpWithRegion(rewriter, ifOp, ifOp.false_branch());
6221   }
6222   return success();
6223 }
6224 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)6225 void IfOp::getCanonicalizationPatterns(RewritePatternSet& results,
6226                                        MLIRContext* context) {
6227   results.add(&inlineIfConstantCondition);
6228 }
6229 
6230 //===----------------------------------------------------------------------===//
6231 // Case Op
6232 //===----------------------------------------------------------------------===//
6233 
verify()6234 LogicalResult CaseOp::verify() {
6235   auto numBranches = branches().size();
6236 
6237   for (unsigned i = 0; i < numBranches; ++i)
6238     if (failed(verifyConditionalBranch(*this, branches()[i],
6239                                        /*branchName=*/"branch " + Twine(i))))
6240       return failure();
6241 
6242   return success();
6243 }
6244 
inlineCaseConstantCondition(CaseOp caseOp,PatternRewriter & rewriter)6245 static LogicalResult inlineCaseConstantCondition(CaseOp caseOp,
6246                                                  PatternRewriter& rewriter) {
6247   DenseIntElementsAttr indexAttr;
6248   if (!matchPattern(caseOp.index(), m_Constant(&indexAttr))) {
6249     return failure();
6250   }
6251   int64_t index =
6252       indexAttr.getSplatValue<IntegerAttr>().getValue().getSExtValue();
6253   // For an OOB index, the last branch is executed as the default branch:
6254   // https://www.tensorflow.org/xla/operation_semantics#conditional
6255   if (index < 0 || index >= caseOp.getNumRegions())
6256     index = caseOp.getNumRegions() - 1;
6257 
6258   Region& region = caseOp.getRegion(index);
6259   if (!llvm::hasSingleElement(region)) return failure();
6260   replaceOpWithRegion(rewriter, caseOp, region);
6261   return success();
6262 }
6263 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)6264 void CaseOp::getCanonicalizationPatterns(RewritePatternSet& results,
6265                                          MLIRContext* context) {
6266   results.add(&inlineCaseConstantCondition);
6267 }
6268 
6269 //===----------------------------------------------------------------------===//
6270 // SqrtOp
6271 //===----------------------------------------------------------------------===//
6272 
fold(ArrayRef<Attribute> operands)6273 OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
6274   auto val = operands[0].dyn_cast_or_null<DenseElementsAttr>();
6275   if (!val) return {};
6276 
6277   auto type = getElementTypeOrSelf(getType());
6278   if (!type.isF32() && !type.isF64()) return {};
6279 
6280   auto shapedType = getType().cast<ShapedType>();
6281   if (!shapedType.hasStaticShape()) return {};
6282 
6283   // Prevent folding if the result is too large.
6284   if (val.getNumElements() > kFoldOpEltLimit) return {};
6285 
6286   int bitWidth = type.getIntOrFloatBitWidth();
6287   llvm::SmallVector<APFloat, 4> values;
6288   values.reserve(val.getNumElements());
6289   for (auto it : val.getValues<APFloat>()) {
6290     double value = bitWidth == 32 ? it.convertToFloat() : it.convertToDouble();
6291     if (value < 0) return {};
6292     value = std::sqrt(value);
6293     if (bitWidth == 32)
6294       values.emplace_back(static_cast<float>(value));
6295     else
6296       values.emplace_back(value);
6297   }
6298   return DenseFPElementsAttr::get(shapedType, values);
6299 }
6300 
6301 //===----------------------------------------------------------------------===//
6302 // UnaryOps
6303 //===----------------------------------------------------------------------===//
6304 
parseUnaryOp(OpAsmParser & parser,OperationState & result)6305 ParseResult parseUnaryOp(OpAsmParser& parser, OperationState& result) {
6306   SmallVector<OpAsmParser::UnresolvedOperand> operands;
6307   Type type;
6308   // If the operand is in-between parentheses, use generic form.
6309   SMLoc loc = parser.getCurrentLocation();
6310   if (!parser.parseOptionalLParen()) {
6311     if (parser.parseOperandList(operands) || parser.parseRParen() ||
6312         parser.parseOptionalAttrDict(result.attributes) ||
6313         parser.parseColon() || parser.parseType(type))
6314       return failure();
6315     auto fnType = type.dyn_cast<FunctionType>();
6316     if (!fnType) {
6317       parser.emitError(loc, "expected function type");
6318       return failure();
6319     }
6320     if (parser.resolveOperands(operands, fnType.getInputs(), loc,
6321                                result.operands))
6322       return failure();
6323     result.addTypes(fnType.getResults());
6324     return success();
6325   }
6326   // Otherwise, use shorthand syntax.
6327   return failure(parser.parseOperandList(operands) ||
6328                  parser.parseOptionalAttrDict(result.attributes) ||
6329                  parser.parseColonType(type) ||
6330                  parser.resolveOperands(operands, type, result.operands) ||
6331                  parser.addTypeToList(type, result.types));
6332 }
6333 
printUnaryOp(Operation * op,OpAsmPrinter & p)6334 void printUnaryOp(Operation* op, OpAsmPrinter& p) {
6335   assert(op->getNumResults() == 1 && "op should have one result");
6336   assert(op->getNumOperands() == 1 && "op should have one input");
6337   // If not all types are the same, use generic form.
6338   auto resultType = op->getResult(0).getType();
6339   if (resultType != op->getOperandTypes()[0]) {
6340     p.printGenericOp(op, /*printOpName=*/false);
6341     return;
6342   }
6343   // Otherwise, use the shorthand syntax.
6344   p << ' ';
6345   p.printOperands(op->getOperands());
6346   p.printOptionalAttrDict(op->getAttrs());
6347   p << " : " << resultType;
6348 }
6349 
6350 template <typename Op, typename ElementType = Type, typename ValType,
6351           typename Convert>
UnaryFolder(Op * op,ArrayRef<Attribute> attrs)6352 static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
6353   if (!attrs[0]) return {};
6354 
6355   DenseElementsAttr val = attrs[0].dyn_cast<DenseElementsAttr>();
6356   if (!val) return {};
6357 
6358   ShapedType type = op->getType().template cast<ShapedType>();
6359   if (!type.hasStaticShape()) {
6360     return {};
6361   }
6362 
6363   Type etype = type.getElementType();
6364 
6365   // Evaluate for integer values.
6366   if (!etype.isa<ElementType>()) {
6367     return {};
6368   }
6369 
6370   // Prevent folding if the result is too large.
6371   if (val.getNumElements() > kFoldOpEltLimit) return {};
6372 
6373   SmallVector<ValType, 6> values;
6374   values.reserve(val.getNumElements());
6375   for (const auto v : val.getValues<ValType>()) {
6376     values.push_back(Convert()(v));
6377   }
6378 
6379   return DenseElementsAttr::get(type, values);
6380 }
6381 
6382 struct Round {
operator ()mlir::mhlo::Round6383   APFloat operator()(const APFloat& f) {
6384     APFloat r = f;
6385     r.roundToIntegral(llvm::RoundingMode::NearestTiesToAway);
6386     return r;
6387   }
6388 };
6389 
6390 struct RoundNearestEven {
operator ()mlir::mhlo::RoundNearestEven6391   APFloat operator()(const APFloat& f) {
6392     APFloat r = f;
6393     r.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
6394     return r;
6395   }
6396 };
6397 
6398 struct LogicalNot {
operator ()mlir::mhlo::LogicalNot6399   APInt operator()(const APInt& i) {
6400     return APInt(i.getBitWidth(), static_cast<uint64_t>(!i));
6401   }
6402 };
6403 
6404 template <typename FloatOrInt>
6405 struct Sign {
computemlir::mhlo::Sign6406   APFloat compute(const APFloat& f) {
6407     if (f.isZero() || f.isNaN()) return f;
6408     double value = f.isNegative() ? -1.0 : 1.0;
6409     APFloat val(value);
6410     bool unused;
6411     val.convert(f.getSemantics(), APFloat::rmNearestTiesToEven, &unused);
6412     return val;
6413   }
6414 
computemlir::mhlo::Sign6415   APInt compute(const APInt& i) {
6416     APInt r = i;
6417     if (r == 0) return r;
6418     if (r.isNegative()) {
6419       return APInt(r.getBitWidth(), -1, /*isSigned=*/true);
6420     }
6421     return APInt(r.getBitWidth(), 1, /*isSigned=*/true);
6422   }
6423 
operator ()mlir::mhlo::Sign6424   FloatOrInt operator()(const FloatOrInt& fi) { return compute(fi); }
6425 };
6426 
6427 #define UNARY_FOLDER(Op, Func)                                                \
6428   OpFoldResult Op::fold(ArrayRef<Attribute> attrs) {                          \
6429     if (getElementTypeOrSelf(getType()).isa<FloatType>())                     \
6430       return UnaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
6431     if (getElementTypeOrSelf(getType()).isa<IntegerType>())                   \
6432       return UnaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs);   \
6433     return {};                                                                \
6434   }
6435 
6436 #define UNARY_FOLDER_INT(Op, Func)                                   \
6437   OpFoldResult Op::fold(ArrayRef<Attribute> attrs) {                 \
6438     if (getElementTypeOrSelf(getType()).isa<IntegerType>())          \
6439       return UnaryFolder<Op, IntegerType, APInt, Func>(this, attrs); \
6440     return {};                                                       \
6441   }
6442 
6443 #define UNARY_FOLDER_FLOAT(Op, Func)                                 \
6444   OpFoldResult Op::fold(ArrayRef<Attribute> attrs) {                 \
6445     if (getElementTypeOrSelf(getType()).isa<FloatType>())            \
6446       return UnaryFolder<Op, FloatType, APFloat, Func>(this, attrs); \
6447     return {};                                                       \
6448   }
6449 
6450 UNARY_FOLDER(NegOp, std::negate);
6451 UNARY_FOLDER(SignOp, Sign);
6452 UNARY_FOLDER_INT(NotOp, LogicalNot);
6453 UNARY_FOLDER_FLOAT(RoundNearestEvenOp, RoundNearestEven);
6454 UNARY_FOLDER_FLOAT(RoundOp, Round);
6455 
6456 #undef UNARY_FOLDER
6457 #undef UNARY_FOLDER_INT
6458 #undef UNARY_FOLDER_FLOAT
6459 
6460 //===----------------------------------------------------------------------===//
6461 // BinaryOps
6462 //===----------------------------------------------------------------------===//
6463 
parseBinaryOp(OpAsmParser & parser,OperationState & result)6464 ParseResult parseBinaryOp(OpAsmParser& parser, OperationState& result) {
6465   SmallVector<OpAsmParser::UnresolvedOperand> operands;
6466   Type type;
6467   // If the operand list is in-between parentheses, use generic form.
6468   SMLoc loc = parser.getCurrentLocation();
6469   if (!parser.parseOptionalLParen()) {
6470     if (parser.parseOperandList(operands) || parser.parseRParen() ||
6471         parser.parseOptionalAttrDict(result.attributes) ||
6472         parser.parseColon() || parser.parseType(type))
6473       return failure();
6474     auto fnType = type.dyn_cast<FunctionType>();
6475     if (!fnType) {
6476       parser.emitError(loc, "expected function type");
6477       return failure();
6478     }
6479     if (parser.resolveOperands(operands, fnType.getInputs(), loc,
6480                                result.operands))
6481       return failure();
6482     result.addTypes(fnType.getResults());
6483     return success();
6484   }
6485   // Otherwise, use shorthand syntax.
6486   return failure(parser.parseOperandList(operands) ||
6487                  parser.parseOptionalAttrDict(result.attributes) ||
6488                  parser.parseColonType(type) ||
6489                  parser.resolveOperands(operands, type, result.operands) ||
6490                  parser.addTypeToList(type, result.types));
6491 }
6492 
printBinaryOp(Operation * op,OpAsmPrinter & p)6493 void printBinaryOp(Operation* op, OpAsmPrinter& p) {
6494   assert(op->getNumResults() == 1 && "op should have one result");
6495   // If not all types are the same, use generic form.
6496   auto resultType = op->getResult(0).getType();
6497   if (llvm::any_of(op->getOperandTypes(),
6498                    [&](Type type) { return type != resultType; })) {
6499     p.printGenericOp(op, /*printOpName=*/false);
6500     return;
6501   }
6502   // Otherwise, use the shorthand syntax.
6503   p << ' ';
6504   p.printOperands(op->getOperands());
6505   p.printOptionalAttrDict(op->getAttrs());
6506   p << " : " << resultType;
6507 }
6508 
addSign(const APFloat & v,Type)6509 static const APFloat& addSign(const APFloat& v, Type) { return v; }
addSign(const APInt & v,Type t)6510 static APSInt addSign(const APInt& v, Type t) {
6511   // Add signedness information to the value, treating signless as signed.
6512   return APSInt(v, t.isUnsignedInteger());
6513 }
6514 
6515 template <typename Op, typename ElementType = Type, typename ValType,
6516           typename Convert>
BinaryFolder(Op * op,ArrayRef<Attribute> attrs)6517 static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {
6518   if (!attrs[0] || !attrs[1]) return {};
6519 
6520   DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
6521   DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
6522   if (!lhs || !rhs) return {};
6523 
6524   ShapedType type = op->getType().template cast<ShapedType>();
6525   if (!type.hasStaticShape()) {
6526     return {};
6527   }
6528 
6529   Type etype = type.getElementType();
6530 
6531   // Evaluate for integer values.
6532   if (!etype.isa<ElementType>()) {
6533     return {};
6534   }
6535 
6536   // Special case for folding splats no matter how large.
6537   // Only covers the case of both attrs being splats; operation-specific cases
6538   // like adding a zero or multiplying by one are handled elsewhere.
6539   SplatElementsAttr splatLhs = lhs.dyn_cast<SplatElementsAttr>();
6540   SplatElementsAttr splatRhs = rhs.dyn_cast<SplatElementsAttr>();
6541   if (splatLhs && splatRhs) {
6542     auto signedLhs = addSign(splatLhs.getSplatValue<ValType>(), etype);
6543     auto signedRhs = addSign(splatRhs.getSplatValue<ValType>(), etype);
6544     FailureOr<decltype(signedLhs)> result(Convert()(signedLhs, signedRhs));
6545     return succeeded(result) ? SplatElementsAttr::get(type, *result)
6546                              : Attribute();
6547   }
6548 
6549   // Prevent folding if the result is too large.
6550   if (lhs.getNumElements() > kFoldOpEltLimit) return {};
6551 
6552   SmallVector<ValType, 6> values;
6553   values.reserve(lhs.getNumElements());
6554   for (const auto zip :
6555        llvm::zip(lhs.getValues<ValType>(), rhs.getValues<ValType>())) {
6556     auto signedLhs = addSign(std::get<0>(zip), etype);
6557     auto signedRhs = addSign(std::get<1>(zip), etype);
6558     FailureOr<decltype(signedLhs)> result(Convert()(signedLhs, signedRhs));
6559     if (failed(result)) {
6560       return {};
6561     }
6562     values.push_back(std::move(*result));
6563   }
6564 
6565   return DenseElementsAttr::get(type, values);
6566 }
6567 
6568 template <typename T>
6569 struct Divide : std::divides<T> {};
6570 
6571 template <>
6572 struct Divide<APSInt> {
operator ()mlir::mhlo::Divide6573   FailureOr<APSInt> operator()(const APSInt& a, const APSInt& b) const {
6574     if (b.isZero()) return failure();
6575     return a / b;
6576   }
6577 };
6578 
6579 template <typename T>
6580 struct Remainder : std::modulus<T> {};
6581 
6582 template <>
6583 struct Remainder<APSInt> {
operator ()mlir::mhlo::Remainder6584   FailureOr<APSInt> operator()(const APSInt& a, const APSInt& b) const {
6585     if (b.isZero()) return failure();
6586     return a % b;
6587   }
6588 };
6589 
6590 template <>
6591 struct Remainder<APFloat> {
operator ()mlir::mhlo::Remainder6592   APFloat operator()(const APFloat& a, const APFloat& b) const {
6593     APFloat result(a);
6594     result.remainder(b);
6595     return result;
6596   }
6597 };
6598 
6599 template <typename T>
6600 struct Max {
operator ()mlir::mhlo::Max6601   T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
6602 };
6603 
6604 template <typename T>
6605 struct Min {
operator ()mlir::mhlo::Min6606   T operator()(const T& a, const T& b) const { return std::min<T>(a, b); }
6607 };
6608 
6609 template <typename T>
6610 struct And {
operator ()mlir::mhlo::And6611   T operator()(const T& a, const T& b) const { return a & b; }
6612 };
6613 
6614 template <typename T>
6615 struct Or {
operator ()mlir::mhlo::Or6616   T operator()(const T& a, const T& b) const { return a | b; }
6617 };
6618 
6619 template <typename T>
6620 struct Xor {
operator ()mlir::mhlo::Xor6621   T operator()(const T& a, const T& b) const { return a ^ b; }
6622 };
6623 
6624 #define BINARY_FOLDER_INTERNAL(Op, Func)                                     \
6625   if (getElementTypeOrSelf(getType()).isa<FloatType>())                      \
6626     return BinaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
6627   if (getElementTypeOrSelf(getType()).isa<IntegerType>())                    \
6628     return BinaryFolder<Op, IntegerType, APInt, Func<APSInt>>(this, attrs);  \
6629   return {};
6630 
6631 #define BINARY_FOLDER(Op, Func)                      \
6632   OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
6633     BINARY_FOLDER_INTERNAL(Op, Func)                 \
6634   }
6635 
6636 // Addition, subtraction and multiplication use the std:: versions of the ops.
6637 // Due to the other ops behaving differently in signed vs unsigned integers,
6638 // APInts need a special implementation. Currently, it replicates signed int
6639 // op behavior.
6640 BINARY_FOLDER(SubtractOp, std::minus);
6641 BINARY_FOLDER(DivOp, Divide);
6642 BINARY_FOLDER(RemOp, Remainder);
6643 BINARY_FOLDER(MaxOp, Max);
6644 BINARY_FOLDER(MinOp, Min);
6645 
isSplatZero(SplatElementsAttr attr)6646 bool isSplatZero(SplatElementsAttr attr) {
6647   if (!attr) return false;
6648   if (attr.getElementType().isa<FloatType>()) {
6649     return attr.getSplatValue<APFloat>().isZero();
6650   }
6651   if (attr.getElementType().isa<IntegerType>()) {
6652     return attr.getSplatValue<APInt>().isZero();
6653   }
6654   return false;
6655 }
6656 
fold(ArrayRef<Attribute> attrs)6657 OpFoldResult AddOp::fold(ArrayRef<Attribute> attrs) {
6658   // Handle special case where one operand is 0:  x + 0 => x
6659   if (attrs[0] || attrs[1]) {
6660     SplatElementsAttr splatLhs = attrs[0].dyn_cast_or_null<SplatElementsAttr>();
6661     SplatElementsAttr splatRhs = attrs[1].dyn_cast_or_null<SplatElementsAttr>();
6662     if (isSplatZero(splatLhs)) return splatRhs ? (OpFoldResult)splatRhs : rhs();
6663     if (isSplatZero(splatRhs)) return splatLhs ? (OpFoldResult)splatLhs : lhs();
6664   }
6665   if (attrs[0] && attrs[1]) {
6666     BINARY_FOLDER_INTERNAL(AddOp, std::plus)
6667   }
6668   return {};
6669 }
6670 
isSplatOne(SplatElementsAttr attr)6671 bool isSplatOne(SplatElementsAttr attr) {
6672   if (!attr) return false;
6673   if (attr.getElementType().isa<FloatType>()) {
6674     return attr.getSplatValue<APFloat>().convertToDouble() == 1.0;
6675   }
6676   if (attr.getElementType().isa<IntegerType>()) {
6677     return attr.getSplatValue<APInt>().getSExtValue() == 1;
6678   }
6679   return false;
6680 }
6681 
fold(ArrayRef<Attribute> attrs)6682 OpFoldResult MulOp::fold(ArrayRef<Attribute> attrs) {
6683   // Handle special case where one operand is 1: x * 1 => x
6684   if (attrs[0] || attrs[1]) {
6685     SplatElementsAttr splatLhs = attrs[0].dyn_cast_or_null<SplatElementsAttr>();
6686     SplatElementsAttr splatRhs = attrs[1].dyn_cast_or_null<SplatElementsAttr>();
6687     if (isSplatOne(splatLhs)) return splatRhs ? (OpFoldResult)splatRhs : rhs();
6688     if (isSplatOne(splatRhs)) return splatLhs ? (OpFoldResult)splatLhs : lhs();
6689   }
6690   if (attrs[0] && attrs[1]) {
6691     BINARY_FOLDER_INTERNAL(MulOp, std::multiplies);
6692   }
6693   return {};
6694 }
6695 
6696 //===----------------------------------------------------------------------===//
6697 // Logical Ops
6698 //===----------------------------------------------------------------------===//
6699 
fold(ArrayRef<Attribute> operands)6700 OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
6701   if (lhs() == rhs()) return lhs();
6702 
6703   auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
6704   auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
6705 
6706   if (lhsVal && lhsVal.isSplat()) {
6707     if (lhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) {
6708       return rhs();
6709     }
6710 
6711     if (lhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
6712       return lhsVal;
6713     }
6714   }
6715 
6716   if (rhsVal && rhsVal.isSplat()) {
6717     if (rhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) {
6718       return lhs();
6719     }
6720 
6721     if (rhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
6722       return rhsVal;
6723     }
6724   }
6725 
6726   if (!rhsVal || !lhsVal) return {};
6727   return BinaryFolder<AndOp, IntegerType, APInt, And<APSInt>>(this, operands);
6728 }
6729 
fold(ArrayRef<Attribute> operands)6730 OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
6731   if (lhs() == rhs()) return lhs();
6732 
6733   auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
6734   auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
6735 
6736   if (lhsVal && lhsVal.isSplat()) {
6737     if (lhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) {
6738       return lhsVal;
6739     }
6740 
6741     if (lhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
6742       return rhs();
6743     }
6744   }
6745 
6746   if (rhsVal && rhsVal.isSplat()) {
6747     if (rhsVal.getSplatValue<IntegerAttr>().getValue().isAllOnesValue()) {
6748       return rhsVal;
6749     }
6750 
6751     if (rhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
6752       return lhs();
6753     }
6754   }
6755 
6756   if (!rhsVal || !lhsVal) return {};
6757   return BinaryFolder<OrOp, IntegerType, APInt, Or<APSInt>>(this, operands);
6758 }
6759 
fold(ArrayRef<Attribute> operands)6760 OpFoldResult XorOp::fold(ArrayRef<Attribute> operands) {
6761   // Fold x^x to 0. Attributes only support static shapes.
6762   auto rType = getType().cast<ShapedType>();
6763   if (lhs() == rhs() && rType.hasStaticShape()) {
6764     Builder builder(getContext());
6765     return builder.getZeroAttr(rType);
6766   }
6767 
6768   auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
6769   auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
6770 
6771   if (lhsVal && lhsVal.isSplat()) {
6772     if (lhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
6773       return rhs();
6774     }
6775   }
6776 
6777   if (rhsVal && rhsVal.isSplat()) {
6778     if (rhsVal.getSplatValue<IntegerAttr>().getValue().isNullValue()) {
6779       return lhs();
6780     }
6781   }
6782 
6783   if (!rhsVal || !lhsVal) return {};
6784   return BinaryFolder<XorOp, IntegerType, APInt, Xor<APSInt>>(this, operands);
6785 }
6786 
6787 #undef BINARY_FOLDER_INTERNAL
6788 #undef BINARY_FOLDER
6789 
6790 //===----------------------------------------------------------------------===//
6791 // SliceOp
6792 //===----------------------------------------------------------------------===//
6793 
6794 // Returns output dimension size for slice result for the given arguments.
6795 // Returns -1 if arguments are illegal.
inferSliceDim(int64_t inputDim,int64_t start,int64_t end,int64_t stride)6796 static int64_t inferSliceDim(int64_t inputDim, int64_t start, int64_t end,
6797                              int64_t stride) {
6798   if (inputDim == -1 || start < 0 || start > end || end > inputDim ||
6799       stride == 0)
6800     return -1;
6801 
6802   return llvm::divideCeil(end - start, stride);
6803 }
6804 
6805 // The following properties are already enforced by the ODS:
6806 //  type(start_indices) == type(limit_indices) == type(strides).
6807 // Verify the following properties:
6808 //  P1. Verify rank(start_indices) == 1.
6809 //  P2. Verify size(start_indices) == rank(operand).
6810 //  P3~5. Verify 0 <= start_indices[i] <= limit_indices[i] <= shape(operand)[i].
6811 //  P6. Verify stride[i] > 0.
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)6812 LogicalResult SliceOp::inferReturnTypes(
6813     MLIRContext* context, Optional<Location> location, ValueRange operands,
6814     DictionaryAttr attributes, RegionRange regions,
6815     SmallVectorImpl<Type>& inferredReturnTypes) {
6816   SliceOpAdaptor slice(operands, attributes);
6817   Type ty = slice.operand().getType();
6818   RankedTensorType rankedTy = ty.dyn_cast<RankedTensorType>();
6819   if (!rankedTy) {
6820     // The operand type is unranked, so the best we can infer for the result
6821     // type is an unranked tensor with the same element type as the operand
6822     // type.
6823     inferredReturnTypes.assign({ty});
6824     return success();
6825   }
6826 
6827   ShapedType attrTy = slice.start_indices().getType();
6828   // P1.
6829   // Note: ODS has type(start_indices) == type(limit_indices) == type(strides)
6830   // So this implies rank(limit_indices) == rank(strides) == 1 also.
6831   if (attrTy.getRank() != 1) {
6832     return emitOptionalError(location, "start_indices has rank ",
6833                              attrTy.getRank(), " instead of required rank 1");
6834   }
6835 
6836   // P2.
6837   int64_t rank = rankedTy.getRank();
6838   if (attrTy.getNumElements() != rank) {
6839     return emitOptionalError(
6840         location, "the number of elements in start_indices (",
6841         attrTy.getNumElements(), ") does not match the rank of the operand (",
6842         rank, ")");
6843   }
6844 
6845   SmallVector<int64_t, 4> start(slice.start_indices().getValues<int64_t>());
6846   SmallVector<int64_t, 4> limit(slice.limit_indices().getValues<int64_t>());
6847   SmallVector<int64_t, 4> strideVals(slice.strides().getValues<int64_t>());
6848 
6849   SmallVector<int64_t, 4> shape;
6850   shape.reserve(rank);
6851   for (int64_t i = 0, e = rank; i != e; i++) {
6852     if (isDynamicDimSize(rankedTy.getDimSize(i))) {
6853       shape.push_back(ShapedType::kDynamicSize);
6854       continue;
6855     }
6856     // P3.
6857     if (start[i] < 0)
6858       return emitOptionalError(location, "negative start index ", start[i],
6859                                " in dimension ", i);
6860     // P4.
6861     if (limit[i] > rankedTy.getDimSize(i))
6862       return emitOptionalError(location, "limit index ", limit[i],
6863                                " is larger than dimension size ",
6864                                rankedTy.getDimSize(i), " in dimension ", i);
6865     // P5.
6866     if (start[i] > limit[i])
6867       return emitOptionalError(location, "start index ", start[i],
6868                                " is larger than limit index ", limit[i],
6869                                " in dimension ", i);
6870     // P6.
6871     if (strideVals[i] <= 0)
6872       return emitOptionalError(location, "stride must be positive but got ",
6873                                strideVals[i], " in dimension ", i);
6874 
6875     shape.push_back(inferSliceDim(rankedTy.getDimSize(i), start[i], limit[i],
6876                                   strideVals[i]));
6877   }
6878   inferredReturnTypes.assign(
6879       {RankedTensorType::get(shape, rankedTy.getElementType())});
6880   return success();
6881 }
6882 
6883 template <typename I, typename E>
sliceElements(I values,ArrayRef<int64_t> sizes,ArrayRef<int64_t> starts,ArrayRef<int64_t> limits,ArrayRef<int64_t> strides,llvm::SmallVectorImpl<E> * outValues)6884 static void sliceElements(I values, ArrayRef<int64_t> sizes,
6885                           ArrayRef<int64_t> starts, ArrayRef<int64_t> limits,
6886                           ArrayRef<int64_t> strides,
6887                           llvm::SmallVectorImpl<E>* outValues) {
6888   assert(starts.size() == limits.size());
6889   assert(starts.size() == strides.size());
6890   if (starts.empty()) return;
6891 
6892   int64_t start = starts.front();
6893   int64_t limit = limits.front();
6894   int64_t stride = strides.front();
6895   if (starts.size() == 1) {
6896     for (int i = start; i < limit; i += stride) {
6897       outValues->push_back(*(values + i));
6898     }
6899     return;
6900   }
6901 
6902   for (; start < limit; start += stride) {
6903     auto begin = values + start * sizes.front();
6904     sliceElements<I, E>(begin, sizes.drop_front(), starts.drop_front(),
6905                         limits.drop_front(), strides.drop_front(), outValues);
6906   }
6907 }
6908 
6909 template <typename I, typename E>
foldSlice(SliceOp * op,I values)6910 static Attribute foldSlice(SliceOp* op, I values) {
6911   auto start = llvm::to_vector<6>(op->start_indices().getValues<int64_t>());
6912   auto limit = llvm::to_vector<6>(op->limit_indices().getValues<int64_t>());
6913   auto stride = llvm::to_vector<6>(op->strides().getValues<int64_t>());
6914 
6915   // TODO(b/235903849): This should be op->getType().case<ShapedType>().
6916   auto resultType = op->operand().getType().cast<ShapedType>();
6917   if (!resultType.hasStaticShape()) return {};
6918 
6919   auto shape = resultType.getShape();
6920   int64_t count = resultType.getNumElements();
6921   if (count == 0) {
6922     return DenseElementsAttr::get<E>(
6923         op->getResult().getType().cast<ShapedType>(),
6924         /*list=*/{});
6925   }
6926 
6927   // Compute the striding for each dimension.
6928   llvm::SmallVector<int64_t, 6> sizes;
6929   sizes.reserve(shape.size());
6930   for (auto v : shape) {
6931     count = count / v;
6932     sizes.push_back(count);
6933   }
6934 
6935   // Prevent folding if the result is too large.
6936   if (resultType.getNumElements() > kFoldOpEltLimit) return {};
6937 
6938   llvm::SmallVector<E, 6> outValues;
6939   outValues.reserve(resultType.getNumElements());
6940   sliceElements<I, E>(values, sizes, start, limit, stride, &outValues);
6941 
6942   return DenseElementsAttr::get(op->getResult().getType().cast<ShapedType>(),
6943                                 outValues);
6944 }
6945 
fold(ArrayRef<Attribute> operands)6946 OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
6947   // Check if the SliceOp is a NoOp operation.
6948   auto operandType = getOperand().getType().cast<ShapedType>();
6949   auto resultType = getResult().getType().cast<ShapedType>();
6950 
6951   if (operandType.hasStaticShape() && resultType.hasStaticShape() &&
6952       (operandType.getShape() == resultType.getShape())) {
6953     return getOperand();
6954   }
6955 
6956   if (operands.empty() || !operands.front()) return {};
6957 
6958   // Evaluate for statically valued inputs.
6959   DenseElementsAttr elements = operands.front().dyn_cast<DenseElementsAttr>();
6960   if (!elements) return {};
6961 
6962   auto etype = elements.getType().getElementType();
6963   if (etype.isa<IntegerType>()) {
6964     return foldSlice<DenseElementsAttr::IntElementIterator, APInt>(
6965         this, elements.value_begin<APInt>());
6966   }
6967   if (etype.isa<FloatType>()) {
6968     return foldSlice<DenseElementsAttr::FloatElementIterator, APFloat>(
6969         this, elements.value_begin<APFloat>());
6970   }
6971 
6972   return {};
6973 }
6974 
6975 namespace {
6976 // In cases where a concat is fed into a slice, it is possible the concat
6977 // can be simplified or bypassed. This checks which inputs to the concat are
6978 // used by the slice, either reducing the number of concatenated values or
6979 // entirely removes the concat.
6980 struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
6981   using OpRewritePattern<SliceOp>::OpRewritePattern;
6982 
matchAndRewritemlir::mhlo::__anon00baf10a3911::SimplifyConcatSlice6983   LogicalResult matchAndRewrite(SliceOp slice,
6984                                 PatternRewriter& rewriter) const override {
6985     auto resultTy = slice.getType().cast<ShapedType>();
6986     if (!resultTy.hasStaticShape()) {
6987       return failure();
6988     }
6989 
6990     auto sliceInput = slice.operand();
6991     auto sliceInputTy = sliceInput.getType().cast<ShapedType>();
6992     auto concat = sliceInput.getDefiningOp<ConcatenateOp>();
6993     if (!concat) {
6994       return failure();
6995     }
6996 
6997     auto dimension = concat.dimension();
6998 
6999     auto start = slice.start_indices().getValues<APInt>();
7000     auto limit = slice.limit_indices().getValues<APInt>();
7001 
7002     auto sliceStart = (*(start.begin() + dimension)).getSExtValue();
7003     auto sliceLimit = (*(limit.begin() + dimension)).getSExtValue();
7004 
7005     // We need to determine what inputs from the concat affect the slice, and
7006     // how the bounds of the slice need to be updated for the minimally required
7007     // inputs.
7008     int64_t runningSize = 0;
7009     int64_t frontOffset = sliceInputTy.getShape()[dimension];
7010 
7011     auto subsetStart = concat.operand_end();
7012     auto subsetEnd = concat.operand_end();
7013     for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) {
7014       auto input = *it;
7015       ShapedType inputTy = input.getType().cast<ShapedType>();
7016       if (inputTy.isDynamicDim(dimension)) {
7017         return failure();
7018       }
7019       auto dimSize = inputTy.getShape()[dimension];
7020 
7021       // If this position is in the slice its the start of the subset and we
7022       // need to update the start and limit values.
7023       if (runningSize + dimSize > sliceStart &&
7024           subsetStart == concat.operand_end()) {
7025         subsetStart = it;
7026         frontOffset = runningSize;
7027       }
7028 
7029       // Determine the last required offset.
7030       if (runningSize < sliceLimit) {
7031         subsetEnd = it + 1;
7032       }
7033 
7034       runningSize += dimSize;
7035     }
7036 
7037     auto subsetSize = subsetEnd - subsetStart;
7038     // We need all inputs so no optimization.
7039     if (subsetSize == concat.getNumOperands()) {
7040       return failure();
7041     }
7042 
7043     // If there's nothing to slice that means the output is an empty tensor and
7044     // there is dead code. We do nothing here and rely on other passes to clean
7045     // this up.
7046     if (subsetSize == 0) {
7047       return failure();
7048     }
7049 
7050     if (subsetSize > 1 && !concat.getResult().hasOneUse()) {
7051       return failure();
7052     }
7053 
7054     auto concatRange = OperandRange(subsetStart, subsetEnd);
7055     auto newConcat = rewriter.create<ConcatenateOp>(
7056         concat.getLoc(), concatRange, concat.dimension());
7057 
7058     llvm::SmallVector<APInt, 6> newStart(start);
7059     llvm::SmallVector<APInt, 6> newLimit(limit);
7060     newStart[dimension] -= frontOffset;
7061     newLimit[dimension] -= frontOffset;
7062 
7063     auto attrType = slice.start_indices().getType().cast<ShapedType>();
7064     auto create = rewriter.create<SliceOp>(
7065         slice.getLoc(), newConcat,
7066         DenseIntElementsAttr::get(attrType, newStart),
7067         DenseIntElementsAttr::get(attrType, newLimit), slice.strides());
7068     rewriter.replaceOp(slice, create.getResult());
7069     return success();
7070   }
7071 };
7072 }  // namespace
7073 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)7074 void SliceOp::getCanonicalizationPatterns(RewritePatternSet& results,
7075                                           MLIRContext* context) {
7076   results.add<SimplifyConcatSlice>(context);
7077 }
7078 
7079 //===----------------------------------------------------------------------===//
7080 // SortOp
7081 //===----------------------------------------------------------------------===//
7082 
build(OpBuilder & builder,OperationState & state,ValueRange operands,int64_t dimension,bool isStable)7083 void SortOp::build(OpBuilder& builder, OperationState& state,
7084                    ValueRange operands, int64_t dimension, bool isStable) {
7085   state.addOperands(operands);
7086   state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
7087   state.addAttribute("is_stable", builder.getBoolAttr(isStable));
7088 
7089   for (Value operand : operands) state.addTypes(operand.getType());
7090 
7091   state.addRegion();
7092 }
7093 
verify()7094 LogicalResult SortOp::verify() {
7095   Operation::operand_range operands = this->operands();
7096   if (operands.empty()) return emitOpError("requires at least one input");
7097 
7098   // TODO(antiagainst): verify partionally dynamic shapes
7099   if (llvm::all_of(operands, [](Value operand) {
7100         return operand.getType().cast<ShapedType>().hasRank();
7101       })) {
7102     ArrayRef<int64_t> inputShape =
7103         (*operands.begin()).getType().cast<ShapedType>().getShape();
7104 
7105     if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) {
7106           return operand.getType().cast<ShapedType>().getShape() != inputShape;
7107         }))
7108       return emitOpError("requires all inputs to have the same dimensions");
7109 
7110     int64_t rank = inputShape.size();
7111     int64_t cmpDim = dimension();
7112     if (cmpDim < -rank || cmpDim >= rank)
7113       return emitOpError("dimension attribute value must be in range [-")
7114              << rank << ", " << rank << "), but found " << cmpDim;
7115   }
7116 
7117   Block& block = comparator().front();
7118   size_t numOperands = getOperation()->getNumOperands();
7119   if (block.getNumArguments() != 2 * numOperands)
7120     return emitOpError("comparator block should have ")
7121            << 2 * numOperands << " arguments";
7122 
7123   for (const auto& indexedOperand : llvm::enumerate(operands)) {
7124     int index = indexedOperand.index();
7125     Type elementType =
7126         indexedOperand.value().getType().cast<ShapedType>().getElementType();
7127     Type tensorType = RankedTensorType::get({}, elementType);
7128     for (int i : {2 * index, 2 * index + 1}) {
7129       Type argType = block.getArgument(i).getType();
7130       if (argType != tensorType)
7131         return emitOpError("comparator block argument #")
7132                << i << " should be of type " << tensorType << " but got "
7133                << argType;
7134     }
7135   }
7136 
7137   // Mapped computation must return single output.
7138   auto comparatorResult = block.getTerminator()->getOperands();
7139   if (comparatorResult.size() != 1)
7140     return emitOpError() << "comparator must return single output, but got: "
7141                          << comparatorResult.size();
7142 
7143   // The output of computation must be 0-ranked tensor with element-type i1.
7144   auto comparatorResultType =
7145       comparatorResult[0].getType().dyn_cast<RankedTensorType>();
7146   if (!comparatorResultType || comparatorResultType.getRank() != 0 ||
7147       !comparatorResultType.getElementType().isInteger(1))
7148     return emitOpError() << "comparator must return tensor<i1>, but got: "
7149                          << comparatorResult[0].getType();
7150 
7151   // check number of return-values and their element-types.
7152   auto resultTypes = getResultTypes();
7153   if (resultTypes.size() != numOperands)
7154     return emitOpError() << "expects the number of results to be same as "
7155                             "number of operands. Got number of results = "
7156                          << resultTypes.size()
7157                          << " and number of operands = " << numOperands;
7158 
7159   for (auto it : llvm::zip(operands, getResultTypes()))
7160     if (std::get<0>(it).getType().cast<TensorType>().getElementType() !=
7161         std::get<1>(it).cast<TensorType>().getElementType())
7162       return emitOpError()
7163              << "expects the operands and results to have pairwize equal "
7164                 "element-types, but got "
7165              << std::get<0>(it).getType().cast<TensorType>().getElementType()
7166              << " vs " << std::get<1>(it).cast<TensorType>().getElementType();
7167 
7168   return success();
7169 }
7170 
7171 /// Drops the operands if the results are not used and they are not used in
7172 /// op.comparator().
sortDropEmptyUseArgs(SortOp op,PatternRewriter & rewriter)7173 static LogicalResult sortDropEmptyUseArgs(SortOp op,
7174                                           PatternRewriter& rewriter) {
7175   DenseSet<unsigned> erasedArgs;
7176   unsigned numOperands = op.getNumOperands();
7177   for (unsigned i = 0; i < numOperands; ++i) {
7178     if (!op.getResult(i).use_empty()) continue;
7179     Block& block = op.comparator().front();
7180     if (!block.getArgument(i * 2).use_empty()) continue;
7181     if (!block.getArgument(i * 2 + 1).use_empty()) continue;
7182     erasedArgs.insert(i);
7183   }
7184   if (erasedArgs.empty()) return failure();
7185 
7186   SmallVector<Value> newOperands;
7187   SmallVector<unsigned> erasedBlockArgs;
7188   for (const auto& en : llvm::enumerate(op.operands())) {
7189     if (erasedArgs.contains(en.index())) {
7190       erasedBlockArgs.push_back(en.index() * 2);
7191       erasedBlockArgs.push_back(en.index() * 2 + 1);
7192     } else {
7193       newOperands.push_back(en.value());
7194     }
7195   }
7196 
7197   auto newOp = rewriter.create<SortOp>(op.getLoc(), newOperands, op.dimension(),
7198                                        op.is_stable());
7199   Region& region = newOp.comparator();
7200   rewriter.inlineRegionBefore(op.comparator(), region, region.end());
7201   region.front().eraseArguments(erasedBlockArgs);
7202 
7203   SmallVector<Value> results;
7204   for (unsigned i = 0, j = 0; i < numOperands; ++i) {
7205     if (erasedArgs.contains(i)) {
7206       results.push_back({});
7207     } else {
7208       results.push_back(newOp.getResult(j++));
7209     }
7210   }
7211   rewriter.replaceOp(op, results);
7212 
7213   return success();
7214 }
7215 
7216 /// Set the sorting dimension to the last dimension if it's not set and the rank
7217 /// is known.
sortOpInferDefaultDimension(SortOp op,PatternRewriter & rewriter)7218 static LogicalResult sortOpInferDefaultDimension(SortOp op,
7219                                                  PatternRewriter& rewriter) {
7220   auto ty = op.getResultTypes()[0].dyn_cast<ShapedType>();
7221   if (!ty) {
7222     return failure();
7223   }
7224   if (static_cast<int64_t>(op.dimension()) != -1) {
7225     return failure();
7226   }
7227 
7228   IntegerAttr dim = rewriter.getI64IntegerAttr(ty.getRank() - 1);
7229   auto newOp = rewriter.create<SortOp>(op.getLoc(), op.getResultTypes(),
7230                                        op.operands(), dim, op.is_stableAttr());
7231   Region& region = newOp.comparator();
7232   rewriter.inlineRegionBefore(op.comparator(), region, region.end());
7233   rewriter.replaceOp(op, newOp.getResults());
7234 
7235   return success();
7236 }
7237 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext *)7238 void SortOp::getCanonicalizationPatterns(RewritePatternSet& results,
7239                                          MLIRContext* /*context*/) {
7240   results.add(sortDropEmptyUseArgs);
7241   results.add(sortOpInferDefaultDimension);
7242 }
7243 
7244 //===----------------------------------------------------------------------===//
7245 // TransposeOp
7246 //===----------------------------------------------------------------------===//
7247 
fold(ArrayRef<Attribute> operands)7248 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
7249   if (auto elements = operands.front().dyn_cast_or_null<SplatElementsAttr>()) {
7250     return reshape(elements, getResult().getType().cast<ShapedType>());
7251   }
7252   for (const auto& it : llvm::enumerate(permutation().getValues<APInt>())) {
7253     if (it.index() != it.value()) {
7254       return {};
7255     }
7256   }
7257   return getOperand();
7258 }
7259 
7260 // transpose(transpose(X)) => transpose(X)
eliminateRedundantTranspse(TransposeOp op,PatternRewriter & rewriter)7261 static LogicalResult eliminateRedundantTranspse(TransposeOp op,
7262                                                 PatternRewriter& rewriter) {
7263   auto tranposeOperand = op.operand().getDefiningOp<TransposeOp>();
7264   if (!tranposeOperand) {
7265     return failure();
7266   }
7267   auto operandPermutation = tranposeOperand.permutation().getValues<APInt>();
7268   auto newPermutation =
7269       op.permutation()
7270           .mapValues(op.permutation().getElementType(),
7271                      [&operandPermutation](const APInt& index) -> APInt {
7272                        return operandPermutation[index.getSExtValue()];
7273                      })
7274           .cast<DenseIntElementsAttr>();
7275   rewriter.replaceOpWithNewOp<TransposeOp>(
7276       op, op.getResult().getType(), tranposeOperand.operand(), newPermutation);
7277   return success();
7278 }
7279 
7280 // transpose(broadcast_in_dim(X)) => broadcast_in_dim(X)
eliminateBroadcastInDimTranspose(TransposeOp op,PatternRewriter & rewriter)7281 static LogicalResult eliminateBroadcastInDimTranspose(
7282     TransposeOp op, PatternRewriter& rewriter) {
7283   auto broadcastInDimOp = op.operand().getDefiningOp<BroadcastInDimOp>();
7284   if (!broadcastInDimOp) {
7285     return failure();
7286   }
7287   DenseIntElementsAttr broadcastDimensions =
7288       broadcastInDimOp.broadcast_dimensions();
7289   DenseIntElementsAttr permutation = op.permutation();
7290   SmallVector<int64_t> newBroadcastDimensions;
7291   for (auto dimension : broadcastDimensions.getValues<int64_t>()) {
7292     int64_t index = 0;
7293     for (auto p : permutation.getValues<int64_t>()) {
7294       if (p == dimension) {
7295         newBroadcastDimensions.push_back(index);
7296         break;
7297       }
7298       index++;
7299     }
7300   }
7301   rewriter.replaceOpWithNewOp<BroadcastInDimOp>(
7302       op, op->getResultTypes(), broadcastInDimOp.operand(),
7303       rewriter.getI64TensorAttr(newBroadcastDimensions));
7304   return success();
7305 }
7306 
7307 // simplify Transpose: replace Transpose with Reshape if they are equivalent
simplifyTranspose(TransposeOp op,PatternRewriter & rewriter)7308 static LogicalResult simplifyTranspose(TransposeOp op,
7309                                        PatternRewriter& rewriter) {
7310   auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
7311   auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
7312   if (!operandType || !resultType) {
7313     return failure();
7314   }
7315   // Not support dynamic shape a.t.m. BTW, when it's dynamic shape,
7316   // maybe Transpose should be replaced by DynamicReshape.
7317   if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) {
7318     return failure();
7319   }
7320   auto permutation = op.permutation().getValues<int64_t>();
7321   llvm::SmallVector<int64_t> sortedPermutation;
7322   for (int64_t i = 0, e = resultType.getRank(); i < e; i++) {
7323     if (resultType.getDimSize(i) != 1) {
7324       sortedPermutation.push_back(permutation[i]);
7325     }
7326   }
7327   if (llvm::is_sorted(sortedPermutation)) {
7328     rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand());
7329     return success();
7330   }
7331   return failure();
7332 }
7333 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext *)7334 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet& results,
7335                                               MLIRContext* /*context*/) {
7336   results.add(eliminateRedundantTranspse);
7337   results.add(eliminateBroadcastInDimTranspose);
7338   results.add(simplifyTranspose);
7339 }
7340 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)7341 LogicalResult TransposeOp::reifyReturnTypeShapes(
7342     OpBuilder& builder, ValueRange operands,
7343     SmallVectorImpl<Value>& reifiedReturnShapes) {
7344   TransposeOp::Adaptor adaptor(operands);
7345   Value operand = adaptor.operand();
7346 
7347   auto operandType = operand.getType().dyn_cast<RankedTensorType>();
7348   // Not support unranked type a.t.m.
7349   if (!operandType) return failure();
7350 
7351   Location loc = this->getLoc();
7352   SmallVector<int64_t, 4> permutation(this->permutation().getValues<int64_t>());
7353   SmallVector<Value, 4> shapeValues(permutation.size());
7354 
7355   Type shapeScalarType = builder.getIndexType();
7356   auto toShapeScalarType = [&](Value v) {
7357     return maybeCastTo(builder, loc, v, shapeScalarType);
7358   };
7359 
7360   for (const auto& element : llvm::enumerate(operandType.getShape())) {
7361     int64_t idx = element.index();
7362     auto* it = std::find(permutation.begin(), permutation.end(), idx);
7363     Value valueDim = toShapeScalarType(
7364         builder.createOrFold<tensor::DimOp>(loc, operand, element.index()));
7365     shapeValues[std::distance(permutation.begin(), it)] = valueDim;
7366   }
7367 
7368   Value outputShape = builder.create<tensor::FromElementsOp>(
7369       loc,
7370       RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
7371                             shapeScalarType),
7372       shapeValues);
7373   reifiedReturnShapes.push_back(outputShape);
7374 
7375   return success();
7376 }
7377 
7378 // Method for InferTypeOpInterface: infer the return type from the operand type
7379 // and the permutation.
inferReturnTypes(MLIRContext *,Optional<Location> loc,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)7380 LogicalResult TransposeOp::inferReturnTypes(
7381     MLIRContext* /*context*/, Optional<Location> loc, ValueRange operands,
7382     DictionaryAttr attributes, RegionRange,
7383     SmallVectorImpl<Type>& inferredReturnTypes) {
7384   auto type = operands[0].getType();
7385   auto rankedTy = type.dyn_cast<RankedTensorType>();
7386   if (!rankedTy) {
7387     auto shapedTy = type.dyn_cast<ShapedType>();
7388     inferredReturnTypes.emplace_back(shapedTy);
7389     return success();
7390   }
7391   auto permutation = attributes.getAs<DenseIntElementsAttr>("permutation");
7392   int64_t rank = rankedTy.getRank();
7393   if (permutation.getType().getRank() != 1)
7394     return emitOptionalError(loc, "TransposeOp permutation has rank ",
7395                              permutation.getType().getRank(),
7396                              " instead of rank 1");
7397 
7398   if (permutation.size() != rank)
7399     return emitOptionalError(loc, "TransposeOp operand rank ", rank,
7400                              " does not match permutation size ",
7401                              permutation.size());
7402 
7403   std::vector<int64_t> range(rank);
7404   std::iota(range.begin(), range.end(), 0);
7405   if (!std::is_permutation(range.begin(), range.end(), permutation.begin()))
7406     return emitOptionalError(loc,
7407                              "attribute permutation must be a permutation"
7408                              " of [",
7409                              range, "] but got ", permutation);
7410 
7411   SmallVector<int64_t> resultShape;
7412   ArrayRef<int64_t> inputShape = rankedTy.getShape();
7413   for (int64_t dim : permutation.getValues<int64_t>()) {
7414     resultShape.push_back(inputShape[dim]);
7415   }
7416   inferredReturnTypes.emplace_back(RankedTensorType::get(
7417       resultShape, rankedTy.getElementType(), rankedTy.getEncoding()));
7418   return success();
7419 }
7420 
7421 //===----------------------------------------------------------------------===//
7422 // TriangularSolveOp
7423 //===----------------------------------------------------------------------===//
7424 
verify()7425 LogicalResult TriangularSolveOp::verify() {
7426   auto aType = a().getType().dyn_cast<RankedTensorType>();
7427 
7428   // Skip verifier if a is unranked tensor.
7429   if (!aType) return success();
7430 
7431   // Check that a should have rank >= 2
7432   auto aRank = aType.getRank();
7433   if (aRank < 2)
7434     return emitOpError() << "operand 'a' must have rank >= 2, but got "
7435                          << aType;
7436 
7437   // The two minor dimensions of a must have same size.
7438   if (aType.getDimSize(aRank - 2) != aType.getDimSize(aRank - 1))
7439     return emitOpError() << "two minor dimensions of operand 'a' must have "
7440                             "equal size, but got "
7441                          << aType;
7442 
7443   auto bType = b().getType().dyn_cast<RankedTensorType>();
7444   // If b is unranked skip remaining checks.
7445   if (!bType) return success();
7446 
7447   // Check that a and b have same rank.
7448   auto bRank = bType.getRank();
7449   if (aRank != bRank)
7450     return emitOpError() << "operands must have equal rank, but got " << aType
7451                          << " and " << bType;
7452 
7453   // The shared dimension of a and b should match.
7454   if (aType.getDimSize(aRank - 1) !=
7455       bType.getDimSize(bRank - (left_side() ? 2 : 1)))
7456     return emitOpError() << "shared dimension of operands 'a' and 'b' does "
7457                             "not match, but got "
7458                          << aType << " and " << bType;
7459 
7460   // The leading batch dimensions of a and b must be equal.
7461   auto aBatchDims = aType.getShape().drop_back(2);
7462   auto bBatchDims = bType.getShape().drop_back(2);
7463   if (aBatchDims != bBatchDims)
7464     return emitOpError()
7465            << "leading batch dimensions of the operands must be same, but got "
7466            << aType << " and " << bType;
7467 
7468   // Result and argument b must have same shape.
7469   auto resultType = getType().dyn_cast<RankedTensorType>();
7470   if (!resultType) return success();
7471   if (resultType != bType)
7472     return emitOpError()
7473            << "result and operand 'b' must have same shape, but got "
7474            << resultType << " and " << bType;
7475   return success();
7476 }
7477 
7478 //===----------------------------------------------------------------------===//
7479 // GetTupleElementOp
7480 //===----------------------------------------------------------------------===//
7481 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)7482 LogicalResult GetTupleElementOp::inferReturnTypes(
7483     MLIRContext*, Optional<Location>, ValueRange operands,
7484     DictionaryAttr attributes, RegionRange,
7485     SmallVectorImpl<Type>& inferredReturnTypes) {
7486   auto tupleType = operands[0].getType().dyn_cast<TupleType>();
7487   if (!tupleType) return failure();
7488 
7489   auto indexAttr = attributes.get("index").cast<IntegerAttr>();
7490   auto index = indexAttr.getInt();
7491   if (index < 0 || index >= static_cast<int64_t>(tupleType.size()))
7492     return failure();
7493 
7494   inferredReturnTypes.push_back(tupleType.getType(index));
7495   return success();
7496 }
7497 
7498 //===----------------------------------------------------------------------===//
7499 // TupleOp
7500 //===----------------------------------------------------------------------===//
7501 
inferReturnTypes(MLIRContext * context,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)7502 LogicalResult TupleOp::inferReturnTypes(
7503     MLIRContext* context, Optional<Location>, ValueRange operands,
7504     DictionaryAttr attributes, RegionRange,
7505     SmallVectorImpl<Type>& inferredReturnTypes) {
7506   inferredReturnTypes.push_back(TupleType::get(context, TypeRange(operands)));
7507   return success();
7508 }
7509 
7510 //===----------------------------------------------------------------------===//
7511 // UnaryEinsumOp
7512 //===----------------------------------------------------------------------===//
7513 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)7514 void UnaryEinsumOp::getCanonicalizationPatterns(RewritePatternSet& results,
7515                                                 MLIRContext* context) {
7516   results.add<UnaryEinsumToEinsum>(context);
7517 }
7518 
7519 //===----------------------------------------------------------------------===//
7520 // CompareOp
7521 //===----------------------------------------------------------------------===//
7522 
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,ComparisonDirection comparisonDirection,ComparisonType compareType)7523 void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
7524                       Value rhs, ComparisonDirection comparisonDirection,
7525                       ComparisonType compareType) {
7526   build(builder, result, lhs, rhs,
7527         ComparisonDirectionAttr::get(builder.getContext(), comparisonDirection),
7528         ComparisonTypeAttr::get(builder.getContext(), compareType));
7529 }
7530 
inferReturnTypeComponents(mlir::MLIRContext * ctx,llvm::Optional<mlir::Location>,ValueShapeRange operands,mlir::DictionaryAttr,mlir::RegionRange,llvm::SmallVectorImpl<mlir::ShapedTypeComponents> & inferredReturnTypes)7531 LogicalResult CompareOp::inferReturnTypeComponents(
7532     mlir::MLIRContext* ctx, llvm::Optional<mlir::Location>,
7533     ValueShapeRange operands, mlir::DictionaryAttr, mlir::RegionRange,
7534     llvm::SmallVectorImpl<mlir::ShapedTypeComponents>& inferredReturnTypes) {
7535   ShapedTypeComponents& components =
7536       inferredReturnTypes.emplace_back(IntegerType::get(ctx, /*width=*/1));
7537   auto argTy = operands.front().getType().cast<TensorType>();
7538   if (argTy.hasRank()) {
7539     components =
7540         ShapedTypeComponents(argTy.getShape(), components.getElementType());
7541   }
7542   return success();
7543 }
7544 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)7545 LogicalResult CompareOp::reifyReturnTypeShapes(
7546     OpBuilder& builder, ValueRange operands,
7547     SmallVectorImpl<Value>& reifiedReturnShapes) {
7548   return deriveShapeFromOperand(&builder, getOperation(), operands.front(),
7549                                 &reifiedReturnShapes);
7550 }
7551 
7552 template <typename T>
7553 struct Less : std::less<T> {};
7554 
7555 template <>
7556 struct Less<APInt> {
operator ()mlir::mhlo::Less7557   bool operator()(const APInt& a, const APInt& b) const { return a.slt(b); }
7558 };
7559 
7560 template <typename T>
7561 struct LessEqual : std::less_equal<T> {};
7562 
7563 template <>
7564 struct LessEqual<APInt> {
operator ()mlir::mhlo::LessEqual7565   bool operator()(const APInt& a, const APInt& b) const { return a.sle(b); }
7566 };
7567 
7568 template <typename T>
7569 struct Greater : std::greater<T> {};
7570 
7571 template <>
7572 struct Greater<APInt> {
operator ()mlir::mhlo::Greater7573   bool operator()(const APInt& a, const APInt& b) const { return a.sgt(b); }
7574 };
7575 
7576 template <typename T>
7577 struct GreaterEqual : std::greater_equal<T> {};
7578 
7579 template <>
7580 struct GreaterEqual<APInt> {
operator ()mlir::mhlo::GreaterEqual7581   bool operator()(const APInt& a, const APInt& b) const { return a.sge(b); }
7582 };
7583 
7584 template <typename Op, typename ElementType, typename SrcType, typename Convert>
CompareFolder(CompareOp op,ArrayRef<Attribute> attrs)7585 static Attribute CompareFolder(CompareOp op, ArrayRef<Attribute> attrs) {
7586   if (!attrs[0] || !attrs[1]) return {};
7587 
7588   DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
7589   DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
7590   if (!lhs || !rhs) return {};
7591 
7592   ShapedType operandType =
7593       op.getOperand(0).getType().template cast<ShapedType>();
7594   if (!operandType.hasStaticShape()) {
7595     return {};
7596   }
7597 
7598   if (!operandType.getElementType().isa<ElementType>()) {
7599     return {};
7600   }
7601 
7602   // Prevent folding if the result is too large.
7603   if (lhs.getNumElements() > kFoldOpEltLimit) return {};
7604 
7605   SmallVector<bool, 6> values;
7606   values.reserve(lhs.getNumElements());
7607   for (const auto zip :
7608        llvm::zip(lhs.getValues<SrcType>(), rhs.getValues<SrcType>())) {
7609     values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip)));
7610   }
7611 
7612   auto resultTy = op.getType().cast<ShapedType>();
7613   return DenseElementsAttr::get(resultTy, values);
7614 }
7615 
fold(ArrayRef<Attribute> operands)7616 OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
7617   auto resultTy = getType().cast<ShapedType>();
7618   if (!resultTy.hasStaticShape()) return {};
7619 
7620   auto direction = comparison_direction();
7621   auto lhsTy = getElementTypeOrSelf(lhs());
7622   if (lhs() == rhs() && !lhsTy.isa<FloatType>() &&
7623       (!lhsTy.isa<ComplexType>() ||
7624        !lhsTy.cast<ComplexType>().getElementType().isa<FloatType>())) {
7625     if (direction == ComparisonDirection::LE ||
7626         direction == ComparisonDirection::EQ ||
7627         direction == ComparisonDirection::GE) {
7628       return DenseIntElementsAttr::get(resultTy, {true});
7629     }
7630     return DenseIntElementsAttr::get(resultTy, {false});
7631   }
7632 
7633   auto opElType = lhs().getType().cast<ShapedType>().getElementType();
7634   // Fold tensor<*xi1> != false to just return tensor<*xi1>
7635   if (direction == ComparisonDirection::NE && opElType.isInteger(1)) {
7636     DenseIntElementsAttr cstAttr;
7637     if (matchPattern(lhs(), m_Constant(&cstAttr))) {
7638       if (cstAttr.isSplat() && !cstAttr.getSplatValue<bool>()) {
7639         return rhs();
7640       }
7641     }
7642 
7643     if (matchPattern(rhs(), m_Constant(&cstAttr))) {
7644       if (cstAttr.isSplat() && !cstAttr.getSplatValue<bool>()) {
7645         return lhs();
7646       }
7647     }
7648   }
7649 
7650   // Fold tensor<*xi1> == True to just return tensor<*xi1>
7651   if (direction == ComparisonDirection::EQ && opElType.isInteger(1)) {
7652     DenseIntElementsAttr cstAttr;
7653     if (matchPattern(lhs(), m_Constant(&cstAttr))) {
7654       if (cstAttr.isSplat() && cstAttr.getSplatValue<bool>()) {
7655         return rhs();
7656       }
7657     }
7658 
7659     if (matchPattern(rhs(), m_Constant(&cstAttr))) {
7660       if (cstAttr.isSplat() && cstAttr.getSplatValue<bool>()) {
7661         return lhs();
7662       }
7663     }
7664   }
7665 
7666   if (!operands[0] || !operands[1]) {
7667     return {};
7668   }
7669 
7670 #define COMPARE_FOLDER(Op, comparison, Func)                                \
7671   if (direction == comparison) {                                            \
7672     if (auto folded = CompareFolder<Op, FloatType, APFloat, Func<APFloat>>( \
7673             *this, operands))                                               \
7674       return folded;                                                        \
7675     if (auto folded = CompareFolder<Op, IntegerType, APInt, Func<APInt>>(   \
7676             *this, operands))                                               \
7677       return folded;                                                        \
7678   }
7679 
7680   COMPARE_FOLDER(CompareOp, ComparisonDirection::EQ, std::equal_to);
7681   COMPARE_FOLDER(CompareOp, ComparisonDirection::NE, std::not_equal_to);
7682   COMPARE_FOLDER(CompareOp, ComparisonDirection::LT, Less);
7683   COMPARE_FOLDER(CompareOp, ComparisonDirection::LE, LessEqual);
7684   COMPARE_FOLDER(CompareOp, ComparisonDirection::GT, Greater);
7685   COMPARE_FOLDER(CompareOp, ComparisonDirection::GE, GreaterEqual);
7686 #undef COMPARE_FOLDER
7687 
7688   return {};
7689 }
7690 
7691 //===----------------------------------------------------------------------===//
7692 // SelectAndScatterOp
7693 //===----------------------------------------------------------------------===//
7694 
7695 namespace {
7696 // Infer the return-type of SelectAndScatterOp.
inferSelectAndScatterOpReturnType(TensorType operandType,const ArrayRef<WindowDimension> window)7697 TensorType inferSelectAndScatterOpReturnType(
7698     TensorType operandType, const ArrayRef<WindowDimension> window) {
7699   if (!operandType.hasRank())
7700     return UnrankedTensorType::get(operandType.getElementType());
7701 
7702   return RankedTensorType::get(
7703       inferWindowOutputShape(operandType.getShape(), window),
7704       operandType.getElementType());
7705 }
7706 }  // namespace
7707 
7708 //  We intend to verify the following properties:
7709 //   P1. Check if the select function has a proper shape of (T,T) -> PRED, where
7710 //        T is a 0-D tensor with element-type same as 'operand' element-type.
7711 //   P2. Verify scatter-computation type.
7712 //   P3. size-of(window_dimension) == rank-of(input),
7713 //         where input is an element of 'inputs'.
7714 //   P4. Verify and collect the window attributes.
7715 //   P5. Verify the return type matches the operand-type.
7716 //   P6. Check if the result type of window operation matches the source type.
verify()7717 LogicalResult SelectAndScatterOp::verify() {
7718   auto operandType = operand().getType().cast<TensorType>();
7719   auto initValueType = init_value().getType().cast<TensorType>();
7720   auto sourceType = source().getType().cast<TensorType>();
7721   auto resultType = getResult().getType().cast<TensorType>();
7722 
7723   // P1.
7724   Block& selectBlock = select().front();
7725 
7726   if (selectBlock.getArguments().size() != 2)
7727     return emitOpError()
7728            << "expects the select-region to take 2 parameters, but takes "
7729            << selectBlock.getArguments().size();
7730 
7731   Type expectedSelectArgType =
7732       RankedTensorType::get({}, operandType.getElementType());
7733   for (const auto& selectArgIt : llvm::enumerate(selectBlock.getArguments()))
7734     if (!compatibleShapeAndElementType(expectedSelectArgType,
7735                                        selectArgIt.value().getType(),
7736                                        /*ignoreFpPrecision=*/true))
7737       return emitOpError()
7738              << "expects the type of select-region's parameter at index "
7739              << selectArgIt.index() << " to be " << expectedSelectArgType
7740              << ", but got " << selectArgIt.value().getType();
7741 
7742   auto selectResult = selectBlock.getTerminator()->getOperands();
7743   if (selectResult.size() != 1)
7744     return emitOpError()
7745            << "expects select-region to return single value, but got: "
7746            << selectResult.size();
7747 
7748   auto selectResultType = selectResult[0].getType().dyn_cast<TensorType>();
7749   if (!selectResultType || !selectResultType.getElementType().isInteger(1) ||
7750       (selectResultType.hasRank() &&
7751        selectResultType.cast<RankedTensorType>().getRank() != 0))
7752     return emitOpError() << "expects the return-type of select-region to be "
7753                             "tensor<i1>, but got: "
7754                          << selectResult[0].getType();
7755 
7756   // P2.
7757   Block& scatterBlock = scatter().front();
7758   SmallVector<TensorType> accumulatorSubshapes;
7759   if (failed(verifyReducerShape(
7760           this->getLoc(), scatterBlock,
7761           {RankedTensorType::get({}, sourceType.getElementType())},
7762           {initValueType},
7763           /*numInputs=*/1, /*allowedDimensions=*/{},
7764           /*allInputsUnranked=*/false, accumulatorSubshapes)))
7765     return failure();
7766 
7767   // P3.
7768   SmallVector<int64_t> windowDims =
7769       convertDenseIntAttr(this->window_dimensions());
7770   if (operandType.hasRank()) {
7771     if (operandType.getRank() != static_cast<int64_t>(windowDims.size()))
7772       return emitOpError()
7773              << "expects window-dimensions size == operand rank, but got "
7774                 "window-dimensions size: "
7775              << windowDims.size() << " and operand-type: " << operandType
7776              << " with rank = " << operandType.getRank() << ".";
7777   }
7778 
7779   // P4.
7780   auto paddingOrErr = convertNx2Attribute(this->padding(), getLoc());
7781   if (failed(paddingOrErr)) return failure();
7782   SmallVector<std::pair<int64_t, int64_t>> padding = *paddingOrErr;
7783 
7784   auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
7785       windowDims, convertDenseIntAttr(window_strides()), padding,
7786       /*lhs_dilation=*/{}, /*rhs_dilation=*/{}, getLoc());
7787   if (failed(windowOrErr)) return failure();
7788 
7789   // P5.
7790   if (!compatibleShapeAndElementType(operandType, resultType))
7791     return emitOpError()
7792            << "expects the return-type to match the operand-type, but got "
7793            << resultType << " and " << operandType << " resp.";
7794 
7795   // P6.
7796   auto windowResultType =
7797       inferSelectAndScatterOpReturnType(operandType, *windowOrErr);
7798 
7799   if (!compatibleShapeAndElementType(windowResultType, sourceType,
7800                                      /*ignoreFpPrecision=*/true))
7801     return emitOpError() << "expects source-type to be " << windowResultType
7802                          << ", but got" << sourceType;
7803 
7804   return success();
7805 }
7806 
7807 //===----------------------------------------------------------------------===//
7808 // ScatterOp
7809 //===----------------------------------------------------------------------===//
7810 
7811 /*
7812  * We intend to verify the following properties:
7813  * P1. The 'update_window_dims' must be valid indices of 'updates' tensor.
7814  * P2. The 'inserted_window_dims' must be valid indices of 'operand' tensor.
7815  * P3. Check if the rank-of('operand') == size-of('update_window_dims') +
7816  *     size-of('inserted_window_dims')
7817  * P4. size-of('scatter_dims_to_operand_dims') =
7818  *         'scatter_indices'['index_vector_dim'] &
7819  *     'scatter_dims_to_operand_dims' must be valid indices of 'operand' tensor.
7820  */
validateScatterDimensionNumbers(ShapedType operandType,ArrayRef<int64_t> scatterIndicesShape,ShapedType updateType,bool operandTypeRanked,bool scatterIndicesTypeRanked,bool updatesTypeRanked,ScatterDimensionNumbersAttr dimNumbers,Location loc)7821 LogicalResult validateScatterDimensionNumbers(
7822     ShapedType operandType, ArrayRef<int64_t> scatterIndicesShape,
7823     ShapedType updateType, bool operandTypeRanked,
7824     bool scatterIndicesTypeRanked, bool updatesTypeRanked,
7825     ScatterDimensionNumbersAttr dimNumbers, Location loc) {
7826   // P1.
7827   auto updateWindowDims = to_vector(dimNumbers.getUpdateWindowDims());
7828   if (!llvm::is_sorted(updateWindowDims))
7829     return mlir::emitError(loc)
7830            << "Expects update_window_dims to be sorted; got: ["
7831            << updateWindowDims << "].";
7832 
7833   if (hasDuplicates(updateWindowDims))
7834     return mlir::emitError(loc)
7835            << "Expects update_window_dims to not repeat; got: ["
7836            << updateWindowDims << "].";
7837 
7838   if (updatesTypeRanked) {
7839     for (int64_t windowDim : updateWindowDims) {
7840       if (windowDim < 0 || windowDim >= updateType.getRank()) {
7841         return mlir::emitError(loc)
7842                << "Expects each element of update_window_dims to be in range "
7843                   "[0, "
7844                   "rank-of('updates') i.e. [0, "
7845                << updateType.getRank() << "). got: " << windowDim << ".";
7846       }
7847     }
7848   }
7849 
7850   // P2.
7851   auto insertedWindowDims = to_vector(dimNumbers.getInsertedWindowDims());
7852   if (!llvm::is_sorted(insertedWindowDims))
7853     return mlir::emitError(loc)
7854            << "Expects inserted_window_dims to be sorted; got: ["
7855            << insertedWindowDims << "].";
7856 
7857   if (hasDuplicates(insertedWindowDims))
7858     return mlir::emitError(loc)
7859            << "Expects inserted_window_dims to not repeat; got: ["
7860            << insertedWindowDims << "].";
7861 
7862   if (operandTypeRanked) {
7863     for (int64_t insertedDim : insertedWindowDims) {
7864       if (insertedDim < 0 || insertedDim >= operandType.getRank()) {
7865         return mlir::emitError(loc)
7866                << "Expects each element of inserted_window_dims to be in range "
7867                   "[0, rank-of('operand') i.e. [0, "
7868                << operandType.getRank() << "). got: " << insertedDim << ".";
7869       }
7870     }
7871   }
7872 
7873   // P3.
7874   if (operandTypeRanked) {
7875     auto windowSize = updateWindowDims.size() + insertedWindowDims.size();
7876     if (operandType.getRank() != static_cast<int64_t>(windowSize))
7877       return mlir::emitError(loc)
7878              << "Expects rank-of operand to match "
7879                 "size-of('update_window_dims')  + "
7880                 "size-of('inserted_window_dims') i.e. "
7881              << windowSize << " but got " << operandType.getRank() << ".";
7882   }
7883 
7884   // P4.
7885   auto scatterDimsToOperandDims =
7886       to_vector(dimNumbers.getScatterDimsToOperandDims());
7887   auto indexVectorDim = dimNumbers.getIndexVectorDim();
7888   if (scatterIndicesTypeRanked) {
7889     if (!isDynamicDimSize(scatterIndicesShape[indexVectorDim]) &&
7890         static_cast<int64_t>(scatterDimsToOperandDims.size()) !=
7891             scatterIndicesShape[dimNumbers.getIndexVectorDim()])
7892       return mlir::emitError(loc)
7893              << "Scatter op has " << scatterDimsToOperandDims.size()
7894              << " elements in scatter_dims_to_operand_dims and the bound of "
7895                 "dimension index_vector_dim="
7896              << dimNumbers.getIndexVectorDim() << " of scatter_indices is "
7897              << scatterIndicesShape[dimNumbers.getIndexVectorDim()]
7898              << ". These two numbers must be equal.";
7899   }
7900 
7901   if (operandTypeRanked) {
7902     for (int64_t i = 0;
7903          i < static_cast<int64_t>(scatterDimsToOperandDims.size()); ++i) {
7904       int64_t scatterDimToOperandDim = scatterDimsToOperandDims[i];
7905       if (scatterDimToOperandDim < 0 ||
7906           scatterDimToOperandDim >= operandType.getRank())
7907         return mlir::emitError(loc)
7908                << "Invalid scatter_dims_to_operand_dims mapping; domain is [0, "
7909                << operandType.getRank() << "), got: " << i << "->"
7910                << scatterDimToOperandDim << ".";
7911     }
7912   }
7913 
7914   if (hasDuplicates(scatterDimsToOperandDims))
7915     return mlir::emitError(loc)
7916            << "Expects scatter_dims_to_operand_dims to not repeat; got: ["
7917            << scatterDimsToOperandDims << "].";
7918 
7919   return success();
7920 }
7921 /*
7922  * We intend to verify the following properties:
7923  *  P0. scatter_indices argument must be an integral tensor. Enforced by ODS.
7924  *  P1. Scatter index leaf dimension must be within [0, rank(scatter_indices)"
7925  *      " + 1).
7926  *  P2. Verify reducer shape.
7927  *  P3. rank-of('updates[i]') == size-of('update_window_dims') +
7928  *      rank-of('scatter_indices') - 1, where 'scatter_indices' is expanded by a
7929  *      trailing 1 dimension if 'index_vector_dim' == rank-of('scatter_indices')
7930  *      for all values of `i`.
7931  *  P4. Validate the scatter-dimensions-numbers.
7932  *  P5. Valide the bounds of each of the 'updates' w.r.t the operands.
7933  *  P6. Validate the bounds of each of the 'updates' w.r.t the
7934  * 'scatter_indices'.
7935  *  P7. Check return types.
7936  */
verify()7937 LogicalResult ScatterOp::verify() {
7938   // Get the first operand and update, since variadic Scatter is not yet
7939   // implemented
7940   auto numOperands = operands().size();
7941   auto scatterIndicesType = scatter_indices().getType().dyn_cast<TensorType>();
7942 
7943   SmallVector<TensorType, 1> operandTypes =
7944       llvm::to_vector(llvm::map_range(operands().getTypes(), [](Type type) {
7945         return type.cast<TensorType>();
7946       }));
7947   SmallVector<TensorType, 1> updatesTypes = llvm::to_vector(llvm::map_range(
7948       updates().getTypes(), [](Type type) { return type.cast<TensorType>(); }));
7949   bool allOperandTypesRanked =
7950       llvm::all_of(operands().getTypes(),
7951                    [](Type type) { return type.isa<RankedTensorType>(); });
7952   bool scatterIndicesTypeRanked = scatterIndicesType.isa<RankedTensorType>();
7953 
7954   // P1.
7955   int64_t indexVectorDim = scatter_dimension_numbers().getIndexVectorDim();
7956   if (scatterIndicesTypeRanked) {
7957     if (indexVectorDim > scatterIndicesType.getRank() || indexVectorDim < 0)
7958       return emitOpError()
7959              << "expects scatter index leaf dimension to be within [0, "
7960                 "rank(scatter_indices) + 1."
7961                 " rank(scatter_indices) is "
7962              << scatterIndicesType.getRank()
7963              << " and scatter index leaf dimension is " << indexVectorDim
7964              << ".";
7965   }
7966 
7967   // P2.
7968   Block& block = update_computation().front();
7969   SmallVector<TensorType> accumulatorSubshapes;
7970   SmallVector<TensorType> inputTypes, initValueTypes;
7971   for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
7972     inputTypes.push_back(operandTypes[i]);
7973     initValueTypes.push_back(
7974         RankedTensorType::get({}, updatesTypes[i].getElementType()));
7975   }
7976   if (failed(verifyReducerShape(
7977           this->getLoc(), block, inputTypes, initValueTypes, numOperands,
7978           /*allowedDimensions=*/{},
7979           /*allInputsUnranked=*/!allOperandTypesRanked, accumulatorSubshapes)))
7980     return failure();
7981 
7982   // P3.
7983   auto updateWindowDims = scatter_dimension_numbers().getUpdateWindowDims();
7984   SmallVector<int64_t> expandedScatterIndicesShape;
7985   if (scatterIndicesTypeRanked) {
7986     expandedScatterIndicesShape =
7987         llvm::to_vector(scatterIndicesType.getShape());
7988     if (static_cast<int64_t>(expandedScatterIndicesShape.size()) ==
7989         indexVectorDim)
7990       expandedScatterIndicesShape.push_back(1);
7991   }
7992 
7993   for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
7994     if (scatterIndicesTypeRanked && updatesTypes[i].isa<RankedTensorType>()) {
7995       int64_t expectedUpdatesRank =
7996           expandedScatterIndicesShape.size() - 1 + updateWindowDims.size();
7997       if (updatesTypes[i].getRank() != expectedUpdatesRank)
7998         return emitOpError()
7999                << "expects updates tensor must be of rank "
8000                << expectedUpdatesRank
8001                << " ( == rank-of('scatter_indices') - 1 + "
8002                   "size-of('update_window_dims'), where 'scatter_indices' is "
8003                   "expanded by a trailing 1 dimension if 'index_vector_dim' == "
8004                   "rank-of('scatter_indices')), but got "
8005                << updatesTypes[i].getRank() << ".";
8006     }
8007   }
8008 
8009   // P4.
8010   for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
8011     if (failed(validateScatterDimensionNumbers(
8012             operandTypes[i], expandedScatterIndicesShape, updatesTypes[i],
8013             operandTypes[i].isa<RankedTensorType>(), scatterIndicesTypeRanked,
8014             updatesTypes[i].isa<RankedTensorType>(),
8015             scatter_dimension_numbers(), getLoc())))
8016       return failure();
8017   }
8018 
8019   // P5.
8020   for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
8021     if (updatesTypes[i].isa<RankedTensorType>()) {
8022       auto updatesShape = updatesTypes[i].getShape();
8023       if (operandTypes[i].isa<RankedTensorType>()) {
8024         auto operandShape = operandTypes[i].getShape();
8025         auto insertedWindowDims =
8026             scatter_dimension_numbers().getInsertedWindowDims();
8027 
8028         int64_t insertedDimsSeen = 0;
8029         SmallVector<int64_t> maxUpdateSliceSizes;
8030         const auto dimensionsSize = operandTypes[i].getRank();
8031         maxUpdateSliceSizes.reserve(dimensionsSize);
8032         for (int i = 0; i < dimensionsSize; ++i) {
8033           if (insertedDimsSeen <
8034                   static_cast<int64_t>(insertedWindowDims.size()) &&
8035               insertedWindowDims[insertedDimsSeen] == i) {
8036             ++insertedDimsSeen;
8037           } else {
8038             maxUpdateSliceSizes.push_back(operandShape[i]);
8039           }
8040         }
8041 
8042         for (int64_t i = 0; i < static_cast<int64_t>(updateWindowDims.size());
8043              ++i) {
8044           auto updateWindowDim = updateWindowDims[i];
8045 
8046           if (isDynamicDimSize(updatesShape[updateWindowDim]) ||
8047               isDynamicDimSize(maxUpdateSliceSizes[i]))
8048             continue;
8049 
8050           if (updatesShape[updateWindowDim] > maxUpdateSliceSizes[i]) {
8051             return emitOpError()
8052                    << "expects bounds of the window dimensions of "
8053                       "updates to not exceed the "
8054                       "bounds of the corresponding dimensions of "
8055                       "operand. For dimension "
8056                    << updateWindowDim << ", updates bound is "
8057                    << updatesShape[updateWindowDim] << ", operand bound is "
8058                    << maxUpdateSliceSizes[i] << ".";
8059           }
8060         }
8061       }
8062 
8063       // P6.
8064       if (scatterIndicesTypeRanked) {
8065         int64_t scatterDimsSeen = 0;
8066         for (int64_t i = 0; i < static_cast<int64_t>(updatesShape.size());
8067              ++i) {
8068           bool isUpdateWindowDim = std::binary_search(
8069               updateWindowDims.begin(), updateWindowDims.end(), i);
8070 
8071           if (isUpdateWindowDim) continue;
8072           if (scatterDimsSeen == indexVectorDim) ++scatterDimsSeen;
8073 
8074           if (!isDynamicDimSize(updatesShape[i]) &&
8075               !isDynamicDimSize(expandedScatterIndicesShape[scatterDimsSeen]) &&
8076               (updatesShape[i] !=
8077                expandedScatterIndicesShape[scatterDimsSeen])) {
8078             return emitOpError()
8079                    << "expects bounds of the scatter dimensions of "
8080                       "updates to be same as the "
8081                       "bounds of the corresponding dimensions of "
8082                       "scatter indices. For "
8083                       "scatter dimension "
8084                    << i << ", updates bound is " << updatesShape[i]
8085                    << " , scatter_indices "
8086                       "bound is "
8087                    << expandedScatterIndicesShape[scatterDimsSeen] << ".";
8088           }
8089           ++scatterDimsSeen;
8090         }
8091       }
8092     }
8093   }
8094 
8095   // P7.
8096   for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
8097     if (!compatibleShapeAndElementType(operandTypes[i], getResult(i).getType()))
8098       return emitOpError()
8099              << "expects the return type to be same as the operand type: "
8100              << operandTypes[i] << ", but got " << getResult(i).getType()
8101              << ".";
8102   }
8103 
8104   return success();
8105 }
8106 
evaluateMhloRegion(Region & region,ArrayRef<Attribute> inputs)8107 llvm::SmallVector<Attribute, 4> evaluateMhloRegion(Region& region,
8108                                                    ArrayRef<Attribute> inputs) {
8109   if (region.getNumArguments() != inputs.size()) return {};
8110 
8111   llvm::DenseMap<Value, Attribute> values;
8112   values.reserve(region.getNumArguments());
8113   for (auto it : llvm::zip(region.getArguments(), inputs)) {
8114     values.try_emplace(std::get<0>(it), std::get<1>(it));
8115   }
8116 
8117   for (auto& op : region.getOps()) {
8118     llvm::SmallVector<Attribute, 4> inputs;
8119     for (auto& operand : op.getOpOperands()) {
8120       inputs.push_back(values.lookup(operand.get()));
8121     }
8122     if (isa<ReturnOp>(op)) return inputs;
8123 
8124     llvm::SmallVector<OpFoldResult, 4> results;
8125     if (failed(op.fold(inputs, results))) return {};
8126     for (auto it : llvm::zip(op.getResults(), results)) {
8127       if (!std::get<1>(it).is<Attribute>()) return {};
8128       values.insert({std::get<0>(it), std::get<1>(it).get<Attribute>()});
8129     }
8130   }
8131   return {};
8132 }
8133 
fold(ArrayRef<Attribute> args,llvm::SmallVectorImpl<OpFoldResult> & foldResults)8134 LogicalResult ScatterOp::fold(
8135     ArrayRef<Attribute> args,
8136     llvm::SmallVectorImpl<OpFoldResult>& foldResults) {
8137   // Variadic Scatter not yet implemented
8138   if (operands().size() != 1 || updates().size() != 1) return failure();
8139   auto index = args[1].dyn_cast_or_null<DenseIntElementsAttr>();
8140   if (!index) return failure();
8141 
8142   auto baseType = operands().getTypes()[0].dyn_cast<RankedTensorType>();
8143   auto updateType = updates().getTypes()[0].dyn_cast<RankedTensorType>();
8144   auto indexType = index.getType().cast<RankedTensorType>();
8145   if (!baseType || !indexType || !updateType) return failure();
8146 
8147   // TODO(b/228310289): Work around canonicalization crash for complex types.
8148   // Remove after upstream MLIR has been fixed.
8149   if (baseType.getElementType().isa<ComplexType>()) return failure();
8150 
8151   // Catch a trivial full replacement of base with update, this does not require
8152   // these to be constant: just that we know the type.
8153   if (updateType == baseType && updateType.hasStaticShape() &&
8154       baseType.hasStaticShape() && index.isSplat() &&
8155       index.getSplatValue<uint32_t>() == 0 &&
8156       llvm::hasSingleElement(update_computation().front())) {
8157     foldResults.push_back(updates()[0]);
8158     return success();
8159   }
8160   auto base = args[0].dyn_cast_or_null<DenseElementsAttr>();
8161   auto update = args[2].dyn_cast_or_null<DenseElementsAttr>();
8162   if (!base || !update) return failure();
8163 
8164   // Add the virtual trailing dimension of size 1 if indexVectorDim equals to
8165   // indexType.rank.
8166   const int64_t indexVectorDim =
8167       scatter_dimension_numbers().getIndexVectorDim();
8168   if (indexVectorDim == indexType.getRank()) {
8169     auto indexShape = indexType.getShape().vec();
8170     indexShape.push_back(1);
8171     indexType = RankedTensorType::get(indexShape, indexType.getElementType());
8172     index = reshape(index, indexType).cast<DenseIntElementsAttr>();
8173   }
8174 
8175   // Increment the multi-dimensional index vector based on the limits for each
8176   // dimension specified by shape and returns false if the index rolled around
8177   // with true otherwise.
8178   auto nextIndex = [](llvm::SmallVector<uint64_t, 8>& index,
8179                       llvm::ArrayRef<int64_t> shape) {
8180     for (int64_t i = index.size() - 1; i >= 0; --i) {
8181       ++index[i];
8182       if (index[i] < static_cast<unsigned long>(shape[i])) return true;
8183       index[i] = 0;
8184     }
8185     return false;
8186   };
8187 
8188   // Prevent folding if the result is too large.
8189   if (base.getNumElements() > kFoldOpEltLimit) return failure();
8190 
8191   // Iterate over all elements of the update tensor, then find the corresponding
8192   // value in the indices tensor to determine which location we have to update
8193   // in the base/result tensor.
8194   llvm::SmallVector<Attribute, 8> results(base.getValues<Attribute>());
8195   llvm::SmallVector<uint64_t, 8> updateIndex(updateType.getRank(), 0);
8196   llvm::SmallVector<uint64_t, 8> indexIndex;
8197   indexIndex.reserve(indexType.getRank());
8198   llvm::SmallVector<int64_t, 8> baseIndex;
8199   baseIndex.reserve(baseType.getRank());
8200   do {
8201     // Compute the index for the slice of the indices tensor for this update
8202     // value.
8203     indexIndex.clear();
8204     if (indexVectorDim == 0) indexIndex.push_back(0);
8205     for (int64_t i = 0; i < static_cast<int64_t>(updateIndex.size()); ++i) {
8206       if (llvm::count(scatter_dimension_numbers().getUpdateWindowDims(), i) ==
8207           0)
8208         indexIndex.push_back(updateIndex[i]);
8209       if (static_cast<int64_t>(indexIndex.size()) == indexVectorDim)
8210         indexIndex.push_back(0);
8211     }
8212 
8213     // Compute the index for the given update value in the base tensor.
8214     baseIndex.assign(baseType.getRank(), 0);
8215     uint64_t indexCount = indexType.getShape()[indexVectorDim];
8216     for (uint64_t i = 0; i < indexCount; ++i) {
8217       uint64_t operandDim =
8218           scatter_dimension_numbers().getScatterDimsToOperandDims()[i];
8219       indexIndex[indexVectorDim] = i;
8220       baseIndex[operandDim] +=
8221           index.getValues<APInt>()[indexIndex].getSExtValue();
8222     }
8223     uint64_t updateWindowDimIndex = 0;
8224     auto insertedWindowDims =
8225         scatter_dimension_numbers().getInsertedWindowDims();
8226     auto updateWindowDims = scatter_dimension_numbers().getUpdateWindowDims();
8227     for (uint64_t i = 0; i < baseIndex.size(); ++i) {
8228       if (llvm::count(insertedWindowDims, i)) continue;
8229       baseIndex[i] += updateIndex[updateWindowDims[updateWindowDimIndex]];
8230       updateWindowDimIndex++;
8231     }
8232 
8233     // Compute the linear index for the index into the base tensor.
8234     int64_t linearBaseIndex = 0;
8235     int64_t linearBaseIndexMultiplyer = 1;
8236     for (int64_t i = baseIndex.size() - 1; i >= 0; --i) {
8237       // Out of bound index have backend specific behaviour so avoid folding it.
8238       if (baseIndex[i] < 0 || baseIndex[i] >= baseType.getShape()[i])
8239         return failure();
8240       linearBaseIndex += baseIndex[i] * linearBaseIndexMultiplyer;
8241       linearBaseIndexMultiplyer *= baseType.getShape()[i];
8242     }
8243 
8244     // Evaluate update computation and update the value with the newly computed
8245     // attribute in the base tensor.
8246     auto lhs = DenseElementsAttr::get(
8247         RankedTensorType::get({}, baseType.getElementType()),
8248         results[linearBaseIndex]);
8249     auto rhs = DenseElementsAttr::get(
8250         RankedTensorType::get({}, baseType.getElementType()),
8251         update.getValues<Attribute>()[updateIndex]);
8252     auto newValue = evaluateMhloRegion(update_computation(), {lhs, rhs});
8253     if (newValue.size() != 1 || !newValue[0]) return failure();
8254     results[linearBaseIndex] =
8255         newValue[0].cast<DenseElementsAttr>().getValues<Attribute>()[0];
8256   } while (nextIndex(updateIndex, updateType.getShape()));
8257 
8258   foldResults.push_back(DenseElementsAttr::get(baseType, results));
8259   return success();
8260 }
8261 
8262 //===----------------------------------------------------------------------===//
8263 // WhileOp
8264 //===----------------------------------------------------------------------===//
8265 
verify()8266 LogicalResult WhileOp::verify() {
8267   if (getNumOperands() != cond().front().getNumArguments())
8268     return emitOpError() << "mismatch in operand count (" << getNumOperands()
8269                          << ") vs the condition block argument count ("
8270                          << cond().front().getNumArguments() << ")";
8271   if (getNumOperands() != body().front().getNumArguments())
8272     return emitOpError() << "mismatch in operand count (" << getNumOperands()
8273                          << ") vs the body block argument count ("
8274                          << body().front().getNumArguments() << ")";
8275   for (const auto& enumeratedOperands : llvm::enumerate(
8276            llvm::zip(getOperandTypes(), cond().front().getArgumentTypes(),
8277                      body().front().getArgumentTypes()))) {
8278     int argCount = enumeratedOperands.index();
8279     const auto& operands = enumeratedOperands.value();
8280     Type operandType = std::get<0>(operands);
8281     Type condType = std::get<1>(operands);
8282     Type bodyType = std::get<2>(operands);
8283     if (operandType != condType)
8284       return emitOpError() << "type mismatch between operand #" << argCount
8285                            << " and the matching condition block argument: "
8286                            << operandType << " vs " << condType;
8287     if (operandType != bodyType)
8288       return emitOpError() << "type mismatch between operand #" << argCount
8289                            << " and the matching body block argument: "
8290                            << operandType << " vs " << bodyType;
8291   }
8292   // Check the return type for the condition block.
8293   {
8294     auto condReturnOp = cast<ReturnOp>(cond().front().back());
8295     if (condReturnOp->getNumOperands() != 1)
8296       return condReturnOp.emitOpError()
8297              << "expects a single operand for while condition body return, got "
8298              << condReturnOp->getNumOperands();
8299     auto operandType =
8300         condReturnOp->getOperand(0).getType().dyn_cast<RankedTensorType>();
8301     if (!operandType || operandType.getRank() != 0 ||
8302         !operandType.getElementType().isInteger(1))
8303       return condReturnOp.emitOpError()
8304              << "expects a zero-ranked tensor of i1, got "
8305              << condReturnOp->getOperand(0).getType();
8306   }
8307   // Check the return type for the body block.
8308   {
8309     auto bodyReturnOp = cast<ReturnOp>(body().front().back());
8310     if (bodyReturnOp->getNumOperands() != getNumOperands())
8311       return bodyReturnOp.emitOpError()
8312              << "expects body to return a many value as the operands ("
8313              << getNumOperands() << "), got " << bodyReturnOp->getNumOperands();
8314     for (const auto& enumeratedOperandTypes : llvm::enumerate(
8315              llvm::zip(bodyReturnOp->getOperandTypes(), getOperandTypes()))) {
8316       Type operandType = std::get<0>(enumeratedOperandTypes.value());
8317       Type returnType = std::get<1>(enumeratedOperandTypes.value());
8318       if (operandType != returnType)
8319         return bodyReturnOp.emitOpError()
8320                << "type mismatch between operand #"
8321                << enumeratedOperandTypes.index()
8322                << " and the enclosing WhileOp returned value: " << operandType
8323                << " vs " << returnType;
8324     }
8325   }
8326   return success();
8327 }
8328 
8329 /// Print a `while` op.
8330 ///
8331 /// op ::= `mhlo.while` `(` assignment-list `)` `:` types attribute-dict
8332 ///         `cond` region
8333 ///         `do` region
8334 /// assignment-list ::= assignment | assignment `,` assignment-list
8335 /// assignment ::= ssa-value `=` ssa-value
print(OpAsmPrinter & p)8336 void WhileOp::print(OpAsmPrinter& p) {
8337   p << '(';
8338   llvm::interleaveComma(llvm::zip(getBody()->getArguments(), getOperands()), p,
8339                         [&](auto zip) {
8340                           p.printOperand(std::get<0>(zip));
8341                           p << " = ";
8342                           p.printOperand(std::get<1>(zip));
8343                         });
8344   p << ")";
8345   if (getNumOperands()) {
8346     p << " : ";
8347     llvm::interleaveComma(getOperandTypes(), p);
8348   }
8349   p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs());
8350   p.printNewline();
8351   p << " cond ";
8352   p.printRegion(getRegion(0), /*printEntryBlockArgs=*/false);
8353   p << " do ";
8354   p.printRegion(getRegion(1), /*printEntryBlockArgs=*/false);
8355 }
8356 
parse(OpAsmParser & parser,OperationState & result)8357 ParseResult WhileOp::parse(OpAsmParser& parser, OperationState& result) {
8358   llvm::SMLoc loc = parser.getCurrentLocation();
8359   // Parse the operands of the while: these are of the form:
8360   //   %iter_arg = %init_val
8361   // where %iter_arg is the name of the block argument in the cond/body blocks
8362   // and %init_val is the actual operand.
8363   SmallVector<OpAsmParser::UnresolvedOperand> operands;
8364   SmallVector<OpAsmParser::UnresolvedOperand> iterArgs;
8365   if (parser.parseLParen()) return failure();
8366   do {
8367     if (succeeded(parser.parseOptionalRParen())) break;
8368     OpAsmParser::UnresolvedOperand operand, iterArg;
8369     if (parser.parseOperand(iterArg) || parser.parseEqual() ||
8370         parser.parseOperand(operand))
8371       return failure();
8372     iterArgs.push_back(iterArg);
8373     operands.push_back(operand);
8374     if (succeeded(parser.parseOptionalRParen())) break;
8375     if (failed(parser.parseComma())) return failure();
8376   } while (true);
8377   if (!operands.empty()) {
8378     if (parser.parseColon() || parser.parseTypeList(result.types))
8379       return failure();
8380   }
8381 
8382   SmallVector<OpAsmParser::Argument> args;
8383   createArgs(iterArgs, result.types, args);
8384   if (parser.resolveOperands(operands, result.types, loc, result.operands) ||
8385       parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
8386       parser.parseKeyword("cond") ||
8387       parser.parseRegion(*result.addRegion(), args) ||
8388       parser.parseKeyword("do") ||
8389       parser.parseRegion(*result.addRegion(), args))
8390     return failure();
8391   return success();
8392 }
8393 
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> & results)8394 LogicalResult WhileOp::fold(ArrayRef<Attribute> /*operands*/,
8395                             SmallVectorImpl<OpFoldResult>& results) {
8396   DenseIntElementsAttr condValue;
8397   auto condReturnOp = cast<ReturnOp>(cond().front().back());
8398   if (!matchPattern(condReturnOp.getOperand(0), m_Constant(&condValue)))
8399     return failure();
8400   if (condValue.getSplatValue<BoolAttr>().getValue())
8401     return failure();  // TODO(mhlo): this is an infinite loop, should we fold?
8402 
8403   results.append(getOperands().begin(), getOperands().end());
8404   return success();
8405 }
8406 
whileCanonicalization(WhileOp whileOp,PatternRewriter & rewriter)8407 static LogicalResult whileCanonicalization(WhileOp whileOp,
8408                                            PatternRewriter& rewriter) {
8409   // Turn loop invariant values into implicit capture.
8410   // Check if there is at least one value is forwarded from one iteration to the
8411   // next, or one of the yielded value is an implicit capture already. Otherwise
8412   // there is nothing to do here.
8413   Block* cond = whileOp.getBody(0);
8414   Block* body = whileOp.getBody(1);
8415   auto bodyReturnOp = cast<ReturnOp>(body->getTerminator());
8416   if (!llvm::any_of(llvm::zip(whileOp->getOperands(), body->getArguments(),
8417                               bodyReturnOp->getOperands()),
8418                     [&](auto zip) {
8419                       return (std::get<0>(zip) == std::get<2>(zip) ||
8420                               std::get<1>(zip) == std::get<2>(zip));
8421                     }))
8422     return rewriter.notifyMatchFailure(whileOp, "no loop invariant found");
8423 
8424   SmallVector<Value> newOperands, resultsToReplace;
8425   SmallVector<unsigned> invariantArgIdxs;
8426   for (const auto& enumeratedOperands : llvm::enumerate(llvm::zip(
8427            whileOp.getOperands(), cond->getArguments(), body->getArguments(),
8428            bodyReturnOp->getOperands(), whileOp->getResults()))) {
8429     const auto& operands = enumeratedOperands.value();
8430     Value whileOperand = std::get<0>(operands);
8431     BlockArgument condBlockArg = std::get<1>(operands);
8432     BlockArgument bodyBlockArg = std::get<2>(operands);
8433     Value bodyReturnOperand = std::get<3>(operands);
8434     Value whileResult = std::get<4>(operands);
8435 
8436     bool forwarded = (whileOperand == bodyReturnOperand ||
8437                       bodyBlockArg == bodyReturnOperand);
8438     if (forwarded) {
8439       invariantArgIdxs.push_back(enumeratedOperands.index());
8440       condBlockArg.replaceAllUsesWith(whileOperand);
8441       bodyBlockArg.replaceAllUsesWith(whileOperand);
8442       whileResult.replaceAllUsesWith(whileOperand);
8443       continue;
8444     }
8445     newOperands.push_back(whileOperand);
8446     resultsToReplace.push_back(whileResult);
8447   }
8448   cond->eraseArguments(invariantArgIdxs);
8449   body->eraseArguments(invariantArgIdxs);
8450   for (int idx : llvm::reverse(invariantArgIdxs))
8451     bodyReturnOp->eraseOperand(idx);
8452 
8453   WhileOp newWhileOp = rewriter.create<WhileOp>(
8454       whileOp.getLoc(), bodyReturnOp->getOperandTypes(), newOperands);
8455   newWhileOp.getBodyRegion(0).takeBody(whileOp.getBodyRegion(0));
8456   newWhileOp.getBodyRegion(1).takeBody(whileOp.getBodyRegion(1));
8457   for (auto results : llvm::zip(resultsToReplace, newWhileOp->getResults()))
8458     std::get<0>(results).replaceAllUsesWith(std::get<1>(results));
8459   rewriter.eraseOp(whileOp);
8460   return success();
8461 }
8462 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)8463 void WhileOp::getCanonicalizationPatterns(RewritePatternSet& results,
8464                                           MLIRContext* context) {
8465   results.add(&whileCanonicalization);
8466 }
8467 
inferReturnTypeComponents(MLIRContext *,Optional<Location>,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)8468 LogicalResult UniformDequantizeOp::inferReturnTypeComponents(
8469     MLIRContext*, Optional<Location> /*location*/, ValueShapeRange operands,
8470     DictionaryAttr attributes, RegionRange regions,
8471     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
8472   UniformDequantizeOp::Adaptor adaptor(operands, attributes, regions);
8473   auto operandType = (*operands.begin()).getType().cast<ShapedType>();
8474   // Trait HLO_QuantizedIntTensor in ODS guarantees QuantizedType;
8475   auto quantType = operandType.getElementType().cast<quant::QuantizedType>();
8476   auto shape = operandType.dyn_cast<ShapedType>().getShape();
8477   inferredReturnShapes.emplace_back(shape, quantType.getExpressedType());
8478   return success();
8479 }
8480 
8481 using mlir::hlo::parseWindowAttributes;
8482 using mlir::hlo::printWindowAttributes;
8483 
8484 }  // namespace mhlo
8485 }  // namespace mlir
8486 
8487 #define GET_OP_CLASSES
8488 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
8489 
8490 namespace mlir {
8491 namespace mhlo {
8492 
8493 //===----------------------------------------------------------------------===//
8494 // mhlo Dialect Interfaces
8495 //===----------------------------------------------------------------------===//
8496 
8497 namespace {
8498 struct HLOInlinerInterface : public DialectInlinerInterface {
8499   using DialectInlinerInterface::DialectInlinerInterface;
8500 
8501   // Allow all call operations to be inlined.
isLegalToInlinemlir::mhlo::__anon00baf10a4511::HLOInlinerInterface8502   bool isLegalToInline(Operation* call, Operation* callable,
8503                        bool wouldBeCloned) const final {
8504     return true;
8505   }
8506   // We don't have any special restrictions on what can be inlined into
8507   // destination regions (e.g. while/conditional bodies). Always allow it.
isLegalToInlinemlir::mhlo::__anon00baf10a4511::HLOInlinerInterface8508   bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned,
8509                        BlockAndValueMapping& valueMapping) const final {
8510     return true;
8511   }
8512   // Operations in mhlo dialect are always legal to inline since they are
8513   // pure.
isLegalToInlinemlir::mhlo::__anon00baf10a4511::HLOInlinerInterface8514   bool isLegalToInline(Operation*, Region*, bool,
8515                        BlockAndValueMapping&) const final {
8516     return true;
8517   }
8518 };
8519 }  // end anonymous namespace
8520 
8521 //===----------------------------------------------------------------------===//
8522 // mhlo Dialect Constructor
8523 //===----------------------------------------------------------------------===//
8524 
MhloDialect(MLIRContext * context)8525 MhloDialect::MhloDialect(MLIRContext* context)
8526     : Dialect(getDialectNamespace(), context, TypeID::get<MhloDialect>()) {
8527   addOperations<
8528 #define GET_OP_LIST
8529 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
8530       >();
8531   addInterfaces<HLOInlinerInterface>();
8532   addTypes<TokenType>();
8533   addAttributes<
8534 #define GET_ATTRDEF_LIST
8535 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.cc.inc"
8536       >();
8537   context->loadDialect<tensor::TensorDialect>();
8538 }
8539 
parseType(DialectAsmParser & parser) const8540 Type MhloDialect::parseType(DialectAsmParser& parser) const {
8541   StringRef dataType;
8542   if (parser.parseKeyword(&dataType)) return Type();
8543 
8544   if (dataType == "token") return TokenType::get(getContext());
8545   parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << dataType;
8546   return nullptr;
8547 }
8548 
printType(Type type,DialectAsmPrinter & os) const8549 void MhloDialect::printType(Type type, DialectAsmPrinter& os) const {
8550   if (type.isa<TokenType>()) {
8551     os << "token";
8552     return;
8553   }
8554   os << "<unknown mhlo type>";
8555 }
8556 
8557 // Entry point for Attribute parsing, TableGen generated code will handle the
8558 // dispatch to the individual classes.
parseAttribute(DialectAsmParser & parser,Type type) const8559 Attribute MhloDialect::parseAttribute(DialectAsmParser& parser,
8560                                       Type type) const {
8561   StringRef attrTag;
8562   Attribute attr;
8563   auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
8564   if (parseResult.hasValue()) return attr;
8565   parser.emitError(parser.getNameLoc(), "unknown mhlo attribute");
8566   return Attribute();
8567 }
8568 
8569 // Entry point for Attribute printing, TableGen generated code will handle the
8570 // dispatch to the individual classes.
printAttribute(Attribute attr,DialectAsmPrinter & os) const8571 void MhloDialect::printAttribute(Attribute attr, DialectAsmPrinter& os) const {
8572   LogicalResult result = generatedAttributePrinter(attr, os);
8573   (void)result;
8574   assert(succeeded(result));
8575 }
8576 
8577 /// Helpers for attributes parsing.
8578 
parseDims(AsmParser & parser,SmallVector<int64_t> & dims)8579 static ParseResult parseDims(AsmParser& parser, SmallVector<int64_t>& dims) {
8580   dims.clear();
8581   if (parser.parseLSquare()) return failure();
8582   while (failed(parser.parseOptionalRSquare())) {
8583     dims.emplace_back();
8584     if (parser.parseInteger(dims.back())) return failure();
8585     (void)parser.parseOptionalComma();
8586   }
8587   return success();
8588 }
8589 
parseDimsWithMinimumElements(AsmParser & parser,SmallVector<int64_t> & dims,int minElements)8590 static ParseResult parseDimsWithMinimumElements(AsmParser& parser,
8591                                                 SmallVector<int64_t>& dims,
8592                                                 int minElements) {
8593   if (failed(parseDims(parser, dims))) return failure();
8594   if (static_cast<int64_t>(dims.size()) < minElements)
8595     return parser.emitError(parser.getCurrentLocation())
8596            << "expected at least " << minElements << " element(s), found "
8597            << dims.size();
8598   return success();
8599 }
8600 
8601 /// Parse a custom attribute that resembles a struct of the form
8602 /// <
8603 ///   foo = something_parsed_by_custom_parser,
8604 ///   bar = something_parsed_by_different_custom_parser,
8605 ///   baz something_parsed_by_another_custom_parser
8606 /// >
8607 /// The optional argument `parse_equal` array can be used to denote if
8608 /// '=' follows the keyword (see baz in the example above) for a field. If
8609 /// not provided, all fields must be followed by a '='.
parseStruct(AsmParser & parser,ArrayRef<StringRef> keywords,ArrayRef<llvm::function_ref<ParseResult ()>> parseFuncs,ArrayRef<bool> parseEqual={})8610 static ParseResult parseStruct(
8611     AsmParser& parser, ArrayRef<StringRef> keywords,
8612     ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs,
8613     ArrayRef<bool> parseEqual = {}) {
8614   assert(keywords.size() == parseFuncs.size());
8615   assert(parseEqual.empty() || parseEqual.size() == keywords.size());
8616   SmallVector<bool> seen(keywords.size(), false);
8617   while (failed(parser.parseOptionalGreater())) {
8618     bool foundOne = false;
8619     for (const auto& it : llvm::enumerate(keywords)) {
8620       size_t index = it.index();
8621       StringRef keyword = it.value();
8622       if (succeeded(parser.parseOptionalKeyword(keyword))) {
8623         if (seen[index]) {
8624           return parser.emitError(parser.getCurrentLocation())
8625                  << "duplicated `" << keyword << "` entry";
8626         }
8627         if (parseEqual.empty() || parseEqual[index]) {
8628           if (failed(parser.parseEqual())) return failure();
8629         }
8630         if (failed(parseFuncs[index]())) return failure();
8631         if (failed(parser.parseOptionalComma())) return parser.parseGreater();
8632         seen[index] = true;
8633         foundOne = true;
8634       }
8635     }
8636     if (!foundOne) {
8637       auto parseError = parser.emitError(parser.getCurrentLocation())
8638                         << "expected one of: ";
__anon00baf10a4602(StringRef kw) 8639       llvm::interleaveComma(keywords, parseError, [&](StringRef kw) {
8640         parseError << '`' << kw << '`';
8641       });
8642       return parseError;
8643     }
8644   }
8645   return success();
8646 }
8647 
8648 // Helpers to print an optional array or integer field, to simplify writing
8649 // attribute printers.
8650 template <typename T>
printField(AsmPrinter & printer,StringRef name,T field,StringRef & separator)8651 static void printField(AsmPrinter& printer, StringRef name, T field,
8652                        StringRef& separator) {
8653   if (field != 0) {
8654     printer << separator << name << " = " << field;
8655     separator = ", ";
8656   }
8657 }
8658 template <typename T>
printField(AsmPrinter & printer,StringRef name,ArrayRef<T> field,StringRef & separator)8659 static void printField(AsmPrinter& printer, StringRef name, ArrayRef<T> field,
8660                        StringRef& separator) {
8661   if (!field.empty()) {
8662     printer << separator << name << " = [";
8663     llvm::interleaveComma(field, printer);
8664     printer << "]";
8665     separator = ", ";
8666   }
8667 }
8668 template <typename... Ts>
printStruct(AsmPrinter & printer,StringRef name,Ts...printFields)8669 static void printStruct(AsmPrinter& printer, StringRef name,
8670                         Ts... printFields) {
8671   printer << "<";
8672   StringRef separator = "";
8673   // Fold expression to print each entry in the parameter pack.
8674   // TODO(mhlo-team): this can be simplified when TF moves to C++17.
8675   using unused = int[];
8676   (void)unused{0, (printField(printer, std::get<0>(printFields),
8677                               std::get<1>(printFields), separator),
8678                    0)...};
8679   printer << ">";
8680 }
8681 
8682 // Custom printer and parser for ScatterDimensionNumbersAttr.
print(AsmPrinter & printer) const8683 void ScatterDimensionNumbersAttr::print(AsmPrinter& printer) const {
8684   printStruct(printer, "scatter",
8685               std::make_pair("update_window_dims", getUpdateWindowDims()),
8686               std::make_pair("inserted_window_dims", getInsertedWindowDims()),
8687               std::make_pair("scatter_dims_to_operand_dims",
8688                              getScatterDimsToOperandDims()),
8689               std::make_pair("index_vector_dim", getIndexVectorDim()));
8690 }
parse(AsmParser & parser,Type type)8691 Attribute ScatterDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
8692   if (failed(parser.parseLess())) return {};
8693   SmallVector<int64_t> updateWindowDims;
8694   SmallVector<int64_t> insertedWindowDims;
8695   SmallVector<int64_t> scatterDimsToOperandDims;
8696   int64_t indexVectorDim = 0;
8697 
8698   if (failed(parseStruct(
8699           parser,
8700           {"update_window_dims", "inserted_window_dims",
8701            "scatter_dims_to_operand_dims", "index_vector_dim"},
8702           {[&]() { return parseDims(parser, updateWindowDims); },
8703            [&]() { return parseDims(parser, insertedWindowDims); },
8704            [&]() { return parseDims(parser, scatterDimsToOperandDims); },
8705            [&]() { return parser.parseInteger(indexVectorDim); }}))) {
8706     parser.emitError(parser.getCurrentLocation())
8707         << "failed parsing scatter dimension numbers attribute";
8708     return {};
8709   }
8710 
8711   return ScatterDimensionNumbersAttr::get(
8712       parser.getContext(), updateWindowDims, insertedWindowDims,
8713       scatterDimsToOperandDims, indexVectorDim);
8714 }
8715 
8716 // Custom printer and parser for GatherDimensionNumbersAttr.
print(AsmPrinter & printer) const8717 void GatherDimensionNumbersAttr::print(AsmPrinter& printer) const {
8718   printStruct(printer, "gather", std::make_pair("offset_dims", getOffsetDims()),
8719               std::make_pair("collapsed_slice_dims", getCollapsedSliceDims()),
8720               std::make_pair("start_index_map", getStartIndexMap()),
8721               std::make_pair("index_vector_dim", getIndexVectorDim()));
8722 }
8723 
parse(AsmParser & parser,Type type)8724 Attribute GatherDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
8725   if (failed(parser.parseLess())) return {};
8726 
8727   SmallVector<int64_t> offsetDims;
8728   SmallVector<int64_t> collapsedSliceDims;
8729   SmallVector<int64_t> startIndexMap;
8730   int64_t indexVectorDim = 0;
8731 
8732   if (failed(parseStruct(
8733           parser,
8734           {"offset_dims", "collapsed_slice_dims", "start_index_map",
8735            "index_vector_dim"},
8736           {[&]() { return parseDims(parser, offsetDims); },
8737            [&]() { return parseDims(parser, collapsedSliceDims); },
8738            [&]() { return parseDims(parser, startIndexMap); },
8739            [&]() { return parser.parseInteger(indexVectorDim); }}))) {
8740     parser.emitError(parser.getCurrentLocation())
8741         << "failed parsing gather dimension numbers attribute";
8742     return {};
8743   }
8744 
8745   return GatherDimensionNumbersAttr::get(parser.getContext(), offsetDims,
8746                                          collapsedSliceDims, startIndexMap,
8747                                          indexVectorDim);
8748 }
8749 
8750 // Custom printer and parser for DotDimensionNumbersAttr.
print(AsmPrinter & printer) const8751 void DotDimensionNumbersAttr::print(AsmPrinter& printer) const {
8752   printStruct(
8753       printer, "dot",
8754       std::make_pair("lhs_batching_dimensions", getLhsBatchingDimensions()),
8755       std::make_pair("rhs_batching_dimensions", getRhsBatchingDimensions()),
8756       std::make_pair("lhs_contracting_dimensions",
8757                      getLhsContractingDimensions()),
8758       std::make_pair("rhs_contracting_dimensions",
8759                      getRhsContractingDimensions()));
8760 }
8761 
parse(AsmParser & parser,Type type)8762 Attribute DotDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
8763   if (failed(parser.parseLess())) return {};
8764 
8765   SmallVector<int64_t> lhsBatchingDimensions;
8766   SmallVector<int64_t> rhsBatchingDimensions;
8767   SmallVector<int64_t> lhsContractingDimensions;
8768   SmallVector<int64_t> rhsContractingDimensions;
8769 
8770   if (failed(parseStruct(
8771           parser,
8772           {"lhs_batching_dimensions", "rhs_batching_dimensions",
8773            "lhs_contracting_dimensions", "rhs_contracting_dimensions"},
8774           {[&]() { return parseDims(parser, lhsBatchingDimensions); },
8775            [&]() { return parseDims(parser, rhsBatchingDimensions); },
8776            [&]() { return parseDims(parser, lhsContractingDimensions); },
8777            [&]() { return parseDims(parser, rhsContractingDimensions); }}))) {
8778     parser.emitError(parser.getCurrentLocation())
8779         << "failed parsing dot dimension numbers attribute";
8780     return {};
8781   }
8782   return DotDimensionNumbersAttr::get(
8783       parser.getContext(), lhsBatchingDimensions, rhsBatchingDimensions,
8784       lhsContractingDimensions, rhsContractingDimensions);
8785 }
8786 
8787 namespace {
8788 enum NonSpatialDim : int64_t {
8789   IOBatch = -1,    // Input or output batch dimension
8790   IOFeature = -2,  // Input or output feature dimension
8791   KIFeature = -3,  // Kernel input feature dimension
8792   KOFeature = -4,  // Kernel output feature dimensions.
8793 };
8794 
8795 struct DenseMapInfoNonSpatialDim {
getEmptyKeymlir::mhlo::__anon00baf10a5311::DenseMapInfoNonSpatialDim8796   static inline NonSpatialDim getEmptyKey() {
8797     return NonSpatialDim(DenseMapInfo<int64_t>::getEmptyKey());
8798   }
8799 
getTombstoneKeymlir::mhlo::__anon00baf10a5311::DenseMapInfoNonSpatialDim8800   static inline NonSpatialDim getTombstoneKey() {
8801     return NonSpatialDim(DenseMapInfo<int64_t>::getTombstoneKey());
8802   }
8803 
getHashValuemlir::mhlo::__anon00baf10a5311::DenseMapInfoNonSpatialDim8804   static unsigned getHashValue(const NonSpatialDim& key) {
8805     return DenseMapInfo<int64_t>::getHashValue(key);
8806   }
8807 
isEqualmlir::mhlo::__anon00baf10a5311::DenseMapInfoNonSpatialDim8808   static bool isEqual(const NonSpatialDim& lhs, const NonSpatialDim& rhs) {
8809     return lhs == rhs;
8810   }
8811 };
8812 
nonSpatialDimToString(NonSpatialDim dim)8813 char nonSpatialDimToString(NonSpatialDim dim) {
8814   switch (dim) {
8815     case IOBatch:
8816       return 'b';
8817     case IOFeature:
8818       return 'f';
8819     case KIFeature:
8820       return 'i';
8821     case KOFeature:
8822       return 'o';
8823   }
8824   llvm_unreachable("Unknown NonSpatialDim");
8825 }
8826 }  // namespace
8827 
8828 // Custom printer and parser for convolution attribute.
printConvolutionDimensions(AsmPrinter & p,ConvDimensionNumbersAttr dnums)8829 void printConvolutionDimensions(AsmPrinter& p, ConvDimensionNumbersAttr dnums) {
8830   // TODO(b/202040055): we should check the attribute invariant and print the
8831   // "raw" form if they are violated, otherwise we'll crash here.
8832   constexpr int64_t kUnknownDim = std::numeric_limits<int64_t>::min();
8833   auto printDim =
8834       [&](ArrayRef<int64_t> spatialDims,
8835           ArrayRef<std::pair<int64_t, NonSpatialDim>> nonSpatialDims) {
8836         int64_t numDims = 0;
8837         if (!spatialDims.empty()) {
8838           numDims =
8839               *std::max_element(spatialDims.begin(), spatialDims.end()) + 1;
8840         }
8841         for (const auto& dim : nonSpatialDims) {
8842           numDims = std::max(numDims, dim.first + 1);
8843         }
8844 
8845         llvm::SmallVector<int64_t> dims(numDims, kUnknownDim);
8846         // Fill each element of dims with a (< 0) NonSpatialDim enum or a (>=0)
8847         // spatial dimension index.
8848         for (const std::pair<int64_t, NonSpatialDim>& nonSpatialDim :
8849              nonSpatialDims) {
8850           dims[nonSpatialDim.first] = nonSpatialDim.second;
8851         }
8852         for (const auto& spatialDim : llvm::enumerate(spatialDims)) {
8853           dims[spatialDim.value()] = static_cast<int64_t>(spatialDim.index());
8854         }
8855 
8856         // Each dimension numbers will be printed as a comma separated list
8857         // surrounded by square brackets, e.g., [b, 0, 1, 2, f]
8858         p << '[';
8859         llvm::interleaveComma(dims, p, [&](int64_t dim) {
8860           if (dim == kUnknownDim) {
8861             p << "?";
8862           } else if (dim >= 0) {
8863             p << dim;
8864           } else {
8865             p << nonSpatialDimToString(static_cast<NonSpatialDim>(dim));
8866           }
8867         });
8868         p << ']';
8869       };
8870 
8871   printDim(dnums.getInputSpatialDimensions(),
8872            {{dnums.getInputBatchDimension(), IOBatch},
8873             {dnums.getInputFeatureDimension(), IOFeature}});
8874   p << "x";
8875   printDim(dnums.getKernelSpatialDimensions(),
8876            {{dnums.getKernelInputFeatureDimension(), KIFeature},
8877             {dnums.getKernelOutputFeatureDimension(), KOFeature}});
8878   p << "->";
8879   printDim(dnums.getOutputSpatialDimensions(),
8880            {{dnums.getOutputBatchDimension(), IOBatch},
8881             {dnums.getOutputFeatureDimension(), IOFeature}});
8882 }
8883 
printConvolutionDimensions(AsmPrinter & p,Operation *,ConvDimensionNumbersAttr dnums)8884 void printConvolutionDimensions(AsmPrinter& p, Operation*,
8885                                 ConvDimensionNumbersAttr dnums) {
8886   printConvolutionDimensions(p, dnums);
8887 }
8888 
8889 // Custom printer and parser for ConvDimensionNumbersAttr.
print(AsmPrinter & printer) const8890 void ConvDimensionNumbersAttr::print(AsmPrinter& printer) const {
8891   printer << "<";
8892   printConvolutionDimensions(printer, *this);
8893   printer << ">";
8894 }
8895 
8896 // If the attribute is written with `#mhlo.conv raw<`, we parse it as a struct
8897 // instead of the compressed format. This enables writing tests covering
8898 // impossible/invalid internal representation for the attribute.
parseConvolutionDimensionsRaw(AsmParser & parser,ConvDimensionNumbersAttr & dnums)8899 static ParseResult parseConvolutionDimensionsRaw(
8900     AsmParser& parser, ConvDimensionNumbersAttr& dnums) {
8901   int64_t inputBatchDimension = 0;
8902   int64_t inputFeatureDimension = 0;
8903   SmallVector<int64_t> inputSpatialDimensions;
8904   int64_t kernelInputFeatureDimension = 0;
8905   int64_t kernelOutputFeatureDimension = 0;
8906   SmallVector<int64_t> kernelSpatialDimensions;
8907   int64_t outBatchDimension = 0;
8908   int64_t outputFeatureDimension = 0;
8909   SmallVector<int64_t> outputSpatialDimensions;
8910   if (failed(parseStruct(
8911           parser,
8912           {"input_batch_dimension", "input_feature_dimension",
8913            "input_spatial_dimensions", "kernel_input_feature_dimension",
8914            "kernel_output_feature_dimension", "kernel_spatial_dimensions",
8915            "output_batch_dimension", "output_feature_dimension",
8916            "output_spatial_dimensions"},
8917           {
8918               [&]() { return parser.parseInteger(inputBatchDimension); },
8919               [&]() { return parser.parseInteger(inputFeatureDimension); },
8920               [&]() { return parseDims(parser, inputSpatialDimensions); },
8921               [&]() {
8922                 return parser.parseInteger(kernelInputFeatureDimension);
8923               },
8924               [&]() {
8925                 return parser.parseInteger(kernelOutputFeatureDimension);
8926               },
8927               [&]() { return parseDims(parser, kernelSpatialDimensions); },
8928               [&]() { return parser.parseInteger(outBatchDimension); },
8929               [&]() { return parser.parseInteger(outputFeatureDimension); },
8930               [&]() { return parseDims(parser, outputSpatialDimensions); },
8931           }))) {
8932     parser.emitError(parser.getCurrentLocation())
8933         << "failed parsing dot dimension numbers attribute";
8934     return failure();
8935   }
8936   dnums = ConvDimensionNumbersAttr::get(
8937       parser.getBuilder().getContext(), inputBatchDimension,
8938       inputFeatureDimension, inputSpatialDimensions,
8939       kernelInputFeatureDimension, kernelOutputFeatureDimension,
8940       kernelSpatialDimensions, outBatchDimension, outputFeatureDimension,
8941       outputSpatialDimensions);
8942   return success();
8943 }
8944 
parseConvolutionDimensions(AsmParser & parser,ConvDimensionNumbersAttr & dnums)8945 ParseResult parseConvolutionDimensions(AsmParser& parser,
8946                                        ConvDimensionNumbersAttr& dnums) {
8947   // Parsing a single set of dim numbers gives the spatial dimensions as a
8948   // single ArrayRef<int64_t> and a list of non-spatial dimensions as
8949   // IntegerAttrs (indexed by the NonSpatialDim enum).
8950   using parse_dim_result_t =
8951       std::pair<llvm::SmallVector<int64_t>,
8952                 llvm::SmallDenseMap<NonSpatialDim, int64_t, 4,
8953                                     DenseMapInfoNonSpatialDim>>;
8954 
8955   // Note that the allowed_non_spatial_dims is a set (as opposed to unordered
8956   // set) because its used to print a list of allowed non spatial dims in the
8957   // error messages, so making it a set keeps the error messages deterministic.
8958   auto parseDims =
8959       [&](std::set<NonSpatialDim, std::greater<>> allowedNonSpatialDims,
8960           parse_dim_result_t& parsedDims) -> ParseResult {
8961     auto& spatialDims = std::get<0>(parsedDims);
8962     auto& nonSpatialDims = std::get<1>(parsedDims);
8963     spatialDims.clear();
8964     nonSpatialDims.clear();
8965 
8966     // Parse the starting [
8967     if (parser.parseLSquare()) {
8968       return failure();
8969     }
8970 
8971     llvm::SmallDenseMap<int64_t, int64_t> spatialDimsMap;
8972     constexpr int64_t kInvalidDimension = -1;
8973     // Keep track of the maximum spatial dimension parsed as we expect to see
8974     // all the dimensions from 0 to maximum dimension parsed.
8975     int64_t maxParsedSpatialDim = kInvalidDimension;
8976 
8977     int64_t index = 0;
8978     do {
8979       int64_t spatialDim;
8980       auto dimLocation = parser.getCurrentLocation();
8981       OptionalParseResult parseResult = parser.parseOptionalInteger(spatialDim);
8982       if (parseResult.hasValue()) {
8983         if (parseResult.getValue().failed()) {
8984           return failure();
8985         }
8986         // We were successful in parsing an integer. Check if it is a valid
8987         // dimension (non-negative and no duplicate) and add its index to the
8988         // spatial dims map.
8989         if (spatialDim < 0)
8990           return parser.emitError(dimLocation)
8991                  << "Unexpected dimension " << spatialDim;
8992         if (!spatialDimsMap
8993                  .insert(std::pair<int64_t, int64_t>(spatialDim, index))
8994                  .second)
8995           return parser.emitError(dimLocation)
8996                  << "Duplicate entries for spatial dimension " << spatialDim;
8997         maxParsedSpatialDim = std::max(spatialDim, maxParsedSpatialDim);
8998       } else if (!parser.parseOptionalQuestion()) {
8999         // Do nothing other than increment `index` at the bottom of the loop;
9000         // '?' means "unknown dimension", and it's not represented in the
9001         // return value of this function.
9002       } else {
9003         // We did not parse an integer or question mark. We expect a keyword
9004         // token.
9005         StringRef keyword;
9006         if (parser.parseKeyword(&keyword)) {
9007           return failure();
9008         }
9009         if (keyword.size() != 1 || allowedNonSpatialDims.empty()) {
9010           return parser.emitError(dimLocation, "Unexpected keyword ")
9011                  << keyword;
9012         }
9013         // Check if the keyword matches one of the allowed non-spatial dims.
9014         // If so, add it to the non_spatial dims and remove it from the
9015         // allowed set so that it won't be allowed again.
9016         bool isAllowed = false;
9017         for (NonSpatialDim allowed : allowedNonSpatialDims) {
9018           if (keyword[0] == nonSpatialDimToString(allowed)) {
9019             nonSpatialDims.insert({allowed, index});
9020             allowedNonSpatialDims.erase(allowed);
9021             isAllowed = true;
9022             break;
9023           }
9024         }
9025 
9026         if (!isAllowed) {
9027           mlir::InFlightDiagnostic diag =
9028               parser.emitError(dimLocation, "Unexpected dimension ");
9029           diag << keyword << ", expecting ";
9030           llvm::interleaveComma(
9031               allowedNonSpatialDims, diag,
9032               [&](NonSpatialDim dim) { diag << nonSpatialDimToString(dim); });
9033           return diag;
9034         }
9035       }
9036       index++;
9037     } while (parser.parseOptionalComma().succeeded());
9038 
9039     // Make sure all expected non-spatial dimensions are parsed.
9040     if (!allowedNonSpatialDims.empty()) {
9041       mlir::InFlightDiagnostic diag =
9042           parser.emitError(parser.getCurrentLocation(), "Expected dimensions ");
9043       llvm::interleaveComma(
9044           allowedNonSpatialDims, diag,
9045           [&](NonSpatialDim dim) { diag << nonSpatialDimToString(dim); });
9046       diag << " not specified";
9047       return diag;
9048     }
9049 
9050     // parse ending ]
9051     if (parser.parseRSquare()) {
9052       return failure();
9053     }
9054 
9055     // Number of expected spatial dimensions is one more than the maximum parsed
9056     // spatial dimension. For example, if we parse [0, 3, 2, b, i, 1], then the
9057     // maximum parsed spatial dimension is 3 and the number of expected spatial
9058     // dimensions is 4.
9059     int64_t numSpatialDimensions = maxParsedSpatialDim + 1;
9060     spatialDims.resize(numSpatialDimensions);
9061     // Store spatial dimensions in a vector which maps spatial dim (vector
9062     // index) -> index in the tensor dimensions. For example, for parsed
9063     // dimension numbers [0, 3, 2, b, i, 1] the spatial dimension vector would
9064     // be [0, 5, 2, 1].
9065     //
9066     // Get all the unspecified spatial dimensions to throw a more descriptive
9067     // error later.
9068     llvm::SmallVector<int64_t> unspecifiedSpatialDims;
9069     constexpr int kPrintUnspecifiedDimsMax = 10;
9070     for (int dim = 0; dim < numSpatialDimensions; ++dim) {
9071       auto it = spatialDimsMap.find(dim);
9072       if (it == spatialDimsMap.end()) {
9073         // Have an upper bound on the number of unspecified dimensions to print
9074         // in the error message.
9075         if (unspecifiedSpatialDims.size() < kPrintUnspecifiedDimsMax)
9076           unspecifiedSpatialDims.push_back(dim);
9077         continue;
9078       }
9079       spatialDims[dim] = it->second;
9080     }
9081 
9082     // Verify that we got all spatial dimensions between 0 and maximum parsed
9083     // spatial dimension.
9084     if (!unspecifiedSpatialDims.empty()) {
9085       mlir::InFlightDiagnostic diag = parser.emitError(
9086           parser.getCurrentLocation(), "Expected spatial dimensions ");
9087       llvm::interleaveComma(unspecifiedSpatialDims, diag);
9088       diag << " not specified";
9089       return diag;
9090     }
9091 
9092     return success();
9093   };
9094 
9095   parse_dim_result_t parsedDims;
9096   if (parseDims({IOBatch, IOFeature}, parsedDims)) {
9097     return failure();
9098   }
9099   llvm::SmallVector<int64_t> inputSpatialDimensions = parsedDims.first;
9100   int64_t inputBatchDimension = parsedDims.second[IOBatch];
9101   int64_t inputFeatureDimension = parsedDims.second[IOFeature];
9102   if (parser.parseKeyword("x")) return failure();
9103   if (parseDims({KIFeature, KOFeature}, parsedDims)) {
9104     return failure();
9105   }
9106   llvm::SmallVector<int64_t> kernelSpatialDimensions = parsedDims.first;
9107   int64_t kernelInputFeatureDimension = parsedDims.second[KIFeature];
9108   int64_t kernelOutputFeatureDimension = parsedDims.second[KOFeature];
9109   if (parser.parseArrow()) {
9110     return failure();
9111   }
9112   if (parseDims({IOBatch, IOFeature}, parsedDims)) {
9113     return failure();
9114   }
9115   llvm::SmallVector<int64_t> outputSpatialDimensions = parsedDims.first;
9116   const int64_t outBatchDimension = parsedDims.second[IOBatch];
9117   const int64_t outputFeatureDimension = parsedDims.second[IOFeature];
9118   dnums = ConvDimensionNumbersAttr::get(
9119       parser.getBuilder().getContext(), inputBatchDimension,
9120       inputFeatureDimension, inputSpatialDimensions,
9121       kernelInputFeatureDimension, kernelOutputFeatureDimension,
9122       kernelSpatialDimensions, outBatchDimension, outputFeatureDimension,
9123       outputSpatialDimensions);
9124 
9125   return success();
9126 }
9127 
parse(AsmParser & parser,Type type)9128 Attribute ConvDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
9129   if (failed(parser.parseLess())) return {};
9130   ConvDimensionNumbersAttr dnums;
9131   if (succeeded(parser.parseOptionalKeyword("raw"))) {
9132     if (failed(parseConvolutionDimensionsRaw(parser, dnums))) return {};
9133     return dnums;
9134   }
9135   if (failed(parseConvolutionDimensions(parser, dnums))) return {};
9136   if (failed(parser.parseGreater())) return {};
9137   return dnums;
9138 }
9139 
9140 // Custom printer and parser for ArgResultAliasAttr.
9141 constexpr char kMustAlias[] = "must_alias";
9142 constexpr char kResult[] = "result_index";
9143 constexpr char kArgTupleIndices[] = "tuple_indices";
9144 
print(AsmPrinter & printer) const9145 void ArgResultAliasAttr::print(AsmPrinter& printer) const {
9146   printer << "<";
9147 
9148   // The attribute can have empty tuple indices. Only print argument tuple
9149   // indices if they are non-empty.
9150   if (!getArgTupleIndices().empty())
9151     printer << kArgTupleIndices << " = [" << getArgTupleIndices() << "], ";
9152 
9153   // Print the result index followed by any result tuple indices if present.
9154   printer << kResult << " = [";
9155   printer << getResultIndex();
9156   if (!getResultTupleIndices().empty()) {
9157     printer << ", " << getResultTupleIndices();
9158   }
9159   printer << "]";
9160 
9161   // Print the "must_alias" keyword if this is a must alias, otherwise skip.
9162   if (getIsMustAlias()) printer << ", " << kMustAlias;
9163 
9164   printer << ">";
9165 }
9166 
parse(AsmParser & parser,Type type)9167 Attribute ArgResultAliasAttr::parse(AsmParser& parser, Type type) {
9168   if (failed(parser.parseLess())) return {};
9169   llvm::SmallVector<int64_t> argTupleIndices;
9170   // The first element of result indices holds the aliased result index and the
9171   // remaining elements are the result tuple indices.
9172   llvm::SmallVector<int64_t> resultIndices;
9173   bool isMustAlias = false;
9174 
9175   // This conveys to parseStruct that keyword "must_alias" (3rd field) is not
9176   // followed by a "=", but other fields are.
9177   llvm::SmallVector<bool, 3> parseEqual = {true, true, false};
9178 
9179   if (failed(parseStruct(parser, {kArgTupleIndices, kResult, kMustAlias},
9180                          {[&]() { return parseDims(parser, argTupleIndices); },
9181                           [&]() {
9182                             // Since the first element is the index of result,
9183                             // at least one element is expected.
9184                             return parseDimsWithMinimumElements(
9185                                 parser, resultIndices, /*minElements=*/1);
9186                           },
9187                           [&]() {
9188                             // always succeeds if the keyword "must_alias" was
9189                             // parsed
9190                             isMustAlias = true;
9191                             return success();
9192                           }},
9193                          parseEqual))) {
9194     parser.emitError(parser.getCurrentLocation())
9195         << "failed parsing argument-result alias attribute";
9196     return {};
9197   }
9198 
9199   int64_t resultIndex = resultIndices[0];
9200   auto resultTupleIndices =
9201       ArrayRef<int64_t>{resultIndices.begin() + 1, resultIndices.end()};
9202 
9203   return ArgResultAliasAttr::get(parser.getContext(), argTupleIndices,
9204                                  resultIndex, resultTupleIndices, isMustAlias);
9205 }
9206 
9207 // Returns the element type pointed to by `indices` in type `t`. If the indices
9208 // are invalid, returns nullptr.
getTypeFromTupleIndices(Type type,ArrayRef<int64_t> indices)9209 static Type getTypeFromTupleIndices(Type type, ArrayRef<int64_t> indices) {
9210   Type current = type;
9211   for (auto index : indices) {
9212     TupleType tupleType = current.dyn_cast<TupleType>();
9213     if (!tupleType || index >= static_cast<int64_t>(tupleType.size()))
9214       return {};
9215     current = tupleType.getType(index);
9216   }
9217   return current;
9218 }
9219 
verifyArgResultAliasAttr(StringAttr attrName,ArgResultAliasAttr aliasAttr,unsigned argIndex,Operation * op)9220 static LogicalResult verifyArgResultAliasAttr(StringAttr attrName,
9221                                               ArgResultAliasAttr aliasAttr,
9222                                               unsigned argIndex,
9223                                               Operation* op) {
9224   // The attribute can only be applied to function-like operations.
9225   if (!isa<mlir::FunctionOpInterface>(op))
9226     return op->emitOpError() << "attribute " << attrName
9227                              << " can only be used on function-like operations";
9228 
9229   // Verify there are no negative indices.
9230   auto tupleIndices = llvm::concat<const int64_t>(
9231       aliasAttr.getArgTupleIndices(), aliasAttr.getResultTupleIndices());
9232   if (llvm::any_of(tupleIndices, [](const int64_t val) { return val < 0; }) ||
9233       aliasAttr.getResultIndex() < 0)
9234     return op->emitOpError()
9235            << "attribute " << attrName
9236            << " expects all argument and result indices to be >= 0";
9237 
9238   // Verify that the result index is not out of range. Since the attribute is a
9239   // function argument attribute, the argument index is always correct when this
9240   // verifier is called.
9241   FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
9242   ArrayRef<Type> argTypes = funcOp.getArgumentTypes();
9243   ArrayRef<Type> resultTypes = funcOp.getResultTypes();
9244   if (aliasAttr.getResultIndex() >= static_cast<int64_t>(resultTypes.size()))
9245     return op->emitOpError()
9246            << "attribute " << attrName
9247            << " result index is out of range, must be <" << resultTypes.size();
9248 
9249   // Verify that argument and result types pointed to by the indices are valid
9250   // and compatible.
9251   Type argType = getTypeFromTupleIndices(argTypes[argIndex],
9252                                          aliasAttr.getArgTupleIndices());
9253   if (!argType)
9254     return op->emitOpError()
9255            << "attribute " << attrName << " argument tuple indices are invalid";
9256   Type resultType =
9257       getTypeFromTupleIndices(resultTypes[aliasAttr.getResultIndex()],
9258                               aliasAttr.getResultTupleIndices());
9259   if (!resultType)
9260     return op->emitOpError()
9261            << "attribute " << attrName << " result tuple indices are invalid";
9262 
9263   if (failed(mlir::verifyCompatibleShape(argType, resultType)) ||
9264       getElementTypeOrSelf(argType) != getElementTypeOrSelf(resultType))
9265     return op->emitOpError() << "attribute " << attrName
9266                              << " aliases do not have compatible types, "
9267                              << argType << " vs. " << resultType;
9268   return success();
9269 }
9270 
9271 //===----------------------------------------------------------------------===//
9272 // Type utilities
9273 //===----------------------------------------------------------------------===//
9274 
getExpressedTypeOrSelf(Type type)9275 Type getExpressedTypeOrSelf(Type type) {
9276   auto quantType = type.dyn_cast<quant::QuantizedType>();
9277   return quantType ? quantType.getExpressedType() : type;
9278 }
9279 
verifyCompatibleShapeWithBounds(Type type1,Type type2)9280 LogicalResult verifyCompatibleShapeWithBounds(Type type1, Type type2) {
9281   if (failed(verifyCompatibleShape(type1, type2))) return failure();
9282 
9283   // Verify shapes against bounds
9284   auto isCompatible = [](ArrayRef<int64_t> shape,
9285                          TypeExtensionsAttr extensionAttr) {
9286     if (shape.empty() || !extensionAttr) return true;
9287     auto bounds = extensionAttr.getBounds();
9288     for (auto [dim_size, bound] : llvm::zip(shape, bounds))  // NOLINT
9289       if (bound != ShapedType::kDynamicSize && bound < dim_size) return false;
9290     return true;
9291   };
9292 
9293   RankedTensorType rankedType1 = type1.dyn_cast<RankedTensorType>();
9294   RankedTensorType rankedType2 = type2.dyn_cast<RankedTensorType>();
9295   if (rankedType1 && rankedType2) {
9296     TypeExtensionsAttr extensionAttr1 =
9297         rankedType1.getEncoding().dyn_cast_or_null<TypeExtensionsAttr>();
9298     TypeExtensionsAttr extensionAttr2 =
9299         rankedType2.getEncoding().dyn_cast_or_null<TypeExtensionsAttr>();
9300     return LogicalResult::success(
9301         isCompatible(rankedType1.getShape(), extensionAttr2) &&
9302         isCompatible(rankedType2.getShape(), extensionAttr1));
9303   }
9304   return success();
9305 }
9306 
isCompatibleForMhloTypeInference(Type tp1,Type tp2)9307 bool isCompatibleForMhloTypeInference(Type tp1, Type tp2) {
9308   // Dynamism: We don't require shapes to be the same, we only require them
9309   // to be compatible, which means that:
9310   //   1) At least one of the shapes is unranked.
9311   //   2) Or both shapes have the same rank and their dimensions are compatible,
9312   //     i.e. for each pair of corresponding dimensions:
9313   //       2.1) At least one of the dimensions is dynamic,
9314   //       2.2) Or both dimensions are equal.
9315   // These relaxed rules simplify the implementation of type inference, allowing
9316   // ops with partially inferred types to pass verification.
9317   auto stp1 = tp1.dyn_cast<ShapedType>();
9318   auto stp2 = tp2.dyn_cast<ShapedType>();
9319   if (stp1 && stp2) {
9320     return succeeded(verifyCompatibleShapeWithBounds(stp1, stp2)) &&
9321            isCompatibleForMhloTypeInference(stp1.getElementType(),
9322                                             stp2.getElementType());
9323   }
9324 
9325   // Quantization: In the most general case, we allow any combination of
9326   // quantized/non-quantized across any combination of operands/results,
9327   // and some differences in quantization parameters across operands/results.
9328   // Individual ops may introduce additional constraints.
9329   auto qtp1 = tp1.dyn_cast<quant::QuantizedType>();
9330   auto qtp2 = tp2.dyn_cast<quant::QuantizedType>();
9331   if (qtp1 && qtp2) {
9332     if (qtp1.getStorageType() != qtp2.getStorageType() ||
9333         qtp1.getStorageTypeMin() != qtp2.getStorageTypeMin() ||
9334         qtp1.getStorageTypeMax() != qtp2.getStorageTypeMax())
9335       return false;
9336   }
9337   auto etp1 = getExpressedTypeOrSelf(tp1);
9338   auto etp2 = getExpressedTypeOrSelf(tp2);
9339 
9340   // Sparsity: In the most general case, we allow any combination of
9341   // sparsity/denseness across any combination of operands/results, as well as
9342   // differences in sparsity encodings for operands and results.
9343   // Individual ops may introduce additional constraints.
9344   // No additional code is needed to check this because of how sparsity is
9345   // currently implemented.
9346 
9347   // Default case: Unless dynamism, quantization and/or sparsity are involved,
9348   // the types are required to be exactly equal.
9349   return etp1 == etp2;
9350 }
9351 
9352 //===----------------------------------------------------------------------===//
9353 // Builder utilities
9354 //===----------------------------------------------------------------------===//
9355 
9356 // Builds the region `body` for mhlo.sort's comparator: for each type in
9357 // `element_types`, create two block arguments, one for lhs and one for rhs, and
9358 // generates mhlo.compare op to compare them with the given `direction`.
9359 //
9360 // Note that this right now only does comparision on the first pair of block
9361 // arguments.
buildSortComparisonBody(llvm::ArrayRef<Type> elementTypes,ComparisonDirection direction,llvm::Optional<StringRef> compareType,Region * body,OpBuilder * builder)9362 static void buildSortComparisonBody(llvm::ArrayRef<Type> elementTypes,
9363                                     ComparisonDirection direction,
9364                                     llvm::Optional<StringRef> compareType,
9365                                     Region* body, OpBuilder* builder) {
9366   OpBuilder::InsertionGuard insertionPointGurad(*builder);
9367 
9368   Location loc = body->getLoc();
9369   Block* block = builder->createBlock(body);
9370   // Add two arguments for each element type.
9371   for (Type elementType : elementTypes) {
9372     TensorType tensorType = RankedTensorType::get({}, elementType);
9373     block->addArguments({tensorType, tensorType},
9374                         SmallVector<Location, 2>(2, loc));
9375   }
9376 
9377   ComparisonType typeAttr;
9378   if (compareType)
9379     typeAttr = symbolizeComparisonType(*compareType).value();
9380   else
9381     typeAttr = ComparisonType::NOTYPE;
9382   Value compare = builder->create<mhlo::CompareOp>(
9383       loc, block->getArgument(0), block->getArgument(1), direction, typeAttr);
9384 
9385   builder->create<mhlo::ReturnOp>(loc, compare);
9386 }
9387 
createSortOp(PatternRewriter * rewriter,const Location & loc,const llvm::ArrayRef<Value> & operands,const llvm::ArrayRef<Type> & elementTypes,int64_t dimension,bool isStable,ComparisonDirection direction)9388 SortOp createSortOp(PatternRewriter* rewriter, const Location& loc,
9389                     const llvm::ArrayRef<Value>& operands,
9390                     const llvm::ArrayRef<Type>& elementTypes, int64_t dimension,
9391                     bool isStable, ComparisonDirection direction) {
9392   assert(!operands.empty() && "No operands to sort");
9393   // Create the sort op.
9394   auto sortOp =
9395       rewriter->create<mhlo::SortOp>(loc, operands, dimension, isStable);
9396 
9397   // Use TOTALORDER comparison type instead of the default comparison if the
9398   // element type is of type float.
9399   llvm::Optional<StringRef> compareType = llvm::None;
9400   for (auto const& elementType : elementTypes)
9401     if (elementType.isa<FloatType>()) {
9402       compareType.emplace("TOTALORDER");
9403       break;
9404     }
9405   buildSortComparisonBody(elementTypes, direction, compareType,
9406                           &sortOp.comparator(), rewriter);
9407   return sortOp;
9408 }
9409 
9410 //===----------------------------------------------------------------------===//
9411 // Shape inference
9412 //===----------------------------------------------------------------------===//
9413 
deriveShapeFromOperand(OpBuilder * builder,Operation * op,Value operand,SmallVectorImpl<Value> * reifiedReturnShapes)9414 LogicalResult deriveShapeFromOperand(
9415     OpBuilder* builder, Operation* op, Value operand,
9416     SmallVectorImpl<Value>* reifiedReturnShapes) {
9417   auto shapedTy = operand.getType().dyn_cast<ShapedType>();
9418   if (!shapedTy) {
9419     op->emitOpError() << "operand is not a shaped type";
9420     return failure();
9421   }
9422   reifiedReturnShapes->assign(
9423       {builder->create<shape::ShapeOfOp>(op->getLoc(), operand)});
9424   return success();
9425 }
9426 
9427 //===----------------------------------------------------------------------===//
9428 // MHLO Dialect Hooks
9429 //===----------------------------------------------------------------------===//
9430 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)9431 Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value,
9432                                             Type type, Location loc) {
9433   auto elementsAttr = value.dyn_cast<ElementsAttr>();
9434   // HLO dialect constants only support ElementsAttr unlike standard dialect
9435   // constant which supports all attributes.
9436   if (!elementsAttr) return nullptr;
9437   // HLO dialect constants require the type of value and result to match.
9438   if (type != elementsAttr.getType()) return nullptr;
9439 
9440   return builder.create<mhlo::ConstantOp>(loc, type, elementsAttr);
9441 }
9442 
verifyRegionArgAttribute(Operation * op,unsigned,unsigned argIndex,NamedAttribute attr)9443 LogicalResult MhloDialect::verifyRegionArgAttribute(Operation* op,
9444                                                     unsigned /*regionIndex*/,
9445                                                     unsigned argIndex,
9446                                                     NamedAttribute attr) {
9447   if (auto aliasAttr = attr.getValue().dyn_cast<ArgResultAliasAttr>()) {
9448     if (failed(
9449             verifyArgResultAliasAttr(attr.getName(), aliasAttr, argIndex, op)))
9450       return failure();
9451   }
9452   return success();
9453 }
9454 
verifyOperationAttribute(Operation * op,NamedAttribute attr)9455 LogicalResult MhloDialect::verifyOperationAttribute(Operation* op,
9456                                                     NamedAttribute attr) {
9457   if (auto aliasAttr = attr.getValue().dyn_cast<ArgResultAliasAttr>()) {
9458     if (!isa<mlir::FunctionOpInterface>(op))
9459       return op->emitOpError()
9460              << "attribute " << attr.getName()
9461              << " can only be used on function-like operations";
9462   }
9463   return success();
9464 }
9465 
9466 }  // namespace mhlo
9467 }  // namespace mlir
9468