1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering HLO dialect to LHLO dialect.
17 
18 #include <algorithm>
19 #include <utility>
20 
21 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
22 #include "mlir-hlo/Dialect/lhlo/transforms/map_hlo_to_lhlo_op.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
27 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
28 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
29 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
30 #include "mlir/Dialect/Func/IR/FuncOps.h"
31 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
32 #include "mlir/Dialect/MemRef/IR/MemRef.h"
33 #include "mlir/Dialect/Shape/IR/Shape.h"
34 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
35 #include "mlir/Dialect/Shape/Transforms/Passes.h"
36 #include "mlir/Dialect/Tensor/IR/Tensor.h"
37 #include "mlir/IR/AffineMap.h"
38 #include "mlir/IR/Attributes.h"
39 #include "mlir/IR/BlockAndValueMapping.h"
40 #include "mlir/IR/Builders.h"
41 #include "mlir/IR/BuiltinOps.h"
42 #include "mlir/IR/BuiltinTypes.h"
43 #include "mlir/IR/Location.h"
44 #include "mlir/IR/MLIRContext.h"
45 #include "mlir/IR/Operation.h"
46 #include "mlir/IR/PatternMatch.h"
47 #include "mlir/Pass/Pass.h"
48 #include "mlir/Transforms/DialectConversion.h"
49 
50 namespace mlir {
51 namespace mhlo {
52 namespace {
53 
54 template <typename T>
55 using BaseOpConversion = OpConversionPattern<T>;
56 
insertDynamicAlloc(Location loc,Value result,Value shapeOperand,ConversionPatternRewriter * rewriter)57 Value insertDynamicAlloc(Location loc, Value result, Value shapeOperand,
58                          ConversionPatternRewriter* rewriter) {
59   auto resultType = result.getType().dyn_cast<RankedTensorType>();
60   if (!resultType) {
61     result.getDefiningOp()->emitOpError()
62         << "tensor to buffer conversion expects ranked results";
63   }
64   auto memrefType =
65       MemRefType::get(resultType.getShape(), resultType.getElementType());
66 
67   // Extract the required element out of the vector.
68   SmallVector<Value, 4> dynamicOperands;
69   for (const auto& shapeElement : llvm::enumerate(resultType.getShape())) {
70     if (shapeElement.value() != ShapedType::kDynamicSize) continue;
71     Value index =
72         rewriter->create<arith::ConstantIndexOp>(loc, shapeElement.index());
73     Value allocOperand =
74         rewriter->create<tensor::ExtractOp>(loc, shapeOperand, index);
75     if (!allocOperand.getType().isIndex()) {
76       allocOperand = rewriter->create<arith::IndexCastOp>(
77           loc, rewriter->getIndexType(), allocOperand);
78     }
79     dynamicOperands.push_back(allocOperand);
80   }
81 
82   return rewriter->create<memref::AllocOp>(loc, memrefType, dynamicOperands);
83 }
84 
insertAlloc(Location loc,OpResult result,ConversionPatternRewriter * rewriter)85 Value insertAlloc(Location loc, OpResult result,
86                   ConversionPatternRewriter* rewriter) {
87   auto resultType = result.getType().dyn_cast<RankedTensorType>();
88   if (!resultType || !resultType.hasStaticShape()) {
89     result.getDefiningOp()->emitOpError()
90         << "tensor to buffer conversion expects statically shaped results";
91   }
92   auto memrefType =
93       MemRefType::get(resultType.getShape(), resultType.getElementType());
94   OpBuilder::InsertionGuard guard(*rewriter);
95   rewriter->setInsertionPoint(result.getDefiningOp());
96   auto alloc = rewriter->create<memref::AllocOp>(loc, memrefType);
97   return alloc;
98 }
99 
100 /// Converts the results of the operation `op` to memref types and append them
101 /// to the `results` vector.
convertResults(Operation * op,SmallVectorImpl<Value> & results,ConversionPatternRewriter & rewriter)102 LogicalResult convertResults(Operation* op, SmallVectorImpl<Value>& results,
103                              ConversionPatternRewriter& rewriter) {
104   size_t numOperands = results.size();
105   SmallVector<Value, 2> tensorOperands;
106   for (const auto& result : llvm::enumerate(op->getResults())) {
107     RankedTensorType resultType =
108         result.value().getType().dyn_cast<RankedTensorType>();
109     if (!resultType) return failure();
110 
111     if (resultType.hasStaticShape()) {
112       results.push_back(insertAlloc(op->getLoc(), result.value(), &rewriter));
113       continue;
114     }
115     auto shapeTypeOp = dyn_cast<InferShapedTypeOpInterface>(op);
116     if (!shapeTypeOp) return failure();
117 
118     if (tensorOperands.empty()) {
119       for (auto operand : ArrayRef<Value>(results).take_front(numOperands)) {
120         auto operandType = operand.getType().dyn_cast<MemRefType>();
121         if (!operandType) return failure();
122         tensorOperands.push_back(rewriter.create<bufferization::ToTensorOp>(
123             op->getLoc(),
124             RankedTensorType::get(operandType.getShape(),
125                                   operandType.getElementType()),
126             operand));
127       }
128     }
129 
130     SmallVector<Value, 1> resultsShape;
131     auto status = shapeTypeOp.reifyReturnTypeShapes(rewriter, tensorOperands,
132                                                     resultsShape);
133     if (failed(status)) return failure();
134     results.push_back(insertDynamicAlloc(
135         op->getLoc(), result.value(), resultsShape[result.index()], &rewriter));
136   }
137   return success();
138 }
139 
140 template <typename HloOpTy>
141 class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
142  public:
143   using BaseOpConversion<HloOpTy>::BaseOpConversion;
matchAndRewrite(HloOpTy hloOp,typename HloOpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter) const144   LogicalResult matchAndRewrite(
145       HloOpTy hloOp, typename HloOpTy::Adaptor adaptor,
146       ConversionPatternRewriter& rewriter) const final {
147     Operation* op = hloOp.getOperation();
148     SmallVector<Value, 4> bufferArgs(adaptor.getOperands());
149     if (failed(convertResults(op, bufferArgs, rewriter))) return failure();
150     rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
151                                                 bufferArgs, op->getAttrs());
152     rewriter.replaceOp(op, llvm::makeArrayRef(bufferArgs)
153                                .drop_front(adaptor.getOperands().size()));
154     return success();
155   }
156 };
157 
158 // This specialization exists so that LMHLO's Dot can be given a specific set of
159 // dimension numbers, when lowering from MHLO's Dot, which does not have
160 // dimension numbers (it uses DotGeneral for this generalized notion of dot
161 // products). When these two dialects are in sync with respect to the
162 // Dot/DotGeneral issue, this specialization should be deleted.
163 template <>
164 class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
165  public:
166   using BaseOpConversion<mhlo::DotOp>::BaseOpConversion;
matchAndRewrite(mhlo::DotOp hloOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const167   LogicalResult matchAndRewrite(
168       mhlo::DotOp hloOp, OpAdaptor adaptor,
169       ConversionPatternRewriter& rewriter) const final {
170     Operation* op = hloOp.getOperation();
171     SmallVector<Value, 2> bufferArgs(adaptor.getOperands());
172     if (failed(convertResults(op, bufferArgs, rewriter))) return failure();
173 
174     auto dotOp = rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None,
175                                                bufferArgs, op->getAttrs());
176     // MHLO's Dot uses rank-2 operands, of the form ([N, M], [M, O]) -> [N, O].
177     auto dimensionNumbers = mhlo::DotDimensionNumbersAttr::get(
178         rewriter.getContext(), /*lhsBatchingDimensions=*/{},
179         /*rhsBatchingDimensions=*/{}, /*lhsContractingDimensions=*/{1},
180         /*rhsContractingDimensions=*/{0});
181     dotOp.setDotDimensionNumbersAttr(dimensionNumbers);
182     rewriter.replaceOp(
183         op, ArrayRef<Value>(bufferArgs).slice(adaptor.getOperands().size()));
184     return success();
185   }
186 };
187 
188 struct HloToLhloCustomCallOpConverter
189     : public BaseOpConversion<mhlo::CustomCallOp> {
190  public:
191   using BaseOpConversion<mhlo::CustomCallOp>::BaseOpConversion;
192 
matchAndRewritemlir::mhlo::__anonc86515690111::HloToLhloCustomCallOpConverter193   LogicalResult matchAndRewrite(
194       mhlo::CustomCallOp hloOp, OpAdaptor adaptor,
195       ConversionPatternRewriter& rewriter) const final {
196     Operation* op = hloOp.getOperation();
197     SmallVector<Value, 2> bufferArgs(adaptor.getOperands());
198     if (failed(convertResults(op, bufferArgs, rewriter))) return failure();
199 
200     auto lhloOp = rewriter.create<lmhlo::CustomCallOp>(
201         op->getLoc(), llvm::None, bufferArgs, op->getAttrs());
202     // Setup AttrSizedOperandSegments attribute to indicate number of operands
203     // for args and outputs.
204     const int32_t segments[2] = {
205         static_cast<int32_t>(adaptor.getOperands().size()),
206         static_cast<int32_t>(op->getNumResults())};
207     lhloOp->setAttr(lhloOp.getOperandSegmentSizeAttr(),
208                     rewriter.getDenseI32ArrayAttr(segments));
209 
210     rewriter.replaceOp(
211         op, ArrayRef<Value>(bufferArgs).slice(adaptor.getOperands().size()));
212     return success();
213   }
214 };
215 
216 struct HloToLhloDotGeneralOpConverter
217     : public BaseOpConversion<mhlo::DotGeneralOp> {
218   using BaseOpConversion<mhlo::DotGeneralOp>::BaseOpConversion;
matchAndRewritemlir::mhlo::__anonc86515690111::HloToLhloDotGeneralOpConverter219   LogicalResult matchAndRewrite(
220       mhlo::DotGeneralOp dotGeneralOp, OpAdaptor adaptor,
221       ConversionPatternRewriter& rewriter) const final {
222     Operation* op = dotGeneralOp.getOperation();
223 
224     if (op->getResults().empty()) return failure();
225     OpResult result = op->getResults()[0];
226     RankedTensorType resultType = result.getType().dyn_cast<RankedTensorType>();
227     if (!resultType) return failure();
228 
229     // The third buffer argument will be filled with what used to be the return
230     // type of the DotGeneral.
231     if (adaptor.getOperands().size() != 2) return failure();
232     std::array<Value, 3> bufferArgs = {
233         adaptor.getOperands()[0], adaptor.getOperands()[1], {}};
234 
235     if (resultType.hasStaticShape()) {
236       bufferArgs[2] = insertAlloc(op->getLoc(), result, &rewriter);
237     } else {
238       SmallVector<Value, 1> resultsShape;
239       auto shapeTypeOp = dyn_cast<InferShapedTypeOpInterface>(op);
240       if (failed(shapeTypeOp.reifyReturnTypeShapes(
241               rewriter, adaptor.getOperands(), resultsShape)))
242         return failure();
243 
244       bufferArgs[2] = insertDynamicAlloc(op->getLoc(), result,
245                                          resultsShape.front(), &rewriter);
246     }
247 
248     rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None, bufferArgs,
249                                   op->getAttrs());
250     rewriter.replaceOp(op, bufferArgs[2]);
251     return success();
252   }
253 };
254 
255 template <typename HloOpTy>
256 struct HloToLhloReduceLikeOpConverter : public BaseOpConversion<HloOpTy> {
257  public:
258   using BaseOpConversion<HloOpTy>::BaseOpConversion;
259 
matchAndRewritemlir::mhlo::__anonc86515690111::HloToLhloReduceLikeOpConverter260   LogicalResult matchAndRewrite(
261       HloOpTy hloOp, typename HloOpTy::Adaptor adaptor,
262       ConversionPatternRewriter& rewriter) const final {
263     Operation* op = hloOp.getOperation();
264     auto loc = op->getLoc();
265     if (!llvm::hasSingleElement(hloOp.body())) {
266       return op->emitOpError()
267              << "tensor to buffer conversion expects a single block "
268                 "in the region containing the operation";
269     }
270     SmallVector<Value, 4> bufferArgs(adaptor.getOperands());
271     if (failed(convertResults(op, bufferArgs, rewriter))) return failure();
272     auto newOp = rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(
273         loc, llvm::None, bufferArgs, op->getAttrs());
274 
275     // Copy over the operations inside the region.
276     rewriter.inlineRegionBefore(hloOp.body(), newOp.getBody(),
277                                 newOp.getBody().end());
278 
279     // Convert the region signature to memref and add extra result.
280     auto& entryBlock = newOp.getBody().front();
281     TypeConverter::SignatureConversion sigConversion(
282         adaptor.getOperands().size());
283     for (auto arg : entryBlock.getArguments()) {
284       auto oldType = arg.getType().template cast<TensorType>();
285       auto newType =
286           MemRefType::get(oldType.getShape(), oldType.getElementType());
287       sigConversion.addInputs(arg.getArgNumber(), newType);
288     }
289     auto returnOp = cast<mhlo::ReturnOp>(entryBlock.getTerminator());
290     if (auto tupleTy = returnOp.results()
291                            .front()
292                            .getType()
293                            .template dyn_cast<TupleType>()) {
294       auto* tupleOp = returnOp.getODSOperands(0).front().getDefiningOp();
295       returnOp.getOperation()->dropAllReferences();
296       rewriter.eraseOp(tupleOp);
297       returnOp.getOperation()->setOperands(tupleOp->getOperands());
298       for (auto ty : tupleTy) {
299         auto tensorTy = ty.template cast<TensorType>();
300         sigConversion.addInputs(
301             MemRefType::get(tensorTy.getShape(), tensorTy.getElementType()));
302       }
303     } else {
304       for (auto result : returnOp.results()) {
305         auto resultType = result.getType().template cast<TensorType>();
306         sigConversion.addInputs({MemRefType::get(resultType.getShape(),
307                                                  resultType.getElementType())});
308       }
309     }
310     rewriter.applySignatureConversion(&newOp.getBody(), sigConversion);
311 
312     rewriter.replaceOp(
313         op, ArrayRef<Value>(bufferArgs).slice(adaptor.getOperands().size()));
314 
315     return success();
316   }
317 };
318 
319 // Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator.
320 struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
321  public:
322   using BaseOpConversion<mhlo::ReturnOp>::BaseOpConversion;
323 
matchAndRewritemlir::mhlo::__anonc86515690111::HloToLhloReturnOpConverter324   LogicalResult matchAndRewrite(
325       mhlo::ReturnOp op, OpAdaptor adaptor,
326       ConversionPatternRewriter& rewriter) const final {
327     auto loc = op.getLoc();
328     auto& entryBlock = op->getParentRegion()->front();
329     auto numArguments = entryBlock.getNumArguments();
330     if (adaptor.getOperands().size() > numArguments) {
331       return op.emitError(
332           "The number of operands that need Copy operations is more "
333           "than the number of target function arguments.");
334     }
335 
336     // The index of the first output block argument.
337     auto destArgIdx = numArguments - adaptor.getOperands().size();
338 
339     // Create a lmhlo.copy for each operand of mhlo.return.
340     for (Value operand : adaptor.getOperands()) {
341       rewriter.create<lmhlo::CopyOp>(loc, operand,
342                                      entryBlock.getArgument(destArgIdx));
343       ++destArgIdx;
344     }
345     rewriter.replaceOpWithNewOp<lmhlo::TerminatorOp>(op);
346     return success();
347   }
348 };
349 
350 // Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
351 // buffers if necessary.
352 //
353 // Example fusion with HLO ops.
354 //
355 // func @fusion(%arg0: memref<2x2xf32>,
356 //              %arg1: memref<2x2xf32>,
357 //              %arg2: memref<2x2xf32>,
358 //              %arg3: memref<2x2xf32>) {
359 //   "lmhlo.fusion"() ({
360 //     %0 = bufferization.to_tensor %arg1 : memref<2x2xf32>
361 //     %1 = bufferization.to_tensor %arg2 : memref<2x2xf32>
362 //     %2 = "mhlo.add"(%0, %1) :
363 //         (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
364 //     %3 = bufferization.to_tensor %arg0 : memref<2x2xf32>
365 //     %4 = "mhlo.multiply"(%2, %3) :
366 //         (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
367 //     tensor_store %4, %arg3 : memref<2x2xf32>
368 //     "lmhlo.terminator"() : () -> ()
369 //   }) : () -> ()
370 //   return
371 // }
372 //
373 // Transformed fusion with LHLO ops.
374 // func @fusion(%arg0: memref<2x2xf32>,
375 //              %arg1: memref<2x2xf32>,
376 //              %arg2: memref<2x2xf32>,
377 //              %arg3: memref<2x2xf32>) {
378 //   "lmhlo.fusion"() ({
379 //     %0 = alloc() : memref<2x2xf32>
380 //     "lmhlo.add"(%arg1, %arg2, %0) :
381 //         (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
382 //     "lmhlo.multiply"(%0, %arg0, %arg3) :
383 //         (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
384 //     "lmhlo.terminator"() : () -> ()
385 //   }) : () -> ()
386 //   return
387 // }
388 //
389 // FuncOp signature conversion example:
390 //
391 // func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
392 //   %0 = "mhlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
393 //   tensor<4xf32> %1 = "mhlo.add"(%arg0, %0)  : (tensor<4xf32>,
394 //   tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
395 // }
396 //
397 // Transformed function with an extra argument for the result. The types have
398 // been converted from tensor to memref.
399 //
400 // func @func_op(%arg0: memref<4xf32>,
401 //               %arg1: memref<4xf32>,
402 //               %arg2: memref<4xf32>) {
403 //   %0 = alloc() : memref<4xf32>
404 
405 //   "lmhlo.maximum"(%arg0, %arg1, %0) :
406 //         (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
407 //   %1 = alloc() : memref<4xf32>
408 //   "lmhlo.add"(%arg0, %0, %1) :
409 //         (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
410 //   "lmhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
411 //   "lmhlo.terminator"() : () -> ()
412 // }
413 
414 struct HloLegalizeToLhlo : public HloLegalizeToLhloPassBase<HloLegalizeToLhlo> {
415   using HloLegalizeToLhloPassBase<HloLegalizeToLhlo>::HloLegalizeToLhloPassBase;
416 
getDependentDialectsmlir::mhlo::__anonc86515690111::HloLegalizeToLhlo417   void getDependentDialects(DialectRegistry& registry) const override {
418     registry.insert<bufferization::BufferizationDialect, lmhlo::LmhloDialect,
419                     memref::MemRefDialect, shape::ShapeDialect>();
420     shape::registerBufferizableOpInterfaceExternalModels(registry);
421   }
422 
423  public:
424   HloLegalizeToLhlo() = default;
425 
runOpInterfaceBufferizationmlir::mhlo::__anonc86515690111::HloLegalizeToLhlo426   LogicalResult runOpInterfaceBufferization() {
427     // Bufferize ops using BufferizableOpInterface. This could be switched to
428     // One-Shot Bufferize in the future.
429     RewritePatternSet patterns(&getContext());
430     bufferization::BufferizationOptions options =
431         bufferization::getPartialBufferizationOptions();
432     // TODO(springerm): Add dialects to this filter as more and more dialects
433     // will be migrated to BufferizableOpInterface-based bufferization.
434     options.opFilter.allowDialect<shape::ShapeDialect>();
435     return bufferization::bufferizeOp(getOperation(), options);
436   }
437 
runOnOperationmlir::mhlo::__anonc86515690111::HloLegalizeToLhlo438   void runOnOperation() override {
439     if (failed(runOpInterfaceBufferization())) {
440       signalPassFailure();
441       return;
442     }
443 
444     auto& context = getContext();
445     RewritePatternSet patterns(&context);
446     ConversionTarget target(context);
447     target.addLegalDialect<
448         arith::ArithmeticDialect, bufferization::BufferizationDialect,
449         lmhlo::LmhloDialect, memref::MemRefDialect, shape::ShapeDialect,
450         func::FuncDialect, tensor::TensorDialect>();
451     target.addIllegalDialect<mhlo::MhloDialect>();
452     // bufferization.to_memref is illegal if it has uses.
453     // TODO(b/175670649) Make bufferization.to_memref illegal.
454     target.addDynamicallyLegalOp<mlir::bufferization::ToMemrefOp>(
455         [](auto op) { return op->use_empty(); });
456 
457     bufferization::BufferizeTypeConverter converter;
458     auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
459     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
460       return converter.isSignatureLegal(op.getFunctionType()) &&
461              converter.isLegal(&op.getBody());
462     });
463     target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
464       return std::all_of(op.operand_type_begin(), op.operand_type_end(),
465                          isMemRefType) &&
466              std::all_of(op.result_type_begin(), op.result_type_end(),
467                          isMemRefType);
468     });
469     target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
470         [&](mlir::func::ReturnOp op) {
471           return std::all_of(op.operand_type_begin(), op.operand_type_end(),
472                              isMemRefType);
473         });
474 
475     populateHloToLhloConversionPattern(&context, &converter, &patterns);
476     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
477                                                                    converter);
478     populateCallOpTypeConversionPattern(patterns, converter);
479     populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
480     populateReturnOpTypeConversionPattern(patterns, converter);
481     populateEliminateBufferizeMaterializationsPatterns(converter, patterns);
482 
483     if (failed(applyPartialConversion(getOperation(), target,
484                                       std::move(patterns))))
485       signalPassFailure();
486   }
487 };
488 }  // namespace
489 
490 // Simply lowers all mhlo ops to their lmhlo counterparts.
populateDynamicHloToLhloConversionPattern(MLIRContext * context,bufferization::BufferizeTypeConverter * converter,RewritePatternSet * patterns)491 void populateDynamicHloToLhloConversionPattern(
492     MLIRContext* context, bufferization::BufferizeTypeConverter* converter,
493     RewritePatternSet* patterns) {
494   // clang-format off
495   patterns->add<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
496                    HloToLhloOpConverter<mhlo::DynamicGatherOp>,
497                    HloToLhloOpConverter<mhlo::DynamicIotaOp>,
498                    HloToLhloOpConverter<mhlo::DynamicPadOp>,
499                    HloToLhloOpConverter<mhlo::DynamicReshapeOp>,
500                    HloToLhloOpConverter<mhlo::RealDynamicSliceOp>
501   >(*converter, context);
502   // clang-format on
503 }
504 
populateHloToLhloConversionPattern(MLIRContext * context,bufferization::BufferizeTypeConverter * converter,RewritePatternSet * patterns)505 void populateHloToLhloConversionPattern(
506     MLIRContext* context, bufferization::BufferizeTypeConverter* converter,
507     RewritePatternSet* patterns) {
508   populateDynamicHloToLhloConversionPattern(context, converter, patterns);
509 
510   // clang-format off
511   patterns->add<
512       HloToLhloCustomCallOpConverter,
513       HloToLhloDotGeneralOpConverter,
514       HloToLhloOpConverter<mhlo::AbsOp>,
515       HloToLhloOpConverter<mhlo::AddOp>,
516       HloToLhloOpConverter<mhlo::AndOp>,
517       HloToLhloOpConverter<mhlo::Atan2Op>,
518       HloToLhloOpConverter<mhlo::BatchNormGradOp>,
519       HloToLhloOpConverter<mhlo::BatchNormTrainingOp>,
520       HloToLhloOpConverter<mhlo::BroadcastInDimOp>,
521       HloToLhloOpConverter<mhlo::CeilOp>,
522       HloToLhloOpConverter<mhlo::ClampOp>,
523       HloToLhloOpConverter<mhlo::CompareOp>,
524       HloToLhloOpConverter<mhlo::ComplexOp>,
525       HloToLhloOpConverter<mhlo::ConcatenateOp>,
526       HloToLhloOpConverter<mhlo::ConstantOp>,
527       HloToLhloOpConverter<mhlo::ConvolutionOp>,
528       HloToLhloOpConverter<mhlo::ConvertOp>,
529       HloToLhloOpConverter<mhlo::CopyOp>,
530       HloToLhloOpConverter<mhlo::CosineOp>,
531       HloToLhloOpConverter<mhlo::DivOp>,
532       HloToLhloOpConverter<mhlo::DotOp>,
533       HloToLhloOpConverter<mhlo::ExpOp>,
534       HloToLhloOpConverter<mhlo::Expm1Op>,
535       HloToLhloOpConverter<mhlo::FloorOp>,
536       HloToLhloOpConverter<mhlo::GatherOp>,
537       HloToLhloOpConverter<mhlo::ImagOp>,
538       HloToLhloOpConverter<mhlo::IotaOp>,
539       HloToLhloOpConverter<mhlo::IsFiniteOp>,
540       HloToLhloOpConverter<mhlo::LogOp>,
541       HloToLhloOpConverter<mhlo::LogisticOp>,
542       HloToLhloOpConverter<mhlo::MaxOp>,
543       HloToLhloOpConverter<mhlo::MinOp>,
544       HloToLhloOpConverter<mhlo::MulOp>,
545       HloToLhloOpConverter<mhlo::NegOp>,
546       HloToLhloOpConverter<mhlo::NotOp>,
547       HloToLhloOpConverter<mhlo::OrOp>,
548       HloToLhloOpConverter<mhlo::PowOp>,
549       HloToLhloOpConverter<mhlo::RealOp>,
550       HloToLhloOpConverter<mhlo::RemOp>,
551       HloToLhloOpConverter<mhlo::RsqrtOp>,
552       HloToLhloOpConverter<mhlo::ReshapeOp>,
553       HloToLhloOpConverter<mhlo::SelectOp>,
554       HloToLhloOpConverter<mhlo::ShiftLeftOp>,
555       HloToLhloOpConverter<mhlo::ShiftRightArithmeticOp>,
556       HloToLhloOpConverter<mhlo::ShiftRightLogicalOp>,
557       HloToLhloOpConverter<mhlo::SignOp>,
558       HloToLhloOpConverter<mhlo::SineOp>,
559       HloToLhloOpConverter<mhlo::SliceOp>,
560       HloToLhloOpConverter<mhlo::SqrtOp>,
561       HloToLhloOpConverter<mhlo::SubtractOp>,
562       HloToLhloOpConverter<mhlo::TanhOp>,
563       HloToLhloOpConverter<mhlo::TransposeOp>,
564       HloToLhloOpConverter<mhlo::XorOp>,
565       HloToLhloReduceLikeOpConverter<mhlo::ReduceOp>,
566       HloToLhloReduceLikeOpConverter<mhlo::ReduceWindowOp>,
567       HloToLhloReturnOpConverter
568   >(*converter, context);
569   // clang-format on
570 }
571 
createLegalizeToLhloPass()572 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
573   return std::make_unique<HloLegalizeToLhlo>();
574 }
575 
576 }  // namespace mhlo
577 }  // namespace mlir
578