1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for translating mixed IR to buffer form.
17 // Currently it supports MHLO and some operations from the Standard dialect.
18 
19 #include <memory>
20 #include <utility>
21 
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
27 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/BuiltinAttributes.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/BuiltinTypes.h"
33 #include "mlir/IR/MLIRContext.h"
34 #include "mlir/IR/Operation.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/Support/LogicalResult.h"
37 #include "mlir/Transforms/DialectConversion.h"
38 
39 namespace mlir {
40 namespace mhlo {
41 namespace {
42 
43 // Generic pattern that rewrites any op by rewriting its operands and result
44 // types. Regions are also rewritten.
45 class ConvertToSignless : public ConversionPattern {
46  public:
ConvertToSignless(TypeConverter & typeConverter,MLIRContext * context)47   ConvertToSignless(TypeConverter& typeConverter, MLIRContext* context)
48       : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
49 
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const50   LogicalResult matchAndRewrite(
51       Operation* op, ArrayRef<Value> operands,
52       ConversionPatternRewriter& rewriter) const final {
53     SmallVector<Type> resultTypes;
54     if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
55       return failure();
56 
57     auto* newOp = Operation::create(op->getLoc(), op->getName(), resultTypes,
58                                     operands, op->getAttrs(),
59                                     op->getSuccessors(), op->getNumRegions());
60     for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
61       Region& before = std::get<0>(regions);
62       Region& parent = std::get<1>(regions);
63       rewriter.inlineRegionBefore(before, parent, parent.end());
64       if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
65         return failure();
66     }
67     rewriter.insert(newOp);
68     rewriter.replaceOp(op, newOp->getResults());
69     return success();
70   }
71 };
72 
73 // A pattern that converts the type of the attribute used as an operand for
74 // arith.constant
75 class ConvertConstantToSignless
76     : public OpConversionPattern<arith::ConstantOp> {
77  public:
ConvertConstantToSignless(TypeConverter & typeConverter,MLIRContext * context)78   ConvertConstantToSignless(TypeConverter& typeConverter, MLIRContext* context)
79       : OpConversionPattern<arith::ConstantOp>(typeConverter, context) {}
80 
matchAndRewrite(arith::ConstantOp constantOp,arith::ConstantOpAdaptor adaptor,ConversionPatternRewriter & rewriter) const81   LogicalResult matchAndRewrite(
82       arith::ConstantOp constantOp, arith::ConstantOpAdaptor adaptor,
83       ConversionPatternRewriter& rewriter) const override {
84     // We only care about unsigned integers
85     if (!adaptor.getValue().isa<DenseIntElementsAttr>()) return failure();
86 
87     auto values = llvm::to_vector(
88         adaptor.getValue().cast<DenseIntElementsAttr>().getValues<APInt>());
89     auto newValues = DenseIntElementsAttr::get(
90         typeConverter->convertType(constantOp.getType()), values);
91 
92     rewriter.replaceOpWithNewOp<arith::ConstantOp>(constantOp, newValues);
93     return success();
94   }
95 };
96 
97 struct ConvertToSignlessPass
98     : public ConvertToSignlessPassBase<ConvertToSignlessPass> {
99  public:
runOnOperationmlir::mhlo::__anon881b8e4a0111::ConvertToSignlessPass100   void runOnOperation() override {
101     auto& context = getContext();
102     ConversionTarget target(context);
103 
104     mhlo::RemoveSignTypeConverter converter;
105     target.markUnknownOpDynamicallyLegal([&](auto op) {
106       return converter.isLegal(op->getOperandTypes()) &&
107              converter.isLegal(op->getResultTypes());
108     });
109     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
110       return converter.isSignatureLegal(op.getFunctionType());
111     });
112     target.addDynamicallyLegalOp<arith::ConstantOp>([&](arith::ConstantOp op) {
113       return converter.isLegal(op.getType()) &&
114              converter.isLegal(op.getValue().getType());
115     });
116 
117     RewritePatternSet patterns(&getContext());
118     patterns.add<ConvertToSignless, ConvertConstantToSignless>(converter,
119                                                                &context);
120     // FuncOp is special as it has type encoding via attributes.
121     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
122                                                                    converter);
123 
124     auto module = getOperation();
125     if (failed(applyFullConversion(module, target, std::move(patterns)))) {
126       signalPassFailure();
127     }
128   }
129 };
130 
131 }  // namespace
132 
createConvertToSignlessPass()133 std::unique_ptr<OperationPass<ModuleOp>> createConvertToSignlessPass() {
134   return std::make_unique<ConvertToSignlessPass>();
135 }
136 
137 }  // namespace mhlo
138 }  // namespace mlir
139