xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <iterator>
22 #include <limits>
23 #include <numeric>
24 #include <string>
25 #include <tuple>
26 #include <type_traits>
27 
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/Optional.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/Sequence.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/ADT/StringSwitch.h"
38 #include "llvm/ADT/iterator_range.h"
39 #include "llvm/Support/Casting.h"
40 #include "llvm/Support/FormatVariadic.h"
41 #include "llvm/Support/raw_ostream.h"
42 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
43 #include "mlir/Dialect/Traits.h"  // from @llvm-project
44 #include "mlir/IR/Attributes.h"  // from @llvm-project
45 #include "mlir/IR/Builders.h"  // from @llvm-project
46 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
47 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
48 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
49 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
50 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
51 #include "mlir/IR/Location.h"  // from @llvm-project
52 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
53 #include "mlir/IR/Matchers.h"  // from @llvm-project
54 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
55 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
56 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
57 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
58 #include "mlir/IR/Types.h"  // from @llvm-project
59 #include "mlir/IR/Value.h"  // from @llvm-project
60 #include "mlir/Parser/Parser.h"  // from @llvm-project
61 #include "mlir/Support/LLVM.h"  // from @llvm-project
62 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
63 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
64 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h"
65 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_canonicalization_helper.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_device_helper.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h"
70 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h"
71 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
72 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
73 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
74 #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h"
75 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
76 #include "tensorflow/core/framework/kernel_shape_util.h"
77 #include "tensorflow/core/platform/logging.h"
78 #include "tensorflow/core/util/padding.h"
79 #include "tensorflow/core/util/tensor_format.h"
80 
81 namespace mlir {
82 namespace TF {
83 
84 namespace {
85 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
86 }  // namespace
87 
88 //===----------------------------------------------------------------------===//
89 // AddOp
90 //===----------------------------------------------------------------------===//
91 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)92 void AddOp::getCanonicalizationPatterns(RewritePatternSet &results,
93                                         MLIRContext *context) {
94   results.add<AddToAddV2>(context);
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // AddNOp
99 //===----------------------------------------------------------------------===//
100 
fold(ArrayRef<Attribute> operands)101 OpFoldResult AddNOp::fold(ArrayRef<Attribute> operands) {
102   if (operands.size() == 1) return *inputs().begin();
103 
104   // Fold if there is only one single non-zero operand or all operands are zero.
105   int non_zero_index = -1;
106   auto IsKnownZero = [](Attribute attr) {
107     if (!attr) return false;
108     auto splat = attr.dyn_cast<SplatElementsAttr>();
109     if (!splat) return false;
110     Type element_ty = splat.getType().getElementType();
111     if (element_ty.isa<FloatType>())
112       return splat.getSplatValue<llvm::APFloat>().isZero();
113     if (element_ty.isa<IntegerType>())
114       return splat.getSplatValue<llvm::APInt>().getSExtValue() == 0;
115     return false;
116   };
117 
118   for (auto it : llvm::enumerate(operands)) {
119     if (IsKnownZero(it.value())) continue;
120     if (non_zero_index != -1) {
121       // Don't fold if we find more than 1 non-zero operand.
122       return {};
123     }
124     non_zero_index = it.index();
125   }
126 
127   // Only fold when the result shape is fully static.
128   auto result_ty = getType().dyn_cast<ShapedType>();
129   if (!result_ty || !result_ty.hasStaticShape()) return {};
130 
131   if (non_zero_index == -1) {
132     return SplatElementsAttr::get(
133         result_ty,
134         operands.begin()->cast<DenseElementsAttr>().getSplatValue<Attribute>());
135   }
136 
137   // Check the non-zero operand's shape matches the result shape.
138   if (result_ty == inputs()[non_zero_index].getType())
139     return inputs()[non_zero_index];
140   return {};
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // AddV2Op
145 //===----------------------------------------------------------------------===//
146 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)147 void AddV2Op::getCanonicalizationPatterns(RewritePatternSet &results,
148                                           MLIRContext *context) {
149   results.add<AddV2OfNegLeft, AddV2OfNegRight>(context);
150 }
151 
fold(ArrayRef<Attribute> operands)152 OpFoldResult AddV2Op::fold(ArrayRef<Attribute> operands) {
153   return IdentityArithmeticOpFolder<AddV2Op>(*this, operands);
154 }
155 
156 //===----------------------------------------------------------------------===//
157 // AllOp
158 //===----------------------------------------------------------------------===//
159 
verify()160 LogicalResult AllOp::verify() {
161   AllOp op = *this;
162   return VerifyReductionInputAndDims(op.input(), op.reduction_indices(),
163                                      op.getLoc());
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // AnyOp
168 //===----------------------------------------------------------------------===//
169 
verify()170 LogicalResult AnyOp::verify() {
171   AnyOp op = *this;
172   return VerifyReductionInputAndDims(op.input(), op.reduction_indices(),
173                                      op.getLoc());
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // AssertOp
178 //===----------------------------------------------------------------------===//
179 
180 namespace {
181 
182 // Removes Assert with constant true predicate.
183 struct AssertWithTrue : public OpRewritePattern<AssertOp> {
184   using OpRewritePattern<AssertOp>::OpRewritePattern;
185 
matchAndRewritemlir::TF::__anonc7529f2f0311::AssertWithTrue186   LogicalResult matchAndRewrite(AssertOp op,
187                                 PatternRewriter &rewriter) const override {
188     ElementsAttr cst;
189     if (matchPattern(op.condition(), m_Constant(&cst))) {
190       if (cst.getValues<bool>()[0]) {
191         rewriter.eraseOp(op);
192         return success();
193       }
194     }
195     return failure();
196   }
197 };
198 }  // namespace
199 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)200 void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
201                                            MLIRContext *context) {
202   results.add<AssertWithTrue>(context);
203 }
204 
205 //===----------------------------------------------------------------------===//
206 // BatchMatMulV2Op & BatchMatMulOp
207 //===----------------------------------------------------------------------===//
208 
209 template <typename OpT,
210           typename std::enable_if<llvm::is_one_of<
211               OpT, BatchMatMulOp, BatchMatMulV2Op>::value>::type * = nullptr>
Verify(OpT op)212 static LogicalResult Verify(OpT op) {
213   if (!HasRankAtLeast(op.x(), 2)) {
214     return op.emitOpError("requires lhs operand to have rank at least two");
215   }
216   if (!HasRankAtLeast(op.y(), 2)) {
217     return op.emitOpError("requires rhs operand to have rank at least two");
218   }
219 
220   RankedTensorType x_ty = GetRankedTensorTypeForOperand(op.x());
221   RankedTensorType y_ty = GetRankedTensorTypeForOperand(op.y());
222 
223   if (!x_ty || !y_ty) return success();
224 
225   ArrayRef<int64_t> x_shape = x_ty.getShape();
226   ArrayRef<int64_t> y_shape = y_ty.getShape();
227 
228   llvm::SmallVector<int64_t, 4> result_batch_shape;
229   llvm::ArrayRef<int64_t> x_batches = x_shape.drop_back(2);
230   llvm::ArrayRef<int64_t> y_batches = y_shape.drop_back(2);
231 
232   // Check compatibility of batch dimensions if both input shapes are known.
233   // BatchMatMul should have exactly the same batch dimensions and
234   // BatchMatMulV2 should have broadcastable batch dimensions.
235   //
236   // The last two dimensions are non-batch dimensions that don't need to
237   // participate in batch dimension compatibility check.
238   if (std::is_same<OpT, BatchMatMulOp>()) {
239     for (const auto &dim_pairs : llvm::zip(x_batches, y_batches)) {
240       int64_t x_dim = std::get<0>(dim_pairs);
241       int64_t y_dim = std::get<1>(dim_pairs);
242       if (!ShapedType::isDynamic(x_dim) && !ShapedType::isDynamic(y_dim) &&
243           x_dim != y_dim) {
244         return op.emitOpError()
245                << "found mismatching batch dimensions for lhs shape " << x_ty
246                << " and rhs shape " << y_ty;
247       }
248     }
249   } else {
250     if (!OpTrait::util::getBroadcastedShape(x_batches, y_batches,
251                                             result_batch_shape))
252       return op.emitOpError()
253              << "found incompatible broadcast batch dimensions for lhs shape "
254              << x_ty << " and rhs shape " << y_ty;
255   }
256 
257   RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output());
258   if (!output_ty) return success();
259 
260   int64_t expected_output_rank = std::max(x_ty.getRank(), y_ty.getRank());
261   if (output_ty.getRank() != expected_output_rank)
262     return op.emitOpError()
263            << "found invalid output rank, expected " << expected_output_rank
264            << " but got " << output_ty.getRank();
265 
266   // Check output batch dim with potential broadcasting.
267   ArrayRef<int64_t> output_shape = output_ty.getShape();
268   for (int i = 0; i < result_batch_shape.size(); ++i) {
269     if (output_shape[i] != ShapedType::kDynamicSize &&
270         result_batch_shape[i] != ShapedType::kDynamicSize &&
271         output_shape[i] != result_batch_shape[i])
272       return op.emitOpError()
273              << "has mismatching input batch dimension "
274              << result_batch_shape[i] << " and output batch dimension "
275              << output_shape[i];
276   }
277 
278   // Check output shape for non-batch dimension, following documentation below.
279   // https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul
280   int64_t x_row_dim = x_shape[x_shape.size() - 2];
281   int64_t x_col_dim = x_shape[x_shape.size() - 1];
282   int64_t y_row_dim = y_shape[y_shape.size() - 2];
283   int64_t y_col_dim = y_shape[y_shape.size() - 1];
284   int64_t out_row_dim = output_shape[output_shape.size() - 2];
285   int64_t out_col_dim = output_shape[output_shape.size() - 1];
286 
287   int64_t expected_out_row_dim = op.adj_x() ? x_col_dim : x_row_dim;
288   int64_t expected_out_col_dim = op.adj_y() ? y_row_dim : y_col_dim;
289 
290   if (expected_out_row_dim != ShapedType::kDynamicSize &&
291       out_row_dim != ShapedType::kDynamicSize &&
292       out_row_dim != expected_out_row_dim)
293     return op.emitOpError()
294            << "found invalid output dimension on row, expected "
295            << expected_out_row_dim << " but got " << out_row_dim;
296   if (expected_out_col_dim != ShapedType::kDynamicSize &&
297       out_col_dim != ShapedType::kDynamicSize &&
298       out_col_dim != expected_out_col_dim)
299     return op.emitOpError()
300            << "found invalid output dimension on col, expected "
301            << expected_out_col_dim << " but got " << out_col_dim;
302 
303   return success();
304 }
verify()305 LogicalResult BatchMatMulOp::verify() { return Verify(*this); }
verify()306 LogicalResult BatchMatMulV2Op::verify() { return Verify(*this); }
307 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)308 void BatchMatMulOp::getCanonicalizationPatterns(RewritePatternSet &results,
309                                                 MLIRContext *context) {
310   results.add<BatchMatMulToV2>(context);
311 }
312 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)313 void BatchMatMulV2Op::getCanonicalizationPatterns(RewritePatternSet &results,
314                                                   MLIRContext *context) {
315   results.add<BatchMatMulV2ToMatMul>(context);
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // BatchToSpaceOp
320 //===----------------------------------------------------------------------===//
321 
verify()322 LogicalResult BatchToSpaceOp::verify() {
323   BatchToSpaceOp op = *this;
324   // Op already has a constraint that block_size >= 2.
325   int64_t block_size = op.block_size();
326 
327   llvm::SmallVector<int64_t, 4> input_shape(4, ShapedType::kDynamicSize);
328   auto input_type = op.input().getType().cast<TensorType>();
329   if (input_type.hasRank()) {
330     if (input_type.getRank() != 4)
331       return op.emitOpError()
332              << "requires input to be a 4D tensor, but got " << input_type;
333 
334     int64_t input_batch = input_type.getDimSize(0);
335     if (input_batch != ShapedType::kDynamicSize &&
336         input_batch % (block_size * block_size) != 0) {
337       return op.emitOpError()
338              << "requires input batch (dimension 0) to be evenly divisible "
339                 "by (block_size * block_size), but got input batch "
340              << input_batch << " and block_size " << block_size;
341     }
342 
343     input_shape.assign(input_type.getShape().begin(),
344                        input_type.getShape().end());
345   }
346 
347   auto crops_type = op.crops().getType().cast<TensorType>();
348   if (crops_type.hasRank()) {
349     if (crops_type.getRank() != 2)
350       return op.emitOpError()
351              << "requires crops to be a 2D tensor, but got " << crops_type;
352 
353     auto dim_of_size = [&](int64_t dim, int64_t size) {
354       if (crops_type.isDynamicDim(dim)) return true;
355       return crops_type.getDimSize(dim) == size;
356     };
357     if (!dim_of_size(0, 2) || !dim_of_size(1, 2))
358       return op.emitOpError()
359              << "requires crops to be a tensor<2x2>, but got " << crops_type;
360   }
361 
362   DenseIntElementsAttr crops_attr;
363   // Crops are defined as [[crop_top, crop_bottom], [crop_left, crop_right]],
364   // and flattened as [crop_top, crop_bottom, crop_left, crop_right]
365   llvm::SmallVector<int64_t, 4> crops_values;
366   if (matchPattern(op.crops(), m_Constant(&crops_attr))) {
367     assert(crops_attr.getNumElements() == 4 &&
368            "tf.BatchToSpace crops must have 4 elements");
369 
370     auto crops_range = crops_attr.getValues<APInt>();
371     for (const auto &crops_value : crops_range) {
372       int64_t crops_value_int = crops_value.getSExtValue();
373       if (crops_value_int < 0)
374         return op.emitOpError()
375                << "requires all crop values to be nonnegative, but got "
376                << crops_attr;
377 
378       crops_values.push_back(crops_value_int);
379     }
380   }
381 
382   auto output_type = op.output().getType().cast<TensorType>();
383   if (output_type.hasRank()) {
384     if (output_type.getRank() != 4)
385       return op.emitOpError()
386              << "requires output to be a 4D tensor, but got " << output_type;
387 
388     auto static_dims = [](int64_t dim_a, int64_t dim_b) {
389       return dim_a != ShapedType::kDynamicSize &&
390              dim_b != ShapedType::kDynamicSize;
391     };
392 
393     auto output_shape = output_type.getShape();
394 
395     // output batch = input batch / (block_size * block_size).
396     int64_t input_batch = input_shape[0];
397     int64_t output_batch = output_shape[0];
398     if (static_dims(input_batch, output_batch) &&
399         (output_batch * block_size * block_size) != input_batch)
400       return op.emitOpError()
401              << "requires output batch (dimension 0) to be equal to input "
402                 "batch (dimension 0) / (block_size * block_size), but got "
403                 "output batch "
404              << output_batch << ", input batch " << input_batch
405              << ", and block_size " << block_size;
406 
407     auto check_spatial_dim = [&](int64_t spatial_dim_index,
408                                  llvm::StringRef dim_name,
409                                  llvm::StringRef crop_a_name,
410                                  llvm::StringRef crop_b_name) -> LogicalResult {
411       int64_t input_dim = input_shape[spatial_dim_index];
412       int64_t output_dim = output_shape[spatial_dim_index];
413       if (!static_dims(input_dim, output_dim)) return success();
414 
415       int64_t input_dim_pad = input_dim * block_size;
416       // If crops are unknown, the maximum output spatial dim size is input
417       // spatial dim size * block_size, as crops can be minimum 0.
418       if (crops_values.empty() && output_dim > input_dim * block_size)
419         return op.emitOpError()
420                << "requires output " << dim_name << " (dimension "
421                << spatial_dim_index << ") to be less than or equal to input "
422                << dim_name << " (dimension " << spatial_dim_index
423                << ") * block_size, but got output " << dim_name << " "
424                << output_dim << ", input " << dim_name << " " << input_dim
425                << ", and block_size " << block_size;
426 
427       if (!crops_values.empty()) {
428         // output spatial dim = input spatial dim * block_size - crops.
429         int64_t crop_a = crops_values[2 * (spatial_dim_index - 1)];
430         int64_t crop_b = crops_values[2 * (spatial_dim_index - 1) + 1];
431         if (output_dim != input_dim_pad - crop_a - crop_b)
432           return op.emitOpError()
433                  << "requires output " << dim_name << " (dimension "
434                  << spatial_dim_index << ") to be equal to input " << dim_name
435                  << " (dimension " << spatial_dim_index << ") * block_size - "
436                  << crop_a_name << " - " << crop_b_name << ", but got output "
437                  << dim_name << " " << output_dim << ", input " << dim_name
438                  << " " << input_dim << ", " << crop_a_name << " " << crop_a
439                  << ", " << crop_b_name << " " << crop_b << ", and block_size "
440                  << block_size;
441       }
442 
443       return success();
444     };
445 
446     if (failed(check_spatial_dim(1, "height", "crop_top", "crop_bottom")) ||
447         failed(check_spatial_dim(2, "width", "crop_left", "crop_right")))
448       return failure();
449 
450     int64_t input_depth = input_shape[3];
451     int64_t output_depth = output_shape[3];
452     if (static_dims(input_depth, output_depth) && output_depth != input_depth)
453       return op.emitOpError()
454              << "requires output depth (dimension 3) to be equal to input "
455                 "depth (dimension 3), but got output depth "
456              << output_depth << " and input depth " << input_depth;
457   }
458 
459   return success();
460 }
461 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)462 void BatchToSpaceOp::getCanonicalizationPatterns(RewritePatternSet &results,
463                                                  MLIRContext *context) {
464   results.add<BatchToSpaceToBatchToSpaceND>(context);
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // BatchToSpaceNDOp
469 //===----------------------------------------------------------------------===//
470 
verify()471 LogicalResult BatchToSpaceNDOp::verify() {
472   BatchToSpaceNDOp op = *this;
473   auto block_shape_ty = op.block_shape().getType().cast<ShapedType>();
474   auto crops_ty = op.crops().getType().cast<ShapedType>();
475 
476   if (block_shape_ty.hasStaticShape() && crops_ty.hasStaticShape()) {
477     const int block_rank = block_shape_ty.getShape().front();
478     if (crops_ty.getRank() != 2 || crops_ty.getShape().front() != block_rank ||
479         crops_ty.getShape()[1] != 2) {
480       op.emitOpError() << "crops should have shape [" << block_rank
481                        << ", 2] instead of " << crops_ty.getShape();
482       return failure();
483     }
484   }
485 
486   return success();
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // BiasAddOp
491 //===----------------------------------------------------------------------===//
492 
493 // Verifies that,
494 // * the value and bias operands have valid ranks or are unranked.
495 // * Channel dimension of the value operand and length of bias matches if they
496 //   are not unknown.
497 //
verify()498 LogicalResult BiasAddOp::verify() {
499   BiasAddOp op = *this;
500   absl::string_view data_format(op.data_format().data(),
501                                 op.data_format().size());
502   tensorflow::TensorFormat format;
503   bool is_valid = FormatFromString(data_format, &format);
504   DCHECK(is_valid) << data_format;
505   if (format == tensorflow::TensorFormat::FORMAT_NHWC) {
506     if (!HasRankAtLeast(op.value(), 2))
507       return op.emitOpError(
508           "requires value operand to have rank at least two with `NHWC` data "
509           "format");
510   } else {
511     // Op definition requires data_format to be either NHWC or NCHW.
512     DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW);
513     if (!HasRankAtLeast(op.value(), 3))
514       return op.emitOpError(
515           "requires value operand to have rank at least three with `NCHW` data "
516           "format");
517   }
518 
519   if (!IsOfRankOrUnranked(op.bias(), 1))
520     return op.emitOpError("requires bias operand to have rank exactly one");
521 
522   RankedTensorType value_ty = op.value().getType().dyn_cast<RankedTensorType>();
523   RankedTensorType bias_ty = op.bias().getType().dyn_cast<RankedTensorType>();
524   if (!bias_ty || !value_ty) return success();
525 
526   int64_t feature_dim_idx =
527       tensorflow::GetTensorFeatureDimIndex(value_ty.getRank(), format);
528   int64_t feature_dim = value_ty.getDimSize(feature_dim_idx);
529   int64_t bias_len = bias_ty.getDimSize(0);
530   if (feature_dim != -1 && bias_len != -1 && feature_dim != bias_len) {
531     return op.emitOpError()
532            << "requires channel dimension and feature dimension to match; "
533               "found "
534            << feature_dim << " and " << bias_len << ", respectively";
535   }
536   return success();
537 }
538 
UpdateDataFormat(StringRef data_format)539 LogicalResult BiasAddOp::UpdateDataFormat(StringRef data_format) {
540   return ::mlir::TF::UpdateDataFormat(data_format, this);
541 }
542 
GetOptimalLayout(const RuntimeDevices & devices)543 StringRef BiasAddOp::GetOptimalLayout(const RuntimeDevices &devices) {
544   // Keep current data format if no GPUs are available or if explicit placement
545   // does not allow to use GPU for this operation.
546   if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
547     return data_format();
548 
549   // Prefer NHWC for GPU devices.
550   return "NHWC";
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // BiasAddGradOp
555 //===----------------------------------------------------------------------===//
556 
557 // Verifies that,
558 // * the out_backprop operands have valid ranks or are unranked.
559 //
verify()560 LogicalResult BiasAddGradOp::verify() {
561   BiasAddGradOp op = *this;
562   absl::string_view data_format(op.data_format().data(),
563                                 op.data_format().size());
564   tensorflow::TensorFormat format;
565   bool is_valid = FormatFromString(data_format, &format);
566   DCHECK(is_valid) << data_format;
567   if (format == tensorflow::TensorFormat::FORMAT_NHWC) {
568     if (!HasRankAtLeast(op.out_backprop(), 2))
569       return op.emitOpError(
570           "requires out_backprop operand to have rank at least two with `NHWC` "
571           "data format");
572   } else {
573     // Op definition requires data_format to be either NHWC or NCHW.
574     DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW);
575     if (!HasRankAtLeast(op.out_backprop(), 3))
576       return op.emitOpError(
577           "requires out_backprop operand to have rank at least three with "
578           "`NCHW` data format");
579   }
580 
581   return success();
582 }
583 
584 //===----------------------------------------------------------------------===//
585 // BiasAddV1Op
586 //===----------------------------------------------------------------------===//
587 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)588 void BiasAddV1Op::getCanonicalizationPatterns(RewritePatternSet &results,
589                                               MLIRContext *context) {
590   results.add<BiasAddV1ToBiasAdd>(context);
591 }
592 
593 //===----------------------------------------------------------------------===//
594 // arith::BitcastOp
595 //===----------------------------------------------------------------------===//
596 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)597 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
598                                             MLIRContext *context) {
599   results.add<BitcastSameType, BitcastNested>(context);
600 }
601 
602 //===----------------------------------------------------------------------===//
603 // BroadcastToOp
604 //===----------------------------------------------------------------------===//
605 
verify()606 LogicalResult BroadcastToOp::verify() {
607   // TODO(antiagainst): check that
608   // * The 'shape' input is an 1-D int tensor.
609   // * Each dimension pair of the source and target shapes are either equal
610   //   or one of them is one.
611   return success();
612 }
613 
fold(ArrayRef<Attribute> operands)614 OpFoldResult BroadcastToOp::fold(ArrayRef<Attribute> operands) {
615   Value input = this->input();
616 
617   // Fold broadcast if operand and result types are the same and all dimensions
618   // are statically known (no-op broadcast).
619   auto result_ty = getType().dyn_cast<ShapedType>();
620   if (!result_ty || !result_ty.hasStaticShape()) return {};
621 
622   if (result_ty == input.getType()) return input;
623 
624   DenseIntElementsAttr cst_attr;
625   if (!matchPattern(input, m_Constant(&cst_attr))) return {};
626   if (!cst_attr.isSplat()) return {};
627 
628   return DenseElementsAttr::get(result_ty, cst_attr.getSplatValue<Attribute>());
629 }
630 
631 //===----------------------------------------------------------------------===//
632 // BroadcastGradientArgsOp
633 //===----------------------------------------------------------------------===//
634 
635 namespace {
636 // Returns `true` if both s0 & s1 are defined via constant op, and fills
637 // s0_shape & s1_shape.
ExtractInputConstShape(BroadcastGradientArgsOp op,DenseIntElementsAttr & s0,DenseIntElementsAttr & s1,SmallVectorImpl<int64_t> & s0_shape,SmallVectorImpl<int64_t> & s1_shape)638 bool ExtractInputConstShape(BroadcastGradientArgsOp op,
639                             DenseIntElementsAttr &s0, DenseIntElementsAttr &s1,
640                             SmallVectorImpl<int64_t> &s0_shape,
641                             SmallVectorImpl<int64_t> &s1_shape) {
642   if (!matchPattern(op.s0(), m_Constant(&s0))) return false;
643   if (!matchPattern(op.s1(), m_Constant(&s1))) return false;
644 
645   for (auto s : s0.getValues<APInt>()) s0_shape.push_back(s.getSExtValue());
646   for (auto s : s1.getValues<APInt>()) s1_shape.push_back(s.getSExtValue());
647 
648   return true;
649 }
650 
651 // Calculates r0 & r1 output based on inputs and calculated broadcasted shape.
652 //
653 // For given bcasted_shape, s0_shape and s1_shape, the broadcasted dimension is
654 // calculated and push back to its corresponding result, r0 or r1. For example,
655 // for s0_shape [1,4] and s1_shape [4, 4], bcasted_shape is computed to be
656 // [4,4] - this leads to the result of r0 to be [0] as the first dimension of s0
657 // is broadcasted, and r1 to be <> as no broadcasting is happening for s1.
GetOutputShapeForBroadcastGradientArgs(ArrayRef<int64_t> bcasted_shape,ArrayRef<int64_t> s0_shape,ArrayRef<int64_t> s1_shape,SmallVectorImpl<int64_t> & r0,SmallVectorImpl<int64_t> & r1)658 void GetOutputShapeForBroadcastGradientArgs(ArrayRef<int64_t> bcasted_shape,
659                                             ArrayRef<int64_t> s0_shape,
660                                             ArrayRef<int64_t> s1_shape,
661                                             SmallVectorImpl<int64_t> &r0,
662                                             SmallVectorImpl<int64_t> &r1) {
663   r0.clear();
664   r1.clear();
665 
666   // No broadcasting is required if both the shapes are equal.
667   if (s0_shape == s1_shape) return;
668 
669   for (int i = bcasted_shape.size(); i > 0; --i) {
670     int idx = bcasted_shape.size() - i;
671     int s0_idx = i > s0_shape.size() ? -1 : s0_shape.size() - i;
672     int s1_idx = i > s1_shape.size() ? -1 : s1_shape.size() - i;
673     if (s0_idx == -1) {
674       r0.push_back(idx);
675       if (s1_shape[s1_idx] == 1) r1.push_back(idx);
676     } else if (s1_idx == -1) {
677       r1.push_back(idx);
678       if (s0_shape[s0_idx] == 1) r0.push_back(idx);
679     } else if (s0_shape[s0_idx] != s1_shape[s1_idx]) {
680       if (s0_shape[s0_idx] != bcasted_shape[idx])
681         r0.push_back(idx);
682       else
683         r1.push_back(idx);
684     } else if (s0_shape[s0_idx] == 1) {
685       // This op is used to compute the gradient dimensions requiring reduction
686       // to match the input dimensions. In case both the dimensions are one,
687       // reducing the dimension has no effect. We choose to reduce such
688       // dimensions to match the TensorFlow kernel behavior. However, note that
689       // the TF behavior in this case is inconsistent with the case with the
690       // same shapes.
691       r0.push_back(idx);
692       r1.push_back(idx);
693     }
694   }
695 }
696 }  // namespace
697 
698 // Verifies that,
699 // * Broadcast compatability for input shapes.
700 // * Output shape dimension matches the expected dimension size for input
701 // shapes.
verify()702 LogicalResult BroadcastGradientArgsOp::verify() {
703   BroadcastGradientArgsOp op = *this;
704   SmallVector<int64_t, 4> s0_shape, s1_shape;
705   DenseIntElementsAttr s0, s1;
706   if (!ExtractInputConstShape(op, s0, s1, s0_shape, s1_shape)) return success();
707 
708   // If both shape is known const, try to validate shape on them as well.
709   SmallVector<int64_t, 4> bcasted_shape;
710   if (!OpTrait::util::getBroadcastedShape(s0_shape, s1_shape, bcasted_shape))
711     return op.emitOpError() << "requires broadcast compatible shape tensors "
712                                "for 's0' and 's1', but got "
713                             << s0 << " and " << s1;
714 
715   SmallVector<int64_t, 4> r0, r1;
716   GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0,
717                                          r1);
718 
719   // Verify that output types are of rank one and matches the computed result
720   // shape.
721   auto r0_ty = op.r0().getType().dyn_cast<RankedTensorType>();
722   auto r1_ty = op.r1().getType().dyn_cast<RankedTensorType>();
723   if (r0_ty && r0_ty.hasStaticShape() && r0_ty.getDimSize(0) != r0.size())
724     return op.emitOpError() << "requires dimension 0 size of 'r0' to be "
725                             << r0.size() << " but got " << r0_ty.getShape()[0];
726   if (r1_ty && r1_ty.hasStaticShape() && r1_ty.getDimSize(0) != r1.size())
727     return op.emitOpError() << "requires dimension 0 size of 'r1' to be "
728                             << r1.size() << " but got " << r1_ty.getShape()[0];
729 
730   return success();
731 }
732 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)733 LogicalResult BroadcastGradientArgsOp::fold(
734     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
735   SmallVector<int64_t, 4> s0_shape, s1_shape;
736   DenseIntElementsAttr s0, s1;
737   if (!ExtractInputConstShape(*this, s0, s1, s0_shape, s1_shape))
738     return failure();
739 
740   // Fold BroadcastGradientArgs into two constants if both of the inputs have
741   // known shape.
742   SmallVector<int64_t, 4> bcasted_shape;
743   // Verifier should already ensure the broadcast compatibility.
744   bool bcast_compatible =
745       OpTrait::util::getBroadcastedShape(s0_shape, s1_shape, bcasted_shape);
746   assert(bcast_compatible);
747   (void)bcast_compatible;
748 
749   SmallVector<int64_t, 4> r0, r1;
750   GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0,
751                                          r1);
752 
753   auto build_out_dense_element = [](SmallVectorImpl<int64_t> &shape,
754                                     Type input_type) {
755     Type element_type = input_type.cast<mlir::TensorType>().getElementType();
756     RankedTensorType type = RankedTensorType::get(
757         {static_cast<int64_t>(shape.size())}, element_type);
758     // Input could only be i32 or i64. For i32, downcast to int32_t array.
759     if (element_type.isInteger(32)) {
760       SmallVector<int32_t, 4> i32_shape;
761       for (auto s : shape) i32_shape.push_back(static_cast<int32_t>(s));
762       return DenseIntElementsAttr::get(type, i32_shape);
763     } else {
764       assert(element_type.isInteger(64));
765       return DenseIntElementsAttr::get(type, shape);
766     }
767   };
768 
769   results.push_back(build_out_dense_element(r0, this->s0().getType()));
770   results.push_back(build_out_dense_element(r1, this->s1().getType()));
771 
772   return success();
773 }
774 
775 //===----------------------------------------------------------------------===//
776 // CaseOp
777 //===----------------------------------------------------------------------===//
778 
779 class FoldConstantCaseOp : public OpRewritePattern<TF::CaseOp> {
780  public:
FoldConstantCaseOp(MLIRContext * context)781   explicit FoldConstantCaseOp(MLIRContext *context)
782       : OpRewritePattern<TF::CaseOp>(context) {}
783   LogicalResult matchAndRewrite(TF::CaseOp op,
784                                 PatternRewriter &rewriter) const override;
785 };
786 
matchAndRewrite(TF::CaseOp op,PatternRewriter & rewriter) const787 LogicalResult FoldConstantCaseOp::matchAndRewrite(
788     TF::CaseOp op, PatternRewriter &rewriter) const {
789   // Extract the constant cond value.
790   DenseIntElementsAttr branch;
791   if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure();
792 
793   int index = *branch.getValues<int>().begin();
794   if (index < 0 || index >= op.num_branches()) index = op.num_branches() - 1;
795 
796   auto func = op.branches()[index].cast<SymbolRefAttr>();
797   auto empty = rewriter.getStringAttr("");
798   ReplaceTfOpWithNewOp<PartitionedCallOp>(
799       rewriter, op, op.getResultTypes(), op.getOperands().drop_front(), func,
800       /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
801   return success();
802 }
803 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)804 void CaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
805                                          MLIRContext *context) {
806   results.add<FoldConstantCaseOp, DropAttributes<CaseOp>>(context);
807 }
808 
VerifyCaseOpBase(Operation * op,Value branch_index)809 static LogicalResult VerifyCaseOpBase(Operation *op, Value branch_index) {
810   if (!IsOfRankOrUnranked(branch_index, 0))
811     return op->emitOpError()
812            << "expects 'branch_index' to be a scalar, but got "
813            << branch_index.getType();
814   return success();
815 }
816 
VerifyCaseOrIfOpBranchFunctions(SymbolTableCollection & symbol_table,Operation * op,ArrayRef<Attribute> branches,llvm::function_ref<std::string (unsigned branch_index)> branch_name)817 static LogicalResult VerifyCaseOrIfOpBranchFunctions(
818     SymbolTableCollection &symbol_table, Operation *op,
819     ArrayRef<Attribute> branches,
820     llvm::function_ref<std::string(unsigned branch_index)> branch_name) {
821   SmallVector<FunctionType, 2> branch_types;
822   branch_types.reserve(branches.size());
823 
824   if (llvm::any_of(op->getOperands(),
825                    [](Value value) { return value == nullptr; }))
826     return op->emitOpError("operation has null operand");
827 
828   // Functions have one less operand compared to op as first operand is elided
829   // (`cond` of `tf.If` and `branch_index` of `tf.Case`).
830   TypeRangeWithDesc input{op->getOperands().drop_front().getTypes(), "input"};
831   TypeRangeWithDesc result{op->getResultTypes(), "result"};
832 
833   for (auto branch : llvm::enumerate(branches)) {
834     auto branch_func = symbol_table.lookupNearestSymbolFrom<func::FuncOp>(
835         op, branch.value().cast<SymbolRefAttr>());
836     if (!branch_func)
837       return op->emitOpError()
838              << "expects " << branch_name(branch.index()) << " ("
839              << branch.value() << ") to point to a defined function";
840 
841     FunctionType branch_type = branch_func.getFunctionType();
842     std::string desc = branch_name(branch.index()) + " input";
843     TypeRangeWithDesc branch_input{branch_type.getInputs(), desc};
844     if (failed(VerifyTypeRangesAreCompatible(op, branch_input, input)))
845       return failure();
846 
847     desc = branch_name(branch.index()) + " result";
848     TypeRangeWithDesc branch_result{branch_type.getResults(), desc};
849     if (failed(VerifyTypeRangesAreCompatible(op, branch_result, result)))
850       return failure();
851 
852     branch_types.push_back(branch_type);
853   }
854 
855   // If branches have incompatible input types that means that no tensor can
856   // serve as input to all the functions. Hence, the op is invalid.
857   int expected_num_inputs = op->getNumOperands() - 1;
858   for (int i = 0; i < expected_num_inputs; ++i) {
859     SmallVector<Type, 2> branch_input_i_types;
860     branch_input_i_types.reserve(branches.size());
861     llvm::transform(
862         branch_types, std::back_inserter(branch_input_i_types),
863         [i](FunctionType &branch_type) { return branch_type.getInput(i); });
864     if (!AreCastCompatible(branch_input_i_types)) {
865       std::string input_types_str;
866       llvm::raw_string_ostream os(input_types_str);
867       llvm::interleaveComma(branch_input_i_types, os);
868       return op->emitOpError()
869              << "expects all branch input type(s) (" << os.str()
870              << ") at index " << i << " to be cast compatible";
871     }
872   }
873 
874   return success();
875 }
876 
verify()877 LogicalResult CaseOp::verify() {
878   CaseOp op = *this;
879   return VerifyCaseOpBase(op, op.branch_index());
880 }
881 
verifySymbolUses(SymbolTableCollection & symbol_table)882 LogicalResult CaseOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
883   auto branch_name = [](unsigned index) {
884     return llvm::formatv("branch #{0}", index).str();
885   };
886   return VerifyCaseOrIfOpBranchFunctions(symbol_table, *this,
887                                          branches().getValue(), branch_name);
888 }
889 
890 //===----------------------------------------------------------------------===//
891 // CaseRegionOp
892 //===----------------------------------------------------------------------===//
893 
verify()894 LogicalResult CaseRegionOp::verify() {
895   CaseRegionOp op = *this;
896   if (op.branches().empty())
897     return op.emitOpError() << "expects to have at least 1 region";
898 
899   if (failed(VerifyCaseOpBase(op, op.branch_index()))) return failure();
900 
901   TypeRangeWithDesc results{op.getResultTypes(), "result"};
902 
903   for (auto region_and_idx : llvm::enumerate(op.branches())) {
904     std::string description =
905         llvm::formatv("branch #{0} result", region_and_idx.index()).str();
906     Operation *yield = region_and_idx.value().front().getTerminator();
907     TypeRangeWithDesc branch_results{yield->getOperandTypes(), description};
908     if (failed(VerifyTypeRangesAreCompatible(op, branch_results, results)))
909       return failure();
910   }
911 
912   return success();
913 }
914 
915 namespace {
916 // Eliminate values that pass through the CaseRegionOp or IfRegionOp branches.
917 template <class CaseOrIfRegionOp>
918 class CaseOrIfRegionEliminatePassThrough
919     : public OpRewritePattern<CaseOrIfRegionOp> {
920   using OpRewritePattern<CaseOrIfRegionOp>::OpRewritePattern;
921 
matchAndRewrite(CaseOrIfRegionOp op,PatternRewriter & rewriter) const922   LogicalResult matchAndRewrite(CaseOrIfRegionOp op,
923                                 PatternRewriter &rewriter) const override {
924     RegionRange branches = op.getRegions();
925     SmallVector<Type, 4> new_result_types;
926     // Maps pass through results to extern values.
927     llvm::SmallDenseMap<Value, Value, 4> result_to_extern_value;
928 
929     for (auto result : op.getResults()) {
930       unsigned index = result.getResultNumber();
931       Region *first_branch = *branches.begin();
932       Operation *first_terminator = first_branch->front().getTerminator();
933       Value returned_val = first_terminator->getOperand(index);
934 
935       // Pass through values would be defined outside the branch region. Keep
936       // the type of non pass through results to create a new op later, if
937       // required.
938       if (returned_val.getParentBlock() == &first_branch->front()) {
939         new_result_types.push_back(result.getType());
940         continue;
941       }
942       // Check if the same extern value is returned in each branch.
943       for (Region *region : branches.drop_front()) {
944         Operation *terminator = region->front().getTerminator();
945         if (terminator->getOperand(index) != returned_val) return failure();
946       }
947       result_to_extern_value[result] = returned_val;
948     }
949 
950     // If no pass through values are found, no change is required.
951     if (result_to_extern_value.empty()) return failure();
952 
953     // Create new case/if region op.
954     auto new_op = rewriter.create<CaseOrIfRegionOp>(
955         op.getLoc(), new_result_types, op.getOperand(), op->getAttrs(),
956         op.getNumRegions());
957 
958     int next_index = 0;
959     for (auto result : op.getResults()) {
960       if (!result_to_extern_value.count(result)) {
961         result.replaceAllUsesWith(new_op.getResult(next_index++));
962         continue;
963       }
964       result.replaceAllUsesWith(result_to_extern_value[result]);
965       for (Region *branch : branches)
966         branch->front().getTerminator()->eraseOperand(next_index);
967     }
968 
969     // Move region bodies to the new op.
970     for (auto region_index : llvm::seq<int>(0, branches.size()))
971       new_op.getRegion(region_index).takeBody(op.getRegion(region_index));
972 
973     op.erase();
974     return success();
975   }
976 };
977 }  // namespace
978 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)979 void CaseRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
980                                                MLIRContext *context) {
981   results.add<CaseOrIfRegionEliminatePassThrough<TF::CaseRegionOp>>(context);
982 }
983 
984 //===----------------------------------------------------------------------===//
985 // CastOp
986 //===----------------------------------------------------------------------===//
987 
fold(ArrayRef<Attribute> operands)988 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
989   // Cast with the same type is a no-op.
990   Value operand = getOperand();
991   if (getType() == operand.getType()) return operand;
992   return {};
993 }
994 
995 //===----------------------------------------------------------------------===//
996 // ConcatOp and ConcatV2Op
997 //===----------------------------------------------------------------------===//
998 
999 template <typename OpT,
1000           typename std::enable_if<llvm::is_one_of<
1001               OpT, ConcatOp, ConcatV2Op>::value>::type * = nullptr>
Verify(OpT op)1002 static LogicalResult Verify(OpT op) {
1003   // TODO(hinsu): Convert variadic length attributes to derived attributes.
1004   Operation::operand_range values = op.values();
1005 
1006   int axis_idx = std::is_same<OpT, ConcatOp>() ? 0 : 1;
1007   Value axis = *op.getODSOperands(axis_idx).begin();
1008   if (!HasRankAtMost(axis, 1)) {
1009     return op.emitOpError(
1010         "requires axis to be of scalar type (or vector type for older "
1011         "versions)");
1012   }
1013 
1014   return VerifyTypesCompatibility(values,
1015                                   /*mask_one_dim=*/true, op.getOperation());
1016 }
1017 
verify()1018 LogicalResult ConcatOp::verify() { return Verify(*this); }
verify()1019 LogicalResult ConcatV2Op::verify() { return Verify(*this); }
1020 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1021 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
1022                                            MLIRContext *context) {
1023   results.add<ConvertToConcatV2>(context);
1024 }
1025 
1026 namespace {
1027 
1028 // Hoist coefficient-wise unary operation out of the Concat op:
1029 //
1030 //   %0 = "tf.Log1p"(%arg_0)
1031 //   %1 = "tf.Log1p"(%arg_1)
1032 //   ...
1033 //   %n = "tf.Log1p"(%arg_n)
1034 //   %m = "tf.ConcatV2"(%0, %1, ..., %n, %axis)
1035 //
1036 // Rewrite it to:
1037 //
1038 //   %0 = "tf.ConcatV2"(%arg_0, %arg_1, ..., %arg_n, %axis)
1039 //   %1 = "tf.Log1p"(%0)
1040 class HoistCwiseUnaryOutOfConcat : public OpRewritePattern<TF::ConcatV2Op> {
1041  public:
HoistCwiseUnaryOutOfConcat(MLIRContext * context)1042   explicit HoistCwiseUnaryOutOfConcat(MLIRContext *context)
1043       : OpRewritePattern<TF::ConcatV2Op>(context) {}
1044   LogicalResult matchAndRewrite(TF::ConcatV2Op op,
1045                                 PatternRewriter &rewriter) const override;
1046 };
1047 
matchAndRewrite(TF::ConcatV2Op op,PatternRewriter & rewriter) const1048 LogicalResult HoistCwiseUnaryOutOfConcat::matchAndRewrite(
1049     TF::ConcatV2Op op, PatternRewriter &rewriter) const {
1050   auto loc = op.getLoc();
1051 
1052   // All concat operands must be defined by ops.
1053   Operation *first_arg_op = op.values().front().getDefiningOp();
1054   if (first_arg_op == nullptr) return failure();
1055 
1056   // All concat operands must be produced by the coeff-wise unary operation.
1057   if (!first_arg_op->hasTrait<OpTrait::TF::CwiseUnary>()) return failure();
1058 
1059   // All concat operands must be defined by the op of same kind.
1060   bool args_same_op = llvm::all_of(op.values(), [&](Value arg) -> bool {
1061     Operation *arg_op = arg.getDefiningOp();
1062     return arg_op && arg_op->getName() == first_arg_op->getName();
1063   });
1064   if (!args_same_op) return failure();
1065 
1066   // Collect unary operations operands.
1067   auto unary_operands = llvm::map_range(op.values(), [](Value arg) -> Value {
1068     return arg.getDefiningOp()->getOperand(0);
1069   });
1070   SmallVector<Value, 8> unary_ops_args(unary_operands);
1071 
1072   // Concatenate unary ops operands.
1073   auto concat_unary_operands =
1074       rewriter.create<ConcatV2Op>(loc, op.getType(), unary_ops_args, op.axis());
1075 
1076   // Replace original concat with an unary op.
1077   OperationState new_unary_op_state(loc, first_arg_op->getName().getStringRef(),
1078                                     concat_unary_operands.getResult(),
1079                                     op.getResult().getType(),
1080                                     ArrayRef<NamedAttribute>());
1081   Operation *new_unary_op = rewriter.create(new_unary_op_state);
1082 
1083   rewriter.replaceOp(op, new_unary_op->getResults());
1084 
1085   return success();
1086 }
1087 
1088 // Hoist coefficient-wise binary operation out of the Concat op:
1089 //
1090 //   %0 = tf.Mul(%lhs_0, %rhs_0)
1091 //   %1 = tf.Mul(%lhs_1, %rhs_1)
1092 //   ...
1093 //   %n = tf.Mul(%lhs_n, %rhs_n)
1094 //   %m = tf.ConcatV2(%0, %1, ..., %n, %axis)
1095 //
1096 // Rewrite it to:
1097 //
1098 //   %0 = tf.ConcatV2(%lhs0, %lhs1, ..., %lhs_n, %lhs_concat_axis)
1099 //   %1 = tf.ConcatV2(%rhs0, %rhs1, ..., %rhs_n, %rhs_concat_axis)
1100 //   %2 = tf.Mul(%0, %1)
1101 //
1102 // If a minor fraction of the Concat inputs are not of the same binary op kind
1103 // (tf.Mul in the above example), we will synthesize the binary ops for those
1104 // inputs. e.g. if we instead have %1 = %lhs_1, then we would synthesize a
1105 // tf.Mul op over it and a scalar const tensor 1.0. For now this only applies to
1106 // float32 tensors.
1107 // TODO(hongm): Implement this op synthesis optimization for other dtypes if
1108 // needed.
1109 //
1110 // Because coefficient-wise binary operations support implicit broadcasting, we
1111 // should be very careful with this optimization, and do not accidentally
1112 // produce incorrect concat operations.
1113 class HoistCwiseBinaryOutOfConcat : public OpRewritePattern<TF::ConcatV2Op> {
1114  public:
HoistCwiseBinaryOutOfConcat(MLIRContext * context)1115   explicit HoistCwiseBinaryOutOfConcat(MLIRContext *context)
1116       : OpRewritePattern<TF::ConcatV2Op>(context) {}
1117   LogicalResult matchAndRewrite(TF::ConcatV2Op op,
1118                                 PatternRewriter &rewriter) const override;
1119 
1120  private:
1121   struct HoistParams {
1122     SmallVector<Value, 8> lhs_args;
1123     SmallVector<Value, 8> rhs_args;
1124     int64_t lhs_axis;
1125     int64_t rhs_axis;
1126     Type lhs_concat_type;
1127     Type rhs_concat_type;
1128     int scalar_operand_idx;  // can be 0 or 1 for the binary op's operands.
1129   };
1130 
1131   // Returns parameters of a binary op hoisting out of concatenation if all of
1132   // the operands are in one of the compatible configurations.
1133   // All inputs of `op` should be of the same binary op kind (e.g. tf.Mul),
1134   // except from the ones in `exceptions`. In that case, we can synthesize that
1135   // binary op kind for the values in `exceptions`.
1136   Optional<HoistParams> GetHoistParams(
1137       TF::ConcatV2Op op, int64_t axis,
1138       const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const;
1139 };
1140 
matchAndRewrite(TF::ConcatV2Op op,PatternRewriter & rewriter) const1141 LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
1142     TF::ConcatV2Op op, PatternRewriter &rewriter) const {
1143   auto loc = op.getLoc();
1144 
1145   // Axis must be a constant scalar value.
1146   DenseIntElementsAttr axis_attr;
1147   if (!matchPattern(op.axis(), m_Constant(&axis_attr))) return failure();
1148   if (axis_attr.getNumElements() != 1) return failure();
1149   int64_t axis =
1150       axis_attr.getSplatValue<IntegerAttr>().getValue().getSExtValue();
1151   // TODO(ezhulenev): Compute axis from rank. e.g. It might be common to concat
1152   // on the channels dim for NCHW layout as axis=-2.
1153   if (axis < 0) return failure();
1154 
1155   // All concat operands must be defined by ops of the same kind (e.g. tf.Mul),
1156   // or some other ops that we might convert to using the same op kind above
1157   // (e.g. converting op A to tf.Mul(A, 1.0))
1158   // TODO(hongm): generalize the code here to support cases where the first arg
1159   // has no defining op (e.g. might be a block arg).
1160   Operation *first_arg_op = op.values().front().getDefiningOp();
1161   if (first_arg_op == nullptr) return failure();
1162 
1163   // All concat operands must be produced by the coeff-wise binary operation.
1164   if (!first_arg_op->hasTrait<OpTrait::TF::CwiseBinary>()) return failure();
1165 
1166   // All concat operands must be defined by the op of same kind, except for a
1167   // minor portion which we track in `exceptions`.
1168   // Map from the operands to operand indices.
1169   llvm::SmallDenseMap<Value, unsigned, 4> exceptions;
1170   unsigned operand_idx = 0;
1171   for (Value arg : op.values()) {
1172     Operation *arg_op = arg.getDefiningOp();
1173     if (arg_op && arg_op->getName() == first_arg_op->getName()) {
1174       ++operand_idx;
1175       continue;
1176     }
1177     exceptions[arg] = operand_idx++;
1178   }
1179   // Recall those inputs to the concat op that are not produced by a binary op
1180   // of the `first_arg_op` kind (e.g. tf.Mul) are stored in `exceptions`. If
1181   // there are too many exceptions, it might not be cost effective to apply the
1182   // concat hoisting optimization here.
1183   // Setting the threshold to be 50% as a simple cost model heuristic. e.g. If 1
1184   // out of 2 concat inputs is an exception, we don't apply the hoist. If it's 1
1185   // out of 3, we do.
1186   const float exception_pct_threshold = 0.5;
1187   if (static_cast<float>(op.values().size()) * exception_pct_threshold <=
1188       exceptions.size())
1189     return failure();
1190 
1191   // Compute binary operands hoist parameters.
1192   auto hoist_params = GetHoistParams(op, axis, exceptions);
1193   if (!hoist_params.has_value()) return failure();
1194 
1195   // Process `exceptions`: For each value there, synthesize a binary op of the
1196   // above kind, so that the concat hoisting optimization can still apply.
1197   if (!exceptions.empty()) {
1198     int identity_val;
1199     if (isa<AddOp>(first_arg_op) || isa<SubOp>(first_arg_op))
1200       identity_val = 0;
1201     else if (isa<MulOp>(first_arg_op) || isa<DivOp>(first_arg_op) ||
1202              isa<RealDivOp>(first_arg_op))
1203       identity_val = 1;
1204     else
1205       return failure();
1206     DenseElementsAttr const_attr;
1207     auto scalar_tensor_type =
1208         first_arg_op->getOperand(hoist_params->scalar_operand_idx)
1209             .getType()
1210             .dyn_cast<ShapedType>();
1211     Type scalar_dtype = scalar_tensor_type.getElementType();
1212     if (scalar_dtype.isa<FloatType>())
1213       const_attr = DenseElementsAttr::get(scalar_tensor_type,
1214                                           static_cast<float>(identity_val));
1215     else
1216       return failure();
1217 
1218     // All checks are passes, and we now prepare for rewrite.
1219     auto identity_const = rewriter.create<TF::ConstOp>(loc, const_attr);
1220     for (const auto &kv : exceptions) {
1221       assert(!hoist_params->lhs_args[kv.second]);
1222       assert(!hoist_params->rhs_args[kv.second]);
1223 
1224       if (hoist_params->scalar_operand_idx == 1) {
1225         hoist_params->lhs_args[kv.second] = kv.first;
1226         hoist_params->rhs_args[kv.second] = identity_const;
1227       } else {
1228         assert(hoist_params->scalar_operand_idx == 0);
1229         hoist_params->lhs_args[kv.second] = identity_const;
1230         hoist_params->rhs_args[kv.second] = kv.first;
1231       }
1232     }
1233   }
1234 
1235   // Concatenates `args` along `axis`.
1236   auto pack_or_concat = [&](bool is_scalar, Type result_type, ValueRange args,
1237                             int64_t axis) {
1238     // Use `PackOp` for scalar concatenation because `ConcatV2Op` doesn't
1239     // support scalar concatenation.
1240     if (is_scalar) {
1241       auto pack = rewriter.create<PackOp>(loc, result_type, args,
1242                                           rewriter.getI64IntegerAttr(axis));
1243       return pack.getResult();
1244     }
1245 
1246     // New concatenation axis.
1247     auto axis_type = RankedTensorType::get({}, getElementTypeOrSelf(axis_attr));
1248     DenseIntElementsAttr attr;
1249     if (axis_type.getElementType().isInteger(32)) {
1250       attr = DenseIntElementsAttr::get(axis_type, static_cast<int32_t>(axis));
1251     } else {
1252       assert(axis_type.getElementType().isInteger(64));
1253       attr = DenseIntElementsAttr::get(axis_type, axis);
1254     }
1255     auto axis_const = rewriter.create<TF::ConstOp>(loc, attr);
1256 
1257     auto concat =
1258         rewriter.create<ConcatV2Op>(loc, result_type, args, axis_const);
1259     return concat.getResult();
1260   };
1261 
1262   // Concatenate binary ops operands on the new axis.
1263   Value lhs_concat = pack_or_concat(
1264       hoist_params->scalar_operand_idx == 0, hoist_params->lhs_concat_type,
1265       hoist_params->lhs_args, hoist_params->lhs_axis);
1266   Value rhs_concat = pack_or_concat(
1267       hoist_params->scalar_operand_idx == 1, hoist_params->rhs_concat_type,
1268       hoist_params->rhs_args, hoist_params->rhs_axis);
1269 
1270   // Replace original concat with a binary op.
1271   OperationState new_binary_op_state(
1272       loc, first_arg_op->getName().getStringRef(), {lhs_concat, rhs_concat},
1273       op.getResult().getType(), ArrayRef<NamedAttribute>());
1274   Operation *new_binary_op = rewriter.create(new_binary_op_state);
1275 
1276   rewriter.replaceOp(op, new_binary_op->getResults());
1277 
1278   return success();
1279 }
1280 
1281 Optional<HoistCwiseBinaryOutOfConcat::HoistParams>
GetHoistParams(TF::ConcatV2Op op,int64_t axis,const llvm::SmallDenseMap<Value,unsigned,4> & exceptions) const1282 HoistCwiseBinaryOutOfConcat::GetHoistParams(
1283     TF::ConcatV2Op op, int64_t axis,
1284     const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const {
1285   assert(axis >= 0);
1286   // Collects lhs or rhs arguments of concat op operands.
1287   auto args = [&](int operand_idx) -> SmallVector<Value, 8> {
1288     auto range = llvm::map_range(op.values(), [&](Value arg) {
1289       if (exceptions.count(arg)) return Value();
1290       return arg.getDefiningOp()->getOperand(operand_idx);
1291     });
1292     return {range.begin(), range.end()};
1293   };
1294 
1295   // Returns true if all binary ops operands at `operand_idx` index are tensors
1296   // of `axis + 1` rank and axis dim has size `1`.
1297   auto is_all_tensors = [&](int operand_idx, int axis) -> bool {
1298     return llvm::all_of(op.values(), [&](Value arg) -> bool {
1299       mlir::Value operand;
1300       if (exceptions.count(arg)) {
1301         // For exceptions, since we are going to synthesize a binary op that
1302         // produce the identity value, it is also required that it is a ranked
1303         // tensor with rank = `axis + 1` and axis dim has size `1`.
1304         operand = arg;
1305       } else {
1306         operand = arg.getDefiningOp()->getOperand(operand_idx);
1307       }
1308       auto ranked = operand.getType().dyn_cast<RankedTensorType>();
1309       return ranked && ranked.getRank() == (axis + 1) &&
1310              ranked.getShape()[axis] == 1;
1311     });
1312   };
1313 
1314   // Returns true if all binary ops operands at `operand_idx` index are scalars.
1315   auto is_all_scalars = [&](int operand_idx) -> bool {
1316     return llvm::all_of(op.values(), [&](Value arg) -> bool {
1317       if (exceptions.count(arg)) return true;
1318       auto operand = arg.getDefiningOp()->getOperand(operand_idx);
1319       auto ranked = operand.getType().dyn_cast<RankedTensorType>();
1320       return ranked && ranked.hasRank() && ranked.getRank() == 0;
1321     });
1322   };
1323 
1324   // Concat result type must be a ranked tensor.
1325   auto ranked = op.getType().dyn_cast<RankedTensorType>();
1326   if (!ranked) return None;
1327 
1328   // TODO(ezhulenev): Add support for more valid concat patterns.
1329 
1330   // Tensor + Scalar: [..., 1] + []  <- scalar
1331   //                        ^
1332   //                        \- axis is the innermost dimension.
1333   //
1334   // Concatenate tensor arguments on the same axis as the original operation,
1335   // and concatenate scalars into the vector.
1336   if (is_all_tensors(0, axis) && is_all_scalars(1)) {
1337     std::array<int64_t, 1> rhs_dims{static_cast<int64_t>(op.values().size())};
1338     auto rhs_type = RankedTensorType::get(rhs_dims, ranked.getElementType());
1339     return HoistParams{args(0),
1340                        args(1),
1341                        axis,
1342                        0,
1343                        op.getType(),
1344                        rhs_type,
1345                        /*scalar_operand_idx=*/1};
1346   } else if (is_all_tensors(1, axis) && is_all_scalars(0)) {
1347     std::array<int64_t, 1> lhs_dims{static_cast<int64_t>(op.values().size())};
1348     auto lhs_type = RankedTensorType::get(lhs_dims, ranked.getElementType());
1349     return HoistParams{args(0),
1350                        args(1),
1351                        0,
1352                        axis,
1353                        lhs_type,
1354                        op.getType(),
1355                        /*scalar_operand_idx=*/0};
1356   }
1357   return None;
1358 }
1359 
1360 }  // namespace
1361 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1362 void ConcatV2Op::getCanonicalizationPatterns(RewritePatternSet &results,
1363                                              MLIRContext *context) {
1364   results.add<HoistCwiseBinaryOutOfConcat, HoistCwiseUnaryOutOfConcat>(context);
1365 }
1366 
1367 //===----------------------------------------------------------------------===//
1368 // CumsumOp and CumprodOp
1369 //===----------------------------------------------------------------------===//
1370 
1371 template <typename OpT, typename std::enable_if<llvm::is_one_of<
1372                             OpT, CumsumOp, CumprodOp>::value>::type * = nullptr>
Verify(OpT op)1373 static LogicalResult Verify(OpT op) {
1374   if (!IsOfRankOrUnranked(op.axis(), 0))
1375     return op.emitOpError("requires scalar axis operand");
1376 
1377   DenseIntElementsAttr axis_attr;
1378   if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
1379     auto input_ty = op.x().getType().template dyn_cast<RankedTensorType>();
1380     if (input_ty) {
1381       int64_t rank = input_ty.getRank();
1382       assert(axis_attr.getNumElements() == 1 &&
1383              "scalar attribute should have exactly one element");
1384       int64_t axis = (*axis_attr.begin()).getSExtValue();
1385       if (axis < -rank || axis >= rank) {
1386         return op.emitError()
1387                << "axis operand should be within range [" << -rank << ", "
1388                << rank << "); actual value: " << axis;
1389       }
1390     }
1391   }
1392 
1393   return success();
1394 }
verify()1395 LogicalResult CumprodOp::verify() { return Verify(*this); }
verify()1396 LogicalResult CumsumOp::verify() { return Verify(*this); }
1397 
1398 //===----------------------------------------------------------------------===//
1399 // ConcatOffsetOp
1400 //===----------------------------------------------------------------------===//
1401 
verify()1402 LogicalResult ConcatOffsetOp::verify() {
1403   ConcatOffsetOp op = *this;
1404   if (op.N() < 2)
1405     return op.emitOpError() << "requires N to be at least 2, got " << op.N();
1406 
1407   if (op.shape().size() != op.offset().size())
1408     return op.emitOpError()
1409            << "requires sizes of shapes and offsets to be the same, got sizes "
1410            << op.shape().size() << " and " << op.offset().size();
1411 
1412   auto ranked_dim = op.concat_dim().getType().dyn_cast<RankedTensorType>();
1413   if (ranked_dim && ranked_dim.getRank() != 0)
1414     return op.emitOpError()
1415            << "requires concat_dim to be a scalar, got tensor of rank "
1416            << ranked_dim.getRank();
1417 
1418   int64_t num_dims = -1;
1419   for (auto shape_offset_idx :
1420        llvm::enumerate(llvm::zip(op.shape(), op.offset()))) {
1421     Value shape = std::get<0>(shape_offset_idx.value());
1422     Value offset = std::get<1>(shape_offset_idx.value());
1423     const size_t idx = shape_offset_idx.index();
1424 
1425     if (failed(verifyCompatibleShape(shape.getType(), offset.getType())))
1426       return op.emitOpError() << "requires operand and result " << idx
1427                               << " to have compatible shapes";
1428 
1429     auto ranked_shape = shape.getType().dyn_cast<RankedTensorType>();
1430     if (!ranked_shape) continue;
1431 
1432     if (ranked_shape.getRank() != 1)
1433       return op.emitOpError() << "requires shape tensor operand " << idx
1434                               << " to be of rank 1, got tensor of rank "
1435                               << ranked_shape.getRank();
1436 
1437     if (!ranked_shape.hasStaticShape()) continue;
1438 
1439     int64_t ranked_shape_dim = ranked_shape.getDimSize(0);
1440     if (num_dims == -1)
1441       num_dims = ranked_shape_dim;
1442     else if (ranked_shape_dim != num_dims)
1443       return op.emitOpError()
1444              << "requires shape tensor (rank 1) operand " << idx
1445              << " to be of length " << num_dims
1446              << ", got tensor (rank 1) of length " << ranked_shape_dim;
1447   }
1448 
1449   return success();
1450 }
1451 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1452 LogicalResult ConcatOffsetOp::fold(ArrayRef<Attribute> operands,
1453                                    SmallVectorImpl<OpFoldResult> &results) {
1454   // ConcatOffset must have its first operand be concat_dim and at least two
1455   // shape tensors in variadic shapes operand.
1456   if (operands.size() < 3) return failure();
1457 
1458   // Check concat_dim is a scalar.
1459   auto concat_dim_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1460   if (!concat_dim_attr || concat_dim_attr.getType().getRank() != 0)
1461     return failure();
1462 
1463   llvm::SmallVector<DenseIntElementsAttr, 4> shapes;
1464   shapes.reserve(operands.size() - 1);
1465   for (Attribute shape : llvm::drop_begin(operands, 1))
1466     if (auto shape_attr = shape.dyn_cast_or_null<DenseIntElementsAttr>())
1467       shapes.push_back(shape_attr);
1468     else
1469       return failure();
1470 
1471   // Check all shapes are vectors of the same length.
1472   if (shapes.front().getType().getRank() != 1) return success();
1473   const int64_t num_dims = shapes.front().getNumElements();
1474   for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1))
1475     if (shape.getType().getRank() != 1 || shape.getNumElements() != num_dims)
1476       return failure();
1477 
1478   // Check concat_dim is within [-num_dims, num_dims).
1479   int32_t concat_dim = (*concat_dim_attr.getValues<int32_t>().begin());
1480   if (concat_dim < 0) concat_dim += num_dims;
1481   if (concat_dim >= num_dims || concat_dim < 0) return failure();
1482 
1483   // Check all elements besides at concat_dim match across all shape tensors.
1484   SmallVector<int32_t, 4> shape0;
1485   shape0.reserve(num_dims);
1486   for (int32_t dim : shapes.front().getValues<int32_t>()) shape0.push_back(dim);
1487 
1488   for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) {
1489     for (auto dims_and_idx : llvm::enumerate(llvm::zip(shape0, shape))) {
1490       if (dims_and_idx.index() == concat_dim) continue;
1491 
1492       if (std::get<0>(dims_and_idx.value()) !=
1493           std::get<1>(dims_and_idx.value()).getSExtValue())
1494         return failure();
1495     }
1496   }
1497 
1498   // Compute an exclusive cumulative sum of elements at concat_dim.
1499   results.reserve(shapes.size());
1500   SmallVector<int32_t, 4> cumulative_sum(num_dims, 0);
1501   RankedTensorType offset_type =
1502       RankedTensorType::get({num_dims}, IntegerType::get(getContext(), 32));
1503   for (DenseIntElementsAttr shape : shapes) {
1504     results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum));
1505     cumulative_sum[concat_dim] += shape.getValues<int32_t>()[concat_dim];
1506   }
1507 
1508   return success();
1509 }
1510 
1511 //===----------------------------------------------------------------------===//
1512 // ConstOp
1513 //===----------------------------------------------------------------------===//
1514 
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)1515 void ConstOp::getAsmResultNames(
1516     function_ref<void(Value, StringRef)> setNameFn) {
1517   setNameFn(getResult(), "cst");
1518 }
1519 
fold(ArrayRef<Attribute> operands)1520 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
1521   assert(operands.empty() && "constant has no operands");
1522 
1523   // Return the held attribute value.
1524   return value();
1525 }
1526 
1527 // Builds a constant op with the specified attribute `value`. The result
1528 // op's type is deduced from `value`; if `value` is of scalar type,
1529 // wraps it up with a tensor type of empty shape.
1530 // TODO(jpienaar): This one differs from the autogenerated one as it takes an
1531 // attribute but always creates an ElementsAttr internally.
build(OpBuilder & builder,OperationState & result,Attribute value)1532 void ConstOp::build(OpBuilder &builder, OperationState &result,
1533                     Attribute value) {
1534   ShapedType type;
1535   if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
1536     return ConstOp::build(builder, result, elem_attr);
1537   } else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
1538     // All TensorFlow types must be tensor types. In the build() method,
1539     // we want to provide more flexibility by allowing attributes of scalar
1540     // types. But we need to wrap it up with ElementsAttr to construct
1541     // valid TensorFlow constants.
1542     auto typed_attr = value.cast<TypedAttr>();
1543     type = RankedTensorType::get(/*shape=*/{}, typed_attr.getType());
1544     return ConstOp::build(builder, result, DenseElementsAttr::get(type, value));
1545   }
1546   // TODO(jpienaar): support other TensorFlow specific types.
1547   llvm_unreachable("unsupported attribute type for building tf.Const");
1548 }
1549 
build(OpBuilder & builder,OperationState & result,Type type,Attribute value)1550 void ConstOp::build(OpBuilder &builder, OperationState &result, Type type,
1551                     Attribute value) {
1552   // Handle the case where the type and value are already tensors.
1553   if (type.isa<TensorType>() && value.isa<ElementsAttr>()) {
1554     result.addTypes(type);
1555     result.addAttribute("value", value);
1556     return;
1557   }
1558 
1559   // Otherwise, default to the attribute builder.
1560   ConstOp::build(builder, result, value);
1561   assert(type == result.types[0] && "type mismatch in construction");
1562 }
1563 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1564 LogicalResult ConstOp::inferReturnTypes(
1565     MLIRContext *context, Optional<Location> location, ValueRange operands,
1566     DictionaryAttr attributes, RegionRange regions,
1567     SmallVectorImpl<Type> &inferredReturnTypes) {
1568   auto value = attributes.get("value");
1569   if (!value) return emitOptionalError(location, "missing attribute 'value'");
1570   if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
1571     inferredReturnTypes.assign({elem_attr.getType()});
1572     return success();
1573   }
1574   return emitOptionalError(location,
1575                            "attribute 'value' failed to satisfy constraint: "
1576                            "constant vector/tensor");
1577 }
1578 
1579 //===----------------------------------------------------------------------===//
1580 // Conv2DOp and Conv3DOp
1581 //===----------------------------------------------------------------------===//
1582 
VerifyConvOpAttributes(int num_dims,ArrayRef<Attribute> strides,ArrayRef<Attribute> dilations,llvm::Optional<mlir::Location> location)1583 static LogicalResult VerifyConvOpAttributes(
1584     int num_dims, ArrayRef<Attribute> strides, ArrayRef<Attribute> dilations,
1585     llvm::Optional<mlir::Location> location) {
1586   int64_t strides_size = strides.size();
1587   if (strides_size != num_dims)
1588     return emitOptionalError(
1589         location, "requires strides attribute length to be ", num_dims);
1590   auto is_not_positive = [](Attribute val) {
1591     return val.cast<IntegerAttr>().getValue().getSExtValue() <= 0;
1592   };
1593   if (llvm::any_of(strides, is_not_positive))
1594     return emitOptionalError(location, "requires positive strides");
1595 
1596   int64_t dilations_size = dilations.size();
1597   if (dilations_size != num_dims)
1598     return emitOptionalError(
1599         location, "requires dilations attribute length to be ", num_dims);
1600   if (llvm::any_of(dilations, is_not_positive))
1601     return emitOptionalError(location, "requires positive dilations");
1602 
1603   return success();
1604 }
1605 
1606 // Verifies that,
1607 // * Number of input channels is divisible by the number of filter input
1608 //   channels
1609 template <typename OpT, typename std::enable_if<llvm::is_one_of<
1610                             OpT, Conv2DOp, Conv3DOp>::value>::type * = nullptr>
Verify(OpT op)1611 static LogicalResult Verify(OpT op) {
1612   int num_spatial_dims = std::is_same<OpT, Conv2DOp>() ? 2 : 3;
1613   int num_dims = 2 + num_spatial_dims;
1614 
1615   StringRef data_format = op.data_format();
1616   tensorflow::TensorFormat format;
1617   auto data_format_is_valid = FormatFromString(data_format.str(), &format);
1618   if (!data_format_is_valid) {
1619     return emitOptionalError(op.getLoc(), "Invalid data format provided");
1620   }
1621 
1622   const StringRef paddings = op.padding();
1623   tensorflow::Padding padding;
1624   auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding);
1625   if (!padding_is_valid.ok()) {
1626     return emitOptionalError(op.getLoc(), "Invalid padding format provided");
1627   }
1628 
1629   // Verifies that,
1630   // * Ranks of operands and result are valid
1631   // * Length of explicit_paddings attribute is valid and has non negative
1632   //   elements
1633   // * strides and dilations attributes have positive elements
1634   if (!IsOfRankOrUnranked(op.input(), num_dims) ||
1635       !IsOfRankOrUnranked(op.filter(), num_dims))
1636     return emitOptionalError(op.getLoc(), "requires operands to be ", num_dims,
1637                              "D tensor");
1638 
1639   if (padding == tensorflow::Padding::EXPLICIT) {
1640     ArrayRef<Attribute> explicit_padding;
1641     ArrayAttr explicit_pad =
1642         op->getAttr("explicit_paddings")
1643             .template dyn_cast_or_null<::mlir::ArrayAttr>();
1644     if (!explicit_pad) {
1645       explicit_pad = ::mlir::Builder(op->getContext()).getI64ArrayAttr({});
1646     }
1647     explicit_padding = explicit_pad.getValue();
1648 
1649     if (explicit_padding.empty()) {
1650       return emitOptionalError(op.getLoc(),
1651                                "requires attribute 'explicit_paddings' with "
1652                                "'EXPLICIT' padding mode");
1653     }
1654     if (explicit_padding.size() != num_dims * 2) {
1655       return emitOptionalError(
1656           op.getLoc(), "requires explicit_paddings attribute length to be ",
1657           num_dims * 2);
1658     }
1659     auto is_negative = [](Attribute val) {
1660       return val.cast<IntegerAttr>().getValue().getSExtValue() < 0;
1661     };
1662     if (llvm::any_of(explicit_padding, is_negative))
1663       return emitOptionalError(op.getLoc(),
1664                                "requires non negative explicit paddings");
1665   }
1666 
1667   ArrayRef<Attribute> strides = op.strides().getValue();
1668   ArrayRef<Attribute> dilations = op.dilations().getValue();
1669   if (failed(
1670           VerifyConvOpAttributes(num_dims, strides, dilations, op.getLoc()))) {
1671     return failure();
1672   }
1673 
1674   int64_t input_channels = -1;
1675   if (auto ty = op.input().getType().template dyn_cast<RankedTensorType>()) {
1676     absl::string_view data_format(op.data_format().data(),
1677                                   op.data_format().size());
1678     tensorflow::TensorFormat format;
1679     auto is_valid = FormatFromString(data_format, &format);
1680     DCHECK(is_valid) << data_format;
1681     int idx = tensorflow::GetTensorFeatureDimIndex(num_dims, format);
1682     input_channels = ty.getDimSize(idx);
1683   }
1684 
1685   int64_t filter_channels = -1;
1686   if (auto ty = op.filter().getType().template dyn_cast<RankedTensorType>()) {
1687     int idx = tensorflow::GetFilterTensorInputChannelsDimIndex(
1688         num_dims, tensorflow::FORMAT_HWIO);
1689     filter_channels = ty.getDimSize(idx);
1690   }
1691 
1692   if (input_channels != -1 && filter_channels != -1 &&
1693       input_channels % filter_channels != 0)
1694     return op.emitOpError()
1695            << "requires the number of input channels to be divisible by the "
1696               "number of filter input channels; found "
1697            << input_channels << " and " << filter_channels << ", respectively";
1698 
1699   return success();
1700 }
1701 
verify()1702 LogicalResult Conv2DOp::verify() { return Verify(*this); }
verify()1703 LogicalResult Conv3DOp::verify() { return Verify(*this); }
1704 
UpdateDataFormat(StringRef data_format)1705 LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) {
1706   auto perm = GetDataFormatPermutation(this->data_format(), data_format);
1707   if (perm.empty()) return failure();
1708 
1709   // Update data_format attribute and result types.
1710   if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
1711 
1712   // Update convolution attributes.
1713   (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
1714   (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
1715   (*this)->setAttr("explicit_paddings",
1716                    ShuffleArrayAttr(explicit_paddings(), perm, 2));
1717 
1718   return success();
1719 }
1720 
1721 // Verifies the inferred return type of the given operation.
1722 template <typename OpT,
1723           typename std::enable_if<llvm::is_one_of<
1724               OpT, Conv2DOpAdaptor, Conv3DOpAdaptor>::value>::type * = nullptr>
inferConvReturnTypeComponents(llvm::Optional<mlir::Location> location,OpT op,ArrayRef<Attribute> explicit_padding,llvm::SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1725 static LogicalResult inferConvReturnTypeComponents(
1726     llvm::Optional<mlir::Location> location, OpT op,
1727     ArrayRef<Attribute> explicit_padding,
1728     llvm::SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1729   const int64_t num_spatial_dims = std::is_same<OpT, Conv2DOpAdaptor>() ? 2 : 3;
1730   const int64_t num_dims = 2 + num_spatial_dims;
1731   const Value input = op.input();
1732   const Value filter = op.filter();
1733   const TensorType input_ty = input.getType().template cast<TensorType>();
1734   const TensorType filter_ty = filter.getType().template cast<TensorType>();
1735 
1736   ArrayRef<Attribute> strides = op.strides().getValue();
1737   StringRef data_format = op.data_format();
1738   ArrayRef<Attribute> dilations = op.dilations().getValue();
1739 
1740   tensorflow::TensorFormat format;
1741   auto data_format_is_valid = FormatFromString(data_format.str(), &format);
1742   assert(data_format_is_valid);
1743   (void)data_format_is_valid;
1744 
1745   tensorflow::Padding padding;
1746   const StringRef paddings = op.padding();
1747   auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding);
1748   assert(padding_is_valid.ok());
1749   (void)padding_is_valid;
1750 
1751   auto get_int = [](Attribute attr) {
1752     return attr.template cast<IntegerAttr>().getInt();
1753   };
1754 
1755   // Output always have `num_dims` rank. All dimensions are initialized to
1756   // dynamic size and can be partially inferred.
1757   SmallVector<int64_t, 4> return_shape(num_dims, ShapedType::kDynamicSize);
1758   // Output batch and channel dimension can be obtained using utilities from
1759   // tensorflow/core/util/tensor_format.h.
1760   if (input_ty.hasRank()) {
1761     return_shape[GetTensorBatchDimIndex(num_dims, format)] =
1762         input_ty.getDimSize(GetTensorBatchDimIndex(num_dims, format));
1763   }
1764   if (filter_ty.hasRank()) {
1765     return_shape[GetTensorFeatureDimIndex(num_dims, format)] =
1766         filter_ty.getDimSize(GetFilterTensorOutputChannelsDimIndex(
1767             num_dims, tensorflow::FORMAT_HWIO));
1768   }
1769   // Spatial dimensions can be inferred only when both input and filter are
1770   // ranked because we need to get their spatial dimensions.
1771   if (input_ty.hasRank() && filter_ty.hasRank()) {
1772     // Checks the size of each of the output spatial dimensions.
1773     for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
1774       const int64_t dim = GetTensorSpatialDimIndex(num_dims, format, i);
1775       int64_t stride = get_int(strides[dim]);
1776       int64_t expected_output_size;
1777       int64_t pad_low;
1778       int64_t pad_high;
1779       // Retrieve padding, if defined explicitly.
1780       if (padding == tensorflow::Padding::EXPLICIT) {
1781         pad_low = get_int(explicit_padding[2 * dim]);
1782         pad_high = get_int(explicit_padding[2 * dim + 1]);
1783       }
1784       // Skip if input or filter size is dynamic.
1785       if (input_ty.isDynamicDim(dim) || filter_ty.isDynamicDim(i)) continue;
1786       // Calculate the expected_output_size.
1787       tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
1788           input_ty.getDimSize(dim), filter_ty.getDimSize(i),
1789           get_int(dilations[dim]), stride, padding, &expected_output_size,
1790           &pad_low, &pad_high);
1791       // Return failure if expected_output_size could not be calculated.
1792       if (!status.ok()) return failure();
1793       return_shape[dim] = expected_output_size;
1794     }
1795   }
1796 
1797   inferredReturnShapes.emplace_back(return_shape, input_ty.getElementType());
1798   return success();
1799 }
1800 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1801 LogicalResult Conv2DOp::inferReturnTypeComponents(
1802     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
1803     DictionaryAttr attributes, RegionRange regions,
1804     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1805   Conv2DOpAdaptor op(operands.getValues(), attributes);
1806   ArrayRef<Attribute> explicit_padding;
1807   ArrayAttr explicit_pad =
1808       attributes.get("explicit_paddings").dyn_cast_or_null<::mlir::ArrayAttr>();
1809   if (!explicit_pad) {
1810     explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({});
1811   }
1812   explicit_padding = explicit_pad.getValue();
1813 
1814   return inferConvReturnTypeComponents(location, op, explicit_padding,
1815                                        inferredReturnShapes);
1816 }
1817 
GetOptimalLayout(const RuntimeDevices & devices)1818 StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) {
1819   // Keep current data format if no GPUs are available or if explicit placement
1820   // does not allow to use GPU for this operation.
1821   if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
1822     return data_format();
1823 
1824   // Input must be a tensor.
1825   auto input_ty = input().getType().dyn_cast<TensorType>();
1826   if (!input_ty) return data_format();
1827 
1828   // For f16 data type on devices with Tensor Cores support NHWC data format
1829   // is up to ~2x faster.
1830   const bool is_f16 = input_ty.getElementType().isF16();
1831   if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
1832 
1833   // For f32/f16 data type decision depends on the filter size in spatial
1834   // dimensions, for other data types we keep current data format.
1835   if (!input_ty.getElementType().isF32() && !input_ty.getElementType().isF16())
1836     return data_format();
1837 
1838   // Keep current data format if filter rank is unknown or not equal to 4.
1839   auto filter_ty = filter().getType().dyn_cast<RankedTensorType>();
1840   if (!filter_ty || filter_ty.getRank() != 4) return data_format();
1841 
1842   const int64_t d0 = filter_ty.getDimSize(0);
1843   const int64_t d1 = filter_ty.getDimSize(1);
1844 
1845   auto all_ones = [](ArrayAttr arr) -> bool {
1846     return llvm::all_of(arr, [](Attribute attr) -> bool {
1847       return attr.cast<IntegerAttr>().getInt() == 1;
1848     });
1849   };
1850 
1851   // Convolutions with 1x1 filter and with strides and dilations all ones, can
1852   // be computed as a GEMM in NHWC data format, and can be up to ~2x times
1853   // faster than convolution in NCHW.
1854   const bool one_by_one = d0 == 1 && d1 == 1;
1855   const bool trivial_strides = all_ones(strides());
1856   const bool trivial_dilations = all_ones(dilations());
1857 
1858   // TODO(ezhulenev): This might lead to excessive transposes in the final IR,
1859   // if the ratio of 1x1 convolutions to regular convolutions is close to 1:1.
1860   // Also FusedBatchNorm in training mode prefers NCHW data format. Check if all
1861   // users can efficiently use NHWC data format?
1862   if (one_by_one && trivial_strides && trivial_dilations) {
1863     return "NHWC";
1864   }
1865 
1866   // If filter spatial dimensions are unknown or not 1x1 we prefer NCHW, because
1867   // it's the fastest option on NVIDIA GPUs with cuDNN library support.
1868   return "NCHW";
1869 }
1870 
1871 //===----------------------------------------------------------------------===//
1872 // Conv2dBackpropFilterOp
1873 //===----------------------------------------------------------------------===//
1874 
UpdateDataFormat(StringRef data_format)1875 LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) {
1876   StringRef src_data_format = this->data_format();
1877 
1878   auto perm = GetDataFormatPermutation(src_data_format, data_format);
1879   if (perm.empty()) return failure();
1880 
1881   // Update data_format attribute and result types.
1882   if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
1883 
1884   // Update convolution attributes.
1885   (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
1886   (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
1887   (*this)->setAttr("explicit_paddings",
1888                    ShuffleArrayAttr(explicit_paddings(), perm, 2));
1889 
1890   // Permute filter sizes operand.
1891   OpBuilder builder(getOperation());
1892   auto filter_sizes_permuted = builder.create<TF::DataFormatVecPermuteOp>(
1893       getLoc(), filter_sizes(), StringAttr::get(getContext(), src_data_format),
1894       StringAttr::get(getContext(), data_format));
1895   setOperand(1, filter_sizes_permuted);
1896 
1897   return success();
1898 }
1899 
GetOptimalLayout(const RuntimeDevices & devices)1900 StringRef Conv2DBackpropFilterOp::GetOptimalLayout(
1901     const RuntimeDevices &devices) {
1902   // Keep current data format if no GPUs are available or if explicit placement
1903   // does not allow to use GPU for this operation.
1904   if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
1905     return data_format();
1906 
1907   // Input must be a tensor.
1908   auto input_ty = input().getType().dyn_cast<TensorType>();
1909   if (!input_ty) return data_format();
1910 
1911   // For f16 data type on devices with Tensor Cores support NHWC data format
1912   // is up to ~2x faster.
1913   const bool is_f16 = input_ty.getElementType().isF16();
1914   if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
1915 
1916   // Otherwise always use "NCHW".
1917   return "NCHW";
1918 }
1919 
1920 //===----------------------------------------------------------------------===//
1921 // Conv2DBackpropInputOp
1922 //===----------------------------------------------------------------------===//
1923 
verify()1924 LogicalResult Conv2DBackpropInputOp::verify() {
1925   Conv2DBackpropInputOp op = *this;
1926   int num_spatial_dims = 2;
1927   int num_dims = 2 + num_spatial_dims;
1928 
1929   if (!IsOfRankOrUnranked(op.out_backprop(), num_dims) ||
1930       !IsOfRankOrUnranked(op.filter(), num_dims))
1931     return op.emitOpError()
1932            << "requires operands to be " << num_dims << "D tensor";
1933   if (!IsOfRankOrUnranked(op.getResult(), num_dims))
1934     return op.emitOpError()
1935            << "requires result to be " << num_dims << "D tensor";
1936 
1937   llvm::Optional<mlir::Location> location = op.getLoc();
1938   ArrayRef<Attribute> strides = op.strides().getValue();
1939   ArrayRef<Attribute> dilations = op.dilations().getValue();
1940   LogicalResult verify_result =
1941       VerifyConvOpAttributes(num_dims, strides, dilations, location);
1942   if (failed(verify_result)) {
1943     return verify_result;
1944   }
1945 
1946   return success();
1947 }
1948 
UpdateDataFormat(StringRef data_format)1949 LogicalResult Conv2DBackpropInputOp::UpdateDataFormat(StringRef data_format) {
1950   StringRef src_data_format = this->data_format();
1951 
1952   auto perm = GetDataFormatPermutation(src_data_format, data_format);
1953   if (perm.empty()) return failure();
1954 
1955   // Update data_format attribute and result types.
1956   if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
1957 
1958   // Update convolution attributes.
1959   (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
1960   (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
1961   (*this)->setAttr("explicit_paddings",
1962                    ShuffleArrayAttr(explicit_paddings(), perm, 2));
1963 
1964   // Permute input sizes operand.
1965   OpBuilder builder(getOperation());
1966   auto input_sizes_permuted = builder.create<TF::DataFormatVecPermuteOp>(
1967       getLoc(), input_sizes(), StringAttr::get(getContext(), src_data_format),
1968       StringAttr::get(getContext(), data_format));
1969   setOperand(0, input_sizes_permuted);
1970 
1971   return success();
1972 }
1973 
GetOptimalLayout(const RuntimeDevices & devices)1974 StringRef Conv2DBackpropInputOp::GetOptimalLayout(
1975     const RuntimeDevices &devices) {
1976   // Keep current data format if no GPUs are available or if explicit placement
1977   // does not allow to use GPU for this operation.
1978   if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
1979     return data_format();
1980 
1981   // Filter must be a tensor.
1982   auto filter_ty = filter().getType().dyn_cast<TensorType>();
1983   if (!filter_ty) return data_format();
1984 
1985   // For f16 data type on devices with Tensor Cores support NHWC data format
1986   // is up to ~2x faster.
1987   const bool is_f16 = filter_ty.getElementType().isF16();
1988   if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
1989 
1990   // Otherwise always use "NCHW".
1991   return "NCHW";
1992 }
1993 
1994 //===----------------------------------------------------------------------===//
1995 // Conv3DOp
1996 //===----------------------------------------------------------------------===//
1997 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1998 LogicalResult Conv3DOp::inferReturnTypeComponents(
1999     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
2000     DictionaryAttr attributes, RegionRange regions,
2001     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2002   Conv3DOpAdaptor op(operands.getValues(), attributes);
2003   ArrayRef<Attribute> explicit_padding;
2004   ArrayAttr explicit_pad =
2005       attributes.get("explicit_paddings").dyn_cast_or_null<::mlir::ArrayAttr>();
2006   if (!explicit_pad) {
2007     explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({});
2008   }
2009   explicit_padding = explicit_pad.getValue();
2010 
2011   return inferConvReturnTypeComponents(location, op, explicit_padding,
2012                                        inferredReturnShapes);
2013 }
2014 
2015 //===----------------------------------------------------------------------===//
2016 // DataFormatVecPermuteOp
2017 //===----------------------------------------------------------------------===//
2018 
verify()2019 LogicalResult DataFormatVecPermuteOp::verify() {
2020   DataFormatVecPermuteOp op = *this;
2021   auto input_ty = op.x().getType().dyn_cast<RankedTensorType>();
2022   if (!input_ty) return success();
2023 
2024   int rank = input_ty.getRank();
2025   if (rank != 1 && rank != 2)
2026     return op.emitOpError("requires input of rank 1 or 2");
2027 
2028   if (rank == 1) {
2029     int64_t dim0 = input_ty.getDimSize(0);
2030     if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2)
2031       return op.emitOpError("requires 1D input of size 4 or size 2");
2032   }
2033 
2034   if (rank == 2) {
2035     int64_t dim0 = input_ty.getDimSize(0);
2036     if (dim0 != ShapedType::kDynamicSize && dim0 != 4)
2037       return op.emitOpError(
2038           "requires first dimensions of 2D input to be of size 4");
2039 
2040     int64_t dim1 = input_ty.getDimSize(1);
2041     if (dim1 != ShapedType::kDynamicSize && dim1 != 2)
2042       return op.emitOpError(
2043           "requires second dimensions of 2D input to be of size 2");
2044   }
2045 
2046   return success();
2047 }
2048 
2049 //===----------------------------------------------------------------------===//
2050 // DivNoNanOp
2051 //===----------------------------------------------------------------------===//
2052 
2053 namespace {
2054 
2055 /// Canonicalization template for tf.DivNoNan and tf.MulNoNan:
2056 /// If the op is tf.DivNoNan and the divisor is a constant tensor (with all the
2057 /// elements of any allowed type: float or complex), rewrite the op to the
2058 /// divisor if all the elements of the divisor are zero and to tf.Div if all the
2059 /// elements of the divisor are non-zero.
2060 
2061 /// Similarly, if the op is tf.MulNoNan and the multiplier is a constant tensor
2062 /// (with all the elements of any allowed type: float or complex), rewrite the
2063 /// op to the multiplier if all the elements of the multiplier are zero and to
2064 /// tf.Mul if all the elements of the multiplier are non-zero.
2065 
2066 /// Replace the given op with an op of type `RetT`. Upon calling
2067 /// DivNoNanOrMulNoNanConstantY for canonicalizing tf.DivNoNan, tf.DivOp is
2068 /// passed as the second argument and for canonicalizing tf.MulNoNan, tf.MulOp
2069 /// is passed as the second argument.
2070 template <typename OpT, typename RetT>
2071 class DivNoNanOrMulNoNanConstantY : public OpRewritePattern<OpT> {
2072   using OpRewritePattern<OpT>::OpRewritePattern;
2073 
matchAndRewrite(OpT op,PatternRewriter & rewriter) const2074   LogicalResult matchAndRewrite(OpT op,
2075                                 PatternRewriter &rewriter) const override {
2076     static_assert(
2077         llvm::is_one_of<OpT, DivNoNanOp, MulNoNanOp>::value,
2078         "only canonicalization of tf.DivNoNan and tf.MulNoNan is supported");
2079 
2080     // Returns true iff `val` (a complex constant with float real and imaginary
2081     // parts) is zero.
2082     auto complexIsZero = [](const std::complex<APFloat> val) {
2083       // Note that when `val` is of complex type, it is zero iff both
2084       // its real and imaginary parts are zero.
2085       if (val.real().isZero() && val.imag().isZero())
2086         return true;
2087       else
2088         return false;
2089     };
2090 
2091     // Returns true iff `attr` has both zero and non-zero elements
2092     // (float/complex type) in `attr`.
2093     auto hasBothZeroAndNonzeroElements =
2094         [&complexIsZero](ElementsAttr attr, bool hasComplexElements) {
2095           bool foundZero = false, foundNonzero = false;
2096           if (!hasComplexElements) {
2097             for (const auto val : attr.getValues<APFloat>()) {
2098               if (val.isZero())
2099                 foundZero = true;
2100               else
2101                 foundNonzero = true;
2102               if (foundZero && foundNonzero) return true;
2103             }
2104           } else {
2105             for (const auto val : attr.getValues<std::complex<APFloat>>()) {
2106               if (complexIsZero(val))
2107                 foundZero = true;
2108               else
2109                 foundNonzero = true;
2110               if (foundZero && foundNonzero) return true;
2111             }
2112           }
2113           return false;
2114         };
2115 
2116     // Note that `y` is the divisor if the op is tf.DivNoNan and it is the
2117     // multiplier if the op is tf.MulNoNan.
2118     Value y = op.y();
2119     // The below if condition is true iff `y.getDefiningOp()` is of the type
2120     // TF::ConstOp, i.e., if `y` is defined by an op and it is the tf.Const op.
2121     // In that case, `yDefOp` stores this tf.Const op.
2122     // Note that if `y` is a block argument, `y.getDefiningOp()` will return
2123     // null, which will get propogated by dyn_cast_or_null to `yDefOp`.
2124     // Further, if `y` is defined by an op other than tf.Const,
2125     // `y.getDefiningOp()` will not return null but dyn_cast_or_null will.
2126     if (auto yDefOp = dyn_cast_or_null<TF::ConstOp>(y.getDefiningOp())) {
2127       Type typeOfElementsInY = getElementTypeOrSelf(y.getType());
2128       ElementsAttr attr = yDefOp.value();
2129       bool yHasComplexElements = typeOfElementsInY.isa<ComplexType>();
2130 
2131       // If `y` is a splat constant, then the op will definitely get replaced.
2132       // We check for a splat constant first, in order to optimize the
2133       // performance of this canonicalization because this check will be O(1).
2134       if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
2135         bool splatAttrIsZero = false;
2136         if (!yHasComplexElements) {
2137           if (splatAttr.getSplatValue<APFloat>().isZero())
2138             splatAttrIsZero = true;
2139         } else {
2140           if (complexIsZero(splatAttr.getSplatValue<std::complex<APFloat>>()))
2141             splatAttrIsZero = true;
2142         }
2143         if (splatAttrIsZero) {
2144           // When `y` is a zero splat constant (i.e., all the elements in `y`
2145           // are zero, replace the op (tf.divNoNan or tf.MulNoNan) with `y`.
2146           rewriter.replaceOp(op, y);
2147         } else {
2148           // When `y` is a non-zero splat constant, replace tf.DivNoNan with
2149           // tf.Div and tf.MulNoNan with tf.Mul.
2150           rewriter.replaceOpWithNewOp<RetT>(op, op->getResult(0).getType(),
2151                                             op->getOperand(0),
2152                                             op->getOperand(1));
2153         }
2154         return success();
2155       }
2156 
2157       // If `y` has both zero and non-zero elements, do nothing.
2158       if (hasBothZeroAndNonzeroElements(attr, yHasComplexElements)) {
2159         return failure();
2160       } else {
2161         // When all the elements in `y` are non-splat and non-zero, replace
2162         // tf.DivNoNan with tf.Div and tf.MulNoNan with tf.Mul.
2163         rewriter.replaceOpWithNewOp<RetT>(op, op->getResult(0).getType(),
2164                                           op->getOperand(0), op->getOperand(1));
2165         return success();
2166       }
2167     }
2168     return failure();
2169   }
2170 };
2171 }  // namespace
2172 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2173 void DivNoNanOp::getCanonicalizationPatterns(RewritePatternSet &results,
2174                                              MLIRContext *context) {
2175   results.add<DivNoNanOrMulNoNanConstantY<TF::DivNoNanOp, TF::DivOp>>(context);
2176 }
2177 
2178 //===----------------------------------------------------------------------===//
2179 // DivOp
2180 //===----------------------------------------------------------------------===//
2181 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2182 void DivOp::getCanonicalizationPatterns(RewritePatternSet &results,
2183                                         MLIRContext *context) {
2184   results.add<DivWithSqrtDivisor>(context);
2185 }
2186 
fold(ArrayRef<Attribute> operands)2187 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
2188   return IdentityArithmeticOpFolder<DivOp>(*this, operands);
2189 }
2190 
2191 //===----------------------------------------------------------------------===//
2192 // DynamicStitchOp
2193 //===----------------------------------------------------------------------===//
2194 
verify()2195 LogicalResult DynamicStitchOp::verify() {
2196   DynamicStitchOp op = *this;
2197   if (op.N() < 1) return op.emitOpError("requires attribute N with value >= 1");
2198 
2199   if (RankedTensorType out_ty = op.getType().dyn_cast<RankedTensorType>()) {
2200     if (out_ty.getRank() == 0) {
2201       return op.emitOpError("requires non scalar output");
2202     }
2203   }
2204 
2205   llvm::SmallDenseSet<int64_t, 8> index_values;
2206   bool all_indices_const = true;
2207   int32_t max_index = -1;
2208   llvm::Optional<SmallVector<int64_t, 4>> inferred_item_shape;
2209   for (auto it : llvm::zip(op.indices(), op.data())) {
2210     Value index = std::get<0>(it);
2211 
2212     DenseIntElementsAttr index_attr;
2213     if (matchPattern(index, m_Constant(&index_attr))) {
2214       for (int32_t index : index_attr.getValues<int32_t>()) {
2215         if (index < 0)
2216           return op.emitOpError()
2217                  << "requires non-negative index values; found " << index;
2218         max_index = std::max(index, max_index);
2219         index_values.insert(index);
2220       }
2221     } else {
2222       all_indices_const = false;
2223     }
2224 
2225     Value data = std::get<1>(it);
2226     RankedTensorType index_ty = index.getType().dyn_cast<RankedTensorType>();
2227     RankedTensorType data_ty = data.getType().dyn_cast<RankedTensorType>();
2228     if (!index_ty || !data_ty) continue;
2229 
2230     int64_t index_rank = index_ty.getRank();
2231     ArrayRef<int64_t> data_shape = data_ty.getShape();
2232     ArrayRef<int64_t> index_shape = index_ty.getShape();
2233     if (failed(mlir::verifyCompatibleShape(index_shape,
2234                                            data_shape.take_front(index_rank))))
2235       return op.emitOpError() << "requires shape of data with type " << data_ty
2236                               << " to have prefix matching with shape of the "
2237                                  "corresponding index type "
2238                               << index_ty;
2239 
2240     ArrayRef<int64_t> item_shape = data_shape.drop_front(index_rank);
2241     if (!inferred_item_shape) {
2242       inferred_item_shape = llvm::to_vector<4>(item_shape);
2243       continue;
2244     }
2245 
2246     if (failed(mlir::verifyCompatibleShape(item_shape, *inferred_item_shape)))
2247       return op.emitOpError() << "has inconsistent shaped data and index "
2248                                  "pairs; inferred item shapes ["
2249                               << llvm::makeArrayRef(*inferred_item_shape)
2250                               << "] and [" << item_shape << "] don't match";
2251     for (int i = 0, e = item_shape.size(); i < e; ++i) {
2252       int64_t &inferred_dim = (*inferred_item_shape)[i];
2253       int64_t dim = item_shape[i];
2254       if (ShapedType::isDynamic(inferred_dim)) inferred_dim = dim;
2255     }
2256   }
2257 
2258   // If all indices are constants, then verify that they cover all indices in
2259   // the range [0, max_index] and the output type is legal.
2260   if (all_indices_const) {
2261     for (int32_t i = 0; i <= max_index; i++) {
2262       if (!index_values.count(i))
2263         return op.emitOpError() << "missing index " << i;
2264     }
2265 
2266     if (inferred_item_shape) {
2267       SmallVector<int64_t, 4> expected_shape;
2268       expected_shape.push_back(max_index + 1);
2269       expected_shape.append(inferred_item_shape->begin(),
2270                             inferred_item_shape->end());
2271 
2272       auto out_ty = op.getType().cast<TensorType>();
2273       auto expected_out_ty =
2274           RankedTensorType::get(expected_shape, out_ty.getElementType());
2275 
2276       if (!AreCastCompatible({out_ty, expected_out_ty})) {
2277         return op.emitOpError() << "has invalid output type; should be "
2278                                    "compatible with inferred type "
2279                                 << expected_out_ty;
2280       }
2281     }
2282   }
2283 
2284   return success();
2285 }
2286 
2287 //===----------------------------------------------------------------------===//
2288 // EinsumOp
2289 //===----------------------------------------------------------------------===//
2290 
2291 // Verifies that,
2292 // * Arity of the op is at most two.
2293 //
2294 // TODO(hinsu): Verify einsum equation attribute.
verify()2295 LogicalResult EinsumOp::verify() {
2296   EinsumOp op = *this;
2297   if (op.N() > 2) {
2298     return op.emitOpError("supports at most two operands");
2299   }
2300   return success();
2301 }
2302 
2303 //===----------------------------------------------------------------------===//
2304 // EmptyOp
2305 //===----------------------------------------------------------------------===//
2306 
fold(ArrayRef<Attribute> operands)2307 OpFoldResult EmptyOp::fold(ArrayRef<Attribute> operands) {
2308   assert(operands.size() == 1 && "empty op has one operand");
2309 
2310   Attribute attr = operands.front();
2311   if (!attr) return {};
2312 
2313   auto int_attr = attr.cast<DenseIntElementsAttr>();
2314   SmallVector<int64_t, 6> out_shape;
2315   for (const auto val : int_attr.getValues<int32_t>()) {
2316     out_shape.push_back(val);
2317   }
2318 
2319   auto type = getResult().getType().cast<ShapedType>();
2320   auto etype = type.getElementType();
2321 
2322   // We can not fold if the result is not static.
2323   if (!type.hasStaticShape()) return {};
2324 
2325   if (auto float_type = etype.dyn_cast<FloatType>()) {
2326     auto out_type = RankedTensorType::get(out_shape, float_type);
2327     return DenseElementsAttr::get(out_type,
2328                                   {APFloat(float_type.getFloatSemantics())});
2329   }
2330 
2331   if (auto int_type = etype.dyn_cast<IntegerType>()) {
2332     auto out_type = RankedTensorType::get(out_shape, etype);
2333     APInt val(int_type.getWidth(), 0, int_type.getSignedness());
2334     return DenseElementsAttr::get(out_type, val);
2335   }
2336 
2337   return {};
2338 }
2339 
2340 //===----------------------------------------------------------------------===//
2341 // EmptyTensorListOp
2342 //===----------------------------------------------------------------------===//
2343 
verify()2344 LogicalResult EmptyTensorListOp::verify() {
2345   EmptyTensorListOp op = *this;
2346   // This is required to populate derived attributes during export in a
2347   // meaningful way. Else during export to GraphDef element_type() query
2348   // will result in out of bounds access/assert.
2349   if (handle_dtype().getSubtypes().size() != 1) {
2350     return emitOpError(
2351         "must have exactly one subtype in the result variant type");
2352   }
2353 
2354   if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2355       !IsOfRankOrUnranked(op.element_shape(), 1)) {
2356     return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2357   }
2358 
2359   if (!IsOfRankOrUnranked(op.max_num_elements(), 0)) {
2360     return op.emitOpError("requires max_num_elements operand to be 0D tensor");
2361   }
2362   return success();
2363 }
2364 
2365 //===----------------------------------------------------------------------===//
2366 // EnqueueTPUEmbedding ops
2367 //===----------------------------------------------------------------------===//
2368 
2369 // For EnqueueTPUEmbedding ops the device ordinal corresponds to the resource
2370 // instance.
2371 
2372 std::string
GetResourceInstanceStr()2373 EnqueueTPUEmbeddingArbitraryTensorBatchOp::GetResourceInstanceStr() {
2374   return std::to_string(device_ordinal());
2375 }
2376 
GetResourceInstanceStr()2377 std::string EnqueueTPUEmbeddingBatchOp::GetResourceInstanceStr() {
2378   return std::to_string(device_ordinal());
2379 }
2380 
GetResourceInstanceStr()2381 std::string EnqueueTPUEmbeddingIntegerBatchOp::GetResourceInstanceStr() {
2382   return std::to_string(device_ordinal());
2383 }
2384 
GetResourceInstanceStr()2385 std::string EnqueueTPUEmbeddingRaggedTensorBatchOp::GetResourceInstanceStr() {
2386   return std::to_string(device_ordinal());
2387 }
2388 
GetResourceInstanceStr()2389 std::string EnqueueTPUEmbeddingSparseBatchOp::GetResourceInstanceStr() {
2390   return std::to_string(device_ordinal());
2391 }
2392 
GetResourceInstanceStr()2393 std::string EnqueueTPUEmbeddingSparseTensorBatchOp::GetResourceInstanceStr() {
2394   return std::to_string(device_ordinal());
2395 }
2396 
2397 //===----------------------------------------------------------------------===//
2398 // EnsureShapeOp
2399 //===----------------------------------------------------------------------===//
2400 
fold(llvm::ArrayRef<mlir::Attribute>)2401 OpFoldResult EnsureShapeOp::fold(llvm::ArrayRef<mlir::Attribute>) {
2402   ShapedType type = input().getType().dyn_cast<ShapedType>();
2403   if (!type || !type.hasRank()) return {};
2404   // If shape attribute equals input operand's type's shape, fold it to input.
2405   llvm::Optional<llvm::ArrayRef<int64_t>> shape_constraint = shape();
2406   if (type.getShape() == shape_constraint) return input();
2407 
2408   // If input operand's type's shape always satisfies the shape attribute, fold
2409   // it to input.
2410   if (shape_constraint.has_value() &&
2411       shape_constraint->size() == type.getShape().size()) {
2412     for (int i = 0; i < shape_constraint->size(); ++i) {
2413       if (!ShapedType::isDynamic(shape_constraint.getValue()[i]) &&
2414           type.getDimSize(i) != shape_constraint.getValue()[i]) {
2415         return {};
2416       }
2417     }
2418     return input();
2419   }
2420   // Else retain to enable failing dynamically.
2421   return {};
2422 }
2423 
2424 //===----------------------------------------------------------------------===//
2425 // EqualOp/NotEqualOp
2426 //===----------------------------------------------------------------------===//
2427 
verify()2428 LogicalResult EqualOp::verify() {
2429   EqualOp op = *this;
2430   // If we allow inputs to have incompatible type, then nothing to do.
2431   if (!op.incompatible_shape_error()) return success();
2432 
2433   // Otherwise, check inputs are broadcastable.
2434   return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
2435       op.getOperation());
2436 }
2437 
build(OpBuilder & builder,OperationState & result,Value x,Value y,BoolAttr incompatible_shape_error)2438 void EqualOp::build(OpBuilder &builder, OperationState &result, Value x,
2439                     Value y, BoolAttr incompatible_shape_error) {
2440   auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y,
2441                                           incompatible_shape_error);
2442   return build(builder, result, result_type, x, y, incompatible_shape_error);
2443 }
2444 
2445 namespace {
2446 
2447 // Flips the incompatible_shape_error attribute to true if the shapes are known
2448 // to be compatible.
2449 template <typename Ty>
flipComatibleShapeError(Ty op,PatternRewriter & rewriter)2450 static LogicalResult flipComatibleShapeError(Ty op, PatternRewriter &rewriter) {
2451   if (op.incompatible_shape_error()) {
2452     return rewriter.notifyMatchFailure(op, "the attribute is already true");
2453   }
2454 
2455   // incompatible_shape_error=false implies that the op will either return a
2456   // valid result or a scalar boolean indicating the error. For unranked outputs
2457   // we don't know which one it is. TF shape inference turns unranked outputs
2458   // into ranked ones if it can statically evaluate the broadcast, see the shape
2459   // function of tf.Equal.
2460   auto ty = op.getType().template dyn_cast<RankedTensorType>();
2461   if (!ty) {
2462     return rewriter.notifyMatchFailure(op, "requires a ranked output shape");
2463   }
2464 
2465   // Unless this is a scalar compare, a scalar output indicates that this will
2466   // always fail.
2467   auto x_ty = op.x().getType().template dyn_cast<RankedTensorType>();
2468   auto y_ty = op.y().getType().template dyn_cast<RankedTensorType>();
2469   if (ty.getRank() == 0 &&
2470       (!x_ty || x_ty.getRank() != 0 || !y_ty || y_ty.getRank() != 0)) {
2471     return rewriter.notifyMatchFailure(op, "output rank must match input rank");
2472   }
2473 
2474   // Shapes are known to be compatible.
2475   rewriter.template replaceOpWithNewOp<Ty>(op, op.x(), op.y(),
2476                                            rewriter.getBoolAttr(true));
2477   return success();
2478 }
2479 }  // namespace
2480 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2481 void EqualOp::getCanonicalizationPatterns(RewritePatternSet &results,
2482                                           MLIRContext *context) {
2483   results.add(flipComatibleShapeError<EqualOp>);
2484 }
2485 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2486 void NotEqualOp::getCanonicalizationPatterns(RewritePatternSet &results,
2487                                              MLIRContext *context) {
2488   results.add(flipComatibleShapeError<NotEqualOp>);
2489 }
2490 
2491 //===----------------------------------------------------------------------===//
2492 // ExpandDimsOp
2493 //===----------------------------------------------------------------------===//
2494 
InferExpandDimsOpType(Value input,Value dim)2495 Type InferExpandDimsOpType(Value input, Value dim) {
2496   Type element_ty = input.getType().cast<TensorType>().getElementType();
2497   auto unranked_ty = UnrankedTensorType::get(element_ty);
2498 
2499   auto input_ty = input.getType().dyn_cast<RankedTensorType>();
2500   if (!input_ty) return unranked_ty;
2501 
2502   DenseIntElementsAttr dim_attr;
2503   if (!matchPattern(dim, m_Constant(&dim_attr)) ||
2504       dim_attr.getNumElements() != 1)
2505     return unranked_ty;
2506   int64_t dim_val = (*dim_attr.begin()).getSExtValue();
2507   int64_t input_rank = input_ty.getRank();
2508 
2509   if (dim_val < -input_rank - 1 || dim_val > input_rank + 1) return unranked_ty;
2510   if (dim_val < 0) dim_val += input_rank + 1;
2511 
2512   SmallVector<int64_t, 4> shape = llvm::to_vector<4>(input_ty.getShape());
2513   shape.insert(shape.begin() + dim_val, 1);
2514   return RankedTensorType::get(shape, element_ty);
2515 }
2516 
build(OpBuilder & builder,OperationState & result,Value input,Value dim)2517 void ExpandDimsOp::build(OpBuilder &builder, OperationState &result,
2518                          Value input, Value dim) {
2519   return build(builder, result, InferExpandDimsOpType(input, dim), input, dim);
2520 }
2521 
2522 //===----------------------------------------------------------------------===//
2523 // Expm1Op
2524 //===----------------------------------------------------------------------===//
2525 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)2526 LogicalResult Expm1Op::inferReturnTypeComponents(
2527     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
2528     DictionaryAttr attributes, RegionRange regions,
2529     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2530   ShapeAdaptor adaptor = operands.getShape(0);
2531   ShapedTypeComponents component(adaptor.getElementType());
2532   if (adaptor.hasRank()) adaptor.getDims(component);
2533   inferredReturnShapes.push_back(component);
2534   return success();
2535 }
2536 
2537 //===----------------------------------------------------------------------===//
2538 // FakeQuantWithMinMaxArgsOp
2539 //===----------------------------------------------------------------------===//
verify()2540 LogicalResult FakeQuantWithMinMaxArgsOp::verify() {
2541   FakeQuantWithMinMaxArgsOp op = *this;
2542   // TODO(fengliuai): moving the following to an utility method.
2543   const llvm::fltSemantics &semantics = op.min().getSemantics();
2544   float rmin, rmax;
2545   if (&semantics == &APFloat::IEEEsingle()) {
2546     rmin = op.min().convertToFloat();
2547     rmax = op.max().convertToFloat();
2548   } else {
2549     rmin = op.min().convertToDouble();
2550     rmax = op.max().convertToDouble();
2551   }
2552   // Range boundaries must be valid.
2553   if (rmin >= rmax) {
2554     return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) +
2555                           "," + Twine(std::to_string(rmax)) + "]");
2556   }
2557   int64_t num_bits = op.num_bits();
2558   if (num_bits < 2 || num_bits > 16) {
2559     return op.emitOpError(
2560         "requires num_bits to be between 2 and 16, inclusive");
2561   }
2562   return success();
2563 }
2564 
2565 //===----------------------------------------------------------------------===//
2566 // FakeQuantWithMinMaxVarsOp
2567 //===----------------------------------------------------------------------===//
verify()2568 LogicalResult FakeQuantWithMinMaxVarsOp::verify() {
2569   FakeQuantWithMinMaxVarsOp op = *this;
2570   auto min = GetRankedTensorTypeForOperand(op.min());
2571   if (min && !IsOfRankedFloatTensorType(min, 0))
2572     return op.emitOpError("requires min to be a 0d float tensor");
2573 
2574   auto max = GetRankedTensorTypeForOperand(op.max());
2575   if (max && !IsOfRankedFloatTensorType(max, 0))
2576     return op.emitOpError("requires max to be a 0d float tensor");
2577 
2578   int64_t num_bits = op.num_bits();
2579   if (num_bits < 2 || num_bits > 16) {
2580     return op.emitOpError(
2581         "requires num_bits to be between 2 and 16, inclusive");
2582   }
2583   return success();
2584 }
2585 
2586 //===----------------------------------------------------------------------===//
2587 // FakeQuantWithMinMaxVarsPerChannelOp
2588 //===----------------------------------------------------------------------===//
verify()2589 LogicalResult FakeQuantWithMinMaxVarsPerChannelOp::verify() {
2590   FakeQuantWithMinMaxVarsPerChannelOp op = *this;
2591   auto min = GetRankedTensorTypeForOperand(op.min());
2592   if (min && !IsOfRankedFloatTensorType(min, 1))
2593     return op.emitOpError("requires min to be a 1d float tensor");
2594 
2595   auto max = GetRankedTensorTypeForOperand(op.max());
2596   if (max && !IsOfRankedFloatTensorType(max, 1))
2597     return op.emitOpError("requires max to be a 1d float tensor");
2598 
2599   Value inputs = op.inputs();
2600   if (!HasRankAtLeast(inputs, 1))
2601     return op.emitError("requires inputs to be at least 1d float tensor");
2602 
2603   int64_t num_bits = op.num_bits();
2604   if (num_bits < 2 || num_bits > 16) {
2605     return op.emitOpError(
2606         "requires num_bits to be between 2 and 16, inclusive");
2607   }
2608 
2609   auto inputs_type = inputs.getType().dyn_cast<RankedTensorType>();
2610   if (!inputs_type) return success();
2611   int depth = inputs_type.getDimSize(inputs_type.getRank() - 1);
2612   if ((min && min.getDimSize(0) != depth) ||
2613       (max && max.getDimSize(0) != depth)) {
2614     return op.emitOpError(
2615         "requires min and max to have same size as last dimension of inputs");
2616   }
2617 
2618   return success();
2619 }
2620 
2621 //===----------------------------------------------------------------------===//
2622 // FillOp
2623 //===----------------------------------------------------------------------===//
2624 
verify()2625 LogicalResult FillOp::verify() {
2626   FillOp op = *this;
2627   if (!IsOfRankOrUnranked(op.dims(), 1))
2628     return op.emitOpError() << "requires dims to be a 1D tensor";
2629   if (!IsOfRankOrUnranked(op.value(), 0))
2630     return op.emitOpError() << "requires value to be a scalar";
2631 
2632   return success();
2633 }
2634 
InferFillOpType(Value dims,Value value)2635 static ShapedType InferFillOpType(Value dims, Value value) {
2636   Type etype = value.getType().cast<ShapedType>().getElementType();
2637 
2638   DenseIntElementsAttr dims_attr;
2639   if (matchPattern(dims, m_Constant(&dims_attr))) {
2640     llvm::SmallVector<int64_t, 4> shape;
2641     shape.reserve(dims_attr.getNumElements());
2642     for (const APInt dim : dims_attr.getValues<APInt>()) {
2643       shape.push_back(dim.getSExtValue());
2644     }
2645     return RankedTensorType::get(shape, etype);
2646   }
2647 
2648   if (auto shape_op = dims.getDefiningOp<ShapeOp>()) {
2649     if (auto t = shape_op.input().getType().dyn_cast<ShapedType>()) {
2650       return t;
2651     }
2652   }
2653 
2654   return UnrankedTensorType::get(etype);
2655 }
2656 
build(OpBuilder & builder,OperationState & result,Value dims,Value value)2657 void FillOp::build(OpBuilder &builder, OperationState &result, Value dims,
2658                    Value value) {
2659   FillOp::build(builder, result, InferFillOpType(dims, value), dims, value);
2660 }
2661 
fold(ArrayRef<Attribute> operands)2662 OpFoldResult FillOp::fold(ArrayRef<Attribute> operands) {
2663   assert(operands.size() == 2 && "fill op has two operand");
2664 
2665   auto type = getType().cast<ShapedType>();
2666   // DenseElementsAttr that is used in this folder only supports int and float
2667   // types.
2668   // TODO(hinsu): Handle complex types once there is a attribute kind for
2669   // complex.
2670   if (!type.getElementType().isIntOrFloat()) return {};
2671 
2672   auto value = operands[1].dyn_cast_or_null<ElementsAttr>();
2673   if (!value) return {};
2674 
2675   if (type.hasStaticShape())
2676     return DenseElementsAttr::get(type, value.getValues<Attribute>()[0]);
2677 
2678   auto dims = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
2679   if (!dims) return {};
2680 
2681   llvm::SmallVector<int64_t, 4> shape;
2682   shape.reserve(dims.getNumElements());
2683   for (const APInt dim : dims.getValues<APInt>()) {
2684     shape.push_back(dim.getSExtValue());
2685   }
2686   type = RankedTensorType::get(shape, type.getElementType());
2687 
2688   return DenseElementsAttr::get(type, value.getValues<Attribute>()[0]);
2689 }
2690 
2691 //===----------------------------------------------------------------------===//
2692 // FusedBatchNormGradOp
2693 //===----------------------------------------------------------------------===//
2694 
2695 // TODO(b/150954845): Add benchmarks to verify that layout preference didn't
2696 // change in the latest GPU generations.
2697 
UpdateDataFormat(StringRef data_format)2698 LogicalResult FusedBatchNormGradV3Op::UpdateDataFormat(StringRef data_format) {
2699   return ::mlir::TF::UpdateDataFormat(data_format, this);
2700 }
2701 
GetOptimalLayout(const RuntimeDevices & devices)2702 StringRef FusedBatchNormGradV3Op::GetOptimalLayout(
2703     const RuntimeDevices &devices) {
2704   // Keep current data format if no GPUs are available or if explicit placement
2705   // does not allow to use GPU for this operation.
2706   if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
2707     return data_format();
2708 
2709   // For f16 data type on devices with Tensor Cores support NHWC data format
2710   // is up to ~2x faster.
2711   auto x_ty = x().getType().cast<TensorType>();
2712   const bool is_f16 = x_ty.getElementType().isF16();
2713   if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
2714 
2715   // For all other data types prefer NCHW.
2716   return "NCHW";
2717 }
2718 
2719 //===----------------------------------------------------------------------===//
2720 // FusedBatchNormOp
2721 //===----------------------------------------------------------------------===//
2722 
verify()2723 LogicalResult FusedBatchNormOp::verify() {
2724   FusedBatchNormOp op = *this;
2725   auto x = GetRankedTensorTypeForOperand(op.x());
2726   if (x && !IsOfRankedFloatTensorType(x, 4))
2727     return op.emitOpError("requires x to be a 4D float tensor");
2728 
2729   auto scale = GetRankedTensorTypeForOperand(op.scale());
2730   if (scale && !IsOfRankedFloatTensorType(scale, 1))
2731     return op.emitOpError("requires scale to be a 1D float tensor");
2732 
2733   auto offset = GetRankedTensorTypeForOperand(op.offset());
2734   if (offset && !IsOfRankedFloatTensorType(offset, 1))
2735     return op.emitOpError("requires offset to be a 1D float tensor");
2736 
2737   auto mean = GetRankedTensorTypeForOperand(op.mean());
2738   if (mean && !IsOfRankedFloatTensorType(mean, 1))
2739     return op.emitOpError("requires mean to be a 1D float tensor");
2740 
2741   auto variance = GetRankedTensorTypeForOperand(op.variance());
2742   if (variance && !IsOfRankedFloatTensorType(variance, 1))
2743     return op.emitOpError("requires variance to be a 1D float tensor");
2744 
2745   // TODO(antiagainst): check attributes
2746 
2747   return success();
2748 }
2749 
2750 //===----------------------------------------------------------------------===//
2751 // FusedBatchNormV2Op / FusedBatchNormV3Op
2752 //===----------------------------------------------------------------------===//
2753 
2754 template <class Op>
InferenceFoldOperandsPermutation(ArrayRef<int64_t> permutation,Op * op)2755 static LogicalResult InferenceFoldOperandsPermutation(
2756     ArrayRef<int64_t> permutation, Op *op) {
2757   // FusedBatchNorm in training mode is a layout sentitive operation, and should
2758   // have already assigned an optimal data format.
2759   if (op->is_training()) return failure();
2760   return ::mlir::TF::FoldOperandsPermutation(permutation, op);
2761 }
2762 
2763 template <class Op>
GetOptimalLayout(const RuntimeDevices & devices,Op * op)2764 static StringRef GetOptimalLayout(const RuntimeDevices &devices, Op *op) {
2765   // In inference mode FusedBatchNorm is not sensitive to data layout.
2766   if (!op->is_training()) return op->data_format();
2767 
2768   // Keep current data format if no GPUs are available or if explicit placement
2769   // does not allow to use GPU for this operation.
2770   if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(op->getOperation()))
2771     return op->data_format();
2772 
2773   // For f16 data type on devices with Tensor Cores support NHWC data format
2774   // is up to ~2x faster.
2775   auto x_ty = op->x().getType().template cast<TensorType>();
2776   const bool is_f16 = x_ty.getElementType().isF16();
2777   if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
2778 
2779   // For all other data types prefer NCHW.
2780   return "NCHW";
2781 }
2782 
FoldOperandsPermutation(ArrayRef<int64_t> permutation)2783 LogicalResult FusedBatchNormV2Op::FoldOperandsPermutation(
2784     ArrayRef<int64_t> permutation) {
2785   return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this);
2786 }
2787 
UpdateDataFormat(StringRef data_format)2788 LogicalResult FusedBatchNormV2Op::UpdateDataFormat(StringRef data_format) {
2789   return ::mlir::TF::UpdateDataFormat(data_format, this);
2790 }
2791 
GetOptimalLayout(const RuntimeDevices & devices)2792 StringRef FusedBatchNormV2Op::GetOptimalLayout(const RuntimeDevices &devices) {
2793   return ::mlir::TF::GetOptimalLayout(devices, this);
2794 }
2795 
FoldOperandsPermutation(ArrayRef<int64_t> permutation)2796 LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation(
2797     ArrayRef<int64_t> permutation) {
2798   return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this);
2799 }
2800 
UpdateDataFormat(StringRef data_format)2801 LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) {
2802   return ::mlir::TF::UpdateDataFormat(data_format, this);
2803 }
2804 
GetOptimalLayout(const RuntimeDevices & devices)2805 StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) {
2806   return ::mlir::TF::GetOptimalLayout(devices, this);
2807 }
2808 
2809 //===----------------------------------------------------------------------===//
2810 // GatherV2Op
2811 //===----------------------------------------------------------------------===//
2812 
verify()2813 LogicalResult GatherV2Op::verify() {
2814   GatherV2Op op = *this;
2815   int64_t batch_dims = op.batch_dims();
2816   if (auto ty = op.indices().getType().dyn_cast<RankedTensorType>()) {
2817     int64_t rank = ty.getRank();
2818     if (batch_dims > rank || batch_dims < -rank)
2819       return op.emitOpError()
2820              << "batch_dims (" << batch_dims << ") must be in range [" << -rank
2821              << ", " << rank + 1 << ")";
2822     if (batch_dims < 0) batch_dims += rank;
2823   }
2824 
2825   if (!HasRankAtMost(op.axis(), 1))
2826     return op.emitOpError("requires axis to have rank at most 1");
2827 
2828   DenseIntElementsAttr axis_attr;
2829   if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
2830     int64_t axis = (*axis_attr.begin()).getSExtValue();
2831     if (auto ty = op.params().getType().dyn_cast<RankedTensorType>()) {
2832       int64_t rank = ty.getRank();
2833       if (axis >= rank || axis < -rank)
2834         return op.emitOpError() << "axis (" << axis << ") must be in range ["
2835                                 << -rank << ", " << rank << ")";
2836       if (axis < 0) axis += rank;
2837     }
2838 
2839     if (batch_dims >= 0 && axis >= 0 && axis < batch_dims) {
2840       return op.emitOpError() << "requires axis (" << axis
2841                               << ") to be greater than or equal to batch_dims ("
2842                               << batch_dims << ")";
2843     }
2844   }
2845   return success();
2846 }
2847 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2848 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
2849                                            MLIRContext *context) {
2850   results.add<GatherToV2>(context);
2851 }
2852 
2853 //===----------------------------------------------------------------------===//
2854 // IfOp
2855 //===----------------------------------------------------------------------===//
2856 
verifySymbolUses(SymbolTableCollection & symbol_table)2857 LogicalResult IfOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
2858   auto branch_name = [](unsigned index) -> std::string {
2859     return index == 0 ? "'then_branch'" : "'else_branch'";
2860   };
2861   return VerifyCaseOrIfOpBranchFunctions(
2862       symbol_table, *this, {then_branchAttr(), else_branchAttr()}, branch_name);
2863 }
2864 
2865 //===----------------------------------------------------------------------===//
2866 // IfOp canonicalization.
2867 //===----------------------------------------------------------------------===//
2868 
2869 namespace {
2870 class FoldConstantIfOp : public OpRewritePattern<TF::IfOp> {
2871  public:
FoldConstantIfOp(MLIRContext * context)2872   explicit FoldConstantIfOp(MLIRContext *context)
2873       : OpRewritePattern<TF::IfOp>(context) {}
2874   LogicalResult matchAndRewrite(TF::IfOp op,
2875                                 PatternRewriter &rewriter) const override;
2876 
2877  private:
2878   template <typename T>
2879   struct CallOpType {
2880     using CallOp = T;
2881   };
2882 };
2883 
matchAndRewrite(TF::IfOp op,PatternRewriter & rewriter) const2884 LogicalResult FoldConstantIfOp::matchAndRewrite(
2885     TF::IfOp op, PatternRewriter &rewriter) const {
2886   // Extract the constant cond value.
2887   DenseIntElementsAttr cond_attr;
2888   if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure();
2889 
2890   // Cond value must be a scalar.
2891   if (cond_attr.getNumElements() != 1) return failure();
2892 
2893   // Select a branch function.
2894   bool cond = cond_attr.getSplatValue<BoolAttr>().getValue();
2895   FlatSymbolRefAttr func = cond ? op.then_branchAttr() : op.else_branchAttr();
2896 
2897   // Replace IfOp with PartitionedCallOp or StatefulPartitionedCallOp.
2898   auto rewrite = [&](auto op_type) {
2899     auto empty = rewriter.getStringAttr("");
2900     ReplaceTfOpWithNewOp<typename decltype(op_type)::CallOp>(
2901         rewriter, op, op.getResultTypes(), op.input(), func,
2902         /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
2903   };
2904 
2905   if (op.is_stateless())
2906     rewrite(CallOpType<PartitionedCallOp>{});
2907   else
2908     rewrite(CallOpType<StatefulPartitionedCallOp>{});
2909 
2910   return success();
2911 }
2912 }  // anonymous namespace
2913 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2914 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2915                                        MLIRContext *context) {
2916   results.add<FoldConstantIfOp, DropAttributes<IfOp>>(context);
2917 }
2918 
2919 //===----------------------------------------------------------------------===//
2920 // IfRegionOp
2921 //===----------------------------------------------------------------------===//
2922 
verifyRegions()2923 LogicalResult IfRegionOp::verifyRegions() {
2924   IfRegionOp op = *this;
2925   TypeRange then_types =
2926       op.then_branch().front().getTerminator()->getOperandTypes();
2927   TypeRange else_types =
2928       op.else_branch().front().getTerminator()->getOperandTypes();
2929 
2930   TypeRangeWithDesc results{op.getResultTypes(), "result"};
2931   TypeRangeWithDesc then_results{then_types, "then result"};
2932   TypeRangeWithDesc else_results{else_types, "else result"};
2933 
2934   if (failed(VerifyTypeRangesAreCompatible(op, then_results, results)))
2935     return failure();
2936   if (failed(VerifyTypeRangesAreCompatible(op, else_results, results)))
2937     return failure();
2938   return success();
2939 }
2940 
2941 namespace {
2942 class FoldConstantIfRegionOp : public OpRewritePattern<TF::IfRegionOp> {
2943  public:
FoldConstantIfRegionOp(MLIRContext * context)2944   explicit FoldConstantIfRegionOp(MLIRContext *context)
2945       : OpRewritePattern<TF::IfRegionOp>(context) {}
2946   LogicalResult matchAndRewrite(TF::IfRegionOp op,
2947                                 PatternRewriter &rewriter) const override;
2948 };
2949 
matchAndRewrite(TF::IfRegionOp op,PatternRewriter & rewriter) const2950 LogicalResult FoldConstantIfRegionOp::matchAndRewrite(
2951     TF::IfRegionOp op, PatternRewriter &rewriter) const {
2952   // Extract the constant cond value.
2953   DenseIntElementsAttr cond_attr;
2954   if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure();
2955 
2956   // IfRegion condition should always be a scalar. Select the region to fold to.
2957   bool cond = cond_attr.getSplatValue<BoolAttr>().getValue();
2958   Region &region = cond ? op.then_branch() : op.else_branch();
2959 
2960   // If the IfRegion is stateless but the region being inlined itself is not
2961   // stateless, then inlining the region could cause a loss of information.
2962   // However, its probably better to fold the IfRegion instead of having the
2963   // dead branch stay.
2964 
2965   // Inline the region in place of the IfRegion op, and forward the yield
2966   // inputs to the IfRegion op results. This is possible only if the yield
2967   // types match the result types.
2968   auto yield = cast<YieldOp>(region.front().getTerminator());
2969   auto updated_results = llvm::to_vector<4>(yield.getOperands());
2970 
2971   // If the yield types do not match the IfRegion result types, add appropriate
2972   // casts.
2973   rewriter.setInsertionPoint(yield);
2974   for (auto it : llvm::zip(op.getResultTypes(), updated_results)) {
2975     auto &updated_result = std::get<1>(it);
2976     Type result_type = std::get<0>(it);
2977     if (result_type != updated_result.getType()) {
2978       updated_result =
2979           rewriter.create<TF::CastOp>(op.getLoc(), result_type, updated_result,
2980                                       /*Truncate=*/rewriter.getBoolAttr(false));
2981     }
2982   }
2983   // Inline the region into the block containing the IfRegion.
2984   rewriter.mergeBlockBefore(&region.front(), op);
2985   rewriter.eraseOp(yield);
2986   rewriter.replaceOp(op, updated_results);
2987   return success();
2988 }
2989 }  // anonymous namespace
2990 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2991 void IfRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
2992                                              MLIRContext *context) {
2993   results.add<FoldConstantIfRegionOp,
2994               CaseOrIfRegionEliminatePassThrough<TF::IfRegionOp>>(context);
2995 }
2996 
2997 //===----------------------------------------------------------------------===//
2998 // InvertPermutationOp
2999 //===----------------------------------------------------------------------===//
3000 
3001 // Verifies that the input is 1D.
verify()3002 LogicalResult InvertPermutationOp::verify() {
3003   InvertPermutationOp op = *this;
3004   auto x_type = op.x().getType().cast<TensorType>();
3005   if (!x_type.hasRank()) return success();
3006   if (x_type.getShape().size() != 1)
3007     return op.emitOpError() << "requires input x to be 1-dimensional";
3008 
3009   return success();
3010 }
3011 
3012 //===----------------------------------------------------------------------===//
3013 // LeakyReluOp
3014 //===----------------------------------------------------------------------===//
3015 
fold(ArrayRef<Attribute> operands)3016 OpFoldResult LeakyReluOp::fold(ArrayRef<Attribute> operands) {
3017   assert(operands.size() == 1 && "leaky relu has one operand");
3018 
3019   // leaky_relu(x, alpha: 1) -> x
3020   if (alpha().convertToFloat() == 1.0f) return getOperand();
3021 
3022   auto calculate = [&](FloatAttr arg) {
3023     APFloat val = arg.getValue();
3024     if (val.isNegative()) val = alpha() * val;
3025     return FloatAttr::get(arg.getType(), val);
3026   };
3027 
3028   if (auto arg = operands[0].dyn_cast_or_null<FloatAttr>()) {
3029     return calculate(arg);
3030   } else if (auto arg = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
3031     if (auto elementAttr = arg.getSplatValue<Attribute>().dyn_cast<FloatAttr>())
3032       return DenseElementsAttr::get(arg.getType(), calculate(elementAttr));
3033   }
3034   return {};
3035 }
3036 
3037 //===----------------------------------------------------------------------===//
3038 // LogOp
3039 //===----------------------------------------------------------------------===//
3040 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3041 void LogOp::getCanonicalizationPatterns(RewritePatternSet &results,
3042                                         MLIRContext *context) {
3043   results.add<LogOfSoftmax, LogToLog1p>(context);
3044 }
3045 
3046 //===----------------------------------------------------------------------===//
3047 // LogicalNotOp
3048 //===----------------------------------------------------------------------===//
3049 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3050 void LogicalNotOp::getCanonicalizationPatterns(RewritePatternSet &results,
3051                                                MLIRContext *context) {
3052   results
3053       .add<LogicalNotOfEqual, LogicalNotOfNotEqual, LogicalNotOfGreater,
3054            LogicalNotOfGreaterEqual, LogicalNotOfLess, LogicalNotOfLessEqual>(
3055           context);
3056 }
3057 
3058 //===----------------------------------------------------------------------===//
3059 // MatrixBandPartOp
3060 //===----------------------------------------------------------------------===//
3061 
verify()3062 LogicalResult MatrixBandPartOp::verify() {
3063   MatrixBandPartOp op = *this;
3064   if (!HasRankAtLeast(op.input(), 2)) {
3065     return op.emitOpError()
3066            << "requires `input` to have rank of at least 2, but found "
3067            << op.input().getType();
3068   }
3069   if (!IsOfRankOrUnranked(op.num_lower(), 0)) {
3070     return op.emitOpError()
3071            << "requires `num_lower` to have 0 dimensions, but found "
3072            << op.num_lower().getType();
3073   }
3074   if (!IsOfRankOrUnranked(op.num_upper(), 0)) {
3075     return op.emitOpError()
3076            << "requires `num_upper` to have 0 dimensions, but found "
3077            << op.num_upper().getType();
3078   }
3079   return success();
3080 }
3081 
3082 //===----------------------------------------------------------------------===//
3083 // MatrixDiag Ops
3084 //===----------------------------------------------------------------------===//
3085 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3086 void MatrixDiagOp::getCanonicalizationPatterns(RewritePatternSet &results,
3087                                                MLIRContext *context) {
3088   results.add<MatrixDiagToV3>(context);
3089 }
3090 
3091 //===----------------------------------------------------------------------===//
3092 // MatrixSetDiagOp
3093 //===----------------------------------------------------------------------===//
3094 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3095 void MatrixSetDiagOp::getCanonicalizationPatterns(RewritePatternSet &results,
3096                                                   MLIRContext *context) {
3097   results.add<MatrixSetDiagToV3>(context);
3098 }
3099 
3100 //===----------------------------------------------------------------------===//
3101 // MatrixSetDiagV2Op
3102 //===----------------------------------------------------------------------===//
3103 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3104 void MatrixSetDiagV2Op::getCanonicalizationPatterns(RewritePatternSet &results,
3105                                                     MLIRContext *context) {
3106   results.add<MatrixSetDiagV2ToV3>(context);
3107 }
3108 
3109 //===----------------------------------------------------------------------===//
3110 // MaxOp
3111 //===----------------------------------------------------------------------===//
3112 
build(OpBuilder & builder,OperationState & result,Value input,Value reduction_indices,BoolAttr keep_dims)3113 void MaxOp::build(OpBuilder &builder, OperationState &result, Value input,
3114                   Value reduction_indices, BoolAttr keep_dims) {
3115   Type out_ty = InferReductionOpType(input, reduction_indices, keep_dims);
3116   build(builder, result, out_ty, input, reduction_indices, keep_dims);
3117 }
3118 
3119 //===----------------------------------------------------------------------===//
3120 // MaximumOp
3121 //===----------------------------------------------------------------------===//
3122 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3123 void MaximumOp::getCanonicalizationPatterns(RewritePatternSet &results,
3124                                             MLIRContext *context) {
3125   results.add<MaximumOfZeroToRelu>(context);
3126 }
3127 
3128 //===----------------------------------------------------------------------===//
3129 // MaxPoolOp
3130 //===----------------------------------------------------------------------===//
3131 
FoldOperandsPermutation(ArrayRef<int64_t> permutation)3132 LogicalResult MaxPoolOp::FoldOperandsPermutation(
3133     ArrayRef<int64_t> permutation) {
3134   return ::mlir::TF::FoldOperandsPermutation(
3135       permutation, this, {{"strides", strides()}, {"ksize", ksize()}});
3136 }
3137 
UpdateDataFormat(StringRef new_data_format)3138 LogicalResult MaxPoolOp::UpdateDataFormat(StringRef new_data_format) {
3139   StringRef src_data_format = data_format();
3140 
3141   auto perm = GetDataFormatPermutation(src_data_format, new_data_format);
3142   if (perm.empty()) return failure();
3143 
3144   // Update data_format attribute and result types.
3145   if (failed(::mlir::TF::UpdateDataFormat(new_data_format, this)))
3146     return failure();
3147 
3148   stridesAttr(ShuffleArrayAttr(strides(), perm));
3149   explicit_paddingsAttr(ShuffleArrayAttr(explicit_paddings(), perm, 2));
3150   ksizeAttr(ShuffleArrayAttr(ksize(), perm));
3151 
3152   return success();
3153 }
3154 
GetOptimalLayout(const RuntimeDevices & devices)3155 StringRef MaxPoolOp::GetOptimalLayout(const RuntimeDevices &devices) {
3156   // Keep current data format if no GPUs are available or if explicit placement
3157   // does not allow to use GPU for this operation.
3158   if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
3159     return data_format();
3160 
3161   // Defaults to NCHW.
3162   return "NCHW";
3163 }
3164 
3165 //===----------------------------------------------------------------------===//
3166 // MaxPoolGradOp
3167 //===----------------------------------------------------------------------===//
3168 
verify()3169 LogicalResult MaxPoolGradOp::verify() {
3170   MaxPoolGradOp op = *this;
3171   if (!IsOfRankOrUnranked(op.orig_input(), 4)) {
3172     return op.emitOpError() << "requires orig_input to be rank 4";
3173   }
3174   if (!IsOfRankOrUnranked(op.orig_output(), 4)) {
3175     return op.emitOpError() << "requires orig_output to be rank 4";
3176   }
3177   if (!IsOfRankOrUnranked(op.grad(), 4)) {
3178     return op.emitOpError() << "requires grad to be rank 4";
3179   }
3180   return success();
3181 }
3182 
3183 //===----------------------------------------------------------------------===//
3184 // MeanOp
3185 //===----------------------------------------------------------------------===//
3186 
FoldOperandsPermutation(ArrayRef<int64_t> permutation)3187 LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
3188   // Reduction indices must be defined by a constant operation.
3189   auto reduction_op =
3190       dyn_cast_or_null<TF::ConstOp>(reduction_indices().getDefiningOp());
3191   if (!reduction_op) return failure();
3192 
3193   auto reductions_value = reduction_op.value().dyn_cast<DenseElementsAttr>();
3194   if (!reductions_value) return failure();
3195 
3196   // Prepare new reduction indices according to operand permutation.
3197   SmallVector<int32_t, 4> shuffled_reduction;
3198   llvm::transform(reductions_value.getValues<APInt>(),
3199                   std::back_inserter(shuffled_reduction),
3200                   [&](APInt idx) { return permutation[idx.getSExtValue()]; });
3201 
3202   // Add constant operation with a new reduction indices.
3203   OpBuilder builder(getOperation());
3204   auto type = mlir::RankedTensorType::get(shuffled_reduction.size(),
3205                                           builder.getIntegerType(32));
3206   auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction);
3207   auto shuffled_reduction_op = builder.create<TF::ConstOp>(getLoc(), values);
3208 
3209   // Use new reduction indices.
3210   setOperand(1, shuffled_reduction_op);
3211 
3212   return success();
3213 }
3214 
3215 //===----------------------------------------------------------------------===//
3216 // MulNoNanOp
3217 //===----------------------------------------------------------------------===//
3218 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3219 void MulNoNanOp::getCanonicalizationPatterns(RewritePatternSet &results,
3220                                              MLIRContext *context) {
3221   results.add<DivNoNanOrMulNoNanConstantY<TF::MulNoNanOp, TF::MulOp>>(context);
3222 }
3223 
3224 //===----------------------------------------------------------------------===//
3225 // MulOp
3226 //===----------------------------------------------------------------------===//
3227 
fold(ArrayRef<Attribute> operands)3228 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
3229   return IdentityArithmeticOpFolder<MulOp>(*this, operands);
3230 }
3231 
3232 //===----------------------------------------------------------------------===//
3233 // HashTableOp
3234 //===----------------------------------------------------------------------===//
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3235 void HashTableOp::getCanonicalizationPatterns(RewritePatternSet &results,
3236                                               MLIRContext *context) {
3237   results.add<HashTableAndInitializeTableToV2>(context);
3238   results.add<HashTableAndLookupTableSizeToV2>(context);
3239   results.add<HashTableAndLookupTableFindToV2>(context);
3240 }
3241 
3242 //===----------------------------------------------------------------------===//
3243 // BitcastOp
3244 //===----------------------------------------------------------------------===//
3245 
verify()3246 LogicalResult BitcastOp::verify() {
3247   BitcastOp op = *this;
3248   auto input_type = op.input().getType().cast<ShapedType>();
3249   auto output_type = op.output().getType().cast<ShapedType>();
3250   auto input_element_type = input_type.getElementType();
3251   auto output_element_type = output_type.getElementType();
3252 
3253   // We only handle float and int element type in the verifier currently
3254   // TODO(hanxiongwang): we can plan to handle more element type checks besides
3255   // int and float in the verifier
3256   if (input_type.hasStaticShape() && output_type.hasStaticShape() &&
3257       input_element_type.isIntOrFloat() && output_element_type.isIntOrFloat()) {
3258     const auto input_element_type_bitwidth =
3259         input_element_type.getIntOrFloatBitWidth();
3260     const auto output_element_type_bitwidth =
3261         output_element_type.getIntOrFloatBitWidth();
3262 
3263     auto is_output_shape_valid_with_small_input_element_type_bitwidth = [&]() {
3264       if (output_element_type_bitwidth % input_element_type_bitwidth != 0) {
3265         op.emitOpError() << "output element bitwidth is not multiple "
3266                          << "of input element bitwidth";
3267         return failure();
3268       }
3269       if (input_type.getShape().size() != output_type.getShape().size() + 1) {
3270         op.emitOpError() << "rank of input tensor is "
3271                          << input_type.getShape().size()
3272                          << ". rank of output tensor is expected to be "
3273                          << input_type.getShape().size() - 1 << ", instead of "
3274                          << output_type.getShape().size() << ".";
3275         return failure();
3276       }
3277       const auto rightmost_dim_size_divisor =
3278           output_element_type_bitwidth / input_element_type_bitwidth;
3279       if (input_type.getShape().empty() ||
3280           input_type.getShape().back() != rightmost_dim_size_divisor) {
3281         op.emitOpError()
3282             << "input rightmost dimension size is not equal to the divisor. "
3283             << "the last dimension of input is expected to be "
3284             << rightmost_dim_size_divisor;
3285         return failure();
3286       }
3287       for (auto idx = 0; idx < output_type.getShape().size(); idx++) {
3288         if (input_type.getShape()[idx] != output_type.getShape()[idx]) {
3289           op.emitOpError()
3290               << "the " << idx << "th dim of output tensor is "
3291               << output_type.getShape()[idx]
3292               << ". It is not equal to the one in input tensor, which is "
3293               << input_type.getShape()[idx];
3294           return failure();
3295         }
3296       }
3297       return success();
3298     };
3299 
3300     auto is_output_shape_valid_with_small_output_element_type_bitwidth = [&]() {
3301       if (input_element_type_bitwidth % output_element_type_bitwidth != 0) {
3302         op.emitOpError() << "input element bitwidth is not multiple "
3303                          << "of output element bitwidth";
3304         return failure();
3305       }
3306       if (input_type.getShape().size() + 1 != output_type.getShape().size()) {
3307         op.emitOpError() << "rank of input tensor is "
3308                          << input_type.getShape().size()
3309                          << ". rank of output tensor is expected to be "
3310                          << input_type.getShape().size() + 1 << ", instead of "
3311                          << output_type.getShape().size() << ".";
3312         return failure();
3313       }
3314       const auto rightmost_dim_size_divisor =
3315           input_element_type_bitwidth / output_element_type_bitwidth;
3316       if (output_type.getShape().back() != rightmost_dim_size_divisor) {
3317         op.emitOpError()
3318             << "output rightmost dimension size is not equal to the divisor. "
3319             << "the last dimension of output is expected to be "
3320             << rightmost_dim_size_divisor;
3321         return failure();
3322       }
3323       for (auto idx = 0; idx < input_type.getShape().size(); idx++) {
3324         if (input_type.getShape()[idx] != output_type.getShape()[idx]) {
3325           op.emitOpError()
3326               << "the " << idx << "th dim of output tensor is "
3327               << output_type.getShape()[idx]
3328               << ". It is not equal to the one in input tensor, which is "
3329               << input_type.getShape()[idx];
3330           return failure();
3331         }
3332       }
3333       return success();
3334     };
3335 
3336     auto is_output_shape_valid_with_equal_bitwidth = [&]() {
3337       if (input_type.getShape().equals(output_type.getShape())) {
3338         return success();
3339       }
3340       op.emitOpError()
3341           << "output tensor shape shall be equal to input tensor shape";
3342       return failure();
3343     };
3344 
3345     if (input_element_type_bitwidth < output_element_type_bitwidth) {
3346       return is_output_shape_valid_with_small_input_element_type_bitwidth();
3347     } else if (input_element_type_bitwidth > output_element_type_bitwidth) {
3348       return is_output_shape_valid_with_small_output_element_type_bitwidth();
3349     } else {
3350       return is_output_shape_valid_with_equal_bitwidth();
3351     }
3352   }
3353   return success();
3354 }
3355 
3356 }  // namespace TF
3357 }  // namespace mlir
3358 
3359 //===----------------------------------------------------------------------===//
3360 // TableGen'd op method definitions
3361 //===----------------------------------------------------------------------===//
3362 
3363 #define GET_OP_CLASSES
3364 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc.inc"
3365