xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 #include <string>
24 
25 #include "absl/strings/escaping.h"
26 #include "third_party/eigen3/Eigen/Core"
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/Optional.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SetVector.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/FormatVariadic.h"
37 #include "llvm/Support/Threading.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
40 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
41 #include "mlir/IR/Attributes.h"  // from @llvm-project
42 #include "mlir/IR/Builders.h"  // from @llvm-project
43 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
44 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
45 #include "mlir/IR/Location.h"  // from @llvm-project
46 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
47 #include "mlir/IR/Matchers.h"  // from @llvm-project
48 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
49 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
50 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
51 #include "mlir/IR/Types.h"  // from @llvm-project
52 #include "mlir/Support/LLVM.h"  // from @llvm-project
53 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
54 #include "mlir/Transforms/FoldUtils.h"  // from @llvm-project
55 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
56 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
57 #include "tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h"
58 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
59 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
60 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
61 #include "tensorflow/core/framework/kernel_shape_util.h"
62 
63 namespace mlir {
64 namespace TFL {
65 namespace {
66 
parseOneResultSameOperandTypeOp(OpAsmParser & parser,OperationState & result)67 ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
68                                             OperationState &result) {
69   SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
70   Type type;
71   // If the operand list is in-between parentheses, then we have a generic form.
72   // (see the fallback in `printOneResultOp`).
73   SMLoc loc = parser.getCurrentLocation();
74   if (!parser.parseOptionalLParen()) {
75     if (parser.parseOperandList(ops) || parser.parseRParen() ||
76         parser.parseOptionalAttrDict(result.attributes) ||
77         parser.parseColon() || parser.parseType(type))
78       return failure();
79     auto fnType = type.dyn_cast<FunctionType>();
80     if (!fnType) {
81       parser.emitError(loc, "expected function type");
82       return failure();
83     }
84     if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
85       return failure();
86     result.addTypes(fnType.getResults());
87     return success();
88   }
89   return failure(parser.parseOperandList(ops) ||
90                  parser.parseOptionalAttrDict(result.attributes) ||
91                  parser.parseColonType(type) ||
92                  parser.resolveOperands(ops, type, result.operands) ||
93                  parser.addTypeToList(type, result.types));
94 }
95 
printOneResultOp(Operation * op,OpAsmPrinter & p)96 void printOneResultOp(Operation *op, OpAsmPrinter &p) {
97   assert(op->getNumResults() == 1 && "op should have one result");
98 
99   // If not all the operand and result types are the same, just use the
100   // generic assembly form to avoid omitting information in printing.
101   auto resultType = op->getResult(0).getType();
102   if (llvm::any_of(op->getOperandTypes(),
103                    [&](Type type) { return type != resultType; })) {
104     p.printGenericOp(op, /*printOpName=*/false);
105     return;
106   }
107 
108   p << ' ';
109   p.printOperands(op->getOperands());
110   p.printOptionalAttrDict(op->getAttrs());
111   // Now we can output only one type for all operands and the result.
112   p << " : " << resultType;
113 }
114 
getDefiningBroadcastArgsOp(Value operand)115 Operation *getDefiningBroadcastArgsOp(Value operand) {
116   auto *defining_op = operand.getDefiningOp();
117   if (!llvm::dyn_cast_or_null<TF::BroadcastToOp>(defining_op) &&
118       !llvm::dyn_cast_or_null<TFL::BroadcastToOp>(defining_op)) {
119     return nullptr;
120   }
121 
122   Value broadcast_shape = defining_op->getOperand(
123       1);  // Broadcasted shape operand of BroadcastTo op.
124   Operation *parent_of_defining_op = broadcast_shape.getDefiningOp();
125   if (!llvm::dyn_cast_or_null<TF::BroadcastArgsOp>(parent_of_defining_op) &&
126       !llvm::dyn_cast_or_null<TFL::BroadcastArgsOp>(parent_of_defining_op)) {
127     return nullptr;
128   }
129   return parent_of_defining_op;
130 }
131 }  // namespace
132 
133 // Returns true when the given type lists contain a single element of shaped
134 // type with compatible shapes (unranked shape is compatible with any ranked
135 // shape, ranked shapes are compatible if their respective dimensions are
136 // compatible, dynamic dimensions are compatible with any size, static
137 // dimensions must be equal to be compatible) and identical element types.
VerifyCompatibleShapesSameElementType(TypeRange lhs,TypeRange rhs)138 bool VerifyCompatibleShapesSameElementType(TypeRange lhs, TypeRange rhs) {
139   if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
140   if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
141   auto lhsShaped = lhs[0].cast<ShapedType>();
142   auto rhsShaped = rhs[0].cast<ShapedType>();
143   return lhsShaped.getElementType() == rhsShaped.getElementType();
144 }
145 
146 // Returns true when the given operand arguments have the same shape or
147 // broadcastable shape within the given rank. If any given shapes are
148 // non-static and maximum rank is within the given rank, this method returns
149 // true.
VerifyOperandsHaveSameShapesOrBroadcastableShape(Operation * op,ArrayRef<unsigned> indices,int max_bcast_rank)150 bool VerifyOperandsHaveSameShapesOrBroadcastableShape(
151     Operation *op, ArrayRef<unsigned> indices, int max_bcast_rank) {
152   if (indices.empty()) return true;
153 
154   // First, it checks there are any inputs that has unknown rank.
155   bool has_unknown_shape_input = false;
156   bool has_same_shape = true;
157   bool reach_first_known_shape = false;
158   int64_t max_rank = -1;
159 
160   ArrayRef<int64_t> pivot_shape;
161   SmallVector<int64_t, 4> current_shape;
162   SmallVector<int64_t, 4> result_shape;
163 
164   for (unsigned index : indices) {
165     ShapedType shaped_type =
166         op->getOperand(index).getType().dyn_cast<ShapedType>();
167     if (!shaped_type || !shaped_type.hasRank()) {
168       // Marks that we have an unknown rank input.
169       has_unknown_shape_input = true;
170       continue;
171     }
172     max_rank = std::max(max_rank, shaped_type.getRank());
173     if (!shaped_type.hasStaticShape()) {
174       // Marks that we have an unknown shape input.
175       has_unknown_shape_input = true;
176       continue;
177     }
178 
179     ArrayRef<int64_t> shape = shaped_type.getShape();
180     if (!reach_first_known_shape) {
181       pivot_shape = shape;
182       current_shape.assign(shape.begin(), shape.end());
183       reach_first_known_shape = true;
184       continue;
185     }
186 
187     if (!pivot_shape.equals(shape)) {
188       has_same_shape = false;
189     }
190     //  Checks if all the inputs are broadcastable since they have not all the
191     //  same shapes.
192     if (!OpTrait::util::getBroadcastedShape(current_shape, shape,
193                                             result_shape)) {
194       return false;
195     }
196     current_shape = result_shape;
197   }
198 
199   // If all the shape is known and same, CPU kernels are able to handle inputs
200   // regardless of dimension size.
201   if (!has_unknown_shape_input) {
202     return has_same_shape || max_rank <= max_bcast_rank;
203   }
204 
205   // It will treat the unknown shape inputs as acceptable inputs for model
206   // compatibility if all known ranks are no bigger than the allowed broadcast
207   // maximum rank.
208   if (max_rank <= max_bcast_rank) {
209     return true;
210   }
211 
212   // Checks if all operands are broadcasted by BroadcastTo ops with the shape
213   // is calculated from the same BroadcastArgs op. In such case, all operands
214   // will have the same shape.
215   Operation *broadcast_args_pivot = nullptr;
216   for (unsigned index : indices) {
217     Operation *parent_broadcast_args =
218         getDefiningBroadcastArgsOp(op->getOperand(index));
219     if (parent_broadcast_args == nullptr) {
220       return false;
221     }
222 
223     if (broadcast_args_pivot == nullptr) {
224       broadcast_args_pivot = parent_broadcast_args;
225       continue;
226     }
227 
228     if (broadcast_args_pivot != parent_broadcast_args) {
229       return false;
230     }
231   }
232   return true;
233 }
234 
235 // Return true when the given element_type is QI8.
IsQI8Type(Type element_type)236 bool IsQI8Type(Type element_type) {
237   auto quantized_type = element_type.dyn_cast<QuantizedType>();
238   return quantized_type != nullptr &&
239          quantized_type.getStorageTypeIntegralWidth() == 8 &&
240          quantized_type.isSigned();
241 }
242 
243 // Return true when the given element_type is QUI8.
IsQUI8Type(Type element_type)244 bool IsQUI8Type(Type element_type) {
245   auto quantized_type = element_type.dyn_cast<QuantizedType>();
246   return quantized_type != nullptr &&
247          quantized_type.getStorageTypeIntegralWidth() == 8 &&
248          !quantized_type.isSigned();
249 }
250 
251 // Return true when the given element_type is QI16.
IsQI16Type(Type element_type)252 bool IsQI16Type(Type element_type) {
253   auto quantized_type = element_type.dyn_cast<QuantizedType>();
254   return quantized_type != nullptr &&
255          quantized_type.getStorageTypeIntegralWidth() == 16 &&
256          quantized_type.isSigned();
257 }
258 
259 // Return true when the given element_type is I32.
IsI32Type(Type element_type)260 bool IsI32Type(Type element_type) {
261   return element_type.isInteger(32) && !element_type.isUnsignedInteger();
262 }
263 
264 // Return true when the given element_type is I64.
IsI64Type(Type element_type)265 bool IsI64Type(Type element_type) {
266   return element_type.isInteger(64) && !element_type.isUnsignedInteger();
267 }
268 
269 // Return true if the value is a splat tensor constant zero.
EqualsZero(Value value)270 bool EqualsZero(Value value) {
271   DenseElementsAttr constant;
272   if (!matchPattern(value, m_Constant(&constant)) || !constant.isSplat()) {
273     return false;
274   }
275 
276   Type element_type = value.getType().cast<ShapedType>().getElementType();
277   if (element_type.isa<FloatType>()) {
278     return constant.getSplatValue<APFloat>().isZero();
279   } else {
280     return false;
281   }
282 }
283 
284 // Replaces the bias operand with a "none" type value if the bias value is
285 // constant zero.
286 // `ConcreteOpType` must be an concrete MLIR op class that has an optional
287 // bias operand named 'bias'.
288 template <typename ConcreteOpType>
289 struct RemoveOptionalZeroBias : public OpRewritePattern<ConcreteOpType> {
290   using OpRewritePattern<ConcreteOpType>::OpRewritePattern;
291 
matchAndRewritemlir::TFL::RemoveOptionalZeroBias292   LogicalResult matchAndRewrite(ConcreteOpType op,
293                                 PatternRewriter &rewriter) const override {
294     if (EqualsZero(op.bias())) {
295       auto none_value = rewriter.create<TFL::NoValueOp>(
296           rewriter.getUnknownLoc(), rewriter.getNoneType(),
297           rewriter.getUnitAttr());
298       op.biasMutable().assign(none_value);
299     }
300 
301     return success();
302   }
303 };
304 
305 // Return true if the given Add operation has the CPU kernel supported shapes.
VerifyAddOpShapeConstraints(AddOp op)306 bool VerifyAddOpShapeConstraints(AddOp op) {
307   auto element_type = getElementTypeOrSelf(op.output().getType());
308 
309   // Allows F32, QI8, QUI8 and I32 outputs when the operands have valid shapes,
310   // which are broadcastable shapes up to four dimensions or have same shapes.
311   if (element_type.isF32() || IsQI8Type(element_type) ||
312       IsQUI8Type(element_type) || IsI32Type(element_type) ||
313       IsI64Type(element_type)) {
314     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
315         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
316         /*max_bcast_rank=*/4);
317   }
318 
319   // Allows QI16 output when operands have the same shape.
320   if (IsQI16Type(element_type)) {
321     return succeeded(
322         mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
323   }
324   return false;
325 }
326 
327 // Return true if the given Sub operation has the CPU kernel supported shapes.
VerifySubOpShapeConstraints(SubOp op)328 bool VerifySubOpShapeConstraints(SubOp op) {
329   auto element_type = getElementTypeOrSelf(op.output().getType());
330 
331   // Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
332   // which are broadcastable shapes up to five dimension or have same shapes.
333   if (element_type.isF32() || IsI32Type(element_type) ||
334       IsI64Type(element_type) || IsQUI8Type(element_type) ||
335       IsQI16Type(element_type)) {
336     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
337         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
338         /*max_bcast_rank=*/5);
339   }
340 
341   // Allows QI8 output when the operands have valid shapes, which are
342   // broadcastable shapes up to four dimension or have same shapes.
343   if (IsQI8Type(element_type)) {
344     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
345         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
346         /*max_bcast_rank=*/4);
347   }
348   return false;
349 }
350 
351 // Return true if the given Mul operation has the CPU kernel supported shapes.
VerifyMulOpShapeConstraints(MulOp op)352 bool VerifyMulOpShapeConstraints(MulOp op) {
353   auto element_type = getElementTypeOrSelf(op.output().getType());
354 
355   // Allows QI8 and QUI8 inputs up to five dimension broadcasting unless the
356   // output type is not QI16. If the output type is Q16, allows only the same
357   // shape operands.
358   if (IsQI8Type(element_type) || IsQUI8Type(element_type)) {
359     if (IsQI16Type(getElementTypeOrSelf(op.lhs().getType()))) {
360       return succeeded(
361           mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
362     }
363     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
364         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
365         /*max_bcast_rank=*/4);
366   }
367 
368   // Allows I32, I64, QI16 and F32 outputs when the operands have valid shapes,
369   // which are broadcastable shapes up to four dimension or have same shapes.
370   if (IsI32Type(element_type) || IsI64Type(element_type) ||
371       IsQI16Type(element_type) || element_type.isa<ComplexType>() ||
372       element_type.isF32()) {
373     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
374         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
375         /*max_bcast_rank=*/4);
376   }
377   return false;
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // TensorFlowLiteDialect
382 //===----------------------------------------------------------------------===//
383 
384 struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface {
385   using DialectInlinerInterface::DialectInlinerInterface;
386 
387   //===--------------------------------------------------------------------===//
388   // Analysis Hooks
389   //===--------------------------------------------------------------------===//
390 
391   // Allow all call operations to be inlined.
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface392   bool isLegalToInline(Operation *call, Operation *callable,
393                        bool wouldBeCloned) const final {
394     return true;
395   }
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface396   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
397                        BlockAndValueMapping &) const final {
398     // No TFLite op restricts inlining today, revise as needed in the future.
399     return true;
400   }
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface401   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
402                        BlockAndValueMapping &valueMapping) const final {
403     return isa<WhileOp>(dest->getParentOp());
404   }
405 };
406 
407 struct TensorFlowLiteDialectFoldInterface : public DialectFoldInterface {
408   using DialectFoldInterface::DialectFoldInterface;
409 
410   // Registered hook to check if the given region, which is attached to an
411   // operation that is *not* isolated from above (i.e. no internal regions
412   // reference values defined in an enclosing region), should be used when
413   // materializing constants.
414   // In the TFLite dialect we materialize inside a while regions as slightly
415   // more efficient computationally.
shouldMaterializeIntomlir::TFL::TensorFlowLiteDialectFoldInterface416   bool shouldMaterializeInto(Region *region) const final {
417     return isa<WhileOp>(region->getParentOp());
418   }
419 };
420 
printType(Type type,DialectAsmPrinter & os) const421 void TFLDialect::printType(Type type, DialectAsmPrinter &os) const {
422   if (type.isa<ControlType>()) {
423     os << "control";
424     return;
425   }
426   os << "<unknown TFL type>";
427 }
428 
parseType(DialectAsmParser & parser) const429 Type TFLDialect::parseType(DialectAsmParser &parser) const {
430   StringRef data_type;
431   if (parser.parseKeyword(&data_type)) return Type();
432   if (data_type == "control") return ControlType::get(getContext());
433   parser.emitError(parser.getNameLoc()) << "unknown TFL type: " << data_type;
434   return nullptr;
435 }
436 
initialize()437 void TFLDialect::initialize() {
438   addOperations<
439 #define GET_OP_LIST
440 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
441       >();
442   addAttributes<
443 #define GET_ATTRDEF_LIST
444 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_attrdefs.cc.inc"
445       >();
446   addInterfaces<TensorFlowLiteInlinerInterface,
447                 TensorFlowLiteDialectFoldInterface>();
448   addTypes<ControlType>();
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // Common support logic
453 //===----------------------------------------------------------------------===//
454 
455 namespace {
456 
457 // Returns true if the dimensions in `a` is a suffix of the ones in `b`.
458 // For example, dimensions {2}, {1, 2}, and {3, 1, 2} are all suffixes to
459 // {5, 4, 3, 1, 2}, while {1}, {5, 4}, and {1, 3, 2} are all not.
IsTrailingDimensions(ArrayRef<int64_t> a,ArrayRef<int64_t> b)460 inline bool IsTrailingDimensions(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
461   if (a.size() > b.size()) return false;
462 
463   return std::equal(a.rbegin(), a.rend(), b.rbegin());
464 }
465 
466 // Returns true if it is a shaped type of f32 elements.
IsF32ShapedType(Type t)467 inline bool IsF32ShapedType(Type t) {
468   if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
469     return shaped_type.getElementType().isF32();
470   }
471   return false;
472 }
473 
474 // Returns true if it is a shaped type of bf16 elements.
IsBF16ShapedType(Type t)475 inline bool IsBF16ShapedType(Type t) {
476   if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
477     return shaped_type.getElementType().isBF16();
478   }
479   return false;
480 }
481 
482 // Returns new shape with rank 'new_dims' with padded ones on the
483 // left if needed.
GetPaddedShape(ArrayRef<int64_t> old_shape,int new_dims)484 inline std::vector<int64_t> GetPaddedShape(ArrayRef<int64_t> old_shape,
485                                            int new_dims) {
486   std::vector<int64_t> new_shape(new_dims, 1);
487   std::copy_backward(old_shape.begin(), old_shape.end(), new_shape.end());
488   return new_shape;
489 }
490 
491 // Helper method that given and 'current_index' representing
492 // index in broadcasted tensor, get the index in the flat original tensor.
493 // 'shape' is the original shape with padding to match result shape.
GetElementIndex(const std::vector<int64_t> & shape,const std::vector<int64_t> & current_index)494 int64_t GetElementIndex(const std::vector<int64_t> &shape,
495                         const std::vector<int64_t> &current_index) {
496   int64_t ind = 0;
497   int64_t mul = 1;
498   for (int i = shape.size() - 1; i >= 0; --i) {
499     ind += (current_index[i] % shape[i]) * mul;
500     mul *= shape[i];
501   }
502   return ind;
503 }
504 
505 // Helper method that increment index represented in 'current_index_ptr'
506 // in the shape of 'result_shape'.
IncrementIndex(ArrayRef<int64_t> result_shape,std::vector<int64_t> * current_index_ptr)507 void IncrementIndex(ArrayRef<int64_t> result_shape,
508                     std::vector<int64_t> *current_index_ptr) {
509   std::vector<int64_t> &current_index = *current_index_ptr;
510   for (int i = result_shape.size() - 1; i >= 0; --i) {
511     current_index[i]++;
512     if (current_index[i] == result_shape[i]) {
513       current_index[i] = 0;
514     } else {
515       break;
516     }
517   }
518 }
519 
520 /// Performs const folding `calculate` with broadcast behavior on the two
521 /// attributes `operand1` and `operand2` and returns the result if possible.
522 /// This function assumes the both operands are verified to have value
523 /// attributes of broadcastable types.
524 template <class AttrElementT,
525           class ElementValueT = typename AttrElementT::ValueType,
526           class CalculationT =
527               llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
ConstFoldBinaryOpDenseDense(Type result_type,DenseElementsAttr lhs,DenseElementsAttr rhs,const CalculationT & calculate)528 Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs,
529                                       DenseElementsAttr rhs,
530                                       const CalculationT &calculate) {
531   auto type = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType())
532                   .dyn_cast_or_null<ShapedType>();
533   if (!type) {
534     return {};
535   }
536 
537   const bool rhs_is_splat = rhs.isSplat();
538   const bool lhs_is_splat = lhs.isSplat();
539 
540   // If both of them are splat, compute and return.
541   if (lhs_is_splat && rhs_is_splat) {
542     auto element_result = AttrElementT::get(
543         type.getElementType(), calculate(lhs.getSplatValue<ElementValueT>(),
544                                          rhs.getSplatValue<ElementValueT>()));
545     if (!element_result) return {};
546 
547     return DenseElementsAttr::get(type, element_result);
548   }
549 
550   auto num_elements = type.getNumElements();
551 
552   SmallVector<ElementValueT, 16> new_values;
553   new_values.reserve(num_elements);
554   const auto result_shape = type.getShape();
555   std::vector<int64_t> current_index(type.getRank(), 0);
556   // Create the new shape with ones padded to the left.
557   const std::vector<int64_t> lhs_new_shape =
558       GetPaddedShape(lhs.getType().getShape(), type.getRank());
559   const std::vector<int64_t> rhs_new_shape =
560       GetPaddedShape(rhs.getType().getShape(), type.getRank());
561 
562   auto lhs_old_values = lhs.getValues<ElementValueT>();
563   auto rhs_old_values = rhs.getValues<ElementValueT>();
564 
565   // Add each pair of the corresponding values in the dense elements
566   // attributes.
567   for (int64_t i = 0; i < num_elements; ++i) {
568     // current_index represents the index
569     // in the N-dimension tensor. GetElementIndex returns
570     // the index in the flat representation of the original tensor
571     // to use.
572     const int64_t lhs_index =
573         lhs_is_splat ? 0 : GetElementIndex(lhs_new_shape, current_index);
574     const int64_t rhs_index =
575         rhs_is_splat ? 0 : GetElementIndex(rhs_new_shape, current_index);
576 
577     new_values.push_back(calculate(*(lhs_old_values.begin() + lhs_index),
578                                    *(rhs_old_values.begin() + rhs_index)));
579     IncrementIndex(result_shape, &current_index);
580   }
581   return DenseElementsAttr::get(type, ArrayRef<ElementValueT>(new_values));
582 }
583 
584 /// Performs const folding `calculate` with broadcast behavior on the two
585 /// attributes `operand1` and `operand2` and returns the result if possible.
586 /// This function assumes the two operands are verified to have value
587 /// attributes of broadcastable types.
588 template <class AttrElementT,
589           class ElementValueT = typename AttrElementT::ValueType,
590           class CalculationT =
591               llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
ConstFoldBinaryOp(Type result_type,Attribute operand1,Attribute operand2,const CalculationT & calculate)592 Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1,
593                             Attribute operand2, const CalculationT &calculate) {
594   if (operand1.dyn_cast_or_null<DenseElementsAttr>() &&
595       operand2.dyn_cast_or_null<DenseElementsAttr>()) {
596     return ConstFoldBinaryOpDenseDense<AttrElementT, ElementValueT>(
597         result_type, operand1.cast<DenseElementsAttr>(),
598         operand2.cast<DenseElementsAttr>(), calculate);
599   }
600 
601   // TODO: support other attribute kinds
602 
603   return {};
604 }
605 
606 /// Performs const folding with broadcast behavior on the two attributes in
607 /// `operands` and returns the result if possible.
608 /// Depending on the given `resultType`, either `floatCalculate` or
609 /// `intCalculate` is chosen to conduct the calculate.
ConstFoldBinaryOp(Type result_type,ArrayRef<Attribute> operands,llvm::function_ref<APFloat (APFloat,APFloat)> float_calculate,llvm::function_ref<APInt (APInt,APInt)> int_calculate)610 Attribute ConstFoldBinaryOp(
611     Type result_type, ArrayRef<Attribute> operands,
612     llvm::function_ref<APFloat(APFloat, APFloat)> float_calculate,
613     llvm::function_ref<APInt(APInt, APInt)> int_calculate) {
614   // Note: All types are wrapped in tensor types in TFlite. E.g., f32 is
615   // represented as tensor<f32>. So we are only handling tensor types here.
616   auto type = result_type.dyn_cast<ShapedType>();
617   if (!type) return {};
618 
619   auto elemType = type.getElementType();
620 
621   if (elemType.isa<FloatType>())
622     return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
623                                         float_calculate);
624 
625   if (elemType.isSignlessInteger())
626     return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
627                                           int_calculate);
628 
629   return {};
630 }
631 
632 /// Performs const folding a attributes `operand` and returns the result if
633 /// possible.
634 /// The function currently asserts that the `result_type` to be a f32 tensor
635 /// type.
636 /// TODO: Extend this function to handle integral tensor for ops like
637 /// "tfl.logical_not".
ConstFoldUnaryOp(Type result_type,Attribute operand,llvm::function_ref<APFloat (APFloat)> calculate)638 Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
639                            llvm::function_ref<APFloat(APFloat)> calculate) {
640   assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type));
641   auto result_shape_type = result_type.cast<ShapedType>();
642 
643   if (!result_shape_type.hasStaticShape()) return {};
644 
645   if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
646     SmallVector<APFloat, 16> new_values;
647     const int num_elements = result_shape_type.getNumElements();
648     new_values.reserve(num_elements);
649 
650     for (const APFloat &old_value : dense_elements.getValues<APFloat>()) {
651       new_values.push_back(calculate(old_value));
652     }
653 
654     return DenseElementsAttr::get(result_shape_type, new_values);
655   }
656 
657   return {};
658 }
659 
buildComparisonBinOp(Builder * builder,OperationState & result,Value lhs,Value rhs)660 void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
661                           Value rhs) {
662   auto result_type =
663       OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
664   if (!result_type)
665     emitError(result.location)
666         << "non-broadcastable operands: " << lhs.getType() << " and "
667         << rhs.getType();
668   result.addOperands({lhs, rhs});
669   // Comparison binary ops always return i1 tensor.
670   if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) {
671     auto result_shape = shaped_type.getShape();
672     result.types.push_back(
673         RankedTensorType::get(result_shape, builder->getI1Type()));
674   } else {
675     result.types.push_back(UnrankedTensorType::get(builder->getI1Type()));
676   }
677 }
678 
buildFusedBroadcastableBinOp(Builder * builder,OperationState & result,Value lhs,Value rhs,StringAttr fused_activation_function)679 void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
680                                   Value lhs, Value rhs,
681                                   StringAttr fused_activation_function) {
682   auto result_type =
683       OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
684 
685   if (!result_type)
686     emitError(result.location)
687         << "non-broadcastable operands: " << lhs.getType() << " and "
688         << rhs.getType();
689 
690   result.addOperands({lhs, rhs});
691   result.addAttribute("fused_activation_function", fused_activation_function);
692   result.types.push_back(result_type);
693 }
694 
695 }  // end anonymous namespace
696 
697 //===----------------------------------------------------------------------===//
698 // AddOp
699 //===----------------------------------------------------------------------===//
700 
fold(ArrayRef<Attribute> operands)701 OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
702   // TODO(b/142478136): Handle fused ops.
703   if (fused_activation_function() != "NONE") return {};
704   return ConstFoldBinaryOp(
705       getType(), operands, [](APFloat a, APFloat b) { return a + b; },
706       [](APInt a, APInt b) { return a + b; });
707 }
708 
GetArithmeticCount(Operation * op)709 int64_t AddOp::GetArithmeticCount(Operation *op) {
710   int64_t count;
711   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
712 
713   return -1;
714 }
715 
716 //===----------------------------------------------------------------------===//
717 // ConcatenationOp
718 //===----------------------------------------------------------------------===//
719 // TODO(ashwinm): Implement shape inference for Concatenation
720 
721 namespace {
722 
GetConcatenationOpAxis(ConcatenationOp op)723 int64_t GetConcatenationOpAxis(ConcatenationOp op) {
724   auto output_type = op.output().getType().cast<RankedTensorType>();
725   int32_t axis = op.axis();
726   if (axis < 0) axis += output_type.getRank();
727   return axis;
728 }
729 
730 // Verify operand types and the result type:
731 //
732 // 1. Operand type ranks must be equal to the output type rank.
733 //
734 // 2. Operand dimension sizes (except dimension `axis`) must be equal to
735 //    previously seen dimension sizes of the same dimension.
736 //
737 // 3. Sum of operand dimension sizes of the `axis` dimension must be equal to
738 //    the dimension size of the `axis` dimension of output.
739 //
740 // Note: If an operand has unranked tensor type or has dynamic dimension size,
741 // those dimensions will be skipped.
VerifyConcatenationOpTypes(Operation * op,RankedTensorType output_type,ArrayRef<TensorType> operand_types,int64_t axis)742 LogicalResult VerifyConcatenationOpTypes(Operation *op,
743                                          RankedTensorType output_type,
744                                          ArrayRef<TensorType> operand_types,
745                                          int64_t axis) {
746   const int64_t output_rank = output_type.getRank();
747 
748   constexpr int64_t kDynamicSize = -1;
749   SmallVector<int64_t, 4> result_dim_sizes_loc(output_rank, -1);
750   SmallVector<int64_t, 4> result_dim_sizes(output_type.getShape().begin(),
751                                            output_type.getShape().end());
752   result_dim_sizes[axis] = 0;
753 
754   auto FormatLoc = [&result_dim_sizes_loc](int64_t dim) {
755     const int64_t loc = result_dim_sizes_loc[dim];
756     if (loc == -1) return std::string("output");
757     return llvm::formatv("operand #{0}", loc).str();
758   };
759 
760   for (const auto &operand : llvm::enumerate(operand_types)) {
761     auto operand_type = operand.value().dyn_cast<RankedTensorType>();
762     if (!operand_type) {
763       result_dim_sizes[axis] = kDynamicSize;
764       continue;
765     }
766 
767     const int64_t operand_rank = operand_type.getRank();
768     if (operand_rank != output_rank)
769       return op->emitOpError() << "rank of operand #" << operand.index()
770                                << " must be equal to rank of output, expected "
771                                << output_rank << ", got " << operand_rank;
772 
773     for (int64_t dim = 0; dim < output_rank; ++dim) {
774       const int64_t operand_dim_size = operand_type.getDimSize(dim);
775       const int64_t result_dim_size = result_dim_sizes[dim];
776 
777       if (dim == axis) {
778         if (ShapedType::isDynamic(operand_dim_size) ||
779             ShapedType::isDynamic(result_dim_size)) {
780           result_dim_sizes[axis] = kDynamicSize;
781         } else {
782           result_dim_sizes[axis] += operand_dim_size;
783         }
784         continue;
785       }
786 
787       if (ShapedType::isDynamic(operand_dim_size)) continue;
788 
789       if (ShapedType::isDynamic(result_dim_size)) {
790         result_dim_sizes[dim] = operand_dim_size;
791         result_dim_sizes_loc[dim] = operand.index();
792         continue;
793       }
794 
795       if (result_dim_size != operand_dim_size)
796         return op->emitOpError()
797                << "dimension size of dimension #" << dim << " of operand #"
798                << operand.index() << " must be equal to "
799                << "dimension size of dimension #" << dim << " of "
800                << FormatLoc(dim) << ", expected " << result_dim_size << ", got "
801                << operand_dim_size;
802     }
803   }
804 
805   const int64_t output_concated_dim_size = output_type.getDimSize(axis);
806   if (!ShapedType::isDynamic(output_concated_dim_size) &&
807       !ShapedType::isDynamic(result_dim_sizes[axis]) &&
808       result_dim_sizes[axis] != output_concated_dim_size)
809     return op->emitOpError()
810            << "dimension size of dimension #" << axis << " of output "
811            << "must be equal to the sum of dimension sizes of dimension #"
812            << axis << ", expected " << result_dim_sizes[axis] << ", got "
813            << output_concated_dim_size;
814 
815   return success();
816 }
817 
818 // Returns true when all operands are instances of DenseElementsAttr and the
819 // output type has a static shape.
IsConcatenationOpConstFoldable(ConcatenationOp op,ArrayRef<Attribute> operands,RankedTensorType output_type,int64_t axis)820 bool IsConcatenationOpConstFoldable(ConcatenationOp op,
821                                     ArrayRef<Attribute> operands,
822                                     RankedTensorType output_type,
823                                     int64_t axis) {
824   if (operands.empty()) return false;
825   if (!output_type.hasStaticShape()) return false;
826   if (axis < 0) return false;
827 
828   return llvm::all_of(operands, [](Attribute operand) {
829     return operand && operand.isa<DenseElementsAttr>();
830   });
831 }
832 
ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,RankedTensorType output_type,int64_t axis)833 DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,
834                                               RankedTensorType output_type,
835                                               int64_t axis) {
836   const auto outer_dims = output_type.getShape().take_front(axis);
837   const int64_t outer_size = std::accumulate(
838       outer_dims.begin(), outer_dims.end(), 1, std::multiplies<int64_t>());
839 
840   const auto base_inner_dims = output_type.getShape().drop_front(axis + 1);
841   const int64_t base_inner_size =
842       std::accumulate(base_inner_dims.begin(), base_inner_dims.end(), 1,
843                       std::multiplies<int64_t>());
844 
845   // Splits each input operand into outer_size pieces and combines them in
846   // round-robin ordering.
847   std::vector<Attribute> out_attrs(output_type.getNumElements());
848   int64_t out = 0;
849   for (int64_t outer = 0; outer < outer_size; ++outer) {
850     for (auto op : operands) {
851       auto typed_attr = op.cast<TypedAttr>();
852       const int64_t dim_size =
853           typed_attr.getType().cast<RankedTensorType>().getDimSize(axis);
854       const int64_t inner_size = dim_size * base_inner_size;
855 
856       auto input_attrs = op.cast<DenseElementsAttr>().getValues<Attribute>();
857       auto input_iter = input_attrs.begin() + outer * inner_size;
858       for (int64_t inner = 0; inner < inner_size; ++inner)
859         out_attrs[out++] = *input_iter++;
860     }
861   }
862 
863   return DenseElementsAttr::get(output_type, out_attrs);
864 }
865 
866 }  // end anonymous namespace
867 
verify()868 LogicalResult ConcatenationOp::verify() {
869   ConcatenationOp op = *this;
870   auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
871 
872   // If the output type is unranked, there is nothing else to be verified.
873   if (!output_type) return success();
874 
875   const int64_t axis = GetConcatenationOpAxis(op);
876   if (axis < 0 || axis >= output_type.getRank())
877     return op.emitOpError("concatenation dimension must be in [-rank, rank)");
878 
879   SmallVector<TensorType, 4> operand_types;
880   for (Value operand : op.values())
881     operand_types.push_back(operand.getType().cast<TensorType>());
882 
883   return VerifyConcatenationOpTypes(op.getOperation(), output_type,
884                                     operand_types, axis);
885 }
886 
fold(ArrayRef<Attribute> operands)887 OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
888   if (fused_activation_function() == "NONE") {
889     if (auto output_type = output().getType().dyn_cast<RankedTensorType>()) {
890       const int64_t axis = GetConcatenationOpAxis(*this);
891       if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis))
892         return ConstFoldConcatenateOpDense(operands, output_type, axis);
893     }
894   }
895 
896   // Remove all empty values.
897   SmallVector<Value, 4> non_empty_values;
898   for (Value value : this->values()) {
899     const auto shaped_type = value.getType().cast<ShapedType>();
900     if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
901       continue;
902     }
903     non_empty_values.push_back(value);
904   }
905 
906   // All are not empty, do nothing.
907   if (non_empty_values.size() == getNumOperands()) return nullptr;
908 
909   // If only one input is non-empty, just return it as the result of folding.
910   if (non_empty_values.size() == 1) {
911     return non_empty_values[0];
912   }
913 
914   // Otherwise, build a new concatenation op with non-empty values.
915   mlir::OpBuilder builder(getOperation());
916   auto new_concat = builder.create<TFL::ConcatenationOp>(
917       getLoc(), getType(), non_empty_values,
918       builder.getIntegerAttr(builder.getIntegerType(32), axis()),
919       builder.getStringAttr(fused_activation_function()));
920   return new_concat.getResult();
921 }
922 
923 //===----------------------------------------------------------------------===//
924 // CustomOp
925 //===----------------------------------------------------------------------===//
926 
927 // TODO(b/241745316): Confirm that this is always valid
verify()928 mlir::LogicalResult CustomOp::verify() {
929   // Currently, this is always valid as it is a wrapper around a StringRef of 0
930   // or more characters.
931   return success();
932 }
933 
934 //===----------------------------------------------------------------------===//
935 // CustomTfOp
936 //===----------------------------------------------------------------------===//
937 
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attr,RegionRange ranges,SmallVectorImpl<Type> & inferredReturnTypes)938 LogicalResult CustomTfOp::inferReturnTypes(
939     MLIRContext *, Optional<Location> location, ValueRange operands,
940     DictionaryAttr attr, RegionRange ranges,
941     SmallVectorImpl<Type> &inferredReturnTypes) {
942   CustomTfOpAdaptor op(operands, attr, ranges);
943 
944   if (op.getRegions().empty()) return success();
945   auto *real_op = &op.body().front().front();
946   if (llvm::isa<TF::FakeQuantWithMinMaxArgsOp, TF::FakeQuantWithMinMaxVarsOp,
947                 TF::FakeQuantWithMinMaxVarsPerChannelOp>(real_op)) {
948     Value input = *operands.begin();
949     inferredReturnTypes.assign({input.getType()});
950   }
951   return success();
952 }
953 
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)954 bool CustomTfOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
955   if (lhs.empty()) return true;
956   if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
957   if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
958   return true;
959 }
960 
961 //===----------------------------------------------------------------------===//
962 // Gather op
963 //===----------------------------------------------------------------------===//
964 
verify()965 LogicalResult GatherOp::verify() {
966   GatherOp op = *this;
967   ShapedType params_type = op.params().getType().cast<ShapedType>();
968   // TFLite gather kernel supports 1D string input only.
969   if (params_type.getElementType().isa<mlir::TF::StringType>()) {
970     if (params_type.hasRank() && params_type.getRank() != 1) {
971       return op.emitOpError(
972                  "expect 1d input when the given type is string, got ")
973              << params_type;
974     }
975   }
976   return mlir::success();
977 }
978 
979 //===----------------------------------------------------------------------===//
980 // BroadcastToOp
981 //===----------------------------------------------------------------------===//
982 
983 // Canonicalizes BroadcastToOp to ReshapeOp if the input and output has the same
984 // number of elements.
985 struct ConvertBroadcastToReshape : public OpRewritePattern<BroadcastToOp> {
986   using OpRewritePattern<BroadcastToOp>::OpRewritePattern;
987 
matchAndRewritemlir::TFL::ConvertBroadcastToReshape988   LogicalResult matchAndRewrite(BroadcastToOp op,
989                                 PatternRewriter &rewriter) const override {
990     auto input_type = op.input().getType().cast<ShapedType>();
991     auto output_type = op.getType().cast<ShapedType>();
992     if (!input_type.hasStaticShape() || !output_type.hasStaticShape() ||
993         input_type.getNumElements() != output_type.getNumElements()) {
994       return failure();
995     }
996     // Reshape op supports only new shape as I32. Add a cast op to I32 always
997     // to make sure the introduced Reshape Op is a valid one.
998     auto result_type = RankedTensorType::get(
999         op.shape().getType().cast<RankedTensorType>().getShape(),
1000         rewriter.getI32Type());
1001     auto cast_op =
1002         rewriter.create<TFL::CastOp>(op->getLoc(), result_type, op.shape());
1003 
1004     rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.input(),
1005                                            cast_op);
1006     return success();
1007   }
1008 };
1009 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1010 void BroadcastToOp::getCanonicalizationPatterns(RewritePatternSet &results,
1011                                                 MLIRContext *context) {
1012   results.add<ConvertBroadcastToReshape>(context);
1013 }
1014 
1015 //===----------------------------------------------------------------------===//
1016 // FullyConnectedOp
1017 //===----------------------------------------------------------------------===//
1018 
verify()1019 LogicalResult FullyConnectedOp::verify() {
1020   FullyConnectedOp op = *this;
1021   ShapedType input_type = op.input().getType().cast<ShapedType>();
1022   ShapedType filter_type = op.filter().getType().cast<ShapedType>();
1023   if (filter_type.hasRank() && filter_type.getRank() != 2) {
1024     return op.emitOpError("expect 2d filter, got ") << filter_type;
1025   }
1026 
1027   if (!input_type.hasStaticShape() || !filter_type.hasStaticShape()) {
1028     return mlir::success();
1029   }
1030 
1031   // Input's element size must be multiple of parameter's z_in dimension.
1032   const int z_in = filter_type.getDimSize(1);
1033   const int num_input_elements = input_type.getNumElements();
1034   if (z_in != 0 && num_input_elements % z_in != 0) {
1035     return op.emitOpError(llvm::formatv(
1036                "expect 'input' num_elements % {0} == 0, got input type ", z_in))
1037            << input_type;
1038   }
1039 
1040   // TODO(jpienaar): Include more shape verification for SHUFFLED4x16INT8
1041   // format.
1042   if (op.weights_format() == "DEFAULT") {
1043     ShapedType output_type =
1044         (*op.output().begin()).getType().cast<ShapedType>();
1045     if (!output_type.hasStaticShape()) {
1046       return mlir::success();
1047     }
1048 
1049     const int num_output_elements = output_type.getNumElements();
1050     const int z_out = filter_type.getDimSize(0);
1051     if (num_output_elements % z_out != 0) {
1052       return op.emitOpError(llvm::formatv(
1053                  "expect 'output' num_elements % {0} == 0, got ", z_out))
1054              << output_type;
1055     }
1056 
1057     if (z_in != 0 && num_input_elements / z_in != num_output_elements / z_out) {
1058       return op.emitOpError(
1059           "num_input_elements / z_in != num_output_elements / z_out");
1060     }
1061   }
1062 
1063   return mlir::success();
1064 }
1065 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1066 LogicalResult FullyConnectedOp::fold(ArrayRef<Attribute> operands,
1067                                      SmallVectorImpl<OpFoldResult> &results) {
1068   assert(operands.size() == 3);
1069 
1070   // Folding not implemented with any activation function or any weight type
1071   // besides the default.
1072   if (fused_activation_function() != "NONE") return failure();
1073   if (weights_format() != "DEFAULT") return failure();
1074 
1075   // Bias tensor is optional.
1076   const bool has_bias = !(!bias() || bias().getType().isa<NoneType>());
1077 
1078   // Get the tensors.
1079   DenseElementsAttr input_tensor, weights_tensor, bias_tensor;
1080   if (!matchPattern(input(), m_Constant(&input_tensor)) ||
1081       !matchPattern(filter(), m_Constant(&weights_tensor)) ||
1082       (has_bias && !matchPattern(bias(), m_Constant(&bias_tensor)))) {
1083     return failure();
1084   }
1085 
1086   // Get the tensor types.
1087   const auto input_type = input_tensor.getType().cast<ShapedType>();
1088   const auto weights_type = weights_tensor.getType().cast<ShapedType>();
1089   const auto bias_type =
1090       has_bias ? bias_tensor.getType().cast<ShapedType>() : ShapedType{};
1091 
1092   const auto output_type = getType(0).cast<ShapedType>();
1093 
1094   // Folding only implemented for float tensors.
1095   if (!input_type.getElementType().isF32() ||
1096       !weights_type.getElementType().isF32() ||
1097       !output_type.getElementType().isF32() ||
1098       (has_bias && !bias_type.getElementType().isF32())) {
1099     return failure();
1100   }
1101 
1102   // Folding only implemented for static shapes
1103   if (!input_type.hasStaticShape() || !weights_type.hasStaticShape() ||
1104       (has_bias && !bias_type.hasStaticShape())) {
1105     return failure();
1106   }
1107 
1108   // Folding only implemented for 1D input, 2D weights and 1D bias
1109   if (input_type.getShape().size() != 1 ||
1110       weights_type.getShape().size() != 2 ||
1111       (has_bias && bias_type.getShape().size() != 1)) {
1112     return failure();
1113   }
1114 
1115   // Get the sizes
1116   const auto input_size = input_type.getNumElements();
1117   const auto output_size = output_type.getNumElements();
1118 
1119   // Get iterators to the tensors.
1120   const auto input_values_it = input_tensor.getValues<float>().begin();
1121   const auto weights_values_ptr = weights_tensor.getValues<float>().begin();
1122   auto weights_row_it = weights_values_ptr;
1123   // The 'else' case could be nullptr, but the types don't match.
1124   auto bias_values_it =
1125       has_bias ? bias_tensor.getValues<float>().begin() : input_values_it;
1126 
1127   // Do the actual folding, one output at a time.
1128   std::vector<float> result_values;
1129   result_values.reserve(output_size);
1130 
1131   for (int i = 0; i < output_size; ++i) {
1132     // Dot product with Kahan/Neumaier summation to minimize numeric errors.
1133     float sum = has_bias ? *bias_values_it : 0.0f;
1134     float compensation = 0.0f;
1135     for (int j = 0; j < input_size; ++j) {
1136       const float addend = input_values_it[j] * weights_row_it[j];
1137       const float new_sum = sum + addend;
1138       // DO NOT enable -funsafe-math-optimizations here.
1139       // There is a test detecting unsafe optimizations.
1140       // Unsafe math optimizations can reorder float formulas, and set the
1141       // compensation to constant 0. The formula must be evaluated as written
1142       // for the algorithm to work.
1143       // (Note: -ffast-math is a superset of -funsafe-math-optimizations.)
1144       if (std::abs(sum) >= std::abs(addend)) {
1145         compensation += (sum - new_sum) + addend;
1146       } else {
1147         compensation += (addend - new_sum) + sum;
1148       }
1149       sum = new_sum;
1150     }
1151     result_values.push_back(sum + compensation);
1152     weights_row_it += input_size;
1153     bias_values_it++;
1154   }
1155 
1156   // Set result tensor
1157   const auto folded =
1158       DenseElementsAttr::get(output_type, ArrayRef<float>(result_values));
1159   results.assign({folded});
1160 
1161   return success();
1162 }
1163 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1164 void FullyConnectedOp::getCanonicalizationPatterns(RewritePatternSet &results,
1165                                                    MLIRContext *context) {
1166   results.add<RemoveOptionalZeroBias<FullyConnectedOp>>(context);
1167 }
1168 
GetArithmeticCount(Operation * op)1169 int64_t FullyConnectedOp::GetArithmeticCount(Operation *op) {
1170   int64_t count;
1171   if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
1172           op, &count))
1173     return count;
1174 
1175   return -1;
1176 }
1177 
1178 //===----------------------------------------------------------------------===//
1179 // Conv2DOp
1180 //===----------------------------------------------------------------------===//
1181 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1182 void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
1183                                            MLIRContext *context) {
1184   // TODO(b/180121750): Enable the pattern after the integration tests are
1185   // fixed.
1186   // results.add<RemoveOptionalZeroBias<Conv2DOp>>(context);
1187 }
1188 
ComputeConvWindowedOutputSize(int64_t input_size,int64_t filter_size,int64_t dilation_rate,int64_t stride,tensorflow::Padding padding,int64_t * output_size)1189 static LogicalResult ComputeConvWindowedOutputSize(
1190     int64_t input_size, int64_t filter_size, int64_t dilation_rate,
1191     int64_t stride, tensorflow::Padding padding, int64_t *output_size) {
1192   int64_t pad_low;
1193   int64_t pad_high;
1194 
1195   tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
1196       input_size, filter_size, dilation_rate, stride, padding, output_size,
1197       &pad_low, &pad_high);
1198   // Return failure if expected_output_size could not be calculated.
1199   if (!status.ok()) return failure();
1200   return success();
1201 }
1202 
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1203 LogicalResult Conv2DOp::inferReturnTypes(
1204     MLIRContext *, Optional<Location> location, ValueRange operands,
1205     DictionaryAttr attr, RegionRange,
1206     SmallVectorImpl<Type> &inferredReturnTypes) {
1207   Conv2DOpAdaptor op(operands, attr);
1208 
1209   const Value input = op.input();
1210   const Value filter = op.filter();
1211 
1212   const RankedTensorType input_ty =
1213       input.getType().dyn_cast_or_null<RankedTensorType>();
1214   const RankedTensorType filter_ty =
1215       filter.getType().dyn_cast_or_null<RankedTensorType>();
1216   // If indeed both input type & filter type are ranked type and have ranks.
1217   // We will need to check their ranks are valid.
1218   if ((input_ty && input_ty.hasRank() && input_ty.getRank() != 4) ||
1219       (filter_ty && filter_ty.hasRank() && filter_ty.getRank() != 4)) {
1220     return emitOptionalError(location, "Invalid ranks");
1221   }
1222 
1223   // If either input or filter is unranked, we will just return unranked output
1224   // shape.
1225   if (!input_ty || !filter_ty || !input_ty.hasRank() || !filter_ty.hasRank()) {
1226     Type result_type;
1227     result_type = UnrankedTensorType::get(
1228         input.getType().cast<ShapedType>().getElementType());
1229     inferredReturnTypes.assign({result_type});
1230     return success();
1231   }
1232 
1233   auto stride_h = op.stride_hAttr().getInt();
1234   auto stride_w = op.stride_wAttr().getInt();
1235   auto dilation_h = op.dilation_h_factorAttr().getInt();
1236   auto dilation_w = op.dilation_w_factorAttr().getInt();
1237 
1238   // We don't have EXPLICIT PADDING in TfLite.
1239   auto paddings = op.padding();
1240   tensorflow::Padding padding;
1241   auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding);
1242   if (!padding_is_valid.ok()) {
1243     return emitOptionalError(location, "invalid padding format provided");
1244   }
1245 
1246   // Output always have rank 4. All dimensions are initialized to
1247   // dynamic size and can be partially inferred.
1248   // TFL's conv2d is always NHWC format & the filter is OHWI.
1249   SmallVector<int64_t, 4> return_shape(4, ShapedType::kDynamicSize);
1250   return_shape[0] = input_ty.getDimSize(0);
1251   return_shape[3] = filter_ty.getDimSize(0);
1252 
1253   // Spatial dimensions can be inferred only when both input and filter are
1254   // ranked because we need to get their spatial dimensions.
1255 
1256   // Height.
1257   if (!input_ty.isDynamicDim(1) && !filter_ty.isDynamicDim(1)) {
1258     int64_t output_height;
1259     if (failed(ComputeConvWindowedOutputSize(
1260             input_ty.getDimSize(1), filter_ty.getDimSize(1), dilation_h,
1261             stride_h, padding, &output_height))) {
1262       return failure();
1263     }
1264     return_shape[1] = output_height;
1265   }
1266 
1267   // Width.
1268   if (!input_ty.isDynamicDim(2) && !filter_ty.isDynamicDim(2)) {
1269     int64_t output_width;
1270     if (failed(ComputeConvWindowedOutputSize(
1271             input_ty.getDimSize(2), filter_ty.getDimSize(2), dilation_w,
1272             stride_w, padding, &output_width))) {
1273       return failure();
1274     }
1275     return_shape[2] = output_width;
1276   }
1277 
1278   auto result_type =
1279       mlir::RankedTensorType::get(return_shape, input_ty.getElementType());
1280 
1281   inferredReturnTypes.assign({result_type});
1282   return success();
1283 }
1284 
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)1285 bool Conv2DOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
1286   if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
1287   if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
1288   return true;
1289 }
1290 
GetArithmeticCount(Operation * op)1291 int64_t Conv2DOp::GetArithmeticCount(Operation *op) {
1292   int64_t count;
1293   if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
1294           op, &count))
1295     return count;
1296 
1297   return -1;
1298 }
1299 
1300 //===----------------------------------------------------------------------===//
1301 // DepthwiseConv2DO
1302 //===----------------------------------------------------------------------===//
1303 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1304 void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
1305                                                     MLIRContext *context) {
1306   // TODO(b/180121750): Enable the pattern after the integration tests are
1307   // fixed.
1308   // results.add<RemoveOptionalZeroBias<DepthwiseConv2DOp>>(context);
1309 }
1310 
GetArithmeticCount(Operation * op)1311 int64_t DepthwiseConv2DOp::GetArithmeticCount(Operation *op) {
1312   int64_t count;
1313   if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
1314           op, &count))
1315     return count;
1316 
1317   return -1;
1318 }
1319 
1320 //===----------------------------------------------------------------------===//
1321 // GatherOp
1322 //===----------------------------------------------------------------------===//
1323 
BuildGatherOp(OpBuilder * builder,OperationState & result,Value params,Value indices,IntegerAttr axis,IntegerAttr batch_dims)1324 static void BuildGatherOp(OpBuilder *builder, OperationState &result,
1325                           Value params, Value indices, IntegerAttr axis,
1326                           IntegerAttr batch_dims) {
1327   auto params_type = params.getType().cast<TensorType>();
1328   auto indices_type = indices.getType().cast<TensorType>();
1329 
1330   // If params/indices is unranked, then output is unranked.
1331   if (!params_type.hasRank() || !indices_type.hasRank())
1332     return TFL::GatherOp::build(
1333         *builder, result, UnrankedTensorType::get(params_type.getElementType()),
1334         params, indices, axis, batch_dims);
1335 
1336   int64_t params_rank = params_type.getRank();
1337   int64_t indices_rank = indices_type.getRank();
1338 
1339   // params rank is guaranteed to be at least 1.
1340   // Produces an output tensor with shape:
1341   // params.shape[:axis] + indices.shape + params.shape[axis + 1:]
1342   std::vector<int64_t> shape(params_type.getShape());
1343   int64_t axis_i = axis.getInt();
1344 
1345   // For neg axis values, we wrap around params, e.g. axis = -1 => params[:-1]
1346   if (axis_i < 0) {
1347     axis_i += params_rank;
1348   }
1349 
1350   // params must be at least rank axis + 1
1351   if (params_rank < axis_i + 1) {
1352     emitError(result.location, "params must be at least rank axis + 1");
1353   }
1354 
1355   int64_t batch_dims_i = batch_dims.getInt();
1356   if (batch_dims_i < 0) {
1357     batch_dims_i += indices_rank;
1358   }
1359 
1360   if (batch_dims_i > axis_i) {
1361     emitError(result.location,
1362               "axis should be bigger than or equal to batch_dims");
1363   }
1364   if (batch_dims_i >= params_rank || batch_dims_i > indices_rank) {
1365     emitError(result.location,
1366               "batch_dims must be smaller than params' rank and smaller than "
1367               "or equal to indices'rank");
1368   }
1369   for (int i = 0; i < batch_dims_i; ++i) {
1370     if (indices_type.getShape()[i] != params_type.getShape()[i]) {
1371       emitError(result.location,
1372                 "batch dimensions of params must be equal to batch dimensions "
1373                 "of indices");
1374     }
1375   }
1376 
1377   if ((indices_rank == 0) || (indices_rank == batch_dims_i)) {
1378     // Scalar indices (output is rank(params) - 1).
1379     // Erase shape[axis]
1380     shape.erase(shape.begin() + axis_i);
1381   } else if (indices_rank == 1) {
1382     // Vector indices (output is rank(params)).
1383     // Copy indices.shape into params.shape[axis]
1384     std::copy(std::begin(indices_type.getShape()),
1385               std::end(indices_type.getShape()), std::begin(shape) + axis_i);
1386   } else {
1387     // Higher rank indices (output is rank(params) + rank(indices) - 1).
1388     shape.resize(params_rank + indices_rank - 1 - batch_dims_i);
1389     // Copy params.shape[axis + 1: ] into shape[axis + indices_rank:]
1390     std::copy(std::begin(params_type.getShape()) + axis_i + 1,
1391               std::end(params_type.getShape()),
1392               std::begin(shape) + axis_i + indices_rank - batch_dims_i);
1393 
1394     // Copy indices.shape into params.shape[axis]
1395     std::copy(std::begin(indices_type.getShape()) + batch_dims_i,
1396               std::end(indices_type.getShape()), std::begin(shape) + axis_i);
1397   }
1398 
1399   TFL::GatherOp::build(
1400       *builder, result,
1401       RankedTensorType::get(shape, params_type.getElementType()), params,
1402       indices, axis, batch_dims);
1403 }
1404 
1405 //===----------------------------------------------------------------------===//
1406 // ScatterNdOp
1407 //===----------------------------------------------------------------------===//
1408 
verify()1409 mlir::LogicalResult ScatterNdOp::verify() {
1410   ScatterNdOp op = *this;
1411   auto indices = op.indices();
1412   auto updates = op.updates();
1413   auto shape = op.shape();
1414   auto output = op.output();
1415 
1416   auto updates_type = updates.getType().cast<ShapedType>();
1417   auto indices_type = indices.getType().cast<ShapedType>();
1418 
1419   if (!indices_type.hasStaticShape() || !updates_type.hasStaticShape()) {
1420     return success();
1421   }
1422 
1423   // Checks if the shape of `updates` is a tensor of shape
1424   // `indices.shape[:-1] + shape[indices.shape[-1]:]`, as described in
1425   // ScatterNd op description.
1426 
1427   auto outer_dims = indices_type.getRank() - 1;
1428   auto outermost_dim = indices_type.getDimSize(outer_dims);
1429   // Checks whether the first `outer_dims` dimensions of `indices` and
1430   // `updates` are equal.
1431   for (auto i = 0; i < outer_dims; i++) {
1432     if (indices_type.getDimSize(i) != updates_type.getDimSize(i)) {
1433       return op.emitOpError()
1434              << "indices.Dims(" << i << ") == " << indices_type.getDimSize(i)
1435              << ", but updates.Dims(" << i
1436              << ") == " << updates_type.getDimSize(i);
1437     }
1438   }
1439 
1440   auto output_type = output.getType().cast<ShapedType>();
1441   auto shape_type = shape.getType().cast<ShapedType>();
1442   if (shape_type.hasStaticShape()) {
1443     // Check the rank of `shape`.
1444     auto output_rank = outermost_dim + updates_type.getRank() - outer_dims;
1445     if (shape_type.getDimSize(0) != output_rank) {
1446       return op.emitOpError()
1447              << "shape must be a vector of length " << output_rank;
1448     }
1449     if (output_type.hasRank()) {
1450       if (output_type.getRank() != output_rank) {
1451         return op.emitOpError()
1452                << "output must have the same rank with the length of shape = "
1453                << output_rank;
1454       }
1455     }
1456   }
1457 
1458   DenseIntElementsAttr shape_value;
1459   if (matchPattern(shape, m_Constant(&shape_value))) {
1460     for (const auto shape_elem : shape_value) {
1461       if (shape_elem.getSExtValue() <= 0) {
1462         return op.emitOpError("all elements of shape must be > 0");
1463       }
1464     }
1465 
1466     // Checks whether the last `(shape_type.getDimSize(0) - outermost_dim)`
1467     // dimensions of `updates` and `shape` are equal.
1468     for (const auto &shape_it : llvm::enumerate(shape_value)) {
1469       int64_t i = shape_it.index();
1470       auto value = shape_it.value().getSExtValue();
1471       if (i >= outermost_dim) {
1472         auto corresponding_dim = i - outermost_dim + outer_dims;
1473         if (value != updates_type.getDimSize(corresponding_dim)) {
1474           return op.emitOpError()
1475                  << "updates.Dims(" << i
1476                  << ") == " << updates_type.getDimSize(corresponding_dim)
1477                  << ", but shape[" << i << "] == " << value;
1478         }
1479       }
1480     }
1481 
1482     // Checks if the output has the shape specified by `shape`.
1483     if (output_type.hasStaticShape()) {
1484       for (const auto &shape_it : llvm::enumerate(shape_value)) {
1485         int i = shape_it.index();
1486         auto value = shape_it.value().getSExtValue();
1487         if (output_type.getDimSize(i) != value) {
1488           return op.emitOpError()
1489                  << "output shape [" << output_type.getShape()
1490                  << "] must be equal to the value of shape " << shape_value;
1491         }
1492       }
1493     }
1494   }
1495   return success();
1496 }
1497 
1498 //===----------------------------------------------------------------------===//
1499 // MulOp
1500 //===----------------------------------------------------------------------===//
1501 
fold(ArrayRef<Attribute> operands)1502 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1503   // TODO(b/142478136): Handle fused ops.
1504   if (fused_activation_function() != "NONE") return {};
1505 
1506   // This function is performance critical for op fusion patterns, e.g.
1507   // FuseBinaryOpToPrecedingAffine and FuseMulOrDivWithConv2dOrDepthwiseConv2d.
1508   // So a few specializations are provided to evaluate the math operation
1509   // more efficiently.
1510 
1511   // Specialization for f32 type.
1512   if (getType().cast<ShapedType>().getElementType().isF32()) {
1513     return ConstFoldBinaryOp<FloatAttr, float>(
1514         getType(), operands[0], operands[1],
1515         [](float a, float b) { return a * b; });
1516   }
1517 
1518   // Specialization for bf16 type.
1519   if (getType().cast<ShapedType>().getElementType().isBF16()) {
1520     return ConstFoldBinaryOp<FloatAttr, Eigen::bfloat16>(
1521         getType(), operands[0], operands[1],
1522         [](Eigen::bfloat16 a, Eigen::bfloat16 b) { return a * b; });
1523   }
1524 
1525   // Generic fallback with APFloat
1526   return ConstFoldBinaryOp(
1527       getType(), operands, [](APFloat a, APFloat b) { return a * b; },
1528       [](APInt a, APInt b) { return a * b; });
1529 }
1530 
GetArithmeticCount(Operation * op)1531 int64_t MulOp::GetArithmeticCount(Operation *op) {
1532   int64_t count;
1533   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
1534 
1535   return -1;
1536 }
1537 
1538 //===----------------------------------------------------------------------===//
1539 // DivOp
1540 //===----------------------------------------------------------------------===//
1541 
fold(ArrayRef<Attribute> operands)1542 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1543   // TODO(b/142478136): Handle fused ops.
1544   if (fused_activation_function() != "NONE") return {};
1545   return ConstFoldBinaryOp(
1546       getType(), operands, [](APFloat a, APFloat b) { return a / b; },
1547       [](APInt a, APInt b) { return a.sdiv(b); });
1548 }
1549 
GetArithmeticCount(Operation * op)1550 int64_t DivOp::GetArithmeticCount(Operation *op) {
1551   int64_t count;
1552   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
1553 
1554   return -1;
1555 }
1556 
1557 //===----------------------------------------------------------------------===//
1558 // PackOp
1559 //===----------------------------------------------------------------------===//
1560 
1561 // TODO(b/133486129): Implement shape inference for pack
1562 
verify()1563 mlir::LogicalResult PackOp::verify() {
1564   PackOp op = *this;
1565   // TODO(antiagainst): Implement other checks as in
1566   // tensorflow/lite/kernels/pack.cc
1567 
1568   if (op.getOperation()->getNumOperands() != op.values_count())
1569     return op.emitOpError("input count should match 'values_count' attribute");
1570 
1571   Value operand0 = op.getOperand(0);
1572   auto input_type = operand0.getType().cast<ShapedType>();
1573 
1574   // Check axis bounds.
1575   if (input_type.hasRank()) {
1576     int32_t axis_value = op.axis();
1577     if (axis_value < 0) axis_value += input_type.getRank() + 1;
1578     if (axis_value < 0 || axis_value >= input_type.getRank() + 1)
1579       return op.emitOpError()
1580              << "op attribute 'axis' should be in range [-rank - 1, rank + 1), "
1581              << "got rank = " << input_type.getRank()
1582              << ", and axis = " << op.axis();
1583   }
1584 
1585   // Make sure all inputs have the same shape and element type.
1586   // TODO(b/135032063): Simplify once fixed.
1587   for (Type operand_type : op.getOperandTypes()) {
1588     if (failed(mlir::verifyCompatibleShape(input_type, operand_type)))
1589       return op.emitOpError("operands should be of the same type. got ")
1590              << input_type << ", " << operand_type;
1591   }
1592 
1593   return success();
1594 }
1595 
1596 //===----------------------------------------------------------------------===//
1597 // PReluOp
1598 //===----------------------------------------------------------------------===//
1599 
verify()1600 mlir::LogicalResult PReluOp::verify() {
1601   PReluOp op = *this;
1602   auto input_type = op.input().getType().cast<ShapedType>();
1603   auto alpha_type = op.alpha().getType().cast<ShapedType>();
1604   auto output_type = op.output().getType().cast<ShapedType>();
1605 
1606   if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
1607     if (input_type.getRank() != alpha_type.getRank() + 1) {
1608       return op.emitOpError("'alpha' should have one less rank than 'input'.");
1609     }
1610 
1611     // Check if alpha is broadcastable
1612     for (int i = 0; i < alpha_type.getRank(); i++) {
1613       if (alpha_type.getDimSize(i) != input_type.getDimSize(i + 1) &&
1614           alpha_type.getDimSize(i) != 1) {
1615         return op.emitOpError(
1616             llvm::formatv("'alpha' is not broadcastable at dimension {0}.", i));
1617       }
1618     }
1619   }
1620 
1621   if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
1622     if (input_type.getRank() != output_type.getRank()) {
1623       return op.emitOpError("'input' and 'output' should have the same rank.");
1624     }
1625 
1626     // Check if input and output shapes are same
1627     for (int i = 0; i < input_type.getRank(); i++) {
1628       if (input_type.getDimSize(i) != output_type.getDimSize(i)) {
1629         return op.emitOpError(
1630             "'input' and 'output' should have the same shape.");
1631       }
1632     }
1633   }
1634   return success();
1635 }
1636 
1637 //===----------------------------------------------------------------------===//
1638 // ReshapeOp
1639 //===----------------------------------------------------------------------===//
1640 
1641 namespace {
1642 // This pattern matches and merges a tfl.reshape under the following
1643 // condition:
1644 // * The input's defining op is another tfl.reshape.
1645 // TODO(antiagainst): This pattern probably should be moved to the peephole
1646 // category, after we have the infra for peephole passes.
1647 struct RemoveAdjacentReshape : public RewritePattern {
RemoveAdjacentReshapemlir::TFL::__anon216e30ea0f11::RemoveAdjacentReshape1648   explicit RemoveAdjacentReshape(MLIRContext *context)
1649       : RewritePattern(ReshapeOp::getOperationName(), 1, context) {}
1650 
matchmlir::TFL::__anon216e30ea0f11::RemoveAdjacentReshape1651   LogicalResult match(Operation *op) const override {
1652     auto thisOp = cast<ReshapeOp>(op);
1653     auto prevOp = thisOp.getOperand(0).getDefiningOp();
1654     return isa_and_nonnull<ReshapeOp>(prevOp) ? success() : failure();
1655   }
1656 
rewritemlir::TFL::__anon216e30ea0f11::RemoveAdjacentReshape1657   void rewrite(Operation *op, PatternRewriter &rewriter) const override {
1658     auto thisOp = cast<ReshapeOp>(op);
1659     auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0).getDefiningOp());
1660 
1661     // Replace
1662     //   %1 = "tfl.reshape"(%0, %shape0)
1663     //   %2 = "tfl.reshape"(%1, %shape1)
1664     // With
1665     //   %2 = "tfl.reshape"(%0, %shape1)
1666     rewriter.replaceOpWithNewOp<ReshapeOp>(
1667         op, thisOp.getType(), prevOp.getOperand(0), thisOp.getOperand(1));
1668   }
1669 };
1670 
1671 // The kernel expects an 1-D tensor for the shape operand if it presents. If all
1672 // the dimensions are '1's except the last dimension, it will be reshaped to a
1673 // 1-D tensor.
1674 // Note that this pattern doesn't check or change the content of the shape
1675 // tensor.
1676 struct ConvertShapeTo1D : public OpRewritePattern<ReshapeOp> {
1677   using OpRewritePattern<ReshapeOp>::OpRewritePattern;
1678 
matchAndRewritemlir::TFL::__anon216e30ea0f11::ConvertShapeTo1D1679   LogicalResult matchAndRewrite(ReshapeOp reshape,
1680                                 PatternRewriter &rewriter) const override {
1681     if (!reshape.shape().hasOneUse()) return failure();
1682 
1683     DenseIntElementsAttr shape;
1684     if (!matchPattern(reshape.shape(), m_Constant(&shape))) {
1685       return failure();
1686     }
1687     // It is already a 1-D constant, no change.
1688     auto old_shape = shape.getType().getShape();
1689     if (old_shape.size() == 1) {
1690       return failure();
1691     }
1692     // Verify all the leading dimensions are length one, except the last one.
1693     for (auto it = ++old_shape.rbegin(); it != old_shape.rend(); ++it) {
1694       if (*it != 1) {
1695         reshape->emitError(
1696             "Non-vector shape input is used, might cause runtime error");
1697         return failure();
1698       }
1699     }
1700     auto new_shape = shape.reshape(RankedTensorType::get(
1701         {*old_shape.rbegin()}, shape.getType().getElementType()));
1702     rewriter.replaceOpWithNewOp<TFL::ConstOp>(reshape.shape().getDefiningOp(),
1703                                               new_shape);
1704     return success();
1705   }
1706 };
1707 
InputOutputHasSameShape(mlir::Type input_type,mlir::Type output_type)1708 bool InputOutputHasSameShape(mlir::Type input_type, mlir::Type output_type) {
1709   auto input_shaped_type = input_type.dyn_cast_or_null<ShapedType>();
1710   if (!input_shaped_type || !input_shaped_type.hasStaticShape()) return false;
1711 
1712   auto output_shaped_type = output_type.dyn_cast_or_null<ShapedType>();
1713   if (!output_shaped_type || !output_shaped_type.hasStaticShape()) return false;
1714 
1715   return input_shaped_type == output_shaped_type;
1716 }
1717 
1718 }  // end anonymous namespace
1719 
fold(ArrayRef<Attribute> operands)1720 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
1721   // Remove identity reshape with both static result and input shape.
1722   auto result_type = getType().cast<ShapedType>();
1723   auto input_type = getOperand(0).getType().cast<ShapedType>();
1724   if (InputOutputHasSameShape(input_type, result_type)) return input();
1725 
1726   // Constant folding
1727   if (auto dense_elements = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
1728     // If the result type isn't static, tries to derive the result type from
1729     // the #2 operand.
1730     if (!result_type.hasStaticShape()) {
1731       auto shape_elements = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1732       if (!shape_elements) return nullptr;
1733 
1734       SmallVector<int64_t, 4> shape_data;
1735       for (const auto &it : shape_elements.getValues<APInt>()) {
1736         shape_data.push_back(it.getSExtValue());
1737       }
1738       result_type =
1739           RankedTensorType::get(shape_data, input_type.getElementType());
1740     }
1741     return dense_elements.reshape(result_type);
1742   }
1743 
1744   return nullptr;
1745 }
1746 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1747 void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1748                                             MLIRContext *context) {
1749   results.add<RemoveAdjacentReshape, ConvertShapeTo1D>(context);
1750 }
1751 
1752 using ReshapeErrorHandler =
1753     llvm::function_ref<LogicalResult(const llvm::Twine &)>;
1754 
GetReshapeOutputType(Value input,Value shape,ReshapeErrorHandler error_handler,TensorType & output_ty)1755 LogicalResult GetReshapeOutputType(Value input, Value shape,
1756                                    ReshapeErrorHandler error_handler,
1757                                    TensorType &output_ty) {
1758   auto input_ty = input.getType().cast<TensorType>();
1759   auto element_ty = input_ty.getElementType();
1760   output_ty = UnrankedTensorType::get(element_ty);
1761 
1762   auto shape_ty = shape.getType().dyn_cast<RankedTensorType>();
1763   if (!shape_ty) return success();
1764   if (shape_ty.getRank() != 1)
1765     return error_handler(llvm::formatv(
1766         "requires 'shape' to be rank 1, but got {0}", shape_ty.getRank()));
1767 
1768   DenseIntElementsAttr shape_attr;
1769   if (!matchPattern(shape, m_Constant(&shape_attr))) {
1770     // If only shape of `shape` is known, return ranked but dynamic output
1771     // shape.
1772     if (shape_ty.hasStaticShape()) {
1773       llvm::SmallVector<int64_t, 8> dynamic_shape(shape_ty.getDimSize(0),
1774                                                   ShapedType::kDynamicSize);
1775       output_ty = RankedTensorType::get(dynamic_shape, element_ty);
1776     }
1777     return success();
1778   }
1779 
1780   // Detect if reshape output shape is folded.
1781   bool shape_ty_zero_dim = false;
1782   int unknown_index = -1;
1783   // The product of constant shape argument excluding unknown dimension.
1784   int64_t shape_ty_size = 1;
1785   llvm::SmallVector<int64_t, 8> output_ty_shape;
1786   output_ty_shape.reserve(shape_attr.getNumElements());
1787   for (const auto &dim : llvm::enumerate(shape_attr.getValues<APInt>())) {
1788     const int64_t size = dim.value().getSExtValue();
1789     if (size == ShapedType::kDynamicSize) {
1790       if (unknown_index != -1)
1791         return error_handler(llvm::formatv(
1792             "requires 'shape' to have at most one dynamic dimension, but got "
1793             "multiple dynamic dimensions at indices {0} and {1}. You need to "
1794             "set up the unspecified size(s) to avoid this problem, for example,"
1795             "setting batch size in keras model or setting unspecified input "
1796             "size(s) with fixed ones.",
1797             unknown_index, dim.index()));
1798 
1799       unknown_index = dim.index();
1800     } else if (size == 0) {
1801       shape_ty_zero_dim = true;
1802     } else if (size > 0) {
1803       shape_ty_size *= size;
1804     } else {
1805       return error_handler(
1806           llvm::formatv("requires 'shape' to have dimensions greater than -1, "
1807                         "but got {0} at index {1}",
1808                         size, dim.index()));
1809     }
1810     output_ty_shape.push_back(size);
1811   }
1812 
1813   if (!input_ty.hasStaticShape()) {
1814     output_ty = RankedTensorType::get(output_ty_shape, element_ty);
1815     return success();
1816   }
1817 
1818   // Compute the value of the unknown dimension.
1819   if (unknown_index != -1) {
1820     // Compute number of elements in tensor shape.
1821     int64_t input_ty_size = 1;
1822     bool input_ty_zero_dim = false;
1823     for (const auto &dim : input_ty.getShape()) {
1824       if (dim > 0 || !shape_ty_zero_dim) {
1825         input_ty_size *= dim;
1826       } else {
1827         input_ty_zero_dim = true;
1828       }
1829     }
1830 
1831     const int64_t missing_dim = input_ty_size / shape_ty_size;
1832     if (!input_ty_zero_dim && shape_ty_size * missing_dim != input_ty_size)
1833       return error_handler(
1834           llvm::formatv("requires 'input' number of elements be a multiple of "
1835                         "{0}, but got {1}",
1836                         shape_ty_size, input_ty_size));
1837 
1838     // Set the unknown dimension such that total number of elements remain
1839     // constant.
1840     output_ty_shape[unknown_index] = missing_dim;
1841   }
1842 
1843   output_ty = RankedTensorType::get(output_ty_shape, element_ty);
1844 
1845   return success();
1846 }
1847 
verify()1848 mlir::LogicalResult ReshapeOp::verify() {
1849   ReshapeOp op = *this;
1850   auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult {
1851     return op.emitOpError() << message;
1852   };
1853   TensorType expected_ty;
1854   if (failed(GetReshapeOutputType(op.input(), op.shape(), error_handler,
1855                                   expected_ty)))
1856     return failure();
1857 
1858   auto output_ty = op.getType().dyn_cast<RankedTensorType>();
1859   if (!output_ty) return success();
1860   auto input_ty = op.input().getType().cast<TensorType>();
1861   if (output_ty.hasStaticShape() && input_ty.hasStaticShape()) {
1862     const int64_t output_ty_size = output_ty.getNumElements();
1863     const int64_t input_ty_size = input_ty.getNumElements();
1864     if (input_ty_size != output_ty_size)
1865       return op.emitOpError() << "requires 'output' number of elements to "
1866                                  "match 'input' number of elements, but got "
1867                               << output_ty_size << " and " << input_ty_size;
1868   }
1869 
1870   if (!TF::AreCastCompatible({output_ty, expected_ty}))
1871     return op.emitOpError()
1872            << "requires 'output' type " << output_ty
1873            << " to be cast compatible with expected type " << expected_ty;
1874 
1875   return success();
1876 }
1877 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1878 LogicalResult ReshapeOp::inferReturnTypes(
1879     MLIRContext *context, Optional<Location> location, ValueRange operands,
1880     DictionaryAttr attr, RegionRange,
1881     SmallVectorImpl<Type> &inferredReturnTypes) {
1882   ReshapeOpAdaptor op(operands, attr);
1883   const Value input = op.input();
1884   const Value shape = op.shape();
1885 
1886   auto error_handler = [&](const llvm::Twine &message) -> LogicalResult {
1887     // A dummy error handler.
1888     // Errors when computing the output shape will be raised in
1889     // ReshapeOp::verify call.
1890     return failure();
1891   };
1892   TensorType output_type;
1893   if (GetReshapeOutputType(input, shape, error_handler, output_type)
1894           .succeeded()) {
1895     inferredReturnTypes.assign({output_type});
1896     return success();
1897   }
1898   Type result_type;
1899   result_type = UnrankedTensorType::get(
1900       input.getType().cast<ShapedType>().getElementType());
1901   inferredReturnTypes.assign({result_type});
1902   return success();
1903 }
1904 
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)1905 bool ReshapeOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
1906   if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
1907   if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
1908   return true;
1909 }
1910 
1911 //===----------------------------------------------------------------------===//
1912 // PackOp
1913 //===----------------------------------------------------------------------===//
1914 
1915 // Remove redundant unpack pack op.
1916 // If a unpack op is followed by a pack op, we can remove the pack op, if the
1917 // unpack op is only consumed by the pack op, it will be removed as well.
1918 // An example illustration is:
1919 //                  Unpack [5, 8, 9], axis = 1
1920 //                /       \
1921 //            value  ...  value [5, 9], 8 values in total
1922 //              \           /
1923 //                 pack,   axis = 1
1924 //                   |
1925 //               value   [5, 8, 9]
1926 //
1927 //   This can actually be simplified into just:
1928 //
1929 //           =>   Value [5, 8, 9]
1930 // TODO(b/133341698): Move to tablegen when variadic is supported.
1931 struct RemoveRedundantUnpackPack : public RewritePattern {
RemoveRedundantUnpackPackmlir::TFL::RemoveRedundantUnpackPack1932   explicit RemoveRedundantUnpackPack(MLIRContext *context)
1933       : RewritePattern(PackOp::getOperationName(), 2, context) {}
1934 
matchAndRewritemlir::TFL::RemoveRedundantUnpackPack1935   LogicalResult matchAndRewrite(Operation *op,
1936                                 PatternRewriter &rewriter) const override {
1937     TFL::PackOp pack_op = cast<TFL::PackOp>(op);
1938     Operation *first_input = pack_op.getOperand(0).getDefiningOp();
1939     if (!first_input) return failure();
1940     auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
1941     if (!input_unpack_op) return failure();
1942 
1943     // The unpack & pack should have the same axis & num inputs/outputs.
1944     if (pack_op.axis() != input_unpack_op.axis() ||
1945         pack_op.values_count() != input_unpack_op.num())
1946       return failure();
1947 
1948     const int total_pack_inputs = pack_op.getNumOperands();
1949     const int num_results = input_unpack_op.getNumResults();
1950     if (total_pack_inputs != num_results) return failure();
1951     for (auto input_output :
1952          llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) {
1953       Value pack_input = std::get<0>(input_output);
1954       Value unpack_output = std::get<1>(input_output);
1955       // Make sure the ordering is the same for the pack op & unpack op.
1956       if (pack_input != unpack_output) return failure();
1957     }
1958 
1959     // Replace the pack's output to the unpack's input.
1960     rewriter.replaceOp(pack_op, input_unpack_op.getOperand());
1961     // At this point, we don't manually remove the redundant pack op & unpack op
1962     // (we cannot actually), but trust the PatterRewriter to garbage collect
1963     // these two ops.
1964     return success();
1965   }
1966 };
1967 
1968 // Replace PackOp with a reshape when there is only one operand.
1969 struct ReplacePackWithReshape : public RewritePattern {
ReplacePackWithReshapemlir::TFL::ReplacePackWithReshape1970   explicit ReplacePackWithReshape(MLIRContext *context)
1971       : RewritePattern(PackOp::getOperationName(), 2, context) {}
matchAndRewritemlir::TFL::ReplacePackWithReshape1972   LogicalResult matchAndRewrite(Operation *op,
1973                                 PatternRewriter &rewriter) const override {
1974     TFL::PackOp pack_op = cast<TFL::PackOp>(op);
1975     if (pack_op.getNumOperands() != 1) return failure();
1976 
1977     Location loc = pack_op.getLoc();
1978     auto output_type = pack_op.getType().cast<ShapedType>();
1979     if (!output_type.hasStaticShape()) return failure();
1980 
1981     // This is to workaround the unnecessary cast i64 -> i32.
1982     SmallVector<int32_t, 4> new_shape_array;
1983     for (auto size : output_type.getShape()) {
1984       new_shape_array.push_back(static_cast<int32_t>(size));
1985     }
1986 
1987     auto new_shape = rewriter.create<TFL::ConstOp>(
1988         loc, DenseIntElementsAttr::get(
1989                  RankedTensorType::get(new_shape_array.size(),
1990                                        rewriter.getIntegerType(32)),
1991                  new_shape_array));
1992 
1993     rewriter.replaceOpWithNewOp<ReshapeOp>(op, output_type,
1994                                            pack_op.getOperand(0), new_shape);
1995     return success();
1996   }
1997 };
1998 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1999 void PackOp::getCanonicalizationPatterns(RewritePatternSet &results,
2000                                          MLIRContext *context) {
2001   results.add<RemoveRedundantUnpackPack, ReplacePackWithReshape>(context);
2002 }
2003 
2004 //===----------------------------------------------------------------------===//
2005 // SliceOp
2006 //===----------------------------------------------------------------------===//
2007 
verify()2008 mlir::LogicalResult SliceOp::verify() {
2009   SliceOp op = *this;
2010   auto input_type = op.input().getType().cast<ShapedType>();
2011   auto begin_type = op.begin().getType().cast<ShapedType>();
2012   auto size_type = op.size().getType().cast<ShapedType>();
2013   if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
2014       size_type.hasStaticShape()) {
2015     if (input_type.getRank() != begin_type.getNumElements()) {
2016       return op.emitError(
2017           "begin tensor elements size is not equal to input tensor rank");
2018     }
2019 
2020     if (input_type.getRank() != size_type.getNumElements()) {
2021       return op.emitError(
2022           "size tensor elements size is not equal to input tensor rank");
2023     }
2024   }
2025 
2026   DenseIntElementsAttr begin;
2027   if (matchPattern(op.begin(), m_Constant(&begin))) {
2028     int axis = 0;
2029     for (const auto &begin_i : llvm::enumerate(begin)) {
2030       if (begin_i.value().getSExtValue() < 0) {
2031         return op.emitError(
2032             llvm::formatv("begin[{0}] cannot be negative", axis));
2033       }
2034       axis++;
2035     }
2036   }
2037 
2038   DenseIntElementsAttr size;
2039   if (matchPattern(op.size(), m_Constant(&size))) {
2040     int axis = 0;
2041     for (const auto &size_i : llvm::enumerate(size)) {
2042       if (size_i.value().getSExtValue() < -1) {
2043         return op.emitError(
2044             llvm::formatv("size[{0}] cannot be negative other than -1", axis));
2045       }
2046       axis++;
2047     }
2048   }
2049 
2050   if (begin && size && input_type.hasStaticShape()) {
2051     for (uint64_t i = 0, end = begin.getNumElements(); i < end; i++) {
2052       int begin_i = begin.getValues<APInt>()[i].getSExtValue();
2053       int size_i = size.getValues<APInt>()[i].getSExtValue();
2054       int dim_i = input_type.getShape()[i];
2055       if (begin_i > dim_i) {
2056         return op.emitOpError(llvm::formatv(
2057             "begin[{0}] cannot exceed dimension length: {1}", i, dim_i));
2058       }
2059       if (size_i >= 0 && begin_i + size_i > dim_i) {
2060         return op.emitError(llvm::formatv(
2061             "begin[{0}] + size[{0}] cannot exceed dimension length: {1}", i,
2062             dim_i));
2063       }
2064     }
2065   }
2066 
2067   return success();
2068 }
2069 
NarrowDownInt64InputValuesForOp(Operation * input_op,RankedTensorType value_type,Location loc,OpBuilder * builder)2070 TFL::ConstOp NarrowDownInt64InputValuesForOp(Operation *input_op,
2071                                              RankedTensorType value_type,
2072                                              Location loc, OpBuilder *builder) {
2073   if (input_op == nullptr) return nullptr;
2074 
2075   mlir::DenseIntElementsAttr attr;
2076   if (!matchPattern(input_op, m_Constant(&attr))) {
2077     return nullptr;
2078   }
2079 
2080   auto value_shape_type = mlir::RankedTensorType::get(
2081       value_type.getShape(), builder->getIntegerType(32));
2082 
2083   SmallVector<int32_t, 4> value_i32;
2084   value_i32.reserve(value_type.getRank());
2085   for (const auto &size : attr) {
2086     value_i32.push_back(static_cast<int32_t>(size.getSExtValue()));
2087   }
2088   auto new_value_i32_attr =
2089       mlir::DenseIntElementsAttr::get(value_shape_type, value_i32);
2090 
2091   return builder->create<TFL::ConstOp>(loc, new_value_i32_attr);
2092 }
2093 
2094 // This will cast down int64 values for TFL slice op.
2095 // This will require the begin & size are constants.
2096 struct CastDonwInt64BeginEndToInt32 : public OpRewritePattern<TFL::SliceOp> {
2097   using OpRewritePattern<TFL::SliceOp>::OpRewritePattern;
2098 
matchAndRewritemlir::TFL::CastDonwInt64BeginEndToInt322099   LogicalResult matchAndRewrite(TFL::SliceOp slice_op,
2100                                 PatternRewriter &rewriter) const override {
2101     auto begin = slice_op.begin();
2102     auto size = slice_op.size();
2103     auto begin_type = begin.getType().dyn_cast_or_null<RankedTensorType>();
2104     auto size_type = size.getType().dyn_cast_or_null<RankedTensorType>();
2105     auto begin_op = begin.getDefiningOp();
2106     auto size_op = size.getDefiningOp();
2107 
2108     if (begin_op == nullptr && size_op == nullptr) return failure();
2109 
2110     if (begin_type == nullptr && size_type == nullptr) return failure();
2111 
2112     // Handle begin.
2113     if (begin_op && begin_type && begin_type.getElementType().isInteger(64)) {
2114       auto new_begin = NarrowDownInt64InputValuesForOp(
2115           begin_op, begin_type, slice_op.getLoc(), &rewriter);
2116       if (new_begin != nullptr) {
2117         slice_op.setOperand(1, new_begin);
2118       }
2119     }
2120 
2121     // Handle size.
2122     if (size_op && size_type && size_type.getElementType().isInteger(64)) {
2123       auto new_size = NarrowDownInt64InputValuesForOp(
2124           size_op, size_type, slice_op.getLoc(), &rewriter);
2125       if (new_size != nullptr) {
2126         slice_op.setOperand(2, new_size);
2127       }
2128     }
2129 
2130     return success();
2131   }
2132 };
2133 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2134 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2135                                           MLIRContext *context) {
2136   results.add<CastDonwInt64BeginEndToInt32>(context);
2137 }
2138 
2139 //===----------------------------------------------------------------------===//
2140 // SqueezeOp
2141 //===----------------------------------------------------------------------===//
2142 
fold(ArrayRef<Attribute> operands)2143 OpFoldResult SqueezeOp::fold(ArrayRef<Attribute> operands) {
2144   auto input_ty = input().getType().dyn_cast<RankedTensorType>();
2145   auto result_ty = getType().dyn_cast<RankedTensorType>();
2146 
2147   if (!input_ty || !result_ty) return {};
2148   if (input_ty == result_ty) return input();
2149   return {};
2150 }
2151 
2152 //===----------------------------------------------------------------------===//
2153 // SubOp
2154 //===----------------------------------------------------------------------===//
2155 
fold(ArrayRef<Attribute> operands)2156 OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
2157   // TODO(b/142478136): Handle fused ops.
2158   if (fused_activation_function() != "NONE") return {};
2159   return ConstFoldBinaryOp(
2160       getType(), operands, [](APFloat a, APFloat b) { return a - b; },
2161       [](APInt a, APInt b) { return a - b; });
2162 }
2163 
GetArithmeticCount(Operation * op)2164 int64_t SubOp::GetArithmeticCount(Operation *op) {
2165   int64_t count;
2166   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
2167 
2168   return -1;
2169 }
2170 
2171 //===----------------------------------------------------------------------===//
2172 // TopKOp
2173 //===----------------------------------------------------------------------===//
2174 
BuildTopKOp(OpBuilder * builder,OperationState & result,Value input,Value k)2175 static void BuildTopKOp(OpBuilder *builder, OperationState &result, Value input,
2176                         Value k) {
2177   // Output size is only known if k is constant value. A negative dimension is
2178   // considered dynamic so use -1 here if k is not a constant value.
2179   int const_k = -1;
2180   ElementsAttr cst;
2181   if (matchPattern(k, m_Constant(&cst)))
2182     // These casts should all be valid due to how Tensor constants are stored.
2183     // TODO(jpienaar): This should use a helper function.
2184     const_k = cst.getValues<IntegerAttr>()[0].getValue().getSExtValue();
2185 
2186   auto val_type = input.getType().cast<TensorType>();
2187   // If value is unranked, then so is results.
2188   if (!val_type.hasRank())
2189     return TFL::TopKV2Op::build(
2190         *builder, result, UnrankedTensorType::get(val_type.getElementType()),
2191         UnrankedTensorType::get(builder->getIntegerType(32)), input, k);
2192 
2193   // Resultant shape is value.shape[:-1] + [k]
2194   std::vector<int64_t> shape(val_type.getShape());
2195   shape[shape.size() - 1] = const_k;
2196   TFL::TopKV2Op::build(
2197       *builder, result, RankedTensorType::get(shape, val_type.getElementType()),
2198       RankedTensorType::get(shape, builder->getIntegerType(32)), input, k);
2199 }
2200 
2201 //===----------------------------------------------------------------------===//
2202 // FakeQuantOp
2203 //===----------------------------------------------------------------------===//
2204 
2205 // Return true if the op has non-empty "minmax" attribute.
HasValidMinMaxAttribute(Operation * op)2206 static inline bool HasValidMinMaxAttribute(Operation *op) {
2207   auto minmax = op->getAttrOfType<ArrayAttr>("minmax");
2208   return minmax && minmax.getValue().size() == 2;
2209 }
2210 
2211 namespace {
2212 
2213 /// This pattern matches and remove a tfl.fake_quant if all the users of this op
2214 /// and itself have "minmax" attribute set.
2215 struct DropFakeQuant : public RewritePattern {
DropFakeQuantmlir::TFL::__anon216e30ea1411::DropFakeQuant2216   explicit DropFakeQuant(MLIRContext *context)
2217       : RewritePattern(FakeQuantOp::getOperationName(), 1, context) {}
2218 
matchmlir::TFL::__anon216e30ea1411::DropFakeQuant2219   LogicalResult match(Operation *op) const override {
2220     // We only match the op with valid "minmax" attribute.
2221     if (!HasValidMinMaxAttribute(op)) return failure();
2222 
2223     // If all the users of this op have valid "minmax" attributes, it is matched
2224     // and can be removed.
2225     auto fakeQuantOp = cast<FakeQuantOp>(op);
2226     for (auto *operand : fakeQuantOp.getResult().getUsers())
2227       if (!HasValidMinMaxAttribute(operand)) return failure();
2228 
2229     return success();
2230   }
2231 
rewritemlir::TFL::__anon216e30ea1411::DropFakeQuant2232   void rewrite(Operation *op, PatternRewriter &rewriter) const override {
2233     // Replace the matched FakeQuantOp by its primary operand.
2234     rewriter.replaceOp(op, op->getOperand(0));
2235   }
2236 };
2237 }  // end anonymous namespace
2238 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2239 void FakeQuantOp::getCanonicalizationPatterns(RewritePatternSet &results,
2240                                               MLIRContext *context) {
2241   results.add<DropFakeQuant>(context);
2242 }
2243 
2244 //===----------------------------------------------------------------------===//
2245 // UnpackOp
2246 //===----------------------------------------------------------------------===//
2247 
2248 // TODO(b/133486129): Implement shape inference for unpack
2249 
inferReturnTypes(MLIRContext * context,Optional<Location> loc,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)2250 LogicalResult UnpackOp::inferReturnTypes(
2251     MLIRContext *context, Optional<Location> loc, ValueRange operands,
2252     DictionaryAttr attributes, RegionRange regions,
2253     SmallVectorImpl<Type> &inferredReturnTypes) {
2254   UnpackOpAdaptor op(operands, attributes);
2255   // TODO(jpienaar): Refactor verify
2256   if (failed(op.verify(loc.has_value() ? *loc : UnknownLoc::get(context))))
2257     return failure();
2258 
2259   if (operands.size() != 1) {
2260     return emitOptionalError(loc, "input count should be equal to 1");
2261   }
2262 
2263   const int64_t num_value = op.numAttr().getInt();
2264   auto input_type = operands[0].getType().dyn_cast<ShapedType>();
2265   if (!input_type || !input_type.hasRank()) {
2266     // If input is unranked, then so is output.
2267     inferredReturnTypes.assign(
2268         num_value, UnrankedTensorType::get(input_type.getElementType()));
2269     return success();
2270   }
2271 
2272   if (input_type.hasStaticShape() && input_type.getNumElements() <= 0) {
2273     return emitOptionalError(
2274         loc, "number of elements in input should be larger than 0");
2275   }
2276 
2277   const int64_t rank = input_type.getRank();
2278   if (rank <= 0) {
2279     return emitOptionalError(loc, "input should be of rank larger than 0");
2280   }
2281 
2282   int64_t axis_value = op.axisAttr().getInt();
2283   if (axis_value < 0) {
2284     axis_value += rank;
2285   }
2286   if (axis_value < 0 || axis_value >= rank) {
2287     return emitOptionalError(
2288         loc, "attribute 'axis' should be in range [-rank, rank), got axis = ",
2289         op.axisAttr().getInt(), ", and rank = ", rank);
2290   }
2291 
2292   if (!ShapedType::isDynamic(input_type.getDimSize(axis_value)) &&
2293       input_type.getDimSize(axis_value) != num_value) {
2294     return emitOptionalError(loc, "output count should match 'num' attribute");
2295   }
2296 
2297   auto output_shape = llvm::to_vector<4>(input_type.getShape());
2298   output_shape.erase(output_shape.begin() + axis_value);
2299 
2300   auto output_type =
2301       RankedTensorType::get(output_shape, input_type.getElementType());
2302   inferredReturnTypes.assign(num_value, output_type);
2303 
2304   return success();
2305 }
2306 
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)2307 bool UnpackOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
2308   if (lhs.size() != rhs.size()) return false;
2309   for (auto pair : llvm::zip(lhs, rhs)) {
2310     if (failed(
2311             mlir::verifyCompatibleShape(std::get<0>(pair), std::get<1>(pair))))
2312       return false;
2313   }
2314   return true;
2315 }
2316 
2317 //===----------------------------------------------------------------------===//
2318 // SplitOp
2319 //===----------------------------------------------------------------------===//
2320 
2321 // Extracts and returns the signed integer constant in a 0-rank integer tensor
2322 // or 1-element 1-rank integer tensor if 'value' is a constant.
ExtractConstantIntFromTensor(Value value)2323 static llvm::Optional<int64_t> ExtractConstantIntFromTensor(Value value) {
2324   ElementsAttr attr;
2325   if (!matchPattern(value, m_Constant(&attr))) return {};
2326   if (attr.getNumElements() != 1) return {};
2327   IntegerAttr int_attr = *attr.getValues<IntegerAttr>().begin();
2328   return int_attr.getValue().getSExtValue();
2329 }
2330 
2331 // Returns a RankedTensorType which is similar to `input_type` but replaces the
2332 // dimension size of `dim` with `dim_size`.  For example,
2333 // `SubstituteRankedTensorTypeDimSize(tensor<3x4xi32>, 1, 2)` returns
2334 // `tensor<3x2xi32>`.
SubstituteRankedTensorTypeDimSize(RankedTensorType input_type,int64_t dim,int64_t dim_size)2335 static RankedTensorType SubstituteRankedTensorTypeDimSize(
2336     RankedTensorType input_type, int64_t dim, int64_t dim_size) {
2337   auto shape = input_type.getShape().vec();
2338   shape[dim] = dim_size;
2339   return RankedTensorType::get(shape, input_type.getElementType());
2340 }
2341 
2342 // Verifies the output tensor types of SplitOp or SplitVOp.
2343 template <typename ExpectedOutputTypeGetter>
VerifySplitOpOutputTypes(Operation * op,int64_t num_splits,ExpectedOutputTypeGetter get_expected_output_type)2344 static LogicalResult VerifySplitOpOutputTypes(
2345     Operation *op, int64_t num_splits,
2346     ExpectedOutputTypeGetter get_expected_output_type) {
2347   for (int64_t i = 0; i < num_splits; ++i) {
2348     auto expected_output_type = get_expected_output_type(i);
2349     Value output = op->getResult(i);
2350     if (failed(verifyCompatibleShape(output.getType(), expected_output_type)))
2351       return op->emitOpError()
2352              << "output #" << i << " should be " << expected_output_type
2353              << " instead got " << output.getType();
2354   }
2355   return success();
2356 }
2357 
verify()2358 mlir::LogicalResult SplitOp::verify() {
2359   SplitOp op = *this;
2360   int64_t num_splits = op.num_splits();
2361   if (op.getNumResults() != num_splits)
2362     return op.emitOpError("output count should match 'num_splits' attribute");
2363 
2364   // If 'split_dim' is not a constant, there are no other checks.
2365   llvm::Optional<int64_t> split_dim_opt =
2366       ExtractConstantIntFromTensor(op.split_dim());
2367   if (!split_dim_opt) return success();
2368 
2369   // If 'input' is not a ranked tensor, there are no other checks.
2370   auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
2371   if (!input_type) return success();
2372 
2373   int64_t split_dim = split_dim_opt.getValue();
2374   const int64_t rank = input_type.getRank();
2375   if (split_dim < 0) split_dim += rank;
2376   if (split_dim < 0 || split_dim >= rank)
2377     return op.emitOpError("'split_dim' should be in [-rank, rank)");
2378 
2379   // If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
2380   // there are no other checks.
2381   const int64_t dim_size = input_type.getDimSize(split_dim);
2382   if (ShapedType::isDynamic(dim_size)) return success();
2383 
2384   if (dim_size % num_splits != 0)
2385     return op.emitOpError("'num_splits' should evenly divide 'split_dim' axis");
2386 
2387   // Verifies output tensor types.
2388   RankedTensorType expected_output_type = SubstituteRankedTensorTypeDimSize(
2389       input_type, split_dim, dim_size / num_splits);
2390   return VerifySplitOpOutputTypes(
2391       op.getOperation(), num_splits,
2392       [expected_output_type](int64_t) { return expected_output_type; });
2393 }
2394 
verify()2395 mlir::LogicalResult SplitVOp::verify() {
2396   SplitVOp op = *this;
2397   int64_t num_splits = op.num_splits();
2398   if (op.getNumResults() != num_splits)
2399     return op.emitOpError("output count should match 'num_splits' attribute");
2400 
2401   // If 'split_dim' is not a constant, there are no other checks.
2402   llvm::Optional<int64_t> split_dim_opt =
2403       ExtractConstantIntFromTensor(op.split_dim());
2404   if (!split_dim_opt) return success();
2405 
2406   // If 'input' is not a ranked tensor, there are no other checks.
2407   auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
2408   if (!input_type) return success();
2409 
2410   int64_t split_dim = split_dim_opt.getValue();
2411   const int64_t rank = input_type.getRank();
2412   if (split_dim < 0) split_dim += rank;
2413   if (split_dim < 0 || split_dim >= rank)
2414     return op.emitOpError("'split_dim' should be in [-rank, rank)");
2415 
2416   // If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
2417   // there are no other checks.
2418   const int64_t dim_size = input_type.getDimSize(split_dim);
2419   if (ShapedType::isDynamic(dim_size)) return success();
2420 
2421   // If 'size_splits' is not a constant, there are no other checks.
2422   ElementsAttr size_splits_attr;
2423   if (!matchPattern(op.size_splits(), m_Constant(&size_splits_attr)))
2424     return success();
2425 
2426   if (size_splits_attr.getNumElements() != num_splits) {
2427     auto size_splits_type = op.size_splits().getType().cast<RankedTensorType>();
2428     RankedTensorType expected_size_splits_type =
2429         RankedTensorType::get({num_splits}, size_splits_type.getElementType());
2430     return op.emitOpError("'size_splits' should be ")
2431            << expected_size_splits_type;
2432   }
2433 
2434   // Normalizes and verifies 'size_splits'.
2435   // Note: TensorFlow allows one -1 element in 'size_splits'.  The -1 element
2436   // means the rest of the dimension size.
2437   llvm::SmallVector<int64_t, 4> size_splits;
2438   size_splits.reserve(num_splits);
2439 
2440   int64_t negative_size_split_loc = -1;
2441   int64_t total_size_splits = 0;
2442 
2443   for (int64_t i = 0; i < num_splits; ++i) {
2444     auto size_split_attr = size_splits_attr.getValues<IntegerAttr>()[i];
2445     int64_t size_split = size_split_attr.getValue().getSExtValue();
2446     size_splits.push_back(size_split);
2447     if (size_split >= 0) {
2448       total_size_splits += size_split;
2449       continue;
2450     }
2451     if (size_split < -1)
2452       return op.emitOpError(
2453           "elements of 'size_splits' should be greater than or equal to -1");
2454     if (negative_size_split_loc != -1)
2455       return op.emitOpError("'size_splits' can only have one -1");
2456     negative_size_split_loc = i;
2457   }
2458 
2459   if (negative_size_split_loc != -1) {
2460     if (total_size_splits > dim_size)
2461       return op.emitOpError(
2462           "sum of non-negative elements of 'size_splits' is greater than the "
2463           "dimension size of 'split_dim' axis");
2464     size_splits[negative_size_split_loc] = dim_size - total_size_splits;
2465     total_size_splits = dim_size;
2466   }
2467 
2468   if (total_size_splits != dim_size)
2469     return op.emitOpError(
2470         "sum of 'size_splits' should match the dimension size of 'split_dim' "
2471         "axis");
2472 
2473   // Verifies result tensor types.
2474   auto get_expected_output_type = [input_type, split_dim,
2475                                    &size_splits](int64_t i) {
2476     return SubstituteRankedTensorTypeDimSize(input_type, split_dim,
2477                                              size_splits[i]);
2478   };
2479   return VerifySplitOpOutputTypes(op.getOperation(), num_splits,
2480                                   get_expected_output_type);
2481 }
2482 
2483 //===----------------------------------------------------------------------===//
2484 // MeanOp
2485 //===----------------------------------------------------------------------===//
2486 
2487 // TODO(b/133854225): Implement shape inference to Mean
2488 
2489 //===----------------------------------------------------------------------===//
2490 // LSTMOp
2491 //===----------------------------------------------------------------------===//
2492 
verify()2493 mlir::LogicalResult LSTMOp::verify() {
2494   LSTMOp op = *this;
2495   auto operands = op.GetStatefulOperands();
2496   if (operands.size() != 2 || operands[0] != 18 || operands[1] != 19) {
2497     return op.emitOpError("LSTMOp expected to have two stateful operands");
2498   }
2499 
2500   const auto input_type = op.input().getType().cast<ShapedType>();
2501   // Since TFLite runtime generally supports dynamic shape/rank, if `input_type`
2502   // doesn't have static shape, we skip the shape check below.
2503   if (!input_type.hasStaticShape()) return success();
2504   // The input should be at least 2D tensor since it will go through fully
2505   // connected layer.
2506   if (!input_type.hasRank() || input_type.getRank() < 2)
2507     return op.emitOpError(
2508         "the first input operand should have more than 2 dimensions.");
2509 
2510   const auto activation_state =
2511       op.input_activation_state().getType().cast<ShapedType>();
2512   const auto cell_state = op.input_cell_state().getType().cast<ShapedType>();
2513   const auto input_to_output_weights =
2514       op.input_to_output_weights().getType().cast<ShapedType>();
2515   const auto recurrent_to_output_weights =
2516       op.recurrent_to_output_weights().getType().cast<ShapedType>();
2517   if (activation_state.hasStaticShape() && cell_state.hasStaticShape() &&
2518       input_to_output_weights.hasStaticShape() &&
2519       recurrent_to_output_weights.hasStaticShape()) {
2520     const int n_input = input_type.getDimSize(input_type.getRank() - 1);
2521     const int n_cell = input_to_output_weights.getDimSize(0);
2522     const int n_output = recurrent_to_output_weights.getDimSize(1);
2523     const int output_state_size = activation_state.getNumElements();
2524     const int n_batch = input_type.getRank() == 2 ? input_type.getDimSize(0)
2525                                                   : input_type.getDimSize(1);
2526     const int state_size = cell_state.getNumElements();
2527 
2528     // Check if the dimension of the inputs matches.
2529     if ((output_state_size != n_batch * n_output) ||
2530         (state_size != n_batch * n_cell) ||
2531         (input_to_output_weights.getDimSize(1) != n_input) ||
2532         (recurrent_to_output_weights.getRank() != 2) ||
2533         (recurrent_to_output_weights.getDimSize(0) != n_cell) ||
2534         (input_to_output_weights.getRank() != 2)) {
2535       return op.emitOpError("inputs don't match with the dimensions.");
2536     }
2537 
2538     const bool is_layer_norm_lstm =
2539         !op.forget_layer_norm_coefficients().getType().isa<NoneType>();
2540     if (is_layer_norm_lstm) {
2541       const auto forget_layer_norm_coefficients =
2542           op.forget_layer_norm_coefficients().getType().cast<ShapedType>();
2543       // If this lstm has layer normalization, this input value,
2544       // "forget_layer_norm_coefficients" should be a 1D tensor.
2545       if (!forget_layer_norm_coefficients.hasRank() ||
2546           forget_layer_norm_coefficients.getRank() != 1 ||
2547           forget_layer_norm_coefficients.getDimSize(0) != n_cell)
2548         return op.emitOpError(
2549             "coefficient inputs have more than 2 dimensions or "
2550             "don't match the dimension with input operand "
2551             "`input_to_output_weights`.");
2552     }
2553   }
2554 
2555   return success();
2556 }
2557 
2558 namespace {
2559 
2560 // Replaces the optional bias operands with a "none" type value if the bias
2561 // values are constant zeros.
2562 struct RemoveLSTMOpZeroBias : public OpRewritePattern<LSTMOp> {
2563   using OpRewritePattern<LSTMOp>::OpRewritePattern;
2564 
matchAndRewritemlir::TFL::__anon216e30ea1711::RemoveLSTMOpZeroBias2565   LogicalResult matchAndRewrite(LSTMOp op,
2566                                 PatternRewriter &rewriter) const override {
2567     if (EqualsZero(op.input_gate_bias())) {
2568       auto none_value = rewriter.create<TFL::NoValueOp>(
2569           rewriter.getUnknownLoc(), rewriter.getNoneType(),
2570           rewriter.getUnitAttr());
2571       op.input_gate_biasMutable().assign(none_value);
2572     }
2573 
2574     if (EqualsZero(op.projection_bias())) {
2575       auto none_value = rewriter.create<TFL::NoValueOp>(
2576           rewriter.getUnknownLoc(), rewriter.getNoneType(),
2577           rewriter.getUnitAttr());
2578       op.projection_biasMutable().assign(none_value);
2579     }
2580 
2581     return success();
2582   }
2583 };
2584 
2585 }  // namespace
2586 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2587 void LSTMOp::getCanonicalizationPatterns(RewritePatternSet &results,
2588                                          MLIRContext *context) {
2589   results.add<RemoveLSTMOpZeroBias>(context);
2590 }
2591 
2592 //===----------------------------------------------------------------------===//
2593 // UnidirectionalSequenceLSTMOp
2594 //===----------------------------------------------------------------------===//
2595 
verify()2596 mlir::LogicalResult UnidirectionalSequenceLSTMOp::verify() {
2597   UnidirectionalSequenceLSTMOp op = *this;
2598   auto operands = op.GetStatefulOperands();
2599   if (operands.size() == 2 && operands[0] == 18 && operands[1] == 19) {
2600     return success();
2601   }
2602   return op.emitError(
2603       "UnidirectionalSequenceLSTMOp expected to have two stateful operands");
2604 }
2605 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2606 LogicalResult UnidirectionalSequenceLSTMOp::inferReturnTypes(
2607     MLIRContext *, Optional<Location>, ValueRange operands, DictionaryAttr attr,
2608     RegionRange, SmallVectorImpl<Type> &inferredReturnTypes) {
2609   Value input = operands[0];
2610   auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
2611 
2612   Value output_state = operands[18];
2613   auto output_state_type =
2614       output_state.getType().dyn_cast_or_null<RankedTensorType>();
2615 
2616   if (input_type && input_type.hasRank() && input_type.getRank() != 3) {
2617     return failure();
2618   }
2619 
2620   if (output_state_type && output_state_type.hasRank() &&
2621       output_state_type.getRank() != 2) {
2622     return failure();
2623   }
2624 
2625   if (!input_type || !input_type.hasRank() || !output_state_type ||
2626       !output_state_type.hasRank()) {
2627     // We cannot infer the output shape since we don't know the input shape or
2628     // the output state shape. We will set the output shape as unranked.
2629     Type result_type;
2630     result_type = UnrankedTensorType::get(
2631         input.getType().cast<ShapedType>().getElementType());
2632     inferredReturnTypes.assign({result_type});
2633     return success();
2634   }
2635 
2636   // Default to non-time_major.
2637   Optional<mlir::NamedAttribute> time_major_attr = attr.getNamed("time_major");
2638   bool time_majored =
2639       time_major_attr ? time_major_attr->getValue().cast<BoolAttr>().getValue()
2640                       : false;
2641 
2642   int batch =
2643       time_majored ? input_type.getDimSize(1) : input_type.getDimSize(0);
2644   int time = time_majored ? input_type.getDimSize(0) : input_type.getDimSize(1);
2645   int n_output = output_state_type.getDimSize(1);
2646 
2647   // Build the output shape.
2648   SmallVector<int64_t, 3> output_shape;
2649   if (time_majored) {
2650     output_shape = {time, batch, n_output};
2651   } else {
2652     output_shape = {batch, time, n_output};
2653   }
2654   auto result_type =
2655       mlir::RankedTensorType::get(output_shape, input_type.getElementType());
2656 
2657   inferredReturnTypes.assign({result_type});
2658   return success();
2659 }
2660 
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)2661 bool UnidirectionalSequenceLSTMOp::isCompatibleReturnTypes(TypeRange lhs,
2662                                                            TypeRange rhs) {
2663   if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
2664   if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
2665   return true;
2666 }
2667 
2668 //===----------------------------------------------------------------------===//
2669 // BidirectionalSequenceLSTMOp
2670 //===----------------------------------------------------------------------===//
2671 
verify()2672 mlir::LogicalResult BidirectionalSequenceLSTMOp::verify() {
2673   BidirectionalSequenceLSTMOp op = *this;
2674   auto operands = op.GetStatefulOperands();
2675   if (operands.size() == 4 && operands[0] == 35 && operands[1] == 36 &&
2676       operands[2] == 37 && operands[3] == 38) {
2677     return success();
2678   }
2679   return op.emitError(
2680       "BidirectionalSequenceLSTMOp expected to have four stateful operands");
2681 }
2682 
2683 //===----------------------------------------------------------------------===//
2684 // UnidirectionalSequenceRNNOp
2685 //===----------------------------------------------------------------------===//
2686 
verify()2687 mlir::LogicalResult UnidirectionalSequenceRNNOp::verify() {
2688   UnidirectionalSequenceRNNOp op = *this;
2689   auto operands = op.GetStatefulOperands();
2690   if (operands.size() == 1 && operands[0] == 4) {
2691     return success();
2692   }
2693   return op.emitError(
2694       "UnidirectionalSequenceRNNOp expected to have one stateful operand");
2695 }
2696 
2697 //===----------------------------------------------------------------------===//
2698 // SvdfOp
2699 //===----------------------------------------------------------------------===//
2700 
verify()2701 mlir::LogicalResult SVDFOp::verify() {
2702   SVDFOp op = *this;
2703   auto operands = op.GetStatefulOperands();
2704   if (operands.size() == 1 && operands[0] == 4) {
2705     return success();
2706   }
2707   return op.emitError("SvdfOp expected to have one stateful operand");
2708 }
2709 
2710 //===----------------------------------------------------------------------===//
2711 // AbsOp
2712 //===----------------------------------------------------------------------===//
2713 
fold(ArrayRef<Attribute> operands)2714 OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
2715   Type result_type = getType();
2716   // Only constant fold for tensor of f32 is implemented.
2717   if (!IsF32ShapedType(result_type)) return nullptr;
2718 
2719   auto compute = [](APFloat value) -> APFloat { return llvm::abs(value); };
2720   return ConstFoldUnaryOp(result_type, operands[0], compute);
2721 }
2722 
2723 //===----------------------------------------------------------------------===//
2724 // NegOp
2725 //===----------------------------------------------------------------------===//
2726 
fold(ArrayRef<Attribute> operands)2727 OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
2728   Type result_type = getType();
2729   // Only constant fold for tensor of f32 is implemented.
2730   if (!IsF32ShapedType(result_type)) return nullptr;
2731 
2732   auto compute = [](APFloat value) -> APFloat { return llvm::neg(value); };
2733   return ConstFoldUnaryOp(result_type, operands[0], compute);
2734 }
2735 
2736 //===----------------------------------------------------------------------===//
2737 // SinOp
2738 //===----------------------------------------------------------------------===//
2739 
fold(ArrayRef<Attribute> operands)2740 OpFoldResult SinOp::fold(ArrayRef<Attribute> operands) {
2741   Type result_type = getType();
2742   // Only constant fold for tensor of f32 is implemented.
2743   if (!IsF32ShapedType(result_type)) return nullptr;
2744 
2745   auto compute = [](APFloat value) -> APFloat {
2746     float f = value.convertToFloat();
2747     float result = std::sin(f);
2748     return APFloat(result);
2749   };
2750   return ConstFoldUnaryOp(result_type, operands[0], compute);
2751 }
2752 
2753 //===----------------------------------------------------------------------===//
2754 // CosOp
2755 //===----------------------------------------------------------------------===//
2756 
fold(ArrayRef<Attribute> operands)2757 OpFoldResult CosOp::fold(ArrayRef<Attribute> operands) {
2758   Type result_type = getType();
2759   // Only constant fold for tensor of f32 is implemented.
2760   if (!IsF32ShapedType(result_type)) return nullptr;
2761 
2762   auto compute = [](APFloat value) -> APFloat {
2763     float f = value.convertToFloat();
2764     float result = std::cos(f);
2765     return APFloat(result);
2766   };
2767   return ConstFoldUnaryOp(result_type, operands[0], compute);
2768 }
2769 
2770 //===----------------------------------------------------------------------===//
2771 // LogOp
2772 //===----------------------------------------------------------------------===//
2773 
fold(ArrayRef<Attribute> operands)2774 OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
2775   Type result_type = getType();
2776   // Only constant fold for tensor of f32 is implemented.
2777   if (!IsF32ShapedType(result_type)) return nullptr;
2778 
2779   auto compute = [](APFloat value) -> APFloat {
2780     float f = value.convertToFloat();
2781     float result = std::log(f);
2782     return APFloat(result);
2783   };
2784   return ConstFoldUnaryOp(result_type, operands[0], compute);
2785 }
2786 
2787 //===----------------------------------------------------------------------===//
2788 // ShapeOp
2789 //===----------------------------------------------------------------------===//
2790 
fold(ArrayRef<Attribute> operands)2791 OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
2792   auto input_type = input().getType().cast<ShapedType>();
2793   if (!input_type.hasStaticShape()) return nullptr;
2794 
2795   ArrayRef<int64_t> shape = input_type.getShape();
2796   auto result_type = getType().cast<ShapedType>();
2797   if (result_type.getElementType().isInteger(64)) {
2798     return DenseElementsAttr::get<int64_t>(result_type, shape);
2799   } else if (result_type.getElementType().isInteger(32)) {
2800     SmallVector<int32_t, 4> shape_i32;
2801     shape_i32.reserve(shape.size());
2802     for (int64_t dim : shape) {
2803       shape_i32.push_back(dim);
2804     }
2805     return DenseElementsAttr::get<int32_t>(result_type, shape_i32);
2806   }
2807   return nullptr;
2808 }
2809 
2810 //===----------------------------------------------------------------------===//
2811 // SqrtOp
2812 //===----------------------------------------------------------------------===//
2813 
fold(ArrayRef<Attribute> operands)2814 OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
2815   Type result_type = getType();
2816   // Only constant fold for tensor of f32 is implemented.
2817   if (!IsF32ShapedType(result_type)) return nullptr;
2818 
2819   auto compute = [](APFloat value) -> APFloat {
2820     float f = value.convertToFloat();
2821     float result = std::sqrt(f);
2822     return APFloat(result);
2823   };
2824   return ConstFoldUnaryOp(result_type, operands[0], compute);
2825 }
2826 
2827 //===----------------------------------------------------------------------===//
2828 // RsqrtOp
2829 //===----------------------------------------------------------------------===//
2830 
fold(ArrayRef<Attribute> operands)2831 OpFoldResult RsqrtOp::fold(ArrayRef<Attribute> operands) {
2832   Type result_type = getType();
2833   // Only constant fold for tensor of f32/bf16 is implemented.
2834   if (!IsF32ShapedType(result_type) && !IsBF16ShapedType(result_type))
2835     return nullptr;
2836 
2837   auto compute = [](APFloat value) -> APFloat {
2838     bool loseInfo;
2839     const llvm::fltSemantics &original_float_semantics = value.getSemantics();
2840     value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
2841                   &loseInfo);
2842     float f = value.convertToFloat();
2843     APFloat result(1.f / std::sqrt(f));
2844     result.convert(original_float_semantics, APFloat::rmNearestTiesToEven,
2845                    &loseInfo);
2846     return result;
2847   };
2848   return ConstFoldUnaryOp(result_type, operands[0], compute);
2849 }
2850 
2851 //===----------------------------------------------------------------------===//
2852 // SquareOp
2853 //===----------------------------------------------------------------------===//
2854 
fold(ArrayRef<Attribute> operands)2855 OpFoldResult SquareOp::fold(ArrayRef<Attribute> operands) {
2856   Type result_type = getType();
2857   // Only constant fold for tensor of f32 is implemented.
2858   if (!IsF32ShapedType(result_type)) return nullptr;
2859 
2860   auto compute = [](APFloat value) -> APFloat { return value * value; };
2861   return ConstFoldUnaryOp(result_type, operands[0], compute);
2862 }
2863 
2864 //===----------------------------------------------------------------------===//
2865 // RankOp
2866 //===----------------------------------------------------------------------===//
2867 
fold(ArrayRef<Attribute> operands)2868 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
2869   assert(operands.size() == 1);
2870   auto result_type = getType().cast<ShapedType>();
2871   if (auto elements_attr = operands[0].dyn_cast_or_null<ElementsAttr>()) {
2872     auto rank = static_cast<int32_t>(elements_attr.getType().getRank());
2873     return DenseElementsAttr::get(result_type, {rank});
2874   }
2875 
2876   // Also fold if `input` has a known rank.
2877   auto input_type = input().getType().cast<ShapedType>();
2878   // Do not fold if rank is zero because the TFLite converter doesn't
2879   // distinguish between unranked input and scalar input due to b/138865275.
2880   // TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
2881   // predicate and fold the op when rank is zero.
2882   if (input_type.hasRank() && input_type.getRank() != 0) {
2883     auto rank = static_cast<int32_t>(input_type.getRank());
2884     return DenseElementsAttr::get(result_type, {rank});
2885   }
2886 
2887   return nullptr;
2888 }
2889 
2890 //===----------------------------------------------------------------------===//
2891 // ConstOp
2892 //===----------------------------------------------------------------------===//
2893 
fold(ArrayRef<Attribute> operands)2894 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
2895   assert(operands.empty() && "constant has no operands");
2896   // Return the held attribute value.
2897   return value();
2898 }
2899 
isCompatibleReturnTypes(TypeRange l,TypeRange r)2900 bool ConstOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2901   // Allow the type inferred to not match exactly the inferred type as the
2902   // inferred type is from the element attribute's type while the op may have
2903   // gotten constructed from TF const op or be in a partial state of shape
2904   // refinement, so allow it to only be compatible. The op will be refined
2905   // during shape inference and casts inserted as needed to satisfy type
2906   // constraints of consumers.
2907   return succeeded(verifyCompatibleShapes(l, r));
2908 }
2909 
2910 namespace {
2911 struct FoldPseudoConstOp : public OpRewritePattern<ConstOp> {
2912   using OpRewritePattern<ConstOp>::OpRewritePattern;
2913 
matchAndRewritemlir::TFL::__anon216e30ea2011::FoldPseudoConstOp2914   LogicalResult matchAndRewrite(ConstOp const_op,
2915                                 PatternRewriter &rewriter) const override {
2916     if (arith::ConstantOp::isBuildableWith(const_op.value(),
2917                                            const_op.getType())) {
2918       rewriter.replaceOpWithNewOp<arith::ConstantOp>(const_op,
2919                                                      const_op.value());
2920       return success();
2921     } else if (TFL::NoValueOp::isBuildableWith(const_op.value(),
2922                                                const_op.getType())) {
2923       rewriter.replaceOpWithNewOp<NoValueOp>(const_op, rewriter.getNoneType(),
2924                                              const_op.value().cast<UnitAttr>());
2925       return success();
2926     }
2927     return failure();
2928   }
2929 };
2930 
2931 }  // namespace
2932 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2933 void ConstOp::getCanonicalizationPatterns(RewritePatternSet &results,
2934                                           MLIRContext *context) {
2935   results.add<FoldPseudoConstOp>(context);
2936 }
2937 
2938 //===----------------------------------------------------------------------===//
2939 // CastOp
2940 //===----------------------------------------------------------------------===//
2941 
fold(ArrayRef<Attribute> operands)2942 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
2943   assert(operands.size() == 1);
2944   if (getElementTypeOrSelf(input()) == getElementTypeOrSelf(getType())) {
2945     return input();
2946   }
2947 
2948   // For now, only supports cast between integer types.
2949   auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
2950   if (!elements_attr) {
2951     return nullptr;
2952   }
2953 
2954   auto result_element_type =
2955       getType().cast<ShapedType>().getElementType().dyn_cast<IntegerType>();
2956   auto operand_element_type = input()
2957                                   .getType()
2958                                   .cast<ShapedType>()
2959                                   .getElementType()
2960                                   .dyn_cast<IntegerType>();
2961   // Returns nullptr if either result/operand element type is not integer.
2962   if (!result_element_type || !operand_element_type) {
2963     return nullptr;
2964   }
2965 
2966   const bool is_unsigned = operand_element_type.isUnsigned();
2967   const bool involves_bool = operand_element_type.getWidth() == 1 ||
2968                              result_element_type.getWidth() == 1;
2969   const int output_bitwidth = result_element_type.getWidth();
2970   // The integer cast op is the same as C integer cast. Depends on the operand
2971   // type's signedness, we will determine whether or not sign extension is
2972   // needed.
2973   auto cast = [&](APInt value) {
2974     if (involves_bool) {
2975       // Handle boolean inputs or outputs explicitly as it doesn't have the same
2976       // behavior as extension or truncation.
2977       // true input should always be cast to 1 and not -1 as the sign extension
2978       // would do for signed outputs. Similarly, non-zero inputs should be cast
2979       // to true. Truncating even numbers to one bit will result in `false`.
2980       return APInt(result_element_type.getWidth(), value != 0);
2981     }
2982     return is_unsigned ? value.zextOrTrunc(output_bitwidth)
2983                        : value.sextOrTrunc(output_bitwidth);
2984   };
2985 
2986   return elements_attr.mapValues(result_element_type, cast);
2987 }
2988 
2989 //===----------------------------------------------------------------------===//
2990 // SelectV2Op
2991 //===----------------------------------------------------------------------===//
2992 
BuildSelectV2Op(Builder * builder,OperationState & result,Value cond,Value x,Value y)2993 static void BuildSelectV2Op(Builder *builder, OperationState &result,
2994                             Value cond, Value x, Value y) {
2995   auto operand_type =
2996       OpTrait::util::getBroadcastedType(x.getType(), y.getType());
2997 
2998   if (!operand_type)
2999     emitError(result.location) << "non-broadcastable operands: " << x.getType()
3000                                << " and " << y.getType();
3001 
3002   bool has_static_cond_shape = false;
3003   bool has_static_operand_shape = false;
3004   ArrayRef<int64_t> cond_shape;
3005   ArrayRef<int64_t> operand_shape;
3006 
3007   if (auto shaped_type = cond.getType().dyn_cast<ShapedType>()) {
3008     if (shaped_type.hasStaticShape()) {
3009       has_static_cond_shape = true;
3010       cond_shape = shaped_type.getShape();
3011     }
3012   }
3013   if (auto shaped_type = operand_type.dyn_cast<ShapedType>()) {
3014     if (shaped_type.hasStaticShape()) {
3015       has_static_operand_shape = true;
3016       operand_shape = shaped_type.getShape();
3017     }
3018   }
3019 
3020   SmallVector<int64_t, 4> broadcastedShape;
3021   if (has_static_cond_shape && has_static_operand_shape &&
3022       !OpTrait::util::getBroadcastedShape(cond_shape, operand_shape,
3023                                           broadcastedShape)) {
3024     emitError(result.location) << "non-broadcastable operands: " << operand_type
3025                                << " and " << cond.getType();
3026   }
3027 
3028   result.addOperands({cond, x, y});
3029 
3030   auto elementType = x.getType().dyn_cast<ShapedType>().getElementType();
3031   if (has_static_cond_shape && has_static_operand_shape) {
3032     result.types.push_back(
3033         RankedTensorType::get(broadcastedShape, elementType));
3034   } else {
3035     result.types.push_back(UnrankedTensorType::get(elementType));
3036   }
3037 }
3038 
3039 //===----------------------------------------------------------------------===//
3040 // RangeOp
3041 //===----------------------------------------------------------------------===//
3042 
3043 namespace {
3044 
3045 // Compute the length of a range (1-D) tensor given `start`, `limit`, `delta`.
3046 // Template parameter `FloatOrInt` must be standard C integer or floating-point
3047 // types.
3048 template <typename FloatOrInt>
GetLengthOfRange(FloatOrInt start,FloatOrInt limit,FloatOrInt delta)3049 int GetLengthOfRange(FloatOrInt start, FloatOrInt limit, FloatOrInt delta) {
3050   // Refer to the implementation in
3051   // tensorflow/lite/kernels/range.cc.
3052   return std::is_integral<FloatOrInt>::value
3053              ? ((std::abs(limit - start) + std::abs(delta) - 1) /
3054                 std::abs(delta))
3055              : std::ceil(std::abs((limit - start) / delta));
3056 }
3057 
3058 // Builds a constant range tensor of `result_elem_type` elements.
3059 // Template parameter `FloatOrIntAtrr` must be mlir::IntegerAttr or
3060 // mlir::FloatAttr.
3061 template <typename FloatOrIntAtrr>
BuildConstRangeTensor(Type result_elem_type,int num_elements,FloatOrIntAtrr start_attr,FloatOrIntAtrr delta_attr)3062 DenseElementsAttr BuildConstRangeTensor(Type result_elem_type, int num_elements,
3063                                         FloatOrIntAtrr start_attr,
3064                                         FloatOrIntAtrr delta_attr) {
3065   using ValueType = typename FloatOrIntAtrr::ValueType;  // APInt or APFloat
3066   ValueType start = start_attr.getValue();
3067   ValueType delta = delta_attr.getValue();
3068 
3069   SmallVector<ValueType, 16> new_values;
3070   new_values.reserve(num_elements);
3071   ValueType new_value = start;
3072   for (int i = 0; i < num_elements; ++i) {
3073     new_values.push_back(new_value);
3074     new_value = new_value + delta;
3075   }
3076   // Result is always a 1-D tensor.
3077   auto new_result_type =
3078       RankedTensorType::get({num_elements}, result_elem_type);
3079   return DenseElementsAttr::get(new_result_type, new_values);
3080 }
3081 }  // namespace
3082 
fold(ArrayRef<Attribute> operands)3083 OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
3084   assert(operands.size() == 3);
3085   auto start_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
3086   auto limit_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
3087   auto delta_tensor = operands[2].dyn_cast_or_null<ElementsAttr>();
3088   if (start_tensor && limit_tensor && delta_tensor) {
3089     // Operands should all be scalars
3090     assert(start_tensor.getType().getRank() == 0 &&
3091            limit_tensor.getType().getRank() == 0 &&
3092            delta_tensor.getType().getRank() == 0);
3093     Type elem_type = getType().cast<ShapedType>().getElementType();
3094     if (elem_type.isSignlessInteger()) {
3095       auto start_attr = start_tensor.getValues<IntegerAttr>()[0];
3096       auto limit_attr = limit_tensor.getValues<IntegerAttr>()[0];
3097       auto delta_attr = delta_tensor.getValues<IntegerAttr>()[0];
3098       const int num_elements = GetLengthOfRange(
3099           start_attr.getInt(), limit_attr.getInt(), delta_attr.getInt());
3100       return BuildConstRangeTensor(elem_type, num_elements, start_attr,
3101                                    delta_attr);
3102     } else if (elem_type.isa<FloatType>()) {
3103       auto start_attr = start_tensor.getValues<FloatAttr>()[0];
3104       auto limit_attr = limit_tensor.getValues<FloatAttr>()[0];
3105       auto delta_attr = delta_tensor.getValues<FloatAttr>()[0];
3106       const int num_elements = GetLengthOfRange(start_attr.getValueAsDouble(),
3107                                                 limit_attr.getValueAsDouble(),
3108                                                 delta_attr.getValueAsDouble());
3109       return BuildConstRangeTensor(elem_type, num_elements, start_attr,
3110                                    delta_attr);
3111     }
3112   }
3113 
3114   return nullptr;
3115 }
3116 
3117 //===----------------------------------------------------------------------===//
3118 // TransposeConvOp
3119 //===----------------------------------------------------------------------===//
3120 
verify()3121 mlir::LogicalResult TransposeConvOp::verify() {
3122   TransposeConvOp op = *this;
3123   ShapedType output_type = op.output().getType().cast<ShapedType>();
3124   ShapedType output_shape_type = op.output_shape().getType().cast<ShapedType>();
3125   if (output_type.hasRank() && output_shape_type.hasStaticShape()) {
3126     if (output_type.getRank() != output_shape_type.getDimSize(0)) {
3127       return op.emitOpError(llvm::formatv(
3128           "expect output type has rank = {0}, got output type {1}",
3129           output_shape_type.getDimSize(0), output_type));
3130     }
3131   }
3132 
3133   DenseIntElementsAttr output_shape_elements;
3134   if (!matchPattern(op.output_shape(), m_Constant(&output_shape_elements))) {
3135     return success();
3136   }
3137 
3138   llvm::SmallVector<int64_t, 4> output_shape;
3139   output_shape.reserve(output_shape_elements.getNumElements());
3140   for (auto dim : output_shape_elements.getValues<int>()) {
3141     output_shape.push_back(dim);
3142   }
3143 
3144   auto expected_output_type =
3145       RankedTensorType::get(output_shape, output_type.getElementType());
3146   if (failed(mlir::verifyCompatibleShape(output_type, expected_output_type))) {
3147     return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
3148                                         expected_output_type, output_type));
3149   }
3150 
3151   return success();
3152 }
3153 
GetArithmeticCount(Operation * op)3154 int64_t TransposeConvOp::GetArithmeticCount(Operation *op) {
3155   int64_t count = -1;
3156   auto transpose_conv = llvm::dyn_cast<TransposeConvOp>(op);
3157   auto input_type = transpose_conv.input()
3158                         .getType()
3159                         .dyn_cast_or_null<mlir::RankedTensorType>();
3160   auto weight_type = transpose_conv.weights()
3161                          .getType()
3162                          .dyn_cast_or_null<mlir::RankedTensorType>();
3163   if (input_type && weight_type && input_type.hasStaticShape() &&
3164       weight_type.hasStaticShape()) {
3165     // Compute op count from the seven nested loops of
3166     // tflite::reference_ops::TransposeConv():
3167     count = 2 * input_type.getNumElements() * weight_type.getDimSize(0) *
3168             weight_type.getDimSize(1) * weight_type.getDimSize(2);
3169   }
3170 
3171   return count;
3172 }
3173 
3174 //===----------------------------------------------------------------------===//
3175 // StridedSliceOp
3176 //===----------------------------------------------------------------------===//
3177 
verify()3178 LogicalResult StridedSliceOp::verify() {
3179   StridedSliceOp op = *this;
3180   auto ranked_input_type = op.input().getType().dyn_cast<RankedTensorType>();
3181 
3182   // If input is unranked, there is nothing else to be verified.
3183   if (!ranked_input_type) return success();
3184   int num_input_dims = ranked_input_type.getRank();
3185 
3186   if (auto begin_type = op.begin().getType().dyn_cast<RankedTensorType>()) {
3187     if (begin_type.getRank() != 1) return failure();
3188     if (begin_type.getDimSize(0) > num_input_dims) return failure();
3189   }
3190 
3191   if (auto end_type = op.end().getType().dyn_cast<RankedTensorType>()) {
3192     if (end_type.getRank() != 1) return failure();
3193     if (end_type.getDimSize(0) > num_input_dims) return failure();
3194   }
3195 
3196   if (auto strides_type = op.strides().getType().dyn_cast<RankedTensorType>()) {
3197     if (strides_type.getRank() != 1) return failure();
3198     if (strides_type.getDimSize(0) > num_input_dims) return failure();
3199   }
3200 
3201   // The kernel will reshape the input tensor with new axis, it only supports
3202   // this reshaped tensor up to 5D.
3203   uint32_t ellipsis_mask = op.ellipsis_mask();
3204   uint32_t new_axis_mask = op.new_axis_mask();
3205   int num_added_axis = 0;
3206   for (int i = 0; i < 8; ++i) {
3207     if (!((1 << i) & ellipsis_mask) && ((1 << i) & new_axis_mask)) {
3208       num_added_axis++;
3209     }
3210   }
3211   if (num_input_dims + num_added_axis > 5) return failure();
3212   return success();
3213 }
3214 
fold(ArrayRef<Attribute> operands)3215 OpFoldResult StridedSliceOp::fold(ArrayRef<Attribute> operands) {
3216   // Currently only support all masks being 0.
3217   if (begin_mask() != 0 || end_mask() != 0 || ellipsis_mask() != 0 ||
3218       new_axis_mask() != 0 || shrink_axis_mask() != 0)
3219     return {};
3220 
3221   auto input_type = input().getType().dyn_cast_or_null<RankedTensorType>();
3222   if (!input_type || !input_type.hasStaticShape()) return {};
3223 
3224   // Begin has to be all 0s.
3225   DenseIntElementsAttr begin_dense_elem_attr;
3226   if (!matchPattern(begin(), m_Constant(&begin_dense_elem_attr))) {
3227     return {};
3228   }
3229   for (auto begin_ele : begin_dense_elem_attr) {
3230     if (begin_ele.getSExtValue() != 0) {
3231       return {};
3232     }
3233   }
3234 
3235   // Strides has to be all 1s.
3236   DenseIntElementsAttr strides_dense_elem_attr;
3237   if (!matchPattern(strides(), m_Constant(&strides_dense_elem_attr))) {
3238     return {};
3239   }
3240   for (auto stride_ele : strides_dense_elem_attr) {
3241     if (stride_ele.getSExtValue() != 1) {
3242       return {};
3243     }
3244   }
3245   // End has to map the input shape.
3246   DenseIntElementsAttr end_dense_elem_attr;
3247   if (!matchPattern(end(), m_Constant(&end_dense_elem_attr))) {
3248     return {};
3249   }
3250   int i = 0;
3251   for (auto end_ele : end_dense_elem_attr) {
3252     if (end_ele.getSExtValue() != input_type.getDimSize(i)) {
3253       return {};
3254     }
3255     ++i;
3256   }
3257 
3258   return input();
3259 }
3260 
3261 //===----------------------------------------------------------------------===//
3262 // TransposeOp
3263 //===----------------------------------------------------------------------===//
3264 
3265 namespace {
3266 
3267 // Computes the permutation of a constant `input_tensor` according to `perm`.
3268 // The function recursively traverses the dimensions of the output tensor in
3269 // a row-major order and writes the value in the output tensor into
3270 // `new_values`.
ComputePermutation(ElementsAttr input_tensor,ArrayRef<int32_t> perm,ArrayRef<int64_t> output_shape,int num_dimensions,int output_axis,std::vector<uint64_t> * input_indices,std::vector<Attribute> * new_values)3271 void ComputePermutation(ElementsAttr input_tensor, ArrayRef<int32_t> perm,
3272                         ArrayRef<int64_t> output_shape, int num_dimensions,
3273                         int output_axis, std::vector<uint64_t> *input_indices,
3274                         std::vector<Attribute> *new_values) {
3275   // Refer to the implementation of `Transpose` function in
3276   // tensorflow/lite/kernels/internal/reference/reference_ops.h
3277   assert(output_axis < num_dimensions);
3278   const int input_axis = perm[output_axis];
3279   for (int i = 0; i < output_shape[output_axis]; ++i) {
3280     // Update the input indices on `input_axis`.
3281     input_indices->at(input_axis) = i;
3282     // Write the value from `input_tensor` if it is the last axis or
3283     // recurse into the next axis.
3284     const bool is_last_axis = output_axis == num_dimensions - 1;
3285     if (is_last_axis) {
3286       new_values->push_back(
3287           input_tensor.getValues<Attribute>()[*input_indices]);
3288     } else {
3289       ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
3290                          output_axis + 1, input_indices, new_values);
3291     }
3292   }
3293 }
3294 
3295 }  // namespace
3296 
fold(ArrayRef<Attribute> operands)3297 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
3298   assert(operands.size() == 2);
3299   auto input_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
3300   auto perm_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
3301   if (!input_tensor || !perm_tensor) return nullptr;
3302 
3303   // Do not try to fold elements attr of a quant type because
3304   // DenseElementsAttr does not support it.
3305   if (!getType().cast<ShapedType>().getElementType().isSignlessIntOrFloat())
3306     return nullptr;
3307 
3308   assert(perm_tensor.getType().getRank() == 1);
3309   const int num_dimensions = input_tensor.getType().getRank();
3310   assert(perm_tensor.getType().getNumElements() == num_dimensions);
3311 
3312   ArrayRef<int64_t> input_shape = input_tensor.getType().getShape();
3313   auto output_type = getType().cast<ShapedType>();
3314 
3315   SmallVector<int32_t, 4> perm;
3316   SmallVector<int64_t, 4> output_shape;
3317   for (int i = 0; i < num_dimensions; ++i) {
3318     perm.push_back(perm_tensor.getValues<IntegerAttr>()[i].getInt());
3319     output_shape.push_back(input_shape[perm[i]]);
3320 
3321     // Check that the derived output shape matches the static shape.
3322     assert(!output_type.hasStaticShape() ||
3323            output_type.getShape()[i] == output_shape[i]);
3324   }
3325 
3326   std::vector<Attribute> new_values;
3327   new_values.reserve(input_tensor.getType().getNumElements());
3328   std::vector<uint64_t> input_indices(num_dimensions);
3329   ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
3330                      /*output_axis=*/0, &input_indices, &new_values);
3331   auto result_type =
3332       RankedTensorType::get(output_shape, output_type.getElementType());
3333   return DenseElementsAttr::get(result_type, new_values);
3334 }
3335 
verify()3336 mlir::LogicalResult TransposeOp::verify() {
3337   TransposeOp op = *this;
3338   auto input_type = op.input().getType().cast<ShapedType>();
3339   auto perm_type = op.perm().getType().cast<ShapedType>();
3340   auto output_type = op.output().getType().cast<ShapedType>();
3341   if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
3342     if (perm_type.getNumElements() != input_type.getRank()) {
3343       return op.emitOpError(
3344           "perm tensor elements size is not equal to input tensor rank");
3345     }
3346   }
3347 
3348   DenseIntElementsAttr perm;
3349   if (!matchPattern(op.perm(), m_Constant(&perm))) {
3350     return success();
3351   }
3352 
3353   int index = 0;
3354   llvm::SmallVector<int64_t, 4> axes;
3355   for (const auto &axis_int : perm.getValues<APInt>()) {
3356     const int64_t axis = axis_int.getSExtValue();
3357     if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) {
3358       return op.emitOpError(
3359           llvm::formatv("perm[{0}] must be in [0, rank)", index));
3360     }
3361     if (std::count(axes.begin(), axes.end(), axis) > 0) {
3362       return op.emitOpError(
3363           llvm::formatv("perm[{0}] cannot have duplicated axis", index));
3364     }
3365     axes.push_back(axis);
3366     index++;
3367   }
3368 
3369   if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
3370     llvm::SmallVector<int64_t, 4> transposed_shape;
3371     for (int64_t axis : axes) {
3372       transposed_shape.push_back(input_type.getDimSize(axis));
3373     }
3374     auto expected_output_type =
3375         RankedTensorType::get(transposed_shape, input_type.getElementType());
3376     if (failed(
3377             mlir::verifyCompatibleShape(output_type, expected_output_type))) {
3378       return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
3379                                           expected_output_type, output_type));
3380     }
3381   }
3382 
3383   // Verify the quantized axis if the type is UniformQuantizedPerAxisType. Other
3384   // verifications to make sure the input and output has the same quantization
3385   // type, scale and zero point are performed by the SameOperandsAndResultsScale
3386   // trait.
3387   auto in_per_axis_qtype =
3388       QuantizedType::getQuantizedElementType(input_type)
3389           .dyn_cast_or_null<quant::UniformQuantizedPerAxisType>();
3390   auto out_per_axis_qtype =
3391       QuantizedType::getQuantizedElementType(output_type)
3392           .dyn_cast_or_null<quant::UniformQuantizedPerAxisType>();
3393   if (in_per_axis_qtype && out_per_axis_qtype) {
3394     if (out_per_axis_qtype.getQuantizedDimension() < axes.size() &&
3395         axes[out_per_axis_qtype.getQuantizedDimension()] !=
3396             in_per_axis_qtype.getQuantizedDimension()) {
3397       return op.emitOpError(
3398           "has mismatched quantized axes of input and output");
3399     }
3400   }
3401 
3402   return success();
3403 }
3404 
BuildTransposeOp(OpBuilder * builder,OperationState & result,Value input,Value perm)3405 static void BuildTransposeOp(OpBuilder *builder, OperationState &result,
3406                              Value input, Value perm) {
3407   // Output size is only known if input is ranked and perm is a constant.
3408   auto input_type = input.getType().cast<TensorType>();
3409   DenseIntElementsAttr perm_const;
3410   if (!input_type.hasRank() || !matchPattern(perm, m_Constant(&perm_const)) ||
3411       perm_const.empty()) {
3412     TFL::TransposeOp::build(
3413         *builder, result, UnrankedTensorType::get(input_type.getElementType()),
3414         input, perm);
3415     return;
3416   }
3417 
3418   const auto perm_value_it = perm_const.value_begin<APInt>();
3419 
3420   const ArrayRef<int64_t> input_shape = input_type.getShape();
3421   SmallVector<int64_t, 4> output_shape(input_shape.size());
3422 
3423   for (int i = 0; i < output_shape.size(); ++i) {
3424     const APInt perm_val = perm_value_it[i];
3425     output_shape[i] = input_shape[perm_val.getSExtValue()];
3426   }
3427 
3428   auto element_type = input_type.getElementType();
3429   // For UniformQuantizedPerAxisType element type, the quantized dimension
3430   // should be changed corresponding with the transpose.
3431   auto per_axis_qtype =
3432       QuantizedType::getQuantizedElementType(input_type)
3433           .dyn_cast_or_null<quant::UniformQuantizedPerAxisType>();
3434   if (per_axis_qtype) {
3435     int32_t quantized_dimension = per_axis_qtype.getQuantizedDimension();
3436     for (int i = 0; i < output_shape.size(); ++i) {
3437       const APInt perm_val = perm_value_it[i];
3438       if (perm_val.getSExtValue() == quantized_dimension) {
3439         quantized_dimension = i;
3440         break;
3441       }
3442     }
3443     element_type = quant::UniformQuantizedPerAxisType::get(
3444         per_axis_qtype.getFlags(), per_axis_qtype.getStorageType(),
3445         per_axis_qtype.getExpressedType(), per_axis_qtype.getScales(),
3446         per_axis_qtype.getZeroPoints(), quantized_dimension,
3447         per_axis_qtype.getStorageTypeMin(), per_axis_qtype.getStorageTypeMax());
3448   }
3449 
3450   TFL::TransposeOp::build(*builder, result,
3451                           RankedTensorType::get(output_shape, element_type),
3452                           input, perm);
3453 }
3454 
3455 //===----------------------------------------------------------------------===//
3456 // IfOp
3457 //===----------------------------------------------------------------------===//
3458 
3459 /// Given the region at `index`, or the parent operation if `index` is None,
3460 /// return the successor regions. These are the regions that may be selected
3461 /// during the flow of control. `operands` is a set of optional attributes that
3462 /// correspond to a constant value for each operand, or null if that operand is
3463 /// not a constant.
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)3464 void IfOp::getSuccessorRegions(Optional<unsigned> index,
3465                                ArrayRef<Attribute> operands,
3466                                SmallVectorImpl<RegionSuccessor> &regions) {
3467   // The `then` and the `else` region branch back to the parent operation.
3468   if (index.has_value()) {
3469     regions.push_back(RegionSuccessor(getResults()));
3470     return;
3471   }
3472 
3473   // Don't consider the else region if it is empty.
3474   Region *else_reg = &else_region();
3475   if (else_reg->empty()) else_reg = nullptr;
3476 
3477   // Otherwise, the successor is dependent on the condition.
3478   bool condition;
3479   if (auto cond_attr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
3480     condition = cond_attr.getValue().isOneValue();
3481   } else {
3482     // If the condition isn't constant, both regions may be executed.
3483     regions.push_back(RegionSuccessor(&then_region()));
3484     // If the else region does not exist, it is not a viable successor.
3485     if (else_reg) regions.push_back(RegionSuccessor(else_reg));
3486     return;
3487   }
3488 
3489   // Add the successor regions using the condition.
3490   regions.push_back(RegionSuccessor(condition ? &then_region() : else_reg));
3491 }
3492 
3493 //===----------------------------------------------------------------------===//
3494 // PolyCallOp
3495 //===----------------------------------------------------------------------===//
3496 
3497 namespace {
3498 // Canonicalize converted TF ops into PolymorphicCall op so different
3499 // representations are preserved.
3500 struct PolyCallResultOperandsMatchAndImplicitCapture
3501     : public OpRewritePattern<PolyCallOp> {
3502   using OpRewritePattern<PolyCallOp>::OpRewritePattern;
3503 
matchAndRewritemlir::TFL::__anon216e30ea2411::PolyCallResultOperandsMatchAndImplicitCapture3504   LogicalResult matchAndRewrite(PolyCallOp while_op,
3505                                 PatternRewriter &rewriter) const override {
3506     // Finish this.
3507     return success();
3508   }
3509 };
3510 
3511 }  // namespace
3512 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3513 void PolyCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
3514                                              MLIRContext *context) {
3515   results.add<PolyCallResultOperandsMatchAndImplicitCapture>(context);
3516 }
3517 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)3518 void PolyCallOp::getSuccessorRegions(
3519     Optional<unsigned> index, ArrayRef<Attribute> operands,
3520     SmallVectorImpl<RegionSuccessor> &regions) {
3521   // Defaults to first region for TFLite execution.
3522 }
3523 
3524 //===----------------------------------------------------------------------===//
3525 // WhileOp
3526 //===----------------------------------------------------------------------===//
3527 
verify()3528 LogicalResult WhileOp::verify() {
3529   WhileOp op = *this;
3530   if (op.getNumOperands() != op.getNumResults())
3531     return op.emitOpError(llvm::formatv(
3532         "number of operands does not match number of results ({0} != {1})",
3533         op.getNumOperands(), op.getNumResults()));
3534   if (op.cond().front().getNumArguments() !=
3535       op.body().front().getNumArguments())
3536     return op.emitOpError(llvm::formatv(
3537         "number of arguments in condition function does not match number of "
3538         "arguments in body function ({0} != {1})",
3539         op.cond().front().getNumArguments(),
3540         op.body().front().getNumArguments()));
3541   // Verify shapes are compatible.
3542   for (auto it : llvm::zip(op.cond().front().getArgumentTypes(),
3543                            op.body().front().getArgumentTypes())) {
3544     if (failed(mlir::verifyCompatibleShape(std::get<0>(it), std::get<1>(it))))
3545       return op->emitOpError(llvm::formatv(
3546           "condition function's argument type does not match body "
3547           "function's argument type ({0} != {1})",
3548           std::get<0>(it), std::get<1>(it)));
3549   }
3550 
3551   return success();
3552 }
3553 
3554 namespace {
3555 // Canonicalize While op so that results and operands match and external values
3556 // are via implicit capture rather than via block args.
3557 struct WhileResultOperandsMatchAndImplicitCapture
3558     : public OpRewritePattern<WhileOp> {
3559   using OpRewritePattern<WhileOp>::OpRewritePattern;
3560 
matchAndRewritemlir::TFL::__anon216e30ea2511::WhileResultOperandsMatchAndImplicitCapture3561   LogicalResult matchAndRewrite(WhileOp while_op,
3562                                 PatternRewriter &rewriter) const override {
3563     // Replace values simply passed through the body with extern values
3564     // (in both body and condition regions as well as while result). The
3565     // block arguments of body and while match and so the corresponding cond
3566     // argument can be easily found.
3567     bool unchanged = true;
3568     auto &body_block = while_op.body().front();
3569     auto &cond_block = while_op.cond().front();
3570     auto &yield = *body_block.getTerminator();
3571     for (auto ba : body_block.getArguments()) {
3572       int arg_no = ba.getArgNumber();
3573       // Skip removing resources that are not read-only variables.
3574       if (getElementTypeOrSelf(ba.getType()).isa<TF::ResourceType>()) {
3575         bool has_read_only_variables = true;
3576         for (auto user : ba.getUsers()) {
3577           // Ternimator ops, for example, tfl::yield op, should be ignored since
3578           // the argument can be used for yielding as the `body` function result
3579           // and that does not give any meaningful points to the decision
3580           // whether the given arugment is a read-only variable or not.
3581           if (user->hasTrait<OpTrait::IsTerminator>()) continue;
3582           if (!llvm::isa<mlir::TF::ReadVariableOp>(user)) {
3583             has_read_only_variables = false;
3584             break;
3585           }
3586         }
3587         if (!has_read_only_variables) continue;
3588       }
3589       if (ba == yield.getOperand(arg_no)) {
3590         unchanged = false;
3591         auto value = while_op.getOperand(arg_no);
3592         ba.replaceAllUsesWith(value);
3593         cond_block.getArgument(arg_no).replaceAllUsesWith(value);
3594 
3595         // This could be relaxed and casts inserted.
3596         if (while_op.getResult(arg_no).getType() == value.getType())
3597           while_op.getResult(arg_no).replaceAllUsesWith(value);
3598       }
3599     }
3600 
3601     // The While ops operands and result types need to match
3602     SmallVector<Value, 4> new_operands;
3603     SmallVector<Value, 4> new_body_yield;
3604     SmallVector<bool, 4> removed_operand(while_op.getNumOperands(), false);
3605     llvm::SmallVector<Type, 4> types;
3606     new_operands.reserve(while_op.getNumOperands());
3607     new_body_yield.reserve(while_op.getNumOperands());
3608     types.reserve(while_op.getNumOperands());
3609 
3610     // Remove block arguments not used in either cond or body. This leaves the
3611     // block arguments of body and cond matching still.
3612     int arg_index = 0;
3613     for (int while_index = 0, e = while_op.getNumOperands(); while_index < e;
3614          ++while_index) {
3615       auto value = while_op.getOperand(while_index);
3616       if (body_block.getArgument(arg_index).use_empty() &&
3617           cond_block.getArgument(arg_index).use_empty() &&
3618           // Note: since we are not erasing results, need to use while_index
3619           // to check if the corresponding result is unused.
3620           while_op.getResult(while_index).use_empty()) {
3621         unchanged = false;
3622         body_block.eraseArgument(arg_index);
3623         cond_block.eraseArgument(arg_index);
3624 
3625         // Mark operand for removal.
3626         removed_operand[while_index] = true;
3627       } else {
3628         new_operands.push_back(value);
3629         new_body_yield.push_back(yield.getOperand(while_index));
3630         auto type = while_op.getResult(while_index).getType();
3631         types.push_back(type);
3632         ++arg_index;
3633       }
3634     }
3635 
3636     // Done if no values removed from blocks and operands & results match.
3637     if (unchanged) return failure();
3638 
3639     // Replace with new While with matching operands and results.
3640     Operation *op = while_op.getOperation();
3641     Operation *new_op = rewriter.insert(
3642         Operation::create(op->getLoc(), op->getName(), types, new_operands,
3643                           op->getAttrs(), {}, /*numRegions=*/2));
3644 
3645     for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
3646     int new_index = 0;
3647     for (int op_index = 0, e = op->getNumResults(); op_index < e; ++op_index) {
3648       if (removed_operand[op_index]) continue;
3649       op->getResult(op_index).replaceAllUsesWith(new_op->getResult(new_index));
3650       ++new_index;
3651     }
3652     rewriter.eraseOp(op);
3653 
3654     Block &new_body_block = cast<WhileOp>(new_op).body().front();
3655     rewriter.setInsertionPointToEnd(&new_body_block);
3656     rewriter.replaceOpWithNewOp<YieldOp>(new_body_block.getTerminator(),
3657                                          new_body_yield);
3658 
3659     return success();
3660   }
3661 };
3662 
3663 }  // namespace
3664 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3665 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3666                                           MLIRContext *context) {
3667   results.add<WhileResultOperandsMatchAndImplicitCapture>(context);
3668 }
3669 
getLoopBody()3670 Region &WhileOp::getLoopBody() { return body(); }
3671 
isDefinedOutsideOfLoop(Value value)3672 bool WhileOp::isDefinedOutsideOfLoop(Value value) {
3673   // TODO(jpienaar): This is to overly conservative and disables anything other
3674   // than constant hoisting initially.
3675   return false;
3676 }
3677 
3678 //===----------------------------------------------------------------------===//
3679 // LogisticOp
3680 //===----------------------------------------------------------------------===//
3681 
GetArithmeticCount(Operation * op)3682 int64_t LogisticOp::GetArithmeticCount(Operation *op) {
3683   int64_t count;
3684   // As a very rough ballpark, the cost of evaluating a math function
3685   // such as tanh or logistic is about 32 multiplications, and about as
3686   // many additions/subtractions. (Just a power-of-two order-of-magnitude
3687   // from looking at actual implementations that we use in runtime/code).
3688   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3689     return 64 * count;
3690 
3691   return -1;
3692 }
3693 
3694 //===----------------------------------------------------------------------===//
3695 // LogSoftmaxOp
3696 //===----------------------------------------------------------------------===//
3697 
GetArithmeticCount(Operation * op)3698 int64_t LogSoftmaxOp::GetArithmeticCount(Operation *op) {
3699   int64_t count;
3700   // As a very rough ballpark, the cost of evaluating a math function
3701   // such as tanh or logistic is about 32 multiplications, and about as
3702   // many additions/subtractions. (Just a power-of-two order-of-magnitude
3703   // from looking at actual implementations that we use in runtime/code).
3704   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3705     return 64 * count;
3706 
3707   return -1;
3708 }
3709 
3710 //===----------------------------------------------------------------------===//
3711 // SoftmaxOp
3712 //===----------------------------------------------------------------------===//
3713 
GetArithmeticCount(Operation * op)3714 int64_t SoftmaxOp::GetArithmeticCount(Operation *op) {
3715   int64_t count;
3716   // As a very rough ballpark, the cost of evaluating a math function
3717   // such as tanh or logistic is about 32 multiplications, and about as
3718   // many additions/subtractions. (Just a power-of-two order-of-magnitude
3719   // from looking at actual implementations that we use in runtime/code).
3720   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3721     return 64 * count;
3722 
3723   return -1;
3724 }
3725 
3726 //===----------------------------------------------------------------------===//
3727 // TanhOp
3728 //===----------------------------------------------------------------------===//
3729 
GetArithmeticCount(Operation * op)3730 int64_t TanhOp::GetArithmeticCount(Operation *op) {
3731   int64_t count;
3732   // As a very rough ballpark, the cost of evaluating a math function
3733   // such as tanh or logistic is about 32 multiplications, and about as
3734   // many additions/subtractions. (Just a power-of-two order-of-magnitude
3735   // from looking at actual implementations that we use in runtime/code).
3736   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3737     return 64 * count;
3738 
3739   return -1;
3740 }
3741 
3742 //===----------------------------------------------------------------------===//
3743 // AddNOp
3744 //===----------------------------------------------------------------------===//
3745 
GetArithmeticCount(Operation * op)3746 int64_t AddNOp::GetArithmeticCount(Operation *op) {
3747   int64_t count;
3748   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3749     // AddN cost is roughly the same cost as N-1 Adds.
3750     const int64_t num_adds = op->getNumOperands() - 1;
3751     return num_adds * count;
3752   }
3753 
3754   return -1;
3755 }
3756 
3757 //===----------------------------------------------------------------------===//
3758 // AveragePool2DOp
3759 //===----------------------------------------------------------------------===//
3760 
GetArithmeticCount(Operation * op)3761 int64_t AveragePool2DOp::GetArithmeticCount(Operation *op) {
3762   int64_t count;
3763   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3764     auto avg_pool = llvm::dyn_cast<AveragePool2DOp>(op);
3765     return avg_pool.filter_height() * avg_pool.filter_width() * count;
3766   }
3767 
3768   return -1;
3769 }
3770 
3771 //===----------------------------------------------------------------------===//
3772 // MaxPool2DOp
3773 //===----------------------------------------------------------------------===//
3774 
GetArithmeticCount(Operation * op)3775 int64_t MaxPool2DOp::GetArithmeticCount(Operation *op) {
3776   int64_t count;
3777   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3778     auto max_pool = llvm::dyn_cast<MaxPool2DOp>(op);
3779     return max_pool.filter_height() * max_pool.filter_width() * count;
3780   }
3781 
3782   return -1;
3783 }
3784 
3785 //===----------------------------------------------------------------------===//
3786 // L2NormalizationOp
3787 //===----------------------------------------------------------------------===//
3788 
GetArithmeticCount(Operation * op)3789 int64_t L2NormalizationOp::GetArithmeticCount(Operation *op) {
3790   int64_t count;
3791   // Computing the squared L2 norm is N multiply-adds so 2N ops,
3792   // then the single inverse-sqrt is negligible, then we multiply each
3793   // value by the resulting multiplier, so an extra N ops. count 3N ops.
3794   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3795     return 3 * count;
3796   }
3797 
3798   return -1;
3799 }
3800 
3801 //===----------------------------------------------------------------------===//
3802 // PadOp
3803 //===----------------------------------------------------------------------===//
3804 
fold(ArrayRef<Attribute> operands)3805 OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
3806   if (InputOutputHasSameShape(input().getType(), output().getType()))
3807     return input();
3808 
3809   return {};
3810 }
3811 
3812 //===----------------------------------------------------------------------===//
3813 // PadV2Op
3814 //===----------------------------------------------------------------------===//
3815 
fold(ArrayRef<Attribute> operands)3816 OpFoldResult PadV2Op::fold(ArrayRef<Attribute> operands) {
3817   if (InputOutputHasSameShape(input().getType(), output().getType()))
3818     return input();
3819 
3820   return {};
3821 }
3822 
3823 //===----------------------------------------------------------------------===//
3824 // NoValueOp
3825 //===----------------------------------------------------------------------===//
3826 
fold(ArrayRef<Attribute> operands)3827 OpFoldResult NoValueOp::fold(ArrayRef<Attribute> operands) {
3828   return valueAttr();
3829 }
3830 
isBuildableWith(Attribute value,Type type)3831 bool NoValueOp::isBuildableWith(Attribute value, Type type) {
3832   return value.isa<UnitAttr>() && type.isa<NoneType>();
3833 }
3834 
GetYield()3835 YieldOp ControlNodeOp::GetYield() {
3836   return llvm::cast<YieldOp>(GetBody().back());
3837 }
3838 
3839 // Checks if a TFL.control_node wraps a single operation and the single
3840 // operation results are perfectly forwarded to the wrapper's yield.
WrapsSinglePerfectlyForwardedOp()3841 bool ControlNodeOp::WrapsSinglePerfectlyForwardedOp() {
3842   auto body = GetBody().without_terminator();
3843   if (!hasSingleElement(body)) return false;
3844 
3845   Operation &controlled_op = *body.begin();
3846   YieldOp yield = GetYield();
3847   return controlled_op.getNumResults() == yield.getNumOperands() &&
3848          std::equal(controlled_op.getResults().begin(),
3849                     controlled_op.getResults().end(),
3850                     yield.getOperands().begin());
3851 }
3852 
verify()3853 mlir::LogicalResult ControlNodeOp::verify() {
3854   ControlNodeOp control_node = *this;
3855   if (!control_node.GetBody().args_empty())
3856     return control_node.emitOpError() << "expects body without any arguments";
3857 
3858   Operation &yield = control_node.GetBody().back();
3859   if (!isa<YieldOp>(yield))
3860     return yield.emitOpError()
3861            << "invalid TFL.control_node terminator, yield expected";
3862 
3863   // Ensure that the terminator's operands and the control_node results match in
3864   // types.
3865   const int result_count =
3866       control_node.getNumResults() - 1;  // 1 for control token
3867   const int num_operands = yield.getNumOperands();
3868   if (num_operands != result_count)
3869     return yield.emitOpError()
3870            << "has " << yield.getNumOperands()
3871            << " operand, but control_node returns " << result_count;
3872   for (const int operand_idx : llvm::seq<int>(0, yield.getNumOperands())) {
3873     if (control_node.getResult(operand_idx).getType() !=
3874         yield.getOperand(operand_idx).getType())
3875       return yield.emitOpError() << "operand #" << operand_idx
3876                                  << " type mismatch control_node results";
3877   }
3878   return success();
3879 }
3880 
print(OpAsmPrinter & p)3881 void ControlNodeOp::print(OpAsmPrinter &p) {
3882   if (getNumOperands()) {
3883     // These are always control operand, no explicit type needed.
3884     p << '(';
3885     p.printOperands(getOperands());
3886     p << ')';
3887   }
3888   // Check if we can print the short "controls" form: that is if the
3889   // control_node contains a single operation and the results of this operation
3890   // are perfectly forwarded to the yield.
3891   if (getOperation()->getAttrs().empty() && WrapsSinglePerfectlyForwardedOp()) {
3892     Operation &controlled_op = GetBody().front();
3893     // The "controls" syntax only encodes a single location.
3894     YieldOp yield_op = GetYield();
3895     // In order to correctly round-trip, we can only use this syntax when all
3896     // the locations are identical.
3897     if (controlled_op.getLoc() == getLoc() && yield_op.getLoc() == getLoc()) {
3898       p << " controls ";
3899       p.printGenericOp(&controlled_op);
3900       return;
3901     }
3902   }
3903   p << ' ';
3904   p.printRegion(getOperation()->getRegion(0));
3905   p.printOptionalAttrDict(getOperation()->getAttrs());
3906 }
3907 
parse(OpAsmParser & parser,OperationState & result)3908 ParseResult ControlNodeOp::parse(OpAsmParser &parser, OperationState &result) {
3909   // Parse the body region.
3910   llvm::SMLoc loc = parser.getCurrentLocation();
3911   Type control_type = ControlType::get(parser.getBuilder().getContext());
3912 
3913   // Parse optional argument list (control dependencies only).
3914   SmallVector<OpAsmParser::UnresolvedOperand, 4> op_infos;
3915   if (parser.parseOperandList(op_infos, OpAsmParser::Delimiter::OptionalParen))
3916     return failure();
3917   if (!op_infos.empty()) {
3918     SmallVector<Type, 2> types(op_infos.size(), control_type);
3919     if (parser.resolveOperands(op_infos, types, loc, result.operands))
3920       return failure();
3921   }
3922 
3923   Region &body = *result.addRegion();
3924 
3925   if (succeeded(parser.parseOptionalKeyword("controls"))) {
3926     // If we parse the short version of the control node, we have an operation
3927     // in the generic form that follows the "controls" keyword. Parse it inside
3928     // the region and forward all of its results as-is to the yield operation.
3929     body.push_back(new Block);
3930     Block &block = body.back();
3931     Operation *controlled_op =
3932         parser.parseGenericOperation(&block, block.begin());
3933     if (!controlled_op) return failure();
3934     OpBuilder builder(parser.getBuilder().getContext());
3935     builder.setInsertionPointToEnd(&block);
3936     builder.create<YieldOp>(controlled_op->getLoc(),
3937                             controlled_op->getResults());
3938     result.location = controlled_op->getLoc();
3939   } else if (parser.parseRegion(body)) {
3940     return failure();
3941   }
3942 
3943   ControlNodeOp::ensureTerminator(body, parser.getBuilder(), result.location);
3944 
3945   // Get the results type for the control node from the terminator operands.
3946   Operation &yield = body.back().back();
3947   result.types.reserve(yield.getNumOperands() + 1);
3948   result.types.append(yield.operand_type_begin(), yield.operand_type_end());
3949   result.types.push_back(control_type);
3950 
3951   // Parse the optional attribute list.
3952   if (parser.parseOptionalAttrDict(result.attributes)) return failure();
3953   return success();
3954 }
3955 
3956 //===----------------------------------------------------------------------===//
3957 // ConstBytesAttr
3958 //===----------------------------------------------------------------------===//
3959 
parse(AsmParser & parser,Type type)3960 Attribute ConstBytesAttr::parse(AsmParser &parser, Type type) {
3961   if (parser.parseColon()) {
3962     return nullptr;
3963   }
3964 
3965   std::string data;
3966   if (parser.parseString(&data)) {
3967     return nullptr;
3968   }
3969   if (data.size() < 2 || data.substr(0, 2) != "0x") {
3970     parser.emitError(parser.getNameLoc(), "Hex string doesn't start with `0x`");
3971     return nullptr;
3972   }
3973 
3974   std::string bytes_data = absl::HexStringToBytes(data.substr(2));
3975   return ConstBytesAttr::get(parser.getBuilder().getContext(), bytes_data);
3976 }
3977 
print(mlir::AsmPrinter & printer) const3978 void ConstBytesAttr::print(mlir::AsmPrinter &printer) const {
3979   StringRef bytes_str = getValue();
3980   printer << " : \"0x" << llvm::toHex(bytes_str) << "\"";
3981 }
3982 
3983 //===----------------------------------------------------------------------===//
3984 // TableGen'd op method definitions
3985 //===----------------------------------------------------------------------===//
3986 
3987 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
3988 
parseI32Array(AsmParser & parser)3989 static FailureOr<SmallVector<int32_t>> parseI32Array(AsmParser &parser) {
3990   SmallVector<int32_t> elements;
3991   auto elementParser = [&]() {
3992     int32_t element;
3993     if (failed(parser.parseInteger(element))) return failure();
3994     elements.push_back(element);
3995     return success();
3996   };
3997   if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
3998                                      elementParser))
3999     return failure();
4000   return elements;
4001 }
4002 
4003 }  // namespace TFL
4004 }  // namespace mlir
4005 
4006 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_dialect.cc.inc"
4007 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_enums.cc.inc"
4008 #define GET_ATTRDEF_CLASSES
4009 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_attrdefs.cc.inc"
4010 #define GET_OP_CLASSES
4011 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
4012 
4013 namespace mlir {
4014 namespace TFL {
4015 
4016 #include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc"
4017 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)4018 Operation *TFLDialect::materializeConstant(OpBuilder &builder, Attribute value,
4019                                            Type type, Location loc) {
4020   // If this is a constant bytes attribute or the result type doesn't match the
4021   // attribute type, then generate a tfl.pseudo_const.
4022   if (value.isa<ConstBytesAttr>() ||
4023       (value.isa<ElementsAttr>() &&
4024        value.cast<ElementsAttr>().getType() != type))
4025     return builder.create<ConstOp>(loc, type, value.cast<ElementsAttr>());
4026   if (arith::ConstantOp::isBuildableWith(value, type))
4027     return builder.create<arith::ConstantOp>(loc, type, value);
4028   if (NoValueOp::isBuildableWith(value, type))
4029     return builder.create<NoValueOp>(loc, type, value.cast<UnitAttr>());
4030   return nullptr;
4031 }
4032 
4033 }  // namespace TFL
4034 }  // namespace mlir
4035