1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for translating mixed IR to buffer form.
17 // Currently it supports MHLO and some operations from the Standard dialect.
18
19 #include <memory>
20 #include <utility>
21
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
27 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/BuiltinAttributes.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/BuiltinTypes.h"
33 #include "mlir/IR/MLIRContext.h"
34 #include "mlir/IR/Operation.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/Support/LogicalResult.h"
37 #include "mlir/Transforms/DialectConversion.h"
38
39 namespace mlir {
40 namespace mhlo {
41 namespace {
42
43 // Generic pattern that rewrites any op by rewriting its operands and result
44 // types. Regions are also rewritten.
45 class ConvertToSignless : public ConversionPattern {
46 public:
ConvertToSignless(TypeConverter & typeConverter,MLIRContext * context)47 ConvertToSignless(TypeConverter& typeConverter, MLIRContext* context)
48 : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
49
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const50 LogicalResult matchAndRewrite(
51 Operation* op, ArrayRef<Value> operands,
52 ConversionPatternRewriter& rewriter) const final {
53 SmallVector<Type> resultTypes;
54 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
55 return failure();
56
57 auto* newOp = Operation::create(op->getLoc(), op->getName(), resultTypes,
58 operands, op->getAttrs(),
59 op->getSuccessors(), op->getNumRegions());
60 for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
61 Region& before = std::get<0>(regions);
62 Region& parent = std::get<1>(regions);
63 rewriter.inlineRegionBefore(before, parent, parent.end());
64 if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
65 return failure();
66 }
67 rewriter.insert(newOp);
68 rewriter.replaceOp(op, newOp->getResults());
69 return success();
70 }
71 };
72
73 // A pattern that converts the type of the attribute used as an operand for
74 // arith.constant
75 class ConvertConstantToSignless
76 : public OpConversionPattern<arith::ConstantOp> {
77 public:
ConvertConstantToSignless(TypeConverter & typeConverter,MLIRContext * context)78 ConvertConstantToSignless(TypeConverter& typeConverter, MLIRContext* context)
79 : OpConversionPattern<arith::ConstantOp>(typeConverter, context) {}
80
matchAndRewrite(arith::ConstantOp constantOp,arith::ConstantOpAdaptor adaptor,ConversionPatternRewriter & rewriter) const81 LogicalResult matchAndRewrite(
82 arith::ConstantOp constantOp, arith::ConstantOpAdaptor adaptor,
83 ConversionPatternRewriter& rewriter) const override {
84 // We only care about unsigned integers
85 if (!adaptor.getValue().isa<DenseIntElementsAttr>()) return failure();
86
87 auto values = llvm::to_vector(
88 adaptor.getValue().cast<DenseIntElementsAttr>().getValues<APInt>());
89 auto newValues = DenseIntElementsAttr::get(
90 typeConverter->convertType(constantOp.getType()), values);
91
92 rewriter.replaceOpWithNewOp<arith::ConstantOp>(constantOp, newValues);
93 return success();
94 }
95 };
96
97 struct ConvertToSignlessPass
98 : public ConvertToSignlessPassBase<ConvertToSignlessPass> {
99 public:
runOnOperationmlir::mhlo::__anon881b8e4a0111::ConvertToSignlessPass100 void runOnOperation() override {
101 auto& context = getContext();
102 ConversionTarget target(context);
103
104 mhlo::RemoveSignTypeConverter converter;
105 target.markUnknownOpDynamicallyLegal([&](auto op) {
106 return converter.isLegal(op->getOperandTypes()) &&
107 converter.isLegal(op->getResultTypes());
108 });
109 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
110 return converter.isSignatureLegal(op.getFunctionType());
111 });
112 target.addDynamicallyLegalOp<arith::ConstantOp>([&](arith::ConstantOp op) {
113 return converter.isLegal(op.getType()) &&
114 converter.isLegal(op.getValue().getType());
115 });
116
117 RewritePatternSet patterns(&getContext());
118 patterns.add<ConvertToSignless, ConvertConstantToSignless>(converter,
119 &context);
120 // FuncOp is special as it has type encoding via attributes.
121 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
122 converter);
123
124 auto module = getOperation();
125 if (failed(applyFullConversion(module, target, std::move(patterns)))) {
126 signalPassFailure();
127 }
128 }
129 };
130
131 } // namespace
132
createConvertToSignlessPass()133 std::unique_ptr<OperationPass<ModuleOp>> createConvertToSignlessPass() {
134 return std::make_unique<ConvertToSignlessPass>();
135 }
136
137 } // namespace mhlo
138 } // namespace mlir
139