1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering HLO dialect to LHLO dialect.
17
18 #include <algorithm>
19 #include <utility>
20
21 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
22 #include "mlir-hlo/Dialect/lhlo/transforms/map_hlo_to_lhlo_op.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
27 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
28 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
29 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
30 #include "mlir/Dialect/Func/IR/FuncOps.h"
31 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
32 #include "mlir/Dialect/MemRef/IR/MemRef.h"
33 #include "mlir/Dialect/Shape/IR/Shape.h"
34 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
35 #include "mlir/Dialect/Shape/Transforms/Passes.h"
36 #include "mlir/Dialect/Tensor/IR/Tensor.h"
37 #include "mlir/IR/AffineMap.h"
38 #include "mlir/IR/Attributes.h"
39 #include "mlir/IR/BlockAndValueMapping.h"
40 #include "mlir/IR/Builders.h"
41 #include "mlir/IR/BuiltinOps.h"
42 #include "mlir/IR/BuiltinTypes.h"
43 #include "mlir/IR/Location.h"
44 #include "mlir/IR/MLIRContext.h"
45 #include "mlir/IR/Operation.h"
46 #include "mlir/IR/PatternMatch.h"
47 #include "mlir/Pass/Pass.h"
48 #include "mlir/Transforms/DialectConversion.h"
49
50 namespace mlir {
51 namespace mhlo {
52 namespace {
53
54 template <typename T>
55 using BaseOpConversion = OpConversionPattern<T>;
56
insertDynamicAlloc(Location loc,Value result,Value shapeOperand,ConversionPatternRewriter * rewriter)57 Value insertDynamicAlloc(Location loc, Value result, Value shapeOperand,
58 ConversionPatternRewriter* rewriter) {
59 auto resultType = result.getType().dyn_cast<RankedTensorType>();
60 if (!resultType) {
61 result.getDefiningOp()->emitOpError()
62 << "tensor to buffer conversion expects ranked results";
63 }
64 auto memrefType =
65 MemRefType::get(resultType.getShape(), resultType.getElementType());
66
67 // Extract the required element out of the vector.
68 SmallVector<Value, 4> dynamicOperands;
69 for (const auto& shapeElement : llvm::enumerate(resultType.getShape())) {
70 if (shapeElement.value() != ShapedType::kDynamicSize) continue;
71 Value index =
72 rewriter->create<arith::ConstantIndexOp>(loc, shapeElement.index());
73 Value allocOperand =
74 rewriter->create<tensor::ExtractOp>(loc, shapeOperand, index);
75 if (!allocOperand.getType().isIndex()) {
76 allocOperand = rewriter->create<arith::IndexCastOp>(
77 loc, rewriter->getIndexType(), allocOperand);
78 }
79 dynamicOperands.push_back(allocOperand);
80 }
81
82 return rewriter->create<memref::AllocOp>(loc, memrefType, dynamicOperands);
83 }
84
insertAlloc(Location loc,OpResult result,ConversionPatternRewriter * rewriter)85 Value insertAlloc(Location loc, OpResult result,
86 ConversionPatternRewriter* rewriter) {
87 auto resultType = result.getType().dyn_cast<RankedTensorType>();
88 if (!resultType || !resultType.hasStaticShape()) {
89 result.getDefiningOp()->emitOpError()
90 << "tensor to buffer conversion expects statically shaped results";
91 }
92 auto memrefType =
93 MemRefType::get(resultType.getShape(), resultType.getElementType());
94 OpBuilder::InsertionGuard guard(*rewriter);
95 rewriter->setInsertionPoint(result.getDefiningOp());
96 auto alloc = rewriter->create<memref::AllocOp>(loc, memrefType);
97 return alloc;
98 }
99
100 /// Converts the results of the operation `op` to memref types and append them
101 /// to the `results` vector.
convertResults(Operation * op,SmallVectorImpl<Value> & results,ConversionPatternRewriter & rewriter)102 LogicalResult convertResults(Operation* op, SmallVectorImpl<Value>& results,
103 ConversionPatternRewriter& rewriter) {
104 size_t numOperands = results.size();
105 SmallVector<Value, 2> tensorOperands;
106 for (const auto& result : llvm::enumerate(op->getResults())) {
107 RankedTensorType resultType =
108 result.value().getType().dyn_cast<RankedTensorType>();
109 if (!resultType) return failure();
110
111 if (resultType.hasStaticShape()) {
112 results.push_back(insertAlloc(op->getLoc(), result.value(), &rewriter));
113 continue;
114 }
115 auto shapeTypeOp = dyn_cast<InferShapedTypeOpInterface>(op);
116 if (!shapeTypeOp) return failure();
117
118 if (tensorOperands.empty()) {
119 for (auto operand : ArrayRef<Value>(results).take_front(numOperands)) {
120 auto operandType = operand.getType().dyn_cast<MemRefType>();
121 if (!operandType) return failure();
122 tensorOperands.push_back(rewriter.create<bufferization::ToTensorOp>(
123 op->getLoc(),
124 RankedTensorType::get(operandType.getShape(),
125 operandType.getElementType()),
126 operand));
127 }
128 }
129
130 SmallVector<Value, 1> resultsShape;
131 auto status = shapeTypeOp.reifyReturnTypeShapes(rewriter, tensorOperands,
132 resultsShape);
133 if (failed(status)) return failure();
134 results.push_back(insertDynamicAlloc(
135 op->getLoc(), result.value(), resultsShape[result.index()], &rewriter));
136 }
137 return success();
138 }
139
140 template <typename HloOpTy>
141 class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
142 public:
143 using BaseOpConversion<HloOpTy>::BaseOpConversion;
matchAndRewrite(HloOpTy hloOp,typename HloOpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter) const144 LogicalResult matchAndRewrite(
145 HloOpTy hloOp, typename HloOpTy::Adaptor adaptor,
146 ConversionPatternRewriter& rewriter) const final {
147 Operation* op = hloOp.getOperation();
148 SmallVector<Value, 4> bufferArgs(adaptor.getOperands());
149 if (failed(convertResults(op, bufferArgs, rewriter))) return failure();
150 rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
151 bufferArgs, op->getAttrs());
152 rewriter.replaceOp(op, llvm::makeArrayRef(bufferArgs)
153 .drop_front(adaptor.getOperands().size()));
154 return success();
155 }
156 };
157
158 // This specialization exists so that LMHLO's Dot can be given a specific set of
159 // dimension numbers, when lowering from MHLO's Dot, which does not have
160 // dimension numbers (it uses DotGeneral for this generalized notion of dot
161 // products). When these two dialects are in sync with respect to the
162 // Dot/DotGeneral issue, this specialization should be deleted.
163 template <>
164 class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
165 public:
166 using BaseOpConversion<mhlo::DotOp>::BaseOpConversion;
matchAndRewrite(mhlo::DotOp hloOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const167 LogicalResult matchAndRewrite(
168 mhlo::DotOp hloOp, OpAdaptor adaptor,
169 ConversionPatternRewriter& rewriter) const final {
170 Operation* op = hloOp.getOperation();
171 SmallVector<Value, 2> bufferArgs(adaptor.getOperands());
172 if (failed(convertResults(op, bufferArgs, rewriter))) return failure();
173
174 auto dotOp = rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None,
175 bufferArgs, op->getAttrs());
176 // MHLO's Dot uses rank-2 operands, of the form ([N, M], [M, O]) -> [N, O].
177 auto dimensionNumbers = mhlo::DotDimensionNumbersAttr::get(
178 rewriter.getContext(), /*lhsBatchingDimensions=*/{},
179 /*rhsBatchingDimensions=*/{}, /*lhsContractingDimensions=*/{1},
180 /*rhsContractingDimensions=*/{0});
181 dotOp.setDotDimensionNumbersAttr(dimensionNumbers);
182 rewriter.replaceOp(
183 op, ArrayRef<Value>(bufferArgs).slice(adaptor.getOperands().size()));
184 return success();
185 }
186 };
187
188 struct HloToLhloCustomCallOpConverter
189 : public BaseOpConversion<mhlo::CustomCallOp> {
190 public:
191 using BaseOpConversion<mhlo::CustomCallOp>::BaseOpConversion;
192
matchAndRewritemlir::mhlo::__anonc86515690111::HloToLhloCustomCallOpConverter193 LogicalResult matchAndRewrite(
194 mhlo::CustomCallOp hloOp, OpAdaptor adaptor,
195 ConversionPatternRewriter& rewriter) const final {
196 Operation* op = hloOp.getOperation();
197 SmallVector<Value, 2> bufferArgs(adaptor.getOperands());
198 if (failed(convertResults(op, bufferArgs, rewriter))) return failure();
199
200 auto lhloOp = rewriter.create<lmhlo::CustomCallOp>(
201 op->getLoc(), llvm::None, bufferArgs, op->getAttrs());
202 // Setup AttrSizedOperandSegments attribute to indicate number of operands
203 // for args and outputs.
204 const int32_t segments[2] = {
205 static_cast<int32_t>(adaptor.getOperands().size()),
206 static_cast<int32_t>(op->getNumResults())};
207 lhloOp->setAttr(lhloOp.getOperandSegmentSizeAttr(),
208 rewriter.getDenseI32ArrayAttr(segments));
209
210 rewriter.replaceOp(
211 op, ArrayRef<Value>(bufferArgs).slice(adaptor.getOperands().size()));
212 return success();
213 }
214 };
215
216 struct HloToLhloDotGeneralOpConverter
217 : public BaseOpConversion<mhlo::DotGeneralOp> {
218 using BaseOpConversion<mhlo::DotGeneralOp>::BaseOpConversion;
matchAndRewritemlir::mhlo::__anonc86515690111::HloToLhloDotGeneralOpConverter219 LogicalResult matchAndRewrite(
220 mhlo::DotGeneralOp dotGeneralOp, OpAdaptor adaptor,
221 ConversionPatternRewriter& rewriter) const final {
222 Operation* op = dotGeneralOp.getOperation();
223
224 if (op->getResults().empty()) return failure();
225 OpResult result = op->getResults()[0];
226 RankedTensorType resultType = result.getType().dyn_cast<RankedTensorType>();
227 if (!resultType) return failure();
228
229 // The third buffer argument will be filled with what used to be the return
230 // type of the DotGeneral.
231 if (adaptor.getOperands().size() != 2) return failure();
232 std::array<Value, 3> bufferArgs = {
233 adaptor.getOperands()[0], adaptor.getOperands()[1], {}};
234
235 if (resultType.hasStaticShape()) {
236 bufferArgs[2] = insertAlloc(op->getLoc(), result, &rewriter);
237 } else {
238 SmallVector<Value, 1> resultsShape;
239 auto shapeTypeOp = dyn_cast<InferShapedTypeOpInterface>(op);
240 if (failed(shapeTypeOp.reifyReturnTypeShapes(
241 rewriter, adaptor.getOperands(), resultsShape)))
242 return failure();
243
244 bufferArgs[2] = insertDynamicAlloc(op->getLoc(), result,
245 resultsShape.front(), &rewriter);
246 }
247
248 rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None, bufferArgs,
249 op->getAttrs());
250 rewriter.replaceOp(op, bufferArgs[2]);
251 return success();
252 }
253 };
254
255 template <typename HloOpTy>
256 struct HloToLhloReduceLikeOpConverter : public BaseOpConversion<HloOpTy> {
257 public:
258 using BaseOpConversion<HloOpTy>::BaseOpConversion;
259
matchAndRewritemlir::mhlo::__anonc86515690111::HloToLhloReduceLikeOpConverter260 LogicalResult matchAndRewrite(
261 HloOpTy hloOp, typename HloOpTy::Adaptor adaptor,
262 ConversionPatternRewriter& rewriter) const final {
263 Operation* op = hloOp.getOperation();
264 auto loc = op->getLoc();
265 if (!llvm::hasSingleElement(hloOp.body())) {
266 return op->emitOpError()
267 << "tensor to buffer conversion expects a single block "
268 "in the region containing the operation";
269 }
270 SmallVector<Value, 4> bufferArgs(adaptor.getOperands());
271 if (failed(convertResults(op, bufferArgs, rewriter))) return failure();
272 auto newOp = rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(
273 loc, llvm::None, bufferArgs, op->getAttrs());
274
275 // Copy over the operations inside the region.
276 rewriter.inlineRegionBefore(hloOp.body(), newOp.getBody(),
277 newOp.getBody().end());
278
279 // Convert the region signature to memref and add extra result.
280 auto& entryBlock = newOp.getBody().front();
281 TypeConverter::SignatureConversion sigConversion(
282 adaptor.getOperands().size());
283 for (auto arg : entryBlock.getArguments()) {
284 auto oldType = arg.getType().template cast<TensorType>();
285 auto newType =
286 MemRefType::get(oldType.getShape(), oldType.getElementType());
287 sigConversion.addInputs(arg.getArgNumber(), newType);
288 }
289 auto returnOp = cast<mhlo::ReturnOp>(entryBlock.getTerminator());
290 if (auto tupleTy = returnOp.results()
291 .front()
292 .getType()
293 .template dyn_cast<TupleType>()) {
294 auto* tupleOp = returnOp.getODSOperands(0).front().getDefiningOp();
295 returnOp.getOperation()->dropAllReferences();
296 rewriter.eraseOp(tupleOp);
297 returnOp.getOperation()->setOperands(tupleOp->getOperands());
298 for (auto ty : tupleTy) {
299 auto tensorTy = ty.template cast<TensorType>();
300 sigConversion.addInputs(
301 MemRefType::get(tensorTy.getShape(), tensorTy.getElementType()));
302 }
303 } else {
304 for (auto result : returnOp.results()) {
305 auto resultType = result.getType().template cast<TensorType>();
306 sigConversion.addInputs({MemRefType::get(resultType.getShape(),
307 resultType.getElementType())});
308 }
309 }
310 rewriter.applySignatureConversion(&newOp.getBody(), sigConversion);
311
312 rewriter.replaceOp(
313 op, ArrayRef<Value>(bufferArgs).slice(adaptor.getOperands().size()));
314
315 return success();
316 }
317 };
318
319 // Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator.
320 struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
321 public:
322 using BaseOpConversion<mhlo::ReturnOp>::BaseOpConversion;
323
matchAndRewritemlir::mhlo::__anonc86515690111::HloToLhloReturnOpConverter324 LogicalResult matchAndRewrite(
325 mhlo::ReturnOp op, OpAdaptor adaptor,
326 ConversionPatternRewriter& rewriter) const final {
327 auto loc = op.getLoc();
328 auto& entryBlock = op->getParentRegion()->front();
329 auto numArguments = entryBlock.getNumArguments();
330 if (adaptor.getOperands().size() > numArguments) {
331 return op.emitError(
332 "The number of operands that need Copy operations is more "
333 "than the number of target function arguments.");
334 }
335
336 // The index of the first output block argument.
337 auto destArgIdx = numArguments - adaptor.getOperands().size();
338
339 // Create a lmhlo.copy for each operand of mhlo.return.
340 for (Value operand : adaptor.getOperands()) {
341 rewriter.create<lmhlo::CopyOp>(loc, operand,
342 entryBlock.getArgument(destArgIdx));
343 ++destArgIdx;
344 }
345 rewriter.replaceOpWithNewOp<lmhlo::TerminatorOp>(op);
346 return success();
347 }
348 };
349
350 // Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
351 // buffers if necessary.
352 //
353 // Example fusion with HLO ops.
354 //
355 // func @fusion(%arg0: memref<2x2xf32>,
356 // %arg1: memref<2x2xf32>,
357 // %arg2: memref<2x2xf32>,
358 // %arg3: memref<2x2xf32>) {
359 // "lmhlo.fusion"() ({
360 // %0 = bufferization.to_tensor %arg1 : memref<2x2xf32>
361 // %1 = bufferization.to_tensor %arg2 : memref<2x2xf32>
362 // %2 = "mhlo.add"(%0, %1) :
363 // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
364 // %3 = bufferization.to_tensor %arg0 : memref<2x2xf32>
365 // %4 = "mhlo.multiply"(%2, %3) :
366 // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
367 // tensor_store %4, %arg3 : memref<2x2xf32>
368 // "lmhlo.terminator"() : () -> ()
369 // }) : () -> ()
370 // return
371 // }
372 //
373 // Transformed fusion with LHLO ops.
374 // func @fusion(%arg0: memref<2x2xf32>,
375 // %arg1: memref<2x2xf32>,
376 // %arg2: memref<2x2xf32>,
377 // %arg3: memref<2x2xf32>) {
378 // "lmhlo.fusion"() ({
379 // %0 = alloc() : memref<2x2xf32>
380 // "lmhlo.add"(%arg1, %arg2, %0) :
381 // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
382 // "lmhlo.multiply"(%0, %arg0, %arg3) :
383 // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
384 // "lmhlo.terminator"() : () -> ()
385 // }) : () -> ()
386 // return
387 // }
388 //
389 // FuncOp signature conversion example:
390 //
391 // func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
392 // %0 = "mhlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
393 // tensor<4xf32> %1 = "mhlo.add"(%arg0, %0) : (tensor<4xf32>,
394 // tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
395 // }
396 //
397 // Transformed function with an extra argument for the result. The types have
398 // been converted from tensor to memref.
399 //
400 // func @func_op(%arg0: memref<4xf32>,
401 // %arg1: memref<4xf32>,
402 // %arg2: memref<4xf32>) {
403 // %0 = alloc() : memref<4xf32>
404
405 // "lmhlo.maximum"(%arg0, %arg1, %0) :
406 // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
407 // %1 = alloc() : memref<4xf32>
408 // "lmhlo.add"(%arg0, %0, %1) :
409 // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
410 // "lmhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
411 // "lmhlo.terminator"() : () -> ()
412 // }
413
414 struct HloLegalizeToLhlo : public HloLegalizeToLhloPassBase<HloLegalizeToLhlo> {
415 using HloLegalizeToLhloPassBase<HloLegalizeToLhlo>::HloLegalizeToLhloPassBase;
416
getDependentDialectsmlir::mhlo::__anonc86515690111::HloLegalizeToLhlo417 void getDependentDialects(DialectRegistry& registry) const override {
418 registry.insert<bufferization::BufferizationDialect, lmhlo::LmhloDialect,
419 memref::MemRefDialect, shape::ShapeDialect>();
420 shape::registerBufferizableOpInterfaceExternalModels(registry);
421 }
422
423 public:
424 HloLegalizeToLhlo() = default;
425
runOpInterfaceBufferizationmlir::mhlo::__anonc86515690111::HloLegalizeToLhlo426 LogicalResult runOpInterfaceBufferization() {
427 // Bufferize ops using BufferizableOpInterface. This could be switched to
428 // One-Shot Bufferize in the future.
429 RewritePatternSet patterns(&getContext());
430 bufferization::BufferizationOptions options =
431 bufferization::getPartialBufferizationOptions();
432 // TODO(springerm): Add dialects to this filter as more and more dialects
433 // will be migrated to BufferizableOpInterface-based bufferization.
434 options.opFilter.allowDialect<shape::ShapeDialect>();
435 return bufferization::bufferizeOp(getOperation(), options);
436 }
437
runOnOperationmlir::mhlo::__anonc86515690111::HloLegalizeToLhlo438 void runOnOperation() override {
439 if (failed(runOpInterfaceBufferization())) {
440 signalPassFailure();
441 return;
442 }
443
444 auto& context = getContext();
445 RewritePatternSet patterns(&context);
446 ConversionTarget target(context);
447 target.addLegalDialect<
448 arith::ArithmeticDialect, bufferization::BufferizationDialect,
449 lmhlo::LmhloDialect, memref::MemRefDialect, shape::ShapeDialect,
450 func::FuncDialect, tensor::TensorDialect>();
451 target.addIllegalDialect<mhlo::MhloDialect>();
452 // bufferization.to_memref is illegal if it has uses.
453 // TODO(b/175670649) Make bufferization.to_memref illegal.
454 target.addDynamicallyLegalOp<mlir::bufferization::ToMemrefOp>(
455 [](auto op) { return op->use_empty(); });
456
457 bufferization::BufferizeTypeConverter converter;
458 auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
459 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
460 return converter.isSignatureLegal(op.getFunctionType()) &&
461 converter.isLegal(&op.getBody());
462 });
463 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
464 return std::all_of(op.operand_type_begin(), op.operand_type_end(),
465 isMemRefType) &&
466 std::all_of(op.result_type_begin(), op.result_type_end(),
467 isMemRefType);
468 });
469 target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
470 [&](mlir::func::ReturnOp op) {
471 return std::all_of(op.operand_type_begin(), op.operand_type_end(),
472 isMemRefType);
473 });
474
475 populateHloToLhloConversionPattern(&context, &converter, &patterns);
476 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
477 converter);
478 populateCallOpTypeConversionPattern(patterns, converter);
479 populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
480 populateReturnOpTypeConversionPattern(patterns, converter);
481 populateEliminateBufferizeMaterializationsPatterns(converter, patterns);
482
483 if (failed(applyPartialConversion(getOperation(), target,
484 std::move(patterns))))
485 signalPassFailure();
486 }
487 };
488 } // namespace
489
490 // Simply lowers all mhlo ops to their lmhlo counterparts.
populateDynamicHloToLhloConversionPattern(MLIRContext * context,bufferization::BufferizeTypeConverter * converter,RewritePatternSet * patterns)491 void populateDynamicHloToLhloConversionPattern(
492 MLIRContext* context, bufferization::BufferizeTypeConverter* converter,
493 RewritePatternSet* patterns) {
494 // clang-format off
495 patterns->add<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
496 HloToLhloOpConverter<mhlo::DynamicGatherOp>,
497 HloToLhloOpConverter<mhlo::DynamicIotaOp>,
498 HloToLhloOpConverter<mhlo::DynamicPadOp>,
499 HloToLhloOpConverter<mhlo::DynamicReshapeOp>,
500 HloToLhloOpConverter<mhlo::RealDynamicSliceOp>
501 >(*converter, context);
502 // clang-format on
503 }
504
populateHloToLhloConversionPattern(MLIRContext * context,bufferization::BufferizeTypeConverter * converter,RewritePatternSet * patterns)505 void populateHloToLhloConversionPattern(
506 MLIRContext* context, bufferization::BufferizeTypeConverter* converter,
507 RewritePatternSet* patterns) {
508 populateDynamicHloToLhloConversionPattern(context, converter, patterns);
509
510 // clang-format off
511 patterns->add<
512 HloToLhloCustomCallOpConverter,
513 HloToLhloDotGeneralOpConverter,
514 HloToLhloOpConverter<mhlo::AbsOp>,
515 HloToLhloOpConverter<mhlo::AddOp>,
516 HloToLhloOpConverter<mhlo::AndOp>,
517 HloToLhloOpConverter<mhlo::Atan2Op>,
518 HloToLhloOpConverter<mhlo::BatchNormGradOp>,
519 HloToLhloOpConverter<mhlo::BatchNormTrainingOp>,
520 HloToLhloOpConverter<mhlo::BroadcastInDimOp>,
521 HloToLhloOpConverter<mhlo::CeilOp>,
522 HloToLhloOpConverter<mhlo::ClampOp>,
523 HloToLhloOpConverter<mhlo::CompareOp>,
524 HloToLhloOpConverter<mhlo::ComplexOp>,
525 HloToLhloOpConverter<mhlo::ConcatenateOp>,
526 HloToLhloOpConverter<mhlo::ConstantOp>,
527 HloToLhloOpConverter<mhlo::ConvolutionOp>,
528 HloToLhloOpConverter<mhlo::ConvertOp>,
529 HloToLhloOpConverter<mhlo::CopyOp>,
530 HloToLhloOpConverter<mhlo::CosineOp>,
531 HloToLhloOpConverter<mhlo::DivOp>,
532 HloToLhloOpConverter<mhlo::DotOp>,
533 HloToLhloOpConverter<mhlo::ExpOp>,
534 HloToLhloOpConverter<mhlo::Expm1Op>,
535 HloToLhloOpConverter<mhlo::FloorOp>,
536 HloToLhloOpConverter<mhlo::GatherOp>,
537 HloToLhloOpConverter<mhlo::ImagOp>,
538 HloToLhloOpConverter<mhlo::IotaOp>,
539 HloToLhloOpConverter<mhlo::IsFiniteOp>,
540 HloToLhloOpConverter<mhlo::LogOp>,
541 HloToLhloOpConverter<mhlo::LogisticOp>,
542 HloToLhloOpConverter<mhlo::MaxOp>,
543 HloToLhloOpConverter<mhlo::MinOp>,
544 HloToLhloOpConverter<mhlo::MulOp>,
545 HloToLhloOpConverter<mhlo::NegOp>,
546 HloToLhloOpConverter<mhlo::NotOp>,
547 HloToLhloOpConverter<mhlo::OrOp>,
548 HloToLhloOpConverter<mhlo::PowOp>,
549 HloToLhloOpConverter<mhlo::RealOp>,
550 HloToLhloOpConverter<mhlo::RemOp>,
551 HloToLhloOpConverter<mhlo::RsqrtOp>,
552 HloToLhloOpConverter<mhlo::ReshapeOp>,
553 HloToLhloOpConverter<mhlo::SelectOp>,
554 HloToLhloOpConverter<mhlo::ShiftLeftOp>,
555 HloToLhloOpConverter<mhlo::ShiftRightArithmeticOp>,
556 HloToLhloOpConverter<mhlo::ShiftRightLogicalOp>,
557 HloToLhloOpConverter<mhlo::SignOp>,
558 HloToLhloOpConverter<mhlo::SineOp>,
559 HloToLhloOpConverter<mhlo::SliceOp>,
560 HloToLhloOpConverter<mhlo::SqrtOp>,
561 HloToLhloOpConverter<mhlo::SubtractOp>,
562 HloToLhloOpConverter<mhlo::TanhOp>,
563 HloToLhloOpConverter<mhlo::TransposeOp>,
564 HloToLhloOpConverter<mhlo::XorOp>,
565 HloToLhloReduceLikeOpConverter<mhlo::ReduceOp>,
566 HloToLhloReduceLikeOpConverter<mhlo::ReduceWindowOp>,
567 HloToLhloReturnOpConverter
568 >(*converter, context);
569 // clang-format on
570 }
571
createLegalizeToLhloPass()572 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
573 return std::make_unique<HloLegalizeToLhlo>();
574 }
575
576 } // namespace mhlo
577 } // namespace mlir
578