1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 #include <string>
24
25 #include "absl/strings/escaping.h"
26 #include "third_party/eigen3/Eigen/Core"
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/Optional.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SetVector.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/FormatVariadic.h"
37 #include "llvm/Support/Threading.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
40 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
41 #include "mlir/IR/Attributes.h" // from @llvm-project
42 #include "mlir/IR/Builders.h" // from @llvm-project
43 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
44 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
45 #include "mlir/IR/Location.h" // from @llvm-project
46 #include "mlir/IR/MLIRContext.h" // from @llvm-project
47 #include "mlir/IR/Matchers.h" // from @llvm-project
48 #include "mlir/IR/OpImplementation.h" // from @llvm-project
49 #include "mlir/IR/PatternMatch.h" // from @llvm-project
50 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
51 #include "mlir/IR/Types.h" // from @llvm-project
52 #include "mlir/Support/LLVM.h" // from @llvm-project
53 #include "mlir/Support/LogicalResult.h" // from @llvm-project
54 #include "mlir/Transforms/FoldUtils.h" // from @llvm-project
55 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
56 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
57 #include "tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h"
58 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
59 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
60 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
61 #include "tensorflow/core/framework/kernel_shape_util.h"
62
63 namespace mlir {
64 namespace TFL {
65 namespace {
66
parseOneResultSameOperandTypeOp(OpAsmParser & parser,OperationState & result)67 ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
68 OperationState &result) {
69 SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
70 Type type;
71 // If the operand list is in-between parentheses, then we have a generic form.
72 // (see the fallback in `printOneResultOp`).
73 SMLoc loc = parser.getCurrentLocation();
74 if (!parser.parseOptionalLParen()) {
75 if (parser.parseOperandList(ops) || parser.parseRParen() ||
76 parser.parseOptionalAttrDict(result.attributes) ||
77 parser.parseColon() || parser.parseType(type))
78 return failure();
79 auto fnType = type.dyn_cast<FunctionType>();
80 if (!fnType) {
81 parser.emitError(loc, "expected function type");
82 return failure();
83 }
84 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
85 return failure();
86 result.addTypes(fnType.getResults());
87 return success();
88 }
89 return failure(parser.parseOperandList(ops) ||
90 parser.parseOptionalAttrDict(result.attributes) ||
91 parser.parseColonType(type) ||
92 parser.resolveOperands(ops, type, result.operands) ||
93 parser.addTypeToList(type, result.types));
94 }
95
printOneResultOp(Operation * op,OpAsmPrinter & p)96 void printOneResultOp(Operation *op, OpAsmPrinter &p) {
97 assert(op->getNumResults() == 1 && "op should have one result");
98
99 // If not all the operand and result types are the same, just use the
100 // generic assembly form to avoid omitting information in printing.
101 auto resultType = op->getResult(0).getType();
102 if (llvm::any_of(op->getOperandTypes(),
103 [&](Type type) { return type != resultType; })) {
104 p.printGenericOp(op, /*printOpName=*/false);
105 return;
106 }
107
108 p << ' ';
109 p.printOperands(op->getOperands());
110 p.printOptionalAttrDict(op->getAttrs());
111 // Now we can output only one type for all operands and the result.
112 p << " : " << resultType;
113 }
114
getDefiningBroadcastArgsOp(Value operand)115 Operation *getDefiningBroadcastArgsOp(Value operand) {
116 auto *defining_op = operand.getDefiningOp();
117 if (!llvm::dyn_cast_or_null<TF::BroadcastToOp>(defining_op) &&
118 !llvm::dyn_cast_or_null<TFL::BroadcastToOp>(defining_op)) {
119 return nullptr;
120 }
121
122 Value broadcast_shape = defining_op->getOperand(
123 1); // Broadcasted shape operand of BroadcastTo op.
124 Operation *parent_of_defining_op = broadcast_shape.getDefiningOp();
125 if (!llvm::dyn_cast_or_null<TF::BroadcastArgsOp>(parent_of_defining_op) &&
126 !llvm::dyn_cast_or_null<TFL::BroadcastArgsOp>(parent_of_defining_op)) {
127 return nullptr;
128 }
129 return parent_of_defining_op;
130 }
131 } // namespace
132
133 // Returns true when the given type lists contain a single element of shaped
134 // type with compatible shapes (unranked shape is compatible with any ranked
135 // shape, ranked shapes are compatible if their respective dimensions are
136 // compatible, dynamic dimensions are compatible with any size, static
137 // dimensions must be equal to be compatible) and identical element types.
VerifyCompatibleShapesSameElementType(TypeRange lhs,TypeRange rhs)138 bool VerifyCompatibleShapesSameElementType(TypeRange lhs, TypeRange rhs) {
139 if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
140 if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
141 auto lhsShaped = lhs[0].cast<ShapedType>();
142 auto rhsShaped = rhs[0].cast<ShapedType>();
143 return lhsShaped.getElementType() == rhsShaped.getElementType();
144 }
145
146 // Returns true when the given operand arguments have the same shape or
147 // broadcastable shape within the given rank. If any given shapes are
148 // non-static and maximum rank is within the given rank, this method returns
149 // true.
VerifyOperandsHaveSameShapesOrBroadcastableShape(Operation * op,ArrayRef<unsigned> indices,int max_bcast_rank)150 bool VerifyOperandsHaveSameShapesOrBroadcastableShape(
151 Operation *op, ArrayRef<unsigned> indices, int max_bcast_rank) {
152 if (indices.empty()) return true;
153
154 // First, it checks there are any inputs that has unknown rank.
155 bool has_unknown_shape_input = false;
156 bool has_same_shape = true;
157 bool reach_first_known_shape = false;
158 int64_t max_rank = -1;
159
160 ArrayRef<int64_t> pivot_shape;
161 SmallVector<int64_t, 4> current_shape;
162 SmallVector<int64_t, 4> result_shape;
163
164 for (unsigned index : indices) {
165 ShapedType shaped_type =
166 op->getOperand(index).getType().dyn_cast<ShapedType>();
167 if (!shaped_type || !shaped_type.hasRank()) {
168 // Marks that we have an unknown rank input.
169 has_unknown_shape_input = true;
170 continue;
171 }
172 max_rank = std::max(max_rank, shaped_type.getRank());
173 if (!shaped_type.hasStaticShape()) {
174 // Marks that we have an unknown shape input.
175 has_unknown_shape_input = true;
176 continue;
177 }
178
179 ArrayRef<int64_t> shape = shaped_type.getShape();
180 if (!reach_first_known_shape) {
181 pivot_shape = shape;
182 current_shape.assign(shape.begin(), shape.end());
183 reach_first_known_shape = true;
184 continue;
185 }
186
187 if (!pivot_shape.equals(shape)) {
188 has_same_shape = false;
189 }
190 // Checks if all the inputs are broadcastable since they have not all the
191 // same shapes.
192 if (!OpTrait::util::getBroadcastedShape(current_shape, shape,
193 result_shape)) {
194 return false;
195 }
196 current_shape = result_shape;
197 }
198
199 // If all the shape is known and same, CPU kernels are able to handle inputs
200 // regardless of dimension size.
201 if (!has_unknown_shape_input) {
202 return has_same_shape || max_rank <= max_bcast_rank;
203 }
204
205 // It will treat the unknown shape inputs as acceptable inputs for model
206 // compatibility if all known ranks are no bigger than the allowed broadcast
207 // maximum rank.
208 if (max_rank <= max_bcast_rank) {
209 return true;
210 }
211
212 // Checks if all operands are broadcasted by BroadcastTo ops with the shape
213 // is calculated from the same BroadcastArgs op. In such case, all operands
214 // will have the same shape.
215 Operation *broadcast_args_pivot = nullptr;
216 for (unsigned index : indices) {
217 Operation *parent_broadcast_args =
218 getDefiningBroadcastArgsOp(op->getOperand(index));
219 if (parent_broadcast_args == nullptr) {
220 return false;
221 }
222
223 if (broadcast_args_pivot == nullptr) {
224 broadcast_args_pivot = parent_broadcast_args;
225 continue;
226 }
227
228 if (broadcast_args_pivot != parent_broadcast_args) {
229 return false;
230 }
231 }
232 return true;
233 }
234
235 // Return true when the given element_type is QI8.
IsQI8Type(Type element_type)236 bool IsQI8Type(Type element_type) {
237 auto quantized_type = element_type.dyn_cast<QuantizedType>();
238 return quantized_type != nullptr &&
239 quantized_type.getStorageTypeIntegralWidth() == 8 &&
240 quantized_type.isSigned();
241 }
242
243 // Return true when the given element_type is QUI8.
IsQUI8Type(Type element_type)244 bool IsQUI8Type(Type element_type) {
245 auto quantized_type = element_type.dyn_cast<QuantizedType>();
246 return quantized_type != nullptr &&
247 quantized_type.getStorageTypeIntegralWidth() == 8 &&
248 !quantized_type.isSigned();
249 }
250
251 // Return true when the given element_type is QI16.
IsQI16Type(Type element_type)252 bool IsQI16Type(Type element_type) {
253 auto quantized_type = element_type.dyn_cast<QuantizedType>();
254 return quantized_type != nullptr &&
255 quantized_type.getStorageTypeIntegralWidth() == 16 &&
256 quantized_type.isSigned();
257 }
258
259 // Return true when the given element_type is I32.
IsI32Type(Type element_type)260 bool IsI32Type(Type element_type) {
261 return element_type.isInteger(32) && !element_type.isUnsignedInteger();
262 }
263
264 // Return true when the given element_type is I64.
IsI64Type(Type element_type)265 bool IsI64Type(Type element_type) {
266 return element_type.isInteger(64) && !element_type.isUnsignedInteger();
267 }
268
269 // Return true if the value is a splat tensor constant zero.
EqualsZero(Value value)270 bool EqualsZero(Value value) {
271 DenseElementsAttr constant;
272 if (!matchPattern(value, m_Constant(&constant)) || !constant.isSplat()) {
273 return false;
274 }
275
276 Type element_type = value.getType().cast<ShapedType>().getElementType();
277 if (element_type.isa<FloatType>()) {
278 return constant.getSplatValue<APFloat>().isZero();
279 } else {
280 return false;
281 }
282 }
283
284 // Replaces the bias operand with a "none" type value if the bias value is
285 // constant zero.
286 // `ConcreteOpType` must be an concrete MLIR op class that has an optional
287 // bias operand named 'bias'.
288 template <typename ConcreteOpType>
289 struct RemoveOptionalZeroBias : public OpRewritePattern<ConcreteOpType> {
290 using OpRewritePattern<ConcreteOpType>::OpRewritePattern;
291
matchAndRewritemlir::TFL::RemoveOptionalZeroBias292 LogicalResult matchAndRewrite(ConcreteOpType op,
293 PatternRewriter &rewriter) const override {
294 if (EqualsZero(op.bias())) {
295 auto none_value = rewriter.create<TFL::NoValueOp>(
296 rewriter.getUnknownLoc(), rewriter.getNoneType(),
297 rewriter.getUnitAttr());
298 op.biasMutable().assign(none_value);
299 }
300
301 return success();
302 }
303 };
304
305 // Return true if the given Add operation has the CPU kernel supported shapes.
VerifyAddOpShapeConstraints(AddOp op)306 bool VerifyAddOpShapeConstraints(AddOp op) {
307 auto element_type = getElementTypeOrSelf(op.output().getType());
308
309 // Allows F32, QI8, QUI8 and I32 outputs when the operands have valid shapes,
310 // which are broadcastable shapes up to four dimensions or have same shapes.
311 if (element_type.isF32() || IsQI8Type(element_type) ||
312 IsQUI8Type(element_type) || IsI32Type(element_type) ||
313 IsI64Type(element_type)) {
314 return VerifyOperandsHaveSameShapesOrBroadcastableShape(
315 /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
316 /*max_bcast_rank=*/4);
317 }
318
319 // Allows QI16 output when operands have the same shape.
320 if (IsQI16Type(element_type)) {
321 return succeeded(
322 mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
323 }
324 return false;
325 }
326
327 // Return true if the given Sub operation has the CPU kernel supported shapes.
VerifySubOpShapeConstraints(SubOp op)328 bool VerifySubOpShapeConstraints(SubOp op) {
329 auto element_type = getElementTypeOrSelf(op.output().getType());
330
331 // Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
332 // which are broadcastable shapes up to five dimension or have same shapes.
333 if (element_type.isF32() || IsI32Type(element_type) ||
334 IsI64Type(element_type) || IsQUI8Type(element_type) ||
335 IsQI16Type(element_type)) {
336 return VerifyOperandsHaveSameShapesOrBroadcastableShape(
337 /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
338 /*max_bcast_rank=*/5);
339 }
340
341 // Allows QI8 output when the operands have valid shapes, which are
342 // broadcastable shapes up to four dimension or have same shapes.
343 if (IsQI8Type(element_type)) {
344 return VerifyOperandsHaveSameShapesOrBroadcastableShape(
345 /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
346 /*max_bcast_rank=*/4);
347 }
348 return false;
349 }
350
351 // Return true if the given Mul operation has the CPU kernel supported shapes.
VerifyMulOpShapeConstraints(MulOp op)352 bool VerifyMulOpShapeConstraints(MulOp op) {
353 auto element_type = getElementTypeOrSelf(op.output().getType());
354
355 // Allows QI8 and QUI8 inputs up to five dimension broadcasting unless the
356 // output type is not QI16. If the output type is Q16, allows only the same
357 // shape operands.
358 if (IsQI8Type(element_type) || IsQUI8Type(element_type)) {
359 if (IsQI16Type(getElementTypeOrSelf(op.lhs().getType()))) {
360 return succeeded(
361 mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
362 }
363 return VerifyOperandsHaveSameShapesOrBroadcastableShape(
364 /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
365 /*max_bcast_rank=*/4);
366 }
367
368 // Allows I32, I64, QI16 and F32 outputs when the operands have valid shapes,
369 // which are broadcastable shapes up to four dimension or have same shapes.
370 if (IsI32Type(element_type) || IsI64Type(element_type) ||
371 IsQI16Type(element_type) || element_type.isa<ComplexType>() ||
372 element_type.isF32()) {
373 return VerifyOperandsHaveSameShapesOrBroadcastableShape(
374 /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
375 /*max_bcast_rank=*/4);
376 }
377 return false;
378 }
379
380 //===----------------------------------------------------------------------===//
381 // TensorFlowLiteDialect
382 //===----------------------------------------------------------------------===//
383
384 struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface {
385 using DialectInlinerInterface::DialectInlinerInterface;
386
387 //===--------------------------------------------------------------------===//
388 // Analysis Hooks
389 //===--------------------------------------------------------------------===//
390
391 // Allow all call operations to be inlined.
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface392 bool isLegalToInline(Operation *call, Operation *callable,
393 bool wouldBeCloned) const final {
394 return true;
395 }
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface396 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
397 BlockAndValueMapping &) const final {
398 // No TFLite op restricts inlining today, revise as needed in the future.
399 return true;
400 }
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface401 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
402 BlockAndValueMapping &valueMapping) const final {
403 return isa<WhileOp>(dest->getParentOp());
404 }
405 };
406
407 struct TensorFlowLiteDialectFoldInterface : public DialectFoldInterface {
408 using DialectFoldInterface::DialectFoldInterface;
409
410 // Registered hook to check if the given region, which is attached to an
411 // operation that is *not* isolated from above (i.e. no internal regions
412 // reference values defined in an enclosing region), should be used when
413 // materializing constants.
414 // In the TFLite dialect we materialize inside a while regions as slightly
415 // more efficient computationally.
shouldMaterializeIntomlir::TFL::TensorFlowLiteDialectFoldInterface416 bool shouldMaterializeInto(Region *region) const final {
417 return isa<WhileOp>(region->getParentOp());
418 }
419 };
420
printType(Type type,DialectAsmPrinter & os) const421 void TFLDialect::printType(Type type, DialectAsmPrinter &os) const {
422 if (type.isa<ControlType>()) {
423 os << "control";
424 return;
425 }
426 os << "<unknown TFL type>";
427 }
428
parseType(DialectAsmParser & parser) const429 Type TFLDialect::parseType(DialectAsmParser &parser) const {
430 StringRef data_type;
431 if (parser.parseKeyword(&data_type)) return Type();
432 if (data_type == "control") return ControlType::get(getContext());
433 parser.emitError(parser.getNameLoc()) << "unknown TFL type: " << data_type;
434 return nullptr;
435 }
436
initialize()437 void TFLDialect::initialize() {
438 addOperations<
439 #define GET_OP_LIST
440 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
441 >();
442 addAttributes<
443 #define GET_ATTRDEF_LIST
444 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_attrdefs.cc.inc"
445 >();
446 addInterfaces<TensorFlowLiteInlinerInterface,
447 TensorFlowLiteDialectFoldInterface>();
448 addTypes<ControlType>();
449 }
450
451 //===----------------------------------------------------------------------===//
452 // Common support logic
453 //===----------------------------------------------------------------------===//
454
455 namespace {
456
457 // Returns true if the dimensions in `a` is a suffix of the ones in `b`.
458 // For example, dimensions {2}, {1, 2}, and {3, 1, 2} are all suffixes to
459 // {5, 4, 3, 1, 2}, while {1}, {5, 4}, and {1, 3, 2} are all not.
IsTrailingDimensions(ArrayRef<int64_t> a,ArrayRef<int64_t> b)460 inline bool IsTrailingDimensions(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
461 if (a.size() > b.size()) return false;
462
463 return std::equal(a.rbegin(), a.rend(), b.rbegin());
464 }
465
466 // Returns true if it is a shaped type of f32 elements.
IsF32ShapedType(Type t)467 inline bool IsF32ShapedType(Type t) {
468 if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
469 return shaped_type.getElementType().isF32();
470 }
471 return false;
472 }
473
474 // Returns true if it is a shaped type of bf16 elements.
IsBF16ShapedType(Type t)475 inline bool IsBF16ShapedType(Type t) {
476 if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
477 return shaped_type.getElementType().isBF16();
478 }
479 return false;
480 }
481
482 // Returns new shape with rank 'new_dims' with padded ones on the
483 // left if needed.
GetPaddedShape(ArrayRef<int64_t> old_shape,int new_dims)484 inline std::vector<int64_t> GetPaddedShape(ArrayRef<int64_t> old_shape,
485 int new_dims) {
486 std::vector<int64_t> new_shape(new_dims, 1);
487 std::copy_backward(old_shape.begin(), old_shape.end(), new_shape.end());
488 return new_shape;
489 }
490
491 // Helper method that given and 'current_index' representing
492 // index in broadcasted tensor, get the index in the flat original tensor.
493 // 'shape' is the original shape with padding to match result shape.
GetElementIndex(const std::vector<int64_t> & shape,const std::vector<int64_t> & current_index)494 int64_t GetElementIndex(const std::vector<int64_t> &shape,
495 const std::vector<int64_t> ¤t_index) {
496 int64_t ind = 0;
497 int64_t mul = 1;
498 for (int i = shape.size() - 1; i >= 0; --i) {
499 ind += (current_index[i] % shape[i]) * mul;
500 mul *= shape[i];
501 }
502 return ind;
503 }
504
505 // Helper method that increment index represented in 'current_index_ptr'
506 // in the shape of 'result_shape'.
IncrementIndex(ArrayRef<int64_t> result_shape,std::vector<int64_t> * current_index_ptr)507 void IncrementIndex(ArrayRef<int64_t> result_shape,
508 std::vector<int64_t> *current_index_ptr) {
509 std::vector<int64_t> ¤t_index = *current_index_ptr;
510 for (int i = result_shape.size() - 1; i >= 0; --i) {
511 current_index[i]++;
512 if (current_index[i] == result_shape[i]) {
513 current_index[i] = 0;
514 } else {
515 break;
516 }
517 }
518 }
519
520 /// Performs const folding `calculate` with broadcast behavior on the two
521 /// attributes `operand1` and `operand2` and returns the result if possible.
522 /// This function assumes the both operands are verified to have value
523 /// attributes of broadcastable types.
524 template <class AttrElementT,
525 class ElementValueT = typename AttrElementT::ValueType,
526 class CalculationT =
527 llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
ConstFoldBinaryOpDenseDense(Type result_type,DenseElementsAttr lhs,DenseElementsAttr rhs,const CalculationT & calculate)528 Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs,
529 DenseElementsAttr rhs,
530 const CalculationT &calculate) {
531 auto type = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType())
532 .dyn_cast_or_null<ShapedType>();
533 if (!type) {
534 return {};
535 }
536
537 const bool rhs_is_splat = rhs.isSplat();
538 const bool lhs_is_splat = lhs.isSplat();
539
540 // If both of them are splat, compute and return.
541 if (lhs_is_splat && rhs_is_splat) {
542 auto element_result = AttrElementT::get(
543 type.getElementType(), calculate(lhs.getSplatValue<ElementValueT>(),
544 rhs.getSplatValue<ElementValueT>()));
545 if (!element_result) return {};
546
547 return DenseElementsAttr::get(type, element_result);
548 }
549
550 auto num_elements = type.getNumElements();
551
552 SmallVector<ElementValueT, 16> new_values;
553 new_values.reserve(num_elements);
554 const auto result_shape = type.getShape();
555 std::vector<int64_t> current_index(type.getRank(), 0);
556 // Create the new shape with ones padded to the left.
557 const std::vector<int64_t> lhs_new_shape =
558 GetPaddedShape(lhs.getType().getShape(), type.getRank());
559 const std::vector<int64_t> rhs_new_shape =
560 GetPaddedShape(rhs.getType().getShape(), type.getRank());
561
562 auto lhs_old_values = lhs.getValues<ElementValueT>();
563 auto rhs_old_values = rhs.getValues<ElementValueT>();
564
565 // Add each pair of the corresponding values in the dense elements
566 // attributes.
567 for (int64_t i = 0; i < num_elements; ++i) {
568 // current_index represents the index
569 // in the N-dimension tensor. GetElementIndex returns
570 // the index in the flat representation of the original tensor
571 // to use.
572 const int64_t lhs_index =
573 lhs_is_splat ? 0 : GetElementIndex(lhs_new_shape, current_index);
574 const int64_t rhs_index =
575 rhs_is_splat ? 0 : GetElementIndex(rhs_new_shape, current_index);
576
577 new_values.push_back(calculate(*(lhs_old_values.begin() + lhs_index),
578 *(rhs_old_values.begin() + rhs_index)));
579 IncrementIndex(result_shape, ¤t_index);
580 }
581 return DenseElementsAttr::get(type, ArrayRef<ElementValueT>(new_values));
582 }
583
584 /// Performs const folding `calculate` with broadcast behavior on the two
585 /// attributes `operand1` and `operand2` and returns the result if possible.
586 /// This function assumes the two operands are verified to have value
587 /// attributes of broadcastable types.
588 template <class AttrElementT,
589 class ElementValueT = typename AttrElementT::ValueType,
590 class CalculationT =
591 llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
ConstFoldBinaryOp(Type result_type,Attribute operand1,Attribute operand2,const CalculationT & calculate)592 Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1,
593 Attribute operand2, const CalculationT &calculate) {
594 if (operand1.dyn_cast_or_null<DenseElementsAttr>() &&
595 operand2.dyn_cast_or_null<DenseElementsAttr>()) {
596 return ConstFoldBinaryOpDenseDense<AttrElementT, ElementValueT>(
597 result_type, operand1.cast<DenseElementsAttr>(),
598 operand2.cast<DenseElementsAttr>(), calculate);
599 }
600
601 // TODO: support other attribute kinds
602
603 return {};
604 }
605
606 /// Performs const folding with broadcast behavior on the two attributes in
607 /// `operands` and returns the result if possible.
608 /// Depending on the given `resultType`, either `floatCalculate` or
609 /// `intCalculate` is chosen to conduct the calculate.
ConstFoldBinaryOp(Type result_type,ArrayRef<Attribute> operands,llvm::function_ref<APFloat (APFloat,APFloat)> float_calculate,llvm::function_ref<APInt (APInt,APInt)> int_calculate)610 Attribute ConstFoldBinaryOp(
611 Type result_type, ArrayRef<Attribute> operands,
612 llvm::function_ref<APFloat(APFloat, APFloat)> float_calculate,
613 llvm::function_ref<APInt(APInt, APInt)> int_calculate) {
614 // Note: All types are wrapped in tensor types in TFlite. E.g., f32 is
615 // represented as tensor<f32>. So we are only handling tensor types here.
616 auto type = result_type.dyn_cast<ShapedType>();
617 if (!type) return {};
618
619 auto elemType = type.getElementType();
620
621 if (elemType.isa<FloatType>())
622 return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
623 float_calculate);
624
625 if (elemType.isSignlessInteger())
626 return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
627 int_calculate);
628
629 return {};
630 }
631
632 /// Performs const folding a attributes `operand` and returns the result if
633 /// possible.
634 /// The function currently asserts that the `result_type` to be a f32 tensor
635 /// type.
636 /// TODO: Extend this function to handle integral tensor for ops like
637 /// "tfl.logical_not".
ConstFoldUnaryOp(Type result_type,Attribute operand,llvm::function_ref<APFloat (APFloat)> calculate)638 Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
639 llvm::function_ref<APFloat(APFloat)> calculate) {
640 assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type));
641 auto result_shape_type = result_type.cast<ShapedType>();
642
643 if (!result_shape_type.hasStaticShape()) return {};
644
645 if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
646 SmallVector<APFloat, 16> new_values;
647 const int num_elements = result_shape_type.getNumElements();
648 new_values.reserve(num_elements);
649
650 for (const APFloat &old_value : dense_elements.getValues<APFloat>()) {
651 new_values.push_back(calculate(old_value));
652 }
653
654 return DenseElementsAttr::get(result_shape_type, new_values);
655 }
656
657 return {};
658 }
659
buildComparisonBinOp(Builder * builder,OperationState & result,Value lhs,Value rhs)660 void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
661 Value rhs) {
662 auto result_type =
663 OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
664 if (!result_type)
665 emitError(result.location)
666 << "non-broadcastable operands: " << lhs.getType() << " and "
667 << rhs.getType();
668 result.addOperands({lhs, rhs});
669 // Comparison binary ops always return i1 tensor.
670 if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) {
671 auto result_shape = shaped_type.getShape();
672 result.types.push_back(
673 RankedTensorType::get(result_shape, builder->getI1Type()));
674 } else {
675 result.types.push_back(UnrankedTensorType::get(builder->getI1Type()));
676 }
677 }
678
buildFusedBroadcastableBinOp(Builder * builder,OperationState & result,Value lhs,Value rhs,StringAttr fused_activation_function)679 void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
680 Value lhs, Value rhs,
681 StringAttr fused_activation_function) {
682 auto result_type =
683 OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
684
685 if (!result_type)
686 emitError(result.location)
687 << "non-broadcastable operands: " << lhs.getType() << " and "
688 << rhs.getType();
689
690 result.addOperands({lhs, rhs});
691 result.addAttribute("fused_activation_function", fused_activation_function);
692 result.types.push_back(result_type);
693 }
694
695 } // end anonymous namespace
696
697 //===----------------------------------------------------------------------===//
698 // AddOp
699 //===----------------------------------------------------------------------===//
700
fold(ArrayRef<Attribute> operands)701 OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
702 // TODO(b/142478136): Handle fused ops.
703 if (fused_activation_function() != "NONE") return {};
704 return ConstFoldBinaryOp(
705 getType(), operands, [](APFloat a, APFloat b) { return a + b; },
706 [](APInt a, APInt b) { return a + b; });
707 }
708
GetArithmeticCount(Operation * op)709 int64_t AddOp::GetArithmeticCount(Operation *op) {
710 int64_t count;
711 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
712
713 return -1;
714 }
715
716 //===----------------------------------------------------------------------===//
717 // ConcatenationOp
718 //===----------------------------------------------------------------------===//
719 // TODO(ashwinm): Implement shape inference for Concatenation
720
721 namespace {
722
GetConcatenationOpAxis(ConcatenationOp op)723 int64_t GetConcatenationOpAxis(ConcatenationOp op) {
724 auto output_type = op.output().getType().cast<RankedTensorType>();
725 int32_t axis = op.axis();
726 if (axis < 0) axis += output_type.getRank();
727 return axis;
728 }
729
730 // Verify operand types and the result type:
731 //
732 // 1. Operand type ranks must be equal to the output type rank.
733 //
734 // 2. Operand dimension sizes (except dimension `axis`) must be equal to
735 // previously seen dimension sizes of the same dimension.
736 //
737 // 3. Sum of operand dimension sizes of the `axis` dimension must be equal to
738 // the dimension size of the `axis` dimension of output.
739 //
740 // Note: If an operand has unranked tensor type or has dynamic dimension size,
741 // those dimensions will be skipped.
VerifyConcatenationOpTypes(Operation * op,RankedTensorType output_type,ArrayRef<TensorType> operand_types,int64_t axis)742 LogicalResult VerifyConcatenationOpTypes(Operation *op,
743 RankedTensorType output_type,
744 ArrayRef<TensorType> operand_types,
745 int64_t axis) {
746 const int64_t output_rank = output_type.getRank();
747
748 constexpr int64_t kDynamicSize = -1;
749 SmallVector<int64_t, 4> result_dim_sizes_loc(output_rank, -1);
750 SmallVector<int64_t, 4> result_dim_sizes(output_type.getShape().begin(),
751 output_type.getShape().end());
752 result_dim_sizes[axis] = 0;
753
754 auto FormatLoc = [&result_dim_sizes_loc](int64_t dim) {
755 const int64_t loc = result_dim_sizes_loc[dim];
756 if (loc == -1) return std::string("output");
757 return llvm::formatv("operand #{0}", loc).str();
758 };
759
760 for (const auto &operand : llvm::enumerate(operand_types)) {
761 auto operand_type = operand.value().dyn_cast<RankedTensorType>();
762 if (!operand_type) {
763 result_dim_sizes[axis] = kDynamicSize;
764 continue;
765 }
766
767 const int64_t operand_rank = operand_type.getRank();
768 if (operand_rank != output_rank)
769 return op->emitOpError() << "rank of operand #" << operand.index()
770 << " must be equal to rank of output, expected "
771 << output_rank << ", got " << operand_rank;
772
773 for (int64_t dim = 0; dim < output_rank; ++dim) {
774 const int64_t operand_dim_size = operand_type.getDimSize(dim);
775 const int64_t result_dim_size = result_dim_sizes[dim];
776
777 if (dim == axis) {
778 if (ShapedType::isDynamic(operand_dim_size) ||
779 ShapedType::isDynamic(result_dim_size)) {
780 result_dim_sizes[axis] = kDynamicSize;
781 } else {
782 result_dim_sizes[axis] += operand_dim_size;
783 }
784 continue;
785 }
786
787 if (ShapedType::isDynamic(operand_dim_size)) continue;
788
789 if (ShapedType::isDynamic(result_dim_size)) {
790 result_dim_sizes[dim] = operand_dim_size;
791 result_dim_sizes_loc[dim] = operand.index();
792 continue;
793 }
794
795 if (result_dim_size != operand_dim_size)
796 return op->emitOpError()
797 << "dimension size of dimension #" << dim << " of operand #"
798 << operand.index() << " must be equal to "
799 << "dimension size of dimension #" << dim << " of "
800 << FormatLoc(dim) << ", expected " << result_dim_size << ", got "
801 << operand_dim_size;
802 }
803 }
804
805 const int64_t output_concated_dim_size = output_type.getDimSize(axis);
806 if (!ShapedType::isDynamic(output_concated_dim_size) &&
807 !ShapedType::isDynamic(result_dim_sizes[axis]) &&
808 result_dim_sizes[axis] != output_concated_dim_size)
809 return op->emitOpError()
810 << "dimension size of dimension #" << axis << " of output "
811 << "must be equal to the sum of dimension sizes of dimension #"
812 << axis << ", expected " << result_dim_sizes[axis] << ", got "
813 << output_concated_dim_size;
814
815 return success();
816 }
817
818 // Returns true when all operands are instances of DenseElementsAttr and the
819 // output type has a static shape.
IsConcatenationOpConstFoldable(ConcatenationOp op,ArrayRef<Attribute> operands,RankedTensorType output_type,int64_t axis)820 bool IsConcatenationOpConstFoldable(ConcatenationOp op,
821 ArrayRef<Attribute> operands,
822 RankedTensorType output_type,
823 int64_t axis) {
824 if (operands.empty()) return false;
825 if (!output_type.hasStaticShape()) return false;
826 if (axis < 0) return false;
827
828 return llvm::all_of(operands, [](Attribute operand) {
829 return operand && operand.isa<DenseElementsAttr>();
830 });
831 }
832
ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,RankedTensorType output_type,int64_t axis)833 DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,
834 RankedTensorType output_type,
835 int64_t axis) {
836 const auto outer_dims = output_type.getShape().take_front(axis);
837 const int64_t outer_size = std::accumulate(
838 outer_dims.begin(), outer_dims.end(), 1, std::multiplies<int64_t>());
839
840 const auto base_inner_dims = output_type.getShape().drop_front(axis + 1);
841 const int64_t base_inner_size =
842 std::accumulate(base_inner_dims.begin(), base_inner_dims.end(), 1,
843 std::multiplies<int64_t>());
844
845 // Splits each input operand into outer_size pieces and combines them in
846 // round-robin ordering.
847 std::vector<Attribute> out_attrs(output_type.getNumElements());
848 int64_t out = 0;
849 for (int64_t outer = 0; outer < outer_size; ++outer) {
850 for (auto op : operands) {
851 auto typed_attr = op.cast<TypedAttr>();
852 const int64_t dim_size =
853 typed_attr.getType().cast<RankedTensorType>().getDimSize(axis);
854 const int64_t inner_size = dim_size * base_inner_size;
855
856 auto input_attrs = op.cast<DenseElementsAttr>().getValues<Attribute>();
857 auto input_iter = input_attrs.begin() + outer * inner_size;
858 for (int64_t inner = 0; inner < inner_size; ++inner)
859 out_attrs[out++] = *input_iter++;
860 }
861 }
862
863 return DenseElementsAttr::get(output_type, out_attrs);
864 }
865
866 } // end anonymous namespace
867
verify()868 LogicalResult ConcatenationOp::verify() {
869 ConcatenationOp op = *this;
870 auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
871
872 // If the output type is unranked, there is nothing else to be verified.
873 if (!output_type) return success();
874
875 const int64_t axis = GetConcatenationOpAxis(op);
876 if (axis < 0 || axis >= output_type.getRank())
877 return op.emitOpError("concatenation dimension must be in [-rank, rank)");
878
879 SmallVector<TensorType, 4> operand_types;
880 for (Value operand : op.values())
881 operand_types.push_back(operand.getType().cast<TensorType>());
882
883 return VerifyConcatenationOpTypes(op.getOperation(), output_type,
884 operand_types, axis);
885 }
886
fold(ArrayRef<Attribute> operands)887 OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
888 if (fused_activation_function() == "NONE") {
889 if (auto output_type = output().getType().dyn_cast<RankedTensorType>()) {
890 const int64_t axis = GetConcatenationOpAxis(*this);
891 if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis))
892 return ConstFoldConcatenateOpDense(operands, output_type, axis);
893 }
894 }
895
896 // Remove all empty values.
897 SmallVector<Value, 4> non_empty_values;
898 for (Value value : this->values()) {
899 const auto shaped_type = value.getType().cast<ShapedType>();
900 if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
901 continue;
902 }
903 non_empty_values.push_back(value);
904 }
905
906 // All are not empty, do nothing.
907 if (non_empty_values.size() == getNumOperands()) return nullptr;
908
909 // If only one input is non-empty, just return it as the result of folding.
910 if (non_empty_values.size() == 1) {
911 return non_empty_values[0];
912 }
913
914 // Otherwise, build a new concatenation op with non-empty values.
915 mlir::OpBuilder builder(getOperation());
916 auto new_concat = builder.create<TFL::ConcatenationOp>(
917 getLoc(), getType(), non_empty_values,
918 builder.getIntegerAttr(builder.getIntegerType(32), axis()),
919 builder.getStringAttr(fused_activation_function()));
920 return new_concat.getResult();
921 }
922
923 //===----------------------------------------------------------------------===//
924 // CustomOp
925 //===----------------------------------------------------------------------===//
926
927 // TODO(b/241745316): Confirm that this is always valid
verify()928 mlir::LogicalResult CustomOp::verify() {
929 // Currently, this is always valid as it is a wrapper around a StringRef of 0
930 // or more characters.
931 return success();
932 }
933
934 //===----------------------------------------------------------------------===//
935 // CustomTfOp
936 //===----------------------------------------------------------------------===//
937
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attr,RegionRange ranges,SmallVectorImpl<Type> & inferredReturnTypes)938 LogicalResult CustomTfOp::inferReturnTypes(
939 MLIRContext *, Optional<Location> location, ValueRange operands,
940 DictionaryAttr attr, RegionRange ranges,
941 SmallVectorImpl<Type> &inferredReturnTypes) {
942 CustomTfOpAdaptor op(operands, attr, ranges);
943
944 if (op.getRegions().empty()) return success();
945 auto *real_op = &op.body().front().front();
946 if (llvm::isa<TF::FakeQuantWithMinMaxArgsOp, TF::FakeQuantWithMinMaxVarsOp,
947 TF::FakeQuantWithMinMaxVarsPerChannelOp>(real_op)) {
948 Value input = *operands.begin();
949 inferredReturnTypes.assign({input.getType()});
950 }
951 return success();
952 }
953
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)954 bool CustomTfOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
955 if (lhs.empty()) return true;
956 if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
957 if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
958 return true;
959 }
960
961 //===----------------------------------------------------------------------===//
962 // Gather op
963 //===----------------------------------------------------------------------===//
964
verify()965 LogicalResult GatherOp::verify() {
966 GatherOp op = *this;
967 ShapedType params_type = op.params().getType().cast<ShapedType>();
968 // TFLite gather kernel supports 1D string input only.
969 if (params_type.getElementType().isa<mlir::TF::StringType>()) {
970 if (params_type.hasRank() && params_type.getRank() != 1) {
971 return op.emitOpError(
972 "expect 1d input when the given type is string, got ")
973 << params_type;
974 }
975 }
976 return mlir::success();
977 }
978
979 //===----------------------------------------------------------------------===//
980 // BroadcastToOp
981 //===----------------------------------------------------------------------===//
982
983 // Canonicalizes BroadcastToOp to ReshapeOp if the input and output has the same
984 // number of elements.
985 struct ConvertBroadcastToReshape : public OpRewritePattern<BroadcastToOp> {
986 using OpRewritePattern<BroadcastToOp>::OpRewritePattern;
987
matchAndRewritemlir::TFL::ConvertBroadcastToReshape988 LogicalResult matchAndRewrite(BroadcastToOp op,
989 PatternRewriter &rewriter) const override {
990 auto input_type = op.input().getType().cast<ShapedType>();
991 auto output_type = op.getType().cast<ShapedType>();
992 if (!input_type.hasStaticShape() || !output_type.hasStaticShape() ||
993 input_type.getNumElements() != output_type.getNumElements()) {
994 return failure();
995 }
996 // Reshape op supports only new shape as I32. Add a cast op to I32 always
997 // to make sure the introduced Reshape Op is a valid one.
998 auto result_type = RankedTensorType::get(
999 op.shape().getType().cast<RankedTensorType>().getShape(),
1000 rewriter.getI32Type());
1001 auto cast_op =
1002 rewriter.create<TFL::CastOp>(op->getLoc(), result_type, op.shape());
1003
1004 rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.input(),
1005 cast_op);
1006 return success();
1007 }
1008 };
1009
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1010 void BroadcastToOp::getCanonicalizationPatterns(RewritePatternSet &results,
1011 MLIRContext *context) {
1012 results.add<ConvertBroadcastToReshape>(context);
1013 }
1014
1015 //===----------------------------------------------------------------------===//
1016 // FullyConnectedOp
1017 //===----------------------------------------------------------------------===//
1018
verify()1019 LogicalResult FullyConnectedOp::verify() {
1020 FullyConnectedOp op = *this;
1021 ShapedType input_type = op.input().getType().cast<ShapedType>();
1022 ShapedType filter_type = op.filter().getType().cast<ShapedType>();
1023 if (filter_type.hasRank() && filter_type.getRank() != 2) {
1024 return op.emitOpError("expect 2d filter, got ") << filter_type;
1025 }
1026
1027 if (!input_type.hasStaticShape() || !filter_type.hasStaticShape()) {
1028 return mlir::success();
1029 }
1030
1031 // Input's element size must be multiple of parameter's z_in dimension.
1032 const int z_in = filter_type.getDimSize(1);
1033 const int num_input_elements = input_type.getNumElements();
1034 if (z_in != 0 && num_input_elements % z_in != 0) {
1035 return op.emitOpError(llvm::formatv(
1036 "expect 'input' num_elements % {0} == 0, got input type ", z_in))
1037 << input_type;
1038 }
1039
1040 // TODO(jpienaar): Include more shape verification for SHUFFLED4x16INT8
1041 // format.
1042 if (op.weights_format() == "DEFAULT") {
1043 ShapedType output_type =
1044 (*op.output().begin()).getType().cast<ShapedType>();
1045 if (!output_type.hasStaticShape()) {
1046 return mlir::success();
1047 }
1048
1049 const int num_output_elements = output_type.getNumElements();
1050 const int z_out = filter_type.getDimSize(0);
1051 if (num_output_elements % z_out != 0) {
1052 return op.emitOpError(llvm::formatv(
1053 "expect 'output' num_elements % {0} == 0, got ", z_out))
1054 << output_type;
1055 }
1056
1057 if (z_in != 0 && num_input_elements / z_in != num_output_elements / z_out) {
1058 return op.emitOpError(
1059 "num_input_elements / z_in != num_output_elements / z_out");
1060 }
1061 }
1062
1063 return mlir::success();
1064 }
1065
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1066 LogicalResult FullyConnectedOp::fold(ArrayRef<Attribute> operands,
1067 SmallVectorImpl<OpFoldResult> &results) {
1068 assert(operands.size() == 3);
1069
1070 // Folding not implemented with any activation function or any weight type
1071 // besides the default.
1072 if (fused_activation_function() != "NONE") return failure();
1073 if (weights_format() != "DEFAULT") return failure();
1074
1075 // Bias tensor is optional.
1076 const bool has_bias = !(!bias() || bias().getType().isa<NoneType>());
1077
1078 // Get the tensors.
1079 DenseElementsAttr input_tensor, weights_tensor, bias_tensor;
1080 if (!matchPattern(input(), m_Constant(&input_tensor)) ||
1081 !matchPattern(filter(), m_Constant(&weights_tensor)) ||
1082 (has_bias && !matchPattern(bias(), m_Constant(&bias_tensor)))) {
1083 return failure();
1084 }
1085
1086 // Get the tensor types.
1087 const auto input_type = input_tensor.getType().cast<ShapedType>();
1088 const auto weights_type = weights_tensor.getType().cast<ShapedType>();
1089 const auto bias_type =
1090 has_bias ? bias_tensor.getType().cast<ShapedType>() : ShapedType{};
1091
1092 const auto output_type = getType(0).cast<ShapedType>();
1093
1094 // Folding only implemented for float tensors.
1095 if (!input_type.getElementType().isF32() ||
1096 !weights_type.getElementType().isF32() ||
1097 !output_type.getElementType().isF32() ||
1098 (has_bias && !bias_type.getElementType().isF32())) {
1099 return failure();
1100 }
1101
1102 // Folding only implemented for static shapes
1103 if (!input_type.hasStaticShape() || !weights_type.hasStaticShape() ||
1104 (has_bias && !bias_type.hasStaticShape())) {
1105 return failure();
1106 }
1107
1108 // Folding only implemented for 1D input, 2D weights and 1D bias
1109 if (input_type.getShape().size() != 1 ||
1110 weights_type.getShape().size() != 2 ||
1111 (has_bias && bias_type.getShape().size() != 1)) {
1112 return failure();
1113 }
1114
1115 // Get the sizes
1116 const auto input_size = input_type.getNumElements();
1117 const auto output_size = output_type.getNumElements();
1118
1119 // Get iterators to the tensors.
1120 const auto input_values_it = input_tensor.getValues<float>().begin();
1121 const auto weights_values_ptr = weights_tensor.getValues<float>().begin();
1122 auto weights_row_it = weights_values_ptr;
1123 // The 'else' case could be nullptr, but the types don't match.
1124 auto bias_values_it =
1125 has_bias ? bias_tensor.getValues<float>().begin() : input_values_it;
1126
1127 // Do the actual folding, one output at a time.
1128 std::vector<float> result_values;
1129 result_values.reserve(output_size);
1130
1131 for (int i = 0; i < output_size; ++i) {
1132 // Dot product with Kahan/Neumaier summation to minimize numeric errors.
1133 float sum = has_bias ? *bias_values_it : 0.0f;
1134 float compensation = 0.0f;
1135 for (int j = 0; j < input_size; ++j) {
1136 const float addend = input_values_it[j] * weights_row_it[j];
1137 const float new_sum = sum + addend;
1138 // DO NOT enable -funsafe-math-optimizations here.
1139 // There is a test detecting unsafe optimizations.
1140 // Unsafe math optimizations can reorder float formulas, and set the
1141 // compensation to constant 0. The formula must be evaluated as written
1142 // for the algorithm to work.
1143 // (Note: -ffast-math is a superset of -funsafe-math-optimizations.)
1144 if (std::abs(sum) >= std::abs(addend)) {
1145 compensation += (sum - new_sum) + addend;
1146 } else {
1147 compensation += (addend - new_sum) + sum;
1148 }
1149 sum = new_sum;
1150 }
1151 result_values.push_back(sum + compensation);
1152 weights_row_it += input_size;
1153 bias_values_it++;
1154 }
1155
1156 // Set result tensor
1157 const auto folded =
1158 DenseElementsAttr::get(output_type, ArrayRef<float>(result_values));
1159 results.assign({folded});
1160
1161 return success();
1162 }
1163
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1164 void FullyConnectedOp::getCanonicalizationPatterns(RewritePatternSet &results,
1165 MLIRContext *context) {
1166 results.add<RemoveOptionalZeroBias<FullyConnectedOp>>(context);
1167 }
1168
GetArithmeticCount(Operation * op)1169 int64_t FullyConnectedOp::GetArithmeticCount(Operation *op) {
1170 int64_t count;
1171 if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
1172 op, &count))
1173 return count;
1174
1175 return -1;
1176 }
1177
1178 //===----------------------------------------------------------------------===//
1179 // Conv2DOp
1180 //===----------------------------------------------------------------------===//
1181
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1182 void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
1183 MLIRContext *context) {
1184 // TODO(b/180121750): Enable the pattern after the integration tests are
1185 // fixed.
1186 // results.add<RemoveOptionalZeroBias<Conv2DOp>>(context);
1187 }
1188
ComputeConvWindowedOutputSize(int64_t input_size,int64_t filter_size,int64_t dilation_rate,int64_t stride,tensorflow::Padding padding,int64_t * output_size)1189 static LogicalResult ComputeConvWindowedOutputSize(
1190 int64_t input_size, int64_t filter_size, int64_t dilation_rate,
1191 int64_t stride, tensorflow::Padding padding, int64_t *output_size) {
1192 int64_t pad_low;
1193 int64_t pad_high;
1194
1195 tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
1196 input_size, filter_size, dilation_rate, stride, padding, output_size,
1197 &pad_low, &pad_high);
1198 // Return failure if expected_output_size could not be calculated.
1199 if (!status.ok()) return failure();
1200 return success();
1201 }
1202
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1203 LogicalResult Conv2DOp::inferReturnTypes(
1204 MLIRContext *, Optional<Location> location, ValueRange operands,
1205 DictionaryAttr attr, RegionRange,
1206 SmallVectorImpl<Type> &inferredReturnTypes) {
1207 Conv2DOpAdaptor op(operands, attr);
1208
1209 const Value input = op.input();
1210 const Value filter = op.filter();
1211
1212 const RankedTensorType input_ty =
1213 input.getType().dyn_cast_or_null<RankedTensorType>();
1214 const RankedTensorType filter_ty =
1215 filter.getType().dyn_cast_or_null<RankedTensorType>();
1216 // If indeed both input type & filter type are ranked type and have ranks.
1217 // We will need to check their ranks are valid.
1218 if ((input_ty && input_ty.hasRank() && input_ty.getRank() != 4) ||
1219 (filter_ty && filter_ty.hasRank() && filter_ty.getRank() != 4)) {
1220 return emitOptionalError(location, "Invalid ranks");
1221 }
1222
1223 // If either input or filter is unranked, we will just return unranked output
1224 // shape.
1225 if (!input_ty || !filter_ty || !input_ty.hasRank() || !filter_ty.hasRank()) {
1226 Type result_type;
1227 result_type = UnrankedTensorType::get(
1228 input.getType().cast<ShapedType>().getElementType());
1229 inferredReturnTypes.assign({result_type});
1230 return success();
1231 }
1232
1233 auto stride_h = op.stride_hAttr().getInt();
1234 auto stride_w = op.stride_wAttr().getInt();
1235 auto dilation_h = op.dilation_h_factorAttr().getInt();
1236 auto dilation_w = op.dilation_w_factorAttr().getInt();
1237
1238 // We don't have EXPLICIT PADDING in TfLite.
1239 auto paddings = op.padding();
1240 tensorflow::Padding padding;
1241 auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding);
1242 if (!padding_is_valid.ok()) {
1243 return emitOptionalError(location, "invalid padding format provided");
1244 }
1245
1246 // Output always have rank 4. All dimensions are initialized to
1247 // dynamic size and can be partially inferred.
1248 // TFL's conv2d is always NHWC format & the filter is OHWI.
1249 SmallVector<int64_t, 4> return_shape(4, ShapedType::kDynamicSize);
1250 return_shape[0] = input_ty.getDimSize(0);
1251 return_shape[3] = filter_ty.getDimSize(0);
1252
1253 // Spatial dimensions can be inferred only when both input and filter are
1254 // ranked because we need to get their spatial dimensions.
1255
1256 // Height.
1257 if (!input_ty.isDynamicDim(1) && !filter_ty.isDynamicDim(1)) {
1258 int64_t output_height;
1259 if (failed(ComputeConvWindowedOutputSize(
1260 input_ty.getDimSize(1), filter_ty.getDimSize(1), dilation_h,
1261 stride_h, padding, &output_height))) {
1262 return failure();
1263 }
1264 return_shape[1] = output_height;
1265 }
1266
1267 // Width.
1268 if (!input_ty.isDynamicDim(2) && !filter_ty.isDynamicDim(2)) {
1269 int64_t output_width;
1270 if (failed(ComputeConvWindowedOutputSize(
1271 input_ty.getDimSize(2), filter_ty.getDimSize(2), dilation_w,
1272 stride_w, padding, &output_width))) {
1273 return failure();
1274 }
1275 return_shape[2] = output_width;
1276 }
1277
1278 auto result_type =
1279 mlir::RankedTensorType::get(return_shape, input_ty.getElementType());
1280
1281 inferredReturnTypes.assign({result_type});
1282 return success();
1283 }
1284
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)1285 bool Conv2DOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
1286 if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
1287 if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
1288 return true;
1289 }
1290
GetArithmeticCount(Operation * op)1291 int64_t Conv2DOp::GetArithmeticCount(Operation *op) {
1292 int64_t count;
1293 if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
1294 op, &count))
1295 return count;
1296
1297 return -1;
1298 }
1299
1300 //===----------------------------------------------------------------------===//
1301 // DepthwiseConv2DO
1302 //===----------------------------------------------------------------------===//
1303
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1304 void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
1305 MLIRContext *context) {
1306 // TODO(b/180121750): Enable the pattern after the integration tests are
1307 // fixed.
1308 // results.add<RemoveOptionalZeroBias<DepthwiseConv2DOp>>(context);
1309 }
1310
GetArithmeticCount(Operation * op)1311 int64_t DepthwiseConv2DOp::GetArithmeticCount(Operation *op) {
1312 int64_t count;
1313 if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
1314 op, &count))
1315 return count;
1316
1317 return -1;
1318 }
1319
1320 //===----------------------------------------------------------------------===//
1321 // GatherOp
1322 //===----------------------------------------------------------------------===//
1323
BuildGatherOp(OpBuilder * builder,OperationState & result,Value params,Value indices,IntegerAttr axis,IntegerAttr batch_dims)1324 static void BuildGatherOp(OpBuilder *builder, OperationState &result,
1325 Value params, Value indices, IntegerAttr axis,
1326 IntegerAttr batch_dims) {
1327 auto params_type = params.getType().cast<TensorType>();
1328 auto indices_type = indices.getType().cast<TensorType>();
1329
1330 // If params/indices is unranked, then output is unranked.
1331 if (!params_type.hasRank() || !indices_type.hasRank())
1332 return TFL::GatherOp::build(
1333 *builder, result, UnrankedTensorType::get(params_type.getElementType()),
1334 params, indices, axis, batch_dims);
1335
1336 int64_t params_rank = params_type.getRank();
1337 int64_t indices_rank = indices_type.getRank();
1338
1339 // params rank is guaranteed to be at least 1.
1340 // Produces an output tensor with shape:
1341 // params.shape[:axis] + indices.shape + params.shape[axis + 1:]
1342 std::vector<int64_t> shape(params_type.getShape());
1343 int64_t axis_i = axis.getInt();
1344
1345 // For neg axis values, we wrap around params, e.g. axis = -1 => params[:-1]
1346 if (axis_i < 0) {
1347 axis_i += params_rank;
1348 }
1349
1350 // params must be at least rank axis + 1
1351 if (params_rank < axis_i + 1) {
1352 emitError(result.location, "params must be at least rank axis + 1");
1353 }
1354
1355 int64_t batch_dims_i = batch_dims.getInt();
1356 if (batch_dims_i < 0) {
1357 batch_dims_i += indices_rank;
1358 }
1359
1360 if (batch_dims_i > axis_i) {
1361 emitError(result.location,
1362 "axis should be bigger than or equal to batch_dims");
1363 }
1364 if (batch_dims_i >= params_rank || batch_dims_i > indices_rank) {
1365 emitError(result.location,
1366 "batch_dims must be smaller than params' rank and smaller than "
1367 "or equal to indices'rank");
1368 }
1369 for (int i = 0; i < batch_dims_i; ++i) {
1370 if (indices_type.getShape()[i] != params_type.getShape()[i]) {
1371 emitError(result.location,
1372 "batch dimensions of params must be equal to batch dimensions "
1373 "of indices");
1374 }
1375 }
1376
1377 if ((indices_rank == 0) || (indices_rank == batch_dims_i)) {
1378 // Scalar indices (output is rank(params) - 1).
1379 // Erase shape[axis]
1380 shape.erase(shape.begin() + axis_i);
1381 } else if (indices_rank == 1) {
1382 // Vector indices (output is rank(params)).
1383 // Copy indices.shape into params.shape[axis]
1384 std::copy(std::begin(indices_type.getShape()),
1385 std::end(indices_type.getShape()), std::begin(shape) + axis_i);
1386 } else {
1387 // Higher rank indices (output is rank(params) + rank(indices) - 1).
1388 shape.resize(params_rank + indices_rank - 1 - batch_dims_i);
1389 // Copy params.shape[axis + 1: ] into shape[axis + indices_rank:]
1390 std::copy(std::begin(params_type.getShape()) + axis_i + 1,
1391 std::end(params_type.getShape()),
1392 std::begin(shape) + axis_i + indices_rank - batch_dims_i);
1393
1394 // Copy indices.shape into params.shape[axis]
1395 std::copy(std::begin(indices_type.getShape()) + batch_dims_i,
1396 std::end(indices_type.getShape()), std::begin(shape) + axis_i);
1397 }
1398
1399 TFL::GatherOp::build(
1400 *builder, result,
1401 RankedTensorType::get(shape, params_type.getElementType()), params,
1402 indices, axis, batch_dims);
1403 }
1404
1405 //===----------------------------------------------------------------------===//
1406 // ScatterNdOp
1407 //===----------------------------------------------------------------------===//
1408
verify()1409 mlir::LogicalResult ScatterNdOp::verify() {
1410 ScatterNdOp op = *this;
1411 auto indices = op.indices();
1412 auto updates = op.updates();
1413 auto shape = op.shape();
1414 auto output = op.output();
1415
1416 auto updates_type = updates.getType().cast<ShapedType>();
1417 auto indices_type = indices.getType().cast<ShapedType>();
1418
1419 if (!indices_type.hasStaticShape() || !updates_type.hasStaticShape()) {
1420 return success();
1421 }
1422
1423 // Checks if the shape of `updates` is a tensor of shape
1424 // `indices.shape[:-1] + shape[indices.shape[-1]:]`, as described in
1425 // ScatterNd op description.
1426
1427 auto outer_dims = indices_type.getRank() - 1;
1428 auto outermost_dim = indices_type.getDimSize(outer_dims);
1429 // Checks whether the first `outer_dims` dimensions of `indices` and
1430 // `updates` are equal.
1431 for (auto i = 0; i < outer_dims; i++) {
1432 if (indices_type.getDimSize(i) != updates_type.getDimSize(i)) {
1433 return op.emitOpError()
1434 << "indices.Dims(" << i << ") == " << indices_type.getDimSize(i)
1435 << ", but updates.Dims(" << i
1436 << ") == " << updates_type.getDimSize(i);
1437 }
1438 }
1439
1440 auto output_type = output.getType().cast<ShapedType>();
1441 auto shape_type = shape.getType().cast<ShapedType>();
1442 if (shape_type.hasStaticShape()) {
1443 // Check the rank of `shape`.
1444 auto output_rank = outermost_dim + updates_type.getRank() - outer_dims;
1445 if (shape_type.getDimSize(0) != output_rank) {
1446 return op.emitOpError()
1447 << "shape must be a vector of length " << output_rank;
1448 }
1449 if (output_type.hasRank()) {
1450 if (output_type.getRank() != output_rank) {
1451 return op.emitOpError()
1452 << "output must have the same rank with the length of shape = "
1453 << output_rank;
1454 }
1455 }
1456 }
1457
1458 DenseIntElementsAttr shape_value;
1459 if (matchPattern(shape, m_Constant(&shape_value))) {
1460 for (const auto shape_elem : shape_value) {
1461 if (shape_elem.getSExtValue() <= 0) {
1462 return op.emitOpError("all elements of shape must be > 0");
1463 }
1464 }
1465
1466 // Checks whether the last `(shape_type.getDimSize(0) - outermost_dim)`
1467 // dimensions of `updates` and `shape` are equal.
1468 for (const auto &shape_it : llvm::enumerate(shape_value)) {
1469 int64_t i = shape_it.index();
1470 auto value = shape_it.value().getSExtValue();
1471 if (i >= outermost_dim) {
1472 auto corresponding_dim = i - outermost_dim + outer_dims;
1473 if (value != updates_type.getDimSize(corresponding_dim)) {
1474 return op.emitOpError()
1475 << "updates.Dims(" << i
1476 << ") == " << updates_type.getDimSize(corresponding_dim)
1477 << ", but shape[" << i << "] == " << value;
1478 }
1479 }
1480 }
1481
1482 // Checks if the output has the shape specified by `shape`.
1483 if (output_type.hasStaticShape()) {
1484 for (const auto &shape_it : llvm::enumerate(shape_value)) {
1485 int i = shape_it.index();
1486 auto value = shape_it.value().getSExtValue();
1487 if (output_type.getDimSize(i) != value) {
1488 return op.emitOpError()
1489 << "output shape [" << output_type.getShape()
1490 << "] must be equal to the value of shape " << shape_value;
1491 }
1492 }
1493 }
1494 }
1495 return success();
1496 }
1497
1498 //===----------------------------------------------------------------------===//
1499 // MulOp
1500 //===----------------------------------------------------------------------===//
1501
fold(ArrayRef<Attribute> operands)1502 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1503 // TODO(b/142478136): Handle fused ops.
1504 if (fused_activation_function() != "NONE") return {};
1505
1506 // This function is performance critical for op fusion patterns, e.g.
1507 // FuseBinaryOpToPrecedingAffine and FuseMulOrDivWithConv2dOrDepthwiseConv2d.
1508 // So a few specializations are provided to evaluate the math operation
1509 // more efficiently.
1510
1511 // Specialization for f32 type.
1512 if (getType().cast<ShapedType>().getElementType().isF32()) {
1513 return ConstFoldBinaryOp<FloatAttr, float>(
1514 getType(), operands[0], operands[1],
1515 [](float a, float b) { return a * b; });
1516 }
1517
1518 // Specialization for bf16 type.
1519 if (getType().cast<ShapedType>().getElementType().isBF16()) {
1520 return ConstFoldBinaryOp<FloatAttr, Eigen::bfloat16>(
1521 getType(), operands[0], operands[1],
1522 [](Eigen::bfloat16 a, Eigen::bfloat16 b) { return a * b; });
1523 }
1524
1525 // Generic fallback with APFloat
1526 return ConstFoldBinaryOp(
1527 getType(), operands, [](APFloat a, APFloat b) { return a * b; },
1528 [](APInt a, APInt b) { return a * b; });
1529 }
1530
GetArithmeticCount(Operation * op)1531 int64_t MulOp::GetArithmeticCount(Operation *op) {
1532 int64_t count;
1533 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
1534
1535 return -1;
1536 }
1537
1538 //===----------------------------------------------------------------------===//
1539 // DivOp
1540 //===----------------------------------------------------------------------===//
1541
fold(ArrayRef<Attribute> operands)1542 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1543 // TODO(b/142478136): Handle fused ops.
1544 if (fused_activation_function() != "NONE") return {};
1545 return ConstFoldBinaryOp(
1546 getType(), operands, [](APFloat a, APFloat b) { return a / b; },
1547 [](APInt a, APInt b) { return a.sdiv(b); });
1548 }
1549
GetArithmeticCount(Operation * op)1550 int64_t DivOp::GetArithmeticCount(Operation *op) {
1551 int64_t count;
1552 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
1553
1554 return -1;
1555 }
1556
1557 //===----------------------------------------------------------------------===//
1558 // PackOp
1559 //===----------------------------------------------------------------------===//
1560
1561 // TODO(b/133486129): Implement shape inference for pack
1562
verify()1563 mlir::LogicalResult PackOp::verify() {
1564 PackOp op = *this;
1565 // TODO(antiagainst): Implement other checks as in
1566 // tensorflow/lite/kernels/pack.cc
1567
1568 if (op.getOperation()->getNumOperands() != op.values_count())
1569 return op.emitOpError("input count should match 'values_count' attribute");
1570
1571 Value operand0 = op.getOperand(0);
1572 auto input_type = operand0.getType().cast<ShapedType>();
1573
1574 // Check axis bounds.
1575 if (input_type.hasRank()) {
1576 int32_t axis_value = op.axis();
1577 if (axis_value < 0) axis_value += input_type.getRank() + 1;
1578 if (axis_value < 0 || axis_value >= input_type.getRank() + 1)
1579 return op.emitOpError()
1580 << "op attribute 'axis' should be in range [-rank - 1, rank + 1), "
1581 << "got rank = " << input_type.getRank()
1582 << ", and axis = " << op.axis();
1583 }
1584
1585 // Make sure all inputs have the same shape and element type.
1586 // TODO(b/135032063): Simplify once fixed.
1587 for (Type operand_type : op.getOperandTypes()) {
1588 if (failed(mlir::verifyCompatibleShape(input_type, operand_type)))
1589 return op.emitOpError("operands should be of the same type. got ")
1590 << input_type << ", " << operand_type;
1591 }
1592
1593 return success();
1594 }
1595
1596 //===----------------------------------------------------------------------===//
1597 // PReluOp
1598 //===----------------------------------------------------------------------===//
1599
verify()1600 mlir::LogicalResult PReluOp::verify() {
1601 PReluOp op = *this;
1602 auto input_type = op.input().getType().cast<ShapedType>();
1603 auto alpha_type = op.alpha().getType().cast<ShapedType>();
1604 auto output_type = op.output().getType().cast<ShapedType>();
1605
1606 if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
1607 if (input_type.getRank() != alpha_type.getRank() + 1) {
1608 return op.emitOpError("'alpha' should have one less rank than 'input'.");
1609 }
1610
1611 // Check if alpha is broadcastable
1612 for (int i = 0; i < alpha_type.getRank(); i++) {
1613 if (alpha_type.getDimSize(i) != input_type.getDimSize(i + 1) &&
1614 alpha_type.getDimSize(i) != 1) {
1615 return op.emitOpError(
1616 llvm::formatv("'alpha' is not broadcastable at dimension {0}.", i));
1617 }
1618 }
1619 }
1620
1621 if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
1622 if (input_type.getRank() != output_type.getRank()) {
1623 return op.emitOpError("'input' and 'output' should have the same rank.");
1624 }
1625
1626 // Check if input and output shapes are same
1627 for (int i = 0; i < input_type.getRank(); i++) {
1628 if (input_type.getDimSize(i) != output_type.getDimSize(i)) {
1629 return op.emitOpError(
1630 "'input' and 'output' should have the same shape.");
1631 }
1632 }
1633 }
1634 return success();
1635 }
1636
1637 //===----------------------------------------------------------------------===//
1638 // ReshapeOp
1639 //===----------------------------------------------------------------------===//
1640
1641 namespace {
1642 // This pattern matches and merges a tfl.reshape under the following
1643 // condition:
1644 // * The input's defining op is another tfl.reshape.
1645 // TODO(antiagainst): This pattern probably should be moved to the peephole
1646 // category, after we have the infra for peephole passes.
1647 struct RemoveAdjacentReshape : public RewritePattern {
RemoveAdjacentReshapemlir::TFL::__anon216e30ea0f11::RemoveAdjacentReshape1648 explicit RemoveAdjacentReshape(MLIRContext *context)
1649 : RewritePattern(ReshapeOp::getOperationName(), 1, context) {}
1650
matchmlir::TFL::__anon216e30ea0f11::RemoveAdjacentReshape1651 LogicalResult match(Operation *op) const override {
1652 auto thisOp = cast<ReshapeOp>(op);
1653 auto prevOp = thisOp.getOperand(0).getDefiningOp();
1654 return isa_and_nonnull<ReshapeOp>(prevOp) ? success() : failure();
1655 }
1656
rewritemlir::TFL::__anon216e30ea0f11::RemoveAdjacentReshape1657 void rewrite(Operation *op, PatternRewriter &rewriter) const override {
1658 auto thisOp = cast<ReshapeOp>(op);
1659 auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0).getDefiningOp());
1660
1661 // Replace
1662 // %1 = "tfl.reshape"(%0, %shape0)
1663 // %2 = "tfl.reshape"(%1, %shape1)
1664 // With
1665 // %2 = "tfl.reshape"(%0, %shape1)
1666 rewriter.replaceOpWithNewOp<ReshapeOp>(
1667 op, thisOp.getType(), prevOp.getOperand(0), thisOp.getOperand(1));
1668 }
1669 };
1670
1671 // The kernel expects an 1-D tensor for the shape operand if it presents. If all
1672 // the dimensions are '1's except the last dimension, it will be reshaped to a
1673 // 1-D tensor.
1674 // Note that this pattern doesn't check or change the content of the shape
1675 // tensor.
1676 struct ConvertShapeTo1D : public OpRewritePattern<ReshapeOp> {
1677 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
1678
matchAndRewritemlir::TFL::__anon216e30ea0f11::ConvertShapeTo1D1679 LogicalResult matchAndRewrite(ReshapeOp reshape,
1680 PatternRewriter &rewriter) const override {
1681 if (!reshape.shape().hasOneUse()) return failure();
1682
1683 DenseIntElementsAttr shape;
1684 if (!matchPattern(reshape.shape(), m_Constant(&shape))) {
1685 return failure();
1686 }
1687 // It is already a 1-D constant, no change.
1688 auto old_shape = shape.getType().getShape();
1689 if (old_shape.size() == 1) {
1690 return failure();
1691 }
1692 // Verify all the leading dimensions are length one, except the last one.
1693 for (auto it = ++old_shape.rbegin(); it != old_shape.rend(); ++it) {
1694 if (*it != 1) {
1695 reshape->emitError(
1696 "Non-vector shape input is used, might cause runtime error");
1697 return failure();
1698 }
1699 }
1700 auto new_shape = shape.reshape(RankedTensorType::get(
1701 {*old_shape.rbegin()}, shape.getType().getElementType()));
1702 rewriter.replaceOpWithNewOp<TFL::ConstOp>(reshape.shape().getDefiningOp(),
1703 new_shape);
1704 return success();
1705 }
1706 };
1707
InputOutputHasSameShape(mlir::Type input_type,mlir::Type output_type)1708 bool InputOutputHasSameShape(mlir::Type input_type, mlir::Type output_type) {
1709 auto input_shaped_type = input_type.dyn_cast_or_null<ShapedType>();
1710 if (!input_shaped_type || !input_shaped_type.hasStaticShape()) return false;
1711
1712 auto output_shaped_type = output_type.dyn_cast_or_null<ShapedType>();
1713 if (!output_shaped_type || !output_shaped_type.hasStaticShape()) return false;
1714
1715 return input_shaped_type == output_shaped_type;
1716 }
1717
1718 } // end anonymous namespace
1719
fold(ArrayRef<Attribute> operands)1720 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
1721 // Remove identity reshape with both static result and input shape.
1722 auto result_type = getType().cast<ShapedType>();
1723 auto input_type = getOperand(0).getType().cast<ShapedType>();
1724 if (InputOutputHasSameShape(input_type, result_type)) return input();
1725
1726 // Constant folding
1727 if (auto dense_elements = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
1728 // If the result type isn't static, tries to derive the result type from
1729 // the #2 operand.
1730 if (!result_type.hasStaticShape()) {
1731 auto shape_elements = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1732 if (!shape_elements) return nullptr;
1733
1734 SmallVector<int64_t, 4> shape_data;
1735 for (const auto &it : shape_elements.getValues<APInt>()) {
1736 shape_data.push_back(it.getSExtValue());
1737 }
1738 result_type =
1739 RankedTensorType::get(shape_data, input_type.getElementType());
1740 }
1741 return dense_elements.reshape(result_type);
1742 }
1743
1744 return nullptr;
1745 }
1746
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1747 void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1748 MLIRContext *context) {
1749 results.add<RemoveAdjacentReshape, ConvertShapeTo1D>(context);
1750 }
1751
1752 using ReshapeErrorHandler =
1753 llvm::function_ref<LogicalResult(const llvm::Twine &)>;
1754
GetReshapeOutputType(Value input,Value shape,ReshapeErrorHandler error_handler,TensorType & output_ty)1755 LogicalResult GetReshapeOutputType(Value input, Value shape,
1756 ReshapeErrorHandler error_handler,
1757 TensorType &output_ty) {
1758 auto input_ty = input.getType().cast<TensorType>();
1759 auto element_ty = input_ty.getElementType();
1760 output_ty = UnrankedTensorType::get(element_ty);
1761
1762 auto shape_ty = shape.getType().dyn_cast<RankedTensorType>();
1763 if (!shape_ty) return success();
1764 if (shape_ty.getRank() != 1)
1765 return error_handler(llvm::formatv(
1766 "requires 'shape' to be rank 1, but got {0}", shape_ty.getRank()));
1767
1768 DenseIntElementsAttr shape_attr;
1769 if (!matchPattern(shape, m_Constant(&shape_attr))) {
1770 // If only shape of `shape` is known, return ranked but dynamic output
1771 // shape.
1772 if (shape_ty.hasStaticShape()) {
1773 llvm::SmallVector<int64_t, 8> dynamic_shape(shape_ty.getDimSize(0),
1774 ShapedType::kDynamicSize);
1775 output_ty = RankedTensorType::get(dynamic_shape, element_ty);
1776 }
1777 return success();
1778 }
1779
1780 // Detect if reshape output shape is folded.
1781 bool shape_ty_zero_dim = false;
1782 int unknown_index = -1;
1783 // The product of constant shape argument excluding unknown dimension.
1784 int64_t shape_ty_size = 1;
1785 llvm::SmallVector<int64_t, 8> output_ty_shape;
1786 output_ty_shape.reserve(shape_attr.getNumElements());
1787 for (const auto &dim : llvm::enumerate(shape_attr.getValues<APInt>())) {
1788 const int64_t size = dim.value().getSExtValue();
1789 if (size == ShapedType::kDynamicSize) {
1790 if (unknown_index != -1)
1791 return error_handler(llvm::formatv(
1792 "requires 'shape' to have at most one dynamic dimension, but got "
1793 "multiple dynamic dimensions at indices {0} and {1}. You need to "
1794 "set up the unspecified size(s) to avoid this problem, for example,"
1795 "setting batch size in keras model or setting unspecified input "
1796 "size(s) with fixed ones.",
1797 unknown_index, dim.index()));
1798
1799 unknown_index = dim.index();
1800 } else if (size == 0) {
1801 shape_ty_zero_dim = true;
1802 } else if (size > 0) {
1803 shape_ty_size *= size;
1804 } else {
1805 return error_handler(
1806 llvm::formatv("requires 'shape' to have dimensions greater than -1, "
1807 "but got {0} at index {1}",
1808 size, dim.index()));
1809 }
1810 output_ty_shape.push_back(size);
1811 }
1812
1813 if (!input_ty.hasStaticShape()) {
1814 output_ty = RankedTensorType::get(output_ty_shape, element_ty);
1815 return success();
1816 }
1817
1818 // Compute the value of the unknown dimension.
1819 if (unknown_index != -1) {
1820 // Compute number of elements in tensor shape.
1821 int64_t input_ty_size = 1;
1822 bool input_ty_zero_dim = false;
1823 for (const auto &dim : input_ty.getShape()) {
1824 if (dim > 0 || !shape_ty_zero_dim) {
1825 input_ty_size *= dim;
1826 } else {
1827 input_ty_zero_dim = true;
1828 }
1829 }
1830
1831 const int64_t missing_dim = input_ty_size / shape_ty_size;
1832 if (!input_ty_zero_dim && shape_ty_size * missing_dim != input_ty_size)
1833 return error_handler(
1834 llvm::formatv("requires 'input' number of elements be a multiple of "
1835 "{0}, but got {1}",
1836 shape_ty_size, input_ty_size));
1837
1838 // Set the unknown dimension such that total number of elements remain
1839 // constant.
1840 output_ty_shape[unknown_index] = missing_dim;
1841 }
1842
1843 output_ty = RankedTensorType::get(output_ty_shape, element_ty);
1844
1845 return success();
1846 }
1847
verify()1848 mlir::LogicalResult ReshapeOp::verify() {
1849 ReshapeOp op = *this;
1850 auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult {
1851 return op.emitOpError() << message;
1852 };
1853 TensorType expected_ty;
1854 if (failed(GetReshapeOutputType(op.input(), op.shape(), error_handler,
1855 expected_ty)))
1856 return failure();
1857
1858 auto output_ty = op.getType().dyn_cast<RankedTensorType>();
1859 if (!output_ty) return success();
1860 auto input_ty = op.input().getType().cast<TensorType>();
1861 if (output_ty.hasStaticShape() && input_ty.hasStaticShape()) {
1862 const int64_t output_ty_size = output_ty.getNumElements();
1863 const int64_t input_ty_size = input_ty.getNumElements();
1864 if (input_ty_size != output_ty_size)
1865 return op.emitOpError() << "requires 'output' number of elements to "
1866 "match 'input' number of elements, but got "
1867 << output_ty_size << " and " << input_ty_size;
1868 }
1869
1870 if (!TF::AreCastCompatible({output_ty, expected_ty}))
1871 return op.emitOpError()
1872 << "requires 'output' type " << output_ty
1873 << " to be cast compatible with expected type " << expected_ty;
1874
1875 return success();
1876 }
1877
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1878 LogicalResult ReshapeOp::inferReturnTypes(
1879 MLIRContext *context, Optional<Location> location, ValueRange operands,
1880 DictionaryAttr attr, RegionRange,
1881 SmallVectorImpl<Type> &inferredReturnTypes) {
1882 ReshapeOpAdaptor op(operands, attr);
1883 const Value input = op.input();
1884 const Value shape = op.shape();
1885
1886 auto error_handler = [&](const llvm::Twine &message) -> LogicalResult {
1887 // A dummy error handler.
1888 // Errors when computing the output shape will be raised in
1889 // ReshapeOp::verify call.
1890 return failure();
1891 };
1892 TensorType output_type;
1893 if (GetReshapeOutputType(input, shape, error_handler, output_type)
1894 .succeeded()) {
1895 inferredReturnTypes.assign({output_type});
1896 return success();
1897 }
1898 Type result_type;
1899 result_type = UnrankedTensorType::get(
1900 input.getType().cast<ShapedType>().getElementType());
1901 inferredReturnTypes.assign({result_type});
1902 return success();
1903 }
1904
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)1905 bool ReshapeOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
1906 if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
1907 if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
1908 return true;
1909 }
1910
1911 //===----------------------------------------------------------------------===//
1912 // PackOp
1913 //===----------------------------------------------------------------------===//
1914
1915 // Remove redundant unpack pack op.
1916 // If a unpack op is followed by a pack op, we can remove the pack op, if the
1917 // unpack op is only consumed by the pack op, it will be removed as well.
1918 // An example illustration is:
1919 // Unpack [5, 8, 9], axis = 1
1920 // / \
1921 // value ... value [5, 9], 8 values in total
1922 // \ /
1923 // pack, axis = 1
1924 // |
1925 // value [5, 8, 9]
1926 //
1927 // This can actually be simplified into just:
1928 //
1929 // => Value [5, 8, 9]
1930 // TODO(b/133341698): Move to tablegen when variadic is supported.
1931 struct RemoveRedundantUnpackPack : public RewritePattern {
RemoveRedundantUnpackPackmlir::TFL::RemoveRedundantUnpackPack1932 explicit RemoveRedundantUnpackPack(MLIRContext *context)
1933 : RewritePattern(PackOp::getOperationName(), 2, context) {}
1934
matchAndRewritemlir::TFL::RemoveRedundantUnpackPack1935 LogicalResult matchAndRewrite(Operation *op,
1936 PatternRewriter &rewriter) const override {
1937 TFL::PackOp pack_op = cast<TFL::PackOp>(op);
1938 Operation *first_input = pack_op.getOperand(0).getDefiningOp();
1939 if (!first_input) return failure();
1940 auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
1941 if (!input_unpack_op) return failure();
1942
1943 // The unpack & pack should have the same axis & num inputs/outputs.
1944 if (pack_op.axis() != input_unpack_op.axis() ||
1945 pack_op.values_count() != input_unpack_op.num())
1946 return failure();
1947
1948 const int total_pack_inputs = pack_op.getNumOperands();
1949 const int num_results = input_unpack_op.getNumResults();
1950 if (total_pack_inputs != num_results) return failure();
1951 for (auto input_output :
1952 llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) {
1953 Value pack_input = std::get<0>(input_output);
1954 Value unpack_output = std::get<1>(input_output);
1955 // Make sure the ordering is the same for the pack op & unpack op.
1956 if (pack_input != unpack_output) return failure();
1957 }
1958
1959 // Replace the pack's output to the unpack's input.
1960 rewriter.replaceOp(pack_op, input_unpack_op.getOperand());
1961 // At this point, we don't manually remove the redundant pack op & unpack op
1962 // (we cannot actually), but trust the PatterRewriter to garbage collect
1963 // these two ops.
1964 return success();
1965 }
1966 };
1967
1968 // Replace PackOp with a reshape when there is only one operand.
1969 struct ReplacePackWithReshape : public RewritePattern {
ReplacePackWithReshapemlir::TFL::ReplacePackWithReshape1970 explicit ReplacePackWithReshape(MLIRContext *context)
1971 : RewritePattern(PackOp::getOperationName(), 2, context) {}
matchAndRewritemlir::TFL::ReplacePackWithReshape1972 LogicalResult matchAndRewrite(Operation *op,
1973 PatternRewriter &rewriter) const override {
1974 TFL::PackOp pack_op = cast<TFL::PackOp>(op);
1975 if (pack_op.getNumOperands() != 1) return failure();
1976
1977 Location loc = pack_op.getLoc();
1978 auto output_type = pack_op.getType().cast<ShapedType>();
1979 if (!output_type.hasStaticShape()) return failure();
1980
1981 // This is to workaround the unnecessary cast i64 -> i32.
1982 SmallVector<int32_t, 4> new_shape_array;
1983 for (auto size : output_type.getShape()) {
1984 new_shape_array.push_back(static_cast<int32_t>(size));
1985 }
1986
1987 auto new_shape = rewriter.create<TFL::ConstOp>(
1988 loc, DenseIntElementsAttr::get(
1989 RankedTensorType::get(new_shape_array.size(),
1990 rewriter.getIntegerType(32)),
1991 new_shape_array));
1992
1993 rewriter.replaceOpWithNewOp<ReshapeOp>(op, output_type,
1994 pack_op.getOperand(0), new_shape);
1995 return success();
1996 }
1997 };
1998
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1999 void PackOp::getCanonicalizationPatterns(RewritePatternSet &results,
2000 MLIRContext *context) {
2001 results.add<RemoveRedundantUnpackPack, ReplacePackWithReshape>(context);
2002 }
2003
2004 //===----------------------------------------------------------------------===//
2005 // SliceOp
2006 //===----------------------------------------------------------------------===//
2007
verify()2008 mlir::LogicalResult SliceOp::verify() {
2009 SliceOp op = *this;
2010 auto input_type = op.input().getType().cast<ShapedType>();
2011 auto begin_type = op.begin().getType().cast<ShapedType>();
2012 auto size_type = op.size().getType().cast<ShapedType>();
2013 if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
2014 size_type.hasStaticShape()) {
2015 if (input_type.getRank() != begin_type.getNumElements()) {
2016 return op.emitError(
2017 "begin tensor elements size is not equal to input tensor rank");
2018 }
2019
2020 if (input_type.getRank() != size_type.getNumElements()) {
2021 return op.emitError(
2022 "size tensor elements size is not equal to input tensor rank");
2023 }
2024 }
2025
2026 DenseIntElementsAttr begin;
2027 if (matchPattern(op.begin(), m_Constant(&begin))) {
2028 int axis = 0;
2029 for (const auto &begin_i : llvm::enumerate(begin)) {
2030 if (begin_i.value().getSExtValue() < 0) {
2031 return op.emitError(
2032 llvm::formatv("begin[{0}] cannot be negative", axis));
2033 }
2034 axis++;
2035 }
2036 }
2037
2038 DenseIntElementsAttr size;
2039 if (matchPattern(op.size(), m_Constant(&size))) {
2040 int axis = 0;
2041 for (const auto &size_i : llvm::enumerate(size)) {
2042 if (size_i.value().getSExtValue() < -1) {
2043 return op.emitError(
2044 llvm::formatv("size[{0}] cannot be negative other than -1", axis));
2045 }
2046 axis++;
2047 }
2048 }
2049
2050 if (begin && size && input_type.hasStaticShape()) {
2051 for (uint64_t i = 0, end = begin.getNumElements(); i < end; i++) {
2052 int begin_i = begin.getValues<APInt>()[i].getSExtValue();
2053 int size_i = size.getValues<APInt>()[i].getSExtValue();
2054 int dim_i = input_type.getShape()[i];
2055 if (begin_i > dim_i) {
2056 return op.emitOpError(llvm::formatv(
2057 "begin[{0}] cannot exceed dimension length: {1}", i, dim_i));
2058 }
2059 if (size_i >= 0 && begin_i + size_i > dim_i) {
2060 return op.emitError(llvm::formatv(
2061 "begin[{0}] + size[{0}] cannot exceed dimension length: {1}", i,
2062 dim_i));
2063 }
2064 }
2065 }
2066
2067 return success();
2068 }
2069
NarrowDownInt64InputValuesForOp(Operation * input_op,RankedTensorType value_type,Location loc,OpBuilder * builder)2070 TFL::ConstOp NarrowDownInt64InputValuesForOp(Operation *input_op,
2071 RankedTensorType value_type,
2072 Location loc, OpBuilder *builder) {
2073 if (input_op == nullptr) return nullptr;
2074
2075 mlir::DenseIntElementsAttr attr;
2076 if (!matchPattern(input_op, m_Constant(&attr))) {
2077 return nullptr;
2078 }
2079
2080 auto value_shape_type = mlir::RankedTensorType::get(
2081 value_type.getShape(), builder->getIntegerType(32));
2082
2083 SmallVector<int32_t, 4> value_i32;
2084 value_i32.reserve(value_type.getRank());
2085 for (const auto &size : attr) {
2086 value_i32.push_back(static_cast<int32_t>(size.getSExtValue()));
2087 }
2088 auto new_value_i32_attr =
2089 mlir::DenseIntElementsAttr::get(value_shape_type, value_i32);
2090
2091 return builder->create<TFL::ConstOp>(loc, new_value_i32_attr);
2092 }
2093
2094 // This will cast down int64 values for TFL slice op.
2095 // This will require the begin & size are constants.
2096 struct CastDonwInt64BeginEndToInt32 : public OpRewritePattern<TFL::SliceOp> {
2097 using OpRewritePattern<TFL::SliceOp>::OpRewritePattern;
2098
matchAndRewritemlir::TFL::CastDonwInt64BeginEndToInt322099 LogicalResult matchAndRewrite(TFL::SliceOp slice_op,
2100 PatternRewriter &rewriter) const override {
2101 auto begin = slice_op.begin();
2102 auto size = slice_op.size();
2103 auto begin_type = begin.getType().dyn_cast_or_null<RankedTensorType>();
2104 auto size_type = size.getType().dyn_cast_or_null<RankedTensorType>();
2105 auto begin_op = begin.getDefiningOp();
2106 auto size_op = size.getDefiningOp();
2107
2108 if (begin_op == nullptr && size_op == nullptr) return failure();
2109
2110 if (begin_type == nullptr && size_type == nullptr) return failure();
2111
2112 // Handle begin.
2113 if (begin_op && begin_type && begin_type.getElementType().isInteger(64)) {
2114 auto new_begin = NarrowDownInt64InputValuesForOp(
2115 begin_op, begin_type, slice_op.getLoc(), &rewriter);
2116 if (new_begin != nullptr) {
2117 slice_op.setOperand(1, new_begin);
2118 }
2119 }
2120
2121 // Handle size.
2122 if (size_op && size_type && size_type.getElementType().isInteger(64)) {
2123 auto new_size = NarrowDownInt64InputValuesForOp(
2124 size_op, size_type, slice_op.getLoc(), &rewriter);
2125 if (new_size != nullptr) {
2126 slice_op.setOperand(2, new_size);
2127 }
2128 }
2129
2130 return success();
2131 }
2132 };
2133
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2134 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2135 MLIRContext *context) {
2136 results.add<CastDonwInt64BeginEndToInt32>(context);
2137 }
2138
2139 //===----------------------------------------------------------------------===//
2140 // SqueezeOp
2141 //===----------------------------------------------------------------------===//
2142
fold(ArrayRef<Attribute> operands)2143 OpFoldResult SqueezeOp::fold(ArrayRef<Attribute> operands) {
2144 auto input_ty = input().getType().dyn_cast<RankedTensorType>();
2145 auto result_ty = getType().dyn_cast<RankedTensorType>();
2146
2147 if (!input_ty || !result_ty) return {};
2148 if (input_ty == result_ty) return input();
2149 return {};
2150 }
2151
2152 //===----------------------------------------------------------------------===//
2153 // SubOp
2154 //===----------------------------------------------------------------------===//
2155
fold(ArrayRef<Attribute> operands)2156 OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
2157 // TODO(b/142478136): Handle fused ops.
2158 if (fused_activation_function() != "NONE") return {};
2159 return ConstFoldBinaryOp(
2160 getType(), operands, [](APFloat a, APFloat b) { return a - b; },
2161 [](APInt a, APInt b) { return a - b; });
2162 }
2163
GetArithmeticCount(Operation * op)2164 int64_t SubOp::GetArithmeticCount(Operation *op) {
2165 int64_t count;
2166 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
2167
2168 return -1;
2169 }
2170
2171 //===----------------------------------------------------------------------===//
2172 // TopKOp
2173 //===----------------------------------------------------------------------===//
2174
BuildTopKOp(OpBuilder * builder,OperationState & result,Value input,Value k)2175 static void BuildTopKOp(OpBuilder *builder, OperationState &result, Value input,
2176 Value k) {
2177 // Output size is only known if k is constant value. A negative dimension is
2178 // considered dynamic so use -1 here if k is not a constant value.
2179 int const_k = -1;
2180 ElementsAttr cst;
2181 if (matchPattern(k, m_Constant(&cst)))
2182 // These casts should all be valid due to how Tensor constants are stored.
2183 // TODO(jpienaar): This should use a helper function.
2184 const_k = cst.getValues<IntegerAttr>()[0].getValue().getSExtValue();
2185
2186 auto val_type = input.getType().cast<TensorType>();
2187 // If value is unranked, then so is results.
2188 if (!val_type.hasRank())
2189 return TFL::TopKV2Op::build(
2190 *builder, result, UnrankedTensorType::get(val_type.getElementType()),
2191 UnrankedTensorType::get(builder->getIntegerType(32)), input, k);
2192
2193 // Resultant shape is value.shape[:-1] + [k]
2194 std::vector<int64_t> shape(val_type.getShape());
2195 shape[shape.size() - 1] = const_k;
2196 TFL::TopKV2Op::build(
2197 *builder, result, RankedTensorType::get(shape, val_type.getElementType()),
2198 RankedTensorType::get(shape, builder->getIntegerType(32)), input, k);
2199 }
2200
2201 //===----------------------------------------------------------------------===//
2202 // FakeQuantOp
2203 //===----------------------------------------------------------------------===//
2204
2205 // Return true if the op has non-empty "minmax" attribute.
HasValidMinMaxAttribute(Operation * op)2206 static inline bool HasValidMinMaxAttribute(Operation *op) {
2207 auto minmax = op->getAttrOfType<ArrayAttr>("minmax");
2208 return minmax && minmax.getValue().size() == 2;
2209 }
2210
2211 namespace {
2212
2213 /// This pattern matches and remove a tfl.fake_quant if all the users of this op
2214 /// and itself have "minmax" attribute set.
2215 struct DropFakeQuant : public RewritePattern {
DropFakeQuantmlir::TFL::__anon216e30ea1411::DropFakeQuant2216 explicit DropFakeQuant(MLIRContext *context)
2217 : RewritePattern(FakeQuantOp::getOperationName(), 1, context) {}
2218
matchmlir::TFL::__anon216e30ea1411::DropFakeQuant2219 LogicalResult match(Operation *op) const override {
2220 // We only match the op with valid "minmax" attribute.
2221 if (!HasValidMinMaxAttribute(op)) return failure();
2222
2223 // If all the users of this op have valid "minmax" attributes, it is matched
2224 // and can be removed.
2225 auto fakeQuantOp = cast<FakeQuantOp>(op);
2226 for (auto *operand : fakeQuantOp.getResult().getUsers())
2227 if (!HasValidMinMaxAttribute(operand)) return failure();
2228
2229 return success();
2230 }
2231
rewritemlir::TFL::__anon216e30ea1411::DropFakeQuant2232 void rewrite(Operation *op, PatternRewriter &rewriter) const override {
2233 // Replace the matched FakeQuantOp by its primary operand.
2234 rewriter.replaceOp(op, op->getOperand(0));
2235 }
2236 };
2237 } // end anonymous namespace
2238
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2239 void FakeQuantOp::getCanonicalizationPatterns(RewritePatternSet &results,
2240 MLIRContext *context) {
2241 results.add<DropFakeQuant>(context);
2242 }
2243
2244 //===----------------------------------------------------------------------===//
2245 // UnpackOp
2246 //===----------------------------------------------------------------------===//
2247
2248 // TODO(b/133486129): Implement shape inference for unpack
2249
inferReturnTypes(MLIRContext * context,Optional<Location> loc,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)2250 LogicalResult UnpackOp::inferReturnTypes(
2251 MLIRContext *context, Optional<Location> loc, ValueRange operands,
2252 DictionaryAttr attributes, RegionRange regions,
2253 SmallVectorImpl<Type> &inferredReturnTypes) {
2254 UnpackOpAdaptor op(operands, attributes);
2255 // TODO(jpienaar): Refactor verify
2256 if (failed(op.verify(loc.has_value() ? *loc : UnknownLoc::get(context))))
2257 return failure();
2258
2259 if (operands.size() != 1) {
2260 return emitOptionalError(loc, "input count should be equal to 1");
2261 }
2262
2263 const int64_t num_value = op.numAttr().getInt();
2264 auto input_type = operands[0].getType().dyn_cast<ShapedType>();
2265 if (!input_type || !input_type.hasRank()) {
2266 // If input is unranked, then so is output.
2267 inferredReturnTypes.assign(
2268 num_value, UnrankedTensorType::get(input_type.getElementType()));
2269 return success();
2270 }
2271
2272 if (input_type.hasStaticShape() && input_type.getNumElements() <= 0) {
2273 return emitOptionalError(
2274 loc, "number of elements in input should be larger than 0");
2275 }
2276
2277 const int64_t rank = input_type.getRank();
2278 if (rank <= 0) {
2279 return emitOptionalError(loc, "input should be of rank larger than 0");
2280 }
2281
2282 int64_t axis_value = op.axisAttr().getInt();
2283 if (axis_value < 0) {
2284 axis_value += rank;
2285 }
2286 if (axis_value < 0 || axis_value >= rank) {
2287 return emitOptionalError(
2288 loc, "attribute 'axis' should be in range [-rank, rank), got axis = ",
2289 op.axisAttr().getInt(), ", and rank = ", rank);
2290 }
2291
2292 if (!ShapedType::isDynamic(input_type.getDimSize(axis_value)) &&
2293 input_type.getDimSize(axis_value) != num_value) {
2294 return emitOptionalError(loc, "output count should match 'num' attribute");
2295 }
2296
2297 auto output_shape = llvm::to_vector<4>(input_type.getShape());
2298 output_shape.erase(output_shape.begin() + axis_value);
2299
2300 auto output_type =
2301 RankedTensorType::get(output_shape, input_type.getElementType());
2302 inferredReturnTypes.assign(num_value, output_type);
2303
2304 return success();
2305 }
2306
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)2307 bool UnpackOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
2308 if (lhs.size() != rhs.size()) return false;
2309 for (auto pair : llvm::zip(lhs, rhs)) {
2310 if (failed(
2311 mlir::verifyCompatibleShape(std::get<0>(pair), std::get<1>(pair))))
2312 return false;
2313 }
2314 return true;
2315 }
2316
2317 //===----------------------------------------------------------------------===//
2318 // SplitOp
2319 //===----------------------------------------------------------------------===//
2320
2321 // Extracts and returns the signed integer constant in a 0-rank integer tensor
2322 // or 1-element 1-rank integer tensor if 'value' is a constant.
ExtractConstantIntFromTensor(Value value)2323 static llvm::Optional<int64_t> ExtractConstantIntFromTensor(Value value) {
2324 ElementsAttr attr;
2325 if (!matchPattern(value, m_Constant(&attr))) return {};
2326 if (attr.getNumElements() != 1) return {};
2327 IntegerAttr int_attr = *attr.getValues<IntegerAttr>().begin();
2328 return int_attr.getValue().getSExtValue();
2329 }
2330
2331 // Returns a RankedTensorType which is similar to `input_type` but replaces the
2332 // dimension size of `dim` with `dim_size`. For example,
2333 // `SubstituteRankedTensorTypeDimSize(tensor<3x4xi32>, 1, 2)` returns
2334 // `tensor<3x2xi32>`.
SubstituteRankedTensorTypeDimSize(RankedTensorType input_type,int64_t dim,int64_t dim_size)2335 static RankedTensorType SubstituteRankedTensorTypeDimSize(
2336 RankedTensorType input_type, int64_t dim, int64_t dim_size) {
2337 auto shape = input_type.getShape().vec();
2338 shape[dim] = dim_size;
2339 return RankedTensorType::get(shape, input_type.getElementType());
2340 }
2341
2342 // Verifies the output tensor types of SplitOp or SplitVOp.
2343 template <typename ExpectedOutputTypeGetter>
VerifySplitOpOutputTypes(Operation * op,int64_t num_splits,ExpectedOutputTypeGetter get_expected_output_type)2344 static LogicalResult VerifySplitOpOutputTypes(
2345 Operation *op, int64_t num_splits,
2346 ExpectedOutputTypeGetter get_expected_output_type) {
2347 for (int64_t i = 0; i < num_splits; ++i) {
2348 auto expected_output_type = get_expected_output_type(i);
2349 Value output = op->getResult(i);
2350 if (failed(verifyCompatibleShape(output.getType(), expected_output_type)))
2351 return op->emitOpError()
2352 << "output #" << i << " should be " << expected_output_type
2353 << " instead got " << output.getType();
2354 }
2355 return success();
2356 }
2357
verify()2358 mlir::LogicalResult SplitOp::verify() {
2359 SplitOp op = *this;
2360 int64_t num_splits = op.num_splits();
2361 if (op.getNumResults() != num_splits)
2362 return op.emitOpError("output count should match 'num_splits' attribute");
2363
2364 // If 'split_dim' is not a constant, there are no other checks.
2365 llvm::Optional<int64_t> split_dim_opt =
2366 ExtractConstantIntFromTensor(op.split_dim());
2367 if (!split_dim_opt) return success();
2368
2369 // If 'input' is not a ranked tensor, there are no other checks.
2370 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
2371 if (!input_type) return success();
2372
2373 int64_t split_dim = split_dim_opt.getValue();
2374 const int64_t rank = input_type.getRank();
2375 if (split_dim < 0) split_dim += rank;
2376 if (split_dim < 0 || split_dim >= rank)
2377 return op.emitOpError("'split_dim' should be in [-rank, rank)");
2378
2379 // If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
2380 // there are no other checks.
2381 const int64_t dim_size = input_type.getDimSize(split_dim);
2382 if (ShapedType::isDynamic(dim_size)) return success();
2383
2384 if (dim_size % num_splits != 0)
2385 return op.emitOpError("'num_splits' should evenly divide 'split_dim' axis");
2386
2387 // Verifies output tensor types.
2388 RankedTensorType expected_output_type = SubstituteRankedTensorTypeDimSize(
2389 input_type, split_dim, dim_size / num_splits);
2390 return VerifySplitOpOutputTypes(
2391 op.getOperation(), num_splits,
2392 [expected_output_type](int64_t) { return expected_output_type; });
2393 }
2394
verify()2395 mlir::LogicalResult SplitVOp::verify() {
2396 SplitVOp op = *this;
2397 int64_t num_splits = op.num_splits();
2398 if (op.getNumResults() != num_splits)
2399 return op.emitOpError("output count should match 'num_splits' attribute");
2400
2401 // If 'split_dim' is not a constant, there are no other checks.
2402 llvm::Optional<int64_t> split_dim_opt =
2403 ExtractConstantIntFromTensor(op.split_dim());
2404 if (!split_dim_opt) return success();
2405
2406 // If 'input' is not a ranked tensor, there are no other checks.
2407 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
2408 if (!input_type) return success();
2409
2410 int64_t split_dim = split_dim_opt.getValue();
2411 const int64_t rank = input_type.getRank();
2412 if (split_dim < 0) split_dim += rank;
2413 if (split_dim < 0 || split_dim >= rank)
2414 return op.emitOpError("'split_dim' should be in [-rank, rank)");
2415
2416 // If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
2417 // there are no other checks.
2418 const int64_t dim_size = input_type.getDimSize(split_dim);
2419 if (ShapedType::isDynamic(dim_size)) return success();
2420
2421 // If 'size_splits' is not a constant, there are no other checks.
2422 ElementsAttr size_splits_attr;
2423 if (!matchPattern(op.size_splits(), m_Constant(&size_splits_attr)))
2424 return success();
2425
2426 if (size_splits_attr.getNumElements() != num_splits) {
2427 auto size_splits_type = op.size_splits().getType().cast<RankedTensorType>();
2428 RankedTensorType expected_size_splits_type =
2429 RankedTensorType::get({num_splits}, size_splits_type.getElementType());
2430 return op.emitOpError("'size_splits' should be ")
2431 << expected_size_splits_type;
2432 }
2433
2434 // Normalizes and verifies 'size_splits'.
2435 // Note: TensorFlow allows one -1 element in 'size_splits'. The -1 element
2436 // means the rest of the dimension size.
2437 llvm::SmallVector<int64_t, 4> size_splits;
2438 size_splits.reserve(num_splits);
2439
2440 int64_t negative_size_split_loc = -1;
2441 int64_t total_size_splits = 0;
2442
2443 for (int64_t i = 0; i < num_splits; ++i) {
2444 auto size_split_attr = size_splits_attr.getValues<IntegerAttr>()[i];
2445 int64_t size_split = size_split_attr.getValue().getSExtValue();
2446 size_splits.push_back(size_split);
2447 if (size_split >= 0) {
2448 total_size_splits += size_split;
2449 continue;
2450 }
2451 if (size_split < -1)
2452 return op.emitOpError(
2453 "elements of 'size_splits' should be greater than or equal to -1");
2454 if (negative_size_split_loc != -1)
2455 return op.emitOpError("'size_splits' can only have one -1");
2456 negative_size_split_loc = i;
2457 }
2458
2459 if (negative_size_split_loc != -1) {
2460 if (total_size_splits > dim_size)
2461 return op.emitOpError(
2462 "sum of non-negative elements of 'size_splits' is greater than the "
2463 "dimension size of 'split_dim' axis");
2464 size_splits[negative_size_split_loc] = dim_size - total_size_splits;
2465 total_size_splits = dim_size;
2466 }
2467
2468 if (total_size_splits != dim_size)
2469 return op.emitOpError(
2470 "sum of 'size_splits' should match the dimension size of 'split_dim' "
2471 "axis");
2472
2473 // Verifies result tensor types.
2474 auto get_expected_output_type = [input_type, split_dim,
2475 &size_splits](int64_t i) {
2476 return SubstituteRankedTensorTypeDimSize(input_type, split_dim,
2477 size_splits[i]);
2478 };
2479 return VerifySplitOpOutputTypes(op.getOperation(), num_splits,
2480 get_expected_output_type);
2481 }
2482
2483 //===----------------------------------------------------------------------===//
2484 // MeanOp
2485 //===----------------------------------------------------------------------===//
2486
2487 // TODO(b/133854225): Implement shape inference to Mean
2488
2489 //===----------------------------------------------------------------------===//
2490 // LSTMOp
2491 //===----------------------------------------------------------------------===//
2492
verify()2493 mlir::LogicalResult LSTMOp::verify() {
2494 LSTMOp op = *this;
2495 auto operands = op.GetStatefulOperands();
2496 if (operands.size() != 2 || operands[0] != 18 || operands[1] != 19) {
2497 return op.emitOpError("LSTMOp expected to have two stateful operands");
2498 }
2499
2500 const auto input_type = op.input().getType().cast<ShapedType>();
2501 // Since TFLite runtime generally supports dynamic shape/rank, if `input_type`
2502 // doesn't have static shape, we skip the shape check below.
2503 if (!input_type.hasStaticShape()) return success();
2504 // The input should be at least 2D tensor since it will go through fully
2505 // connected layer.
2506 if (!input_type.hasRank() || input_type.getRank() < 2)
2507 return op.emitOpError(
2508 "the first input operand should have more than 2 dimensions.");
2509
2510 const auto activation_state =
2511 op.input_activation_state().getType().cast<ShapedType>();
2512 const auto cell_state = op.input_cell_state().getType().cast<ShapedType>();
2513 const auto input_to_output_weights =
2514 op.input_to_output_weights().getType().cast<ShapedType>();
2515 const auto recurrent_to_output_weights =
2516 op.recurrent_to_output_weights().getType().cast<ShapedType>();
2517 if (activation_state.hasStaticShape() && cell_state.hasStaticShape() &&
2518 input_to_output_weights.hasStaticShape() &&
2519 recurrent_to_output_weights.hasStaticShape()) {
2520 const int n_input = input_type.getDimSize(input_type.getRank() - 1);
2521 const int n_cell = input_to_output_weights.getDimSize(0);
2522 const int n_output = recurrent_to_output_weights.getDimSize(1);
2523 const int output_state_size = activation_state.getNumElements();
2524 const int n_batch = input_type.getRank() == 2 ? input_type.getDimSize(0)
2525 : input_type.getDimSize(1);
2526 const int state_size = cell_state.getNumElements();
2527
2528 // Check if the dimension of the inputs matches.
2529 if ((output_state_size != n_batch * n_output) ||
2530 (state_size != n_batch * n_cell) ||
2531 (input_to_output_weights.getDimSize(1) != n_input) ||
2532 (recurrent_to_output_weights.getRank() != 2) ||
2533 (recurrent_to_output_weights.getDimSize(0) != n_cell) ||
2534 (input_to_output_weights.getRank() != 2)) {
2535 return op.emitOpError("inputs don't match with the dimensions.");
2536 }
2537
2538 const bool is_layer_norm_lstm =
2539 !op.forget_layer_norm_coefficients().getType().isa<NoneType>();
2540 if (is_layer_norm_lstm) {
2541 const auto forget_layer_norm_coefficients =
2542 op.forget_layer_norm_coefficients().getType().cast<ShapedType>();
2543 // If this lstm has layer normalization, this input value,
2544 // "forget_layer_norm_coefficients" should be a 1D tensor.
2545 if (!forget_layer_norm_coefficients.hasRank() ||
2546 forget_layer_norm_coefficients.getRank() != 1 ||
2547 forget_layer_norm_coefficients.getDimSize(0) != n_cell)
2548 return op.emitOpError(
2549 "coefficient inputs have more than 2 dimensions or "
2550 "don't match the dimension with input operand "
2551 "`input_to_output_weights`.");
2552 }
2553 }
2554
2555 return success();
2556 }
2557
2558 namespace {
2559
2560 // Replaces the optional bias operands with a "none" type value if the bias
2561 // values are constant zeros.
2562 struct RemoveLSTMOpZeroBias : public OpRewritePattern<LSTMOp> {
2563 using OpRewritePattern<LSTMOp>::OpRewritePattern;
2564
matchAndRewritemlir::TFL::__anon216e30ea1711::RemoveLSTMOpZeroBias2565 LogicalResult matchAndRewrite(LSTMOp op,
2566 PatternRewriter &rewriter) const override {
2567 if (EqualsZero(op.input_gate_bias())) {
2568 auto none_value = rewriter.create<TFL::NoValueOp>(
2569 rewriter.getUnknownLoc(), rewriter.getNoneType(),
2570 rewriter.getUnitAttr());
2571 op.input_gate_biasMutable().assign(none_value);
2572 }
2573
2574 if (EqualsZero(op.projection_bias())) {
2575 auto none_value = rewriter.create<TFL::NoValueOp>(
2576 rewriter.getUnknownLoc(), rewriter.getNoneType(),
2577 rewriter.getUnitAttr());
2578 op.projection_biasMutable().assign(none_value);
2579 }
2580
2581 return success();
2582 }
2583 };
2584
2585 } // namespace
2586
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2587 void LSTMOp::getCanonicalizationPatterns(RewritePatternSet &results,
2588 MLIRContext *context) {
2589 results.add<RemoveLSTMOpZeroBias>(context);
2590 }
2591
2592 //===----------------------------------------------------------------------===//
2593 // UnidirectionalSequenceLSTMOp
2594 //===----------------------------------------------------------------------===//
2595
verify()2596 mlir::LogicalResult UnidirectionalSequenceLSTMOp::verify() {
2597 UnidirectionalSequenceLSTMOp op = *this;
2598 auto operands = op.GetStatefulOperands();
2599 if (operands.size() == 2 && operands[0] == 18 && operands[1] == 19) {
2600 return success();
2601 }
2602 return op.emitError(
2603 "UnidirectionalSequenceLSTMOp expected to have two stateful operands");
2604 }
2605
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2606 LogicalResult UnidirectionalSequenceLSTMOp::inferReturnTypes(
2607 MLIRContext *, Optional<Location>, ValueRange operands, DictionaryAttr attr,
2608 RegionRange, SmallVectorImpl<Type> &inferredReturnTypes) {
2609 Value input = operands[0];
2610 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
2611
2612 Value output_state = operands[18];
2613 auto output_state_type =
2614 output_state.getType().dyn_cast_or_null<RankedTensorType>();
2615
2616 if (input_type && input_type.hasRank() && input_type.getRank() != 3) {
2617 return failure();
2618 }
2619
2620 if (output_state_type && output_state_type.hasRank() &&
2621 output_state_type.getRank() != 2) {
2622 return failure();
2623 }
2624
2625 if (!input_type || !input_type.hasRank() || !output_state_type ||
2626 !output_state_type.hasRank()) {
2627 // We cannot infer the output shape since we don't know the input shape or
2628 // the output state shape. We will set the output shape as unranked.
2629 Type result_type;
2630 result_type = UnrankedTensorType::get(
2631 input.getType().cast<ShapedType>().getElementType());
2632 inferredReturnTypes.assign({result_type});
2633 return success();
2634 }
2635
2636 // Default to non-time_major.
2637 Optional<mlir::NamedAttribute> time_major_attr = attr.getNamed("time_major");
2638 bool time_majored =
2639 time_major_attr ? time_major_attr->getValue().cast<BoolAttr>().getValue()
2640 : false;
2641
2642 int batch =
2643 time_majored ? input_type.getDimSize(1) : input_type.getDimSize(0);
2644 int time = time_majored ? input_type.getDimSize(0) : input_type.getDimSize(1);
2645 int n_output = output_state_type.getDimSize(1);
2646
2647 // Build the output shape.
2648 SmallVector<int64_t, 3> output_shape;
2649 if (time_majored) {
2650 output_shape = {time, batch, n_output};
2651 } else {
2652 output_shape = {batch, time, n_output};
2653 }
2654 auto result_type =
2655 mlir::RankedTensorType::get(output_shape, input_type.getElementType());
2656
2657 inferredReturnTypes.assign({result_type});
2658 return success();
2659 }
2660
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)2661 bool UnidirectionalSequenceLSTMOp::isCompatibleReturnTypes(TypeRange lhs,
2662 TypeRange rhs) {
2663 if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
2664 if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
2665 return true;
2666 }
2667
2668 //===----------------------------------------------------------------------===//
2669 // BidirectionalSequenceLSTMOp
2670 //===----------------------------------------------------------------------===//
2671
verify()2672 mlir::LogicalResult BidirectionalSequenceLSTMOp::verify() {
2673 BidirectionalSequenceLSTMOp op = *this;
2674 auto operands = op.GetStatefulOperands();
2675 if (operands.size() == 4 && operands[0] == 35 && operands[1] == 36 &&
2676 operands[2] == 37 && operands[3] == 38) {
2677 return success();
2678 }
2679 return op.emitError(
2680 "BidirectionalSequenceLSTMOp expected to have four stateful operands");
2681 }
2682
2683 //===----------------------------------------------------------------------===//
2684 // UnidirectionalSequenceRNNOp
2685 //===----------------------------------------------------------------------===//
2686
verify()2687 mlir::LogicalResult UnidirectionalSequenceRNNOp::verify() {
2688 UnidirectionalSequenceRNNOp op = *this;
2689 auto operands = op.GetStatefulOperands();
2690 if (operands.size() == 1 && operands[0] == 4) {
2691 return success();
2692 }
2693 return op.emitError(
2694 "UnidirectionalSequenceRNNOp expected to have one stateful operand");
2695 }
2696
2697 //===----------------------------------------------------------------------===//
2698 // SvdfOp
2699 //===----------------------------------------------------------------------===//
2700
verify()2701 mlir::LogicalResult SVDFOp::verify() {
2702 SVDFOp op = *this;
2703 auto operands = op.GetStatefulOperands();
2704 if (operands.size() == 1 && operands[0] == 4) {
2705 return success();
2706 }
2707 return op.emitError("SvdfOp expected to have one stateful operand");
2708 }
2709
2710 //===----------------------------------------------------------------------===//
2711 // AbsOp
2712 //===----------------------------------------------------------------------===//
2713
fold(ArrayRef<Attribute> operands)2714 OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
2715 Type result_type = getType();
2716 // Only constant fold for tensor of f32 is implemented.
2717 if (!IsF32ShapedType(result_type)) return nullptr;
2718
2719 auto compute = [](APFloat value) -> APFloat { return llvm::abs(value); };
2720 return ConstFoldUnaryOp(result_type, operands[0], compute);
2721 }
2722
2723 //===----------------------------------------------------------------------===//
2724 // NegOp
2725 //===----------------------------------------------------------------------===//
2726
fold(ArrayRef<Attribute> operands)2727 OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
2728 Type result_type = getType();
2729 // Only constant fold for tensor of f32 is implemented.
2730 if (!IsF32ShapedType(result_type)) return nullptr;
2731
2732 auto compute = [](APFloat value) -> APFloat { return llvm::neg(value); };
2733 return ConstFoldUnaryOp(result_type, operands[0], compute);
2734 }
2735
2736 //===----------------------------------------------------------------------===//
2737 // SinOp
2738 //===----------------------------------------------------------------------===//
2739
fold(ArrayRef<Attribute> operands)2740 OpFoldResult SinOp::fold(ArrayRef<Attribute> operands) {
2741 Type result_type = getType();
2742 // Only constant fold for tensor of f32 is implemented.
2743 if (!IsF32ShapedType(result_type)) return nullptr;
2744
2745 auto compute = [](APFloat value) -> APFloat {
2746 float f = value.convertToFloat();
2747 float result = std::sin(f);
2748 return APFloat(result);
2749 };
2750 return ConstFoldUnaryOp(result_type, operands[0], compute);
2751 }
2752
2753 //===----------------------------------------------------------------------===//
2754 // CosOp
2755 //===----------------------------------------------------------------------===//
2756
fold(ArrayRef<Attribute> operands)2757 OpFoldResult CosOp::fold(ArrayRef<Attribute> operands) {
2758 Type result_type = getType();
2759 // Only constant fold for tensor of f32 is implemented.
2760 if (!IsF32ShapedType(result_type)) return nullptr;
2761
2762 auto compute = [](APFloat value) -> APFloat {
2763 float f = value.convertToFloat();
2764 float result = std::cos(f);
2765 return APFloat(result);
2766 };
2767 return ConstFoldUnaryOp(result_type, operands[0], compute);
2768 }
2769
2770 //===----------------------------------------------------------------------===//
2771 // LogOp
2772 //===----------------------------------------------------------------------===//
2773
fold(ArrayRef<Attribute> operands)2774 OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
2775 Type result_type = getType();
2776 // Only constant fold for tensor of f32 is implemented.
2777 if (!IsF32ShapedType(result_type)) return nullptr;
2778
2779 auto compute = [](APFloat value) -> APFloat {
2780 float f = value.convertToFloat();
2781 float result = std::log(f);
2782 return APFloat(result);
2783 };
2784 return ConstFoldUnaryOp(result_type, operands[0], compute);
2785 }
2786
2787 //===----------------------------------------------------------------------===//
2788 // ShapeOp
2789 //===----------------------------------------------------------------------===//
2790
fold(ArrayRef<Attribute> operands)2791 OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
2792 auto input_type = input().getType().cast<ShapedType>();
2793 if (!input_type.hasStaticShape()) return nullptr;
2794
2795 ArrayRef<int64_t> shape = input_type.getShape();
2796 auto result_type = getType().cast<ShapedType>();
2797 if (result_type.getElementType().isInteger(64)) {
2798 return DenseElementsAttr::get<int64_t>(result_type, shape);
2799 } else if (result_type.getElementType().isInteger(32)) {
2800 SmallVector<int32_t, 4> shape_i32;
2801 shape_i32.reserve(shape.size());
2802 for (int64_t dim : shape) {
2803 shape_i32.push_back(dim);
2804 }
2805 return DenseElementsAttr::get<int32_t>(result_type, shape_i32);
2806 }
2807 return nullptr;
2808 }
2809
2810 //===----------------------------------------------------------------------===//
2811 // SqrtOp
2812 //===----------------------------------------------------------------------===//
2813
fold(ArrayRef<Attribute> operands)2814 OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
2815 Type result_type = getType();
2816 // Only constant fold for tensor of f32 is implemented.
2817 if (!IsF32ShapedType(result_type)) return nullptr;
2818
2819 auto compute = [](APFloat value) -> APFloat {
2820 float f = value.convertToFloat();
2821 float result = std::sqrt(f);
2822 return APFloat(result);
2823 };
2824 return ConstFoldUnaryOp(result_type, operands[0], compute);
2825 }
2826
2827 //===----------------------------------------------------------------------===//
2828 // RsqrtOp
2829 //===----------------------------------------------------------------------===//
2830
fold(ArrayRef<Attribute> operands)2831 OpFoldResult RsqrtOp::fold(ArrayRef<Attribute> operands) {
2832 Type result_type = getType();
2833 // Only constant fold for tensor of f32/bf16 is implemented.
2834 if (!IsF32ShapedType(result_type) && !IsBF16ShapedType(result_type))
2835 return nullptr;
2836
2837 auto compute = [](APFloat value) -> APFloat {
2838 bool loseInfo;
2839 const llvm::fltSemantics &original_float_semantics = value.getSemantics();
2840 value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
2841 &loseInfo);
2842 float f = value.convertToFloat();
2843 APFloat result(1.f / std::sqrt(f));
2844 result.convert(original_float_semantics, APFloat::rmNearestTiesToEven,
2845 &loseInfo);
2846 return result;
2847 };
2848 return ConstFoldUnaryOp(result_type, operands[0], compute);
2849 }
2850
2851 //===----------------------------------------------------------------------===//
2852 // SquareOp
2853 //===----------------------------------------------------------------------===//
2854
fold(ArrayRef<Attribute> operands)2855 OpFoldResult SquareOp::fold(ArrayRef<Attribute> operands) {
2856 Type result_type = getType();
2857 // Only constant fold for tensor of f32 is implemented.
2858 if (!IsF32ShapedType(result_type)) return nullptr;
2859
2860 auto compute = [](APFloat value) -> APFloat { return value * value; };
2861 return ConstFoldUnaryOp(result_type, operands[0], compute);
2862 }
2863
2864 //===----------------------------------------------------------------------===//
2865 // RankOp
2866 //===----------------------------------------------------------------------===//
2867
fold(ArrayRef<Attribute> operands)2868 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
2869 assert(operands.size() == 1);
2870 auto result_type = getType().cast<ShapedType>();
2871 if (auto elements_attr = operands[0].dyn_cast_or_null<ElementsAttr>()) {
2872 auto rank = static_cast<int32_t>(elements_attr.getType().getRank());
2873 return DenseElementsAttr::get(result_type, {rank});
2874 }
2875
2876 // Also fold if `input` has a known rank.
2877 auto input_type = input().getType().cast<ShapedType>();
2878 // Do not fold if rank is zero because the TFLite converter doesn't
2879 // distinguish between unranked input and scalar input due to b/138865275.
2880 // TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
2881 // predicate and fold the op when rank is zero.
2882 if (input_type.hasRank() && input_type.getRank() != 0) {
2883 auto rank = static_cast<int32_t>(input_type.getRank());
2884 return DenseElementsAttr::get(result_type, {rank});
2885 }
2886
2887 return nullptr;
2888 }
2889
2890 //===----------------------------------------------------------------------===//
2891 // ConstOp
2892 //===----------------------------------------------------------------------===//
2893
fold(ArrayRef<Attribute> operands)2894 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
2895 assert(operands.empty() && "constant has no operands");
2896 // Return the held attribute value.
2897 return value();
2898 }
2899
isCompatibleReturnTypes(TypeRange l,TypeRange r)2900 bool ConstOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2901 // Allow the type inferred to not match exactly the inferred type as the
2902 // inferred type is from the element attribute's type while the op may have
2903 // gotten constructed from TF const op or be in a partial state of shape
2904 // refinement, so allow it to only be compatible. The op will be refined
2905 // during shape inference and casts inserted as needed to satisfy type
2906 // constraints of consumers.
2907 return succeeded(verifyCompatibleShapes(l, r));
2908 }
2909
2910 namespace {
2911 struct FoldPseudoConstOp : public OpRewritePattern<ConstOp> {
2912 using OpRewritePattern<ConstOp>::OpRewritePattern;
2913
matchAndRewritemlir::TFL::__anon216e30ea2011::FoldPseudoConstOp2914 LogicalResult matchAndRewrite(ConstOp const_op,
2915 PatternRewriter &rewriter) const override {
2916 if (arith::ConstantOp::isBuildableWith(const_op.value(),
2917 const_op.getType())) {
2918 rewriter.replaceOpWithNewOp<arith::ConstantOp>(const_op,
2919 const_op.value());
2920 return success();
2921 } else if (TFL::NoValueOp::isBuildableWith(const_op.value(),
2922 const_op.getType())) {
2923 rewriter.replaceOpWithNewOp<NoValueOp>(const_op, rewriter.getNoneType(),
2924 const_op.value().cast<UnitAttr>());
2925 return success();
2926 }
2927 return failure();
2928 }
2929 };
2930
2931 } // namespace
2932
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2933 void ConstOp::getCanonicalizationPatterns(RewritePatternSet &results,
2934 MLIRContext *context) {
2935 results.add<FoldPseudoConstOp>(context);
2936 }
2937
2938 //===----------------------------------------------------------------------===//
2939 // CastOp
2940 //===----------------------------------------------------------------------===//
2941
fold(ArrayRef<Attribute> operands)2942 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
2943 assert(operands.size() == 1);
2944 if (getElementTypeOrSelf(input()) == getElementTypeOrSelf(getType())) {
2945 return input();
2946 }
2947
2948 // For now, only supports cast between integer types.
2949 auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
2950 if (!elements_attr) {
2951 return nullptr;
2952 }
2953
2954 auto result_element_type =
2955 getType().cast<ShapedType>().getElementType().dyn_cast<IntegerType>();
2956 auto operand_element_type = input()
2957 .getType()
2958 .cast<ShapedType>()
2959 .getElementType()
2960 .dyn_cast<IntegerType>();
2961 // Returns nullptr if either result/operand element type is not integer.
2962 if (!result_element_type || !operand_element_type) {
2963 return nullptr;
2964 }
2965
2966 const bool is_unsigned = operand_element_type.isUnsigned();
2967 const bool involves_bool = operand_element_type.getWidth() == 1 ||
2968 result_element_type.getWidth() == 1;
2969 const int output_bitwidth = result_element_type.getWidth();
2970 // The integer cast op is the same as C integer cast. Depends on the operand
2971 // type's signedness, we will determine whether or not sign extension is
2972 // needed.
2973 auto cast = [&](APInt value) {
2974 if (involves_bool) {
2975 // Handle boolean inputs or outputs explicitly as it doesn't have the same
2976 // behavior as extension or truncation.
2977 // true input should always be cast to 1 and not -1 as the sign extension
2978 // would do for signed outputs. Similarly, non-zero inputs should be cast
2979 // to true. Truncating even numbers to one bit will result in `false`.
2980 return APInt(result_element_type.getWidth(), value != 0);
2981 }
2982 return is_unsigned ? value.zextOrTrunc(output_bitwidth)
2983 : value.sextOrTrunc(output_bitwidth);
2984 };
2985
2986 return elements_attr.mapValues(result_element_type, cast);
2987 }
2988
2989 //===----------------------------------------------------------------------===//
2990 // SelectV2Op
2991 //===----------------------------------------------------------------------===//
2992
BuildSelectV2Op(Builder * builder,OperationState & result,Value cond,Value x,Value y)2993 static void BuildSelectV2Op(Builder *builder, OperationState &result,
2994 Value cond, Value x, Value y) {
2995 auto operand_type =
2996 OpTrait::util::getBroadcastedType(x.getType(), y.getType());
2997
2998 if (!operand_type)
2999 emitError(result.location) << "non-broadcastable operands: " << x.getType()
3000 << " and " << y.getType();
3001
3002 bool has_static_cond_shape = false;
3003 bool has_static_operand_shape = false;
3004 ArrayRef<int64_t> cond_shape;
3005 ArrayRef<int64_t> operand_shape;
3006
3007 if (auto shaped_type = cond.getType().dyn_cast<ShapedType>()) {
3008 if (shaped_type.hasStaticShape()) {
3009 has_static_cond_shape = true;
3010 cond_shape = shaped_type.getShape();
3011 }
3012 }
3013 if (auto shaped_type = operand_type.dyn_cast<ShapedType>()) {
3014 if (shaped_type.hasStaticShape()) {
3015 has_static_operand_shape = true;
3016 operand_shape = shaped_type.getShape();
3017 }
3018 }
3019
3020 SmallVector<int64_t, 4> broadcastedShape;
3021 if (has_static_cond_shape && has_static_operand_shape &&
3022 !OpTrait::util::getBroadcastedShape(cond_shape, operand_shape,
3023 broadcastedShape)) {
3024 emitError(result.location) << "non-broadcastable operands: " << operand_type
3025 << " and " << cond.getType();
3026 }
3027
3028 result.addOperands({cond, x, y});
3029
3030 auto elementType = x.getType().dyn_cast<ShapedType>().getElementType();
3031 if (has_static_cond_shape && has_static_operand_shape) {
3032 result.types.push_back(
3033 RankedTensorType::get(broadcastedShape, elementType));
3034 } else {
3035 result.types.push_back(UnrankedTensorType::get(elementType));
3036 }
3037 }
3038
3039 //===----------------------------------------------------------------------===//
3040 // RangeOp
3041 //===----------------------------------------------------------------------===//
3042
3043 namespace {
3044
3045 // Compute the length of a range (1-D) tensor given `start`, `limit`, `delta`.
3046 // Template parameter `FloatOrInt` must be standard C integer or floating-point
3047 // types.
3048 template <typename FloatOrInt>
GetLengthOfRange(FloatOrInt start,FloatOrInt limit,FloatOrInt delta)3049 int GetLengthOfRange(FloatOrInt start, FloatOrInt limit, FloatOrInt delta) {
3050 // Refer to the implementation in
3051 // tensorflow/lite/kernels/range.cc.
3052 return std::is_integral<FloatOrInt>::value
3053 ? ((std::abs(limit - start) + std::abs(delta) - 1) /
3054 std::abs(delta))
3055 : std::ceil(std::abs((limit - start) / delta));
3056 }
3057
3058 // Builds a constant range tensor of `result_elem_type` elements.
3059 // Template parameter `FloatOrIntAtrr` must be mlir::IntegerAttr or
3060 // mlir::FloatAttr.
3061 template <typename FloatOrIntAtrr>
BuildConstRangeTensor(Type result_elem_type,int num_elements,FloatOrIntAtrr start_attr,FloatOrIntAtrr delta_attr)3062 DenseElementsAttr BuildConstRangeTensor(Type result_elem_type, int num_elements,
3063 FloatOrIntAtrr start_attr,
3064 FloatOrIntAtrr delta_attr) {
3065 using ValueType = typename FloatOrIntAtrr::ValueType; // APInt or APFloat
3066 ValueType start = start_attr.getValue();
3067 ValueType delta = delta_attr.getValue();
3068
3069 SmallVector<ValueType, 16> new_values;
3070 new_values.reserve(num_elements);
3071 ValueType new_value = start;
3072 for (int i = 0; i < num_elements; ++i) {
3073 new_values.push_back(new_value);
3074 new_value = new_value + delta;
3075 }
3076 // Result is always a 1-D tensor.
3077 auto new_result_type =
3078 RankedTensorType::get({num_elements}, result_elem_type);
3079 return DenseElementsAttr::get(new_result_type, new_values);
3080 }
3081 } // namespace
3082
fold(ArrayRef<Attribute> operands)3083 OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
3084 assert(operands.size() == 3);
3085 auto start_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
3086 auto limit_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
3087 auto delta_tensor = operands[2].dyn_cast_or_null<ElementsAttr>();
3088 if (start_tensor && limit_tensor && delta_tensor) {
3089 // Operands should all be scalars
3090 assert(start_tensor.getType().getRank() == 0 &&
3091 limit_tensor.getType().getRank() == 0 &&
3092 delta_tensor.getType().getRank() == 0);
3093 Type elem_type = getType().cast<ShapedType>().getElementType();
3094 if (elem_type.isSignlessInteger()) {
3095 auto start_attr = start_tensor.getValues<IntegerAttr>()[0];
3096 auto limit_attr = limit_tensor.getValues<IntegerAttr>()[0];
3097 auto delta_attr = delta_tensor.getValues<IntegerAttr>()[0];
3098 const int num_elements = GetLengthOfRange(
3099 start_attr.getInt(), limit_attr.getInt(), delta_attr.getInt());
3100 return BuildConstRangeTensor(elem_type, num_elements, start_attr,
3101 delta_attr);
3102 } else if (elem_type.isa<FloatType>()) {
3103 auto start_attr = start_tensor.getValues<FloatAttr>()[0];
3104 auto limit_attr = limit_tensor.getValues<FloatAttr>()[0];
3105 auto delta_attr = delta_tensor.getValues<FloatAttr>()[0];
3106 const int num_elements = GetLengthOfRange(start_attr.getValueAsDouble(),
3107 limit_attr.getValueAsDouble(),
3108 delta_attr.getValueAsDouble());
3109 return BuildConstRangeTensor(elem_type, num_elements, start_attr,
3110 delta_attr);
3111 }
3112 }
3113
3114 return nullptr;
3115 }
3116
3117 //===----------------------------------------------------------------------===//
3118 // TransposeConvOp
3119 //===----------------------------------------------------------------------===//
3120
verify()3121 mlir::LogicalResult TransposeConvOp::verify() {
3122 TransposeConvOp op = *this;
3123 ShapedType output_type = op.output().getType().cast<ShapedType>();
3124 ShapedType output_shape_type = op.output_shape().getType().cast<ShapedType>();
3125 if (output_type.hasRank() && output_shape_type.hasStaticShape()) {
3126 if (output_type.getRank() != output_shape_type.getDimSize(0)) {
3127 return op.emitOpError(llvm::formatv(
3128 "expect output type has rank = {0}, got output type {1}",
3129 output_shape_type.getDimSize(0), output_type));
3130 }
3131 }
3132
3133 DenseIntElementsAttr output_shape_elements;
3134 if (!matchPattern(op.output_shape(), m_Constant(&output_shape_elements))) {
3135 return success();
3136 }
3137
3138 llvm::SmallVector<int64_t, 4> output_shape;
3139 output_shape.reserve(output_shape_elements.getNumElements());
3140 for (auto dim : output_shape_elements.getValues<int>()) {
3141 output_shape.push_back(dim);
3142 }
3143
3144 auto expected_output_type =
3145 RankedTensorType::get(output_shape, output_type.getElementType());
3146 if (failed(mlir::verifyCompatibleShape(output_type, expected_output_type))) {
3147 return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
3148 expected_output_type, output_type));
3149 }
3150
3151 return success();
3152 }
3153
GetArithmeticCount(Operation * op)3154 int64_t TransposeConvOp::GetArithmeticCount(Operation *op) {
3155 int64_t count = -1;
3156 auto transpose_conv = llvm::dyn_cast<TransposeConvOp>(op);
3157 auto input_type = transpose_conv.input()
3158 .getType()
3159 .dyn_cast_or_null<mlir::RankedTensorType>();
3160 auto weight_type = transpose_conv.weights()
3161 .getType()
3162 .dyn_cast_or_null<mlir::RankedTensorType>();
3163 if (input_type && weight_type && input_type.hasStaticShape() &&
3164 weight_type.hasStaticShape()) {
3165 // Compute op count from the seven nested loops of
3166 // tflite::reference_ops::TransposeConv():
3167 count = 2 * input_type.getNumElements() * weight_type.getDimSize(0) *
3168 weight_type.getDimSize(1) * weight_type.getDimSize(2);
3169 }
3170
3171 return count;
3172 }
3173
3174 //===----------------------------------------------------------------------===//
3175 // StridedSliceOp
3176 //===----------------------------------------------------------------------===//
3177
verify()3178 LogicalResult StridedSliceOp::verify() {
3179 StridedSliceOp op = *this;
3180 auto ranked_input_type = op.input().getType().dyn_cast<RankedTensorType>();
3181
3182 // If input is unranked, there is nothing else to be verified.
3183 if (!ranked_input_type) return success();
3184 int num_input_dims = ranked_input_type.getRank();
3185
3186 if (auto begin_type = op.begin().getType().dyn_cast<RankedTensorType>()) {
3187 if (begin_type.getRank() != 1) return failure();
3188 if (begin_type.getDimSize(0) > num_input_dims) return failure();
3189 }
3190
3191 if (auto end_type = op.end().getType().dyn_cast<RankedTensorType>()) {
3192 if (end_type.getRank() != 1) return failure();
3193 if (end_type.getDimSize(0) > num_input_dims) return failure();
3194 }
3195
3196 if (auto strides_type = op.strides().getType().dyn_cast<RankedTensorType>()) {
3197 if (strides_type.getRank() != 1) return failure();
3198 if (strides_type.getDimSize(0) > num_input_dims) return failure();
3199 }
3200
3201 // The kernel will reshape the input tensor with new axis, it only supports
3202 // this reshaped tensor up to 5D.
3203 uint32_t ellipsis_mask = op.ellipsis_mask();
3204 uint32_t new_axis_mask = op.new_axis_mask();
3205 int num_added_axis = 0;
3206 for (int i = 0; i < 8; ++i) {
3207 if (!((1 << i) & ellipsis_mask) && ((1 << i) & new_axis_mask)) {
3208 num_added_axis++;
3209 }
3210 }
3211 if (num_input_dims + num_added_axis > 5) return failure();
3212 return success();
3213 }
3214
fold(ArrayRef<Attribute> operands)3215 OpFoldResult StridedSliceOp::fold(ArrayRef<Attribute> operands) {
3216 // Currently only support all masks being 0.
3217 if (begin_mask() != 0 || end_mask() != 0 || ellipsis_mask() != 0 ||
3218 new_axis_mask() != 0 || shrink_axis_mask() != 0)
3219 return {};
3220
3221 auto input_type = input().getType().dyn_cast_or_null<RankedTensorType>();
3222 if (!input_type || !input_type.hasStaticShape()) return {};
3223
3224 // Begin has to be all 0s.
3225 DenseIntElementsAttr begin_dense_elem_attr;
3226 if (!matchPattern(begin(), m_Constant(&begin_dense_elem_attr))) {
3227 return {};
3228 }
3229 for (auto begin_ele : begin_dense_elem_attr) {
3230 if (begin_ele.getSExtValue() != 0) {
3231 return {};
3232 }
3233 }
3234
3235 // Strides has to be all 1s.
3236 DenseIntElementsAttr strides_dense_elem_attr;
3237 if (!matchPattern(strides(), m_Constant(&strides_dense_elem_attr))) {
3238 return {};
3239 }
3240 for (auto stride_ele : strides_dense_elem_attr) {
3241 if (stride_ele.getSExtValue() != 1) {
3242 return {};
3243 }
3244 }
3245 // End has to map the input shape.
3246 DenseIntElementsAttr end_dense_elem_attr;
3247 if (!matchPattern(end(), m_Constant(&end_dense_elem_attr))) {
3248 return {};
3249 }
3250 int i = 0;
3251 for (auto end_ele : end_dense_elem_attr) {
3252 if (end_ele.getSExtValue() != input_type.getDimSize(i)) {
3253 return {};
3254 }
3255 ++i;
3256 }
3257
3258 return input();
3259 }
3260
3261 //===----------------------------------------------------------------------===//
3262 // TransposeOp
3263 //===----------------------------------------------------------------------===//
3264
3265 namespace {
3266
3267 // Computes the permutation of a constant `input_tensor` according to `perm`.
3268 // The function recursively traverses the dimensions of the output tensor in
3269 // a row-major order and writes the value in the output tensor into
3270 // `new_values`.
ComputePermutation(ElementsAttr input_tensor,ArrayRef<int32_t> perm,ArrayRef<int64_t> output_shape,int num_dimensions,int output_axis,std::vector<uint64_t> * input_indices,std::vector<Attribute> * new_values)3271 void ComputePermutation(ElementsAttr input_tensor, ArrayRef<int32_t> perm,
3272 ArrayRef<int64_t> output_shape, int num_dimensions,
3273 int output_axis, std::vector<uint64_t> *input_indices,
3274 std::vector<Attribute> *new_values) {
3275 // Refer to the implementation of `Transpose` function in
3276 // tensorflow/lite/kernels/internal/reference/reference_ops.h
3277 assert(output_axis < num_dimensions);
3278 const int input_axis = perm[output_axis];
3279 for (int i = 0; i < output_shape[output_axis]; ++i) {
3280 // Update the input indices on `input_axis`.
3281 input_indices->at(input_axis) = i;
3282 // Write the value from `input_tensor` if it is the last axis or
3283 // recurse into the next axis.
3284 const bool is_last_axis = output_axis == num_dimensions - 1;
3285 if (is_last_axis) {
3286 new_values->push_back(
3287 input_tensor.getValues<Attribute>()[*input_indices]);
3288 } else {
3289 ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
3290 output_axis + 1, input_indices, new_values);
3291 }
3292 }
3293 }
3294
3295 } // namespace
3296
fold(ArrayRef<Attribute> operands)3297 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
3298 assert(operands.size() == 2);
3299 auto input_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
3300 auto perm_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
3301 if (!input_tensor || !perm_tensor) return nullptr;
3302
3303 // Do not try to fold elements attr of a quant type because
3304 // DenseElementsAttr does not support it.
3305 if (!getType().cast<ShapedType>().getElementType().isSignlessIntOrFloat())
3306 return nullptr;
3307
3308 assert(perm_tensor.getType().getRank() == 1);
3309 const int num_dimensions = input_tensor.getType().getRank();
3310 assert(perm_tensor.getType().getNumElements() == num_dimensions);
3311
3312 ArrayRef<int64_t> input_shape = input_tensor.getType().getShape();
3313 auto output_type = getType().cast<ShapedType>();
3314
3315 SmallVector<int32_t, 4> perm;
3316 SmallVector<int64_t, 4> output_shape;
3317 for (int i = 0; i < num_dimensions; ++i) {
3318 perm.push_back(perm_tensor.getValues<IntegerAttr>()[i].getInt());
3319 output_shape.push_back(input_shape[perm[i]]);
3320
3321 // Check that the derived output shape matches the static shape.
3322 assert(!output_type.hasStaticShape() ||
3323 output_type.getShape()[i] == output_shape[i]);
3324 }
3325
3326 std::vector<Attribute> new_values;
3327 new_values.reserve(input_tensor.getType().getNumElements());
3328 std::vector<uint64_t> input_indices(num_dimensions);
3329 ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
3330 /*output_axis=*/0, &input_indices, &new_values);
3331 auto result_type =
3332 RankedTensorType::get(output_shape, output_type.getElementType());
3333 return DenseElementsAttr::get(result_type, new_values);
3334 }
3335
verify()3336 mlir::LogicalResult TransposeOp::verify() {
3337 TransposeOp op = *this;
3338 auto input_type = op.input().getType().cast<ShapedType>();
3339 auto perm_type = op.perm().getType().cast<ShapedType>();
3340 auto output_type = op.output().getType().cast<ShapedType>();
3341 if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
3342 if (perm_type.getNumElements() != input_type.getRank()) {
3343 return op.emitOpError(
3344 "perm tensor elements size is not equal to input tensor rank");
3345 }
3346 }
3347
3348 DenseIntElementsAttr perm;
3349 if (!matchPattern(op.perm(), m_Constant(&perm))) {
3350 return success();
3351 }
3352
3353 int index = 0;
3354 llvm::SmallVector<int64_t, 4> axes;
3355 for (const auto &axis_int : perm.getValues<APInt>()) {
3356 const int64_t axis = axis_int.getSExtValue();
3357 if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) {
3358 return op.emitOpError(
3359 llvm::formatv("perm[{0}] must be in [0, rank)", index));
3360 }
3361 if (std::count(axes.begin(), axes.end(), axis) > 0) {
3362 return op.emitOpError(
3363 llvm::formatv("perm[{0}] cannot have duplicated axis", index));
3364 }
3365 axes.push_back(axis);
3366 index++;
3367 }
3368
3369 if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
3370 llvm::SmallVector<int64_t, 4> transposed_shape;
3371 for (int64_t axis : axes) {
3372 transposed_shape.push_back(input_type.getDimSize(axis));
3373 }
3374 auto expected_output_type =
3375 RankedTensorType::get(transposed_shape, input_type.getElementType());
3376 if (failed(
3377 mlir::verifyCompatibleShape(output_type, expected_output_type))) {
3378 return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
3379 expected_output_type, output_type));
3380 }
3381 }
3382
3383 // Verify the quantized axis if the type is UniformQuantizedPerAxisType. Other
3384 // verifications to make sure the input and output has the same quantization
3385 // type, scale and zero point are performed by the SameOperandsAndResultsScale
3386 // trait.
3387 auto in_per_axis_qtype =
3388 QuantizedType::getQuantizedElementType(input_type)
3389 .dyn_cast_or_null<quant::UniformQuantizedPerAxisType>();
3390 auto out_per_axis_qtype =
3391 QuantizedType::getQuantizedElementType(output_type)
3392 .dyn_cast_or_null<quant::UniformQuantizedPerAxisType>();
3393 if (in_per_axis_qtype && out_per_axis_qtype) {
3394 if (out_per_axis_qtype.getQuantizedDimension() < axes.size() &&
3395 axes[out_per_axis_qtype.getQuantizedDimension()] !=
3396 in_per_axis_qtype.getQuantizedDimension()) {
3397 return op.emitOpError(
3398 "has mismatched quantized axes of input and output");
3399 }
3400 }
3401
3402 return success();
3403 }
3404
BuildTransposeOp(OpBuilder * builder,OperationState & result,Value input,Value perm)3405 static void BuildTransposeOp(OpBuilder *builder, OperationState &result,
3406 Value input, Value perm) {
3407 // Output size is only known if input is ranked and perm is a constant.
3408 auto input_type = input.getType().cast<TensorType>();
3409 DenseIntElementsAttr perm_const;
3410 if (!input_type.hasRank() || !matchPattern(perm, m_Constant(&perm_const)) ||
3411 perm_const.empty()) {
3412 TFL::TransposeOp::build(
3413 *builder, result, UnrankedTensorType::get(input_type.getElementType()),
3414 input, perm);
3415 return;
3416 }
3417
3418 const auto perm_value_it = perm_const.value_begin<APInt>();
3419
3420 const ArrayRef<int64_t> input_shape = input_type.getShape();
3421 SmallVector<int64_t, 4> output_shape(input_shape.size());
3422
3423 for (int i = 0; i < output_shape.size(); ++i) {
3424 const APInt perm_val = perm_value_it[i];
3425 output_shape[i] = input_shape[perm_val.getSExtValue()];
3426 }
3427
3428 auto element_type = input_type.getElementType();
3429 // For UniformQuantizedPerAxisType element type, the quantized dimension
3430 // should be changed corresponding with the transpose.
3431 auto per_axis_qtype =
3432 QuantizedType::getQuantizedElementType(input_type)
3433 .dyn_cast_or_null<quant::UniformQuantizedPerAxisType>();
3434 if (per_axis_qtype) {
3435 int32_t quantized_dimension = per_axis_qtype.getQuantizedDimension();
3436 for (int i = 0; i < output_shape.size(); ++i) {
3437 const APInt perm_val = perm_value_it[i];
3438 if (perm_val.getSExtValue() == quantized_dimension) {
3439 quantized_dimension = i;
3440 break;
3441 }
3442 }
3443 element_type = quant::UniformQuantizedPerAxisType::get(
3444 per_axis_qtype.getFlags(), per_axis_qtype.getStorageType(),
3445 per_axis_qtype.getExpressedType(), per_axis_qtype.getScales(),
3446 per_axis_qtype.getZeroPoints(), quantized_dimension,
3447 per_axis_qtype.getStorageTypeMin(), per_axis_qtype.getStorageTypeMax());
3448 }
3449
3450 TFL::TransposeOp::build(*builder, result,
3451 RankedTensorType::get(output_shape, element_type),
3452 input, perm);
3453 }
3454
3455 //===----------------------------------------------------------------------===//
3456 // IfOp
3457 //===----------------------------------------------------------------------===//
3458
3459 /// Given the region at `index`, or the parent operation if `index` is None,
3460 /// return the successor regions. These are the regions that may be selected
3461 /// during the flow of control. `operands` is a set of optional attributes that
3462 /// correspond to a constant value for each operand, or null if that operand is
3463 /// not a constant.
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)3464 void IfOp::getSuccessorRegions(Optional<unsigned> index,
3465 ArrayRef<Attribute> operands,
3466 SmallVectorImpl<RegionSuccessor> ®ions) {
3467 // The `then` and the `else` region branch back to the parent operation.
3468 if (index.has_value()) {
3469 regions.push_back(RegionSuccessor(getResults()));
3470 return;
3471 }
3472
3473 // Don't consider the else region if it is empty.
3474 Region *else_reg = &else_region();
3475 if (else_reg->empty()) else_reg = nullptr;
3476
3477 // Otherwise, the successor is dependent on the condition.
3478 bool condition;
3479 if (auto cond_attr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
3480 condition = cond_attr.getValue().isOneValue();
3481 } else {
3482 // If the condition isn't constant, both regions may be executed.
3483 regions.push_back(RegionSuccessor(&then_region()));
3484 // If the else region does not exist, it is not a viable successor.
3485 if (else_reg) regions.push_back(RegionSuccessor(else_reg));
3486 return;
3487 }
3488
3489 // Add the successor regions using the condition.
3490 regions.push_back(RegionSuccessor(condition ? &then_region() : else_reg));
3491 }
3492
3493 //===----------------------------------------------------------------------===//
3494 // PolyCallOp
3495 //===----------------------------------------------------------------------===//
3496
3497 namespace {
3498 // Canonicalize converted TF ops into PolymorphicCall op so different
3499 // representations are preserved.
3500 struct PolyCallResultOperandsMatchAndImplicitCapture
3501 : public OpRewritePattern<PolyCallOp> {
3502 using OpRewritePattern<PolyCallOp>::OpRewritePattern;
3503
matchAndRewritemlir::TFL::__anon216e30ea2411::PolyCallResultOperandsMatchAndImplicitCapture3504 LogicalResult matchAndRewrite(PolyCallOp while_op,
3505 PatternRewriter &rewriter) const override {
3506 // Finish this.
3507 return success();
3508 }
3509 };
3510
3511 } // namespace
3512
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3513 void PolyCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
3514 MLIRContext *context) {
3515 results.add<PolyCallResultOperandsMatchAndImplicitCapture>(context);
3516 }
3517
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)3518 void PolyCallOp::getSuccessorRegions(
3519 Optional<unsigned> index, ArrayRef<Attribute> operands,
3520 SmallVectorImpl<RegionSuccessor> ®ions) {
3521 // Defaults to first region for TFLite execution.
3522 }
3523
3524 //===----------------------------------------------------------------------===//
3525 // WhileOp
3526 //===----------------------------------------------------------------------===//
3527
verify()3528 LogicalResult WhileOp::verify() {
3529 WhileOp op = *this;
3530 if (op.getNumOperands() != op.getNumResults())
3531 return op.emitOpError(llvm::formatv(
3532 "number of operands does not match number of results ({0} != {1})",
3533 op.getNumOperands(), op.getNumResults()));
3534 if (op.cond().front().getNumArguments() !=
3535 op.body().front().getNumArguments())
3536 return op.emitOpError(llvm::formatv(
3537 "number of arguments in condition function does not match number of "
3538 "arguments in body function ({0} != {1})",
3539 op.cond().front().getNumArguments(),
3540 op.body().front().getNumArguments()));
3541 // Verify shapes are compatible.
3542 for (auto it : llvm::zip(op.cond().front().getArgumentTypes(),
3543 op.body().front().getArgumentTypes())) {
3544 if (failed(mlir::verifyCompatibleShape(std::get<0>(it), std::get<1>(it))))
3545 return op->emitOpError(llvm::formatv(
3546 "condition function's argument type does not match body "
3547 "function's argument type ({0} != {1})",
3548 std::get<0>(it), std::get<1>(it)));
3549 }
3550
3551 return success();
3552 }
3553
3554 namespace {
3555 // Canonicalize While op so that results and operands match and external values
3556 // are via implicit capture rather than via block args.
3557 struct WhileResultOperandsMatchAndImplicitCapture
3558 : public OpRewritePattern<WhileOp> {
3559 using OpRewritePattern<WhileOp>::OpRewritePattern;
3560
matchAndRewritemlir::TFL::__anon216e30ea2511::WhileResultOperandsMatchAndImplicitCapture3561 LogicalResult matchAndRewrite(WhileOp while_op,
3562 PatternRewriter &rewriter) const override {
3563 // Replace values simply passed through the body with extern values
3564 // (in both body and condition regions as well as while result). The
3565 // block arguments of body and while match and so the corresponding cond
3566 // argument can be easily found.
3567 bool unchanged = true;
3568 auto &body_block = while_op.body().front();
3569 auto &cond_block = while_op.cond().front();
3570 auto &yield = *body_block.getTerminator();
3571 for (auto ba : body_block.getArguments()) {
3572 int arg_no = ba.getArgNumber();
3573 // Skip removing resources that are not read-only variables.
3574 if (getElementTypeOrSelf(ba.getType()).isa<TF::ResourceType>()) {
3575 bool has_read_only_variables = true;
3576 for (auto user : ba.getUsers()) {
3577 // Ternimator ops, for example, tfl::yield op, should be ignored since
3578 // the argument can be used for yielding as the `body` function result
3579 // and that does not give any meaningful points to the decision
3580 // whether the given arugment is a read-only variable or not.
3581 if (user->hasTrait<OpTrait::IsTerminator>()) continue;
3582 if (!llvm::isa<mlir::TF::ReadVariableOp>(user)) {
3583 has_read_only_variables = false;
3584 break;
3585 }
3586 }
3587 if (!has_read_only_variables) continue;
3588 }
3589 if (ba == yield.getOperand(arg_no)) {
3590 unchanged = false;
3591 auto value = while_op.getOperand(arg_no);
3592 ba.replaceAllUsesWith(value);
3593 cond_block.getArgument(arg_no).replaceAllUsesWith(value);
3594
3595 // This could be relaxed and casts inserted.
3596 if (while_op.getResult(arg_no).getType() == value.getType())
3597 while_op.getResult(arg_no).replaceAllUsesWith(value);
3598 }
3599 }
3600
3601 // The While ops operands and result types need to match
3602 SmallVector<Value, 4> new_operands;
3603 SmallVector<Value, 4> new_body_yield;
3604 SmallVector<bool, 4> removed_operand(while_op.getNumOperands(), false);
3605 llvm::SmallVector<Type, 4> types;
3606 new_operands.reserve(while_op.getNumOperands());
3607 new_body_yield.reserve(while_op.getNumOperands());
3608 types.reserve(while_op.getNumOperands());
3609
3610 // Remove block arguments not used in either cond or body. This leaves the
3611 // block arguments of body and cond matching still.
3612 int arg_index = 0;
3613 for (int while_index = 0, e = while_op.getNumOperands(); while_index < e;
3614 ++while_index) {
3615 auto value = while_op.getOperand(while_index);
3616 if (body_block.getArgument(arg_index).use_empty() &&
3617 cond_block.getArgument(arg_index).use_empty() &&
3618 // Note: since we are not erasing results, need to use while_index
3619 // to check if the corresponding result is unused.
3620 while_op.getResult(while_index).use_empty()) {
3621 unchanged = false;
3622 body_block.eraseArgument(arg_index);
3623 cond_block.eraseArgument(arg_index);
3624
3625 // Mark operand for removal.
3626 removed_operand[while_index] = true;
3627 } else {
3628 new_operands.push_back(value);
3629 new_body_yield.push_back(yield.getOperand(while_index));
3630 auto type = while_op.getResult(while_index).getType();
3631 types.push_back(type);
3632 ++arg_index;
3633 }
3634 }
3635
3636 // Done if no values removed from blocks and operands & results match.
3637 if (unchanged) return failure();
3638
3639 // Replace with new While with matching operands and results.
3640 Operation *op = while_op.getOperation();
3641 Operation *new_op = rewriter.insert(
3642 Operation::create(op->getLoc(), op->getName(), types, new_operands,
3643 op->getAttrs(), {}, /*numRegions=*/2));
3644
3645 for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
3646 int new_index = 0;
3647 for (int op_index = 0, e = op->getNumResults(); op_index < e; ++op_index) {
3648 if (removed_operand[op_index]) continue;
3649 op->getResult(op_index).replaceAllUsesWith(new_op->getResult(new_index));
3650 ++new_index;
3651 }
3652 rewriter.eraseOp(op);
3653
3654 Block &new_body_block = cast<WhileOp>(new_op).body().front();
3655 rewriter.setInsertionPointToEnd(&new_body_block);
3656 rewriter.replaceOpWithNewOp<YieldOp>(new_body_block.getTerminator(),
3657 new_body_yield);
3658
3659 return success();
3660 }
3661 };
3662
3663 } // namespace
3664
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3665 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3666 MLIRContext *context) {
3667 results.add<WhileResultOperandsMatchAndImplicitCapture>(context);
3668 }
3669
getLoopBody()3670 Region &WhileOp::getLoopBody() { return body(); }
3671
isDefinedOutsideOfLoop(Value value)3672 bool WhileOp::isDefinedOutsideOfLoop(Value value) {
3673 // TODO(jpienaar): This is to overly conservative and disables anything other
3674 // than constant hoisting initially.
3675 return false;
3676 }
3677
3678 //===----------------------------------------------------------------------===//
3679 // LogisticOp
3680 //===----------------------------------------------------------------------===//
3681
GetArithmeticCount(Operation * op)3682 int64_t LogisticOp::GetArithmeticCount(Operation *op) {
3683 int64_t count;
3684 // As a very rough ballpark, the cost of evaluating a math function
3685 // such as tanh or logistic is about 32 multiplications, and about as
3686 // many additions/subtractions. (Just a power-of-two order-of-magnitude
3687 // from looking at actual implementations that we use in runtime/code).
3688 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3689 return 64 * count;
3690
3691 return -1;
3692 }
3693
3694 //===----------------------------------------------------------------------===//
3695 // LogSoftmaxOp
3696 //===----------------------------------------------------------------------===//
3697
GetArithmeticCount(Operation * op)3698 int64_t LogSoftmaxOp::GetArithmeticCount(Operation *op) {
3699 int64_t count;
3700 // As a very rough ballpark, the cost of evaluating a math function
3701 // such as tanh or logistic is about 32 multiplications, and about as
3702 // many additions/subtractions. (Just a power-of-two order-of-magnitude
3703 // from looking at actual implementations that we use in runtime/code).
3704 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3705 return 64 * count;
3706
3707 return -1;
3708 }
3709
3710 //===----------------------------------------------------------------------===//
3711 // SoftmaxOp
3712 //===----------------------------------------------------------------------===//
3713
GetArithmeticCount(Operation * op)3714 int64_t SoftmaxOp::GetArithmeticCount(Operation *op) {
3715 int64_t count;
3716 // As a very rough ballpark, the cost of evaluating a math function
3717 // such as tanh or logistic is about 32 multiplications, and about as
3718 // many additions/subtractions. (Just a power-of-two order-of-magnitude
3719 // from looking at actual implementations that we use in runtime/code).
3720 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3721 return 64 * count;
3722
3723 return -1;
3724 }
3725
3726 //===----------------------------------------------------------------------===//
3727 // TanhOp
3728 //===----------------------------------------------------------------------===//
3729
GetArithmeticCount(Operation * op)3730 int64_t TanhOp::GetArithmeticCount(Operation *op) {
3731 int64_t count;
3732 // As a very rough ballpark, the cost of evaluating a math function
3733 // such as tanh or logistic is about 32 multiplications, and about as
3734 // many additions/subtractions. (Just a power-of-two order-of-magnitude
3735 // from looking at actual implementations that we use in runtime/code).
3736 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3737 return 64 * count;
3738
3739 return -1;
3740 }
3741
3742 //===----------------------------------------------------------------------===//
3743 // AddNOp
3744 //===----------------------------------------------------------------------===//
3745
GetArithmeticCount(Operation * op)3746 int64_t AddNOp::GetArithmeticCount(Operation *op) {
3747 int64_t count;
3748 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3749 // AddN cost is roughly the same cost as N-1 Adds.
3750 const int64_t num_adds = op->getNumOperands() - 1;
3751 return num_adds * count;
3752 }
3753
3754 return -1;
3755 }
3756
3757 //===----------------------------------------------------------------------===//
3758 // AveragePool2DOp
3759 //===----------------------------------------------------------------------===//
3760
GetArithmeticCount(Operation * op)3761 int64_t AveragePool2DOp::GetArithmeticCount(Operation *op) {
3762 int64_t count;
3763 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3764 auto avg_pool = llvm::dyn_cast<AveragePool2DOp>(op);
3765 return avg_pool.filter_height() * avg_pool.filter_width() * count;
3766 }
3767
3768 return -1;
3769 }
3770
3771 //===----------------------------------------------------------------------===//
3772 // MaxPool2DOp
3773 //===----------------------------------------------------------------------===//
3774
GetArithmeticCount(Operation * op)3775 int64_t MaxPool2DOp::GetArithmeticCount(Operation *op) {
3776 int64_t count;
3777 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3778 auto max_pool = llvm::dyn_cast<MaxPool2DOp>(op);
3779 return max_pool.filter_height() * max_pool.filter_width() * count;
3780 }
3781
3782 return -1;
3783 }
3784
3785 //===----------------------------------------------------------------------===//
3786 // L2NormalizationOp
3787 //===----------------------------------------------------------------------===//
3788
GetArithmeticCount(Operation * op)3789 int64_t L2NormalizationOp::GetArithmeticCount(Operation *op) {
3790 int64_t count;
3791 // Computing the squared L2 norm is N multiply-adds so 2N ops,
3792 // then the single inverse-sqrt is negligible, then we multiply each
3793 // value by the resulting multiplier, so an extra N ops. count 3N ops.
3794 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3795 return 3 * count;
3796 }
3797
3798 return -1;
3799 }
3800
3801 //===----------------------------------------------------------------------===//
3802 // PadOp
3803 //===----------------------------------------------------------------------===//
3804
fold(ArrayRef<Attribute> operands)3805 OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
3806 if (InputOutputHasSameShape(input().getType(), output().getType()))
3807 return input();
3808
3809 return {};
3810 }
3811
3812 //===----------------------------------------------------------------------===//
3813 // PadV2Op
3814 //===----------------------------------------------------------------------===//
3815
fold(ArrayRef<Attribute> operands)3816 OpFoldResult PadV2Op::fold(ArrayRef<Attribute> operands) {
3817 if (InputOutputHasSameShape(input().getType(), output().getType()))
3818 return input();
3819
3820 return {};
3821 }
3822
3823 //===----------------------------------------------------------------------===//
3824 // NoValueOp
3825 //===----------------------------------------------------------------------===//
3826
fold(ArrayRef<Attribute> operands)3827 OpFoldResult NoValueOp::fold(ArrayRef<Attribute> operands) {
3828 return valueAttr();
3829 }
3830
isBuildableWith(Attribute value,Type type)3831 bool NoValueOp::isBuildableWith(Attribute value, Type type) {
3832 return value.isa<UnitAttr>() && type.isa<NoneType>();
3833 }
3834
GetYield()3835 YieldOp ControlNodeOp::GetYield() {
3836 return llvm::cast<YieldOp>(GetBody().back());
3837 }
3838
3839 // Checks if a TFL.control_node wraps a single operation and the single
3840 // operation results are perfectly forwarded to the wrapper's yield.
WrapsSinglePerfectlyForwardedOp()3841 bool ControlNodeOp::WrapsSinglePerfectlyForwardedOp() {
3842 auto body = GetBody().without_terminator();
3843 if (!hasSingleElement(body)) return false;
3844
3845 Operation &controlled_op = *body.begin();
3846 YieldOp yield = GetYield();
3847 return controlled_op.getNumResults() == yield.getNumOperands() &&
3848 std::equal(controlled_op.getResults().begin(),
3849 controlled_op.getResults().end(),
3850 yield.getOperands().begin());
3851 }
3852
verify()3853 mlir::LogicalResult ControlNodeOp::verify() {
3854 ControlNodeOp control_node = *this;
3855 if (!control_node.GetBody().args_empty())
3856 return control_node.emitOpError() << "expects body without any arguments";
3857
3858 Operation &yield = control_node.GetBody().back();
3859 if (!isa<YieldOp>(yield))
3860 return yield.emitOpError()
3861 << "invalid TFL.control_node terminator, yield expected";
3862
3863 // Ensure that the terminator's operands and the control_node results match in
3864 // types.
3865 const int result_count =
3866 control_node.getNumResults() - 1; // 1 for control token
3867 const int num_operands = yield.getNumOperands();
3868 if (num_operands != result_count)
3869 return yield.emitOpError()
3870 << "has " << yield.getNumOperands()
3871 << " operand, but control_node returns " << result_count;
3872 for (const int operand_idx : llvm::seq<int>(0, yield.getNumOperands())) {
3873 if (control_node.getResult(operand_idx).getType() !=
3874 yield.getOperand(operand_idx).getType())
3875 return yield.emitOpError() << "operand #" << operand_idx
3876 << " type mismatch control_node results";
3877 }
3878 return success();
3879 }
3880
print(OpAsmPrinter & p)3881 void ControlNodeOp::print(OpAsmPrinter &p) {
3882 if (getNumOperands()) {
3883 // These are always control operand, no explicit type needed.
3884 p << '(';
3885 p.printOperands(getOperands());
3886 p << ')';
3887 }
3888 // Check if we can print the short "controls" form: that is if the
3889 // control_node contains a single operation and the results of this operation
3890 // are perfectly forwarded to the yield.
3891 if (getOperation()->getAttrs().empty() && WrapsSinglePerfectlyForwardedOp()) {
3892 Operation &controlled_op = GetBody().front();
3893 // The "controls" syntax only encodes a single location.
3894 YieldOp yield_op = GetYield();
3895 // In order to correctly round-trip, we can only use this syntax when all
3896 // the locations are identical.
3897 if (controlled_op.getLoc() == getLoc() && yield_op.getLoc() == getLoc()) {
3898 p << " controls ";
3899 p.printGenericOp(&controlled_op);
3900 return;
3901 }
3902 }
3903 p << ' ';
3904 p.printRegion(getOperation()->getRegion(0));
3905 p.printOptionalAttrDict(getOperation()->getAttrs());
3906 }
3907
parse(OpAsmParser & parser,OperationState & result)3908 ParseResult ControlNodeOp::parse(OpAsmParser &parser, OperationState &result) {
3909 // Parse the body region.
3910 llvm::SMLoc loc = parser.getCurrentLocation();
3911 Type control_type = ControlType::get(parser.getBuilder().getContext());
3912
3913 // Parse optional argument list (control dependencies only).
3914 SmallVector<OpAsmParser::UnresolvedOperand, 4> op_infos;
3915 if (parser.parseOperandList(op_infos, OpAsmParser::Delimiter::OptionalParen))
3916 return failure();
3917 if (!op_infos.empty()) {
3918 SmallVector<Type, 2> types(op_infos.size(), control_type);
3919 if (parser.resolveOperands(op_infos, types, loc, result.operands))
3920 return failure();
3921 }
3922
3923 Region &body = *result.addRegion();
3924
3925 if (succeeded(parser.parseOptionalKeyword("controls"))) {
3926 // If we parse the short version of the control node, we have an operation
3927 // in the generic form that follows the "controls" keyword. Parse it inside
3928 // the region and forward all of its results as-is to the yield operation.
3929 body.push_back(new Block);
3930 Block &block = body.back();
3931 Operation *controlled_op =
3932 parser.parseGenericOperation(&block, block.begin());
3933 if (!controlled_op) return failure();
3934 OpBuilder builder(parser.getBuilder().getContext());
3935 builder.setInsertionPointToEnd(&block);
3936 builder.create<YieldOp>(controlled_op->getLoc(),
3937 controlled_op->getResults());
3938 result.location = controlled_op->getLoc();
3939 } else if (parser.parseRegion(body)) {
3940 return failure();
3941 }
3942
3943 ControlNodeOp::ensureTerminator(body, parser.getBuilder(), result.location);
3944
3945 // Get the results type for the control node from the terminator operands.
3946 Operation &yield = body.back().back();
3947 result.types.reserve(yield.getNumOperands() + 1);
3948 result.types.append(yield.operand_type_begin(), yield.operand_type_end());
3949 result.types.push_back(control_type);
3950
3951 // Parse the optional attribute list.
3952 if (parser.parseOptionalAttrDict(result.attributes)) return failure();
3953 return success();
3954 }
3955
3956 //===----------------------------------------------------------------------===//
3957 // ConstBytesAttr
3958 //===----------------------------------------------------------------------===//
3959
parse(AsmParser & parser,Type type)3960 Attribute ConstBytesAttr::parse(AsmParser &parser, Type type) {
3961 if (parser.parseColon()) {
3962 return nullptr;
3963 }
3964
3965 std::string data;
3966 if (parser.parseString(&data)) {
3967 return nullptr;
3968 }
3969 if (data.size() < 2 || data.substr(0, 2) != "0x") {
3970 parser.emitError(parser.getNameLoc(), "Hex string doesn't start with `0x`");
3971 return nullptr;
3972 }
3973
3974 std::string bytes_data = absl::HexStringToBytes(data.substr(2));
3975 return ConstBytesAttr::get(parser.getBuilder().getContext(), bytes_data);
3976 }
3977
print(mlir::AsmPrinter & printer) const3978 void ConstBytesAttr::print(mlir::AsmPrinter &printer) const {
3979 StringRef bytes_str = getValue();
3980 printer << " : \"0x" << llvm::toHex(bytes_str) << "\"";
3981 }
3982
3983 //===----------------------------------------------------------------------===//
3984 // TableGen'd op method definitions
3985 //===----------------------------------------------------------------------===//
3986
3987 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
3988
parseI32Array(AsmParser & parser)3989 static FailureOr<SmallVector<int32_t>> parseI32Array(AsmParser &parser) {
3990 SmallVector<int32_t> elements;
3991 auto elementParser = [&]() {
3992 int32_t element;
3993 if (failed(parser.parseInteger(element))) return failure();
3994 elements.push_back(element);
3995 return success();
3996 };
3997 if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
3998 elementParser))
3999 return failure();
4000 return elements;
4001 }
4002
4003 } // namespace TFL
4004 } // namespace mlir
4005
4006 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_dialect.cc.inc"
4007 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_enums.cc.inc"
4008 #define GET_ATTRDEF_CLASSES
4009 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_attrdefs.cc.inc"
4010 #define GET_OP_CLASSES
4011 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
4012
4013 namespace mlir {
4014 namespace TFL {
4015
4016 #include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc"
4017
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)4018 Operation *TFLDialect::materializeConstant(OpBuilder &builder, Attribute value,
4019 Type type, Location loc) {
4020 // If this is a constant bytes attribute or the result type doesn't match the
4021 // attribute type, then generate a tfl.pseudo_const.
4022 if (value.isa<ConstBytesAttr>() ||
4023 (value.isa<ElementsAttr>() &&
4024 value.cast<ElementsAttr>().getType() != type))
4025 return builder.create<ConstOp>(loc, type, value.cast<ElementsAttr>());
4026 if (arith::ConstantOp::isBuildableWith(value, type))
4027 return builder.create<arith::ConstantOp>(loc, type, value);
4028 if (NoValueOp::isBuildableWith(value, type))
4029 return builder.create<NoValueOp>(loc, type, value.cast<UnitAttr>());
4030 return nullptr;
4031 }
4032
4033 } // namespace TFL
4034 } // namespace mlir
4035