xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/chlo_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
17 
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "mlir-hlo/utils/broadcast_utils.h"
22 #include "mlir/Dialect/Complex/IR/Complex.h"
23 #include "mlir/Dialect/Traits.h"
24 #include "mlir/IR/Attributes.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/Diagnostics.h"
28 #include "mlir/IR/Location.h"
29 #include "mlir/IR/PatternMatch.h"
30 #include "mlir/IR/TypeUtilities.h"
31 #include "mlir/IR/Value.h"
32 #include "mlir/Interfaces/InferTypeOpInterface.h"
33 
34 namespace mlir {
35 namespace chlo {
36 
getConstantLikeMaxFiniteValue(OpBuilder & b,Location loc,Value val)37 Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) {
38   auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
39   return getConstantLike(
40       b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val);
41 }
42 
getConstantLikeInfValue(OpBuilder & b,Location loc,Value val,bool negative)43 Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val,
44                               bool negative) {
45   auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
46   return getConstantLike(
47       b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val);
48 }
49 
getConstantLikeSmallestFiniteValue(OpBuilder & b,Location loc,Value val)50 Value getConstantLikeSmallestFiniteValue(OpBuilder& b, Location loc,
51                                          Value val) {
52   auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
53   return getConstantLike(
54       b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val);
55 }
56 
getConstantLike(OpBuilder & b,Location loc,const APFloat & constant,Value val)57 Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant,
58                       Value val) {
59   Type ty = getElementTypeOrSelf(val.getType());
60   return b.create<ConstantLikeOp>(loc, b.getFloatAttr(ty, constant), val);
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // CompatibleOperandsAndResultType
65 //===----------------------------------------------------------------------===//
66 
67 // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that
68 // support quantization or sparsity.
69 #define INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Op)                        \
70   LogicalResult Op::inferReturnTypeComponents(                                \
71       MLIRContext* context, Optional<Location> location,                      \
72       ValueShapeRange operands, DictionaryAttr attributes,                    \
73       RegionRange regions,                                                    \
74       SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {          \
75     return inferReturnTypeComponentsFromOperands(context, location, operands, \
76                                                  attributes, regions,         \
77                                                  inferredReturnShapes);       \
78   }
79 
80 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AcosOp)
81 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AcoshOp)
82 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AsinOp)
83 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AsinhOp)
84 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AtanOp)
85 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AtanhOp)
86 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(BesselI1eOp)
87 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ConjOp)
88 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CoshOp)
89 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DigammaOp)
90 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfOp)
91 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfcOp)
92 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LgammaOp)
93 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NextAfterOp)
94 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PolygammaOp)
95 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SinhOp)
96 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(TanOp)
97 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ZetaOp)
98 
99 //===----------------------------------------------------------------------===//
100 // BinaryOps
101 //===----------------------------------------------------------------------===//
102 
103 namespace {
104 // Gets the resulting type from a broadcast between two types.
getBroadcastType(Type x,Type y,Type elementType,DenseIntElementsAttr broadcastDimensionsAttr)105 ShapedTypeComponents getBroadcastType(
106     Type x, Type y, Type elementType,
107     DenseIntElementsAttr broadcastDimensionsAttr) {
108   auto xRanked = x.dyn_cast<RankedTensorType>();
109   auto yRanked = y.dyn_cast<RankedTensorType>();
110   if (!xRanked || !yRanked) {
111     return {elementType};
112   }
113 
114   auto shapeX = xRanked.getShape();
115   auto shapeY = yRanked.getShape();
116 
117   // If no broadcast dimensions, assume "numpy" broadcasting.
118   if (shapeX.size() == shapeY.size() || !broadcastDimensionsAttr) {
119     llvm::SmallVector<int64_t, 4> outShape;
120     if (!mlir::OpTrait::util::getBroadcastedShape(shapeX, shapeY, outShape)) {
121       // Signal illegal broadcast_dimensions as unranked.
122       return {elementType};
123     }
124     return {outShape, elementType};
125   }
126 
127   auto shapeLarge = shapeX.size() > shapeY.size() ? shapeX : shapeY;
128   auto shapeSmall = shapeX.size() <= shapeY.size() ? shapeX : shapeY;
129 
130   auto broadcastDimensions = broadcastDimensionsAttr.getValues<APInt>();
131   if (broadcastDimensions.size() != shapeSmall.size()) {
132     // Signal illegal broadcast_dimensions as unranked.
133     return {elementType};
134   }
135 
136   llvm::SmallVector<int64_t, 4> shapeLargeFiltered;
137   shapeLargeFiltered.reserve(shapeSmall.size());
138   for (const auto& dim : broadcastDimensions) {
139     if (dim.getZExtValue() >= shapeLarge.size()) return {elementType};
140     shapeLargeFiltered.push_back(shapeLarge[dim.getZExtValue()]);
141   }
142   llvm::SmallVector<int64_t, 4> outShapeFiltered;
143   if (!mlir::OpTrait::util::getBroadcastedShape(shapeSmall, shapeLargeFiltered,
144                                                 outShapeFiltered)) {
145     // Signal illegal broadcast_dimensions as unranked.
146     return {elementType};
147   }
148 
149   // Update according to the broadcast dimensions.
150   llvm::SmallVector<int64_t, 4> outShape(shapeLarge.begin(), shapeLarge.end());
151   for (const auto& indexPair : llvm::enumerate(broadcastDimensions)) {
152     auto newValue = outShapeFiltered[indexPair.index()];
153     outShape[indexPair.value().getZExtValue()] = newValue;
154   }
155 
156   return {outShape, elementType};
157 }
158 
InferBroadcastBinaryOpReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,Type elementType,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)159 LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
160     MLIRContext* context, Optional<Location> location, ValueRange operands,
161     DictionaryAttr attributes, Type elementType,
162     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
163   // Find broadcast_dimensions.
164   DenseIntElementsAttr broadcastDimensions =
165       attributes.get("broadcast_dimensions")
166           .dyn_cast_or_null<DenseIntElementsAttr>();
167 
168   ShapedType lhsType = operands[0].getType().dyn_cast<ShapedType>();
169   ShapedType rhsType = operands[1].getType().dyn_cast<ShapedType>();
170   if (!lhsType || !rhsType ||
171       lhsType.getElementType() != rhsType.getElementType()) {
172     return emitOptionalError(location, "mismatched operand types");
173   }
174   if (!elementType) elementType = lhsType.getElementType();
175   inferredReturnShapes.push_back(
176       getBroadcastType(lhsType, rhsType, elementType, broadcastDimensions));
177   return success();
178 }
179 
ReifyBroadcastBinaryOpReturnTypeShapes(OpBuilder & builder,Operation * op,ValueRange operands,SmallVectorImpl<Value> & result)180 LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
181     OpBuilder& builder, Operation* op, ValueRange operands,
182     SmallVectorImpl<Value>& result) {
183   assert(operands.size() == 2 && "expect binary op");
184   auto loc = op->getLoc();
185   auto lhs = operands[0];
186   auto rhs = operands[1];
187 
188   // Check for "numpy"-style rank broadcast.
189   auto broadcastDimensions = op->getAttr("broadcast_dimensions")
190                                  .dyn_cast_or_null<DenseIntElementsAttr>();
191   if (broadcastDimensions &&
192       !hlo::isLegalNumpyRankedBroadcast(lhs, rhs, broadcastDimensions)) {
193     // Note: It is unclear whether the general specification of explicit
194     // broadcast_dimensions on binary ops is a feature we want to carry
195     // forward. While it can technically be implemented for ranked-dynamic,
196     // it is incompatible with unranked inputs. If this warning is emitted
197     // in real programs, it is an indication that the feature should be
198     // implemented versus just falling back on the more standard definition
199     // of numpy-like prefix-padding.
200     return op->emitWarning()
201            << "unsupported non prefix-padded dynamic rank "
202            << "broadcast_dimensions = " << broadcastDimensions;
203   }
204 
205   result.push_back(hlo::computeBinaryElementwiseBroadcastingResultExtents(
206       loc, lhs, rhs, builder));
207   return success();
208 }
209 }  // namespace
210 
211 //===----------------------------------------------------------------------===//
212 // BroadcastComplexOp (has custom type inference due to different result type).
213 //===----------------------------------------------------------------------===//
214 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)215 LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
216     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
217     DictionaryAttr attributes, RegionRange /*regions*/,
218     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
219   ShapedType lhsType = operands[0].getType().dyn_cast<ShapedType>();
220   if (!lhsType) {
221     return emitOptionalError(location, "expected ShapedType");
222   }
223   Type elementType = ComplexType::get(lhsType.getElementType());
224   return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
225                                                     attributes, elementType,
226                                                     inferedReturnShapes);
227 }
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)228 LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
229     OpBuilder& builder, ValueRange operands,
230     SmallVectorImpl<Value>& reifiedReturnShapes) {
231   return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
232                                                 operands, reifiedReturnShapes);
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // BroadcastCompareOp (has custom type inference due to different result type).
237 //===----------------------------------------------------------------------===//
238 
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,DenseIntElementsAttr broadcastDimensions,mhlo::ComparisonDirection comparisonDirection,mhlo::ComparisonType compareType)239 void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
240                                Value lhs, Value rhs,
241                                DenseIntElementsAttr broadcastDimensions,
242                                mhlo::ComparisonDirection comparisonDirection,
243                                mhlo::ComparisonType compareType) {
244   build(builder, result, lhs, rhs, broadcastDimensions,
245         mhlo::ComparisonDirectionAttr::get(builder.getContext(),
246                                            comparisonDirection),
247         mhlo::ComparisonTypeAttr::get(builder.getContext(), compareType));
248 }
249 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)250 LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
251     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
252     DictionaryAttr attributes, RegionRange /*regions*/,
253     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
254   Type elementType = IntegerType::get(context, 1);
255   return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
256                                                     attributes, elementType,
257                                                     inferedReturnShapes);
258 }
259 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)260 LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
261     OpBuilder& builder, ValueRange operands,
262     SmallVectorImpl<Value>& reifiedReturnShapes) {
263   return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
264                                                 operands, reifiedReturnShapes);
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // IsInfOp
269 //===----------------------------------------------------------------------===//
270 
getIsInfLikeReturnType(Value operand)271 static Type getIsInfLikeReturnType(Value operand) {
272   Builder b(operand.getContext());
273   return mhlo::getSameShapeTensorType(operand.getType().cast<TensorType>(),
274                                       b.getI1Type());
275 }
276 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)277 LogicalResult IsInfOp::inferReturnTypes(
278     MLIRContext* /*ctx*/, Optional<Location>, ValueRange operands,
279     DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
280   inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
281   return success();
282 }
283 
284 //===----------------------------------------------------------------------===//
285 // IsNegInfOp
286 //===----------------------------------------------------------------------===//
287 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)288 LogicalResult IsNegInfOp::inferReturnTypes(
289     MLIRContext* /*ctx*/, Optional<Location>, ValueRange operands,
290     DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
291   inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
292   return success();
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // IsPosInfOp
297 //===----------------------------------------------------------------------===//
298 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)299 LogicalResult IsPosInfOp::inferReturnTypes(
300     MLIRContext* /*ctx*/, Optional<Location>, ValueRange operands,
301     DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
302   inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
303   return success();
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // Macros for method definitions that are common to most broadcasting ops.
308 //===----------------------------------------------------------------------===//
309 
310 #define BROADCAST_BINARY_OP_DEFS(Op)                                       \
311   LogicalResult Op::inferReturnTypeComponents(                             \
312       MLIRContext* context, Optional<Location> location,                   \
313       ValueShapeRange operands, DictionaryAttr attributes,                 \
314       RegionRange regions,                                                 \
315       SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {        \
316     return InferBroadcastBinaryOpReturnTypeComponents(                     \
317         context, location, operands, attributes, /*element_type=*/nullptr, \
318         inferedReturnShapes);                                              \
319   }                                                                        \
320   LogicalResult Op::reifyReturnTypeShapes(                                 \
321       OpBuilder& builder, ValueRange operands,                             \
322       SmallVectorImpl<Value>& reifiedReturnShapes) {                       \
323     return ReifyBroadcastBinaryOpReturnTypeShapes(                         \
324         builder, getOperation(), operands, reifiedReturnShapes);           \
325   }
326 
327 BROADCAST_BINARY_OP_DEFS(BroadcastAddOp);
328 BROADCAST_BINARY_OP_DEFS(BroadcastAndOp);
329 BROADCAST_BINARY_OP_DEFS(BroadcastAtan2Op);
330 BROADCAST_BINARY_OP_DEFS(BroadcastDivOp);
331 BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp);
332 BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
333 BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
334 BROADCAST_BINARY_OP_DEFS(BroadcastNextAfterOp);
335 BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
336 BROADCAST_BINARY_OP_DEFS(BroadcastPolygammaOp);
337 BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
338 BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
339 BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);
340 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp);
341 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
342 BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
343 BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
344 BROADCAST_BINARY_OP_DEFS(BroadcastZetaOp);
345 
346 #undef BROADCAST_BINARY_OP_DEFS
347 
verify()348 LogicalResult ConstantLikeOp::verify() {
349   if (value().getType() != getType().cast<ShapedType>().getElementType())
350     return emitOpError() << "value's type doesn't match element return type";
351   return success();
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // MinimumBroadcastShapesOp
356 //===----------------------------------------------------------------------===//
verify()357 LogicalResult MinimumBroadcastShapesOp::verify() {
358   // Check that the number of operands matches the number of outputs.
359   unsigned resultShapesCount = results().size();
360   unsigned operandShapesCount = shapes().size();
361   if (operandShapesCount != resultShapesCount) {
362     return emitOpError() << "number of operand shapes (" << operandShapesCount
363                          << ") does not match number of result shapes ("
364                          << resultShapesCount << ")";
365   }
366   if (operandShapesCount < 2) {
367     return emitOpError() << "number of operand shapes (" << operandShapesCount
368                          << ") should be >= 2";
369   }
370   return success();
371 }
372 
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)373 LogicalResult ConstantLikeOp::inferReturnTypeComponents(
374     MLIRContext* /*context*/, Optional<Location> location,
375     ValueShapeRange operands, DictionaryAttr attributes,
376     RegionRange /*regions*/,
377     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
378   ConstantLikeOp::Adaptor op(operands, attributes);
379   if (failed(op.verify(location.value()))) return failure();
380   Type elementType = op.value().getType();
381   Type operandType = op.operand().getType();
382   if (operandType.isa<UnrankedTensorType>()) {
383     inferedReturnShapes.emplace_back(elementType);
384   } else {
385     const auto& shape = operandType.cast<RankedTensorType>().getShape();
386     inferedReturnShapes.emplace_back(shape, elementType);
387   }
388   return success();
389 }
390 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)391 LogicalResult ConstantLikeOp::reifyReturnTypeShapes(
392     OpBuilder& builder, ValueRange operands,
393     SmallVectorImpl<Value>& reifiedReturnShapes) {
394   return ::mlir::mhlo::deriveShapeFromOperand(
395       &builder, getOperation(), operands.front(), &reifiedReturnShapes);
396 }
397 
fold(ArrayRef<Attribute>)398 OpFoldResult ConstantLikeOp::fold(ArrayRef<Attribute> /*operands*/) {
399   auto opType = operand().getType().cast<ShapedType>();
400   if (!opType.hasStaticShape()) return {};
401   auto type = RankedTensorType::get(opType.getShape(), value().getType());
402   if (auto complexAttr = value().dyn_cast<complex::NumberAttr>())
403     return DenseElementsAttr::get(type, complexAttr.getValue());
404   return DenseElementsAttr::get(type, value());
405 }
406 
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)407 LogicalResult BroadcastSelectOp::inferReturnTypeComponents(
408     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
409     DictionaryAttr, RegionRange,
410     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
411   BroadcastSelectOp::Adaptor op(operands.getValues());
412   auto predType = op.pred().getType().dyn_cast<ShapedType>();
413   auto onTrueType = op.on_true().getType().dyn_cast<ShapedType>();
414   auto onFalseType = op.on_false().getType().dyn_cast<ShapedType>();
415 
416   if (!predType || !onTrueType || !onFalseType ||
417       onTrueType.getElementType() != onFalseType.getElementType()) {
418     return emitOptionalError(location, "mismatched operand types");
419   }
420 
421   Type elementType = onTrueType.getElementType();
422 
423   // Compute the result shape as two binary broadcasts.
424   ShapedTypeComponents& components = inferredReturnShapes.emplace_back(
425       getBroadcastType(onTrueType, onFalseType, elementType, nullptr));
426   if (components.hasRank()) {
427     components = getBroadcastType(
428         RankedTensorType::get(components.getDims(), elementType), predType,
429         elementType, nullptr);
430   }
431   return success();
432 }
433 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & result)434 LogicalResult BroadcastSelectOp::reifyReturnTypeShapes(
435     OpBuilder& builder, ValueRange operands, SmallVectorImpl<Value>& result) {
436   result.push_back(hlo::computeNaryElementwiseBroadcastingResultExtents(
437       getLoc(), operands, builder));
438   return success();
439 }
440 
441 //===----------------------------------------------------------------------===//
442 // RankSpecializationClusterOp
443 //===----------------------------------------------------------------------===//
444 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute>,SmallVectorImpl<RegionSuccessor> & regions)445 void RankSpecializationClusterOp::getSuccessorRegions(
446     Optional<unsigned> index, ArrayRef<Attribute> /*operands*/,
447     SmallVectorImpl<RegionSuccessor>& regions) {
448   // RankSpecializationClusterOp has unconditional control flows into the region
449   // and back to the parent, so return the correct RegionSuccessor purely based
450   // on the index being None or 0.
451   if (index.has_value()) {
452     regions.push_back(RegionSuccessor(getResults()));
453     return;
454   }
455   regions.push_back(RegionSuccessor(&body()));
456 }
457 
verify()458 LogicalResult RankSpecializationClusterOp::verify() {
459   if (body().getArgumentTypes() != getOperandTypes())
460     return emitOpError() << "block argument types must match operand types";
461 
462   // All operands of nested ops must be defined in the body or declared by the
463   // cluster.
464   Block* body = getBody();
465   for (Operation& nested : body->without_terminator()) {
466     if (!llvm::all_of(nested.getOpOperands(), [&](OpOperand& operand) {
467           Operation* def = operand.get().getDefiningOp();
468           if (def != nullptr && def->getBlock() == body) return true;
469           return llvm::is_contained(body->getArguments(), operand.get());
470         })) {
471       return emitOpError() << "nested ops must not depend on implicit operands";
472     }
473   }
474 
475   return success();
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // TopKOp
480 //===----------------------------------------------------------------------===//
481 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)482 LogicalResult TopKOp::inferReturnTypeComponents(
483     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
484     DictionaryAttr attributes, RegionRange regions,
485     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
486   Builder builder(context);
487   TopKOp::Adaptor adaptor(operands, attributes, regions);
488   Value operand = adaptor.operand();
489   uint64_t k = adaptor.k();
490 
491   auto operandTy = operand.getType().dyn_cast<RankedTensorType>();
492   if (!operandTy) {
493     return emitOptionalError(location, "operand must be ranked");
494   }
495   if (operandTy.getRank() < 1) {
496     return emitOptionalError(location, "operand's rank must be at least 1");
497   }
498   auto operandLastDim = operandTy.getShape()[operandTy.getRank() - 1];
499   if (operandLastDim == ShapedType::kDynamicSize) {
500     return emitOptionalError(location,
501                              "operand's last dimension must be static");
502   }
503   if (operandLastDim < static_cast<int64_t>(k)) {
504     return emitOptionalError(location,
505                              "operand's last dimension must be at least ", k);
506   }
507 
508   SmallVector<int64_t> resultShape;
509   append_range(resultShape, operandTy.getShape());
510   resultShape[operandTy.getRank() - 1] = k;
511 
512   inferredReturnShapes.emplace_back(resultShape, operandTy.getElementType());
513   inferredReturnShapes.emplace_back(resultShape, builder.getI32Type());
514   return success();
515 }
516 
517 }  // namespace chlo
518 }  // namespace mlir
519 
520 #define GET_OP_CLASSES
521 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
522 
523 namespace mlir {
524 namespace chlo {
525 
526 //===----------------------------------------------------------------------===//
527 // chlo Dialect Constructor
528 //===----------------------------------------------------------------------===//
529 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)530 Operation* ChloDialect::materializeConstant(OpBuilder& builder, Attribute value,
531                                             Type type, Location loc) {
532   // Mirror MHLO dialect here.
533   if (value.isa<ElementsAttr>())
534     return builder.create<mhlo::ConstantOp>(loc, type,
535                                             value.cast<ElementsAttr>());
536   return nullptr;
537 }
538 
initialize()539 void ChloDialect::initialize() {
540   addOperations<
541 #define GET_OP_LIST
542 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
543       >();
544 }
545 
546 }  // namespace chlo
547 }  // namespace mlir
548