1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <iterator>
22 #include <limits>
23 #include <numeric>
24 #include <string>
25 #include <tuple>
26 #include <type_traits>
27
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/Optional.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/Sequence.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/ADT/StringSwitch.h"
38 #include "llvm/ADT/iterator_range.h"
39 #include "llvm/Support/Casting.h"
40 #include "llvm/Support/FormatVariadic.h"
41 #include "llvm/Support/raw_ostream.h"
42 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
43 #include "mlir/Dialect/Traits.h" // from @llvm-project
44 #include "mlir/IR/Attributes.h" // from @llvm-project
45 #include "mlir/IR/Builders.h" // from @llvm-project
46 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
47 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
48 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
49 #include "mlir/IR/Diagnostics.h" // from @llvm-project
50 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
51 #include "mlir/IR/Location.h" // from @llvm-project
52 #include "mlir/IR/MLIRContext.h" // from @llvm-project
53 #include "mlir/IR/Matchers.h" // from @llvm-project
54 #include "mlir/IR/OpDefinition.h" // from @llvm-project
55 #include "mlir/IR/OpImplementation.h" // from @llvm-project
56 #include "mlir/IR/PatternMatch.h" // from @llvm-project
57 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
58 #include "mlir/IR/Types.h" // from @llvm-project
59 #include "mlir/IR/Value.h" // from @llvm-project
60 #include "mlir/Parser/Parser.h" // from @llvm-project
61 #include "mlir/Support/LLVM.h" // from @llvm-project
62 #include "mlir/Support/LogicalResult.h" // from @llvm-project
63 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
64 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h"
65 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_canonicalization_helper.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_device_helper.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h"
70 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h"
71 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
72 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
73 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
74 #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h"
75 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
76 #include "tensorflow/core/framework/kernel_shape_util.h"
77 #include "tensorflow/core/platform/logging.h"
78 #include "tensorflow/core/util/padding.h"
79 #include "tensorflow/core/util/tensor_format.h"
80
81 namespace mlir {
82 namespace TF {
83
84 namespace {
85 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
86 } // namespace
87
88 //===----------------------------------------------------------------------===//
89 // AddOp
90 //===----------------------------------------------------------------------===//
91
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)92 void AddOp::getCanonicalizationPatterns(RewritePatternSet &results,
93 MLIRContext *context) {
94 results.add<AddToAddV2>(context);
95 }
96
97 //===----------------------------------------------------------------------===//
98 // AddNOp
99 //===----------------------------------------------------------------------===//
100
fold(ArrayRef<Attribute> operands)101 OpFoldResult AddNOp::fold(ArrayRef<Attribute> operands) {
102 if (operands.size() == 1) return *inputs().begin();
103
104 // Fold if there is only one single non-zero operand or all operands are zero.
105 int non_zero_index = -1;
106 auto IsKnownZero = [](Attribute attr) {
107 if (!attr) return false;
108 auto splat = attr.dyn_cast<SplatElementsAttr>();
109 if (!splat) return false;
110 Type element_ty = splat.getType().getElementType();
111 if (element_ty.isa<FloatType>())
112 return splat.getSplatValue<llvm::APFloat>().isZero();
113 if (element_ty.isa<IntegerType>())
114 return splat.getSplatValue<llvm::APInt>().getSExtValue() == 0;
115 return false;
116 };
117
118 for (auto it : llvm::enumerate(operands)) {
119 if (IsKnownZero(it.value())) continue;
120 if (non_zero_index != -1) {
121 // Don't fold if we find more than 1 non-zero operand.
122 return {};
123 }
124 non_zero_index = it.index();
125 }
126
127 // Only fold when the result shape is fully static.
128 auto result_ty = getType().dyn_cast<ShapedType>();
129 if (!result_ty || !result_ty.hasStaticShape()) return {};
130
131 if (non_zero_index == -1) {
132 return SplatElementsAttr::get(
133 result_ty,
134 operands.begin()->cast<DenseElementsAttr>().getSplatValue<Attribute>());
135 }
136
137 // Check the non-zero operand's shape matches the result shape.
138 if (result_ty == inputs()[non_zero_index].getType())
139 return inputs()[non_zero_index];
140 return {};
141 }
142
143 //===----------------------------------------------------------------------===//
144 // AddV2Op
145 //===----------------------------------------------------------------------===//
146
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)147 void AddV2Op::getCanonicalizationPatterns(RewritePatternSet &results,
148 MLIRContext *context) {
149 results.add<AddV2OfNegLeft, AddV2OfNegRight>(context);
150 }
151
fold(ArrayRef<Attribute> operands)152 OpFoldResult AddV2Op::fold(ArrayRef<Attribute> operands) {
153 return IdentityArithmeticOpFolder<AddV2Op>(*this, operands);
154 }
155
156 //===----------------------------------------------------------------------===//
157 // AllOp
158 //===----------------------------------------------------------------------===//
159
verify()160 LogicalResult AllOp::verify() {
161 AllOp op = *this;
162 return VerifyReductionInputAndDims(op.input(), op.reduction_indices(),
163 op.getLoc());
164 }
165
166 //===----------------------------------------------------------------------===//
167 // AnyOp
168 //===----------------------------------------------------------------------===//
169
verify()170 LogicalResult AnyOp::verify() {
171 AnyOp op = *this;
172 return VerifyReductionInputAndDims(op.input(), op.reduction_indices(),
173 op.getLoc());
174 }
175
176 //===----------------------------------------------------------------------===//
177 // AssertOp
178 //===----------------------------------------------------------------------===//
179
180 namespace {
181
182 // Removes Assert with constant true predicate.
183 struct AssertWithTrue : public OpRewritePattern<AssertOp> {
184 using OpRewritePattern<AssertOp>::OpRewritePattern;
185
matchAndRewritemlir::TF::__anonc7529f2f0311::AssertWithTrue186 LogicalResult matchAndRewrite(AssertOp op,
187 PatternRewriter &rewriter) const override {
188 ElementsAttr cst;
189 if (matchPattern(op.condition(), m_Constant(&cst))) {
190 if (cst.getValues<bool>()[0]) {
191 rewriter.eraseOp(op);
192 return success();
193 }
194 }
195 return failure();
196 }
197 };
198 } // namespace
199
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)200 void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
201 MLIRContext *context) {
202 results.add<AssertWithTrue>(context);
203 }
204
205 //===----------------------------------------------------------------------===//
206 // BatchMatMulV2Op & BatchMatMulOp
207 //===----------------------------------------------------------------------===//
208
209 template <typename OpT,
210 typename std::enable_if<llvm::is_one_of<
211 OpT, BatchMatMulOp, BatchMatMulV2Op>::value>::type * = nullptr>
Verify(OpT op)212 static LogicalResult Verify(OpT op) {
213 if (!HasRankAtLeast(op.x(), 2)) {
214 return op.emitOpError("requires lhs operand to have rank at least two");
215 }
216 if (!HasRankAtLeast(op.y(), 2)) {
217 return op.emitOpError("requires rhs operand to have rank at least two");
218 }
219
220 RankedTensorType x_ty = GetRankedTensorTypeForOperand(op.x());
221 RankedTensorType y_ty = GetRankedTensorTypeForOperand(op.y());
222
223 if (!x_ty || !y_ty) return success();
224
225 ArrayRef<int64_t> x_shape = x_ty.getShape();
226 ArrayRef<int64_t> y_shape = y_ty.getShape();
227
228 llvm::SmallVector<int64_t, 4> result_batch_shape;
229 llvm::ArrayRef<int64_t> x_batches = x_shape.drop_back(2);
230 llvm::ArrayRef<int64_t> y_batches = y_shape.drop_back(2);
231
232 // Check compatibility of batch dimensions if both input shapes are known.
233 // BatchMatMul should have exactly the same batch dimensions and
234 // BatchMatMulV2 should have broadcastable batch dimensions.
235 //
236 // The last two dimensions are non-batch dimensions that don't need to
237 // participate in batch dimension compatibility check.
238 if (std::is_same<OpT, BatchMatMulOp>()) {
239 for (const auto &dim_pairs : llvm::zip(x_batches, y_batches)) {
240 int64_t x_dim = std::get<0>(dim_pairs);
241 int64_t y_dim = std::get<1>(dim_pairs);
242 if (!ShapedType::isDynamic(x_dim) && !ShapedType::isDynamic(y_dim) &&
243 x_dim != y_dim) {
244 return op.emitOpError()
245 << "found mismatching batch dimensions for lhs shape " << x_ty
246 << " and rhs shape " << y_ty;
247 }
248 }
249 } else {
250 if (!OpTrait::util::getBroadcastedShape(x_batches, y_batches,
251 result_batch_shape))
252 return op.emitOpError()
253 << "found incompatible broadcast batch dimensions for lhs shape "
254 << x_ty << " and rhs shape " << y_ty;
255 }
256
257 RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output());
258 if (!output_ty) return success();
259
260 int64_t expected_output_rank = std::max(x_ty.getRank(), y_ty.getRank());
261 if (output_ty.getRank() != expected_output_rank)
262 return op.emitOpError()
263 << "found invalid output rank, expected " << expected_output_rank
264 << " but got " << output_ty.getRank();
265
266 // Check output batch dim with potential broadcasting.
267 ArrayRef<int64_t> output_shape = output_ty.getShape();
268 for (int i = 0; i < result_batch_shape.size(); ++i) {
269 if (output_shape[i] != ShapedType::kDynamicSize &&
270 result_batch_shape[i] != ShapedType::kDynamicSize &&
271 output_shape[i] != result_batch_shape[i])
272 return op.emitOpError()
273 << "has mismatching input batch dimension "
274 << result_batch_shape[i] << " and output batch dimension "
275 << output_shape[i];
276 }
277
278 // Check output shape for non-batch dimension, following documentation below.
279 // https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul
280 int64_t x_row_dim = x_shape[x_shape.size() - 2];
281 int64_t x_col_dim = x_shape[x_shape.size() - 1];
282 int64_t y_row_dim = y_shape[y_shape.size() - 2];
283 int64_t y_col_dim = y_shape[y_shape.size() - 1];
284 int64_t out_row_dim = output_shape[output_shape.size() - 2];
285 int64_t out_col_dim = output_shape[output_shape.size() - 1];
286
287 int64_t expected_out_row_dim = op.adj_x() ? x_col_dim : x_row_dim;
288 int64_t expected_out_col_dim = op.adj_y() ? y_row_dim : y_col_dim;
289
290 if (expected_out_row_dim != ShapedType::kDynamicSize &&
291 out_row_dim != ShapedType::kDynamicSize &&
292 out_row_dim != expected_out_row_dim)
293 return op.emitOpError()
294 << "found invalid output dimension on row, expected "
295 << expected_out_row_dim << " but got " << out_row_dim;
296 if (expected_out_col_dim != ShapedType::kDynamicSize &&
297 out_col_dim != ShapedType::kDynamicSize &&
298 out_col_dim != expected_out_col_dim)
299 return op.emitOpError()
300 << "found invalid output dimension on col, expected "
301 << expected_out_col_dim << " but got " << out_col_dim;
302
303 return success();
304 }
verify()305 LogicalResult BatchMatMulOp::verify() { return Verify(*this); }
verify()306 LogicalResult BatchMatMulV2Op::verify() { return Verify(*this); }
307
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)308 void BatchMatMulOp::getCanonicalizationPatterns(RewritePatternSet &results,
309 MLIRContext *context) {
310 results.add<BatchMatMulToV2>(context);
311 }
312
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)313 void BatchMatMulV2Op::getCanonicalizationPatterns(RewritePatternSet &results,
314 MLIRContext *context) {
315 results.add<BatchMatMulV2ToMatMul>(context);
316 }
317
318 //===----------------------------------------------------------------------===//
319 // BatchToSpaceOp
320 //===----------------------------------------------------------------------===//
321
verify()322 LogicalResult BatchToSpaceOp::verify() {
323 BatchToSpaceOp op = *this;
324 // Op already has a constraint that block_size >= 2.
325 int64_t block_size = op.block_size();
326
327 llvm::SmallVector<int64_t, 4> input_shape(4, ShapedType::kDynamicSize);
328 auto input_type = op.input().getType().cast<TensorType>();
329 if (input_type.hasRank()) {
330 if (input_type.getRank() != 4)
331 return op.emitOpError()
332 << "requires input to be a 4D tensor, but got " << input_type;
333
334 int64_t input_batch = input_type.getDimSize(0);
335 if (input_batch != ShapedType::kDynamicSize &&
336 input_batch % (block_size * block_size) != 0) {
337 return op.emitOpError()
338 << "requires input batch (dimension 0) to be evenly divisible "
339 "by (block_size * block_size), but got input batch "
340 << input_batch << " and block_size " << block_size;
341 }
342
343 input_shape.assign(input_type.getShape().begin(),
344 input_type.getShape().end());
345 }
346
347 auto crops_type = op.crops().getType().cast<TensorType>();
348 if (crops_type.hasRank()) {
349 if (crops_type.getRank() != 2)
350 return op.emitOpError()
351 << "requires crops to be a 2D tensor, but got " << crops_type;
352
353 auto dim_of_size = [&](int64_t dim, int64_t size) {
354 if (crops_type.isDynamicDim(dim)) return true;
355 return crops_type.getDimSize(dim) == size;
356 };
357 if (!dim_of_size(0, 2) || !dim_of_size(1, 2))
358 return op.emitOpError()
359 << "requires crops to be a tensor<2x2>, but got " << crops_type;
360 }
361
362 DenseIntElementsAttr crops_attr;
363 // Crops are defined as [[crop_top, crop_bottom], [crop_left, crop_right]],
364 // and flattened as [crop_top, crop_bottom, crop_left, crop_right]
365 llvm::SmallVector<int64_t, 4> crops_values;
366 if (matchPattern(op.crops(), m_Constant(&crops_attr))) {
367 assert(crops_attr.getNumElements() == 4 &&
368 "tf.BatchToSpace crops must have 4 elements");
369
370 auto crops_range = crops_attr.getValues<APInt>();
371 for (const auto &crops_value : crops_range) {
372 int64_t crops_value_int = crops_value.getSExtValue();
373 if (crops_value_int < 0)
374 return op.emitOpError()
375 << "requires all crop values to be nonnegative, but got "
376 << crops_attr;
377
378 crops_values.push_back(crops_value_int);
379 }
380 }
381
382 auto output_type = op.output().getType().cast<TensorType>();
383 if (output_type.hasRank()) {
384 if (output_type.getRank() != 4)
385 return op.emitOpError()
386 << "requires output to be a 4D tensor, but got " << output_type;
387
388 auto static_dims = [](int64_t dim_a, int64_t dim_b) {
389 return dim_a != ShapedType::kDynamicSize &&
390 dim_b != ShapedType::kDynamicSize;
391 };
392
393 auto output_shape = output_type.getShape();
394
395 // output batch = input batch / (block_size * block_size).
396 int64_t input_batch = input_shape[0];
397 int64_t output_batch = output_shape[0];
398 if (static_dims(input_batch, output_batch) &&
399 (output_batch * block_size * block_size) != input_batch)
400 return op.emitOpError()
401 << "requires output batch (dimension 0) to be equal to input "
402 "batch (dimension 0) / (block_size * block_size), but got "
403 "output batch "
404 << output_batch << ", input batch " << input_batch
405 << ", and block_size " << block_size;
406
407 auto check_spatial_dim = [&](int64_t spatial_dim_index,
408 llvm::StringRef dim_name,
409 llvm::StringRef crop_a_name,
410 llvm::StringRef crop_b_name) -> LogicalResult {
411 int64_t input_dim = input_shape[spatial_dim_index];
412 int64_t output_dim = output_shape[spatial_dim_index];
413 if (!static_dims(input_dim, output_dim)) return success();
414
415 int64_t input_dim_pad = input_dim * block_size;
416 // If crops are unknown, the maximum output spatial dim size is input
417 // spatial dim size * block_size, as crops can be minimum 0.
418 if (crops_values.empty() && output_dim > input_dim * block_size)
419 return op.emitOpError()
420 << "requires output " << dim_name << " (dimension "
421 << spatial_dim_index << ") to be less than or equal to input "
422 << dim_name << " (dimension " << spatial_dim_index
423 << ") * block_size, but got output " << dim_name << " "
424 << output_dim << ", input " << dim_name << " " << input_dim
425 << ", and block_size " << block_size;
426
427 if (!crops_values.empty()) {
428 // output spatial dim = input spatial dim * block_size - crops.
429 int64_t crop_a = crops_values[2 * (spatial_dim_index - 1)];
430 int64_t crop_b = crops_values[2 * (spatial_dim_index - 1) + 1];
431 if (output_dim != input_dim_pad - crop_a - crop_b)
432 return op.emitOpError()
433 << "requires output " << dim_name << " (dimension "
434 << spatial_dim_index << ") to be equal to input " << dim_name
435 << " (dimension " << spatial_dim_index << ") * block_size - "
436 << crop_a_name << " - " << crop_b_name << ", but got output "
437 << dim_name << " " << output_dim << ", input " << dim_name
438 << " " << input_dim << ", " << crop_a_name << " " << crop_a
439 << ", " << crop_b_name << " " << crop_b << ", and block_size "
440 << block_size;
441 }
442
443 return success();
444 };
445
446 if (failed(check_spatial_dim(1, "height", "crop_top", "crop_bottom")) ||
447 failed(check_spatial_dim(2, "width", "crop_left", "crop_right")))
448 return failure();
449
450 int64_t input_depth = input_shape[3];
451 int64_t output_depth = output_shape[3];
452 if (static_dims(input_depth, output_depth) && output_depth != input_depth)
453 return op.emitOpError()
454 << "requires output depth (dimension 3) to be equal to input "
455 "depth (dimension 3), but got output depth "
456 << output_depth << " and input depth " << input_depth;
457 }
458
459 return success();
460 }
461
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)462 void BatchToSpaceOp::getCanonicalizationPatterns(RewritePatternSet &results,
463 MLIRContext *context) {
464 results.add<BatchToSpaceToBatchToSpaceND>(context);
465 }
466
467 //===----------------------------------------------------------------------===//
468 // BatchToSpaceNDOp
469 //===----------------------------------------------------------------------===//
470
verify()471 LogicalResult BatchToSpaceNDOp::verify() {
472 BatchToSpaceNDOp op = *this;
473 auto block_shape_ty = op.block_shape().getType().cast<ShapedType>();
474 auto crops_ty = op.crops().getType().cast<ShapedType>();
475
476 if (block_shape_ty.hasStaticShape() && crops_ty.hasStaticShape()) {
477 const int block_rank = block_shape_ty.getShape().front();
478 if (crops_ty.getRank() != 2 || crops_ty.getShape().front() != block_rank ||
479 crops_ty.getShape()[1] != 2) {
480 op.emitOpError() << "crops should have shape [" << block_rank
481 << ", 2] instead of " << crops_ty.getShape();
482 return failure();
483 }
484 }
485
486 return success();
487 }
488
489 //===----------------------------------------------------------------------===//
490 // BiasAddOp
491 //===----------------------------------------------------------------------===//
492
493 // Verifies that,
494 // * the value and bias operands have valid ranks or are unranked.
495 // * Channel dimension of the value operand and length of bias matches if they
496 // are not unknown.
497 //
verify()498 LogicalResult BiasAddOp::verify() {
499 BiasAddOp op = *this;
500 absl::string_view data_format(op.data_format().data(),
501 op.data_format().size());
502 tensorflow::TensorFormat format;
503 bool is_valid = FormatFromString(data_format, &format);
504 DCHECK(is_valid) << data_format;
505 if (format == tensorflow::TensorFormat::FORMAT_NHWC) {
506 if (!HasRankAtLeast(op.value(), 2))
507 return op.emitOpError(
508 "requires value operand to have rank at least two with `NHWC` data "
509 "format");
510 } else {
511 // Op definition requires data_format to be either NHWC or NCHW.
512 DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW);
513 if (!HasRankAtLeast(op.value(), 3))
514 return op.emitOpError(
515 "requires value operand to have rank at least three with `NCHW` data "
516 "format");
517 }
518
519 if (!IsOfRankOrUnranked(op.bias(), 1))
520 return op.emitOpError("requires bias operand to have rank exactly one");
521
522 RankedTensorType value_ty = op.value().getType().dyn_cast<RankedTensorType>();
523 RankedTensorType bias_ty = op.bias().getType().dyn_cast<RankedTensorType>();
524 if (!bias_ty || !value_ty) return success();
525
526 int64_t feature_dim_idx =
527 tensorflow::GetTensorFeatureDimIndex(value_ty.getRank(), format);
528 int64_t feature_dim = value_ty.getDimSize(feature_dim_idx);
529 int64_t bias_len = bias_ty.getDimSize(0);
530 if (feature_dim != -1 && bias_len != -1 && feature_dim != bias_len) {
531 return op.emitOpError()
532 << "requires channel dimension and feature dimension to match; "
533 "found "
534 << feature_dim << " and " << bias_len << ", respectively";
535 }
536 return success();
537 }
538
UpdateDataFormat(StringRef data_format)539 LogicalResult BiasAddOp::UpdateDataFormat(StringRef data_format) {
540 return ::mlir::TF::UpdateDataFormat(data_format, this);
541 }
542
GetOptimalLayout(const RuntimeDevices & devices)543 StringRef BiasAddOp::GetOptimalLayout(const RuntimeDevices &devices) {
544 // Keep current data format if no GPUs are available or if explicit placement
545 // does not allow to use GPU for this operation.
546 if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
547 return data_format();
548
549 // Prefer NHWC for GPU devices.
550 return "NHWC";
551 }
552
553 //===----------------------------------------------------------------------===//
554 // BiasAddGradOp
555 //===----------------------------------------------------------------------===//
556
557 // Verifies that,
558 // * the out_backprop operands have valid ranks or are unranked.
559 //
verify()560 LogicalResult BiasAddGradOp::verify() {
561 BiasAddGradOp op = *this;
562 absl::string_view data_format(op.data_format().data(),
563 op.data_format().size());
564 tensorflow::TensorFormat format;
565 bool is_valid = FormatFromString(data_format, &format);
566 DCHECK(is_valid) << data_format;
567 if (format == tensorflow::TensorFormat::FORMAT_NHWC) {
568 if (!HasRankAtLeast(op.out_backprop(), 2))
569 return op.emitOpError(
570 "requires out_backprop operand to have rank at least two with `NHWC` "
571 "data format");
572 } else {
573 // Op definition requires data_format to be either NHWC or NCHW.
574 DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW);
575 if (!HasRankAtLeast(op.out_backprop(), 3))
576 return op.emitOpError(
577 "requires out_backprop operand to have rank at least three with "
578 "`NCHW` data format");
579 }
580
581 return success();
582 }
583
584 //===----------------------------------------------------------------------===//
585 // BiasAddV1Op
586 //===----------------------------------------------------------------------===//
587
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)588 void BiasAddV1Op::getCanonicalizationPatterns(RewritePatternSet &results,
589 MLIRContext *context) {
590 results.add<BiasAddV1ToBiasAdd>(context);
591 }
592
593 //===----------------------------------------------------------------------===//
594 // arith::BitcastOp
595 //===----------------------------------------------------------------------===//
596
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)597 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
598 MLIRContext *context) {
599 results.add<BitcastSameType, BitcastNested>(context);
600 }
601
602 //===----------------------------------------------------------------------===//
603 // BroadcastToOp
604 //===----------------------------------------------------------------------===//
605
verify()606 LogicalResult BroadcastToOp::verify() {
607 // TODO(antiagainst): check that
608 // * The 'shape' input is an 1-D int tensor.
609 // * Each dimension pair of the source and target shapes are either equal
610 // or one of them is one.
611 return success();
612 }
613
fold(ArrayRef<Attribute> operands)614 OpFoldResult BroadcastToOp::fold(ArrayRef<Attribute> operands) {
615 Value input = this->input();
616
617 // Fold broadcast if operand and result types are the same and all dimensions
618 // are statically known (no-op broadcast).
619 auto result_ty = getType().dyn_cast<ShapedType>();
620 if (!result_ty || !result_ty.hasStaticShape()) return {};
621
622 if (result_ty == input.getType()) return input;
623
624 DenseIntElementsAttr cst_attr;
625 if (!matchPattern(input, m_Constant(&cst_attr))) return {};
626 if (!cst_attr.isSplat()) return {};
627
628 return DenseElementsAttr::get(result_ty, cst_attr.getSplatValue<Attribute>());
629 }
630
631 //===----------------------------------------------------------------------===//
632 // BroadcastGradientArgsOp
633 //===----------------------------------------------------------------------===//
634
635 namespace {
636 // Returns `true` if both s0 & s1 are defined via constant op, and fills
637 // s0_shape & s1_shape.
ExtractInputConstShape(BroadcastGradientArgsOp op,DenseIntElementsAttr & s0,DenseIntElementsAttr & s1,SmallVectorImpl<int64_t> & s0_shape,SmallVectorImpl<int64_t> & s1_shape)638 bool ExtractInputConstShape(BroadcastGradientArgsOp op,
639 DenseIntElementsAttr &s0, DenseIntElementsAttr &s1,
640 SmallVectorImpl<int64_t> &s0_shape,
641 SmallVectorImpl<int64_t> &s1_shape) {
642 if (!matchPattern(op.s0(), m_Constant(&s0))) return false;
643 if (!matchPattern(op.s1(), m_Constant(&s1))) return false;
644
645 for (auto s : s0.getValues<APInt>()) s0_shape.push_back(s.getSExtValue());
646 for (auto s : s1.getValues<APInt>()) s1_shape.push_back(s.getSExtValue());
647
648 return true;
649 }
650
651 // Calculates r0 & r1 output based on inputs and calculated broadcasted shape.
652 //
653 // For given bcasted_shape, s0_shape and s1_shape, the broadcasted dimension is
654 // calculated and push back to its corresponding result, r0 or r1. For example,
655 // for s0_shape [1,4] and s1_shape [4, 4], bcasted_shape is computed to be
656 // [4,4] - this leads to the result of r0 to be [0] as the first dimension of s0
657 // is broadcasted, and r1 to be <> as no broadcasting is happening for s1.
GetOutputShapeForBroadcastGradientArgs(ArrayRef<int64_t> bcasted_shape,ArrayRef<int64_t> s0_shape,ArrayRef<int64_t> s1_shape,SmallVectorImpl<int64_t> & r0,SmallVectorImpl<int64_t> & r1)658 void GetOutputShapeForBroadcastGradientArgs(ArrayRef<int64_t> bcasted_shape,
659 ArrayRef<int64_t> s0_shape,
660 ArrayRef<int64_t> s1_shape,
661 SmallVectorImpl<int64_t> &r0,
662 SmallVectorImpl<int64_t> &r1) {
663 r0.clear();
664 r1.clear();
665
666 // No broadcasting is required if both the shapes are equal.
667 if (s0_shape == s1_shape) return;
668
669 for (int i = bcasted_shape.size(); i > 0; --i) {
670 int idx = bcasted_shape.size() - i;
671 int s0_idx = i > s0_shape.size() ? -1 : s0_shape.size() - i;
672 int s1_idx = i > s1_shape.size() ? -1 : s1_shape.size() - i;
673 if (s0_idx == -1) {
674 r0.push_back(idx);
675 if (s1_shape[s1_idx] == 1) r1.push_back(idx);
676 } else if (s1_idx == -1) {
677 r1.push_back(idx);
678 if (s0_shape[s0_idx] == 1) r0.push_back(idx);
679 } else if (s0_shape[s0_idx] != s1_shape[s1_idx]) {
680 if (s0_shape[s0_idx] != bcasted_shape[idx])
681 r0.push_back(idx);
682 else
683 r1.push_back(idx);
684 } else if (s0_shape[s0_idx] == 1) {
685 // This op is used to compute the gradient dimensions requiring reduction
686 // to match the input dimensions. In case both the dimensions are one,
687 // reducing the dimension has no effect. We choose to reduce such
688 // dimensions to match the TensorFlow kernel behavior. However, note that
689 // the TF behavior in this case is inconsistent with the case with the
690 // same shapes.
691 r0.push_back(idx);
692 r1.push_back(idx);
693 }
694 }
695 }
696 } // namespace
697
698 // Verifies that,
699 // * Broadcast compatability for input shapes.
700 // * Output shape dimension matches the expected dimension size for input
701 // shapes.
verify()702 LogicalResult BroadcastGradientArgsOp::verify() {
703 BroadcastGradientArgsOp op = *this;
704 SmallVector<int64_t, 4> s0_shape, s1_shape;
705 DenseIntElementsAttr s0, s1;
706 if (!ExtractInputConstShape(op, s0, s1, s0_shape, s1_shape)) return success();
707
708 // If both shape is known const, try to validate shape on them as well.
709 SmallVector<int64_t, 4> bcasted_shape;
710 if (!OpTrait::util::getBroadcastedShape(s0_shape, s1_shape, bcasted_shape))
711 return op.emitOpError() << "requires broadcast compatible shape tensors "
712 "for 's0' and 's1', but got "
713 << s0 << " and " << s1;
714
715 SmallVector<int64_t, 4> r0, r1;
716 GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0,
717 r1);
718
719 // Verify that output types are of rank one and matches the computed result
720 // shape.
721 auto r0_ty = op.r0().getType().dyn_cast<RankedTensorType>();
722 auto r1_ty = op.r1().getType().dyn_cast<RankedTensorType>();
723 if (r0_ty && r0_ty.hasStaticShape() && r0_ty.getDimSize(0) != r0.size())
724 return op.emitOpError() << "requires dimension 0 size of 'r0' to be "
725 << r0.size() << " but got " << r0_ty.getShape()[0];
726 if (r1_ty && r1_ty.hasStaticShape() && r1_ty.getDimSize(0) != r1.size())
727 return op.emitOpError() << "requires dimension 0 size of 'r1' to be "
728 << r1.size() << " but got " << r1_ty.getShape()[0];
729
730 return success();
731 }
732
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)733 LogicalResult BroadcastGradientArgsOp::fold(
734 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
735 SmallVector<int64_t, 4> s0_shape, s1_shape;
736 DenseIntElementsAttr s0, s1;
737 if (!ExtractInputConstShape(*this, s0, s1, s0_shape, s1_shape))
738 return failure();
739
740 // Fold BroadcastGradientArgs into two constants if both of the inputs have
741 // known shape.
742 SmallVector<int64_t, 4> bcasted_shape;
743 // Verifier should already ensure the broadcast compatibility.
744 bool bcast_compatible =
745 OpTrait::util::getBroadcastedShape(s0_shape, s1_shape, bcasted_shape);
746 assert(bcast_compatible);
747 (void)bcast_compatible;
748
749 SmallVector<int64_t, 4> r0, r1;
750 GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0,
751 r1);
752
753 auto build_out_dense_element = [](SmallVectorImpl<int64_t> &shape,
754 Type input_type) {
755 Type element_type = input_type.cast<mlir::TensorType>().getElementType();
756 RankedTensorType type = RankedTensorType::get(
757 {static_cast<int64_t>(shape.size())}, element_type);
758 // Input could only be i32 or i64. For i32, downcast to int32_t array.
759 if (element_type.isInteger(32)) {
760 SmallVector<int32_t, 4> i32_shape;
761 for (auto s : shape) i32_shape.push_back(static_cast<int32_t>(s));
762 return DenseIntElementsAttr::get(type, i32_shape);
763 } else {
764 assert(element_type.isInteger(64));
765 return DenseIntElementsAttr::get(type, shape);
766 }
767 };
768
769 results.push_back(build_out_dense_element(r0, this->s0().getType()));
770 results.push_back(build_out_dense_element(r1, this->s1().getType()));
771
772 return success();
773 }
774
775 //===----------------------------------------------------------------------===//
776 // CaseOp
777 //===----------------------------------------------------------------------===//
778
779 class FoldConstantCaseOp : public OpRewritePattern<TF::CaseOp> {
780 public:
FoldConstantCaseOp(MLIRContext * context)781 explicit FoldConstantCaseOp(MLIRContext *context)
782 : OpRewritePattern<TF::CaseOp>(context) {}
783 LogicalResult matchAndRewrite(TF::CaseOp op,
784 PatternRewriter &rewriter) const override;
785 };
786
matchAndRewrite(TF::CaseOp op,PatternRewriter & rewriter) const787 LogicalResult FoldConstantCaseOp::matchAndRewrite(
788 TF::CaseOp op, PatternRewriter &rewriter) const {
789 // Extract the constant cond value.
790 DenseIntElementsAttr branch;
791 if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure();
792
793 int index = *branch.getValues<int>().begin();
794 if (index < 0 || index >= op.num_branches()) index = op.num_branches() - 1;
795
796 auto func = op.branches()[index].cast<SymbolRefAttr>();
797 auto empty = rewriter.getStringAttr("");
798 ReplaceTfOpWithNewOp<PartitionedCallOp>(
799 rewriter, op, op.getResultTypes(), op.getOperands().drop_front(), func,
800 /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
801 return success();
802 }
803
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)804 void CaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
805 MLIRContext *context) {
806 results.add<FoldConstantCaseOp, DropAttributes<CaseOp>>(context);
807 }
808
VerifyCaseOpBase(Operation * op,Value branch_index)809 static LogicalResult VerifyCaseOpBase(Operation *op, Value branch_index) {
810 if (!IsOfRankOrUnranked(branch_index, 0))
811 return op->emitOpError()
812 << "expects 'branch_index' to be a scalar, but got "
813 << branch_index.getType();
814 return success();
815 }
816
VerifyCaseOrIfOpBranchFunctions(SymbolTableCollection & symbol_table,Operation * op,ArrayRef<Attribute> branches,llvm::function_ref<std::string (unsigned branch_index)> branch_name)817 static LogicalResult VerifyCaseOrIfOpBranchFunctions(
818 SymbolTableCollection &symbol_table, Operation *op,
819 ArrayRef<Attribute> branches,
820 llvm::function_ref<std::string(unsigned branch_index)> branch_name) {
821 SmallVector<FunctionType, 2> branch_types;
822 branch_types.reserve(branches.size());
823
824 if (llvm::any_of(op->getOperands(),
825 [](Value value) { return value == nullptr; }))
826 return op->emitOpError("operation has null operand");
827
828 // Functions have one less operand compared to op as first operand is elided
829 // (`cond` of `tf.If` and `branch_index` of `tf.Case`).
830 TypeRangeWithDesc input{op->getOperands().drop_front().getTypes(), "input"};
831 TypeRangeWithDesc result{op->getResultTypes(), "result"};
832
833 for (auto branch : llvm::enumerate(branches)) {
834 auto branch_func = symbol_table.lookupNearestSymbolFrom<func::FuncOp>(
835 op, branch.value().cast<SymbolRefAttr>());
836 if (!branch_func)
837 return op->emitOpError()
838 << "expects " << branch_name(branch.index()) << " ("
839 << branch.value() << ") to point to a defined function";
840
841 FunctionType branch_type = branch_func.getFunctionType();
842 std::string desc = branch_name(branch.index()) + " input";
843 TypeRangeWithDesc branch_input{branch_type.getInputs(), desc};
844 if (failed(VerifyTypeRangesAreCompatible(op, branch_input, input)))
845 return failure();
846
847 desc = branch_name(branch.index()) + " result";
848 TypeRangeWithDesc branch_result{branch_type.getResults(), desc};
849 if (failed(VerifyTypeRangesAreCompatible(op, branch_result, result)))
850 return failure();
851
852 branch_types.push_back(branch_type);
853 }
854
855 // If branches have incompatible input types that means that no tensor can
856 // serve as input to all the functions. Hence, the op is invalid.
857 int expected_num_inputs = op->getNumOperands() - 1;
858 for (int i = 0; i < expected_num_inputs; ++i) {
859 SmallVector<Type, 2> branch_input_i_types;
860 branch_input_i_types.reserve(branches.size());
861 llvm::transform(
862 branch_types, std::back_inserter(branch_input_i_types),
863 [i](FunctionType &branch_type) { return branch_type.getInput(i); });
864 if (!AreCastCompatible(branch_input_i_types)) {
865 std::string input_types_str;
866 llvm::raw_string_ostream os(input_types_str);
867 llvm::interleaveComma(branch_input_i_types, os);
868 return op->emitOpError()
869 << "expects all branch input type(s) (" << os.str()
870 << ") at index " << i << " to be cast compatible";
871 }
872 }
873
874 return success();
875 }
876
verify()877 LogicalResult CaseOp::verify() {
878 CaseOp op = *this;
879 return VerifyCaseOpBase(op, op.branch_index());
880 }
881
verifySymbolUses(SymbolTableCollection & symbol_table)882 LogicalResult CaseOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
883 auto branch_name = [](unsigned index) {
884 return llvm::formatv("branch #{0}", index).str();
885 };
886 return VerifyCaseOrIfOpBranchFunctions(symbol_table, *this,
887 branches().getValue(), branch_name);
888 }
889
890 //===----------------------------------------------------------------------===//
891 // CaseRegionOp
892 //===----------------------------------------------------------------------===//
893
verify()894 LogicalResult CaseRegionOp::verify() {
895 CaseRegionOp op = *this;
896 if (op.branches().empty())
897 return op.emitOpError() << "expects to have at least 1 region";
898
899 if (failed(VerifyCaseOpBase(op, op.branch_index()))) return failure();
900
901 TypeRangeWithDesc results{op.getResultTypes(), "result"};
902
903 for (auto region_and_idx : llvm::enumerate(op.branches())) {
904 std::string description =
905 llvm::formatv("branch #{0} result", region_and_idx.index()).str();
906 Operation *yield = region_and_idx.value().front().getTerminator();
907 TypeRangeWithDesc branch_results{yield->getOperandTypes(), description};
908 if (failed(VerifyTypeRangesAreCompatible(op, branch_results, results)))
909 return failure();
910 }
911
912 return success();
913 }
914
915 namespace {
916 // Eliminate values that pass through the CaseRegionOp or IfRegionOp branches.
917 template <class CaseOrIfRegionOp>
918 class CaseOrIfRegionEliminatePassThrough
919 : public OpRewritePattern<CaseOrIfRegionOp> {
920 using OpRewritePattern<CaseOrIfRegionOp>::OpRewritePattern;
921
matchAndRewrite(CaseOrIfRegionOp op,PatternRewriter & rewriter) const922 LogicalResult matchAndRewrite(CaseOrIfRegionOp op,
923 PatternRewriter &rewriter) const override {
924 RegionRange branches = op.getRegions();
925 SmallVector<Type, 4> new_result_types;
926 // Maps pass through results to extern values.
927 llvm::SmallDenseMap<Value, Value, 4> result_to_extern_value;
928
929 for (auto result : op.getResults()) {
930 unsigned index = result.getResultNumber();
931 Region *first_branch = *branches.begin();
932 Operation *first_terminator = first_branch->front().getTerminator();
933 Value returned_val = first_terminator->getOperand(index);
934
935 // Pass through values would be defined outside the branch region. Keep
936 // the type of non pass through results to create a new op later, if
937 // required.
938 if (returned_val.getParentBlock() == &first_branch->front()) {
939 new_result_types.push_back(result.getType());
940 continue;
941 }
942 // Check if the same extern value is returned in each branch.
943 for (Region *region : branches.drop_front()) {
944 Operation *terminator = region->front().getTerminator();
945 if (terminator->getOperand(index) != returned_val) return failure();
946 }
947 result_to_extern_value[result] = returned_val;
948 }
949
950 // If no pass through values are found, no change is required.
951 if (result_to_extern_value.empty()) return failure();
952
953 // Create new case/if region op.
954 auto new_op = rewriter.create<CaseOrIfRegionOp>(
955 op.getLoc(), new_result_types, op.getOperand(), op->getAttrs(),
956 op.getNumRegions());
957
958 int next_index = 0;
959 for (auto result : op.getResults()) {
960 if (!result_to_extern_value.count(result)) {
961 result.replaceAllUsesWith(new_op.getResult(next_index++));
962 continue;
963 }
964 result.replaceAllUsesWith(result_to_extern_value[result]);
965 for (Region *branch : branches)
966 branch->front().getTerminator()->eraseOperand(next_index);
967 }
968
969 // Move region bodies to the new op.
970 for (auto region_index : llvm::seq<int>(0, branches.size()))
971 new_op.getRegion(region_index).takeBody(op.getRegion(region_index));
972
973 op.erase();
974 return success();
975 }
976 };
977 } // namespace
978
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)979 void CaseRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
980 MLIRContext *context) {
981 results.add<CaseOrIfRegionEliminatePassThrough<TF::CaseRegionOp>>(context);
982 }
983
984 //===----------------------------------------------------------------------===//
985 // CastOp
986 //===----------------------------------------------------------------------===//
987
fold(ArrayRef<Attribute> operands)988 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
989 // Cast with the same type is a no-op.
990 Value operand = getOperand();
991 if (getType() == operand.getType()) return operand;
992 return {};
993 }
994
995 //===----------------------------------------------------------------------===//
996 // ConcatOp and ConcatV2Op
997 //===----------------------------------------------------------------------===//
998
999 template <typename OpT,
1000 typename std::enable_if<llvm::is_one_of<
1001 OpT, ConcatOp, ConcatV2Op>::value>::type * = nullptr>
Verify(OpT op)1002 static LogicalResult Verify(OpT op) {
1003 // TODO(hinsu): Convert variadic length attributes to derived attributes.
1004 Operation::operand_range values = op.values();
1005
1006 int axis_idx = std::is_same<OpT, ConcatOp>() ? 0 : 1;
1007 Value axis = *op.getODSOperands(axis_idx).begin();
1008 if (!HasRankAtMost(axis, 1)) {
1009 return op.emitOpError(
1010 "requires axis to be of scalar type (or vector type for older "
1011 "versions)");
1012 }
1013
1014 return VerifyTypesCompatibility(values,
1015 /*mask_one_dim=*/true, op.getOperation());
1016 }
1017
verify()1018 LogicalResult ConcatOp::verify() { return Verify(*this); }
verify()1019 LogicalResult ConcatV2Op::verify() { return Verify(*this); }
1020
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1021 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
1022 MLIRContext *context) {
1023 results.add<ConvertToConcatV2>(context);
1024 }
1025
1026 namespace {
1027
1028 // Hoist coefficient-wise unary operation out of the Concat op:
1029 //
1030 // %0 = "tf.Log1p"(%arg_0)
1031 // %1 = "tf.Log1p"(%arg_1)
1032 // ...
1033 // %n = "tf.Log1p"(%arg_n)
1034 // %m = "tf.ConcatV2"(%0, %1, ..., %n, %axis)
1035 //
1036 // Rewrite it to:
1037 //
1038 // %0 = "tf.ConcatV2"(%arg_0, %arg_1, ..., %arg_n, %axis)
1039 // %1 = "tf.Log1p"(%0)
1040 class HoistCwiseUnaryOutOfConcat : public OpRewritePattern<TF::ConcatV2Op> {
1041 public:
HoistCwiseUnaryOutOfConcat(MLIRContext * context)1042 explicit HoistCwiseUnaryOutOfConcat(MLIRContext *context)
1043 : OpRewritePattern<TF::ConcatV2Op>(context) {}
1044 LogicalResult matchAndRewrite(TF::ConcatV2Op op,
1045 PatternRewriter &rewriter) const override;
1046 };
1047
matchAndRewrite(TF::ConcatV2Op op,PatternRewriter & rewriter) const1048 LogicalResult HoistCwiseUnaryOutOfConcat::matchAndRewrite(
1049 TF::ConcatV2Op op, PatternRewriter &rewriter) const {
1050 auto loc = op.getLoc();
1051
1052 // All concat operands must be defined by ops.
1053 Operation *first_arg_op = op.values().front().getDefiningOp();
1054 if (first_arg_op == nullptr) return failure();
1055
1056 // All concat operands must be produced by the coeff-wise unary operation.
1057 if (!first_arg_op->hasTrait<OpTrait::TF::CwiseUnary>()) return failure();
1058
1059 // All concat operands must be defined by the op of same kind.
1060 bool args_same_op = llvm::all_of(op.values(), [&](Value arg) -> bool {
1061 Operation *arg_op = arg.getDefiningOp();
1062 return arg_op && arg_op->getName() == first_arg_op->getName();
1063 });
1064 if (!args_same_op) return failure();
1065
1066 // Collect unary operations operands.
1067 auto unary_operands = llvm::map_range(op.values(), [](Value arg) -> Value {
1068 return arg.getDefiningOp()->getOperand(0);
1069 });
1070 SmallVector<Value, 8> unary_ops_args(unary_operands);
1071
1072 // Concatenate unary ops operands.
1073 auto concat_unary_operands =
1074 rewriter.create<ConcatV2Op>(loc, op.getType(), unary_ops_args, op.axis());
1075
1076 // Replace original concat with an unary op.
1077 OperationState new_unary_op_state(loc, first_arg_op->getName().getStringRef(),
1078 concat_unary_operands.getResult(),
1079 op.getResult().getType(),
1080 ArrayRef<NamedAttribute>());
1081 Operation *new_unary_op = rewriter.create(new_unary_op_state);
1082
1083 rewriter.replaceOp(op, new_unary_op->getResults());
1084
1085 return success();
1086 }
1087
1088 // Hoist coefficient-wise binary operation out of the Concat op:
1089 //
1090 // %0 = tf.Mul(%lhs_0, %rhs_0)
1091 // %1 = tf.Mul(%lhs_1, %rhs_1)
1092 // ...
1093 // %n = tf.Mul(%lhs_n, %rhs_n)
1094 // %m = tf.ConcatV2(%0, %1, ..., %n, %axis)
1095 //
1096 // Rewrite it to:
1097 //
1098 // %0 = tf.ConcatV2(%lhs0, %lhs1, ..., %lhs_n, %lhs_concat_axis)
1099 // %1 = tf.ConcatV2(%rhs0, %rhs1, ..., %rhs_n, %rhs_concat_axis)
1100 // %2 = tf.Mul(%0, %1)
1101 //
1102 // If a minor fraction of the Concat inputs are not of the same binary op kind
1103 // (tf.Mul in the above example), we will synthesize the binary ops for those
1104 // inputs. e.g. if we instead have %1 = %lhs_1, then we would synthesize a
1105 // tf.Mul op over it and a scalar const tensor 1.0. For now this only applies to
1106 // float32 tensors.
1107 // TODO(hongm): Implement this op synthesis optimization for other dtypes if
1108 // needed.
1109 //
1110 // Because coefficient-wise binary operations support implicit broadcasting, we
1111 // should be very careful with this optimization, and do not accidentally
1112 // produce incorrect concat operations.
1113 class HoistCwiseBinaryOutOfConcat : public OpRewritePattern<TF::ConcatV2Op> {
1114 public:
HoistCwiseBinaryOutOfConcat(MLIRContext * context)1115 explicit HoistCwiseBinaryOutOfConcat(MLIRContext *context)
1116 : OpRewritePattern<TF::ConcatV2Op>(context) {}
1117 LogicalResult matchAndRewrite(TF::ConcatV2Op op,
1118 PatternRewriter &rewriter) const override;
1119
1120 private:
1121 struct HoistParams {
1122 SmallVector<Value, 8> lhs_args;
1123 SmallVector<Value, 8> rhs_args;
1124 int64_t lhs_axis;
1125 int64_t rhs_axis;
1126 Type lhs_concat_type;
1127 Type rhs_concat_type;
1128 int scalar_operand_idx; // can be 0 or 1 for the binary op's operands.
1129 };
1130
1131 // Returns parameters of a binary op hoisting out of concatenation if all of
1132 // the operands are in one of the compatible configurations.
1133 // All inputs of `op` should be of the same binary op kind (e.g. tf.Mul),
1134 // except from the ones in `exceptions`. In that case, we can synthesize that
1135 // binary op kind for the values in `exceptions`.
1136 Optional<HoistParams> GetHoistParams(
1137 TF::ConcatV2Op op, int64_t axis,
1138 const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const;
1139 };
1140
matchAndRewrite(TF::ConcatV2Op op,PatternRewriter & rewriter) const1141 LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
1142 TF::ConcatV2Op op, PatternRewriter &rewriter) const {
1143 auto loc = op.getLoc();
1144
1145 // Axis must be a constant scalar value.
1146 DenseIntElementsAttr axis_attr;
1147 if (!matchPattern(op.axis(), m_Constant(&axis_attr))) return failure();
1148 if (axis_attr.getNumElements() != 1) return failure();
1149 int64_t axis =
1150 axis_attr.getSplatValue<IntegerAttr>().getValue().getSExtValue();
1151 // TODO(ezhulenev): Compute axis from rank. e.g. It might be common to concat
1152 // on the channels dim for NCHW layout as axis=-2.
1153 if (axis < 0) return failure();
1154
1155 // All concat operands must be defined by ops of the same kind (e.g. tf.Mul),
1156 // or some other ops that we might convert to using the same op kind above
1157 // (e.g. converting op A to tf.Mul(A, 1.0))
1158 // TODO(hongm): generalize the code here to support cases where the first arg
1159 // has no defining op (e.g. might be a block arg).
1160 Operation *first_arg_op = op.values().front().getDefiningOp();
1161 if (first_arg_op == nullptr) return failure();
1162
1163 // All concat operands must be produced by the coeff-wise binary operation.
1164 if (!first_arg_op->hasTrait<OpTrait::TF::CwiseBinary>()) return failure();
1165
1166 // All concat operands must be defined by the op of same kind, except for a
1167 // minor portion which we track in `exceptions`.
1168 // Map from the operands to operand indices.
1169 llvm::SmallDenseMap<Value, unsigned, 4> exceptions;
1170 unsigned operand_idx = 0;
1171 for (Value arg : op.values()) {
1172 Operation *arg_op = arg.getDefiningOp();
1173 if (arg_op && arg_op->getName() == first_arg_op->getName()) {
1174 ++operand_idx;
1175 continue;
1176 }
1177 exceptions[arg] = operand_idx++;
1178 }
1179 // Recall those inputs to the concat op that are not produced by a binary op
1180 // of the `first_arg_op` kind (e.g. tf.Mul) are stored in `exceptions`. If
1181 // there are too many exceptions, it might not be cost effective to apply the
1182 // concat hoisting optimization here.
1183 // Setting the threshold to be 50% as a simple cost model heuristic. e.g. If 1
1184 // out of 2 concat inputs is an exception, we don't apply the hoist. If it's 1
1185 // out of 3, we do.
1186 const float exception_pct_threshold = 0.5;
1187 if (static_cast<float>(op.values().size()) * exception_pct_threshold <=
1188 exceptions.size())
1189 return failure();
1190
1191 // Compute binary operands hoist parameters.
1192 auto hoist_params = GetHoistParams(op, axis, exceptions);
1193 if (!hoist_params.has_value()) return failure();
1194
1195 // Process `exceptions`: For each value there, synthesize a binary op of the
1196 // above kind, so that the concat hoisting optimization can still apply.
1197 if (!exceptions.empty()) {
1198 int identity_val;
1199 if (isa<AddOp>(first_arg_op) || isa<SubOp>(first_arg_op))
1200 identity_val = 0;
1201 else if (isa<MulOp>(first_arg_op) || isa<DivOp>(first_arg_op) ||
1202 isa<RealDivOp>(first_arg_op))
1203 identity_val = 1;
1204 else
1205 return failure();
1206 DenseElementsAttr const_attr;
1207 auto scalar_tensor_type =
1208 first_arg_op->getOperand(hoist_params->scalar_operand_idx)
1209 .getType()
1210 .dyn_cast<ShapedType>();
1211 Type scalar_dtype = scalar_tensor_type.getElementType();
1212 if (scalar_dtype.isa<FloatType>())
1213 const_attr = DenseElementsAttr::get(scalar_tensor_type,
1214 static_cast<float>(identity_val));
1215 else
1216 return failure();
1217
1218 // All checks are passes, and we now prepare for rewrite.
1219 auto identity_const = rewriter.create<TF::ConstOp>(loc, const_attr);
1220 for (const auto &kv : exceptions) {
1221 assert(!hoist_params->lhs_args[kv.second]);
1222 assert(!hoist_params->rhs_args[kv.second]);
1223
1224 if (hoist_params->scalar_operand_idx == 1) {
1225 hoist_params->lhs_args[kv.second] = kv.first;
1226 hoist_params->rhs_args[kv.second] = identity_const;
1227 } else {
1228 assert(hoist_params->scalar_operand_idx == 0);
1229 hoist_params->lhs_args[kv.second] = identity_const;
1230 hoist_params->rhs_args[kv.second] = kv.first;
1231 }
1232 }
1233 }
1234
1235 // Concatenates `args` along `axis`.
1236 auto pack_or_concat = [&](bool is_scalar, Type result_type, ValueRange args,
1237 int64_t axis) {
1238 // Use `PackOp` for scalar concatenation because `ConcatV2Op` doesn't
1239 // support scalar concatenation.
1240 if (is_scalar) {
1241 auto pack = rewriter.create<PackOp>(loc, result_type, args,
1242 rewriter.getI64IntegerAttr(axis));
1243 return pack.getResult();
1244 }
1245
1246 // New concatenation axis.
1247 auto axis_type = RankedTensorType::get({}, getElementTypeOrSelf(axis_attr));
1248 DenseIntElementsAttr attr;
1249 if (axis_type.getElementType().isInteger(32)) {
1250 attr = DenseIntElementsAttr::get(axis_type, static_cast<int32_t>(axis));
1251 } else {
1252 assert(axis_type.getElementType().isInteger(64));
1253 attr = DenseIntElementsAttr::get(axis_type, axis);
1254 }
1255 auto axis_const = rewriter.create<TF::ConstOp>(loc, attr);
1256
1257 auto concat =
1258 rewriter.create<ConcatV2Op>(loc, result_type, args, axis_const);
1259 return concat.getResult();
1260 };
1261
1262 // Concatenate binary ops operands on the new axis.
1263 Value lhs_concat = pack_or_concat(
1264 hoist_params->scalar_operand_idx == 0, hoist_params->lhs_concat_type,
1265 hoist_params->lhs_args, hoist_params->lhs_axis);
1266 Value rhs_concat = pack_or_concat(
1267 hoist_params->scalar_operand_idx == 1, hoist_params->rhs_concat_type,
1268 hoist_params->rhs_args, hoist_params->rhs_axis);
1269
1270 // Replace original concat with a binary op.
1271 OperationState new_binary_op_state(
1272 loc, first_arg_op->getName().getStringRef(), {lhs_concat, rhs_concat},
1273 op.getResult().getType(), ArrayRef<NamedAttribute>());
1274 Operation *new_binary_op = rewriter.create(new_binary_op_state);
1275
1276 rewriter.replaceOp(op, new_binary_op->getResults());
1277
1278 return success();
1279 }
1280
1281 Optional<HoistCwiseBinaryOutOfConcat::HoistParams>
GetHoistParams(TF::ConcatV2Op op,int64_t axis,const llvm::SmallDenseMap<Value,unsigned,4> & exceptions) const1282 HoistCwiseBinaryOutOfConcat::GetHoistParams(
1283 TF::ConcatV2Op op, int64_t axis,
1284 const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const {
1285 assert(axis >= 0);
1286 // Collects lhs or rhs arguments of concat op operands.
1287 auto args = [&](int operand_idx) -> SmallVector<Value, 8> {
1288 auto range = llvm::map_range(op.values(), [&](Value arg) {
1289 if (exceptions.count(arg)) return Value();
1290 return arg.getDefiningOp()->getOperand(operand_idx);
1291 });
1292 return {range.begin(), range.end()};
1293 };
1294
1295 // Returns true if all binary ops operands at `operand_idx` index are tensors
1296 // of `axis + 1` rank and axis dim has size `1`.
1297 auto is_all_tensors = [&](int operand_idx, int axis) -> bool {
1298 return llvm::all_of(op.values(), [&](Value arg) -> bool {
1299 mlir::Value operand;
1300 if (exceptions.count(arg)) {
1301 // For exceptions, since we are going to synthesize a binary op that
1302 // produce the identity value, it is also required that it is a ranked
1303 // tensor with rank = `axis + 1` and axis dim has size `1`.
1304 operand = arg;
1305 } else {
1306 operand = arg.getDefiningOp()->getOperand(operand_idx);
1307 }
1308 auto ranked = operand.getType().dyn_cast<RankedTensorType>();
1309 return ranked && ranked.getRank() == (axis + 1) &&
1310 ranked.getShape()[axis] == 1;
1311 });
1312 };
1313
1314 // Returns true if all binary ops operands at `operand_idx` index are scalars.
1315 auto is_all_scalars = [&](int operand_idx) -> bool {
1316 return llvm::all_of(op.values(), [&](Value arg) -> bool {
1317 if (exceptions.count(arg)) return true;
1318 auto operand = arg.getDefiningOp()->getOperand(operand_idx);
1319 auto ranked = operand.getType().dyn_cast<RankedTensorType>();
1320 return ranked && ranked.hasRank() && ranked.getRank() == 0;
1321 });
1322 };
1323
1324 // Concat result type must be a ranked tensor.
1325 auto ranked = op.getType().dyn_cast<RankedTensorType>();
1326 if (!ranked) return None;
1327
1328 // TODO(ezhulenev): Add support for more valid concat patterns.
1329
1330 // Tensor + Scalar: [..., 1] + [] <- scalar
1331 // ^
1332 // \- axis is the innermost dimension.
1333 //
1334 // Concatenate tensor arguments on the same axis as the original operation,
1335 // and concatenate scalars into the vector.
1336 if (is_all_tensors(0, axis) && is_all_scalars(1)) {
1337 std::array<int64_t, 1> rhs_dims{static_cast<int64_t>(op.values().size())};
1338 auto rhs_type = RankedTensorType::get(rhs_dims, ranked.getElementType());
1339 return HoistParams{args(0),
1340 args(1),
1341 axis,
1342 0,
1343 op.getType(),
1344 rhs_type,
1345 /*scalar_operand_idx=*/1};
1346 } else if (is_all_tensors(1, axis) && is_all_scalars(0)) {
1347 std::array<int64_t, 1> lhs_dims{static_cast<int64_t>(op.values().size())};
1348 auto lhs_type = RankedTensorType::get(lhs_dims, ranked.getElementType());
1349 return HoistParams{args(0),
1350 args(1),
1351 0,
1352 axis,
1353 lhs_type,
1354 op.getType(),
1355 /*scalar_operand_idx=*/0};
1356 }
1357 return None;
1358 }
1359
1360 } // namespace
1361
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1362 void ConcatV2Op::getCanonicalizationPatterns(RewritePatternSet &results,
1363 MLIRContext *context) {
1364 results.add<HoistCwiseBinaryOutOfConcat, HoistCwiseUnaryOutOfConcat>(context);
1365 }
1366
1367 //===----------------------------------------------------------------------===//
1368 // CumsumOp and CumprodOp
1369 //===----------------------------------------------------------------------===//
1370
1371 template <typename OpT, typename std::enable_if<llvm::is_one_of<
1372 OpT, CumsumOp, CumprodOp>::value>::type * = nullptr>
Verify(OpT op)1373 static LogicalResult Verify(OpT op) {
1374 if (!IsOfRankOrUnranked(op.axis(), 0))
1375 return op.emitOpError("requires scalar axis operand");
1376
1377 DenseIntElementsAttr axis_attr;
1378 if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
1379 auto input_ty = op.x().getType().template dyn_cast<RankedTensorType>();
1380 if (input_ty) {
1381 int64_t rank = input_ty.getRank();
1382 assert(axis_attr.getNumElements() == 1 &&
1383 "scalar attribute should have exactly one element");
1384 int64_t axis = (*axis_attr.begin()).getSExtValue();
1385 if (axis < -rank || axis >= rank) {
1386 return op.emitError()
1387 << "axis operand should be within range [" << -rank << ", "
1388 << rank << "); actual value: " << axis;
1389 }
1390 }
1391 }
1392
1393 return success();
1394 }
verify()1395 LogicalResult CumprodOp::verify() { return Verify(*this); }
verify()1396 LogicalResult CumsumOp::verify() { return Verify(*this); }
1397
1398 //===----------------------------------------------------------------------===//
1399 // ConcatOffsetOp
1400 //===----------------------------------------------------------------------===//
1401
verify()1402 LogicalResult ConcatOffsetOp::verify() {
1403 ConcatOffsetOp op = *this;
1404 if (op.N() < 2)
1405 return op.emitOpError() << "requires N to be at least 2, got " << op.N();
1406
1407 if (op.shape().size() != op.offset().size())
1408 return op.emitOpError()
1409 << "requires sizes of shapes and offsets to be the same, got sizes "
1410 << op.shape().size() << " and " << op.offset().size();
1411
1412 auto ranked_dim = op.concat_dim().getType().dyn_cast<RankedTensorType>();
1413 if (ranked_dim && ranked_dim.getRank() != 0)
1414 return op.emitOpError()
1415 << "requires concat_dim to be a scalar, got tensor of rank "
1416 << ranked_dim.getRank();
1417
1418 int64_t num_dims = -1;
1419 for (auto shape_offset_idx :
1420 llvm::enumerate(llvm::zip(op.shape(), op.offset()))) {
1421 Value shape = std::get<0>(shape_offset_idx.value());
1422 Value offset = std::get<1>(shape_offset_idx.value());
1423 const size_t idx = shape_offset_idx.index();
1424
1425 if (failed(verifyCompatibleShape(shape.getType(), offset.getType())))
1426 return op.emitOpError() << "requires operand and result " << idx
1427 << " to have compatible shapes";
1428
1429 auto ranked_shape = shape.getType().dyn_cast<RankedTensorType>();
1430 if (!ranked_shape) continue;
1431
1432 if (ranked_shape.getRank() != 1)
1433 return op.emitOpError() << "requires shape tensor operand " << idx
1434 << " to be of rank 1, got tensor of rank "
1435 << ranked_shape.getRank();
1436
1437 if (!ranked_shape.hasStaticShape()) continue;
1438
1439 int64_t ranked_shape_dim = ranked_shape.getDimSize(0);
1440 if (num_dims == -1)
1441 num_dims = ranked_shape_dim;
1442 else if (ranked_shape_dim != num_dims)
1443 return op.emitOpError()
1444 << "requires shape tensor (rank 1) operand " << idx
1445 << " to be of length " << num_dims
1446 << ", got tensor (rank 1) of length " << ranked_shape_dim;
1447 }
1448
1449 return success();
1450 }
1451
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1452 LogicalResult ConcatOffsetOp::fold(ArrayRef<Attribute> operands,
1453 SmallVectorImpl<OpFoldResult> &results) {
1454 // ConcatOffset must have its first operand be concat_dim and at least two
1455 // shape tensors in variadic shapes operand.
1456 if (operands.size() < 3) return failure();
1457
1458 // Check concat_dim is a scalar.
1459 auto concat_dim_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1460 if (!concat_dim_attr || concat_dim_attr.getType().getRank() != 0)
1461 return failure();
1462
1463 llvm::SmallVector<DenseIntElementsAttr, 4> shapes;
1464 shapes.reserve(operands.size() - 1);
1465 for (Attribute shape : llvm::drop_begin(operands, 1))
1466 if (auto shape_attr = shape.dyn_cast_or_null<DenseIntElementsAttr>())
1467 shapes.push_back(shape_attr);
1468 else
1469 return failure();
1470
1471 // Check all shapes are vectors of the same length.
1472 if (shapes.front().getType().getRank() != 1) return success();
1473 const int64_t num_dims = shapes.front().getNumElements();
1474 for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1))
1475 if (shape.getType().getRank() != 1 || shape.getNumElements() != num_dims)
1476 return failure();
1477
1478 // Check concat_dim is within [-num_dims, num_dims).
1479 int32_t concat_dim = (*concat_dim_attr.getValues<int32_t>().begin());
1480 if (concat_dim < 0) concat_dim += num_dims;
1481 if (concat_dim >= num_dims || concat_dim < 0) return failure();
1482
1483 // Check all elements besides at concat_dim match across all shape tensors.
1484 SmallVector<int32_t, 4> shape0;
1485 shape0.reserve(num_dims);
1486 for (int32_t dim : shapes.front().getValues<int32_t>()) shape0.push_back(dim);
1487
1488 for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) {
1489 for (auto dims_and_idx : llvm::enumerate(llvm::zip(shape0, shape))) {
1490 if (dims_and_idx.index() == concat_dim) continue;
1491
1492 if (std::get<0>(dims_and_idx.value()) !=
1493 std::get<1>(dims_and_idx.value()).getSExtValue())
1494 return failure();
1495 }
1496 }
1497
1498 // Compute an exclusive cumulative sum of elements at concat_dim.
1499 results.reserve(shapes.size());
1500 SmallVector<int32_t, 4> cumulative_sum(num_dims, 0);
1501 RankedTensorType offset_type =
1502 RankedTensorType::get({num_dims}, IntegerType::get(getContext(), 32));
1503 for (DenseIntElementsAttr shape : shapes) {
1504 results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum));
1505 cumulative_sum[concat_dim] += shape.getValues<int32_t>()[concat_dim];
1506 }
1507
1508 return success();
1509 }
1510
1511 //===----------------------------------------------------------------------===//
1512 // ConstOp
1513 //===----------------------------------------------------------------------===//
1514
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)1515 void ConstOp::getAsmResultNames(
1516 function_ref<void(Value, StringRef)> setNameFn) {
1517 setNameFn(getResult(), "cst");
1518 }
1519
fold(ArrayRef<Attribute> operands)1520 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
1521 assert(operands.empty() && "constant has no operands");
1522
1523 // Return the held attribute value.
1524 return value();
1525 }
1526
1527 // Builds a constant op with the specified attribute `value`. The result
1528 // op's type is deduced from `value`; if `value` is of scalar type,
1529 // wraps it up with a tensor type of empty shape.
1530 // TODO(jpienaar): This one differs from the autogenerated one as it takes an
1531 // attribute but always creates an ElementsAttr internally.
build(OpBuilder & builder,OperationState & result,Attribute value)1532 void ConstOp::build(OpBuilder &builder, OperationState &result,
1533 Attribute value) {
1534 ShapedType type;
1535 if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
1536 return ConstOp::build(builder, result, elem_attr);
1537 } else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
1538 // All TensorFlow types must be tensor types. In the build() method,
1539 // we want to provide more flexibility by allowing attributes of scalar
1540 // types. But we need to wrap it up with ElementsAttr to construct
1541 // valid TensorFlow constants.
1542 auto typed_attr = value.cast<TypedAttr>();
1543 type = RankedTensorType::get(/*shape=*/{}, typed_attr.getType());
1544 return ConstOp::build(builder, result, DenseElementsAttr::get(type, value));
1545 }
1546 // TODO(jpienaar): support other TensorFlow specific types.
1547 llvm_unreachable("unsupported attribute type for building tf.Const");
1548 }
1549
build(OpBuilder & builder,OperationState & result,Type type,Attribute value)1550 void ConstOp::build(OpBuilder &builder, OperationState &result, Type type,
1551 Attribute value) {
1552 // Handle the case where the type and value are already tensors.
1553 if (type.isa<TensorType>() && value.isa<ElementsAttr>()) {
1554 result.addTypes(type);
1555 result.addAttribute("value", value);
1556 return;
1557 }
1558
1559 // Otherwise, default to the attribute builder.
1560 ConstOp::build(builder, result, value);
1561 assert(type == result.types[0] && "type mismatch in construction");
1562 }
1563
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1564 LogicalResult ConstOp::inferReturnTypes(
1565 MLIRContext *context, Optional<Location> location, ValueRange operands,
1566 DictionaryAttr attributes, RegionRange regions,
1567 SmallVectorImpl<Type> &inferredReturnTypes) {
1568 auto value = attributes.get("value");
1569 if (!value) return emitOptionalError(location, "missing attribute 'value'");
1570 if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
1571 inferredReturnTypes.assign({elem_attr.getType()});
1572 return success();
1573 }
1574 return emitOptionalError(location,
1575 "attribute 'value' failed to satisfy constraint: "
1576 "constant vector/tensor");
1577 }
1578
1579 //===----------------------------------------------------------------------===//
1580 // Conv2DOp and Conv3DOp
1581 //===----------------------------------------------------------------------===//
1582
VerifyConvOpAttributes(int num_dims,ArrayRef<Attribute> strides,ArrayRef<Attribute> dilations,llvm::Optional<mlir::Location> location)1583 static LogicalResult VerifyConvOpAttributes(
1584 int num_dims, ArrayRef<Attribute> strides, ArrayRef<Attribute> dilations,
1585 llvm::Optional<mlir::Location> location) {
1586 int64_t strides_size = strides.size();
1587 if (strides_size != num_dims)
1588 return emitOptionalError(
1589 location, "requires strides attribute length to be ", num_dims);
1590 auto is_not_positive = [](Attribute val) {
1591 return val.cast<IntegerAttr>().getValue().getSExtValue() <= 0;
1592 };
1593 if (llvm::any_of(strides, is_not_positive))
1594 return emitOptionalError(location, "requires positive strides");
1595
1596 int64_t dilations_size = dilations.size();
1597 if (dilations_size != num_dims)
1598 return emitOptionalError(
1599 location, "requires dilations attribute length to be ", num_dims);
1600 if (llvm::any_of(dilations, is_not_positive))
1601 return emitOptionalError(location, "requires positive dilations");
1602
1603 return success();
1604 }
1605
1606 // Verifies that,
1607 // * Number of input channels is divisible by the number of filter input
1608 // channels
1609 template <typename OpT, typename std::enable_if<llvm::is_one_of<
1610 OpT, Conv2DOp, Conv3DOp>::value>::type * = nullptr>
Verify(OpT op)1611 static LogicalResult Verify(OpT op) {
1612 int num_spatial_dims = std::is_same<OpT, Conv2DOp>() ? 2 : 3;
1613 int num_dims = 2 + num_spatial_dims;
1614
1615 StringRef data_format = op.data_format();
1616 tensorflow::TensorFormat format;
1617 auto data_format_is_valid = FormatFromString(data_format.str(), &format);
1618 if (!data_format_is_valid) {
1619 return emitOptionalError(op.getLoc(), "Invalid data format provided");
1620 }
1621
1622 const StringRef paddings = op.padding();
1623 tensorflow::Padding padding;
1624 auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding);
1625 if (!padding_is_valid.ok()) {
1626 return emitOptionalError(op.getLoc(), "Invalid padding format provided");
1627 }
1628
1629 // Verifies that,
1630 // * Ranks of operands and result are valid
1631 // * Length of explicit_paddings attribute is valid and has non negative
1632 // elements
1633 // * strides and dilations attributes have positive elements
1634 if (!IsOfRankOrUnranked(op.input(), num_dims) ||
1635 !IsOfRankOrUnranked(op.filter(), num_dims))
1636 return emitOptionalError(op.getLoc(), "requires operands to be ", num_dims,
1637 "D tensor");
1638
1639 if (padding == tensorflow::Padding::EXPLICIT) {
1640 ArrayRef<Attribute> explicit_padding;
1641 ArrayAttr explicit_pad =
1642 op->getAttr("explicit_paddings")
1643 .template dyn_cast_or_null<::mlir::ArrayAttr>();
1644 if (!explicit_pad) {
1645 explicit_pad = ::mlir::Builder(op->getContext()).getI64ArrayAttr({});
1646 }
1647 explicit_padding = explicit_pad.getValue();
1648
1649 if (explicit_padding.empty()) {
1650 return emitOptionalError(op.getLoc(),
1651 "requires attribute 'explicit_paddings' with "
1652 "'EXPLICIT' padding mode");
1653 }
1654 if (explicit_padding.size() != num_dims * 2) {
1655 return emitOptionalError(
1656 op.getLoc(), "requires explicit_paddings attribute length to be ",
1657 num_dims * 2);
1658 }
1659 auto is_negative = [](Attribute val) {
1660 return val.cast<IntegerAttr>().getValue().getSExtValue() < 0;
1661 };
1662 if (llvm::any_of(explicit_padding, is_negative))
1663 return emitOptionalError(op.getLoc(),
1664 "requires non negative explicit paddings");
1665 }
1666
1667 ArrayRef<Attribute> strides = op.strides().getValue();
1668 ArrayRef<Attribute> dilations = op.dilations().getValue();
1669 if (failed(
1670 VerifyConvOpAttributes(num_dims, strides, dilations, op.getLoc()))) {
1671 return failure();
1672 }
1673
1674 int64_t input_channels = -1;
1675 if (auto ty = op.input().getType().template dyn_cast<RankedTensorType>()) {
1676 absl::string_view data_format(op.data_format().data(),
1677 op.data_format().size());
1678 tensorflow::TensorFormat format;
1679 auto is_valid = FormatFromString(data_format, &format);
1680 DCHECK(is_valid) << data_format;
1681 int idx = tensorflow::GetTensorFeatureDimIndex(num_dims, format);
1682 input_channels = ty.getDimSize(idx);
1683 }
1684
1685 int64_t filter_channels = -1;
1686 if (auto ty = op.filter().getType().template dyn_cast<RankedTensorType>()) {
1687 int idx = tensorflow::GetFilterTensorInputChannelsDimIndex(
1688 num_dims, tensorflow::FORMAT_HWIO);
1689 filter_channels = ty.getDimSize(idx);
1690 }
1691
1692 if (input_channels != -1 && filter_channels != -1 &&
1693 input_channels % filter_channels != 0)
1694 return op.emitOpError()
1695 << "requires the number of input channels to be divisible by the "
1696 "number of filter input channels; found "
1697 << input_channels << " and " << filter_channels << ", respectively";
1698
1699 return success();
1700 }
1701
verify()1702 LogicalResult Conv2DOp::verify() { return Verify(*this); }
verify()1703 LogicalResult Conv3DOp::verify() { return Verify(*this); }
1704
UpdateDataFormat(StringRef data_format)1705 LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) {
1706 auto perm = GetDataFormatPermutation(this->data_format(), data_format);
1707 if (perm.empty()) return failure();
1708
1709 // Update data_format attribute and result types.
1710 if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
1711
1712 // Update convolution attributes.
1713 (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
1714 (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
1715 (*this)->setAttr("explicit_paddings",
1716 ShuffleArrayAttr(explicit_paddings(), perm, 2));
1717
1718 return success();
1719 }
1720
1721 // Verifies the inferred return type of the given operation.
1722 template <typename OpT,
1723 typename std::enable_if<llvm::is_one_of<
1724 OpT, Conv2DOpAdaptor, Conv3DOpAdaptor>::value>::type * = nullptr>
inferConvReturnTypeComponents(llvm::Optional<mlir::Location> location,OpT op,ArrayRef<Attribute> explicit_padding,llvm::SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1725 static LogicalResult inferConvReturnTypeComponents(
1726 llvm::Optional<mlir::Location> location, OpT op,
1727 ArrayRef<Attribute> explicit_padding,
1728 llvm::SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1729 const int64_t num_spatial_dims = std::is_same<OpT, Conv2DOpAdaptor>() ? 2 : 3;
1730 const int64_t num_dims = 2 + num_spatial_dims;
1731 const Value input = op.input();
1732 const Value filter = op.filter();
1733 const TensorType input_ty = input.getType().template cast<TensorType>();
1734 const TensorType filter_ty = filter.getType().template cast<TensorType>();
1735
1736 ArrayRef<Attribute> strides = op.strides().getValue();
1737 StringRef data_format = op.data_format();
1738 ArrayRef<Attribute> dilations = op.dilations().getValue();
1739
1740 tensorflow::TensorFormat format;
1741 auto data_format_is_valid = FormatFromString(data_format.str(), &format);
1742 assert(data_format_is_valid);
1743 (void)data_format_is_valid;
1744
1745 tensorflow::Padding padding;
1746 const StringRef paddings = op.padding();
1747 auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding);
1748 assert(padding_is_valid.ok());
1749 (void)padding_is_valid;
1750
1751 auto get_int = [](Attribute attr) {
1752 return attr.template cast<IntegerAttr>().getInt();
1753 };
1754
1755 // Output always have `num_dims` rank. All dimensions are initialized to
1756 // dynamic size and can be partially inferred.
1757 SmallVector<int64_t, 4> return_shape(num_dims, ShapedType::kDynamicSize);
1758 // Output batch and channel dimension can be obtained using utilities from
1759 // tensorflow/core/util/tensor_format.h.
1760 if (input_ty.hasRank()) {
1761 return_shape[GetTensorBatchDimIndex(num_dims, format)] =
1762 input_ty.getDimSize(GetTensorBatchDimIndex(num_dims, format));
1763 }
1764 if (filter_ty.hasRank()) {
1765 return_shape[GetTensorFeatureDimIndex(num_dims, format)] =
1766 filter_ty.getDimSize(GetFilterTensorOutputChannelsDimIndex(
1767 num_dims, tensorflow::FORMAT_HWIO));
1768 }
1769 // Spatial dimensions can be inferred only when both input and filter are
1770 // ranked because we need to get their spatial dimensions.
1771 if (input_ty.hasRank() && filter_ty.hasRank()) {
1772 // Checks the size of each of the output spatial dimensions.
1773 for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
1774 const int64_t dim = GetTensorSpatialDimIndex(num_dims, format, i);
1775 int64_t stride = get_int(strides[dim]);
1776 int64_t expected_output_size;
1777 int64_t pad_low;
1778 int64_t pad_high;
1779 // Retrieve padding, if defined explicitly.
1780 if (padding == tensorflow::Padding::EXPLICIT) {
1781 pad_low = get_int(explicit_padding[2 * dim]);
1782 pad_high = get_int(explicit_padding[2 * dim + 1]);
1783 }
1784 // Skip if input or filter size is dynamic.
1785 if (input_ty.isDynamicDim(dim) || filter_ty.isDynamicDim(i)) continue;
1786 // Calculate the expected_output_size.
1787 tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
1788 input_ty.getDimSize(dim), filter_ty.getDimSize(i),
1789 get_int(dilations[dim]), stride, padding, &expected_output_size,
1790 &pad_low, &pad_high);
1791 // Return failure if expected_output_size could not be calculated.
1792 if (!status.ok()) return failure();
1793 return_shape[dim] = expected_output_size;
1794 }
1795 }
1796
1797 inferredReturnShapes.emplace_back(return_shape, input_ty.getElementType());
1798 return success();
1799 }
1800
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1801 LogicalResult Conv2DOp::inferReturnTypeComponents(
1802 MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
1803 DictionaryAttr attributes, RegionRange regions,
1804 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1805 Conv2DOpAdaptor op(operands.getValues(), attributes);
1806 ArrayRef<Attribute> explicit_padding;
1807 ArrayAttr explicit_pad =
1808 attributes.get("explicit_paddings").dyn_cast_or_null<::mlir::ArrayAttr>();
1809 if (!explicit_pad) {
1810 explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({});
1811 }
1812 explicit_padding = explicit_pad.getValue();
1813
1814 return inferConvReturnTypeComponents(location, op, explicit_padding,
1815 inferredReturnShapes);
1816 }
1817
GetOptimalLayout(const RuntimeDevices & devices)1818 StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) {
1819 // Keep current data format if no GPUs are available or if explicit placement
1820 // does not allow to use GPU for this operation.
1821 if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
1822 return data_format();
1823
1824 // Input must be a tensor.
1825 auto input_ty = input().getType().dyn_cast<TensorType>();
1826 if (!input_ty) return data_format();
1827
1828 // For f16 data type on devices with Tensor Cores support NHWC data format
1829 // is up to ~2x faster.
1830 const bool is_f16 = input_ty.getElementType().isF16();
1831 if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
1832
1833 // For f32/f16 data type decision depends on the filter size in spatial
1834 // dimensions, for other data types we keep current data format.
1835 if (!input_ty.getElementType().isF32() && !input_ty.getElementType().isF16())
1836 return data_format();
1837
1838 // Keep current data format if filter rank is unknown or not equal to 4.
1839 auto filter_ty = filter().getType().dyn_cast<RankedTensorType>();
1840 if (!filter_ty || filter_ty.getRank() != 4) return data_format();
1841
1842 const int64_t d0 = filter_ty.getDimSize(0);
1843 const int64_t d1 = filter_ty.getDimSize(1);
1844
1845 auto all_ones = [](ArrayAttr arr) -> bool {
1846 return llvm::all_of(arr, [](Attribute attr) -> bool {
1847 return attr.cast<IntegerAttr>().getInt() == 1;
1848 });
1849 };
1850
1851 // Convolutions with 1x1 filter and with strides and dilations all ones, can
1852 // be computed as a GEMM in NHWC data format, and can be up to ~2x times
1853 // faster than convolution in NCHW.
1854 const bool one_by_one = d0 == 1 && d1 == 1;
1855 const bool trivial_strides = all_ones(strides());
1856 const bool trivial_dilations = all_ones(dilations());
1857
1858 // TODO(ezhulenev): This might lead to excessive transposes in the final IR,
1859 // if the ratio of 1x1 convolutions to regular convolutions is close to 1:1.
1860 // Also FusedBatchNorm in training mode prefers NCHW data format. Check if all
1861 // users can efficiently use NHWC data format?
1862 if (one_by_one && trivial_strides && trivial_dilations) {
1863 return "NHWC";
1864 }
1865
1866 // If filter spatial dimensions are unknown or not 1x1 we prefer NCHW, because
1867 // it's the fastest option on NVIDIA GPUs with cuDNN library support.
1868 return "NCHW";
1869 }
1870
1871 //===----------------------------------------------------------------------===//
1872 // Conv2dBackpropFilterOp
1873 //===----------------------------------------------------------------------===//
1874
UpdateDataFormat(StringRef data_format)1875 LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) {
1876 StringRef src_data_format = this->data_format();
1877
1878 auto perm = GetDataFormatPermutation(src_data_format, data_format);
1879 if (perm.empty()) return failure();
1880
1881 // Update data_format attribute and result types.
1882 if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
1883
1884 // Update convolution attributes.
1885 (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
1886 (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
1887 (*this)->setAttr("explicit_paddings",
1888 ShuffleArrayAttr(explicit_paddings(), perm, 2));
1889
1890 // Permute filter sizes operand.
1891 OpBuilder builder(getOperation());
1892 auto filter_sizes_permuted = builder.create<TF::DataFormatVecPermuteOp>(
1893 getLoc(), filter_sizes(), StringAttr::get(getContext(), src_data_format),
1894 StringAttr::get(getContext(), data_format));
1895 setOperand(1, filter_sizes_permuted);
1896
1897 return success();
1898 }
1899
GetOptimalLayout(const RuntimeDevices & devices)1900 StringRef Conv2DBackpropFilterOp::GetOptimalLayout(
1901 const RuntimeDevices &devices) {
1902 // Keep current data format if no GPUs are available or if explicit placement
1903 // does not allow to use GPU for this operation.
1904 if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
1905 return data_format();
1906
1907 // Input must be a tensor.
1908 auto input_ty = input().getType().dyn_cast<TensorType>();
1909 if (!input_ty) return data_format();
1910
1911 // For f16 data type on devices with Tensor Cores support NHWC data format
1912 // is up to ~2x faster.
1913 const bool is_f16 = input_ty.getElementType().isF16();
1914 if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
1915
1916 // Otherwise always use "NCHW".
1917 return "NCHW";
1918 }
1919
1920 //===----------------------------------------------------------------------===//
1921 // Conv2DBackpropInputOp
1922 //===----------------------------------------------------------------------===//
1923
verify()1924 LogicalResult Conv2DBackpropInputOp::verify() {
1925 Conv2DBackpropInputOp op = *this;
1926 int num_spatial_dims = 2;
1927 int num_dims = 2 + num_spatial_dims;
1928
1929 if (!IsOfRankOrUnranked(op.out_backprop(), num_dims) ||
1930 !IsOfRankOrUnranked(op.filter(), num_dims))
1931 return op.emitOpError()
1932 << "requires operands to be " << num_dims << "D tensor";
1933 if (!IsOfRankOrUnranked(op.getResult(), num_dims))
1934 return op.emitOpError()
1935 << "requires result to be " << num_dims << "D tensor";
1936
1937 llvm::Optional<mlir::Location> location = op.getLoc();
1938 ArrayRef<Attribute> strides = op.strides().getValue();
1939 ArrayRef<Attribute> dilations = op.dilations().getValue();
1940 LogicalResult verify_result =
1941 VerifyConvOpAttributes(num_dims, strides, dilations, location);
1942 if (failed(verify_result)) {
1943 return verify_result;
1944 }
1945
1946 return success();
1947 }
1948
UpdateDataFormat(StringRef data_format)1949 LogicalResult Conv2DBackpropInputOp::UpdateDataFormat(StringRef data_format) {
1950 StringRef src_data_format = this->data_format();
1951
1952 auto perm = GetDataFormatPermutation(src_data_format, data_format);
1953 if (perm.empty()) return failure();
1954
1955 // Update data_format attribute and result types.
1956 if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
1957
1958 // Update convolution attributes.
1959 (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
1960 (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
1961 (*this)->setAttr("explicit_paddings",
1962 ShuffleArrayAttr(explicit_paddings(), perm, 2));
1963
1964 // Permute input sizes operand.
1965 OpBuilder builder(getOperation());
1966 auto input_sizes_permuted = builder.create<TF::DataFormatVecPermuteOp>(
1967 getLoc(), input_sizes(), StringAttr::get(getContext(), src_data_format),
1968 StringAttr::get(getContext(), data_format));
1969 setOperand(0, input_sizes_permuted);
1970
1971 return success();
1972 }
1973
GetOptimalLayout(const RuntimeDevices & devices)1974 StringRef Conv2DBackpropInputOp::GetOptimalLayout(
1975 const RuntimeDevices &devices) {
1976 // Keep current data format if no GPUs are available or if explicit placement
1977 // does not allow to use GPU for this operation.
1978 if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
1979 return data_format();
1980
1981 // Filter must be a tensor.
1982 auto filter_ty = filter().getType().dyn_cast<TensorType>();
1983 if (!filter_ty) return data_format();
1984
1985 // For f16 data type on devices with Tensor Cores support NHWC data format
1986 // is up to ~2x faster.
1987 const bool is_f16 = filter_ty.getElementType().isF16();
1988 if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
1989
1990 // Otherwise always use "NCHW".
1991 return "NCHW";
1992 }
1993
1994 //===----------------------------------------------------------------------===//
1995 // Conv3DOp
1996 //===----------------------------------------------------------------------===//
1997
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1998 LogicalResult Conv3DOp::inferReturnTypeComponents(
1999 MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
2000 DictionaryAttr attributes, RegionRange regions,
2001 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2002 Conv3DOpAdaptor op(operands.getValues(), attributes);
2003 ArrayRef<Attribute> explicit_padding;
2004 ArrayAttr explicit_pad =
2005 attributes.get("explicit_paddings").dyn_cast_or_null<::mlir::ArrayAttr>();
2006 if (!explicit_pad) {
2007 explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({});
2008 }
2009 explicit_padding = explicit_pad.getValue();
2010
2011 return inferConvReturnTypeComponents(location, op, explicit_padding,
2012 inferredReturnShapes);
2013 }
2014
2015 //===----------------------------------------------------------------------===//
2016 // DataFormatVecPermuteOp
2017 //===----------------------------------------------------------------------===//
2018
verify()2019 LogicalResult DataFormatVecPermuteOp::verify() {
2020 DataFormatVecPermuteOp op = *this;
2021 auto input_ty = op.x().getType().dyn_cast<RankedTensorType>();
2022 if (!input_ty) return success();
2023
2024 int rank = input_ty.getRank();
2025 if (rank != 1 && rank != 2)
2026 return op.emitOpError("requires input of rank 1 or 2");
2027
2028 if (rank == 1) {
2029 int64_t dim0 = input_ty.getDimSize(0);
2030 if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2)
2031 return op.emitOpError("requires 1D input of size 4 or size 2");
2032 }
2033
2034 if (rank == 2) {
2035 int64_t dim0 = input_ty.getDimSize(0);
2036 if (dim0 != ShapedType::kDynamicSize && dim0 != 4)
2037 return op.emitOpError(
2038 "requires first dimensions of 2D input to be of size 4");
2039
2040 int64_t dim1 = input_ty.getDimSize(1);
2041 if (dim1 != ShapedType::kDynamicSize && dim1 != 2)
2042 return op.emitOpError(
2043 "requires second dimensions of 2D input to be of size 2");
2044 }
2045
2046 return success();
2047 }
2048
2049 //===----------------------------------------------------------------------===//
2050 // DivNoNanOp
2051 //===----------------------------------------------------------------------===//
2052
2053 namespace {
2054
2055 /// Canonicalization template for tf.DivNoNan and tf.MulNoNan:
2056 /// If the op is tf.DivNoNan and the divisor is a constant tensor (with all the
2057 /// elements of any allowed type: float or complex), rewrite the op to the
2058 /// divisor if all the elements of the divisor are zero and to tf.Div if all the
2059 /// elements of the divisor are non-zero.
2060
2061 /// Similarly, if the op is tf.MulNoNan and the multiplier is a constant tensor
2062 /// (with all the elements of any allowed type: float or complex), rewrite the
2063 /// op to the multiplier if all the elements of the multiplier are zero and to
2064 /// tf.Mul if all the elements of the multiplier are non-zero.
2065
2066 /// Replace the given op with an op of type `RetT`. Upon calling
2067 /// DivNoNanOrMulNoNanConstantY for canonicalizing tf.DivNoNan, tf.DivOp is
2068 /// passed as the second argument and for canonicalizing tf.MulNoNan, tf.MulOp
2069 /// is passed as the second argument.
2070 template <typename OpT, typename RetT>
2071 class DivNoNanOrMulNoNanConstantY : public OpRewritePattern<OpT> {
2072 using OpRewritePattern<OpT>::OpRewritePattern;
2073
matchAndRewrite(OpT op,PatternRewriter & rewriter) const2074 LogicalResult matchAndRewrite(OpT op,
2075 PatternRewriter &rewriter) const override {
2076 static_assert(
2077 llvm::is_one_of<OpT, DivNoNanOp, MulNoNanOp>::value,
2078 "only canonicalization of tf.DivNoNan and tf.MulNoNan is supported");
2079
2080 // Returns true iff `val` (a complex constant with float real and imaginary
2081 // parts) is zero.
2082 auto complexIsZero = [](const std::complex<APFloat> val) {
2083 // Note that when `val` is of complex type, it is zero iff both
2084 // its real and imaginary parts are zero.
2085 if (val.real().isZero() && val.imag().isZero())
2086 return true;
2087 else
2088 return false;
2089 };
2090
2091 // Returns true iff `attr` has both zero and non-zero elements
2092 // (float/complex type) in `attr`.
2093 auto hasBothZeroAndNonzeroElements =
2094 [&complexIsZero](ElementsAttr attr, bool hasComplexElements) {
2095 bool foundZero = false, foundNonzero = false;
2096 if (!hasComplexElements) {
2097 for (const auto val : attr.getValues<APFloat>()) {
2098 if (val.isZero())
2099 foundZero = true;
2100 else
2101 foundNonzero = true;
2102 if (foundZero && foundNonzero) return true;
2103 }
2104 } else {
2105 for (const auto val : attr.getValues<std::complex<APFloat>>()) {
2106 if (complexIsZero(val))
2107 foundZero = true;
2108 else
2109 foundNonzero = true;
2110 if (foundZero && foundNonzero) return true;
2111 }
2112 }
2113 return false;
2114 };
2115
2116 // Note that `y` is the divisor if the op is tf.DivNoNan and it is the
2117 // multiplier if the op is tf.MulNoNan.
2118 Value y = op.y();
2119 // The below if condition is true iff `y.getDefiningOp()` is of the type
2120 // TF::ConstOp, i.e., if `y` is defined by an op and it is the tf.Const op.
2121 // In that case, `yDefOp` stores this tf.Const op.
2122 // Note that if `y` is a block argument, `y.getDefiningOp()` will return
2123 // null, which will get propogated by dyn_cast_or_null to `yDefOp`.
2124 // Further, if `y` is defined by an op other than tf.Const,
2125 // `y.getDefiningOp()` will not return null but dyn_cast_or_null will.
2126 if (auto yDefOp = dyn_cast_or_null<TF::ConstOp>(y.getDefiningOp())) {
2127 Type typeOfElementsInY = getElementTypeOrSelf(y.getType());
2128 ElementsAttr attr = yDefOp.value();
2129 bool yHasComplexElements = typeOfElementsInY.isa<ComplexType>();
2130
2131 // If `y` is a splat constant, then the op will definitely get replaced.
2132 // We check for a splat constant first, in order to optimize the
2133 // performance of this canonicalization because this check will be O(1).
2134 if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
2135 bool splatAttrIsZero = false;
2136 if (!yHasComplexElements) {
2137 if (splatAttr.getSplatValue<APFloat>().isZero())
2138 splatAttrIsZero = true;
2139 } else {
2140 if (complexIsZero(splatAttr.getSplatValue<std::complex<APFloat>>()))
2141 splatAttrIsZero = true;
2142 }
2143 if (splatAttrIsZero) {
2144 // When `y` is a zero splat constant (i.e., all the elements in `y`
2145 // are zero, replace the op (tf.divNoNan or tf.MulNoNan) with `y`.
2146 rewriter.replaceOp(op, y);
2147 } else {
2148 // When `y` is a non-zero splat constant, replace tf.DivNoNan with
2149 // tf.Div and tf.MulNoNan with tf.Mul.
2150 rewriter.replaceOpWithNewOp<RetT>(op, op->getResult(0).getType(),
2151 op->getOperand(0),
2152 op->getOperand(1));
2153 }
2154 return success();
2155 }
2156
2157 // If `y` has both zero and non-zero elements, do nothing.
2158 if (hasBothZeroAndNonzeroElements(attr, yHasComplexElements)) {
2159 return failure();
2160 } else {
2161 // When all the elements in `y` are non-splat and non-zero, replace
2162 // tf.DivNoNan with tf.Div and tf.MulNoNan with tf.Mul.
2163 rewriter.replaceOpWithNewOp<RetT>(op, op->getResult(0).getType(),
2164 op->getOperand(0), op->getOperand(1));
2165 return success();
2166 }
2167 }
2168 return failure();
2169 }
2170 };
2171 } // namespace
2172
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2173 void DivNoNanOp::getCanonicalizationPatterns(RewritePatternSet &results,
2174 MLIRContext *context) {
2175 results.add<DivNoNanOrMulNoNanConstantY<TF::DivNoNanOp, TF::DivOp>>(context);
2176 }
2177
2178 //===----------------------------------------------------------------------===//
2179 // DivOp
2180 //===----------------------------------------------------------------------===//
2181
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2182 void DivOp::getCanonicalizationPatterns(RewritePatternSet &results,
2183 MLIRContext *context) {
2184 results.add<DivWithSqrtDivisor>(context);
2185 }
2186
fold(ArrayRef<Attribute> operands)2187 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
2188 return IdentityArithmeticOpFolder<DivOp>(*this, operands);
2189 }
2190
2191 //===----------------------------------------------------------------------===//
2192 // DynamicStitchOp
2193 //===----------------------------------------------------------------------===//
2194
verify()2195 LogicalResult DynamicStitchOp::verify() {
2196 DynamicStitchOp op = *this;
2197 if (op.N() < 1) return op.emitOpError("requires attribute N with value >= 1");
2198
2199 if (RankedTensorType out_ty = op.getType().dyn_cast<RankedTensorType>()) {
2200 if (out_ty.getRank() == 0) {
2201 return op.emitOpError("requires non scalar output");
2202 }
2203 }
2204
2205 llvm::SmallDenseSet<int64_t, 8> index_values;
2206 bool all_indices_const = true;
2207 int32_t max_index = -1;
2208 llvm::Optional<SmallVector<int64_t, 4>> inferred_item_shape;
2209 for (auto it : llvm::zip(op.indices(), op.data())) {
2210 Value index = std::get<0>(it);
2211
2212 DenseIntElementsAttr index_attr;
2213 if (matchPattern(index, m_Constant(&index_attr))) {
2214 for (int32_t index : index_attr.getValues<int32_t>()) {
2215 if (index < 0)
2216 return op.emitOpError()
2217 << "requires non-negative index values; found " << index;
2218 max_index = std::max(index, max_index);
2219 index_values.insert(index);
2220 }
2221 } else {
2222 all_indices_const = false;
2223 }
2224
2225 Value data = std::get<1>(it);
2226 RankedTensorType index_ty = index.getType().dyn_cast<RankedTensorType>();
2227 RankedTensorType data_ty = data.getType().dyn_cast<RankedTensorType>();
2228 if (!index_ty || !data_ty) continue;
2229
2230 int64_t index_rank = index_ty.getRank();
2231 ArrayRef<int64_t> data_shape = data_ty.getShape();
2232 ArrayRef<int64_t> index_shape = index_ty.getShape();
2233 if (failed(mlir::verifyCompatibleShape(index_shape,
2234 data_shape.take_front(index_rank))))
2235 return op.emitOpError() << "requires shape of data with type " << data_ty
2236 << " to have prefix matching with shape of the "
2237 "corresponding index type "
2238 << index_ty;
2239
2240 ArrayRef<int64_t> item_shape = data_shape.drop_front(index_rank);
2241 if (!inferred_item_shape) {
2242 inferred_item_shape = llvm::to_vector<4>(item_shape);
2243 continue;
2244 }
2245
2246 if (failed(mlir::verifyCompatibleShape(item_shape, *inferred_item_shape)))
2247 return op.emitOpError() << "has inconsistent shaped data and index "
2248 "pairs; inferred item shapes ["
2249 << llvm::makeArrayRef(*inferred_item_shape)
2250 << "] and [" << item_shape << "] don't match";
2251 for (int i = 0, e = item_shape.size(); i < e; ++i) {
2252 int64_t &inferred_dim = (*inferred_item_shape)[i];
2253 int64_t dim = item_shape[i];
2254 if (ShapedType::isDynamic(inferred_dim)) inferred_dim = dim;
2255 }
2256 }
2257
2258 // If all indices are constants, then verify that they cover all indices in
2259 // the range [0, max_index] and the output type is legal.
2260 if (all_indices_const) {
2261 for (int32_t i = 0; i <= max_index; i++) {
2262 if (!index_values.count(i))
2263 return op.emitOpError() << "missing index " << i;
2264 }
2265
2266 if (inferred_item_shape) {
2267 SmallVector<int64_t, 4> expected_shape;
2268 expected_shape.push_back(max_index + 1);
2269 expected_shape.append(inferred_item_shape->begin(),
2270 inferred_item_shape->end());
2271
2272 auto out_ty = op.getType().cast<TensorType>();
2273 auto expected_out_ty =
2274 RankedTensorType::get(expected_shape, out_ty.getElementType());
2275
2276 if (!AreCastCompatible({out_ty, expected_out_ty})) {
2277 return op.emitOpError() << "has invalid output type; should be "
2278 "compatible with inferred type "
2279 << expected_out_ty;
2280 }
2281 }
2282 }
2283
2284 return success();
2285 }
2286
2287 //===----------------------------------------------------------------------===//
2288 // EinsumOp
2289 //===----------------------------------------------------------------------===//
2290
2291 // Verifies that,
2292 // * Arity of the op is at most two.
2293 //
2294 // TODO(hinsu): Verify einsum equation attribute.
verify()2295 LogicalResult EinsumOp::verify() {
2296 EinsumOp op = *this;
2297 if (op.N() > 2) {
2298 return op.emitOpError("supports at most two operands");
2299 }
2300 return success();
2301 }
2302
2303 //===----------------------------------------------------------------------===//
2304 // EmptyOp
2305 //===----------------------------------------------------------------------===//
2306
fold(ArrayRef<Attribute> operands)2307 OpFoldResult EmptyOp::fold(ArrayRef<Attribute> operands) {
2308 assert(operands.size() == 1 && "empty op has one operand");
2309
2310 Attribute attr = operands.front();
2311 if (!attr) return {};
2312
2313 auto int_attr = attr.cast<DenseIntElementsAttr>();
2314 SmallVector<int64_t, 6> out_shape;
2315 for (const auto val : int_attr.getValues<int32_t>()) {
2316 out_shape.push_back(val);
2317 }
2318
2319 auto type = getResult().getType().cast<ShapedType>();
2320 auto etype = type.getElementType();
2321
2322 // We can not fold if the result is not static.
2323 if (!type.hasStaticShape()) return {};
2324
2325 if (auto float_type = etype.dyn_cast<FloatType>()) {
2326 auto out_type = RankedTensorType::get(out_shape, float_type);
2327 return DenseElementsAttr::get(out_type,
2328 {APFloat(float_type.getFloatSemantics())});
2329 }
2330
2331 if (auto int_type = etype.dyn_cast<IntegerType>()) {
2332 auto out_type = RankedTensorType::get(out_shape, etype);
2333 APInt val(int_type.getWidth(), 0, int_type.getSignedness());
2334 return DenseElementsAttr::get(out_type, val);
2335 }
2336
2337 return {};
2338 }
2339
2340 //===----------------------------------------------------------------------===//
2341 // EmptyTensorListOp
2342 //===----------------------------------------------------------------------===//
2343
verify()2344 LogicalResult EmptyTensorListOp::verify() {
2345 EmptyTensorListOp op = *this;
2346 // This is required to populate derived attributes during export in a
2347 // meaningful way. Else during export to GraphDef element_type() query
2348 // will result in out of bounds access/assert.
2349 if (handle_dtype().getSubtypes().size() != 1) {
2350 return emitOpError(
2351 "must have exactly one subtype in the result variant type");
2352 }
2353
2354 if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2355 !IsOfRankOrUnranked(op.element_shape(), 1)) {
2356 return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2357 }
2358
2359 if (!IsOfRankOrUnranked(op.max_num_elements(), 0)) {
2360 return op.emitOpError("requires max_num_elements operand to be 0D tensor");
2361 }
2362 return success();
2363 }
2364
2365 //===----------------------------------------------------------------------===//
2366 // EnqueueTPUEmbedding ops
2367 //===----------------------------------------------------------------------===//
2368
2369 // For EnqueueTPUEmbedding ops the device ordinal corresponds to the resource
2370 // instance.
2371
2372 std::string
GetResourceInstanceStr()2373 EnqueueTPUEmbeddingArbitraryTensorBatchOp::GetResourceInstanceStr() {
2374 return std::to_string(device_ordinal());
2375 }
2376
GetResourceInstanceStr()2377 std::string EnqueueTPUEmbeddingBatchOp::GetResourceInstanceStr() {
2378 return std::to_string(device_ordinal());
2379 }
2380
GetResourceInstanceStr()2381 std::string EnqueueTPUEmbeddingIntegerBatchOp::GetResourceInstanceStr() {
2382 return std::to_string(device_ordinal());
2383 }
2384
GetResourceInstanceStr()2385 std::string EnqueueTPUEmbeddingRaggedTensorBatchOp::GetResourceInstanceStr() {
2386 return std::to_string(device_ordinal());
2387 }
2388
GetResourceInstanceStr()2389 std::string EnqueueTPUEmbeddingSparseBatchOp::GetResourceInstanceStr() {
2390 return std::to_string(device_ordinal());
2391 }
2392
GetResourceInstanceStr()2393 std::string EnqueueTPUEmbeddingSparseTensorBatchOp::GetResourceInstanceStr() {
2394 return std::to_string(device_ordinal());
2395 }
2396
2397 //===----------------------------------------------------------------------===//
2398 // EnsureShapeOp
2399 //===----------------------------------------------------------------------===//
2400
fold(llvm::ArrayRef<mlir::Attribute>)2401 OpFoldResult EnsureShapeOp::fold(llvm::ArrayRef<mlir::Attribute>) {
2402 ShapedType type = input().getType().dyn_cast<ShapedType>();
2403 if (!type || !type.hasRank()) return {};
2404 // If shape attribute equals input operand's type's shape, fold it to input.
2405 llvm::Optional<llvm::ArrayRef<int64_t>> shape_constraint = shape();
2406 if (type.getShape() == shape_constraint) return input();
2407
2408 // If input operand's type's shape always satisfies the shape attribute, fold
2409 // it to input.
2410 if (shape_constraint.has_value() &&
2411 shape_constraint->size() == type.getShape().size()) {
2412 for (int i = 0; i < shape_constraint->size(); ++i) {
2413 if (!ShapedType::isDynamic(shape_constraint.getValue()[i]) &&
2414 type.getDimSize(i) != shape_constraint.getValue()[i]) {
2415 return {};
2416 }
2417 }
2418 return input();
2419 }
2420 // Else retain to enable failing dynamically.
2421 return {};
2422 }
2423
2424 //===----------------------------------------------------------------------===//
2425 // EqualOp/NotEqualOp
2426 //===----------------------------------------------------------------------===//
2427
verify()2428 LogicalResult EqualOp::verify() {
2429 EqualOp op = *this;
2430 // If we allow inputs to have incompatible type, then nothing to do.
2431 if (!op.incompatible_shape_error()) return success();
2432
2433 // Otherwise, check inputs are broadcastable.
2434 return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
2435 op.getOperation());
2436 }
2437
build(OpBuilder & builder,OperationState & result,Value x,Value y,BoolAttr incompatible_shape_error)2438 void EqualOp::build(OpBuilder &builder, OperationState &result, Value x,
2439 Value y, BoolAttr incompatible_shape_error) {
2440 auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y,
2441 incompatible_shape_error);
2442 return build(builder, result, result_type, x, y, incompatible_shape_error);
2443 }
2444
2445 namespace {
2446
2447 // Flips the incompatible_shape_error attribute to true if the shapes are known
2448 // to be compatible.
2449 template <typename Ty>
flipComatibleShapeError(Ty op,PatternRewriter & rewriter)2450 static LogicalResult flipComatibleShapeError(Ty op, PatternRewriter &rewriter) {
2451 if (op.incompatible_shape_error()) {
2452 return rewriter.notifyMatchFailure(op, "the attribute is already true");
2453 }
2454
2455 // incompatible_shape_error=false implies that the op will either return a
2456 // valid result or a scalar boolean indicating the error. For unranked outputs
2457 // we don't know which one it is. TF shape inference turns unranked outputs
2458 // into ranked ones if it can statically evaluate the broadcast, see the shape
2459 // function of tf.Equal.
2460 auto ty = op.getType().template dyn_cast<RankedTensorType>();
2461 if (!ty) {
2462 return rewriter.notifyMatchFailure(op, "requires a ranked output shape");
2463 }
2464
2465 // Unless this is a scalar compare, a scalar output indicates that this will
2466 // always fail.
2467 auto x_ty = op.x().getType().template dyn_cast<RankedTensorType>();
2468 auto y_ty = op.y().getType().template dyn_cast<RankedTensorType>();
2469 if (ty.getRank() == 0 &&
2470 (!x_ty || x_ty.getRank() != 0 || !y_ty || y_ty.getRank() != 0)) {
2471 return rewriter.notifyMatchFailure(op, "output rank must match input rank");
2472 }
2473
2474 // Shapes are known to be compatible.
2475 rewriter.template replaceOpWithNewOp<Ty>(op, op.x(), op.y(),
2476 rewriter.getBoolAttr(true));
2477 return success();
2478 }
2479 } // namespace
2480
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2481 void EqualOp::getCanonicalizationPatterns(RewritePatternSet &results,
2482 MLIRContext *context) {
2483 results.add(flipComatibleShapeError<EqualOp>);
2484 }
2485
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2486 void NotEqualOp::getCanonicalizationPatterns(RewritePatternSet &results,
2487 MLIRContext *context) {
2488 results.add(flipComatibleShapeError<NotEqualOp>);
2489 }
2490
2491 //===----------------------------------------------------------------------===//
2492 // ExpandDimsOp
2493 //===----------------------------------------------------------------------===//
2494
InferExpandDimsOpType(Value input,Value dim)2495 Type InferExpandDimsOpType(Value input, Value dim) {
2496 Type element_ty = input.getType().cast<TensorType>().getElementType();
2497 auto unranked_ty = UnrankedTensorType::get(element_ty);
2498
2499 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
2500 if (!input_ty) return unranked_ty;
2501
2502 DenseIntElementsAttr dim_attr;
2503 if (!matchPattern(dim, m_Constant(&dim_attr)) ||
2504 dim_attr.getNumElements() != 1)
2505 return unranked_ty;
2506 int64_t dim_val = (*dim_attr.begin()).getSExtValue();
2507 int64_t input_rank = input_ty.getRank();
2508
2509 if (dim_val < -input_rank - 1 || dim_val > input_rank + 1) return unranked_ty;
2510 if (dim_val < 0) dim_val += input_rank + 1;
2511
2512 SmallVector<int64_t, 4> shape = llvm::to_vector<4>(input_ty.getShape());
2513 shape.insert(shape.begin() + dim_val, 1);
2514 return RankedTensorType::get(shape, element_ty);
2515 }
2516
build(OpBuilder & builder,OperationState & result,Value input,Value dim)2517 void ExpandDimsOp::build(OpBuilder &builder, OperationState &result,
2518 Value input, Value dim) {
2519 return build(builder, result, InferExpandDimsOpType(input, dim), input, dim);
2520 }
2521
2522 //===----------------------------------------------------------------------===//
2523 // Expm1Op
2524 //===----------------------------------------------------------------------===//
2525
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)2526 LogicalResult Expm1Op::inferReturnTypeComponents(
2527 MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
2528 DictionaryAttr attributes, RegionRange regions,
2529 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2530 ShapeAdaptor adaptor = operands.getShape(0);
2531 ShapedTypeComponents component(adaptor.getElementType());
2532 if (adaptor.hasRank()) adaptor.getDims(component);
2533 inferredReturnShapes.push_back(component);
2534 return success();
2535 }
2536
2537 //===----------------------------------------------------------------------===//
2538 // FakeQuantWithMinMaxArgsOp
2539 //===----------------------------------------------------------------------===//
verify()2540 LogicalResult FakeQuantWithMinMaxArgsOp::verify() {
2541 FakeQuantWithMinMaxArgsOp op = *this;
2542 // TODO(fengliuai): moving the following to an utility method.
2543 const llvm::fltSemantics &semantics = op.min().getSemantics();
2544 float rmin, rmax;
2545 if (&semantics == &APFloat::IEEEsingle()) {
2546 rmin = op.min().convertToFloat();
2547 rmax = op.max().convertToFloat();
2548 } else {
2549 rmin = op.min().convertToDouble();
2550 rmax = op.max().convertToDouble();
2551 }
2552 // Range boundaries must be valid.
2553 if (rmin >= rmax) {
2554 return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) +
2555 "," + Twine(std::to_string(rmax)) + "]");
2556 }
2557 int64_t num_bits = op.num_bits();
2558 if (num_bits < 2 || num_bits > 16) {
2559 return op.emitOpError(
2560 "requires num_bits to be between 2 and 16, inclusive");
2561 }
2562 return success();
2563 }
2564
2565 //===----------------------------------------------------------------------===//
2566 // FakeQuantWithMinMaxVarsOp
2567 //===----------------------------------------------------------------------===//
verify()2568 LogicalResult FakeQuantWithMinMaxVarsOp::verify() {
2569 FakeQuantWithMinMaxVarsOp op = *this;
2570 auto min = GetRankedTensorTypeForOperand(op.min());
2571 if (min && !IsOfRankedFloatTensorType(min, 0))
2572 return op.emitOpError("requires min to be a 0d float tensor");
2573
2574 auto max = GetRankedTensorTypeForOperand(op.max());
2575 if (max && !IsOfRankedFloatTensorType(max, 0))
2576 return op.emitOpError("requires max to be a 0d float tensor");
2577
2578 int64_t num_bits = op.num_bits();
2579 if (num_bits < 2 || num_bits > 16) {
2580 return op.emitOpError(
2581 "requires num_bits to be between 2 and 16, inclusive");
2582 }
2583 return success();
2584 }
2585
2586 //===----------------------------------------------------------------------===//
2587 // FakeQuantWithMinMaxVarsPerChannelOp
2588 //===----------------------------------------------------------------------===//
verify()2589 LogicalResult FakeQuantWithMinMaxVarsPerChannelOp::verify() {
2590 FakeQuantWithMinMaxVarsPerChannelOp op = *this;
2591 auto min = GetRankedTensorTypeForOperand(op.min());
2592 if (min && !IsOfRankedFloatTensorType(min, 1))
2593 return op.emitOpError("requires min to be a 1d float tensor");
2594
2595 auto max = GetRankedTensorTypeForOperand(op.max());
2596 if (max && !IsOfRankedFloatTensorType(max, 1))
2597 return op.emitOpError("requires max to be a 1d float tensor");
2598
2599 Value inputs = op.inputs();
2600 if (!HasRankAtLeast(inputs, 1))
2601 return op.emitError("requires inputs to be at least 1d float tensor");
2602
2603 int64_t num_bits = op.num_bits();
2604 if (num_bits < 2 || num_bits > 16) {
2605 return op.emitOpError(
2606 "requires num_bits to be between 2 and 16, inclusive");
2607 }
2608
2609 auto inputs_type = inputs.getType().dyn_cast<RankedTensorType>();
2610 if (!inputs_type) return success();
2611 int depth = inputs_type.getDimSize(inputs_type.getRank() - 1);
2612 if ((min && min.getDimSize(0) != depth) ||
2613 (max && max.getDimSize(0) != depth)) {
2614 return op.emitOpError(
2615 "requires min and max to have same size as last dimension of inputs");
2616 }
2617
2618 return success();
2619 }
2620
2621 //===----------------------------------------------------------------------===//
2622 // FillOp
2623 //===----------------------------------------------------------------------===//
2624
verify()2625 LogicalResult FillOp::verify() {
2626 FillOp op = *this;
2627 if (!IsOfRankOrUnranked(op.dims(), 1))
2628 return op.emitOpError() << "requires dims to be a 1D tensor";
2629 if (!IsOfRankOrUnranked(op.value(), 0))
2630 return op.emitOpError() << "requires value to be a scalar";
2631
2632 return success();
2633 }
2634
InferFillOpType(Value dims,Value value)2635 static ShapedType InferFillOpType(Value dims, Value value) {
2636 Type etype = value.getType().cast<ShapedType>().getElementType();
2637
2638 DenseIntElementsAttr dims_attr;
2639 if (matchPattern(dims, m_Constant(&dims_attr))) {
2640 llvm::SmallVector<int64_t, 4> shape;
2641 shape.reserve(dims_attr.getNumElements());
2642 for (const APInt dim : dims_attr.getValues<APInt>()) {
2643 shape.push_back(dim.getSExtValue());
2644 }
2645 return RankedTensorType::get(shape, etype);
2646 }
2647
2648 if (auto shape_op = dims.getDefiningOp<ShapeOp>()) {
2649 if (auto t = shape_op.input().getType().dyn_cast<ShapedType>()) {
2650 return t;
2651 }
2652 }
2653
2654 return UnrankedTensorType::get(etype);
2655 }
2656
build(OpBuilder & builder,OperationState & result,Value dims,Value value)2657 void FillOp::build(OpBuilder &builder, OperationState &result, Value dims,
2658 Value value) {
2659 FillOp::build(builder, result, InferFillOpType(dims, value), dims, value);
2660 }
2661
fold(ArrayRef<Attribute> operands)2662 OpFoldResult FillOp::fold(ArrayRef<Attribute> operands) {
2663 assert(operands.size() == 2 && "fill op has two operand");
2664
2665 auto type = getType().cast<ShapedType>();
2666 // DenseElementsAttr that is used in this folder only supports int and float
2667 // types.
2668 // TODO(hinsu): Handle complex types once there is a attribute kind for
2669 // complex.
2670 if (!type.getElementType().isIntOrFloat()) return {};
2671
2672 auto value = operands[1].dyn_cast_or_null<ElementsAttr>();
2673 if (!value) return {};
2674
2675 if (type.hasStaticShape())
2676 return DenseElementsAttr::get(type, value.getValues<Attribute>()[0]);
2677
2678 auto dims = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
2679 if (!dims) return {};
2680
2681 llvm::SmallVector<int64_t, 4> shape;
2682 shape.reserve(dims.getNumElements());
2683 for (const APInt dim : dims.getValues<APInt>()) {
2684 shape.push_back(dim.getSExtValue());
2685 }
2686 type = RankedTensorType::get(shape, type.getElementType());
2687
2688 return DenseElementsAttr::get(type, value.getValues<Attribute>()[0]);
2689 }
2690
2691 //===----------------------------------------------------------------------===//
2692 // FusedBatchNormGradOp
2693 //===----------------------------------------------------------------------===//
2694
2695 // TODO(b/150954845): Add benchmarks to verify that layout preference didn't
2696 // change in the latest GPU generations.
2697
UpdateDataFormat(StringRef data_format)2698 LogicalResult FusedBatchNormGradV3Op::UpdateDataFormat(StringRef data_format) {
2699 return ::mlir::TF::UpdateDataFormat(data_format, this);
2700 }
2701
GetOptimalLayout(const RuntimeDevices & devices)2702 StringRef FusedBatchNormGradV3Op::GetOptimalLayout(
2703 const RuntimeDevices &devices) {
2704 // Keep current data format if no GPUs are available or if explicit placement
2705 // does not allow to use GPU for this operation.
2706 if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
2707 return data_format();
2708
2709 // For f16 data type on devices with Tensor Cores support NHWC data format
2710 // is up to ~2x faster.
2711 auto x_ty = x().getType().cast<TensorType>();
2712 const bool is_f16 = x_ty.getElementType().isF16();
2713 if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
2714
2715 // For all other data types prefer NCHW.
2716 return "NCHW";
2717 }
2718
2719 //===----------------------------------------------------------------------===//
2720 // FusedBatchNormOp
2721 //===----------------------------------------------------------------------===//
2722
verify()2723 LogicalResult FusedBatchNormOp::verify() {
2724 FusedBatchNormOp op = *this;
2725 auto x = GetRankedTensorTypeForOperand(op.x());
2726 if (x && !IsOfRankedFloatTensorType(x, 4))
2727 return op.emitOpError("requires x to be a 4D float tensor");
2728
2729 auto scale = GetRankedTensorTypeForOperand(op.scale());
2730 if (scale && !IsOfRankedFloatTensorType(scale, 1))
2731 return op.emitOpError("requires scale to be a 1D float tensor");
2732
2733 auto offset = GetRankedTensorTypeForOperand(op.offset());
2734 if (offset && !IsOfRankedFloatTensorType(offset, 1))
2735 return op.emitOpError("requires offset to be a 1D float tensor");
2736
2737 auto mean = GetRankedTensorTypeForOperand(op.mean());
2738 if (mean && !IsOfRankedFloatTensorType(mean, 1))
2739 return op.emitOpError("requires mean to be a 1D float tensor");
2740
2741 auto variance = GetRankedTensorTypeForOperand(op.variance());
2742 if (variance && !IsOfRankedFloatTensorType(variance, 1))
2743 return op.emitOpError("requires variance to be a 1D float tensor");
2744
2745 // TODO(antiagainst): check attributes
2746
2747 return success();
2748 }
2749
2750 //===----------------------------------------------------------------------===//
2751 // FusedBatchNormV2Op / FusedBatchNormV3Op
2752 //===----------------------------------------------------------------------===//
2753
2754 template <class Op>
InferenceFoldOperandsPermutation(ArrayRef<int64_t> permutation,Op * op)2755 static LogicalResult InferenceFoldOperandsPermutation(
2756 ArrayRef<int64_t> permutation, Op *op) {
2757 // FusedBatchNorm in training mode is a layout sentitive operation, and should
2758 // have already assigned an optimal data format.
2759 if (op->is_training()) return failure();
2760 return ::mlir::TF::FoldOperandsPermutation(permutation, op);
2761 }
2762
2763 template <class Op>
GetOptimalLayout(const RuntimeDevices & devices,Op * op)2764 static StringRef GetOptimalLayout(const RuntimeDevices &devices, Op *op) {
2765 // In inference mode FusedBatchNorm is not sensitive to data layout.
2766 if (!op->is_training()) return op->data_format();
2767
2768 // Keep current data format if no GPUs are available or if explicit placement
2769 // does not allow to use GPU for this operation.
2770 if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(op->getOperation()))
2771 return op->data_format();
2772
2773 // For f16 data type on devices with Tensor Cores support NHWC data format
2774 // is up to ~2x faster.
2775 auto x_ty = op->x().getType().template cast<TensorType>();
2776 const bool is_f16 = x_ty.getElementType().isF16();
2777 if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
2778
2779 // For all other data types prefer NCHW.
2780 return "NCHW";
2781 }
2782
FoldOperandsPermutation(ArrayRef<int64_t> permutation)2783 LogicalResult FusedBatchNormV2Op::FoldOperandsPermutation(
2784 ArrayRef<int64_t> permutation) {
2785 return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this);
2786 }
2787
UpdateDataFormat(StringRef data_format)2788 LogicalResult FusedBatchNormV2Op::UpdateDataFormat(StringRef data_format) {
2789 return ::mlir::TF::UpdateDataFormat(data_format, this);
2790 }
2791
GetOptimalLayout(const RuntimeDevices & devices)2792 StringRef FusedBatchNormV2Op::GetOptimalLayout(const RuntimeDevices &devices) {
2793 return ::mlir::TF::GetOptimalLayout(devices, this);
2794 }
2795
FoldOperandsPermutation(ArrayRef<int64_t> permutation)2796 LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation(
2797 ArrayRef<int64_t> permutation) {
2798 return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this);
2799 }
2800
UpdateDataFormat(StringRef data_format)2801 LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) {
2802 return ::mlir::TF::UpdateDataFormat(data_format, this);
2803 }
2804
GetOptimalLayout(const RuntimeDevices & devices)2805 StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) {
2806 return ::mlir::TF::GetOptimalLayout(devices, this);
2807 }
2808
2809 //===----------------------------------------------------------------------===//
2810 // GatherV2Op
2811 //===----------------------------------------------------------------------===//
2812
verify()2813 LogicalResult GatherV2Op::verify() {
2814 GatherV2Op op = *this;
2815 int64_t batch_dims = op.batch_dims();
2816 if (auto ty = op.indices().getType().dyn_cast<RankedTensorType>()) {
2817 int64_t rank = ty.getRank();
2818 if (batch_dims > rank || batch_dims < -rank)
2819 return op.emitOpError()
2820 << "batch_dims (" << batch_dims << ") must be in range [" << -rank
2821 << ", " << rank + 1 << ")";
2822 if (batch_dims < 0) batch_dims += rank;
2823 }
2824
2825 if (!HasRankAtMost(op.axis(), 1))
2826 return op.emitOpError("requires axis to have rank at most 1");
2827
2828 DenseIntElementsAttr axis_attr;
2829 if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
2830 int64_t axis = (*axis_attr.begin()).getSExtValue();
2831 if (auto ty = op.params().getType().dyn_cast<RankedTensorType>()) {
2832 int64_t rank = ty.getRank();
2833 if (axis >= rank || axis < -rank)
2834 return op.emitOpError() << "axis (" << axis << ") must be in range ["
2835 << -rank << ", " << rank << ")";
2836 if (axis < 0) axis += rank;
2837 }
2838
2839 if (batch_dims >= 0 && axis >= 0 && axis < batch_dims) {
2840 return op.emitOpError() << "requires axis (" << axis
2841 << ") to be greater than or equal to batch_dims ("
2842 << batch_dims << ")";
2843 }
2844 }
2845 return success();
2846 }
2847
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2848 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
2849 MLIRContext *context) {
2850 results.add<GatherToV2>(context);
2851 }
2852
2853 //===----------------------------------------------------------------------===//
2854 // IfOp
2855 //===----------------------------------------------------------------------===//
2856
verifySymbolUses(SymbolTableCollection & symbol_table)2857 LogicalResult IfOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
2858 auto branch_name = [](unsigned index) -> std::string {
2859 return index == 0 ? "'then_branch'" : "'else_branch'";
2860 };
2861 return VerifyCaseOrIfOpBranchFunctions(
2862 symbol_table, *this, {then_branchAttr(), else_branchAttr()}, branch_name);
2863 }
2864
2865 //===----------------------------------------------------------------------===//
2866 // IfOp canonicalization.
2867 //===----------------------------------------------------------------------===//
2868
2869 namespace {
2870 class FoldConstantIfOp : public OpRewritePattern<TF::IfOp> {
2871 public:
FoldConstantIfOp(MLIRContext * context)2872 explicit FoldConstantIfOp(MLIRContext *context)
2873 : OpRewritePattern<TF::IfOp>(context) {}
2874 LogicalResult matchAndRewrite(TF::IfOp op,
2875 PatternRewriter &rewriter) const override;
2876
2877 private:
2878 template <typename T>
2879 struct CallOpType {
2880 using CallOp = T;
2881 };
2882 };
2883
matchAndRewrite(TF::IfOp op,PatternRewriter & rewriter) const2884 LogicalResult FoldConstantIfOp::matchAndRewrite(
2885 TF::IfOp op, PatternRewriter &rewriter) const {
2886 // Extract the constant cond value.
2887 DenseIntElementsAttr cond_attr;
2888 if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure();
2889
2890 // Cond value must be a scalar.
2891 if (cond_attr.getNumElements() != 1) return failure();
2892
2893 // Select a branch function.
2894 bool cond = cond_attr.getSplatValue<BoolAttr>().getValue();
2895 FlatSymbolRefAttr func = cond ? op.then_branchAttr() : op.else_branchAttr();
2896
2897 // Replace IfOp with PartitionedCallOp or StatefulPartitionedCallOp.
2898 auto rewrite = [&](auto op_type) {
2899 auto empty = rewriter.getStringAttr("");
2900 ReplaceTfOpWithNewOp<typename decltype(op_type)::CallOp>(
2901 rewriter, op, op.getResultTypes(), op.input(), func,
2902 /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
2903 };
2904
2905 if (op.is_stateless())
2906 rewrite(CallOpType<PartitionedCallOp>{});
2907 else
2908 rewrite(CallOpType<StatefulPartitionedCallOp>{});
2909
2910 return success();
2911 }
2912 } // anonymous namespace
2913
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2914 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2915 MLIRContext *context) {
2916 results.add<FoldConstantIfOp, DropAttributes<IfOp>>(context);
2917 }
2918
2919 //===----------------------------------------------------------------------===//
2920 // IfRegionOp
2921 //===----------------------------------------------------------------------===//
2922
verifyRegions()2923 LogicalResult IfRegionOp::verifyRegions() {
2924 IfRegionOp op = *this;
2925 TypeRange then_types =
2926 op.then_branch().front().getTerminator()->getOperandTypes();
2927 TypeRange else_types =
2928 op.else_branch().front().getTerminator()->getOperandTypes();
2929
2930 TypeRangeWithDesc results{op.getResultTypes(), "result"};
2931 TypeRangeWithDesc then_results{then_types, "then result"};
2932 TypeRangeWithDesc else_results{else_types, "else result"};
2933
2934 if (failed(VerifyTypeRangesAreCompatible(op, then_results, results)))
2935 return failure();
2936 if (failed(VerifyTypeRangesAreCompatible(op, else_results, results)))
2937 return failure();
2938 return success();
2939 }
2940
2941 namespace {
2942 class FoldConstantIfRegionOp : public OpRewritePattern<TF::IfRegionOp> {
2943 public:
FoldConstantIfRegionOp(MLIRContext * context)2944 explicit FoldConstantIfRegionOp(MLIRContext *context)
2945 : OpRewritePattern<TF::IfRegionOp>(context) {}
2946 LogicalResult matchAndRewrite(TF::IfRegionOp op,
2947 PatternRewriter &rewriter) const override;
2948 };
2949
matchAndRewrite(TF::IfRegionOp op,PatternRewriter & rewriter) const2950 LogicalResult FoldConstantIfRegionOp::matchAndRewrite(
2951 TF::IfRegionOp op, PatternRewriter &rewriter) const {
2952 // Extract the constant cond value.
2953 DenseIntElementsAttr cond_attr;
2954 if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure();
2955
2956 // IfRegion condition should always be a scalar. Select the region to fold to.
2957 bool cond = cond_attr.getSplatValue<BoolAttr>().getValue();
2958 Region ®ion = cond ? op.then_branch() : op.else_branch();
2959
2960 // If the IfRegion is stateless but the region being inlined itself is not
2961 // stateless, then inlining the region could cause a loss of information.
2962 // However, its probably better to fold the IfRegion instead of having the
2963 // dead branch stay.
2964
2965 // Inline the region in place of the IfRegion op, and forward the yield
2966 // inputs to the IfRegion op results. This is possible only if the yield
2967 // types match the result types.
2968 auto yield = cast<YieldOp>(region.front().getTerminator());
2969 auto updated_results = llvm::to_vector<4>(yield.getOperands());
2970
2971 // If the yield types do not match the IfRegion result types, add appropriate
2972 // casts.
2973 rewriter.setInsertionPoint(yield);
2974 for (auto it : llvm::zip(op.getResultTypes(), updated_results)) {
2975 auto &updated_result = std::get<1>(it);
2976 Type result_type = std::get<0>(it);
2977 if (result_type != updated_result.getType()) {
2978 updated_result =
2979 rewriter.create<TF::CastOp>(op.getLoc(), result_type, updated_result,
2980 /*Truncate=*/rewriter.getBoolAttr(false));
2981 }
2982 }
2983 // Inline the region into the block containing the IfRegion.
2984 rewriter.mergeBlockBefore(®ion.front(), op);
2985 rewriter.eraseOp(yield);
2986 rewriter.replaceOp(op, updated_results);
2987 return success();
2988 }
2989 } // anonymous namespace
2990
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2991 void IfRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
2992 MLIRContext *context) {
2993 results.add<FoldConstantIfRegionOp,
2994 CaseOrIfRegionEliminatePassThrough<TF::IfRegionOp>>(context);
2995 }
2996
2997 //===----------------------------------------------------------------------===//
2998 // InvertPermutationOp
2999 //===----------------------------------------------------------------------===//
3000
3001 // Verifies that the input is 1D.
verify()3002 LogicalResult InvertPermutationOp::verify() {
3003 InvertPermutationOp op = *this;
3004 auto x_type = op.x().getType().cast<TensorType>();
3005 if (!x_type.hasRank()) return success();
3006 if (x_type.getShape().size() != 1)
3007 return op.emitOpError() << "requires input x to be 1-dimensional";
3008
3009 return success();
3010 }
3011
3012 //===----------------------------------------------------------------------===//
3013 // LeakyReluOp
3014 //===----------------------------------------------------------------------===//
3015
fold(ArrayRef<Attribute> operands)3016 OpFoldResult LeakyReluOp::fold(ArrayRef<Attribute> operands) {
3017 assert(operands.size() == 1 && "leaky relu has one operand");
3018
3019 // leaky_relu(x, alpha: 1) -> x
3020 if (alpha().convertToFloat() == 1.0f) return getOperand();
3021
3022 auto calculate = [&](FloatAttr arg) {
3023 APFloat val = arg.getValue();
3024 if (val.isNegative()) val = alpha() * val;
3025 return FloatAttr::get(arg.getType(), val);
3026 };
3027
3028 if (auto arg = operands[0].dyn_cast_or_null<FloatAttr>()) {
3029 return calculate(arg);
3030 } else if (auto arg = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
3031 if (auto elementAttr = arg.getSplatValue<Attribute>().dyn_cast<FloatAttr>())
3032 return DenseElementsAttr::get(arg.getType(), calculate(elementAttr));
3033 }
3034 return {};
3035 }
3036
3037 //===----------------------------------------------------------------------===//
3038 // LogOp
3039 //===----------------------------------------------------------------------===//
3040
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3041 void LogOp::getCanonicalizationPatterns(RewritePatternSet &results,
3042 MLIRContext *context) {
3043 results.add<LogOfSoftmax, LogToLog1p>(context);
3044 }
3045
3046 //===----------------------------------------------------------------------===//
3047 // LogicalNotOp
3048 //===----------------------------------------------------------------------===//
3049
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3050 void LogicalNotOp::getCanonicalizationPatterns(RewritePatternSet &results,
3051 MLIRContext *context) {
3052 results
3053 .add<LogicalNotOfEqual, LogicalNotOfNotEqual, LogicalNotOfGreater,
3054 LogicalNotOfGreaterEqual, LogicalNotOfLess, LogicalNotOfLessEqual>(
3055 context);
3056 }
3057
3058 //===----------------------------------------------------------------------===//
3059 // MatrixBandPartOp
3060 //===----------------------------------------------------------------------===//
3061
verify()3062 LogicalResult MatrixBandPartOp::verify() {
3063 MatrixBandPartOp op = *this;
3064 if (!HasRankAtLeast(op.input(), 2)) {
3065 return op.emitOpError()
3066 << "requires `input` to have rank of at least 2, but found "
3067 << op.input().getType();
3068 }
3069 if (!IsOfRankOrUnranked(op.num_lower(), 0)) {
3070 return op.emitOpError()
3071 << "requires `num_lower` to have 0 dimensions, but found "
3072 << op.num_lower().getType();
3073 }
3074 if (!IsOfRankOrUnranked(op.num_upper(), 0)) {
3075 return op.emitOpError()
3076 << "requires `num_upper` to have 0 dimensions, but found "
3077 << op.num_upper().getType();
3078 }
3079 return success();
3080 }
3081
3082 //===----------------------------------------------------------------------===//
3083 // MatrixDiag Ops
3084 //===----------------------------------------------------------------------===//
3085
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3086 void MatrixDiagOp::getCanonicalizationPatterns(RewritePatternSet &results,
3087 MLIRContext *context) {
3088 results.add<MatrixDiagToV3>(context);
3089 }
3090
3091 //===----------------------------------------------------------------------===//
3092 // MatrixSetDiagOp
3093 //===----------------------------------------------------------------------===//
3094
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3095 void MatrixSetDiagOp::getCanonicalizationPatterns(RewritePatternSet &results,
3096 MLIRContext *context) {
3097 results.add<MatrixSetDiagToV3>(context);
3098 }
3099
3100 //===----------------------------------------------------------------------===//
3101 // MatrixSetDiagV2Op
3102 //===----------------------------------------------------------------------===//
3103
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3104 void MatrixSetDiagV2Op::getCanonicalizationPatterns(RewritePatternSet &results,
3105 MLIRContext *context) {
3106 results.add<MatrixSetDiagV2ToV3>(context);
3107 }
3108
3109 //===----------------------------------------------------------------------===//
3110 // MaxOp
3111 //===----------------------------------------------------------------------===//
3112
build(OpBuilder & builder,OperationState & result,Value input,Value reduction_indices,BoolAttr keep_dims)3113 void MaxOp::build(OpBuilder &builder, OperationState &result, Value input,
3114 Value reduction_indices, BoolAttr keep_dims) {
3115 Type out_ty = InferReductionOpType(input, reduction_indices, keep_dims);
3116 build(builder, result, out_ty, input, reduction_indices, keep_dims);
3117 }
3118
3119 //===----------------------------------------------------------------------===//
3120 // MaximumOp
3121 //===----------------------------------------------------------------------===//
3122
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3123 void MaximumOp::getCanonicalizationPatterns(RewritePatternSet &results,
3124 MLIRContext *context) {
3125 results.add<MaximumOfZeroToRelu>(context);
3126 }
3127
3128 //===----------------------------------------------------------------------===//
3129 // MaxPoolOp
3130 //===----------------------------------------------------------------------===//
3131
FoldOperandsPermutation(ArrayRef<int64_t> permutation)3132 LogicalResult MaxPoolOp::FoldOperandsPermutation(
3133 ArrayRef<int64_t> permutation) {
3134 return ::mlir::TF::FoldOperandsPermutation(
3135 permutation, this, {{"strides", strides()}, {"ksize", ksize()}});
3136 }
3137
UpdateDataFormat(StringRef new_data_format)3138 LogicalResult MaxPoolOp::UpdateDataFormat(StringRef new_data_format) {
3139 StringRef src_data_format = data_format();
3140
3141 auto perm = GetDataFormatPermutation(src_data_format, new_data_format);
3142 if (perm.empty()) return failure();
3143
3144 // Update data_format attribute and result types.
3145 if (failed(::mlir::TF::UpdateDataFormat(new_data_format, this)))
3146 return failure();
3147
3148 stridesAttr(ShuffleArrayAttr(strides(), perm));
3149 explicit_paddingsAttr(ShuffleArrayAttr(explicit_paddings(), perm, 2));
3150 ksizeAttr(ShuffleArrayAttr(ksize(), perm));
3151
3152 return success();
3153 }
3154
GetOptimalLayout(const RuntimeDevices & devices)3155 StringRef MaxPoolOp::GetOptimalLayout(const RuntimeDevices &devices) {
3156 // Keep current data format if no GPUs are available or if explicit placement
3157 // does not allow to use GPU for this operation.
3158 if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
3159 return data_format();
3160
3161 // Defaults to NCHW.
3162 return "NCHW";
3163 }
3164
3165 //===----------------------------------------------------------------------===//
3166 // MaxPoolGradOp
3167 //===----------------------------------------------------------------------===//
3168
verify()3169 LogicalResult MaxPoolGradOp::verify() {
3170 MaxPoolGradOp op = *this;
3171 if (!IsOfRankOrUnranked(op.orig_input(), 4)) {
3172 return op.emitOpError() << "requires orig_input to be rank 4";
3173 }
3174 if (!IsOfRankOrUnranked(op.orig_output(), 4)) {
3175 return op.emitOpError() << "requires orig_output to be rank 4";
3176 }
3177 if (!IsOfRankOrUnranked(op.grad(), 4)) {
3178 return op.emitOpError() << "requires grad to be rank 4";
3179 }
3180 return success();
3181 }
3182
3183 //===----------------------------------------------------------------------===//
3184 // MeanOp
3185 //===----------------------------------------------------------------------===//
3186
FoldOperandsPermutation(ArrayRef<int64_t> permutation)3187 LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
3188 // Reduction indices must be defined by a constant operation.
3189 auto reduction_op =
3190 dyn_cast_or_null<TF::ConstOp>(reduction_indices().getDefiningOp());
3191 if (!reduction_op) return failure();
3192
3193 auto reductions_value = reduction_op.value().dyn_cast<DenseElementsAttr>();
3194 if (!reductions_value) return failure();
3195
3196 // Prepare new reduction indices according to operand permutation.
3197 SmallVector<int32_t, 4> shuffled_reduction;
3198 llvm::transform(reductions_value.getValues<APInt>(),
3199 std::back_inserter(shuffled_reduction),
3200 [&](APInt idx) { return permutation[idx.getSExtValue()]; });
3201
3202 // Add constant operation with a new reduction indices.
3203 OpBuilder builder(getOperation());
3204 auto type = mlir::RankedTensorType::get(shuffled_reduction.size(),
3205 builder.getIntegerType(32));
3206 auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction);
3207 auto shuffled_reduction_op = builder.create<TF::ConstOp>(getLoc(), values);
3208
3209 // Use new reduction indices.
3210 setOperand(1, shuffled_reduction_op);
3211
3212 return success();
3213 }
3214
3215 //===----------------------------------------------------------------------===//
3216 // MulNoNanOp
3217 //===----------------------------------------------------------------------===//
3218
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3219 void MulNoNanOp::getCanonicalizationPatterns(RewritePatternSet &results,
3220 MLIRContext *context) {
3221 results.add<DivNoNanOrMulNoNanConstantY<TF::MulNoNanOp, TF::MulOp>>(context);
3222 }
3223
3224 //===----------------------------------------------------------------------===//
3225 // MulOp
3226 //===----------------------------------------------------------------------===//
3227
fold(ArrayRef<Attribute> operands)3228 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
3229 return IdentityArithmeticOpFolder<MulOp>(*this, operands);
3230 }
3231
3232 //===----------------------------------------------------------------------===//
3233 // HashTableOp
3234 //===----------------------------------------------------------------------===//
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3235 void HashTableOp::getCanonicalizationPatterns(RewritePatternSet &results,
3236 MLIRContext *context) {
3237 results.add<HashTableAndInitializeTableToV2>(context);
3238 results.add<HashTableAndLookupTableSizeToV2>(context);
3239 results.add<HashTableAndLookupTableFindToV2>(context);
3240 }
3241
3242 //===----------------------------------------------------------------------===//
3243 // BitcastOp
3244 //===----------------------------------------------------------------------===//
3245
verify()3246 LogicalResult BitcastOp::verify() {
3247 BitcastOp op = *this;
3248 auto input_type = op.input().getType().cast<ShapedType>();
3249 auto output_type = op.output().getType().cast<ShapedType>();
3250 auto input_element_type = input_type.getElementType();
3251 auto output_element_type = output_type.getElementType();
3252
3253 // We only handle float and int element type in the verifier currently
3254 // TODO(hanxiongwang): we can plan to handle more element type checks besides
3255 // int and float in the verifier
3256 if (input_type.hasStaticShape() && output_type.hasStaticShape() &&
3257 input_element_type.isIntOrFloat() && output_element_type.isIntOrFloat()) {
3258 const auto input_element_type_bitwidth =
3259 input_element_type.getIntOrFloatBitWidth();
3260 const auto output_element_type_bitwidth =
3261 output_element_type.getIntOrFloatBitWidth();
3262
3263 auto is_output_shape_valid_with_small_input_element_type_bitwidth = [&]() {
3264 if (output_element_type_bitwidth % input_element_type_bitwidth != 0) {
3265 op.emitOpError() << "output element bitwidth is not multiple "
3266 << "of input element bitwidth";
3267 return failure();
3268 }
3269 if (input_type.getShape().size() != output_type.getShape().size() + 1) {
3270 op.emitOpError() << "rank of input tensor is "
3271 << input_type.getShape().size()
3272 << ". rank of output tensor is expected to be "
3273 << input_type.getShape().size() - 1 << ", instead of "
3274 << output_type.getShape().size() << ".";
3275 return failure();
3276 }
3277 const auto rightmost_dim_size_divisor =
3278 output_element_type_bitwidth / input_element_type_bitwidth;
3279 if (input_type.getShape().empty() ||
3280 input_type.getShape().back() != rightmost_dim_size_divisor) {
3281 op.emitOpError()
3282 << "input rightmost dimension size is not equal to the divisor. "
3283 << "the last dimension of input is expected to be "
3284 << rightmost_dim_size_divisor;
3285 return failure();
3286 }
3287 for (auto idx = 0; idx < output_type.getShape().size(); idx++) {
3288 if (input_type.getShape()[idx] != output_type.getShape()[idx]) {
3289 op.emitOpError()
3290 << "the " << idx << "th dim of output tensor is "
3291 << output_type.getShape()[idx]
3292 << ". It is not equal to the one in input tensor, which is "
3293 << input_type.getShape()[idx];
3294 return failure();
3295 }
3296 }
3297 return success();
3298 };
3299
3300 auto is_output_shape_valid_with_small_output_element_type_bitwidth = [&]() {
3301 if (input_element_type_bitwidth % output_element_type_bitwidth != 0) {
3302 op.emitOpError() << "input element bitwidth is not multiple "
3303 << "of output element bitwidth";
3304 return failure();
3305 }
3306 if (input_type.getShape().size() + 1 != output_type.getShape().size()) {
3307 op.emitOpError() << "rank of input tensor is "
3308 << input_type.getShape().size()
3309 << ". rank of output tensor is expected to be "
3310 << input_type.getShape().size() + 1 << ", instead of "
3311 << output_type.getShape().size() << ".";
3312 return failure();
3313 }
3314 const auto rightmost_dim_size_divisor =
3315 input_element_type_bitwidth / output_element_type_bitwidth;
3316 if (output_type.getShape().back() != rightmost_dim_size_divisor) {
3317 op.emitOpError()
3318 << "output rightmost dimension size is not equal to the divisor. "
3319 << "the last dimension of output is expected to be "
3320 << rightmost_dim_size_divisor;
3321 return failure();
3322 }
3323 for (auto idx = 0; idx < input_type.getShape().size(); idx++) {
3324 if (input_type.getShape()[idx] != output_type.getShape()[idx]) {
3325 op.emitOpError()
3326 << "the " << idx << "th dim of output tensor is "
3327 << output_type.getShape()[idx]
3328 << ". It is not equal to the one in input tensor, which is "
3329 << input_type.getShape()[idx];
3330 return failure();
3331 }
3332 }
3333 return success();
3334 };
3335
3336 auto is_output_shape_valid_with_equal_bitwidth = [&]() {
3337 if (input_type.getShape().equals(output_type.getShape())) {
3338 return success();
3339 }
3340 op.emitOpError()
3341 << "output tensor shape shall be equal to input tensor shape";
3342 return failure();
3343 };
3344
3345 if (input_element_type_bitwidth < output_element_type_bitwidth) {
3346 return is_output_shape_valid_with_small_input_element_type_bitwidth();
3347 } else if (input_element_type_bitwidth > output_element_type_bitwidth) {
3348 return is_output_shape_valid_with_small_output_element_type_bitwidth();
3349 } else {
3350 return is_output_shape_valid_with_equal_bitwidth();
3351 }
3352 }
3353 return success();
3354 }
3355
3356 } // namespace TF
3357 } // namespace mlir
3358
3359 //===----------------------------------------------------------------------===//
3360 // TableGen'd op method definitions
3361 //===----------------------------------------------------------------------===//
3362
3363 #define GET_OP_CLASSES
3364 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc.inc"
3365