1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <string>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 #include "llvm/ADT/Twine.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
32 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
33 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
34 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
35 #include "mlir/IR/Attributes.h" // from @llvm-project
36 #include "mlir/IR/Builders.h" // from @llvm-project
37 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
38 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
39 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
40 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
41 #include "mlir/IR/FunctionImplementation.h" // from @llvm-project
42 #include "mlir/IR/MLIRContext.h" // from @llvm-project
43 #include "mlir/IR/Matchers.h" // from @llvm-project
44 #include "mlir/IR/OpDefinition.h" // from @llvm-project
45 #include "mlir/IR/OpImplementation.h" // from @llvm-project
46 #include "mlir/IR/PatternMatch.h" // from @llvm-project
47 #include "mlir/IR/Types.h" // from @llvm-project
48 #include "mlir/Support/LogicalResult.h" // from @llvm-project
49 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
50 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
52 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
53 #include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
54
55 namespace mlir {
56
57 namespace TFR {
58
59 //===----------------------------------------------------------------------===//
60 // InlinerInterface
61 //===----------------------------------------------------------------------===//
62
63 namespace {
64 /// This class defines the interface for inlining within the TFR dialect.
65 class TFRInlinerInterface : public DialectInlinerInterface {
66 using DialectInlinerInterface::DialectInlinerInterface;
67
68 public:
69 // Allow all call operations to be inlined.
isLegalToInline(Operation * call,Operation * callable,bool wouldBeCloned) const70 bool isLegalToInline(Operation *call, Operation *callable,
71 bool wouldBeCloned) const final {
72 return true;
73 }
74 // Returns true if the given region 'src' can be inlined into the region
75 // 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline(Region * dest,Region * src,bool wouldBeCloned,BlockAndValueMapping &) const76 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
77 BlockAndValueMapping &) const final {
78 return true;
79 }
80
81 // Returns true if the given operation 'op', that is registered to this
82 // dialect, can be inlined into the region 'dest' that is attached to an
83 // operation registered to the current dialect.
isLegalToInline(Operation * op,Region * dest,bool wouldBeCloned,BlockAndValueMapping &) const84 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
85 BlockAndValueMapping &) const final {
86 return true;
87 }
88
89 // Handle the given inlined terminator by replacing it with a new operation
90 // as necessary. Required when the region has only one block.
handleTerminator(Operation * op,ArrayRef<Value> valuesToRepl) const91 void handleTerminator(Operation *op,
92 ArrayRef<Value> valuesToRepl) const final {
93 auto retValOp = dyn_cast<TFRReturnOp>(op);
94 if (!retValOp) return;
95
96 for (auto ret_value : llvm::zip(valuesToRepl, retValOp.operands())) {
97 std::get<0>(ret_value).replaceAllUsesWith(std::get<1>(ret_value));
98 }
99 }
100
101 // Attempts to materialize a conversion for a type mismatch between a call
102 // from this dialect, and a callable region. This method should generate an
103 // operation that takes 'input' as the only operand, and produces a single
104 // result of 'resultType'. If a conversion can not be generated, nullptr
105 // should be returned.
materializeCallConversion(OpBuilder & builder,Value input,Type result_type,Location conversion_loc) const106 Operation *materializeCallConversion(OpBuilder &builder, Value input,
107 Type result_type,
108 Location conversion_loc) const final {
109 if (!input.getType().isa<IntegerType>() ||
110 !result_type.isa<IntegerType>()) {
111 return nullptr;
112 }
113 auto input_itype = input.getType().cast<IntegerType>();
114 auto result_itype = result_type.cast<IntegerType>();
115 if (input_itype.getWidth() == result_itype.getWidth()) return nullptr;
116 if (input_itype.getWidth() > result_itype.getWidth()) {
117 return builder.create<arith::TruncIOp>(conversion_loc, result_type,
118 input);
119 } else {
120 return builder.create<arith::ExtSIOp>(conversion_loc, result_type, input);
121 }
122 }
123 };
124 } // namespace
125
126 //===----------------------------------------------------------------------===//
127 // TFR Dialect
128 //===----------------------------------------------------------------------===//
129
TFRDialect(MLIRContext * context)130 TFRDialect::TFRDialect(MLIRContext *context)
131 : Dialect(/*name=*/"tfr", context, TypeID::get<TFRDialect>()) {
132 // TFR depends on TensorFlow for its canonicalization
133 context->getOrLoadDialect<TF::TensorFlowDialect>();
134
135 addTypes<TFRTensorType, TFRTensorListType, TFRAttrType>();
136 addOperations<
137 #define GET_OP_LIST
138 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
139 >();
140
141 addInterfaces<TFRInlinerInterface>();
142 }
143
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)144 Operation *TFRDialect::materializeConstant(OpBuilder &builder, Attribute value,
145 Type type, Location loc) {
146 if (arith::ConstantOp::isBuildableWith(value, type))
147 return builder.create<arith::ConstantOp>(loc, type, value);
148 if (func::ConstantOp::isBuildableWith(value, type))
149 return builder.create<func::ConstantOp>(loc, type,
150 value.cast<FlatSymbolRefAttr>());
151 return nullptr;
152 }
153
classof(Type type)154 bool TFRType::classof(Type type) {
155 return llvm::isa<TFRDialect>(type.getDialect());
156 }
157
158 //===----------------------------------------------------------------------===//
159 // Custom op methods
160 //===----------------------------------------------------------------------===//
161
verify()162 LogicalResult ConstantTensorOp::verify() {
163 ConstantTensorOp op = *this;
164 auto input_type = op.arg().getType();
165 auto output_type = op.out().getType();
166
167 if (auto output_tensor_type = output_type.dyn_cast<TFRTensorType>()) {
168 return success();
169 }
170
171 auto output_tensor_type = output_type.dyn_cast<RankedTensorType>();
172 if (!output_tensor_type || !output_tensor_type.hasStaticShape()) {
173 op.emitError("output type should be static and ranked.");
174 return failure();
175 }
176
177 if (output_tensor_type.getRank() == 0) {
178 bool same_scalar = output_tensor_type.getElementType() == input_type;
179 if (!same_scalar) {
180 op.emitError("input and output should have the same scalar types.");
181 }
182 return success(same_scalar);
183 }
184
185 if (auto input_vector_type = input_type.dyn_cast<VectorType>()) {
186 bool same_element_type = output_tensor_type.getElementType() ==
187 input_vector_type.getElementType();
188 bool same_shape =
189 output_tensor_type.getShape() == input_vector_type.getShape();
190 if (!same_element_type || !same_shape) {
191 op.emitError("input and output should have same shape and element type.");
192 }
193 return success(same_element_type && same_shape);
194 }
195
196 op.emitError("input can not be converted to an output tensor.");
197 return failure();
198 }
199
verify()200 LogicalResult TFRFuncOp::verify() {
201 TFRFuncOp func = *this;
202 // Collect all attribute names used by the tensor and tensor list arguments
203 // and returns. Also, collect the names of all the attribute arguments as the
204 // defined list. Later on, the used attribute names will be verified to be in
205 // the defined list.
206 llvm::SmallVector<StringAttr, 4> input_used_attrs, output_used_attrs;
207
208 // While scanning the arguments, record the start/end indices of each argument
209 // type, so the order can be verified as well.
210 // TODO(fengliuai): the attribute arguments with default values need to be
211 // at the end?
212 int first_tensor = -1, last_tensor = -1, first_tensor_list = -1,
213 last_tensor_list = -1, first_attr = -1;
214 for (auto arg : llvm::enumerate(func.getFunctionType().getInputs())) {
215 Type arg_type = arg.value();
216
217 if (auto tensor = arg_type.dyn_cast<TFRTensorType>()) {
218 if (first_tensor == -1) {
219 first_tensor = arg.index();
220 }
221 last_tensor = arg.index();
222 auto used = tensor.getAttrKeys();
223 input_used_attrs.append(used.begin(), used.end());
224 continue;
225 }
226
227 if (auto tensor_list = arg_type.dyn_cast<TFRTensorListType>()) {
228 if (first_tensor_list == -1) {
229 first_tensor_list = arg.index();
230 }
231 last_tensor_list = arg.index();
232 auto used = tensor_list.getAttrKeys();
233 input_used_attrs.append(used.begin(), used.end());
234 continue;
235 }
236
237 if (!arg_type.isa<TensorType>()) {
238 if (first_attr == -1) {
239 first_attr = arg.index();
240 }
241 auto name =
242 func.getArgAttrOfType<StringAttr>(arg.index(), kAttrArgumentNameAttr);
243 if (!name) {
244 func.emitError(
245 llvm::Twine(arg.index()) +
246 " attribute argument doesn't have a tfr.name attribute.");
247 return failure();
248 }
249 continue;
250 }
251
252 func.emitError("Builtin TensorType isn't allowed as the argument.");
253 return failure();
254 }
255
256 // Collect all the undefined attributes used in the inputs.
257 llvm::SmallVector<StringAttr, 4> undefined_attrs;
258 for (auto attr : input_used_attrs) {
259 if (!func->getAttr(attr.getValue())) {
260 undefined_attrs.push_back(attr);
261 }
262 }
263
264 // Verify the argument order: tensors, tensor list, attributes; and also
265 // verify there is at most one tensor list argument.
266 if (first_attr != -1 &&
267 (first_attr < last_tensor_list || first_attr < last_tensor)) {
268 func.emitError(
269 "tfr.tensor/tfr.tensor_list argument should be before non tensor "
270 "arguments.");
271 return failure();
272 }
273 // The order between tensor arguments and tensor list arguments and the number
274 // of tensor list arguments are verified only when they couldn't be determined
275 // by the attributes.
276 if (!undefined_attrs.empty()) {
277 if (first_tensor_list != -1 && first_tensor_list < last_tensor) {
278 func.emitError(
279 "tfr.tensor argument should be before tfr.tensor_list argument.");
280 return failure();
281 }
282 if (first_tensor_list != last_tensor_list) {
283 func.emitError("More than one tfr.tensor_list argument isn't allowed.");
284 return failure();
285 }
286 }
287
288 // Verify the result order: tensor, tensor list, and also verify at most one
289 // tensor list result.
290 int undefined_input_attrs_number = undefined_attrs.size();
291 bool seen_tensor_list = false, has_tensor_list_order_error = false,
292 has_multiple_tensor_lists_error = false;
293 for (auto result_type : func.getFunctionType().getResults()) {
294 if (auto tensor = result_type.dyn_cast<TFRTensorType>()) {
295 if (seen_tensor_list) {
296 has_tensor_list_order_error = true;
297 } else {
298 auto used = tensor.getAttrKeys();
299 output_used_attrs.append(used.begin(), used.end());
300 }
301 continue;
302 }
303
304 if (auto tensor_list = result_type.dyn_cast<TFRTensorListType>()) {
305 if (seen_tensor_list) {
306 has_multiple_tensor_lists_error = true;
307 } else {
308 seen_tensor_list = true;
309 auto used = tensor_list.getAttrKeys();
310 output_used_attrs.append(used.begin(), used.end());
311 }
312 continue;
313 }
314
315 func.emitError(
316 "None tfr.tensor/tfr.tensor_list results aren't allowed as a "
317 "result.");
318 return failure();
319 }
320
321 // Collect all the undefined attributes used in the outputs.
322 for (auto attr : output_used_attrs) {
323 if (!func->getAttr(attr.getValue())) {
324 undefined_attrs.push_back(attr);
325 }
326 }
327
328 // Verify there are no tensor/tensor list order error and multiple tensor
329 // list arguments error.
330 if (undefined_input_attrs_number != undefined_attrs.size()) {
331 if (has_tensor_list_order_error) {
332 func.emitError(
333 "tfr.tensor result should be before tfr.tensor_list result.");
334 return failure();
335 } else if (has_multiple_tensor_lists_error) {
336 func.emitError("More than one tfr.tensor_list result isn't allowed.");
337 return failure();
338 }
339 }
340
341 // TODO(fengliuai): We might want to refine this constraint because the
342 // tensor element type can be derived.
343 if (!undefined_attrs.empty()) {
344 llvm::SmallVector<std::string, 4> attr_names(undefined_attrs.size());
345 std::transform(undefined_attrs.begin(), undefined_attrs.end(),
346 attr_names.begin(),
347 [](StringAttr attr) { return attr.getValue().str(); });
348 func.emitError(llvm::Twine("Undefined attributes are used: ",
349 llvm::join(attr_names, ",")));
350 return failure();
351 }
352
353 return success();
354 }
355
parse(OpAsmParser & parser,OperationState & result)356 ParseResult TFRFuncOp::parse(OpAsmParser &parser, OperationState &result) {
357 auto build_func_type =
358 [](Builder &builder, ArrayRef<Type> arg_types, ArrayRef<Type> results,
359 function_interface_impl::VariadicFlag,
360 std::string &) { return builder.getFunctionType(arg_types, results); };
361 return function_interface_impl::parseFunctionOp(
362 parser, result, /*allowVariadic=*/false, build_func_type);
363 }
364
print(OpAsmPrinter & p)365 void TFRFuncOp::print(OpAsmPrinter &p) {
366 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
367 }
368
369 } // namespace TFR
370 } // namespace mlir
371
372 //===----------------------------------------------------------------------===//
373 // TableGen'd op method definitions
374 //===----------------------------------------------------------------------===//
375
376 #define GET_OP_CLASSES
377 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
378
379 namespace mlir {
380 namespace TFR {
381 namespace {
382 class ConvertConstToTensorConst : public OpRewritePattern<ConstantTensorOp> {
383 using OpRewritePattern<ConstantTensorOp>::OpRewritePattern;
384
385 public:
matchAndRewrite(ConstantTensorOp cst_tensor_op,PatternRewriter & rewriter) const386 LogicalResult matchAndRewrite(ConstantTensorOp cst_tensor_op,
387 PatternRewriter &rewriter) const override {
388 Location loc = cst_tensor_op.getLoc();
389 Type out_type = cst_tensor_op.getType();
390 Operation *new_cst = nullptr;
391
392 ArrayAttr array;
393 if (matchPattern(cst_tensor_op.arg(), m_Constant(&array))) {
394 llvm::DenseSet<Type> all_types;
395 for (auto it : array) {
396 TypedAttr typed_attr = it.dyn_cast<TypedAttr>();
397 if (!typed_attr) return failure();
398 all_types.insert(typed_attr.getType());
399 }
400 if (all_types.size() != 1) return failure();
401 ShapedType new_out_type = RankedTensorType::get(
402 {static_cast<int64_t>(array.size())}, *all_types.begin());
403 DenseElementsAttr attr =
404 DenseElementsAttr::get(new_out_type, array.getValue());
405 new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, attr);
406 if (out_type.isa<TFRTensorType>()) {
407 new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
408 }
409 rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
410 return success();
411 }
412
413 TypedAttr scalar;
414 if (matchPattern(cst_tensor_op.arg(), m_Constant(&scalar))) {
415 Type new_out_type = RankedTensorType::get({}, scalar.getType());
416 new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, scalar);
417 if (out_type.isa<TFRTensorType>()) {
418 new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
419 }
420 rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
421 return success();
422 }
423 return failure();
424 }
425 };
426
isQuantizedType(Type type)427 inline bool isQuantizedType(Type type) {
428 auto tensor_type = type.dyn_cast<TensorType>();
429 return (tensor_type &&
430 tensor_type.getElementType().isa<quant::QuantizedType>());
431 }
432
433 class RemoveRedundantCast : public OpRewritePattern<CastOp> {
434 using OpRewritePattern<CastOp>::OpRewritePattern;
435
436 public:
matchAndRewrite(CastOp cast_op,PatternRewriter & rewriter) const437 LogicalResult matchAndRewrite(CastOp cast_op,
438 PatternRewriter &rewriter) const override {
439 auto preceding_cast =
440 llvm::dyn_cast_or_null<CastOp>(cast_op.arg().getDefiningOp());
441 if (!preceding_cast) {
442 return failure();
443 }
444 Value input = preceding_cast.arg();
445 Type input_type = input.getType();
446 Type output_type = cast_op.getType();
447
448 // Preserve quantization information for intermediate tensors.
449 auto intermediate_type = preceding_cast.getType();
450 if (isQuantizedType(intermediate_type) || isQuantizedType(output_type)) {
451 return failure();
452 }
453
454 auto input_tensor_type = input_type.dyn_cast<TensorType>();
455 auto output_tensor_type = output_type.dyn_cast<TensorType>();
456 if (!input_tensor_type || !output_tensor_type) {
457 return failure();
458 }
459
460 // Canonicalize two tfr.cast pairs with different element type to
461 // two tfr.casts with the same element type followed by a tf.Cast.
462 if ((input_tensor_type.getElementType() !=
463 output_tensor_type.getElementType()) &&
464 !isQuantizedType(input_type) && !isQuantizedType(output_type)) {
465 auto new_tfr_cast = rewriter.create<TFR::CastOp>(
466 cast_op.getLoc(),
467 output_tensor_type.clone(input_tensor_type.getElementType()),
468 cast_op.arg());
469 rewriter.replaceOpWithNewOp<TF::CastOp>(cast_op, output_type,
470 new_tfr_cast);
471 return success();
472 }
473
474 // If the two types are the same, the back-to-back tfr.cast ops can be
475 // removed.
476 if (input_type == output_type || output_type.isa<UnrankedTensorType>()) {
477 rewriter.replaceOp(cast_op, {input});
478 return success();
479 }
480
481 // If the rank of the input tensor isn't ranked, we replace the pair
482 // with tf.EnsureShape op so it can be removed after shape inference or
483 // confirmed at runtime.
484 if (input_type.isa<UnrankedTensorType>()) {
485 auto shape = output_type.cast<ShapedType>().getShape();
486 auto shape_attr = TF::ShapeAttr::get(rewriter.getContext(), shape);
487 rewriter.replaceOpWithNewOp<TF::EnsureShapeOp>(cast_op, output_type,
488 input, shape_attr);
489 return success();
490 }
491
492 return failure();
493 }
494 };
495
496 class GetTensorShape : public OpRewritePattern<GetShapeOp> {
497 using OpRewritePattern<GetShapeOp>::OpRewritePattern;
498
499 public:
matchAndRewrite(GetShapeOp shape_op,PatternRewriter & rewriter) const500 LogicalResult matchAndRewrite(GetShapeOp shape_op,
501 PatternRewriter &rewriter) const override {
502 Operation *preceding_op = shape_op.arg().getDefiningOp();
503 if (auto cast_op = llvm::dyn_cast_or_null<CastOp>(preceding_op)) {
504 // replace this pair by shape.shape_of, so the folding works.
505 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(shape_op, cast_op.arg());
506 return success();
507 }
508 return failure();
509 }
510 };
511
512 class RemoveRedundantGetElement : public OpRewritePattern<GetElementOp> {
513 using OpRewritePattern<GetElementOp>::OpRewritePattern;
514
515 public:
matchAndRewrite(GetElementOp ge_op,PatternRewriter & rewriter) const516 LogicalResult matchAndRewrite(GetElementOp ge_op,
517 PatternRewriter &rewriter) const override {
518 IntegerAttr index;
519 if (!matchPattern(ge_op.index(), m_Constant(&index))) {
520 return failure();
521 }
522 auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
523 ge_op.tensor_list().getDefiningOp());
524 if (!preceding_build_list ||
525 preceding_build_list.getNumOperands() <= index.getInt()) {
526 return failure();
527 }
528 Value input = preceding_build_list.getOperand(index.getInt());
529 Type output_type = ge_op.getType();
530 if (input.getType() != output_type &&
531 !output_type.isa<UnrankedTensorType>()) {
532 return failure();
533 }
534 rewriter.replaceOp(ge_op, {input});
535 return success();
536 }
537 };
538
539 class RemoveRedundantGetLength : public OpRewritePattern<GetLengthOp> {
540 using OpRewritePattern<GetLengthOp>::OpRewritePattern;
541
542 public:
matchAndRewrite(GetLengthOp gl_op,PatternRewriter & rewriter) const543 LogicalResult matchAndRewrite(GetLengthOp gl_op,
544 PatternRewriter &rewriter) const override {
545 auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
546 gl_op.tensor_list().getDefiningOp());
547 if (!preceding_build_list) {
548 return failure();
549 }
550 int64_t num_tensors = preceding_build_list.getNumOperands();
551 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
552 gl_op, rewriter.getIndexAttr(num_tensors));
553 return success();
554 }
555 };
556
557 class BuildConstantListAsAttr : public OpRewritePattern<BuildListOp> {
558 using OpRewritePattern<BuildListOp>::OpRewritePattern;
559
560 public:
matchAndRewrite(BuildListOp bl_op,PatternRewriter & rewriter) const561 LogicalResult matchAndRewrite(BuildListOp bl_op,
562 PatternRewriter &rewriter) const override {
563 SmallVector<Attribute, 4> array_list;
564 array_list.reserve(bl_op.getNumOperands());
565 for (const auto &operand : bl_op.getOperands()) {
566 Attribute array_elt;
567 if (!matchPattern(operand, m_Constant(&array_elt))) {
568 return failure();
569 }
570 array_list.push_back(array_elt);
571 }
572 auto array_attr = rewriter.getArrayAttr(array_list);
573 rewriter.replaceOpWithNewOp<TFR::ConstOp>(bl_op, array_attr);
574 return success();
575 }
576 };
577
getQuantizedElementType(CastOp cast_op)578 quant::QuantizedType getQuantizedElementType(CastOp cast_op) {
579 if (!cast_op || !cast_op.getInputElementType()) {
580 return {};
581 }
582 return cast_op.getInputElementType()
583 .cast<TypeAttr>()
584 .getValue()
585 .dyn_cast<quant::QuantizedType>();
586 }
587
588 class RemoveRawDataOp : public OpRewritePattern<TFRQuantRawDataOp> {
589 using OpRewritePattern<TFRQuantRawDataOp>::OpRewritePattern;
590
591 public:
matchAndRewrite(TFRQuantRawDataOp raw_data_op,PatternRewriter & rewriter) const592 LogicalResult matchAndRewrite(TFRQuantRawDataOp raw_data_op,
593 PatternRewriter &rewriter) const override {
594 auto preceding_op = raw_data_op.input().getDefiningOp();
595 if (isa<BuildListOp>(preceding_op)) {
596 return rewritePrecedingListOp(raw_data_op, rewriter);
597 }
598
599 auto preceding_cast = dyn_cast_or_null<CastOp>(preceding_op);
600 if (!preceding_cast || !getQuantizedElementType(preceding_cast)) {
601 return failure();
602 }
603 // If there are redundant casts, hoist output of raw data op originating op.
604 if (preceding_cast.arg().getDefiningOp()) {
605 auto redundant_cast = preceding_cast.arg().getDefiningOp<CastOp>();
606 if (!redundant_cast ||
607 redundant_cast.arg().getType() != preceding_cast.out().getType()) {
608 return failure();
609 }
610 raw_data_op.output().replaceAllUsesWith(redundant_cast.arg());
611 } else {
612 // If the argument of cast op is input, then simply remove the RawData op.
613 raw_data_op.output().replaceAllUsesWith(preceding_cast.out());
614 }
615 return success();
616 }
617
rewritePrecedingListOp(TFRQuantRawDataOp raw_data_op,PatternRewriter & rewriter) const618 LogicalResult rewritePrecedingListOp(TFRQuantRawDataOp raw_data_op,
619 PatternRewriter &rewriter) const {
620 llvm::SmallVector<Value> new_list_values;
621 auto preceding_list = raw_data_op.input().getDefiningOp<BuildListOp>();
622 for (Value operand : preceding_list.tensors()) {
623 auto preceding_cast = operand.getDefiningOp<CastOp>();
624 if (!preceding_cast || !getQuantizedElementType(preceding_cast)) {
625 return failure();
626 }
627
628 // This function currently only supports the case with redundant casts.
629 auto redundant_cast = preceding_cast.arg().getDefiningOp<CastOp>();
630 if (!redundant_cast ||
631 redundant_cast.arg().getType() != preceding_cast.out().getType()) {
632 return failure();
633 }
634
635 new_list_values.push_back(redundant_cast.arg());
636 }
637
638 auto new_list = rewriter.create<BuildListOp>(
639 raw_data_op.getLoc(), preceding_list.getType(), new_list_values);
640 raw_data_op.output().replaceAllUsesWith(new_list.out());
641 return success();
642 }
643 };
644
645 class RemoveQParamsOp : public OpRewritePattern<TFRQuantQParamsOp> {
646 using OpRewritePattern<TFRQuantQParamsOp>::OpRewritePattern;
647
648 public:
matchAndRewrite(TFRQuantQParamsOp qparams_op,PatternRewriter & rewriter) const649 LogicalResult matchAndRewrite(TFRQuantQParamsOp qparams_op,
650 PatternRewriter &rewriter) const override {
651 auto cast_op = dyn_cast<TFR::CastOp>(qparams_op.input().getDefiningOp());
652 auto cast_qtype = getQuantizedElementType(cast_op);
653 if (!cast_qtype) {
654 return failure();
655 }
656
657 TF::ConstOp scale_op;
658 TF::ConstOp zp_op;
659
660 // Reads quantization parameters from the quantized type, and converts
661 // them to constants.
662 rewriter.setInsertionPoint(qparams_op);
663 Location loc = qparams_op->getLoc();
664 if (auto qtype = cast_qtype.dyn_cast<quant::UniformQuantizedType>()) {
665 scale_op = rewriter.create<TF::ConstOp>(
666 loc, RankedTensorType::get({}, rewriter.getF32Type()),
667 rewriter.getF32FloatAttr(qtype.getScale()));
668 zp_op = rewriter.create<TF::ConstOp>(
669 loc, RankedTensorType::get({}, rewriter.getI32Type()),
670 rewriter.getI32IntegerAttr(qtype.getZeroPoint()));
671 } else if (auto qtype =
672 cast_qtype.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
673 SmallVector<float> scales(qtype.getScales().begin(),
674 qtype.getScales().end());
675 SmallVector<int32_t> zps(qtype.getZeroPoints().begin(),
676 qtype.getZeroPoints().end());
677 const size_t num_channels = qtype.getScales().size();
678
679 auto scales_type = RankedTensorType::get(
680 {static_cast<int64_t>(num_channels)}, rewriter.getF32Type());
681 auto scales_attr =
682 DenseElementsAttr::get(scales_type, llvm::makeArrayRef(scales));
683 scale_op = rewriter.create<TF::ConstOp>(loc, scales_attr);
684
685 auto zps_type = RankedTensorType::get(
686 {static_cast<int64_t>(num_channels)}, rewriter.getI32Type());
687 auto zps_attr = DenseElementsAttr::get(zps_type, llvm::makeArrayRef(zps));
688 zp_op = rewriter.create<TF::ConstOp>(loc, zps_attr);
689 }
690 if (!scale_op || !zp_op) {
691 return failure();
692 }
693 auto scale_cast = rewriter.create<CastOp>(loc, qparams_op.scale().getType(),
694 scale_op.output());
695 auto zp_cast =
696 rewriter.create<CastOp>(loc, qparams_op.zp().getType(), zp_op.output());
697
698 qparams_op.scale().replaceAllUsesWith(scale_cast.out());
699 qparams_op.zp().replaceAllUsesWith(zp_cast.out());
700 return success();
701 }
702 };
703
704 // TODO(b/193731721): Migrate tfr_ builtin canonicalizations to LowerTFROpPass
705 class RemoveScaleFactorOp : public OpRewritePattern<TFRQuantScaleFactorOp> {
706 using OpRewritePattern<TFRQuantScaleFactorOp>::OpRewritePattern;
707
708 public:
709 // Replace quant_scale_factor with constant tensor equivalent to
710 // TFR_ConstantTensorOp (
711 // ConstantOp (ConstAttr<F32Attr (in_scale[0] * in_scale[1] /
712 // out_scale))
713 // )
714 // Currently, all decompositions using this pattern (Conv2D, FC) have the
715 // following preconditions:
716 // * out_scale: float scalar attribute
717 // * in_scale[0] (input scale): float scalar, given by tf.Const -> tfr.cast
718 // * in_scale[1] (filter scale): float scalar/vector
719 // (per-tensor vs per-channel) quantization, given by tf.Const -> tfr.cast
matchAndRewrite(TFRQuantScaleFactorOp scale_factor_op,PatternRewriter & rewriter) const720 LogicalResult matchAndRewrite(TFRQuantScaleFactorOp scale_factor_op,
721 PatternRewriter &rewriter) const override {
722 auto out_scale_op =
723 scale_factor_op.out_scale().getDefiningOp<arith::ConstantOp>();
724 if (!out_scale_op) {
725 return failure();
726 }
727 const double out_scale =
728 out_scale_op.getValue().cast<FloatAttr>().getValueAsDouble();
729
730 auto in_scales_op =
731 scale_factor_op.in_scales().getDefiningOp<BuildListOp>();
732 if (!in_scales_op || in_scales_op.getNumOperands() != 2) {
733 // BuildListOp is variadic, but we require two values: input_scale
734 // and filter_scale.
735 return failure();
736 }
737
738 auto in_scale_op = in_scales_op.getOperand(0).getDefiningOp<CastOp>();
739 if (!in_scale_op) {
740 return failure();
741 }
742
743 DenseFPElementsAttr in_scale_attr;
744 if (!matchPattern(in_scale_op.arg(), m_Constant(&in_scale_attr)) ||
745 in_scale_attr.size() != 1) {
746 return failure();
747 }
748 const float in_scale = in_scale_attr.getValues<float>()[0];
749 auto filter_scale_op = in_scales_op.getOperand(1).getDefiningOp<CastOp>();
750 if (!filter_scale_op) {
751 return failure();
752 }
753 DenseFPElementsAttr filter_scale_attr;
754 if (!matchPattern(filter_scale_op.arg(), m_Constant(&filter_scale_attr))) {
755 return failure();
756 }
757
758 // The shape of scale_type is {} (rank 0) for per-tensor quantized tensor,
759 // and {num_channels} (rank 1) for per-channel quantized one.
760 auto scale_type = filter_scale_attr.getType().dyn_cast<RankedTensorType>();
761 if (scale_type.getRank() != 0 && scale_type.getRank() != 1) {
762 return failure();
763 }
764 SmallVector<float> scale_factors;
765 scale_factors.reserve(filter_scale_attr.size());
766 for (auto value : filter_scale_attr.getValues<APFloat>()) {
767 scale_factors.push_back(in_scale * value.convertToFloat() / out_scale);
768 }
769 rewriter.setInsertionPoint(scale_factor_op);
770 const Location loc = scale_factor_op->getLoc();
771 auto result_scale_op = rewriter.create<TF::ConstOp>(
772 loc,
773 DenseElementsAttr::get(scale_type, llvm::makeArrayRef(scale_factors)));
774 auto result_scale_cast_op = rewriter.create<CastOp>(
775 loc, scale_factor_op.getType(), result_scale_op.output());
776 scale_factor_op.scale_factor().replaceAllUsesWith(
777 result_scale_cast_op.out());
778 return success();
779 }
780 };
781
782 class RemoveRescaleOp : public OpRewritePattern<TFRQuantRescaleOp> {
783 using OpRewritePattern<TFRQuantRescaleOp>::OpRewritePattern;
784
785 public:
786 // Replace quant_rescale (input, scale, zp) with
787 // tf.Cast(tf.Round(tf.Cast(input, f32) * scale) + tf.Cast(zp, f32), i32)
matchAndRewrite(TFRQuantRescaleOp rescale_op,PatternRewriter & rewriter) const788 LogicalResult matchAndRewrite(TFRQuantRescaleOp rescale_op,
789 PatternRewriter &rewriter) const override {
790 Value input = rescale_op.input();
791 Value scale = rescale_op.scale();
792 Value zp = rescale_op.zp();
793
794 const Location loc = rescale_op->getLoc();
795 const auto result_types = rescale_op->getResultTypes();
796 auto c_false =
797 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(false));
798 TypeAttr f32_attr = TypeAttr::get(rewriter.getF32Type());
799 TFRAttrType output_type = TFRAttrType::get(rewriter.getContext());
800 auto constant_f32_op = rewriter.create<ConstOp>(loc, output_type, f32_attr);
801 TypeAttr i32_attr = TypeAttr::get(rewriter.getI32Type());
802 auto constant_i32_op = rewriter.create<ConstOp>(loc, output_type, i32_attr);
803
804 IntegerAttr zp_attr;
805 if (!matchPattern(zp, m_Constant(&zp_attr))) {
806 return failure();
807 }
808 rewriter.setInsertionPoint(zp.getDefiningOp());
809 auto zp_tensor = rewriter.create<TF::ConstOp>(
810 loc, RankedTensorType::get({}, zp.getType()), zp_attr);
811 auto zp_cast = rewriter.create<CastOp>(
812 loc, rewriter.getType<TFRTensorType>(), zp_tensor.output());
813
814 rewriter.setInsertionPoint(rescale_op);
815 auto cast_input_to_float_op = rewriter.create<CallOp>(
816 loc, result_types,
817 SymbolRefAttr::get(rewriter.getContext(), "tf__cast"),
818 ArrayRef<Value>{input, constant_f32_op, c_false});
819 auto input_x_scale_op = rewriter.create<CallOp>(
820 loc, result_types, SymbolRefAttr::get(rewriter.getContext(), "tf__mul"),
821 ArrayRef<Value>{cast_input_to_float_op.getResult(0), scale});
822 auto round_rescaled_op = rewriter.create<CallOp>(
823 loc, result_types,
824 SymbolRefAttr::get(rewriter.getContext(), "tf__round"),
825 ArrayRef<Value>{input_x_scale_op->getResult(0)});
826 auto cast_zp_to_float_op = rewriter.create<CallOp>(
827 loc, result_types,
828 SymbolRefAttr::get(rewriter.getContext(), "tf__cast"),
829 ArrayRef<Value>{zp_cast, constant_f32_op, c_false});
830 auto recentered_op = rewriter.create<CallOp>(
831 loc, result_types, SymbolRefAttr::get(rewriter.getContext(), "tf__add"),
832 ArrayRef<Value>{round_rescaled_op->getResult(0),
833 cast_zp_to_float_op->getResult(0)});
834 auto cast_output_to_i32 = rewriter.create<CallOp>(
835 loc, result_types,
836 SymbolRefAttr::get(rewriter.getContext(), "tf__cast"),
837 ArrayRef<Value>{recentered_op->getResult(0), constant_i32_op, c_false});
838 rescale_op.output().replaceAllUsesWith(cast_output_to_i32.getResult(0));
839 return success();
840 }
841 };
842
843 } // namespace
844
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)845 void ConstantTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
846 MLIRContext *context) {
847 results.add<ConvertConstToTensorConst>(context);
848 }
849
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)850 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
851 MLIRContext *context) {
852 results.add<RemoveRedundantCast>(context);
853 }
854
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)855 void GetShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
856 MLIRContext *context) {
857 results.add<GetTensorShape>(context);
858 }
859
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)860 void GetElementOp::getCanonicalizationPatterns(RewritePatternSet &results,
861 MLIRContext *context) {
862 results.add<RemoveRedundantGetElement>(context);
863 }
864
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)865 void GetLengthOp::getCanonicalizationPatterns(RewritePatternSet &results,
866 MLIRContext *context) {
867 results.add<RemoveRedundantGetLength>(context);
868 }
869
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)870 void BuildListOp::getCanonicalizationPatterns(RewritePatternSet &results,
871 MLIRContext *context) {
872 results.add<BuildConstantListAsAttr>(context);
873 }
874
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)875 void TFRQuantRawDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
876 MLIRContext *context) {
877 results.add<RemoveRawDataOp>(context);
878 }
879
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)880 void TFRQuantQParamsOp::getCanonicalizationPatterns(RewritePatternSet &results,
881 MLIRContext *context) {
882 results.add<RemoveQParamsOp>(context);
883 }
884
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)885 void TFRQuantRescaleOp::getCanonicalizationPatterns(RewritePatternSet &results,
886 MLIRContext *context) {
887 results.add<RemoveRescaleOp>(context);
888 }
889
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)890 void TFRQuantScaleFactorOp::getCanonicalizationPatterns(
891 RewritePatternSet &results, MLIRContext *context) {
892 results.add<RemoveScaleFactorOp>(context);
893 }
894
fold(ArrayRef<Attribute> operands)895 OpFoldResult TFR::EqualOp::fold(ArrayRef<Attribute> operands) {
896 assert(operands.size() == 2 && "equal op has two operands");
897 auto ctx = getContext();
898 if (operands[0] == operands[1]) return BoolAttr::get(ctx, true);
899 return BoolAttr::get(ctx, false);
900 }
901
fold(ArrayRef<Attribute> operands)902 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
903 assert(operands.empty() && "constant has no operands");
904
905 // Return the held attribute value.
906 return value();
907 }
908
909 // CallableOpInterface
getCallableRegion()910 Region *TFRFuncOp::getCallableRegion() {
911 return isExternal() ? nullptr : &body().front();
912 }
913
914 // CallableOpInterface
getCallableResults()915 ArrayRef<Type> TFRFuncOp::getCallableResults() {
916 return getFunctionType().getResults();
917 }
918
919 //===----------------------------------------------------------------------===//
920 // Dialect type definitions
921 //===----------------------------------------------------------------------===//
922
923 // Parses a TFR type.
924 // tfr_type ::= tensor_type | tensor_list_type | attr_type
925 // string_list ::= `[` string-literal (, string-literal)+ `]`
926 // tensor_type ::= `tensor`
927 // | `tensor<` (string-literal | string_list) '>'
928 // tensor_list_type ::= `tensor_list`
929 // | `tensor_list<` (string-literal | string_list) '>'
930 // attr_type ::= `attr`
parseType(DialectAsmParser & parser) const931 Type TFRDialect::parseType(DialectAsmParser &parser) const {
932 Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
933 MLIRContext *ctx = loc.getContext();
934
935 StringRef typeNameSpelling;
936 if (failed(parser.parseKeyword(&typeNameSpelling))) return {};
937 llvm::SmallVector<StringAttr, 4> attrs;
938 if (succeeded(parser.parseOptionalLess())) {
939 bool l_square_parsed = false;
940 if (succeeded(parser.parseOptionalLSquare())) {
941 l_square_parsed = true;
942 }
943
944 do {
945 StringRef attr;
946 if (failed(parser.parseKeyword(&attr))) return {};
947 attrs.push_back(StringAttr::get(ctx, attr));
948 } while (succeeded(parser.parseOptionalComma()));
949
950 if (l_square_parsed && failed(parser.parseRSquare())) {
951 parser.emitError(parser.getNameLoc(), "expected ']'");
952 }
953
954 if (failed(parser.parseGreater())) {
955 parser.emitError(parser.getNameLoc(), "expected '>'");
956 }
957 }
958
959 if (typeNameSpelling == "tensor") {
960 return TFRTensorType::getChecked(attrs, loc);
961 } else if (typeNameSpelling == "tensor_list") {
962 return TFRTensorListType::getChecked(attrs, loc);
963 } else if (typeNameSpelling == "attr") {
964 return TFRAttrType::getChecked(loc, loc.getContext());
965 } else {
966 parser.emitError(parser.getNameLoc(), "unknown type " + typeNameSpelling);
967 return {};
968 }
969 }
970
printType(Type type,DialectAsmPrinter & os) const971 void TFRDialect::printType(Type type, DialectAsmPrinter &os) const {
972 llvm::ArrayRef<StringAttr> attrs;
973
974 if (type.isa<TFRAttrType>()) {
975 os << "attr";
976 return;
977 }
978 if (auto tensor_ty = type.dyn_cast<TFRTensorType>()) {
979 attrs = tensor_ty.getAttrKeys();
980 os << "tensor";
981 } else if (auto tensor_list_ty = type.dyn_cast<TFRTensorListType>()) {
982 attrs = tensor_list_ty.getAttrKeys();
983 os << "tensor_list";
984 } else {
985 llvm_unreachable("Unhandled tfr type");
986 }
987
988 if (attrs.empty()) return;
989 os << "<";
990
991 if (attrs.size() > 1) {
992 os << "[";
993 }
994
995 llvm::interleaveComma(attrs, os,
996 [&](StringAttr attr) { os << attr.getValue(); });
997
998 if (attrs.size() > 1) {
999 os << "]";
1000 }
1001 os << ">";
1002 }
1003
1004 } // namespace TFR
1005 } // namespace mlir
1006