1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
17
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "mlir-hlo/utils/broadcast_utils.h"
22 #include "mlir/Dialect/Complex/IR/Complex.h"
23 #include "mlir/Dialect/Traits.h"
24 #include "mlir/IR/Attributes.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/Diagnostics.h"
28 #include "mlir/IR/Location.h"
29 #include "mlir/IR/PatternMatch.h"
30 #include "mlir/IR/TypeUtilities.h"
31 #include "mlir/IR/Value.h"
32 #include "mlir/Interfaces/InferTypeOpInterface.h"
33
34 namespace mlir {
35 namespace chlo {
36
getConstantLikeMaxFiniteValue(OpBuilder & b,Location loc,Value val)37 Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) {
38 auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
39 return getConstantLike(
40 b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val);
41 }
42
getConstantLikeInfValue(OpBuilder & b,Location loc,Value val,bool negative)43 Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val,
44 bool negative) {
45 auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
46 return getConstantLike(
47 b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val);
48 }
49
getConstantLikeSmallestFiniteValue(OpBuilder & b,Location loc,Value val)50 Value getConstantLikeSmallestFiniteValue(OpBuilder& b, Location loc,
51 Value val) {
52 auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
53 return getConstantLike(
54 b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val);
55 }
56
getConstantLike(OpBuilder & b,Location loc,const APFloat & constant,Value val)57 Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant,
58 Value val) {
59 Type ty = getElementTypeOrSelf(val.getType());
60 return b.create<ConstantLikeOp>(loc, b.getFloatAttr(ty, constant), val);
61 }
62
63 //===----------------------------------------------------------------------===//
64 // CompatibleOperandsAndResultType
65 //===----------------------------------------------------------------------===//
66
67 // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that
68 // support quantization or sparsity.
69 #define INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Op) \
70 LogicalResult Op::inferReturnTypeComponents( \
71 MLIRContext* context, Optional<Location> location, \
72 ValueShapeRange operands, DictionaryAttr attributes, \
73 RegionRange regions, \
74 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { \
75 return inferReturnTypeComponentsFromOperands(context, location, operands, \
76 attributes, regions, \
77 inferredReturnShapes); \
78 }
79
80 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AcosOp)
81 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AcoshOp)
82 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AsinOp)
83 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AsinhOp)
84 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AtanOp)
85 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AtanhOp)
86 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(BesselI1eOp)
87 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ConjOp)
88 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CoshOp)
89 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DigammaOp)
90 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfOp)
91 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfcOp)
92 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LgammaOp)
93 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NextAfterOp)
94 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PolygammaOp)
95 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SinhOp)
96 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(TanOp)
97 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ZetaOp)
98
99 //===----------------------------------------------------------------------===//
100 // BinaryOps
101 //===----------------------------------------------------------------------===//
102
103 namespace {
104 // Gets the resulting type from a broadcast between two types.
getBroadcastType(Type x,Type y,Type elementType,DenseIntElementsAttr broadcastDimensionsAttr)105 ShapedTypeComponents getBroadcastType(
106 Type x, Type y, Type elementType,
107 DenseIntElementsAttr broadcastDimensionsAttr) {
108 auto xRanked = x.dyn_cast<RankedTensorType>();
109 auto yRanked = y.dyn_cast<RankedTensorType>();
110 if (!xRanked || !yRanked) {
111 return {elementType};
112 }
113
114 auto shapeX = xRanked.getShape();
115 auto shapeY = yRanked.getShape();
116
117 // If no broadcast dimensions, assume "numpy" broadcasting.
118 if (shapeX.size() == shapeY.size() || !broadcastDimensionsAttr) {
119 llvm::SmallVector<int64_t, 4> outShape;
120 if (!mlir::OpTrait::util::getBroadcastedShape(shapeX, shapeY, outShape)) {
121 // Signal illegal broadcast_dimensions as unranked.
122 return {elementType};
123 }
124 return {outShape, elementType};
125 }
126
127 auto shapeLarge = shapeX.size() > shapeY.size() ? shapeX : shapeY;
128 auto shapeSmall = shapeX.size() <= shapeY.size() ? shapeX : shapeY;
129
130 auto broadcastDimensions = broadcastDimensionsAttr.getValues<APInt>();
131 if (broadcastDimensions.size() != shapeSmall.size()) {
132 // Signal illegal broadcast_dimensions as unranked.
133 return {elementType};
134 }
135
136 llvm::SmallVector<int64_t, 4> shapeLargeFiltered;
137 shapeLargeFiltered.reserve(shapeSmall.size());
138 for (const auto& dim : broadcastDimensions) {
139 if (dim.getZExtValue() >= shapeLarge.size()) return {elementType};
140 shapeLargeFiltered.push_back(shapeLarge[dim.getZExtValue()]);
141 }
142 llvm::SmallVector<int64_t, 4> outShapeFiltered;
143 if (!mlir::OpTrait::util::getBroadcastedShape(shapeSmall, shapeLargeFiltered,
144 outShapeFiltered)) {
145 // Signal illegal broadcast_dimensions as unranked.
146 return {elementType};
147 }
148
149 // Update according to the broadcast dimensions.
150 llvm::SmallVector<int64_t, 4> outShape(shapeLarge.begin(), shapeLarge.end());
151 for (const auto& indexPair : llvm::enumerate(broadcastDimensions)) {
152 auto newValue = outShapeFiltered[indexPair.index()];
153 outShape[indexPair.value().getZExtValue()] = newValue;
154 }
155
156 return {outShape, elementType};
157 }
158
InferBroadcastBinaryOpReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,Type elementType,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)159 LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
160 MLIRContext* context, Optional<Location> location, ValueRange operands,
161 DictionaryAttr attributes, Type elementType,
162 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
163 // Find broadcast_dimensions.
164 DenseIntElementsAttr broadcastDimensions =
165 attributes.get("broadcast_dimensions")
166 .dyn_cast_or_null<DenseIntElementsAttr>();
167
168 ShapedType lhsType = operands[0].getType().dyn_cast<ShapedType>();
169 ShapedType rhsType = operands[1].getType().dyn_cast<ShapedType>();
170 if (!lhsType || !rhsType ||
171 lhsType.getElementType() != rhsType.getElementType()) {
172 return emitOptionalError(location, "mismatched operand types");
173 }
174 if (!elementType) elementType = lhsType.getElementType();
175 inferredReturnShapes.push_back(
176 getBroadcastType(lhsType, rhsType, elementType, broadcastDimensions));
177 return success();
178 }
179
ReifyBroadcastBinaryOpReturnTypeShapes(OpBuilder & builder,Operation * op,ValueRange operands,SmallVectorImpl<Value> & result)180 LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
181 OpBuilder& builder, Operation* op, ValueRange operands,
182 SmallVectorImpl<Value>& result) {
183 assert(operands.size() == 2 && "expect binary op");
184 auto loc = op->getLoc();
185 auto lhs = operands[0];
186 auto rhs = operands[1];
187
188 // Check for "numpy"-style rank broadcast.
189 auto broadcastDimensions = op->getAttr("broadcast_dimensions")
190 .dyn_cast_or_null<DenseIntElementsAttr>();
191 if (broadcastDimensions &&
192 !hlo::isLegalNumpyRankedBroadcast(lhs, rhs, broadcastDimensions)) {
193 // Note: It is unclear whether the general specification of explicit
194 // broadcast_dimensions on binary ops is a feature we want to carry
195 // forward. While it can technically be implemented for ranked-dynamic,
196 // it is incompatible with unranked inputs. If this warning is emitted
197 // in real programs, it is an indication that the feature should be
198 // implemented versus just falling back on the more standard definition
199 // of numpy-like prefix-padding.
200 return op->emitWarning()
201 << "unsupported non prefix-padded dynamic rank "
202 << "broadcast_dimensions = " << broadcastDimensions;
203 }
204
205 result.push_back(hlo::computeBinaryElementwiseBroadcastingResultExtents(
206 loc, lhs, rhs, builder));
207 return success();
208 }
209 } // namespace
210
211 //===----------------------------------------------------------------------===//
212 // BroadcastComplexOp (has custom type inference due to different result type).
213 //===----------------------------------------------------------------------===//
214
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)215 LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
216 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
217 DictionaryAttr attributes, RegionRange /*regions*/,
218 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
219 ShapedType lhsType = operands[0].getType().dyn_cast<ShapedType>();
220 if (!lhsType) {
221 return emitOptionalError(location, "expected ShapedType");
222 }
223 Type elementType = ComplexType::get(lhsType.getElementType());
224 return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
225 attributes, elementType,
226 inferedReturnShapes);
227 }
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)228 LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
229 OpBuilder& builder, ValueRange operands,
230 SmallVectorImpl<Value>& reifiedReturnShapes) {
231 return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
232 operands, reifiedReturnShapes);
233 }
234
235 //===----------------------------------------------------------------------===//
236 // BroadcastCompareOp (has custom type inference due to different result type).
237 //===----------------------------------------------------------------------===//
238
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,DenseIntElementsAttr broadcastDimensions,mhlo::ComparisonDirection comparisonDirection,mhlo::ComparisonType compareType)239 void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
240 Value lhs, Value rhs,
241 DenseIntElementsAttr broadcastDimensions,
242 mhlo::ComparisonDirection comparisonDirection,
243 mhlo::ComparisonType compareType) {
244 build(builder, result, lhs, rhs, broadcastDimensions,
245 mhlo::ComparisonDirectionAttr::get(builder.getContext(),
246 comparisonDirection),
247 mhlo::ComparisonTypeAttr::get(builder.getContext(), compareType));
248 }
249
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)250 LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
251 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
252 DictionaryAttr attributes, RegionRange /*regions*/,
253 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
254 Type elementType = IntegerType::get(context, 1);
255 return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
256 attributes, elementType,
257 inferedReturnShapes);
258 }
259
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)260 LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
261 OpBuilder& builder, ValueRange operands,
262 SmallVectorImpl<Value>& reifiedReturnShapes) {
263 return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
264 operands, reifiedReturnShapes);
265 }
266
267 //===----------------------------------------------------------------------===//
268 // IsInfOp
269 //===----------------------------------------------------------------------===//
270
getIsInfLikeReturnType(Value operand)271 static Type getIsInfLikeReturnType(Value operand) {
272 Builder b(operand.getContext());
273 return mhlo::getSameShapeTensorType(operand.getType().cast<TensorType>(),
274 b.getI1Type());
275 }
276
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)277 LogicalResult IsInfOp::inferReturnTypes(
278 MLIRContext* /*ctx*/, Optional<Location>, ValueRange operands,
279 DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
280 inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
281 return success();
282 }
283
284 //===----------------------------------------------------------------------===//
285 // IsNegInfOp
286 //===----------------------------------------------------------------------===//
287
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)288 LogicalResult IsNegInfOp::inferReturnTypes(
289 MLIRContext* /*ctx*/, Optional<Location>, ValueRange operands,
290 DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
291 inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
292 return success();
293 }
294
295 //===----------------------------------------------------------------------===//
296 // IsPosInfOp
297 //===----------------------------------------------------------------------===//
298
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)299 LogicalResult IsPosInfOp::inferReturnTypes(
300 MLIRContext* /*ctx*/, Optional<Location>, ValueRange operands,
301 DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
302 inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
303 return success();
304 }
305
306 //===----------------------------------------------------------------------===//
307 // Macros for method definitions that are common to most broadcasting ops.
308 //===----------------------------------------------------------------------===//
309
310 #define BROADCAST_BINARY_OP_DEFS(Op) \
311 LogicalResult Op::inferReturnTypeComponents( \
312 MLIRContext* context, Optional<Location> location, \
313 ValueShapeRange operands, DictionaryAttr attributes, \
314 RegionRange regions, \
315 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { \
316 return InferBroadcastBinaryOpReturnTypeComponents( \
317 context, location, operands, attributes, /*element_type=*/nullptr, \
318 inferedReturnShapes); \
319 } \
320 LogicalResult Op::reifyReturnTypeShapes( \
321 OpBuilder& builder, ValueRange operands, \
322 SmallVectorImpl<Value>& reifiedReturnShapes) { \
323 return ReifyBroadcastBinaryOpReturnTypeShapes( \
324 builder, getOperation(), operands, reifiedReturnShapes); \
325 }
326
327 BROADCAST_BINARY_OP_DEFS(BroadcastAddOp);
328 BROADCAST_BINARY_OP_DEFS(BroadcastAndOp);
329 BROADCAST_BINARY_OP_DEFS(BroadcastAtan2Op);
330 BROADCAST_BINARY_OP_DEFS(BroadcastDivOp);
331 BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp);
332 BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
333 BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
334 BROADCAST_BINARY_OP_DEFS(BroadcastNextAfterOp);
335 BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
336 BROADCAST_BINARY_OP_DEFS(BroadcastPolygammaOp);
337 BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
338 BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
339 BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);
340 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp);
341 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
342 BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
343 BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
344 BROADCAST_BINARY_OP_DEFS(BroadcastZetaOp);
345
346 #undef BROADCAST_BINARY_OP_DEFS
347
verify()348 LogicalResult ConstantLikeOp::verify() {
349 if (value().getType() != getType().cast<ShapedType>().getElementType())
350 return emitOpError() << "value's type doesn't match element return type";
351 return success();
352 }
353
354 //===----------------------------------------------------------------------===//
355 // MinimumBroadcastShapesOp
356 //===----------------------------------------------------------------------===//
verify()357 LogicalResult MinimumBroadcastShapesOp::verify() {
358 // Check that the number of operands matches the number of outputs.
359 unsigned resultShapesCount = results().size();
360 unsigned operandShapesCount = shapes().size();
361 if (operandShapesCount != resultShapesCount) {
362 return emitOpError() << "number of operand shapes (" << operandShapesCount
363 << ") does not match number of result shapes ("
364 << resultShapesCount << ")";
365 }
366 if (operandShapesCount < 2) {
367 return emitOpError() << "number of operand shapes (" << operandShapesCount
368 << ") should be >= 2";
369 }
370 return success();
371 }
372
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)373 LogicalResult ConstantLikeOp::inferReturnTypeComponents(
374 MLIRContext* /*context*/, Optional<Location> location,
375 ValueShapeRange operands, DictionaryAttr attributes,
376 RegionRange /*regions*/,
377 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
378 ConstantLikeOp::Adaptor op(operands, attributes);
379 if (failed(op.verify(location.value()))) return failure();
380 Type elementType = op.value().getType();
381 Type operandType = op.operand().getType();
382 if (operandType.isa<UnrankedTensorType>()) {
383 inferedReturnShapes.emplace_back(elementType);
384 } else {
385 const auto& shape = operandType.cast<RankedTensorType>().getShape();
386 inferedReturnShapes.emplace_back(shape, elementType);
387 }
388 return success();
389 }
390
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)391 LogicalResult ConstantLikeOp::reifyReturnTypeShapes(
392 OpBuilder& builder, ValueRange operands,
393 SmallVectorImpl<Value>& reifiedReturnShapes) {
394 return ::mlir::mhlo::deriveShapeFromOperand(
395 &builder, getOperation(), operands.front(), &reifiedReturnShapes);
396 }
397
fold(ArrayRef<Attribute>)398 OpFoldResult ConstantLikeOp::fold(ArrayRef<Attribute> /*operands*/) {
399 auto opType = operand().getType().cast<ShapedType>();
400 if (!opType.hasStaticShape()) return {};
401 auto type = RankedTensorType::get(opType.getShape(), value().getType());
402 if (auto complexAttr = value().dyn_cast<complex::NumberAttr>())
403 return DenseElementsAttr::get(type, complexAttr.getValue());
404 return DenseElementsAttr::get(type, value());
405 }
406
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)407 LogicalResult BroadcastSelectOp::inferReturnTypeComponents(
408 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
409 DictionaryAttr, RegionRange,
410 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
411 BroadcastSelectOp::Adaptor op(operands.getValues());
412 auto predType = op.pred().getType().dyn_cast<ShapedType>();
413 auto onTrueType = op.on_true().getType().dyn_cast<ShapedType>();
414 auto onFalseType = op.on_false().getType().dyn_cast<ShapedType>();
415
416 if (!predType || !onTrueType || !onFalseType ||
417 onTrueType.getElementType() != onFalseType.getElementType()) {
418 return emitOptionalError(location, "mismatched operand types");
419 }
420
421 Type elementType = onTrueType.getElementType();
422
423 // Compute the result shape as two binary broadcasts.
424 ShapedTypeComponents& components = inferredReturnShapes.emplace_back(
425 getBroadcastType(onTrueType, onFalseType, elementType, nullptr));
426 if (components.hasRank()) {
427 components = getBroadcastType(
428 RankedTensorType::get(components.getDims(), elementType), predType,
429 elementType, nullptr);
430 }
431 return success();
432 }
433
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & result)434 LogicalResult BroadcastSelectOp::reifyReturnTypeShapes(
435 OpBuilder& builder, ValueRange operands, SmallVectorImpl<Value>& result) {
436 result.push_back(hlo::computeNaryElementwiseBroadcastingResultExtents(
437 getLoc(), operands, builder));
438 return success();
439 }
440
441 //===----------------------------------------------------------------------===//
442 // RankSpecializationClusterOp
443 //===----------------------------------------------------------------------===//
444
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute>,SmallVectorImpl<RegionSuccessor> & regions)445 void RankSpecializationClusterOp::getSuccessorRegions(
446 Optional<unsigned> index, ArrayRef<Attribute> /*operands*/,
447 SmallVectorImpl<RegionSuccessor>& regions) {
448 // RankSpecializationClusterOp has unconditional control flows into the region
449 // and back to the parent, so return the correct RegionSuccessor purely based
450 // on the index being None or 0.
451 if (index.has_value()) {
452 regions.push_back(RegionSuccessor(getResults()));
453 return;
454 }
455 regions.push_back(RegionSuccessor(&body()));
456 }
457
verify()458 LogicalResult RankSpecializationClusterOp::verify() {
459 if (body().getArgumentTypes() != getOperandTypes())
460 return emitOpError() << "block argument types must match operand types";
461
462 // All operands of nested ops must be defined in the body or declared by the
463 // cluster.
464 Block* body = getBody();
465 for (Operation& nested : body->without_terminator()) {
466 if (!llvm::all_of(nested.getOpOperands(), [&](OpOperand& operand) {
467 Operation* def = operand.get().getDefiningOp();
468 if (def != nullptr && def->getBlock() == body) return true;
469 return llvm::is_contained(body->getArguments(), operand.get());
470 })) {
471 return emitOpError() << "nested ops must not depend on implicit operands";
472 }
473 }
474
475 return success();
476 }
477
478 //===----------------------------------------------------------------------===//
479 // TopKOp
480 //===----------------------------------------------------------------------===//
481
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)482 LogicalResult TopKOp::inferReturnTypeComponents(
483 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
484 DictionaryAttr attributes, RegionRange regions,
485 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
486 Builder builder(context);
487 TopKOp::Adaptor adaptor(operands, attributes, regions);
488 Value operand = adaptor.operand();
489 uint64_t k = adaptor.k();
490
491 auto operandTy = operand.getType().dyn_cast<RankedTensorType>();
492 if (!operandTy) {
493 return emitOptionalError(location, "operand must be ranked");
494 }
495 if (operandTy.getRank() < 1) {
496 return emitOptionalError(location, "operand's rank must be at least 1");
497 }
498 auto operandLastDim = operandTy.getShape()[operandTy.getRank() - 1];
499 if (operandLastDim == ShapedType::kDynamicSize) {
500 return emitOptionalError(location,
501 "operand's last dimension must be static");
502 }
503 if (operandLastDim < static_cast<int64_t>(k)) {
504 return emitOptionalError(location,
505 "operand's last dimension must be at least ", k);
506 }
507
508 SmallVector<int64_t> resultShape;
509 append_range(resultShape, operandTy.getShape());
510 resultShape[operandTy.getRank() - 1] = k;
511
512 inferredReturnShapes.emplace_back(resultShape, operandTy.getElementType());
513 inferredReturnShapes.emplace_back(resultShape, builder.getI32Type());
514 return success();
515 }
516
517 } // namespace chlo
518 } // namespace mlir
519
520 #define GET_OP_CLASSES
521 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
522
523 namespace mlir {
524 namespace chlo {
525
526 //===----------------------------------------------------------------------===//
527 // chlo Dialect Constructor
528 //===----------------------------------------------------------------------===//
529
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)530 Operation* ChloDialect::materializeConstant(OpBuilder& builder, Attribute value,
531 Type type, Location loc) {
532 // Mirror MHLO dialect here.
533 if (value.isa<ElementsAttr>())
534 return builder.create<mhlo::ConstantOp>(loc, type,
535 value.cast<ElementsAttr>());
536 return nullptr;
537 }
538
initialize()539 void ChloDialect::initialize() {
540 addOperations<
541 #define GET_OP_LIST
542 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
543 >();
544 }
545
546 } // namespace chlo
547 } // namespace mlir
548