1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <utility>
18
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
23 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
24
25 namespace tensorflow {
26
27 using llvm::APInt;
28 using llvm::ArrayRef;
29 using llvm::dyn_cast;
30 using llvm::Optional;
31 using llvm::SmallVector;
32 using mlir::ConversionPattern;
33 using mlir::ConversionPatternRewriter;
34 using mlir::ConversionTarget;
35 using mlir::DenseElementsAttr;
36 using mlir::DenseIntElementsAttr;
37 using mlir::IntegerType;
38 using mlir::LogicalResult;
39 using mlir::MLIRContext;
40 using mlir::NamedAttribute;
41 using mlir::Operation;
42 using mlir::OperationPass;
43 using mlir::OperationState;
44 using mlir::RankedTensorType;
45 using mlir::Region;
46 using mlir::RewritePatternSet;
47 using mlir::ShapedType;
48 using mlir::Type;
49 using mlir::TypeConverter;
50 using mlir::Value;
51 using mlir::func::FuncOp;
52
53 #define GEN_PASS_CLASSES
54 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
55
PromoteI1ToI8(Type input_type)56 static Optional<Type> PromoteI1ToI8(Type input_type) {
57 if (auto integer_type = input_type.dyn_cast<IntegerType>()) {
58 if (integer_type.getWidth() == 1)
59 return integer_type.scaleElementBitwidth(8);
60 }
61
62 return llvm::None;
63 }
64
65 /// TypeConverter that turns 'i1' tensors into 'i8' tensors.
66 class I1TypeConverter : public mlir::TypeConverter {
67 public:
68 using TypeConverter::convertType;
69
I1TypeConverter()70 I1TypeConverter() {
71 // Catch-all type conversion.
72 addConversion([](Type type) { return type; });
73
74 addConversion([](RankedTensorType tensor_type) -> Optional<Type> {
75 auto maybe_promoted_i8_type = PromoteI1ToI8(tensor_type.getElementType());
76 if (!maybe_promoted_i8_type) return tensor_type;
77 return RankedTensorType::get(tensor_type.getShape(),
78 *maybe_promoted_i8_type);
79 });
80 }
81 };
82
isLegalType(const Type type)83 static bool isLegalType(const Type type) {
84 if (auto tensor_type = type.dyn_cast<RankedTensorType>()) {
85 if (auto integer_type =
86 tensor_type.getElementType().dyn_cast<IntegerType>()) {
87 return integer_type.getWidth() != 1;
88 }
89 }
90
91 return true;
92 }
93
isLegalAttribute(NamedAttribute attr)94 static bool isLegalAttribute(NamedAttribute attr) {
95 if (auto int_attr = attr.getValue().dyn_cast<DenseIntElementsAttr>()) {
96 // Only RankedTensorType is expected.
97 ShapedType shaped_type = int_attr.getType();
98 if (!shaped_type.isa<RankedTensorType>()) return true;
99 return !shaped_type.getElementType().isInteger(/*width=*/1);
100 }
101
102 // TODO(diegocaballero): Add support for TypeAttr if/when we have a use case.
103
104 return true;
105 }
106
convertAttribute(NamedAttribute attr,ConversionPatternRewriter & rewriter)107 static NamedAttribute convertAttribute(NamedAttribute attr,
108 ConversionPatternRewriter &rewriter) {
109 if (auto int_attr = attr.getValue().dyn_cast<DenseIntElementsAttr>()) {
110 ShapedType shaped_type = int_attr.getType();
111 // Only RankedTensorType is expected.
112 if (!shaped_type.isa<RankedTensorType>()) return attr;
113 if (!shaped_type.getElementType().isInteger(/*width=*/1)) return attr;
114
115 // Convert internal bool attribute representation to 8-bit integer.
116 SmallVector<APInt, 4> new_i8_values;
117 for (bool bool_val : int_attr.getValues<bool>()) {
118 new_i8_values.push_back(
119 bool_val ? APInt::getOneBitSet(/*numBits=*/8, /*bitNo=*/0)
120 : APInt::getZero(/*numBits=*/8));
121 }
122
123 auto i8_tensor_type =
124 RankedTensorType::get(shaped_type.getShape(), rewriter.getI8Type());
125 return NamedAttribute(
126 attr.getName(), DenseElementsAttr::get(i8_tensor_type, new_i8_values));
127 }
128
129 // TODO(diegocaballero): Add support for TypeAttr if/when we have a use case.
130
131 return attr;
132 }
133
134 /// Generic conversion pattern that replaces any operation (except FuncOp) using
135 /// 'i1' tensors with the same operation using 'i8' tensors.
136 struct I1ToI8GenericConversionPattern : public ConversionPattern {
137 using ConversionPattern::ConversionPattern;
138
I1ToI8GenericConversionPatterntensorflow::I1ToI8GenericConversionPattern139 I1ToI8GenericConversionPattern(I1TypeConverter &type_converter,
140 MLIRContext *context)
141 : ConversionPattern(type_converter, MatchAnyOpTypeTag(),
142 /*benefit=*/1, context) {}
143
matchAndRewritetensorflow::I1ToI8GenericConversionPattern144 LogicalResult matchAndRewrite(
145 Operation *op, ArrayRef<Value> converted_operands,
146 ConversionPatternRewriter &rewriter) const override {
147 // Convert attributes.
148 SmallVector<NamedAttribute, 4> new_attrs;
149 for (NamedAttribute attr : op->getAttrs())
150 new_attrs.push_back(convertAttribute(attr, rewriter));
151
152 // Convert result types.
153 SmallVector<Type, 4> new_result_types;
154 if (failed(typeConverter->convertTypes(op->getResultTypes(),
155 new_result_types)))
156 return mlir::failure();
157
158 // Create a new op using the converted attributes, operands and result
159 // types. If the existing op has regions, we move them to the new op and
160 // convert their signature.
161 OperationState new_op_state(op->getLoc(), op->getName().getStringRef(),
162 converted_operands, new_result_types, new_attrs,
163 op->getSuccessors());
164
165 for (Region ®ion : op->getRegions()) {
166 Region *new_region = new_op_state.addRegion();
167 rewriter.inlineRegionBefore(region, *new_region, new_region->begin());
168
169 TypeConverter::SignatureConversion signature_conv(
170 new_region->getNumArguments());
171 if (failed(typeConverter->convertSignatureArgs(
172 new_region->getArgumentTypes(), signature_conv)))
173 return mlir::failure();
174 rewriter.applySignatureConversion(new_region, signature_conv);
175 }
176
177 Operation *new_op = rewriter.create(new_op_state);
178 rewriter.replaceOp(op, new_op->getResults());
179 return mlir::success();
180 }
181 };
182
populateI1TypeConversionPatterns(I1TypeConverter & type_converter,RewritePatternSet & patterns)183 static void populateI1TypeConversionPatterns(I1TypeConverter &type_converter,
184 RewritePatternSet &patterns) {
185 patterns.add<I1ToI8GenericConversionPattern>(type_converter,
186 patterns.getContext());
187 mlir::populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(
188 patterns, type_converter);
189 }
190
191 struct JitRtLegalizeI1TypesPass
192 : public JitRtLegalizeI1TypesBase<JitRtLegalizeI1TypesPass> {
runOnOperationtensorflow::JitRtLegalizeI1TypesPass193 void runOnOperation() override {
194 MLIRContext &context = getContext();
195 I1TypeConverter type_converter;
196
197 ConversionTarget target(context);
198 target.markUnknownOpDynamicallyLegal([](Operation *op) {
199 // Check legality of attributes.
200 auto attrs = op->getAttrs();
201 if (std::any_of(attrs.begin(), attrs.end(), [&](NamedAttribute attr) {
202 return !isLegalAttribute(attr);
203 }))
204 return false;
205
206 // Check legality of FuncOp.
207 if (FuncOp func_op = dyn_cast<FuncOp>(op)) {
208 auto input_types = func_op.getFunctionType().getInputs();
209 auto result_types = func_op.getFunctionType().getResults();
210 return std::all_of(
211 input_types.begin(), input_types.end(),
212 [&](const Type type) { return isLegalType(type); }) &&
213 std::all_of(result_types.begin(), result_types.end(),
214 [&](const Type type) { return isLegalType(type); });
215 }
216
217 // Check legality of any other op.
218 auto operand_types = op->getOperandTypes();
219 auto result_types = op->getResultTypes();
220 return std::all_of(operand_types.begin(), operand_types.end(),
221 [](Type type) { return isLegalType(type); }) &&
222 std::all_of(result_types.begin(), result_types.end(),
223 [](Type type) { return isLegalType(type); });
224 });
225
226 RewritePatternSet patterns(&context);
227 populateI1TypeConversionPatterns(type_converter, patterns);
228 if (failed(
229 applyFullConversion(getOperation(), target, std::move(patterns))))
230 signalPassFailure();
231 }
232 };
233
234 std::unique_ptr<OperationPass<mlir::ModuleOp>>
CreateJitRtLegalizeI1TypesPass()235 CreateJitRtLegalizeI1TypesPass() {
236 return std::make_unique<JitRtLegalizeI1TypesPass>();
237 }
238
239 } // namespace tensorflow
240