1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <utility>
18 
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
24 
25 namespace tensorflow {
26 
27 using llvm::APInt;
28 using llvm::ArrayRef;
29 using llvm::dyn_cast;
30 using llvm::Optional;
31 using llvm::SmallVector;
32 using mlir::ConversionPattern;
33 using mlir::ConversionPatternRewriter;
34 using mlir::ConversionTarget;
35 using mlir::DenseElementsAttr;
36 using mlir::DenseIntElementsAttr;
37 using mlir::IntegerType;
38 using mlir::LogicalResult;
39 using mlir::MLIRContext;
40 using mlir::NamedAttribute;
41 using mlir::Operation;
42 using mlir::OperationPass;
43 using mlir::OperationState;
44 using mlir::RankedTensorType;
45 using mlir::Region;
46 using mlir::RewritePatternSet;
47 using mlir::ShapedType;
48 using mlir::Type;
49 using mlir::TypeConverter;
50 using mlir::Value;
51 using mlir::func::FuncOp;
52 
53 #define GEN_PASS_CLASSES
54 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
55 
PromoteI1ToI8(Type input_type)56 static Optional<Type> PromoteI1ToI8(Type input_type) {
57   if (auto integer_type = input_type.dyn_cast<IntegerType>()) {
58     if (integer_type.getWidth() == 1)
59       return integer_type.scaleElementBitwidth(8);
60   }
61 
62   return llvm::None;
63 }
64 
65 /// TypeConverter that turns 'i1' tensors into 'i8' tensors.
66 class I1TypeConverter : public mlir::TypeConverter {
67  public:
68   using TypeConverter::convertType;
69 
I1TypeConverter()70   I1TypeConverter() {
71     // Catch-all type conversion.
72     addConversion([](Type type) { return type; });
73 
74     addConversion([](RankedTensorType tensor_type) -> Optional<Type> {
75       auto maybe_promoted_i8_type = PromoteI1ToI8(tensor_type.getElementType());
76       if (!maybe_promoted_i8_type) return tensor_type;
77       return RankedTensorType::get(tensor_type.getShape(),
78                                    *maybe_promoted_i8_type);
79     });
80   }
81 };
82 
isLegalType(const Type type)83 static bool isLegalType(const Type type) {
84   if (auto tensor_type = type.dyn_cast<RankedTensorType>()) {
85     if (auto integer_type =
86             tensor_type.getElementType().dyn_cast<IntegerType>()) {
87       return integer_type.getWidth() != 1;
88     }
89   }
90 
91   return true;
92 }
93 
isLegalAttribute(NamedAttribute attr)94 static bool isLegalAttribute(NamedAttribute attr) {
95   if (auto int_attr = attr.getValue().dyn_cast<DenseIntElementsAttr>()) {
96     // Only RankedTensorType is expected.
97     ShapedType shaped_type = int_attr.getType();
98     if (!shaped_type.isa<RankedTensorType>()) return true;
99     return !shaped_type.getElementType().isInteger(/*width=*/1);
100   }
101 
102   // TODO(diegocaballero): Add support for TypeAttr if/when we have a use case.
103 
104   return true;
105 }
106 
convertAttribute(NamedAttribute attr,ConversionPatternRewriter & rewriter)107 static NamedAttribute convertAttribute(NamedAttribute attr,
108                                        ConversionPatternRewriter &rewriter) {
109   if (auto int_attr = attr.getValue().dyn_cast<DenseIntElementsAttr>()) {
110     ShapedType shaped_type = int_attr.getType();
111     // Only RankedTensorType is expected.
112     if (!shaped_type.isa<RankedTensorType>()) return attr;
113     if (!shaped_type.getElementType().isInteger(/*width=*/1)) return attr;
114 
115     // Convert internal bool attribute representation to 8-bit integer.
116     SmallVector<APInt, 4> new_i8_values;
117     for (bool bool_val : int_attr.getValues<bool>()) {
118       new_i8_values.push_back(
119           bool_val ? APInt::getOneBitSet(/*numBits=*/8, /*bitNo=*/0)
120                    : APInt::getZero(/*numBits=*/8));
121     }
122 
123     auto i8_tensor_type =
124         RankedTensorType::get(shaped_type.getShape(), rewriter.getI8Type());
125     return NamedAttribute(
126         attr.getName(), DenseElementsAttr::get(i8_tensor_type, new_i8_values));
127   }
128 
129   // TODO(diegocaballero): Add support for TypeAttr if/when we have a use case.
130 
131   return attr;
132 }
133 
134 /// Generic conversion pattern that replaces any operation (except FuncOp) using
135 /// 'i1' tensors with the same operation using 'i8' tensors.
136 struct I1ToI8GenericConversionPattern : public ConversionPattern {
137   using ConversionPattern::ConversionPattern;
138 
I1ToI8GenericConversionPatterntensorflow::I1ToI8GenericConversionPattern139   I1ToI8GenericConversionPattern(I1TypeConverter &type_converter,
140                                  MLIRContext *context)
141       : ConversionPattern(type_converter, MatchAnyOpTypeTag(),
142                           /*benefit=*/1, context) {}
143 
matchAndRewritetensorflow::I1ToI8GenericConversionPattern144   LogicalResult matchAndRewrite(
145       Operation *op, ArrayRef<Value> converted_operands,
146       ConversionPatternRewriter &rewriter) const override {
147     // Convert attributes.
148     SmallVector<NamedAttribute, 4> new_attrs;
149     for (NamedAttribute attr : op->getAttrs())
150       new_attrs.push_back(convertAttribute(attr, rewriter));
151 
152     // Convert result types.
153     SmallVector<Type, 4> new_result_types;
154     if (failed(typeConverter->convertTypes(op->getResultTypes(),
155                                            new_result_types)))
156       return mlir::failure();
157 
158     // Create a new op using the converted attributes, operands and result
159     // types. If the existing op has regions, we move them to the new op and
160     // convert their signature.
161     OperationState new_op_state(op->getLoc(), op->getName().getStringRef(),
162                                 converted_operands, new_result_types, new_attrs,
163                                 op->getSuccessors());
164 
165     for (Region &region : op->getRegions()) {
166       Region *new_region = new_op_state.addRegion();
167       rewriter.inlineRegionBefore(region, *new_region, new_region->begin());
168 
169       TypeConverter::SignatureConversion signature_conv(
170           new_region->getNumArguments());
171       if (failed(typeConverter->convertSignatureArgs(
172               new_region->getArgumentTypes(), signature_conv)))
173         return mlir::failure();
174       rewriter.applySignatureConversion(new_region, signature_conv);
175     }
176 
177     Operation *new_op = rewriter.create(new_op_state);
178     rewriter.replaceOp(op, new_op->getResults());
179     return mlir::success();
180   }
181 };
182 
populateI1TypeConversionPatterns(I1TypeConverter & type_converter,RewritePatternSet & patterns)183 static void populateI1TypeConversionPatterns(I1TypeConverter &type_converter,
184                                              RewritePatternSet &patterns) {
185   patterns.add<I1ToI8GenericConversionPattern>(type_converter,
186                                                patterns.getContext());
187   mlir::populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(
188       patterns, type_converter);
189 }
190 
191 struct JitRtLegalizeI1TypesPass
192     : public JitRtLegalizeI1TypesBase<JitRtLegalizeI1TypesPass> {
runOnOperationtensorflow::JitRtLegalizeI1TypesPass193   void runOnOperation() override {
194     MLIRContext &context = getContext();
195     I1TypeConverter type_converter;
196 
197     ConversionTarget target(context);
198     target.markUnknownOpDynamicallyLegal([](Operation *op) {
199       // Check legality of attributes.
200       auto attrs = op->getAttrs();
201       if (std::any_of(attrs.begin(), attrs.end(), [&](NamedAttribute attr) {
202             return !isLegalAttribute(attr);
203           }))
204         return false;
205 
206       // Check legality of FuncOp.
207       if (FuncOp func_op = dyn_cast<FuncOp>(op)) {
208         auto input_types = func_op.getFunctionType().getInputs();
209         auto result_types = func_op.getFunctionType().getResults();
210         return std::all_of(
211                    input_types.begin(), input_types.end(),
212                    [&](const Type type) { return isLegalType(type); }) &&
213                std::all_of(result_types.begin(), result_types.end(),
214                            [&](const Type type) { return isLegalType(type); });
215       }
216 
217       // Check legality of any other op.
218       auto operand_types = op->getOperandTypes();
219       auto result_types = op->getResultTypes();
220       return std::all_of(operand_types.begin(), operand_types.end(),
221                          [](Type type) { return isLegalType(type); }) &&
222              std::all_of(result_types.begin(), result_types.end(),
223                          [](Type type) { return isLegalType(type); });
224     });
225 
226     RewritePatternSet patterns(&context);
227     populateI1TypeConversionPatterns(type_converter, patterns);
228     if (failed(
229             applyFullConversion(getOperation(), target, std::move(patterns))))
230       signalPassFailure();
231   }
232 };
233 
234 std::unique_ptr<OperationPass<mlir::ModuleOp>>
CreateJitRtLegalizeI1TypesPass()235 CreateJitRtLegalizeI1TypesPass() {
236   return std::make_unique<JitRtLegalizeI1TypesPass>();
237 }
238 
239 }  // namespace tensorflow
240