xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <string>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 #include "llvm/ADT/Twine.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
32 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
33 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
34 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
35 #include "mlir/IR/Attributes.h"  // from @llvm-project
36 #include "mlir/IR/Builders.h"  // from @llvm-project
37 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
38 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
39 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
40 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
41 #include "mlir/IR/FunctionImplementation.h"  // from @llvm-project
42 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
43 #include "mlir/IR/Matchers.h"  // from @llvm-project
44 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
45 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
46 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
47 #include "mlir/IR/Types.h"  // from @llvm-project
48 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
49 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
50 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
52 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
53 #include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
54 
55 namespace mlir {
56 
57 namespace TFR {
58 
59 //===----------------------------------------------------------------------===//
60 // InlinerInterface
61 //===----------------------------------------------------------------------===//
62 
63 namespace {
64 /// This class defines the interface for inlining within the TFR dialect.
65 class TFRInlinerInterface : public DialectInlinerInterface {
66   using DialectInlinerInterface::DialectInlinerInterface;
67 
68  public:
69   // Allow all call operations to be inlined.
isLegalToInline(Operation * call,Operation * callable,bool wouldBeCloned) const70   bool isLegalToInline(Operation *call, Operation *callable,
71                        bool wouldBeCloned) const final {
72     return true;
73   }
74   // Returns true if the given region 'src' can be inlined into the region
75   // 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline(Region * dest,Region * src,bool wouldBeCloned,BlockAndValueMapping &) const76   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
77                        BlockAndValueMapping &) const final {
78     return true;
79   }
80 
81   // Returns true if the given operation 'op', that is registered to this
82   // dialect, can be inlined into the region 'dest' that is attached to an
83   // operation registered to the current dialect.
isLegalToInline(Operation * op,Region * dest,bool wouldBeCloned,BlockAndValueMapping &) const84   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
85                        BlockAndValueMapping &) const final {
86     return true;
87   }
88 
89   // Handle the given inlined terminator by replacing it with a new operation
90   // as necessary. Required when the region has only one block.
handleTerminator(Operation * op,ArrayRef<Value> valuesToRepl) const91   void handleTerminator(Operation *op,
92                         ArrayRef<Value> valuesToRepl) const final {
93     auto retValOp = dyn_cast<TFRReturnOp>(op);
94     if (!retValOp) return;
95 
96     for (auto ret_value : llvm::zip(valuesToRepl, retValOp.operands())) {
97       std::get<0>(ret_value).replaceAllUsesWith(std::get<1>(ret_value));
98     }
99   }
100 
101   // Attempts to materialize a conversion for a type mismatch between a call
102   // from this dialect, and a callable region. This method should generate an
103   // operation that takes 'input' as the only operand, and produces a single
104   // result of 'resultType'. If a conversion can not be generated, nullptr
105   // should be returned.
materializeCallConversion(OpBuilder & builder,Value input,Type result_type,Location conversion_loc) const106   Operation *materializeCallConversion(OpBuilder &builder, Value input,
107                                        Type result_type,
108                                        Location conversion_loc) const final {
109     if (!input.getType().isa<IntegerType>() ||
110         !result_type.isa<IntegerType>()) {
111       return nullptr;
112     }
113     auto input_itype = input.getType().cast<IntegerType>();
114     auto result_itype = result_type.cast<IntegerType>();
115     if (input_itype.getWidth() == result_itype.getWidth()) return nullptr;
116     if (input_itype.getWidth() > result_itype.getWidth()) {
117       return builder.create<arith::TruncIOp>(conversion_loc, result_type,
118                                              input);
119     } else {
120       return builder.create<arith::ExtSIOp>(conversion_loc, result_type, input);
121     }
122   }
123 };
124 }  // namespace
125 
126 //===----------------------------------------------------------------------===//
127 // TFR Dialect
128 //===----------------------------------------------------------------------===//
129 
TFRDialect(MLIRContext * context)130 TFRDialect::TFRDialect(MLIRContext *context)
131     : Dialect(/*name=*/"tfr", context, TypeID::get<TFRDialect>()) {
132   // TFR depends on TensorFlow for its canonicalization
133   context->getOrLoadDialect<TF::TensorFlowDialect>();
134 
135   addTypes<TFRTensorType, TFRTensorListType, TFRAttrType>();
136   addOperations<
137 #define GET_OP_LIST
138 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
139       >();
140 
141   addInterfaces<TFRInlinerInterface>();
142 }
143 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)144 Operation *TFRDialect::materializeConstant(OpBuilder &builder, Attribute value,
145                                            Type type, Location loc) {
146   if (arith::ConstantOp::isBuildableWith(value, type))
147     return builder.create<arith::ConstantOp>(loc, type, value);
148   if (func::ConstantOp::isBuildableWith(value, type))
149     return builder.create<func::ConstantOp>(loc, type,
150                                             value.cast<FlatSymbolRefAttr>());
151   return nullptr;
152 }
153 
classof(Type type)154 bool TFRType::classof(Type type) {
155   return llvm::isa<TFRDialect>(type.getDialect());
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // Custom op methods
160 //===----------------------------------------------------------------------===//
161 
verify()162 LogicalResult ConstantTensorOp::verify() {
163   ConstantTensorOp op = *this;
164   auto input_type = op.arg().getType();
165   auto output_type = op.out().getType();
166 
167   if (auto output_tensor_type = output_type.dyn_cast<TFRTensorType>()) {
168     return success();
169   }
170 
171   auto output_tensor_type = output_type.dyn_cast<RankedTensorType>();
172   if (!output_tensor_type || !output_tensor_type.hasStaticShape()) {
173     op.emitError("output type should be static and ranked.");
174     return failure();
175   }
176 
177   if (output_tensor_type.getRank() == 0) {
178     bool same_scalar = output_tensor_type.getElementType() == input_type;
179     if (!same_scalar) {
180       op.emitError("input and output should have the same scalar types.");
181     }
182     return success(same_scalar);
183   }
184 
185   if (auto input_vector_type = input_type.dyn_cast<VectorType>()) {
186     bool same_element_type = output_tensor_type.getElementType() ==
187                              input_vector_type.getElementType();
188     bool same_shape =
189         output_tensor_type.getShape() == input_vector_type.getShape();
190     if (!same_element_type || !same_shape) {
191       op.emitError("input and output should have same shape and element type.");
192     }
193     return success(same_element_type && same_shape);
194   }
195 
196   op.emitError("input can not be converted to an output tensor.");
197   return failure();
198 }
199 
verify()200 LogicalResult TFRFuncOp::verify() {
201   TFRFuncOp func = *this;
202   // Collect all attribute names used by the tensor and tensor list arguments
203   // and returns. Also, collect the names of all the attribute arguments as the
204   // defined list. Later on, the used attribute names will be verified to be in
205   // the defined list.
206   llvm::SmallVector<StringAttr, 4> input_used_attrs, output_used_attrs;
207 
208   // While scanning the arguments, record the start/end indices of each argument
209   // type, so the order can be verified as well.
210   // TODO(fengliuai): the attribute arguments with default values need to be
211   // at the end?
212   int first_tensor = -1, last_tensor = -1, first_tensor_list = -1,
213       last_tensor_list = -1, first_attr = -1;
214   for (auto arg : llvm::enumerate(func.getFunctionType().getInputs())) {
215     Type arg_type = arg.value();
216 
217     if (auto tensor = arg_type.dyn_cast<TFRTensorType>()) {
218       if (first_tensor == -1) {
219         first_tensor = arg.index();
220       }
221       last_tensor = arg.index();
222       auto used = tensor.getAttrKeys();
223       input_used_attrs.append(used.begin(), used.end());
224       continue;
225     }
226 
227     if (auto tensor_list = arg_type.dyn_cast<TFRTensorListType>()) {
228       if (first_tensor_list == -1) {
229         first_tensor_list = arg.index();
230       }
231       last_tensor_list = arg.index();
232       auto used = tensor_list.getAttrKeys();
233       input_used_attrs.append(used.begin(), used.end());
234       continue;
235     }
236 
237     if (!arg_type.isa<TensorType>()) {
238       if (first_attr == -1) {
239         first_attr = arg.index();
240       }
241       auto name =
242           func.getArgAttrOfType<StringAttr>(arg.index(), kAttrArgumentNameAttr);
243       if (!name) {
244         func.emitError(
245             llvm::Twine(arg.index()) +
246             " attribute argument doesn't have a tfr.name attribute.");
247         return failure();
248       }
249       continue;
250     }
251 
252     func.emitError("Builtin TensorType isn't allowed as the argument.");
253     return failure();
254   }
255 
256   // Collect all the undefined attributes used in the inputs.
257   llvm::SmallVector<StringAttr, 4> undefined_attrs;
258   for (auto attr : input_used_attrs) {
259     if (!func->getAttr(attr.getValue())) {
260       undefined_attrs.push_back(attr);
261     }
262   }
263 
264   // Verify the argument order: tensors, tensor list, attributes; and also
265   // verify there is at most one tensor list argument.
266   if (first_attr != -1 &&
267       (first_attr < last_tensor_list || first_attr < last_tensor)) {
268     func.emitError(
269         "tfr.tensor/tfr.tensor_list argument should be before non tensor "
270         "arguments.");
271     return failure();
272   }
273   // The order between tensor arguments and tensor list arguments and the number
274   // of tensor list arguments are verified only when they couldn't be determined
275   // by the attributes.
276   if (!undefined_attrs.empty()) {
277     if (first_tensor_list != -1 && first_tensor_list < last_tensor) {
278       func.emitError(
279           "tfr.tensor argument should be before tfr.tensor_list argument.");
280       return failure();
281     }
282     if (first_tensor_list != last_tensor_list) {
283       func.emitError("More than one tfr.tensor_list argument isn't allowed.");
284       return failure();
285     }
286   }
287 
288   // Verify the result order: tensor, tensor list, and also verify at most one
289   // tensor list result.
290   int undefined_input_attrs_number = undefined_attrs.size();
291   bool seen_tensor_list = false, has_tensor_list_order_error = false,
292        has_multiple_tensor_lists_error = false;
293   for (auto result_type : func.getFunctionType().getResults()) {
294     if (auto tensor = result_type.dyn_cast<TFRTensorType>()) {
295       if (seen_tensor_list) {
296         has_tensor_list_order_error = true;
297       } else {
298         auto used = tensor.getAttrKeys();
299         output_used_attrs.append(used.begin(), used.end());
300       }
301       continue;
302     }
303 
304     if (auto tensor_list = result_type.dyn_cast<TFRTensorListType>()) {
305       if (seen_tensor_list) {
306         has_multiple_tensor_lists_error = true;
307       } else {
308         seen_tensor_list = true;
309         auto used = tensor_list.getAttrKeys();
310         output_used_attrs.append(used.begin(), used.end());
311       }
312       continue;
313     }
314 
315     func.emitError(
316         "None tfr.tensor/tfr.tensor_list results aren't allowed as a "
317         "result.");
318     return failure();
319   }
320 
321   // Collect all the undefined attributes used in the outputs.
322   for (auto attr : output_used_attrs) {
323     if (!func->getAttr(attr.getValue())) {
324       undefined_attrs.push_back(attr);
325     }
326   }
327 
328   // Verify there are no tensor/tensor list order error and multiple tensor
329   // list arguments error.
330   if (undefined_input_attrs_number != undefined_attrs.size()) {
331     if (has_tensor_list_order_error) {
332       func.emitError(
333           "tfr.tensor result should be before tfr.tensor_list result.");
334       return failure();
335     } else if (has_multiple_tensor_lists_error) {
336       func.emitError("More than one tfr.tensor_list result isn't allowed.");
337       return failure();
338     }
339   }
340 
341   // TODO(fengliuai): We might want to refine this constraint because the
342   // tensor element type can be derived.
343   if (!undefined_attrs.empty()) {
344     llvm::SmallVector<std::string, 4> attr_names(undefined_attrs.size());
345     std::transform(undefined_attrs.begin(), undefined_attrs.end(),
346                    attr_names.begin(),
347                    [](StringAttr attr) { return attr.getValue().str(); });
348     func.emitError(llvm::Twine("Undefined attributes are used: ",
349                                llvm::join(attr_names, ",")));
350     return failure();
351   }
352 
353   return success();
354 }
355 
parse(OpAsmParser & parser,OperationState & result)356 ParseResult TFRFuncOp::parse(OpAsmParser &parser, OperationState &result) {
357   auto build_func_type =
358       [](Builder &builder, ArrayRef<Type> arg_types, ArrayRef<Type> results,
359          function_interface_impl::VariadicFlag,
360          std::string &) { return builder.getFunctionType(arg_types, results); };
361   return function_interface_impl::parseFunctionOp(
362       parser, result, /*allowVariadic=*/false, build_func_type);
363 }
364 
print(OpAsmPrinter & p)365 void TFRFuncOp::print(OpAsmPrinter &p) {
366   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
367 }
368 
369 }  // namespace TFR
370 }  // namespace mlir
371 
372 //===----------------------------------------------------------------------===//
373 // TableGen'd op method definitions
374 //===----------------------------------------------------------------------===//
375 
376 #define GET_OP_CLASSES
377 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc.inc"
378 
379 namespace mlir {
380 namespace TFR {
381 namespace {
382 class ConvertConstToTensorConst : public OpRewritePattern<ConstantTensorOp> {
383   using OpRewritePattern<ConstantTensorOp>::OpRewritePattern;
384 
385  public:
matchAndRewrite(ConstantTensorOp cst_tensor_op,PatternRewriter & rewriter) const386   LogicalResult matchAndRewrite(ConstantTensorOp cst_tensor_op,
387                                 PatternRewriter &rewriter) const override {
388     Location loc = cst_tensor_op.getLoc();
389     Type out_type = cst_tensor_op.getType();
390     Operation *new_cst = nullptr;
391 
392     ArrayAttr array;
393     if (matchPattern(cst_tensor_op.arg(), m_Constant(&array))) {
394       llvm::DenseSet<Type> all_types;
395       for (auto it : array) {
396         TypedAttr typed_attr = it.dyn_cast<TypedAttr>();
397         if (!typed_attr) return failure();
398         all_types.insert(typed_attr.getType());
399       }
400       if (all_types.size() != 1) return failure();
401       ShapedType new_out_type = RankedTensorType::get(
402           {static_cast<int64_t>(array.size())}, *all_types.begin());
403       DenseElementsAttr attr =
404           DenseElementsAttr::get(new_out_type, array.getValue());
405       new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, attr);
406       if (out_type.isa<TFRTensorType>()) {
407         new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
408       }
409       rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
410       return success();
411     }
412 
413     TypedAttr scalar;
414     if (matchPattern(cst_tensor_op.arg(), m_Constant(&scalar))) {
415       Type new_out_type = RankedTensorType::get({}, scalar.getType());
416       new_cst = rewriter.create<TF::ConstOp>(loc, new_out_type, scalar);
417       if (out_type.isa<TFRTensorType>()) {
418         new_cst = rewriter.create<CastOp>(loc, out_type, new_cst->getResult(0));
419       }
420       rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0));
421       return success();
422     }
423     return failure();
424   }
425 };
426 
isQuantizedType(Type type)427 inline bool isQuantizedType(Type type) {
428   auto tensor_type = type.dyn_cast<TensorType>();
429   return (tensor_type &&
430           tensor_type.getElementType().isa<quant::QuantizedType>());
431 }
432 
433 class RemoveRedundantCast : public OpRewritePattern<CastOp> {
434   using OpRewritePattern<CastOp>::OpRewritePattern;
435 
436  public:
matchAndRewrite(CastOp cast_op,PatternRewriter & rewriter) const437   LogicalResult matchAndRewrite(CastOp cast_op,
438                                 PatternRewriter &rewriter) const override {
439     auto preceding_cast =
440         llvm::dyn_cast_or_null<CastOp>(cast_op.arg().getDefiningOp());
441     if (!preceding_cast) {
442       return failure();
443     }
444     Value input = preceding_cast.arg();
445     Type input_type = input.getType();
446     Type output_type = cast_op.getType();
447 
448     // Preserve quantization information for intermediate tensors.
449     auto intermediate_type = preceding_cast.getType();
450     if (isQuantizedType(intermediate_type) || isQuantizedType(output_type)) {
451       return failure();
452     }
453 
454     auto input_tensor_type = input_type.dyn_cast<TensorType>();
455     auto output_tensor_type = output_type.dyn_cast<TensorType>();
456     if (!input_tensor_type || !output_tensor_type) {
457       return failure();
458     }
459 
460     // Canonicalize two tfr.cast pairs with different element type to
461     // two tfr.casts with the same element type followed by a tf.Cast.
462     if ((input_tensor_type.getElementType() !=
463          output_tensor_type.getElementType()) &&
464         !isQuantizedType(input_type) && !isQuantizedType(output_type)) {
465       auto new_tfr_cast = rewriter.create<TFR::CastOp>(
466           cast_op.getLoc(),
467           output_tensor_type.clone(input_tensor_type.getElementType()),
468           cast_op.arg());
469       rewriter.replaceOpWithNewOp<TF::CastOp>(cast_op, output_type,
470                                               new_tfr_cast);
471       return success();
472     }
473 
474     // If the two types are the same, the back-to-back tfr.cast ops can be
475     // removed.
476     if (input_type == output_type || output_type.isa<UnrankedTensorType>()) {
477       rewriter.replaceOp(cast_op, {input});
478       return success();
479     }
480 
481     // If the rank of the input tensor isn't ranked, we replace the pair
482     // with tf.EnsureShape op so it can be removed after shape inference or
483     // confirmed at runtime.
484     if (input_type.isa<UnrankedTensorType>()) {
485       auto shape = output_type.cast<ShapedType>().getShape();
486       auto shape_attr = TF::ShapeAttr::get(rewriter.getContext(), shape);
487       rewriter.replaceOpWithNewOp<TF::EnsureShapeOp>(cast_op, output_type,
488                                                      input, shape_attr);
489       return success();
490     }
491 
492     return failure();
493   }
494 };
495 
496 class GetTensorShape : public OpRewritePattern<GetShapeOp> {
497   using OpRewritePattern<GetShapeOp>::OpRewritePattern;
498 
499  public:
matchAndRewrite(GetShapeOp shape_op,PatternRewriter & rewriter) const500   LogicalResult matchAndRewrite(GetShapeOp shape_op,
501                                 PatternRewriter &rewriter) const override {
502     Operation *preceding_op = shape_op.arg().getDefiningOp();
503     if (auto cast_op = llvm::dyn_cast_or_null<CastOp>(preceding_op)) {
504       // replace this pair by shape.shape_of, so the folding works.
505       rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(shape_op, cast_op.arg());
506       return success();
507     }
508     return failure();
509   }
510 };
511 
512 class RemoveRedundantGetElement : public OpRewritePattern<GetElementOp> {
513   using OpRewritePattern<GetElementOp>::OpRewritePattern;
514 
515  public:
matchAndRewrite(GetElementOp ge_op,PatternRewriter & rewriter) const516   LogicalResult matchAndRewrite(GetElementOp ge_op,
517                                 PatternRewriter &rewriter) const override {
518     IntegerAttr index;
519     if (!matchPattern(ge_op.index(), m_Constant(&index))) {
520       return failure();
521     }
522     auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
523         ge_op.tensor_list().getDefiningOp());
524     if (!preceding_build_list ||
525         preceding_build_list.getNumOperands() <= index.getInt()) {
526       return failure();
527     }
528     Value input = preceding_build_list.getOperand(index.getInt());
529     Type output_type = ge_op.getType();
530     if (input.getType() != output_type &&
531         !output_type.isa<UnrankedTensorType>()) {
532       return failure();
533     }
534     rewriter.replaceOp(ge_op, {input});
535     return success();
536   }
537 };
538 
539 class RemoveRedundantGetLength : public OpRewritePattern<GetLengthOp> {
540   using OpRewritePattern<GetLengthOp>::OpRewritePattern;
541 
542  public:
matchAndRewrite(GetLengthOp gl_op,PatternRewriter & rewriter) const543   LogicalResult matchAndRewrite(GetLengthOp gl_op,
544                                 PatternRewriter &rewriter) const override {
545     auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
546         gl_op.tensor_list().getDefiningOp());
547     if (!preceding_build_list) {
548       return failure();
549     }
550     int64_t num_tensors = preceding_build_list.getNumOperands();
551     rewriter.replaceOpWithNewOp<arith::ConstantOp>(
552         gl_op, rewriter.getIndexAttr(num_tensors));
553     return success();
554   }
555 };
556 
557 class BuildConstantListAsAttr : public OpRewritePattern<BuildListOp> {
558   using OpRewritePattern<BuildListOp>::OpRewritePattern;
559 
560  public:
matchAndRewrite(BuildListOp bl_op,PatternRewriter & rewriter) const561   LogicalResult matchAndRewrite(BuildListOp bl_op,
562                                 PatternRewriter &rewriter) const override {
563     SmallVector<Attribute, 4> array_list;
564     array_list.reserve(bl_op.getNumOperands());
565     for (const auto &operand : bl_op.getOperands()) {
566       Attribute array_elt;
567       if (!matchPattern(operand, m_Constant(&array_elt))) {
568         return failure();
569       }
570       array_list.push_back(array_elt);
571     }
572     auto array_attr = rewriter.getArrayAttr(array_list);
573     rewriter.replaceOpWithNewOp<TFR::ConstOp>(bl_op, array_attr);
574     return success();
575   }
576 };
577 
getQuantizedElementType(CastOp cast_op)578 quant::QuantizedType getQuantizedElementType(CastOp cast_op) {
579   if (!cast_op || !cast_op.getInputElementType()) {
580     return {};
581   }
582   return cast_op.getInputElementType()
583       .cast<TypeAttr>()
584       .getValue()
585       .dyn_cast<quant::QuantizedType>();
586 }
587 
588 class RemoveRawDataOp : public OpRewritePattern<TFRQuantRawDataOp> {
589   using OpRewritePattern<TFRQuantRawDataOp>::OpRewritePattern;
590 
591  public:
matchAndRewrite(TFRQuantRawDataOp raw_data_op,PatternRewriter & rewriter) const592   LogicalResult matchAndRewrite(TFRQuantRawDataOp raw_data_op,
593                                 PatternRewriter &rewriter) const override {
594     auto preceding_op = raw_data_op.input().getDefiningOp();
595     if (isa<BuildListOp>(preceding_op)) {
596       return rewritePrecedingListOp(raw_data_op, rewriter);
597     }
598 
599     auto preceding_cast = dyn_cast_or_null<CastOp>(preceding_op);
600     if (!preceding_cast || !getQuantizedElementType(preceding_cast)) {
601       return failure();
602     }
603     // If there are redundant casts, hoist output of raw data op originating op.
604     if (preceding_cast.arg().getDefiningOp()) {
605       auto redundant_cast = preceding_cast.arg().getDefiningOp<CastOp>();
606       if (!redundant_cast ||
607           redundant_cast.arg().getType() != preceding_cast.out().getType()) {
608         return failure();
609       }
610       raw_data_op.output().replaceAllUsesWith(redundant_cast.arg());
611     } else {
612       // If the argument of cast op is input, then simply remove the RawData op.
613       raw_data_op.output().replaceAllUsesWith(preceding_cast.out());
614     }
615     return success();
616   }
617 
rewritePrecedingListOp(TFRQuantRawDataOp raw_data_op,PatternRewriter & rewriter) const618   LogicalResult rewritePrecedingListOp(TFRQuantRawDataOp raw_data_op,
619                                        PatternRewriter &rewriter) const {
620     llvm::SmallVector<Value> new_list_values;
621     auto preceding_list = raw_data_op.input().getDefiningOp<BuildListOp>();
622     for (Value operand : preceding_list.tensors()) {
623       auto preceding_cast = operand.getDefiningOp<CastOp>();
624       if (!preceding_cast || !getQuantizedElementType(preceding_cast)) {
625         return failure();
626       }
627 
628       // This function currently only supports the case with redundant casts.
629       auto redundant_cast = preceding_cast.arg().getDefiningOp<CastOp>();
630       if (!redundant_cast ||
631           redundant_cast.arg().getType() != preceding_cast.out().getType()) {
632         return failure();
633       }
634 
635       new_list_values.push_back(redundant_cast.arg());
636     }
637 
638     auto new_list = rewriter.create<BuildListOp>(
639         raw_data_op.getLoc(), preceding_list.getType(), new_list_values);
640     raw_data_op.output().replaceAllUsesWith(new_list.out());
641     return success();
642   }
643 };
644 
645 class RemoveQParamsOp : public OpRewritePattern<TFRQuantQParamsOp> {
646   using OpRewritePattern<TFRQuantQParamsOp>::OpRewritePattern;
647 
648  public:
matchAndRewrite(TFRQuantQParamsOp qparams_op,PatternRewriter & rewriter) const649   LogicalResult matchAndRewrite(TFRQuantQParamsOp qparams_op,
650                                 PatternRewriter &rewriter) const override {
651     auto cast_op = dyn_cast<TFR::CastOp>(qparams_op.input().getDefiningOp());
652     auto cast_qtype = getQuantizedElementType(cast_op);
653     if (!cast_qtype) {
654       return failure();
655     }
656 
657     TF::ConstOp scale_op;
658     TF::ConstOp zp_op;
659 
660     // Reads quantization parameters from the quantized type, and converts
661     // them to constants.
662     rewriter.setInsertionPoint(qparams_op);
663     Location loc = qparams_op->getLoc();
664     if (auto qtype = cast_qtype.dyn_cast<quant::UniformQuantizedType>()) {
665       scale_op = rewriter.create<TF::ConstOp>(
666           loc, RankedTensorType::get({}, rewriter.getF32Type()),
667           rewriter.getF32FloatAttr(qtype.getScale()));
668       zp_op = rewriter.create<TF::ConstOp>(
669           loc, RankedTensorType::get({}, rewriter.getI32Type()),
670           rewriter.getI32IntegerAttr(qtype.getZeroPoint()));
671     } else if (auto qtype =
672                    cast_qtype.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
673       SmallVector<float> scales(qtype.getScales().begin(),
674                                 qtype.getScales().end());
675       SmallVector<int32_t> zps(qtype.getZeroPoints().begin(),
676                                qtype.getZeroPoints().end());
677       const size_t num_channels = qtype.getScales().size();
678 
679       auto scales_type = RankedTensorType::get(
680           {static_cast<int64_t>(num_channels)}, rewriter.getF32Type());
681       auto scales_attr =
682           DenseElementsAttr::get(scales_type, llvm::makeArrayRef(scales));
683       scale_op = rewriter.create<TF::ConstOp>(loc, scales_attr);
684 
685       auto zps_type = RankedTensorType::get(
686           {static_cast<int64_t>(num_channels)}, rewriter.getI32Type());
687       auto zps_attr = DenseElementsAttr::get(zps_type, llvm::makeArrayRef(zps));
688       zp_op = rewriter.create<TF::ConstOp>(loc, zps_attr);
689     }
690     if (!scale_op || !zp_op) {
691       return failure();
692     }
693     auto scale_cast = rewriter.create<CastOp>(loc, qparams_op.scale().getType(),
694                                               scale_op.output());
695     auto zp_cast =
696         rewriter.create<CastOp>(loc, qparams_op.zp().getType(), zp_op.output());
697 
698     qparams_op.scale().replaceAllUsesWith(scale_cast.out());
699     qparams_op.zp().replaceAllUsesWith(zp_cast.out());
700     return success();
701   }
702 };
703 
704 // TODO(b/193731721): Migrate tfr_ builtin canonicalizations to LowerTFROpPass
705 class RemoveScaleFactorOp : public OpRewritePattern<TFRQuantScaleFactorOp> {
706   using OpRewritePattern<TFRQuantScaleFactorOp>::OpRewritePattern;
707 
708  public:
709   // Replace quant_scale_factor with constant tensor equivalent to
710   // TFR_ConstantTensorOp (
711   //   ConstantOp (ConstAttr<F32Attr (in_scale[0] * in_scale[1] /
712   //   out_scale))
713   // )
714   // Currently, all decompositions using this pattern (Conv2D, FC) have the
715   // following preconditions:
716   // * out_scale: float scalar attribute
717   // * in_scale[0] (input scale): float scalar, given by tf.Const -> tfr.cast
718   // * in_scale[1] (filter scale): float scalar/vector
719   //     (per-tensor vs per-channel) quantization, given by tf.Const -> tfr.cast
matchAndRewrite(TFRQuantScaleFactorOp scale_factor_op,PatternRewriter & rewriter) const720   LogicalResult matchAndRewrite(TFRQuantScaleFactorOp scale_factor_op,
721                                 PatternRewriter &rewriter) const override {
722     auto out_scale_op =
723         scale_factor_op.out_scale().getDefiningOp<arith::ConstantOp>();
724     if (!out_scale_op) {
725       return failure();
726     }
727     const double out_scale =
728         out_scale_op.getValue().cast<FloatAttr>().getValueAsDouble();
729 
730     auto in_scales_op =
731         scale_factor_op.in_scales().getDefiningOp<BuildListOp>();
732     if (!in_scales_op || in_scales_op.getNumOperands() != 2) {
733       // BuildListOp is variadic, but we require two values: input_scale
734       // and filter_scale.
735       return failure();
736     }
737 
738     auto in_scale_op = in_scales_op.getOperand(0).getDefiningOp<CastOp>();
739     if (!in_scale_op) {
740       return failure();
741     }
742 
743     DenseFPElementsAttr in_scale_attr;
744     if (!matchPattern(in_scale_op.arg(), m_Constant(&in_scale_attr)) ||
745         in_scale_attr.size() != 1) {
746       return failure();
747     }
748     const float in_scale = in_scale_attr.getValues<float>()[0];
749     auto filter_scale_op = in_scales_op.getOperand(1).getDefiningOp<CastOp>();
750     if (!filter_scale_op) {
751       return failure();
752     }
753     DenseFPElementsAttr filter_scale_attr;
754     if (!matchPattern(filter_scale_op.arg(), m_Constant(&filter_scale_attr))) {
755       return failure();
756     }
757 
758     // The shape of scale_type is {} (rank 0) for per-tensor quantized tensor,
759     // and {num_channels} (rank 1) for per-channel quantized one.
760     auto scale_type = filter_scale_attr.getType().dyn_cast<RankedTensorType>();
761     if (scale_type.getRank() != 0 && scale_type.getRank() != 1) {
762       return failure();
763     }
764     SmallVector<float> scale_factors;
765     scale_factors.reserve(filter_scale_attr.size());
766     for (auto value : filter_scale_attr.getValues<APFloat>()) {
767       scale_factors.push_back(in_scale * value.convertToFloat() / out_scale);
768     }
769     rewriter.setInsertionPoint(scale_factor_op);
770     const Location loc = scale_factor_op->getLoc();
771     auto result_scale_op = rewriter.create<TF::ConstOp>(
772         loc,
773         DenseElementsAttr::get(scale_type, llvm::makeArrayRef(scale_factors)));
774     auto result_scale_cast_op = rewriter.create<CastOp>(
775         loc, scale_factor_op.getType(), result_scale_op.output());
776     scale_factor_op.scale_factor().replaceAllUsesWith(
777         result_scale_cast_op.out());
778     return success();
779   }
780 };
781 
782 class RemoveRescaleOp : public OpRewritePattern<TFRQuantRescaleOp> {
783   using OpRewritePattern<TFRQuantRescaleOp>::OpRewritePattern;
784 
785  public:
786   // Replace quant_rescale (input, scale, zp) with
787   // tf.Cast(tf.Round(tf.Cast(input, f32) * scale) + tf.Cast(zp, f32), i32)
matchAndRewrite(TFRQuantRescaleOp rescale_op,PatternRewriter & rewriter) const788   LogicalResult matchAndRewrite(TFRQuantRescaleOp rescale_op,
789                                 PatternRewriter &rewriter) const override {
790     Value input = rescale_op.input();
791     Value scale = rescale_op.scale();
792     Value zp = rescale_op.zp();
793 
794     const Location loc = rescale_op->getLoc();
795     const auto result_types = rescale_op->getResultTypes();
796     auto c_false =
797         rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(false));
798     TypeAttr f32_attr = TypeAttr::get(rewriter.getF32Type());
799     TFRAttrType output_type = TFRAttrType::get(rewriter.getContext());
800     auto constant_f32_op = rewriter.create<ConstOp>(loc, output_type, f32_attr);
801     TypeAttr i32_attr = TypeAttr::get(rewriter.getI32Type());
802     auto constant_i32_op = rewriter.create<ConstOp>(loc, output_type, i32_attr);
803 
804     IntegerAttr zp_attr;
805     if (!matchPattern(zp, m_Constant(&zp_attr))) {
806       return failure();
807     }
808     rewriter.setInsertionPoint(zp.getDefiningOp());
809     auto zp_tensor = rewriter.create<TF::ConstOp>(
810         loc, RankedTensorType::get({}, zp.getType()), zp_attr);
811     auto zp_cast = rewriter.create<CastOp>(
812         loc, rewriter.getType<TFRTensorType>(), zp_tensor.output());
813 
814     rewriter.setInsertionPoint(rescale_op);
815     auto cast_input_to_float_op = rewriter.create<CallOp>(
816         loc, result_types,
817         SymbolRefAttr::get(rewriter.getContext(), "tf__cast"),
818         ArrayRef<Value>{input, constant_f32_op, c_false});
819     auto input_x_scale_op = rewriter.create<CallOp>(
820         loc, result_types, SymbolRefAttr::get(rewriter.getContext(), "tf__mul"),
821         ArrayRef<Value>{cast_input_to_float_op.getResult(0), scale});
822     auto round_rescaled_op = rewriter.create<CallOp>(
823         loc, result_types,
824         SymbolRefAttr::get(rewriter.getContext(), "tf__round"),
825         ArrayRef<Value>{input_x_scale_op->getResult(0)});
826     auto cast_zp_to_float_op = rewriter.create<CallOp>(
827         loc, result_types,
828         SymbolRefAttr::get(rewriter.getContext(), "tf__cast"),
829         ArrayRef<Value>{zp_cast, constant_f32_op, c_false});
830     auto recentered_op = rewriter.create<CallOp>(
831         loc, result_types, SymbolRefAttr::get(rewriter.getContext(), "tf__add"),
832         ArrayRef<Value>{round_rescaled_op->getResult(0),
833                         cast_zp_to_float_op->getResult(0)});
834     auto cast_output_to_i32 = rewriter.create<CallOp>(
835         loc, result_types,
836         SymbolRefAttr::get(rewriter.getContext(), "tf__cast"),
837         ArrayRef<Value>{recentered_op->getResult(0), constant_i32_op, c_false});
838     rescale_op.output().replaceAllUsesWith(cast_output_to_i32.getResult(0));
839     return success();
840   }
841 };
842 
843 }  // namespace
844 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)845 void ConstantTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
846                                                    MLIRContext *context) {
847   results.add<ConvertConstToTensorConst>(context);
848 }
849 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)850 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
851                                          MLIRContext *context) {
852   results.add<RemoveRedundantCast>(context);
853 }
854 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)855 void GetShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
856                                              MLIRContext *context) {
857   results.add<GetTensorShape>(context);
858 }
859 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)860 void GetElementOp::getCanonicalizationPatterns(RewritePatternSet &results,
861                                                MLIRContext *context) {
862   results.add<RemoveRedundantGetElement>(context);
863 }
864 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)865 void GetLengthOp::getCanonicalizationPatterns(RewritePatternSet &results,
866                                               MLIRContext *context) {
867   results.add<RemoveRedundantGetLength>(context);
868 }
869 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)870 void BuildListOp::getCanonicalizationPatterns(RewritePatternSet &results,
871                                               MLIRContext *context) {
872   results.add<BuildConstantListAsAttr>(context);
873 }
874 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)875 void TFRQuantRawDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
876                                                     MLIRContext *context) {
877   results.add<RemoveRawDataOp>(context);
878 }
879 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)880 void TFRQuantQParamsOp::getCanonicalizationPatterns(RewritePatternSet &results,
881                                                     MLIRContext *context) {
882   results.add<RemoveQParamsOp>(context);
883 }
884 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)885 void TFRQuantRescaleOp::getCanonicalizationPatterns(RewritePatternSet &results,
886                                                     MLIRContext *context) {
887   results.add<RemoveRescaleOp>(context);
888 }
889 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)890 void TFRQuantScaleFactorOp::getCanonicalizationPatterns(
891     RewritePatternSet &results, MLIRContext *context) {
892   results.add<RemoveScaleFactorOp>(context);
893 }
894 
fold(ArrayRef<Attribute> operands)895 OpFoldResult TFR::EqualOp::fold(ArrayRef<Attribute> operands) {
896   assert(operands.size() == 2 && "equal op has two operands");
897   auto ctx = getContext();
898   if (operands[0] == operands[1]) return BoolAttr::get(ctx, true);
899   return BoolAttr::get(ctx, false);
900 }
901 
fold(ArrayRef<Attribute> operands)902 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
903   assert(operands.empty() && "constant has no operands");
904 
905   // Return the held attribute value.
906   return value();
907 }
908 
909 // CallableOpInterface
getCallableRegion()910 Region *TFRFuncOp::getCallableRegion() {
911   return isExternal() ? nullptr : &body().front();
912 }
913 
914 // CallableOpInterface
getCallableResults()915 ArrayRef<Type> TFRFuncOp::getCallableResults() {
916   return getFunctionType().getResults();
917 }
918 
919 //===----------------------------------------------------------------------===//
920 // Dialect type definitions
921 //===----------------------------------------------------------------------===//
922 
923 // Parses a TFR type.
924 //   tfr_type ::= tensor_type | tensor_list_type | attr_type
925 //   string_list ::= `[` string-literal (, string-literal)+ `]`
926 //   tensor_type ::= `tensor`
927 //                 | `tensor<` (string-literal | string_list)  '>'
928 //   tensor_list_type ::= `tensor_list`
929 //                      | `tensor_list<` (string-literal | string_list)  '>'
930 //   attr_type ::= `attr`
parseType(DialectAsmParser & parser) const931 Type TFRDialect::parseType(DialectAsmParser &parser) const {
932   Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
933   MLIRContext *ctx = loc.getContext();
934 
935   StringRef typeNameSpelling;
936   if (failed(parser.parseKeyword(&typeNameSpelling))) return {};
937   llvm::SmallVector<StringAttr, 4> attrs;
938   if (succeeded(parser.parseOptionalLess())) {
939     bool l_square_parsed = false;
940     if (succeeded(parser.parseOptionalLSquare())) {
941       l_square_parsed = true;
942     }
943 
944     do {
945       StringRef attr;
946       if (failed(parser.parseKeyword(&attr))) return {};
947       attrs.push_back(StringAttr::get(ctx, attr));
948     } while (succeeded(parser.parseOptionalComma()));
949 
950     if (l_square_parsed && failed(parser.parseRSquare())) {
951       parser.emitError(parser.getNameLoc(), "expected ']'");
952     }
953 
954     if (failed(parser.parseGreater())) {
955       parser.emitError(parser.getNameLoc(), "expected '>'");
956     }
957   }
958 
959   if (typeNameSpelling == "tensor") {
960     return TFRTensorType::getChecked(attrs, loc);
961   } else if (typeNameSpelling == "tensor_list") {
962     return TFRTensorListType::getChecked(attrs, loc);
963   } else if (typeNameSpelling == "attr") {
964     return TFRAttrType::getChecked(loc, loc.getContext());
965   } else {
966     parser.emitError(parser.getNameLoc(), "unknown type " + typeNameSpelling);
967     return {};
968   }
969 }
970 
printType(Type type,DialectAsmPrinter & os) const971 void TFRDialect::printType(Type type, DialectAsmPrinter &os) const {
972   llvm::ArrayRef<StringAttr> attrs;
973 
974   if (type.isa<TFRAttrType>()) {
975     os << "attr";
976     return;
977   }
978   if (auto tensor_ty = type.dyn_cast<TFRTensorType>()) {
979     attrs = tensor_ty.getAttrKeys();
980     os << "tensor";
981   } else if (auto tensor_list_ty = type.dyn_cast<TFRTensorListType>()) {
982     attrs = tensor_list_ty.getAttrKeys();
983     os << "tensor_list";
984   } else {
985     llvm_unreachable("Unhandled tfr type");
986   }
987 
988   if (attrs.empty()) return;
989   os << "<";
990 
991   if (attrs.size() > 1) {
992     os << "[";
993   }
994 
995   llvm::interleaveComma(attrs, os,
996                         [&](StringAttr attr) { os << attr.getValue(); });
997 
998   if (attrs.size() > 1) {
999     os << "]";
1000   }
1001   os << ">";
1002 }
1003 
1004 }  // namespace TFR
1005 }  // namespace mlir
1006