1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for legalizing HLO to TensorFlow.
17
18 #include <cstddef>
19 #include <cstdint>
20 #include <cstdlib>
21 #include <functional>
22 #include <memory>
23 #include <numeric>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/None.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/STLForwardCompat.h"
34 #include "llvm/ADT/Sequence.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
40 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
41 #include "mlir/IR/Attributes.h" // from @llvm-project
42 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
43 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
44 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
45 #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
46 #include "mlir/IR/Location.h" // from @llvm-project
47 #include "mlir/IR/MLIRContext.h" // from @llvm-project
48 #include "mlir/IR/Matchers.h" // from @llvm-project
49 #include "mlir/IR/Operation.h" // from @llvm-project
50 #include "mlir/IR/PatternMatch.h" // from @llvm-project
51 #include "mlir/IR/Region.h" // from @llvm-project
52 #include "mlir/IR/Value.h" // from @llvm-project
53 #include "mlir/Pass/Pass.h" // from @llvm-project
54 #include "mlir/Support/LLVM.h" // from @llvm-project
55 #include "mlir/Support/LogicalResult.h" // from @llvm-project
56 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
57 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
58 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
59 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
60 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
61 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
62 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/broadcast_utils.h"
63 #include "tensorflow/core/framework/kernel_shape_util.h"
64 #include "tensorflow/core/lib/math/math_util.h"
65
66 namespace mlir {
67 namespace TF {
68 namespace {
69
70 using mhlo::DotDimensionNumbersAttr;
71
72 // Replaces `region`'s terminator to TF::Yield.
ReplaceReturnOp(Region & region,PatternRewriter & rewriter)73 void ReplaceReturnOp(Region ®ion, PatternRewriter &rewriter) {
74 OpBuilder::InsertionGuard guard(rewriter);
75
76 for (auto &block : region.getBlocks()) {
77 Operation *terminator = block.getTerminator();
78 auto return_op = llvm::dyn_cast_or_null<mhlo::ReturnOp>(terminator);
79 if (return_op == nullptr) continue;
80
81 rewriter.setInsertionPoint(return_op);
82 rewriter.replaceOpWithNewOp<TF::YieldOp>(return_op,
83 return_op->getOperands());
84 }
85 }
86
87 // If `value` is a splat constant, returns a success and set `splat_value`
88 // to the splate constant value.
89 // `SplatValueType` can be `APInt` or `APFloat`.
90 template <typename SplatValueType>
GetConstantSplatValue(Value value,SplatValueType & splat_value)91 LogicalResult GetConstantSplatValue(Value value, SplatValueType &splat_value) {
92 DenseElementsAttr attr;
93 if (!matchPattern(value, m_Constant(&attr)) || !attr.isSplat()) {
94 return failure();
95 }
96
97 splat_value = attr.getSplatValue<SplatValueType>();
98 return success();
99 }
100
101 struct PermutationAndShape {
102 DenseIntElementsAttr permutation;
103 ShapedType shape;
104 };
105
106 // Returns a DenseIntElementsAttr for a permutation and the shape after
107 // applying the permutation to a given shape through a transpose.
GetPermutationAndTransposedShape(llvm::ArrayRef<int64_t> permutation_array,ShapedType input_type,ConversionPatternRewriter & rewriter)108 PermutationAndShape GetPermutationAndTransposedShape(
109 llvm::ArrayRef<int64_t> permutation_array, ShapedType input_type,
110 ConversionPatternRewriter &rewriter) {
111 assert(permutation_array.size() == input_type.getRank());
112 llvm::SmallVector<int64_t> transposed_shape(permutation_array.size());
113 for (int64_t i = 0; i < permutation_array.size(); ++i) {
114 transposed_shape[i] = input_type.getDimSize(permutation_array[i]);
115 }
116 auto transposed_type =
117 RankedTensorType::get(transposed_shape, input_type.getElementType());
118 DenseIntElementsAttr permutation = DenseIntElementsAttr::get(
119 RankedTensorType::get(permutation_array.size(), rewriter.getI64Type()),
120 permutation_array);
121 return {permutation, transposed_type};
122 }
123
124 // Returns the inverse permutation array for a permutation array.
GetInversePermutationArray(llvm::ArrayRef<int64_t> permutation_array)125 llvm::SmallVector<int64_t> GetInversePermutationArray(
126 llvm::ArrayRef<int64_t> permutation_array) {
127 llvm::SmallVector<int64_t> inverse_permutation_array(
128 permutation_array.size());
129 const auto permutation_array_size = permutation_array.size();
130 for (int64_t i = 0; i < permutation_array_size; ++i) {
131 inverse_permutation_array[permutation_array[i]] = i;
132 }
133 return inverse_permutation_array;
134 }
135
136 // Returns the DenseIntElementsAttr for an inverse permutation given a
137 // permutation_array.
GetInversePermutation(llvm::ArrayRef<int64_t> permutation_array,ConversionPatternRewriter & rewriter)138 DenseIntElementsAttr GetInversePermutation(
139 llvm::ArrayRef<int64_t> permutation_array,
140 ConversionPatternRewriter &rewriter) {
141 SmallVector<int64_t, 4> inverse_permutation_array =
142 GetInversePermutationArray(permutation_array);
143 return DenseIntElementsAttr::get(
144 RankedTensorType::get(inverse_permutation_array.size(),
145 rewriter.getI64Type()),
146 inverse_permutation_array);
147 }
148
149 // Returns a DenseIntElementsAttr for an inverse permutation and the shape after
150 // applying the inverse permutation to a given shape through a transpose.
GetInversePermutationAndShape(llvm::ArrayRef<int64_t> permutation_array,ShapedType input_type,ConversionPatternRewriter & rewriter)151 PermutationAndShape GetInversePermutationAndShape(
152 llvm::ArrayRef<int64_t> permutation_array, ShapedType input_type,
153 ConversionPatternRewriter &rewriter) {
154 SmallVector<int64_t, 4> inverse_permutation_array =
155 GetInversePermutationArray(permutation_array);
156 return GetPermutationAndTransposedShape(inverse_permutation_array, input_type,
157 rewriter);
158 }
159
160 // Common functionality for ConvertConvOp classes.
161 template <int SupportedSpatialDims>
162 struct ConvertNdConvOp {
IsSupportedConvOpmlir::TF::__anonab43b57b0111::ConvertNdConvOp163 bool IsSupportedConvOp(mhlo::ConvolutionOp conv_op) const {
164 if (!conv_op.lhs().getType().cast<ShapedType>().hasStaticShape() ||
165 !conv_op.rhs().getType().cast<ShapedType>().hasStaticShape() ||
166 !conv_op.getType().cast<ShapedType>().hasStaticShape())
167 return false;
168
169 // All ones in "lhs_dilation" means this "mhlo.conv" op should be
170 // converted to "tf.Conv2D" or "tf.DepthwiseConv2dNativeOp".
171 if (conv_op.lhs_dilation().has_value()) {
172 auto lhs_dilation = conv_op.lhs_dilation().getValue();
173 if (!lhs_dilation.isSplat() || lhs_dilation.getSplatValue<int64_t>() != 1)
174 return false;
175 }
176
177 if (!conv_op.window_strides().has_value() || conv_op.window_strides()
178 .getValue()
179 .getType()
180 .cast<ShapedType>()
181 .getRank() != 1)
182 return false;
183
184 auto num_spatial_dims =
185 conv_op.dimension_numbers().getInputSpatialDimensions().size();
186 // TODO(b/158636600): Currently we don't support 3D Convolution.
187 if (num_spatial_dims != SupportedSpatialDims) return false;
188
189 return true;
190 }
191 };
192
193 // Convert a 1-D convolution into a 2-D convolution (which TF supports) so that
194 // it can be rewritten by the pattern `Convert2DConvOp`.
195 class Convert1DConvOp : public OpConversionPattern<mhlo::ConvolutionOp>,
196 ConvertNdConvOp<1> {
197 public:
198 using OpConversionPattern::OpConversionPattern;
199
matchAndRewrite(mhlo::ConvolutionOp conv_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const200 LogicalResult matchAndRewrite(
201 mhlo::ConvolutionOp conv_op, OpAdaptor adaptor,
202 ConversionPatternRewriter &rewriter) const final {
203 //
204 // Check that input is a supported 1d convolution.
205 //
206
207 if (!IsSupportedConvOp(conv_op) || conv_op->getNumResults() != 1)
208 return rewriter.notifyMatchFailure(conv_op, "unsupported conv op.");
209
210 const mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
211
212 // Group convolution is not supported yet.
213 const int64_t input_feature_dimension = dnums.getInputFeatureDimension();
214 const int64_t input_channels =
215 conv_op.lhs().getType().cast<ShapedType>().getDimSize(
216 input_feature_dimension);
217 const int64_t feature_group_count = conv_op.feature_group_count();
218 if (feature_group_count != 1 && feature_group_count != input_channels)
219 return rewriter.notifyMatchFailure(conv_op,
220 "Group convolution is not supported,");
221
222 //
223 // Transpose and reshape the input and kernel
224 //
225
226 // Reshape input image to add a new spatial dimension.
227 auto image_type = conv_op.lhs().getType().cast<ShapedType>();
228 SmallVector<int64_t, 4> image_2d_shape(image_type.getShape().begin(),
229 image_type.getShape().end());
230 image_2d_shape.push_back(1);
231 auto image_2d_type =
232 RankedTensorType::get(image_2d_shape, image_type.getElementType());
233 auto image_2d_op = rewriter.create<mhlo::ReshapeOp>(
234 conv_op.getLoc(), image_2d_type, conv_op.lhs());
235
236 // Transpose image to get it into NWHC form (where H is the added dim).
237 SmallVector<int64_t, 4> image_permutation = {
238 dnums.getInputBatchDimension(), dnums.getInputSpatialDimensions()[0],
239 3, // The trailing dim that we added.
240 dnums.getInputFeatureDimension()};
241 auto image_permutation_and_shape = GetPermutationAndTransposedShape(
242 image_permutation, image_2d_type, rewriter);
243 auto transposed_image_2d_op = rewriter.create<mhlo::TransposeOp>(
244 conv_op.getLoc(), image_permutation_and_shape.shape,
245 image_2d_op->getResult(0), image_permutation_and_shape.permutation);
246
247 // Reshape kernel to add a new spatial dimension.
248 auto kernel_type = conv_op.rhs().getType().cast<ShapedType>();
249 SmallVector<int64_t, 4> kernel_2d_shape;
250 for (int64_t dim : kernel_type.getShape()) {
251 kernel_2d_shape.push_back(dim);
252 }
253 kernel_2d_shape.push_back(1);
254 auto kernel_2d_type =
255 RankedTensorType::get(kernel_2d_shape, kernel_type.getElementType());
256 auto kernel_2d_op = rewriter.create<mhlo::ReshapeOp>(
257 conv_op.getLoc(), kernel_2d_type, conv_op.rhs());
258
259 // Transpose kernel to get it into WHIO form (where H is the added dim).
260 SmallVector<int64_t, 4> kernel_permutation = {
261 dnums.getKernelSpatialDimensions()[0],
262 3, // The trailing dim that we added.
263 dnums.getKernelInputFeatureDimension(),
264 dnums.getKernelOutputFeatureDimension()};
265 auto kernel_permutation_and_shape = GetPermutationAndTransposedShape(
266 kernel_permutation, kernel_2d_type, rewriter);
267 auto transposed_kernel_2d_op = rewriter.create<mhlo::TransposeOp>(
268 conv_op.getLoc(), kernel_permutation_and_shape.shape,
269 kernel_2d_op->getResult(0), kernel_permutation_and_shape.permutation);
270
271 //
272 // Create 2d equivalents for 1d convolution attributes.
273 //
274
275 // Window Strides
276 SmallVector<int64_t, 2> window_strides_2d_array;
277 for (const auto v : conv_op.window_strides()->getValues<int64_t>()) {
278 window_strides_2d_array.emplace_back(v);
279 }
280 window_strides_2d_array.push_back(1);
281 auto window_strides_2d = DenseIntElementsAttr::get(
282 RankedTensorType::get({2}, rewriter.getI64Type()),
283 window_strides_2d_array);
284
285 // Padding
286 SmallVector<int64_t, 4> padding_2d_array;
287 for (const auto v : conv_op.padding().getValue().getValues<int64_t>()) {
288 padding_2d_array.emplace_back(v);
289 }
290 // The newly added spatial dimension requires zero left and right padding.
291 padding_2d_array.push_back(0);
292 padding_2d_array.push_back(0);
293 auto padding_2d = DenseIntElementsAttr::get(
294 RankedTensorType::get({2, 2}, rewriter.getI64Type()), padding_2d_array);
295
296 // LHS dilation
297 SmallVector<int64_t, 4> lhs_dilation_array_2d;
298 for (const auto v :
299 conv_op.lhs_dilation().getValue().getValues<int64_t>()) {
300 lhs_dilation_array_2d.emplace_back(v);
301 }
302 lhs_dilation_array_2d.push_back(1);
303 auto lhs_dilation_2d = DenseIntElementsAttr::get(
304 RankedTensorType::get({2}, rewriter.getI64Type()),
305 lhs_dilation_array_2d);
306
307 // RHS dilation
308 SmallVector<int64_t, 4> rhs_dilation_array_2d;
309 for (const auto v :
310 conv_op.rhs_dilation().getValue().getValues<int64_t>()) {
311 rhs_dilation_array_2d.emplace_back(v);
312 }
313 rhs_dilation_array_2d.push_back(1);
314 auto rhs_dilation_2d = DenseIntElementsAttr::get(
315 RankedTensorType::get({2}, rewriter.getI64Type()),
316 rhs_dilation_array_2d);
317
318 // Window reversal is unsupported.
319 if (conv_op.window_reversal().has_value() &&
320 conv_op.window_reversal()->getValues<bool>()[0] == true)
321 return failure();
322 auto window_reversal_2d = DenseIntElementsAttr::get(
323 RankedTensorType::get({2}, rewriter.getI64Type()),
324 SmallVector<int64_t>({0, 0}));
325
326 // Precision config
327 if (!conv_op.precision_config().has_value()) return failure();
328
329 // Dimension numbers reflect the form of the 2d conv op NWHC * WHIO -> NWHC
330 auto dnums_2d =
331 mhlo::ConvDimensionNumbersAttr::get(rewriter.getContext(),
332 /*inputBatchDimension=*/0,
333 /*inputFeatureDimension=*/3,
334 /*inputSpatialDimensions=*/{1, 2},
335 /*kernelInputDimension=*/2,
336 /*kernelOutputDimension=*/3,
337 /*kernelSpatialDimensions=*/{0, 1},
338 /*outputBatchDimension=*/0,
339 /*outputFeatureDimension=*/3,
340 /*outputSpatialDimensions=*/{1, 2});
341 //
342 // Generate a 2-D convolution
343 //
344
345 // Determine the 2-D convolution output shape.
346 auto output_type = conv_op->getResult(0).getType().cast<ShapedType>();
347 SmallVector<int64_t, 4> output_2d_shape;
348 for (int64_t dim : output_type.getShape()) {
349 output_2d_shape.push_back(dim);
350 }
351 output_2d_shape.push_back(1);
352 auto output_2d_type =
353 RankedTensorType::get(output_2d_shape, output_type.getElementType());
354 SmallVector<int64_t, 4> output_permutation = {
355 dnums.getOutputBatchDimension(), dnums.getOutputSpatialDimensions()[0],
356 3, // The trailing dim that we added.
357 dnums.getOutputFeatureDimension()};
358 auto transposed_output_2d_shape =
359 GetPermutationAndTransposedShape(output_permutation, output_2d_type,
360 rewriter)
361 .shape;
362
363 auto conv2d_op = rewriter.create<mhlo::ConvolutionOp>(
364 conv_op.getLoc(), transposed_output_2d_shape,
365 transposed_image_2d_op.getResult(), transposed_kernel_2d_op.getResult(),
366 window_strides_2d, padding_2d, lhs_dilation_2d, rhs_dilation_2d,
367 window_reversal_2d, dnums_2d, conv_op.feature_group_count(),
368 conv_op.batch_group_count(), *conv_op.precision_config());
369
370 OpResult conv2d_output = conv2d_op->getResult(0);
371 auto conv2d_output_type = conv2d_output.getType().cast<ShapedType>();
372
373 //
374 // Transpose and reshape the output
375 //
376
377 // Since output is in NWHC form we need to undo the permutation we have
378 // affectively applied.
379 auto output_permutation_and_shape = GetInversePermutationAndShape(
380 output_permutation, conv2d_output_type, rewriter);
381 auto transposed_output_2d_op = rewriter.create<mhlo::TransposeOp>(
382 conv_op.getLoc(), output_permutation_and_shape.shape, conv2d_output,
383 output_permutation_and_shape.permutation);
384
385 // Drop the trailing spatial dimension from the output.
386 rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
387 conv_op, output_type, transposed_output_2d_op.getResult());
388 return success();
389 }
390 };
391
392 class Convert2DConvOp : public OpConversionPattern<mhlo::ConvolutionOp>,
393 ConvertNdConvOp<2> {
394 public:
395 using OpConversionPattern::OpConversionPattern;
396
matchAndRewrite(mhlo::ConvolutionOp conv_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const397 LogicalResult matchAndRewrite(
398 mhlo::ConvolutionOp conv_op, OpAdaptor adaptor,
399 ConversionPatternRewriter &rewriter) const final {
400 if (!IsSupportedConvOp(conv_op)) {
401 return failure();
402 }
403
404 // Constructs strides array.
405 // For example, [2, 3] -> [1, 2, 3, 1].
406 SmallVector<int64_t, 4> strides({1});
407 for (const auto v :
408 conv_op.window_strides().getValue().getValues<int64_t>()) {
409 strides.emplace_back(v);
410 }
411 strides.emplace_back(1);
412
413 // Constructs dilation array.
414 SmallVector<int64_t, 4> dilation;
415 if (auto rhs_dilation = conv_op.rhs_dilation()) {
416 // For example, [2, 3] -> [1, 2, 3, 1].
417 dilation.emplace_back(1);
418 dilation.append(rhs_dilation.getValue().getValues<int64_t>().begin(),
419 rhs_dilation.getValue().getValues<int64_t>().end());
420 dilation.emplace_back(1);
421 } else {
422 // Default value
423 dilation = {1, 1, 1, 1};
424 }
425
426 mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
427 const int input_feature_dimension = dnums.getInputFeatureDimension();
428 const int input_channels =
429 conv_op.lhs().getType().cast<ShapedType>().getDimSize(
430 input_feature_dimension);
431 int feature_group_count = conv_op.feature_group_count();
432
433 if (feature_group_count != 1 && feature_group_count != input_channels) {
434 // Group convolution is not supported yet.
435 return failure();
436 }
437
438 const int num_spatial_dims = dnums.getInputSpatialDimensions().size();
439 const bool is_depthwise_conv = input_channels == feature_group_count;
440 std::string padding;
441 SmallVector<int64_t, 8> explicit_padding;
442 if (!conv_op.padding().has_value() ||
443 (conv_op.padding().getValue().isSplat() &&
444 conv_op.padding()->getSplatValue<int64_t>() == 0)) {
445 padding = "VALID";
446 } else {
447 SmallVector<int64_t, 4> padding_array;
448 for (const auto v : conv_op.padding().getValue().getValues<int64_t>()) {
449 padding_array.emplace_back(v);
450 }
451
452 if (IsSamePadding(conv_op, num_spatial_dims, strides, dilation,
453 padding_array)) {
454 // Check if padding is "SAME".
455 padding = "SAME";
456 } else {
457 padding = "EXPLICIT";
458 explicit_padding.push_back(0);
459 explicit_padding.push_back(0);
460 explicit_padding.append(padding_array);
461 explicit_padding.push_back(0);
462 explicit_padding.push_back(0);
463 }
464 }
465
466 CreateConvOp(conv_op, strides, padding, explicit_padding, dilation,
467 is_depthwise_conv, input_channels, num_spatial_dims, rewriter);
468 return success();
469 };
470
471 private:
IsSamePadding(mhlo::ConvolutionOp conv_op,int num_spatial_dims,ArrayRef<int64_t> strides,ArrayRef<int64_t> dilation,ArrayRef<int64_t> padding_array) const472 bool IsSamePadding(mhlo::ConvolutionOp conv_op, int num_spatial_dims,
473 ArrayRef<int64_t> strides, ArrayRef<int64_t> dilation,
474 ArrayRef<int64_t> padding_array) const {
475 mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
476 auto input_spatial_dim = dnums.getInputSpatialDimensions();
477 auto kernel_spatial_dim = dnums.getKernelSpatialDimensions();
478 for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
479 int dim = i + 1;
480 int64_t output_size;
481 int64_t pad_low_int64;
482 int64_t pad_high_int64;
483 tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
484 conv_op.lhs().getType().cast<ShapedType>().getDimSize(
485 input_spatial_dim[i]),
486 conv_op.rhs().getType().cast<ShapedType>().getDimSize(
487 kernel_spatial_dim[i]),
488 dilation[dim], strides[dim], tensorflow::Padding::SAME, &output_size,
489 &pad_low_int64, &pad_high_int64);
490 if (!status.ok()) return false;
491 if (padding_array[2 * i] != pad_low_int64 ||
492 padding_array[2 * i + 1] != pad_high_int64)
493 return false;
494 }
495
496 return true;
497 }
498
499 // Returns true if the op needs reformat.
NeedsReformatTypeAndPermutation(int batch_dim,int feature_dim,int spatial_dim_start,int default_batch_dim,int default_feature_dim,int default_spatial_dim_start) const500 bool NeedsReformatTypeAndPermutation(int batch_dim, int feature_dim,
501 int spatial_dim_start,
502 int default_batch_dim,
503 int default_feature_dim,
504 int default_spatial_dim_start) const {
505 return batch_dim != default_batch_dim ||
506 feature_dim != default_feature_dim ||
507 spatial_dim_start != default_spatial_dim_start;
508 }
509
510 // Gets reformat type and permutation attribute. Call this function only if
511 // NeedsReformatTypeAndPermutation returns true.
512 std::pair<RankedTensorType, DenseIntElementsAttr>
GetReformatTypeAndPermutation(int batch_dim,int feature_dim,int spatial_dim_start,int default_batch_dim,int default_feature_dim,int default_spatial_dim_start,int num_spatial_dims,RankedTensorType type,ConversionPatternRewriter & rewriter) const513 GetReformatTypeAndPermutation(int batch_dim, int feature_dim,
514 int spatial_dim_start, int default_batch_dim,
515 int default_feature_dim,
516 int default_spatial_dim_start,
517 int num_spatial_dims, RankedTensorType type,
518 ConversionPatternRewriter &rewriter) const {
519 auto shape = type.getShape();
520 llvm::SmallVector<int64_t, 4> permutation_array(num_spatial_dims + 2);
521 permutation_array[default_batch_dim] = batch_dim;
522 permutation_array[default_feature_dim] = feature_dim;
523 llvm::SmallVector<int64_t, 4> transposed_shape(num_spatial_dims + 2);
524 transposed_shape[default_batch_dim] = shape[batch_dim];
525 transposed_shape[default_feature_dim] = shape[feature_dim];
526 for (int i : llvm::seq<int>(0, num_spatial_dims)) {
527 permutation_array[default_spatial_dim_start + i] = spatial_dim_start + i;
528 transposed_shape[default_spatial_dim_start + i] =
529 shape[spatial_dim_start + i];
530 }
531 auto new_type =
532 RankedTensorType::get(transposed_shape, type.getElementType());
533 auto permutation = DenseIntElementsAttr::get(
534 RankedTensorType::get({type.getRank()}, rewriter.getI64Type()),
535 permutation_array);
536 return {new_type, permutation};
537 }
538
FormatToNHWC(Value value,int batch_dim,int feature_dim,ArrayRef<int64_t> spatial_dimensions,int default_batch_dim,int default_feature_dim,int default_spatial_dim_start,int num_spatial_dims,ConversionPatternRewriter & rewriter) const539 Value FormatToNHWC(Value value, int batch_dim, int feature_dim,
540 ArrayRef<int64_t> spatial_dimensions,
541 int default_batch_dim, int default_feature_dim,
542 int default_spatial_dim_start, int num_spatial_dims,
543 ConversionPatternRewriter &rewriter) const {
544 auto type = value.getType().cast<RankedTensorType>();
545 DenseIntElementsAttr permutation;
546 const int spatial_dim_start = spatial_dimensions.front();
547 if (!NeedsReformatTypeAndPermutation(
548 batch_dim, feature_dim, spatial_dim_start, default_batch_dim,
549 default_feature_dim, default_spatial_dim_start)) {
550 // Transpose is not needed because the current format is "NHWC".
551 return value;
552 }
553 std::pair<RankedTensorType &, DenseIntElementsAttr &>(type, permutation) =
554 GetReformatTypeAndPermutation(batch_dim, feature_dim, spatial_dim_start,
555 default_batch_dim, default_feature_dim,
556 default_spatial_dim_start,
557 num_spatial_dims, type, rewriter);
558 return rewriter.create<mhlo::TransposeOp>(value.getLoc(), type, value,
559 permutation);
560 }
561
562 // Slices the input `value` if there are negative padding values in
563 // `explicit_padding`.
SliceNegativePadding(Value value,ArrayRef<int64_t> explicit_padding,ConversionPatternRewriter & rewriter) const564 Value SliceNegativePadding(Value value, ArrayRef<int64_t> explicit_padding,
565 ConversionPatternRewriter &rewriter) const {
566 // If no padding is negative return the input as is.
567 if (llvm::all_of(explicit_padding, [](int64_t pad) { return pad >= 0; })) {
568 return value;
569 }
570
571 auto input_type = value.getType().cast<RankedTensorType>();
572 auto input_shape = input_type.getShape();
573
574 llvm::SmallVector<int64_t, 4> start;
575 llvm::SmallVector<int64_t, 4> size;
576 start.reserve(explicit_padding.size() / 2);
577 size.reserve(explicit_padding.size() / 2);
578 for (int i = 0, e = explicit_padding.size() / 2; i < e; ++i) {
579 int64_t pre_padding = explicit_padding[2 * i];
580 int64_t post_padding = explicit_padding[2 * i + 1];
581 int64_t pre_slice = pre_padding < 0 ? -pre_padding : 0;
582 int64_t post_slice = post_padding < 0 ? -post_padding : 0;
583 start.push_back(pre_slice);
584 size.push_back(input_shape[i] - pre_slice - post_slice);
585 }
586
587 auto start_attr = rewriter.create<ConstOp>(
588 value.getLoc(),
589 DenseIntElementsAttr::get(
590 RankedTensorType::get({static_cast<int64_t>(start.size())},
591 rewriter.getI64Type()),
592 start));
593 auto size_attr = rewriter.create<ConstOp>(
594 value.getLoc(),
595 DenseIntElementsAttr::get(
596 RankedTensorType::get({static_cast<int64_t>(size.size())},
597 rewriter.getI64Type()),
598 size));
599 auto output_type = RankedTensorType::get(size, input_type.getElementType());
600
601 return rewriter.create<SliceOp>(value.getLoc(), output_type, value,
602 start_attr, size_attr);
603 }
604
CreateConvOp(mhlo::ConvolutionOp conv_op,ArrayRef<int64_t> strides,StringRef padding,ArrayRef<int64_t> explicit_padding,ArrayRef<int64_t> dilation,bool is_depthwise_conv,int input_channels,int num_spatial_dims,ConversionPatternRewriter & rewriter) const605 void CreateConvOp(mhlo::ConvolutionOp conv_op, ArrayRef<int64_t> strides,
606 StringRef padding, ArrayRef<int64_t> explicit_padding,
607 ArrayRef<int64_t> dilation, bool is_depthwise_conv,
608 int input_channels, int num_spatial_dims,
609 ConversionPatternRewriter &rewriter) const {
610 mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
611 // Transposes lhs and rhs if their formats are not NHWC.
612 Value lhs = FormatToNHWC(
613 conv_op.lhs(), dnums.getInputBatchDimension(),
614 dnums.getInputFeatureDimension(), dnums.getInputSpatialDimensions(),
615 /*default_batch_dim=*/0, /*default_feature_dim=*/num_spatial_dims + 1,
616 /*default_spatial_dim_start=*/1, num_spatial_dims, rewriter);
617 Value rhs = FormatToNHWC(
618 conv_op.rhs(), dnums.getKernelInputFeatureDimension(),
619 dnums.getKernelOutputFeatureDimension(),
620 dnums.getKernelSpatialDimensions(),
621 /*default_batch_dim=*/num_spatial_dims,
622 /*default_feature_dim=*/num_spatial_dims + 1,
623 /*default_spatial_dim_start=*/0, num_spatial_dims, rewriter);
624
625 // Emulate negative padding with a slice and remove negative values from the
626 // padding vector.
627 Value sliced_lhs = SliceNegativePadding(lhs, explicit_padding, rewriter);
628 auto new_padding = llvm::to_vector<4>(llvm::map_range(
629 explicit_padding, [](int64_t dim) { return dim > 0 ? dim : 0; }));
630
631 auto conv_output_type = conv_op.getType().cast<RankedTensorType>();
632 DenseIntElementsAttr permutation;
633 const bool need_transpose_output = NeedsReformatTypeAndPermutation(
634 dnums.getOutputBatchDimension(), dnums.getOutputFeatureDimension(),
635 dnums.getOutputSpatialDimensions().front(),
636 /*default_batch_dim=*/0, /*default_feature_dim=*/num_spatial_dims + 1,
637 /*default_spatial_dim_start=*/1);
638 if (need_transpose_output) {
639 std::pair<RankedTensorType &, DenseIntElementsAttr &>(conv_output_type,
640 permutation) =
641 GetReformatTypeAndPermutation(
642 dnums.getOutputBatchDimension(),
643 dnums.getOutputFeatureDimension(),
644 dnums.getOutputSpatialDimensions().front(),
645 /*default_batch_dim=*/0,
646 /*default_feature_dim=*/num_spatial_dims + 1,
647 /*default_spatial_dim_start=*/1, num_spatial_dims,
648 conv_output_type, rewriter);
649 }
650 Value output;
651 if (is_depthwise_conv) {
652 // Reshapes filter format to [filter_height, filter_width, in_channels,
653 // channel_multiplier] from HLO's [filter_height, filter_width, 1,
654 // in_channels * channel_multiplier] format.
655 auto filter_type = rhs.getType().cast<ShapedType>();
656 llvm::ArrayRef<int64_t> hlo_filter_shape = filter_type.getShape();
657 llvm::SmallVector<int64_t, 4> tf_filter_shape(hlo_filter_shape.begin(),
658 hlo_filter_shape.end());
659 tf_filter_shape[2] = input_channels;
660 tf_filter_shape[3] = hlo_filter_shape.back() / input_channels;
661 auto reshaped_filter = rewriter.create<mhlo::ReshapeOp>(
662 rhs.getLoc(),
663 RankedTensorType::get(tf_filter_shape, filter_type.getElementType()),
664 rhs);
665
666 output = rewriter.create<DepthwiseConv2dNativeOp>(
667 conv_op.getLoc(), conv_output_type, sliced_lhs, reshaped_filter,
668 rewriter.getI64ArrayAttr(strides),
669 /*padding=*/rewriter.getStringAttr(padding),
670 /*explicit_paddings=*/rewriter.getI64ArrayAttr(new_padding),
671 /*data_format=*/rewriter.getStringAttr("NHWC"),
672 /*dilations=*/rewriter.getI64ArrayAttr(dilation));
673 } else {
674 output = rewriter.create<Conv2DOp>(
675 conv_op.getLoc(), conv_output_type, sliced_lhs, rhs,
676 rewriter.getI64ArrayAttr(strides),
677 /*use_cudnn_on_gpu=*/rewriter.getBoolAttr(true),
678 /*padding=*/rewriter.getStringAttr(padding),
679 /*explicit_paddings=*/rewriter.getI64ArrayAttr(new_padding),
680 /*data_format=*/rewriter.getStringAttr("NHWC"),
681 /*dilations=*/rewriter.getI64ArrayAttr(dilation));
682 }
683
684 if (need_transpose_output) {
685 // Converts from "NHWC" format back to the original output format.
686 std::pair<RankedTensorType &, DenseIntElementsAttr &>(conv_output_type,
687 permutation) =
688 GetReformatTypeAndPermutation(
689 /*batch_dim=*/0, /*feature_dim=*/num_spatial_dims + 1,
690 /*spatial_dim_start=*/1, dnums.getOutputBatchDimension(),
691 dnums.getOutputFeatureDimension(),
692 *dnums.getOutputSpatialDimensions().begin(), num_spatial_dims,
693 conv_output_type, rewriter);
694 output = rewriter.create<mhlo::TransposeOp>(
695 conv_op.getLoc(), conv_op.getType(), output, permutation);
696 }
697 rewriter.replaceOp(conv_op, {output});
698 }
699 };
700
701 class ConvertNonTrivialConvOp
702 : public OpConversionPattern<mhlo::ConvolutionOp> {
703 public:
704 using OpConversionPattern::OpConversionPattern;
705
matchAndRewrite(mhlo::ConvolutionOp conv_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const706 LogicalResult matchAndRewrite(
707 mhlo::ConvolutionOp conv_op, OpAdaptor adaptor,
708 ConversionPatternRewriter &rewriter) const final {
709 if (IsSupportedConvOp(conv_op, rewriter).failed()) {
710 return rewriter.notifyMatchFailure(
711 conv_op,
712 "doesn't support to convert to ConvBackpropInputOp or "
713 "ResizeBilinearOp");
714 }
715
716 // tf.ResizeBilinearOp is perferred than tf.Conv2DBackpropInputOp since
717 // the former has better portability, especially in inference use cases.
718 bool align_corners;
719 llvm::SmallVector<int, 2> output_sizes;
720 if (MatchResizeOp(conv_op, align_corners, output_sizes, rewriter)
721 .succeeded()) {
722 CreateResizeBilinearOp(conv_op, output_sizes, align_corners, rewriter);
723 return success();
724 }
725
726 // Constructs strides array from lhs_dilation.
727 // For example, [2, 3] -> [1, 2, 3, 1].
728 SmallVector<int64_t, 4> strides({1});
729 strides.append(
730 conv_op.lhs_dilation().getValue().getValues<int64_t>().begin(),
731 conv_op.lhs_dilation().getValue().getValues<int64_t>().end());
732 strides.emplace_back(1);
733
734 // Constructs dilation array.
735 SmallVector<int64_t, 4> dilation;
736 if (auto rhs_dilation = conv_op.rhs_dilation()) {
737 // For example, [2, 3] -> [1, 2, 3, 1].
738 dilation.emplace_back(1);
739 dilation.append(rhs_dilation.getValue().getValues<int64_t>().begin(),
740 rhs_dilation.getValue().getValues<int64_t>().end());
741 dilation.emplace_back(1);
742 } else {
743 // Default value
744 dilation = {1, 1, 1, 1};
745 }
746
747 mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
748 std::string padding;
749 if (!conv_op.padding().has_value() ||
750 (conv_op.padding().getValue().isSplat() &&
751 conv_op.padding()->getSplatValue<int64_t>() == 0)) {
752 padding = "VALID";
753 } else {
754 auto spatial_dims = dnums.getInputSpatialDimensions();
755 int num_spatial_dims =
756 std::accumulate(spatial_dims.begin(), spatial_dims.end(), 1LL,
757 std::multiplies<int64_t>{});
758 if (!IsSamePadding(conv_op, num_spatial_dims, strides)) {
759 return rewriter.notifyMatchFailure(
760 conv_op, "requires padding to be SAME or VALID");
761 }
762 padding = "SAME";
763 }
764
765 // Converts int64_t to int32_t.
766 llvm::SmallVector<int, 4> input_shape;
767 for (int64_t dim : conv_op.getType().cast<RankedTensorType>().getShape()) {
768 input_shape.push_back(dim);
769 }
770 auto input_sizes = rewriter.create<ConstOp>(
771 conv_op.getLoc(),
772 DenseIntElementsAttr::get(
773 RankedTensorType::get({static_cast<int64_t>(input_shape.size())},
774 rewriter.getI32Type()),
775 input_shape));
776 // Mirror the filter in the spatial dimensions.
777 auto filter = rewriter.create<mhlo::ReverseOp>(
778 conv_op.getLoc(), conv_op.rhs(),
779 rewriter.getI64TensorAttr(dnums.getKernelSpatialDimensions()));
780 rewriter.replaceOpWithNewOp<Conv2DBackpropInputOp>(
781 conv_op, conv_op.getType(), input_sizes, filter, conv_op.lhs(),
782 rewriter.getI64ArrayAttr(strides),
783 /*use_cudnn_on_gpu=*/rewriter.getBoolAttr(true),
784 /*padding=*/rewriter.getStringAttr(padding),
785 /*explicit_paddings=*/rewriter.getI64ArrayAttr({}),
786 /*data_format=*/rewriter.getStringAttr("NHWC"),
787 /*dilations=*/rewriter.getI64ArrayAttr(dilation));
788 return success();
789 };
790
791 private:
IsSamePadding(mhlo::ConvolutionOp conv_op,int num_spatial_dims,ArrayRef<int64_t> strides) const792 bool IsSamePadding(mhlo::ConvolutionOp conv_op, int num_spatial_dims,
793 ArrayRef<int64_t> strides) const {
794 for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
795 int dim = i + 1;
796 int stride = strides[dim];
797 int input_size = conv_op.getType().cast<ShapedType>().getDimSize(dim);
798 int output_size =
799 conv_op.lhs().getType().cast<ShapedType>().getDimSize(dim);
800 if (output_size != (input_size + stride - 1) / stride) {
801 return false;
802 }
803 }
804
805 return true;
806 }
807
IsSupportedConvOp(mhlo::ConvolutionOp conv_op,ConversionPatternRewriter & rewriter) const808 LogicalResult IsSupportedConvOp(mhlo::ConvolutionOp conv_op,
809 ConversionPatternRewriter &rewriter) const {
810 if (!conv_op.lhs().getType().cast<ShapedType>().hasStaticShape() ||
811 !conv_op.rhs().getType().cast<ShapedType>().hasStaticShape() ||
812 !conv_op.getType().cast<ShapedType>().hasStaticShape())
813 return rewriter.notifyMatchFailure(conv_op, "requires static shape");
814 mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
815 const int input_feature_dimension = dnums.getInputFeatureDimension();
816 const int input_channels =
817 conv_op.lhs().getType().cast<ShapedType>().getDimSize(
818 input_feature_dimension);
819 int feature_group_count = conv_op.feature_group_count();
820
821 if (feature_group_count != 1 && feature_group_count != input_channels) {
822 // Group convolution is not supported yet.
823 return rewriter.notifyMatchFailure(conv_op,
824 "doesn't support group convolution");
825 }
826
827 // Checks lhs_dilation is non-trivial.
828 if (!conv_op.lhs_dilation().has_value()) {
829 return rewriter.notifyMatchFailure(conv_op,
830 "requires lhs_dilation attribute");
831 }
832 auto lhs_dilation = conv_op.lhs_dilation().getValue();
833 if (lhs_dilation.isSplat() && lhs_dilation.getSplatValue<int64_t>() == 1)
834 return rewriter.notifyMatchFailure(conv_op,
835 "requires non-trivial lhs_dilation");
836
837 if (!conv_op.window_strides().has_value() || conv_op.window_strides()
838 .getValue()
839 .getType()
840 .cast<ShapedType>()
841 .getRank() != 1)
842 return rewriter.notifyMatchFailure(
843 conv_op, "requires window_strides to equal to one");
844
845 int num_spatial_dims = dnums.getInputSpatialDimensions().size();
846 // TODO(chhe): Currently we don't support 3D Convolution.
847 if (num_spatial_dims != 2)
848 return rewriter.notifyMatchFailure(conv_op,
849 "doesn't support more than 2D");
850
851 // TODO(chhe): To support more data formats other than "NHWC".
852 // Checks format [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f].
853 if (dnums.getInputBatchDimension() != 0 ||
854 dnums.getInputFeatureDimension() != num_spatial_dims + 1)
855 return rewriter.notifyMatchFailure(conv_op,
856 "requires input format [b, 0, 1, f]");
857 auto input_spatial_dimensions = dnums.getInputSpatialDimensions();
858 for (auto p : llvm::enumerate(input_spatial_dimensions)) {
859 if (p.value() != p.index() + 1)
860 return rewriter.notifyMatchFailure(
861 conv_op, "requires input format [b, 0, 1, f]");
862 }
863
864 // Checks output dimensions.
865 if (dnums.getOutputBatchDimension() != 0 ||
866 conv_op.dimension_numbers().getOutputFeatureDimension() !=
867 num_spatial_dims + 1)
868 return rewriter.notifyMatchFailure(conv_op,
869 "requires output format [b, 0, 1, f]");
870 auto output_spatial_dimensions = dnums.getOutputSpatialDimensions();
871 for (auto p : llvm::enumerate(output_spatial_dimensions)) {
872 if (p.value() != p.index() + 1)
873 return rewriter.notifyMatchFailure(
874 conv_op, "requires output format [b, 0, 1, f]");
875 }
876
877 // Checks kernel dimensions.
878 if (dnums.getKernelInputFeatureDimension() != num_spatial_dims + 1 ||
879 dnums.getKernelOutputFeatureDimension() != num_spatial_dims)
880 return rewriter.notifyMatchFailure(conv_op,
881 "requires kernel format [b, 0, 1, f]");
882 auto kernel_spatial_dimensions = dnums.getKernelSpatialDimensions();
883 for (auto p : llvm::enumerate(kernel_spatial_dimensions)) {
884 if (p.value() != p.index())
885 return rewriter.notifyMatchFailure(
886 conv_op, "requires kernel format [0, 1, o, i]");
887 }
888
889 return success();
890 }
891
CreateResizeBilinearOp(mhlo::ConvolutionOp conv_op,llvm::ArrayRef<int32_t> output_sizes,bool align_corners,ConversionPatternRewriter & rewriter) const892 void CreateResizeBilinearOp(mhlo::ConvolutionOp conv_op,
893 llvm::ArrayRef<int32_t> output_sizes,
894 bool align_corners,
895 ConversionPatternRewriter &rewriter) const {
896 Value output_sizes_attr = rewriter.create<ConstOp>(
897 conv_op.getLoc(),
898 DenseIntElementsAttr::get(
899 RankedTensorType::get({static_cast<int64_t>(output_sizes.size())},
900 rewriter.getI32Type()),
901 output_sizes));
902 // The value of half_pixel_centers couldn't be inferred from the IR and XLA
903 // only support half_pixel_centers=True as in 01/11/2022. Here
904 // half_pixel_centers=False is hardcoded.
905 Value output = rewriter.create<ResizeBilinearOp>(
906 conv_op.getLoc(), conv_op.getType(), conv_op.lhs(), output_sizes_attr,
907 /*align_corners=*/rewriter.getBoolAttr(align_corners),
908 /*half_pixel_centers=*/rewriter.getBoolAttr(false));
909 rewriter.replaceOp(conv_op, {output});
910 }
911
MatchResizeOp(mhlo::ConvolutionOp conv_op,bool & align_corners,llvm::SmallVector<int,2> & output_sizes,ConversionPatternRewriter & rewriter) const912 LogicalResult MatchResizeOp(mhlo::ConvolutionOp conv_op, bool &align_corners,
913 llvm::SmallVector<int, 2> &output_sizes,
914 ConversionPatternRewriter &rewriter) const {
915 mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
916 auto input_spatial_dimensions = dnums.getInputSpatialDimensions();
917 auto kernel_spatial_dimensions = dnums.getKernelSpatialDimensions();
918 auto output_spatial_dimensions = dnums.getOutputSpatialDimensions();
919 if (input_spatial_dimensions.size() != 2 ||
920 output_spatial_dimensions.size() != 2 ||
921 kernel_spatial_dimensions.size() != 2 ||
922 input_spatial_dimensions[0] != output_spatial_dimensions[0] ||
923 input_spatial_dimensions[1] != output_spatial_dimensions[1])
924 return rewriter.notifyMatchFailure(
925 conv_op, "can only be converted to 2D resize op");
926
927 // When "lhs_dilation" is 2D and contains at least "1", and "rhs_dilation"
928 // are all "1"s, this "mhlo.conv" op can potentially be converted to
929 // "tf.ResizeBilinearOp".
930 if (!conv_op.rhs_dilation().has_value() || !conv_op.padding().has_value())
931 return rewriter.notifyMatchFailure(
932 conv_op, "resize op requires rhs_dilation and padding");
933
934 auto lhs_dilation = conv_op.lhs_dilation().getValue();
935 auto rhs_dilation = conv_op.rhs_dilation().getValue();
936 auto window_strides = conv_op.window_strides().getValue();
937 auto padding = conv_op.padding().getValue();
938 if (lhs_dilation.getNumElements() != 2 || !rhs_dilation.isSplat() ||
939 rhs_dilation.getSplatValue<int64_t>() != 1 ||
940 window_strides.getNumElements() != 2 || padding.getNumElements() != 4)
941 return rewriter.notifyMatchFailure(
942 conv_op, "resize op requires [2] dilations and [2,2] padding");
943 auto lhs_dilation_values = lhs_dilation.getValues<int64_t>();
944 auto window_strides_values = window_strides.getValues<int64_t>();
945 auto padding_values = padding.getValues<int64_t>();
946
947 // Cast the dimension sizes to int.
948 auto lhs_type = conv_op.lhs().getType().cast<ShapedType>();
949 llvm::SmallVector<int> input_sizes = {
950 static_cast<int>(lhs_type.getDimSize(input_spatial_dimensions[0])),
951 static_cast<int>(lhs_type.getDimSize(input_spatial_dimensions[1]))};
952 output_sizes = {static_cast<int>(conv_op.getType().getDimSize(
953 output_spatial_dimensions[0])),
954 static_cast<int>(conv_op.getType().getDimSize(
955 output_spatial_dimensions[1]))};
956
957 // This is based on method in compiler/tf2xla/kernels/image_resize_ops.cc
958 auto can_convert_to_bilinear = [](bool align_corners, int64_t dilation,
959 int64_t padding, int64_t stride,
960 int64_t input_spatial,
961 int64_t output_spatial) {
962 int64_t input_spatial_size =
963 align_corners ? input_spatial - 1 : input_spatial;
964 int64_t output_spatial_size =
965 align_corners ? output_spatial - 1 : output_spatial;
966
967 int64_t gcd =
968 tensorflow::MathUtil::GCD(static_cast<uint64_t>(input_spatial_size),
969 static_cast<uint64_t>(output_spatial_size));
970 if ((input_spatial_size % gcd != 0) ||
971 (input_spatial_size / gcd != stride) || (dilation - 1 != padding)) {
972 return false;
973 }
974
975 return true;
976 };
977
978 // Only of the lhs_dilation must be 1, then the non-1 dimension is the
979 // resize dimension.
980 if (lhs_dilation_values[0] != 1 && lhs_dilation_values[1] == 1) {
981 if (can_convert_to_bilinear(
982 /*align_corners=*/true, lhs_dilation_values[0], padding_values[0],
983 window_strides_values[0], input_sizes[0], output_sizes[0])) {
984 align_corners = true;
985 return success();
986 }
987 if (can_convert_to_bilinear(
988 /*align_corners=*/false, lhs_dilation_values[0],
989 padding_values[0], window_strides_values[0], input_sizes[0],
990 output_sizes[0])) {
991 align_corners = false;
992 return success();
993 }
994 }
995
996 if (lhs_dilation_values[0] == 1 && lhs_dilation_values[1] != 1) {
997 if (can_convert_to_bilinear(
998 /*align_corners=*/true, lhs_dilation_values[1], padding_values[2],
999 window_strides_values[1], input_sizes[1], output_sizes[1])) {
1000 align_corners = true;
1001 return success();
1002 }
1003 if (can_convert_to_bilinear(
1004 /*align_corners=*/false, lhs_dilation_values[1],
1005 padding_values[2], window_strides_values[1], input_sizes[1],
1006 output_sizes[1])) {
1007 align_corners = false;
1008 return success();
1009 }
1010 }
1011
1012 return rewriter.notifyMatchFailure(conv_op,
1013 "can not be converted to resize op");
1014 }
1015 };
1016
1017 class ConvertSliceOp : public OpConversionPattern<mhlo::SliceOp> {
1018 public:
1019 using OpConversionPattern::OpConversionPattern;
1020
matchAndRewrite(mhlo::SliceOp slice_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1021 LogicalResult matchAndRewrite(
1022 mhlo::SliceOp slice_op, OpAdaptor adaptor,
1023 ConversionPatternRewriter &rewriter) const final {
1024 auto begin =
1025 rewriter.create<ConstOp>(slice_op.getLoc(), slice_op.start_indices());
1026 auto end =
1027 rewriter.create<ConstOp>(slice_op.getLoc(), slice_op.limit_indices());
1028 auto strides =
1029 rewriter.create<ConstOp>(slice_op.getLoc(), slice_op.strides());
1030 rewriter.replaceOpWithNewOp<StridedSliceOp>(
1031 slice_op, slice_op.getType(), slice_op.operand(), begin, end, strides);
1032 return success();
1033 }
1034 };
1035
1036 class ConvertDynamicSliceOp : public OpConversionPattern<mhlo::DynamicSliceOp> {
1037 public:
1038 using OpConversionPattern::OpConversionPattern;
1039
matchAndRewrite(mhlo::DynamicSliceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1040 LogicalResult matchAndRewrite(
1041 mhlo::DynamicSliceOp op, OpAdaptor adaptor,
1042 ConversionPatternRewriter &rewriter) const final {
1043 ShapedType input_type = op.operand().getType().cast<ShapedType>();
1044 if (!input_type.hasStaticShape()) return failure();
1045 Type start_indices_element_type = op.start_indices()
1046 .front()
1047 .getType()
1048 .cast<ShapedType>()
1049 .getElementType();
1050
1051 // The mhlo dynamic_slice's start_indices can be either signed/unsigned
1052 // int32/int64. However, TF only takes in either i32 or i64 types for begin,
1053 // so we will always put a cast.
1054 Type signed_start_indices_element_type;
1055 if (start_indices_element_type.isInteger(32)) {
1056 signed_start_indices_element_type = rewriter.getI32Type();
1057 } else {
1058 signed_start_indices_element_type = rewriter.getI64Type();
1059 }
1060
1061 // Clamp indices to [0, input_size - output_size]
1062 llvm::SmallVector<Value, 4> start_indices_vector;
1063 start_indices_vector.reserve(op.start_indices().size());
1064 Value clamp_min = rewriter.create<ConstOp>(
1065 op.getLoc(),
1066 rewriter.getIntegerAttr(signed_start_indices_element_type, 0));
1067 for (uint64_t i = 0, e = op.start_indices().size(); i < e; ++i) {
1068 // Always put a cast there.
1069 auto start = op.start_indices()[i];
1070 auto cast_type = start.getType().cast<ShapedType>().clone(
1071 signed_start_indices_element_type);
1072 auto cast_op = rewriter.create<CastOp>(op.getLoc(), cast_type, start);
1073 Value clamp_max = rewriter.create<ConstOp>(
1074 op.getLoc(), rewriter.getIntegerAttr(
1075 signed_start_indices_element_type,
1076 input_type.getShape()[i] -
1077 op.slice_sizes().getValues<int64_t>()[i]));
1078 Value clamped_index = rewriter.create<mhlo::ClampOp>(
1079 op.getLoc(), cast_type, clamp_min, cast_op, clamp_max);
1080 start_indices_vector.push_back(clamped_index);
1081 }
1082
1083 // Pack individual start indices to start indices tensor.
1084 Type start_indices_type = RankedTensorType::get(
1085 {static_cast<int64_t>(start_indices_vector.size())},
1086 signed_start_indices_element_type);
1087 Value start_indices_op = rewriter.create<PackOp>(
1088 op.getLoc(), start_indices_type, ValueRange(start_indices_vector));
1089
1090 Value slice_sices_op =
1091 rewriter.create<ConstOp>(op.getLoc(), op.slice_sizes());
1092 rewriter.replaceOpWithNewOp<SliceOp>(op, op.getType(), op.operand(),
1093 start_indices_op, slice_sices_op);
1094 return success();
1095 };
1096 };
1097
1098 // Appends all elements in `range` to `values`.
1099 template <typename ValueT, typename Range>
Append(llvm::SmallVectorImpl<ValueT> & values,Range && range)1100 void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range) {
1101 values.insert(values.end(), range.begin(), range.end());
1102 }
1103
1104 // Appends all elements in `range` to `values`.
1105 template <typename ValueT, typename Range, typename... RangeTs>
Append(llvm::SmallVectorImpl<ValueT> & values,Range && range,RangeTs &&...ranges)1106 void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range,
1107 RangeTs &&...ranges) {
1108 values.insert(values.end(), range.begin(), range.end());
1109 Append(values, ranges...);
1110 }
1111
1112 // Returns the number of elements in `range`.
1113 template <typename Range>
Size(Range && range)1114 size_t Size(Range &&range) {
1115 return range.size();
1116 }
1117
1118 // Returns the total number of elements in a variadic number of `ranges`.
1119 template <typename Range, typename... RangeTs>
Size(Range && range,RangeTs &&...ranges)1120 size_t Size(Range &&range, RangeTs &&...ranges) {
1121 return range.size() + Size(std::forward<RangeTs>(ranges)...);
1122 }
1123
1124 // Concats all elements in `ranges` and returns a small vector as a result.
1125 template <typename ValueT, typename... RangeTs>
Concat(RangeTs &&...ranges)1126 llvm::SmallVector<ValueT, 4> Concat(RangeTs &&...ranges) {
1127 llvm::SmallVector<int64_t, 4> results;
1128 results.reserve(Size(std::forward<RangeTs>(ranges)...));
1129 Append(results, std::forward<RangeTs>(ranges)...);
1130 return results;
1131 }
1132
1133 // A struct to hold axes and sizes for a set of dimensions.
1134 struct DimensionVector {
AxesArraymlir::TF::__anonab43b57b0111::DimensionVector1135 llvm::ArrayRef<int64_t> AxesArray() const { return axes; }
SizesArraymlir::TF::__anonab43b57b0111::DimensionVector1136 llvm::ArrayRef<int64_t> SizesArray() const { return sizes; }
1137
1138 llvm::SmallVector<int64_t, 4> axes;
1139 llvm::SmallVector<int64_t, 4> sizes;
1140 };
1141
1142 // Create a single const integer.
BuildIntConstOp(ImplicitLocOpBuilder & builder,ConversionPatternRewriter & rewriter,int64_t const_value,Type type)1143 Value BuildIntConstOp(ImplicitLocOpBuilder &builder,
1144 ConversionPatternRewriter &rewriter, int64_t const_value,
1145 Type type) {
1146 Value result_const =
1147 builder.create<ConstOp>(rewriter.getIntegerAttr(type, const_value));
1148 return result_const;
1149 }
1150 // Create a const integer vector tensor (1-dim).
BuildIntArrayConstOp(ImplicitLocOpBuilder & builder,ConversionPatternRewriter & rewriter,ArrayRef<int64_t> const_value,Type type)1151 Value BuildIntArrayConstOp(ImplicitLocOpBuilder &builder,
1152 ConversionPatternRewriter &rewriter,
1153 ArrayRef<int64_t> const_value, Type type) {
1154 DenseIntElementsAttr const_value_raw;
1155 if (type == rewriter.getI64Type()) {
1156 const_value_raw = rewriter.getI64TensorAttr(const_value);
1157 } else {
1158 // Convert I64 const array to I32.
1159 llvm::SmallVector<int32_t> const_i32_vec;
1160 for (auto element : const_value) {
1161 const_i32_vec.push_back(static_cast<int32_t>(element));
1162 }
1163 const_value_raw = rewriter.getI32TensorAttr(const_i32_vec);
1164 }
1165 Value result_const = builder.create<ConstOp>(const_value_raw);
1166 return result_const;
1167 }
1168
1169 // Create a tensor that is reshaped from input.
BuildReshapeOp(ImplicitLocOpBuilder & builder,ConversionPatternRewriter & rewriter,Value input,ArrayRef<int64_t> shape,Type idx_type,Type element_type)1170 Value BuildReshapeOp(ImplicitLocOpBuilder &builder,
1171 ConversionPatternRewriter &rewriter, Value input,
1172 ArrayRef<int64_t> shape, Type idx_type,
1173 Type element_type) {
1174 Value shape_cst = BuildIntArrayConstOp(builder, rewriter, shape, idx_type);
1175 Value reshaped_input = builder.create<ReshapeOp>(
1176 RankedTensorType::get(shape, element_type), input, shape_cst);
1177 return reshaped_input;
1178 }
1179
1180 // Create a tensor which is equal to input[begin: begin + size].
BuildSliceOp(ImplicitLocOpBuilder & builder,ConversionPatternRewriter & rewriter,Value input,Value begin,ArrayRef<int64_t> shape,Type idx_type,Type element_type)1181 Value BuildSliceOp(ImplicitLocOpBuilder &builder,
1182 ConversionPatternRewriter &rewriter, Value input,
1183 Value begin, ArrayRef<int64_t> shape, Type idx_type,
1184 Type element_type) {
1185 Value shape_cst = BuildIntArrayConstOp(builder, rewriter, shape, idx_type);
1186 Value slice_result = builder.create<SliceOp>(
1187 RankedTensorType::get(shape, element_type), input, begin, shape_cst);
1188 return slice_result;
1189 }
1190
1191 class ConvertDynamicUpdateSliceOp
1192 : public OpConversionPattern<mhlo::DynamicUpdateSliceOp> {
1193 public:
1194 using OpConversionPattern::OpConversionPattern;
1195
matchAndRewrite(mhlo::DynamicUpdateSliceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1196 LogicalResult matchAndRewrite(
1197 mhlo::DynamicUpdateSliceOp op, OpAdaptor adaptor,
1198 ConversionPatternRewriter &rewriter) const final {
1199 ShapedType operand_type = op.operand().getType().cast<ShapedType>();
1200 ShapedType update_type =
1201 op.update().getType().dyn_cast_or_null<ShapedType>();
1202 ShapedType start_indices_type =
1203 op.start_indices().front().getType().dyn_cast_or_null<ShapedType>();
1204 if (update_type == nullptr || start_indices_type == nullptr)
1205 return rewriter.notifyMatchFailure(
1206 op, "update and start_indices should have ShapedType");
1207 if (!operand_type.hasStaticShape() || !update_type.hasStaticShape())
1208 return rewriter.notifyMatchFailure(
1209 op, "shape of operand and update should be static");
1210
1211 Type idx_type = start_indices_type.getElementType();
1212 int64_t shape_dim = operand_type.getRank();
1213 auto operand_shape = operand_type.getShape();
1214 auto update_shape = update_type.getShape();
1215
1216 ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
1217 Value zero_cst = BuildIntConstOp(builder, rewriter, 0, idx_type);
1218 Value one_cst = BuildIntConstOp(builder, rewriter, 1, idx_type);
1219 // Clamp start indices in [0, operand_size - update_size].
1220 llvm::SmallVector<Value> start_indices_vector;
1221 Append(start_indices_vector, op.start_indices());
1222 auto shape_tensor_type = RankedTensorType::get({shape_dim}, idx_type);
1223 Value start_indices_tensor =
1224 builder.create<PackOp>(shape_tensor_type, start_indices_vector);
1225 Value operand_shape_cst =
1226 BuildIntArrayConstOp(builder, rewriter, operand_shape, idx_type);
1227 Value update_shape_cst =
1228 BuildIntArrayConstOp(builder, rewriter, update_shape, idx_type);
1229 Value max_start_indices =
1230 builder.create<SubOp>(operand_shape_cst, update_shape_cst);
1231 Value start_indices_clip_max =
1232 builder.create<MinimumOp>(start_indices_tensor, max_start_indices);
1233 Value clamped_start_indices =
1234 builder.create<MaximumOp>(start_indices_clip_max, zero_cst);
1235
1236 // Do dynamic_upate_slice on flattened operand and update with the aid of
1237 // tf.TensorScatterUpdate op. It takes in 3 parameters: flat_operand,
1238 // indices and flat_update. The indices are computed as follows:
1239 // 1. Construct a range (0, n_operand). It arranges a id number to each
1240 // element position in operand.
1241 // 2. Reshape the range to the shape of operand.
1242 // 3. Compute the id numbers of update positions by choose a slice form
1243 // clamped_start_indices to clamped_start_indices + update_size.
1244 // 4. Flatten the update id numbers and the indices is obtained.
1245 int64_t n_operand = operand_type.getNumElements();
1246 Value n_operand_cst =
1247 BuildIntConstOp(builder, rewriter, n_operand, idx_type);
1248 Value range_flat =
1249 builder.create<RangeOp>(zero_cst, n_operand_cst, one_cst);
1250 Value range = BuildReshapeOp(builder, rewriter, range_flat, operand_shape,
1251 idx_type, idx_type);
1252 Value update_indices_raw =
1253 BuildSliceOp(builder, rewriter, range, clamped_start_indices,
1254 update_shape, idx_type, idx_type);
1255 int64_t n_update = update_type.getNumElements();
1256 Type element_type = operand_type.getElementType();
1257 Value update_indices = BuildReshapeOp(builder, rewriter, update_indices_raw,
1258 {n_update, 1}, idx_type, idx_type);
1259 Value operand_flat = BuildReshapeOp(builder, rewriter, op.operand(),
1260 {n_operand}, idx_type, element_type);
1261 Value update_flat = BuildReshapeOp(builder, rewriter, op.update(),
1262 {n_update}, idx_type, element_type);
1263 Value flat_result = builder.create<TensorScatterUpdateOp>(
1264 operand_flat, update_indices, update_flat);
1265
1266 // Reshape back before return.
1267 rewriter.replaceOpWithNewOp<ReshapeOp>(op, operand_type, flat_result,
1268 operand_shape_cst);
1269 return success();
1270 };
1271 };
1272
1273 // It returns "true" when Value $iota is obtained from the following mlir code:
1274 //
1275 // $iota = "mhlo.iota"(){iota_dimension = $dimensions[0]},
1276 //
1277 // where $dimensions must have size 1 and iota can have rank>=1.
1278 // It usually used for matching rank 1 iota since the iotaOp will be folded to
1279 // IotaOp + BroadCastInDimOp except for the case when result shape is rank 1.
MatchSingleIota(DenseIntElementsAttr dimensions,Value iota)1280 bool MatchSingleIota(DenseIntElementsAttr dimensions, Value iota) {
1281 auto iota_op = dyn_cast_or_null<mhlo::IotaOp>(iota.getDefiningOp());
1282 if (!iota_op || dimensions.getNumElements() != 1) return false;
1283 auto dim = *dimensions.value_begin<APInt>();
1284 return dim == iota_op.iota_dimension();
1285 }
1286
1287 // It matches %iota generated from the following mlir codes:
1288 //
1289 // %iota_r1 = "mhlo.iota"(){iota_dimension = 0} :() -> tensor<Lxi32>
1290 // %iota = "mhlo.broadcast_in_dim(%iota_r1){
1291 // broadcast_dimensions = dense<[$dimensions[0]]>},
1292 //
1293 // where %dimensions is of size 1. It ususally comes from an IotaOp that is
1294 // folded to IotaOp (rank1) + BroadCastInDimOp.
MatchIotaBroadCastInDim(DenseIntElementsAttr dimensions,Value iota)1295 bool MatchIotaBroadCastInDim(DenseIntElementsAttr dimensions, Value iota) {
1296 auto iota_broadcast =
1297 dyn_cast_or_null<mhlo::BroadcastInDimOp>(iota.getDefiningOp());
1298 if (!iota_broadcast || iota_broadcast.broadcast_dimensions() != dimensions)
1299 return false;
1300 if (!isa_and_nonnull<mhlo::IotaOp>(iota_broadcast.operand().getDefiningOp()))
1301 return false;
1302 return true;
1303 }
1304
1305 // Matches %iota generated from the following code (rank 3 example):
1306 //
1307 // %iota_r1 = "mhlo.iota"(){iota_dimension = 0 : i32} : () -> tensor<44xi32>
1308 // %iota = "mhlo.reshape"(%iota_r1): (tensor<44xi32>) -> tensor<1x1x44xi32>
1309 //
1310 // Where $dimensions is of size 1 and $dimensions[0] = 2.
1311 //
1312 // In general matches a 1-D Iota with multiple dimensions of size 1 added
1313 // through a reshape.
MatchReshapedIota(DenseIntElementsAttr dimensions,Value iota)1314 bool MatchReshapedIota(DenseIntElementsAttr dimensions, Value iota) {
1315 if (dimensions.getNumElements() != 1) return false;
1316 auto reshape_op = dyn_cast_or_null<mhlo::ReshapeOp>(iota.getDefiningOp());
1317 if (!reshape_op) return false;
1318 auto operand_type =
1319 reshape_op.operand().getType().dyn_cast<RankedTensorType>();
1320 if (!operand_type || !operand_type.hasStaticShape()) return false;
1321 auto reshape_type = reshape_op.getType().cast<RankedTensorType>();
1322
1323 // Reshape can take a 1-D iota input and add extra dims of size one.
1324 if (operand_type.getRank() != 1) return false;
1325 if (!dyn_cast_or_null<mhlo::IotaOp>(reshape_op.operand().getDefiningOp()))
1326 return false;
1327
1328 int64_t iota_dim = (*dimensions.value_begin<APInt>()).getSExtValue();
1329 for (int64_t i = 0, e = reshape_type.getRank(); i < e; ++i) {
1330 if (i == iota_dim) {
1331 if (reshape_type.getDimSize(i) != operand_type.getDimSize(0))
1332 return false;
1333 } else if (reshape_type.getDimSize(i) != 1) {
1334 return false;
1335 }
1336 }
1337 return true;
1338 }
1339
1340 // It matches %iota generated from the following mlir codes:
1341 //
1342 // %iota_r1 = mhlo.constant dense<[0, 1, 2, ..., L]>
1343 // %iota = "mhlo.broadcast_in_dim(%iota_r1){
1344 // broadcast_dimensions = dense<[$dimensions[0]]>},
1345 //
1346 // where $dimensions is of size 1. It ususally comes from an IotaOp that is
1347 // folded to ConstOp (folded rank1 iota) + BroadCastInDimOp.
MatchConstIotaBroadCastInDim(DenseIntElementsAttr dimensions,Value iota)1348 bool MatchConstIotaBroadCastInDim(DenseIntElementsAttr dimensions, Value iota) {
1349 if (dimensions.getNumElements() != 1) return false;
1350 auto iota_broadcast =
1351 dyn_cast_or_null<mhlo::BroadcastInDimOp>(iota.getDefiningOp());
1352 if (!iota_broadcast || iota_broadcast.broadcast_dimensions() != dimensions)
1353 return false;
1354 DenseElementsAttr range_const;
1355 if (!matchPattern(iota_broadcast.operand(), m_Constant(&range_const)))
1356 return false;
1357 int index = 0;
1358 for (auto value : range_const.getValues<APInt>()) {
1359 if (value != index++) return false;
1360 }
1361 return true;
1362 }
1363
1364 // Facilitate access to 1-d backing data for a tensor so that values in a 1-d
1365 // slice of the tensor can be accessed as if part of an ArrayView.
1366 class StridedArrayViewBase {
1367 protected:
StridedArrayViewBase(ArrayRef<int64_t> shape,ArrayRef<int64_t> index,int64_t axis)1368 StridedArrayViewBase(ArrayRef<int64_t> shape, ArrayRef<int64_t> index,
1369 int64_t axis) {
1370 assert(shape.size() == index.size());
1371 assert(axis < shape.size());
1372 assert(axis >= 0);
1373 assert(index[axis] == 0);
1374 offset_ = IndexToOffset(shape, index);
1375 stride_ = StrideForAxis(shape, axis);
1376 size_ = shape[axis];
1377 }
1378
1379 // Returns the size of the 1-d slice across the tensor.
size() const1380 int64_t size() const { return size_; }
1381
1382 // Calculates the next index in a tensor excluding a specified axis.
1383 //
1384 // Returns the next index where one exists.
1385 // If there is no valid next index, returns `std::nullopt`.
1386 //
1387 // `index` should have the same size as `shape`.
1388 // Each value `dim` in `index` should be in [0, shape[dim]).
NextTensorIndex(SmallVector<int64_t> index,ArrayRef<int64_t> shape,int64_t fixed_axis)1389 static llvm::Optional<SmallVector<int64_t>> NextTensorIndex(
1390 SmallVector<int64_t> index, ArrayRef<int64_t> shape, int64_t fixed_axis) {
1391 #ifndef NDEBUG
1392 assert(shape.size() == index.size());
1393 assert(fixed_axis < shape.size());
1394 assert(fixed_axis >= 0);
1395 assert(index[fixed_axis] == 0);
1396 for (size_t i = 0; i < shape.size(); ++i) {
1397 assert(index[i] < shape[i]);
1398 assert(index[i] >= 0);
1399 }
1400 #endif // NDEBUG
1401 for (int64_t dim = shape.size() - 1; dim >= 0; --dim) {
1402 if (dim == fixed_axis) continue;
1403 ++index[dim];
1404 if (index[dim] < shape[dim]) return std::move(index);
1405 index[dim] = 0;
1406 }
1407 return llvm::None;
1408 }
1409
1410 protected:
1411 // Calculates how many values to skip across a 1-D contiguous array that holds
1412 // backing data for a given shape to access the value at a given index along a
1413 // StridedArrayView across a higher dimensional shape.
1414 //
1415 // The index `i` must be in [0, shape[axis])` where `shape` is the shape
1416 // of the tensor and `axis` is the axis along the tensor that the
1417 // StridedArrayView indexes along.
OffsetForIndex(int64_t i) const1418 int64_t OffsetForIndex(int64_t i) const { return offset_ + i * stride_; }
1419
1420 private:
1421 // Calculates how many values to skip across a 1-D contiguous array that holds
1422 // backing data for a given shape to access the next value along a given axis.
1423 //
1424 // `axis` should be a valid dimension in `shape`.
StrideForAxis(ArrayRef<int64_t> shape,int64_t axis)1425 static int64_t StrideForAxis(ArrayRef<int64_t> shape, int64_t axis) {
1426 int64_t stride = 1; // Start with the trailing dimension.
1427 for (int64_t dim = shape.size() - 1; dim > axis; --dim) {
1428 stride *= shape[dim];
1429 }
1430 return stride;
1431 }
1432
1433 // Calculates how many values to skip across a 1-D contiguous array that holds
1434 // backing data for a given shape to access data at a specified index.
1435 //
1436 // `index` should have the same size as `shape`.
1437 // Each value `dim` in `index` should be in [0, shape[dim]).
IndexToOffset(ArrayRef<int64_t> shape,ArrayRef<int64_t> index)1438 static int64_t IndexToOffset(ArrayRef<int64_t> shape,
1439 ArrayRef<int64_t> index) {
1440 #ifndef NDEBUG
1441 assert(shape.size() == index.size());
1442 for (size_t i = 0; i < shape.size(); ++i) {
1443 assert(index[i] < shape[i]);
1444 assert(index[i] >= 0);
1445 }
1446 #endif // NDEBUG
1447 int64_t offset = 0;
1448 int64_t stride = 1;
1449 for (int64_t dim = shape.size() - 1; dim >= 0; --dim) {
1450 offset += index[dim] * stride;
1451 stride *= shape[dim];
1452 }
1453 return offset;
1454 }
1455
1456 int64_t offset_;
1457 int64_t stride_;
1458 int64_t size_;
1459 };
1460
1461 template <typename T>
1462 class StridedArrayView; // Class requires specialization.
1463
1464 // Wraps a DenseIntElementsAttr that holds backing data for a tensor so that
1465 // int64_t values in a 1-d slice of the tensor can be accessed as if part of an
1466 // ArrayView.
1467 template <>
1468 class StridedArrayView<DenseIntElementsAttr> : StridedArrayViewBase {
1469 public:
StridedArrayView(const DenseIntElementsAttr & data,ArrayRef<int64_t> shape,ArrayRef<int64_t> index,int64_t axis)1470 StridedArrayView(const DenseIntElementsAttr &data, ArrayRef<int64_t> shape,
1471 ArrayRef<int64_t> index, int64_t axis)
1472 : StridedArrayViewBase(shape, index, axis), data_(data) {
1473 int64_t element_count = 1;
1474 for (int64_t i = 0, e = shape.size(); i < e; ++i) {
1475 element_count *= shape[i];
1476 }
1477 assert(data.getNumElements() == element_count);
1478 }
1479
1480 using StridedArrayViewBase::NextTensorIndex;
1481 using StridedArrayViewBase::size;
1482
operator [](int64_t i) const1483 int64_t operator[](int64_t i) const {
1484 return data_.getValues<APInt>()[OffsetForIndex(i)].getSExtValue();
1485 }
1486
1487 private:
1488 const DenseIntElementsAttr &data_;
1489 };
1490
1491 // Matches %iota generated from the following mlir codes (rank 2 example):
1492 //
1493 // %iota = mhlo.constant dense<[[0, 1, 2, ..., L],
1494 // [0, 1, 2, ..., L]
1495 // ...
1496 // [0, 1, 2, ..., L]]>,
1497 // where $dimensions is of size 1.
1498 //
1499 // StridedArrayViews are used to check the iota property across the constant
1500 // data so that the iota dimension does not need to be the (inner) z-dimension.
MatchIotaConst(DenseIntElementsAttr dimensions,Value iota)1501 bool MatchIotaConst(DenseIntElementsAttr dimensions, Value iota) {
1502 DenseIntElementsAttr iota_const_attr;
1503 if (!matchPattern(iota, m_Constant(&iota_const_attr))) return false;
1504
1505 auto iota_type = iota_const_attr.getType();
1506 auto iota_shape = iota_type.getShape();
1507 auto reduce_dim = (*dimensions.value_begin<APInt>()).getSExtValue();
1508 if (reduce_dim < 0) reduce_dim += iota_type.getRank();
1509
1510 auto index =
1511 llvm::Optional<SmallVector<int64_t>>(std::in_place, iota_type.getRank());
1512 while (index.has_value()) {
1513 StridedArrayView<DenseIntElementsAttr> array_view(
1514 iota_const_attr, iota_shape, *index, reduce_dim);
1515 for (int64_t i = 0; i < array_view.size(); ++i) {
1516 if (array_view[i] != i) return false;
1517 }
1518 index = StridedArrayView<DenseIntElementsAttr>::NextTensorIndex(
1519 std::move(*index), iota_shape, reduce_dim);
1520 }
1521
1522 return true;
1523 }
1524
1525 // The following 5 different forms of mhlo::iota will be matched:
1526 // 1. IotaOp.
1527 // 2. IotaOp + BroadCastInDim.
1528 // 3. IotaOp + Reshape.
1529 // 4. Constant (folded Iota) + BroadCastInDim.
1530 // 5. Constant (folded result).
1531 // Moreover, the dimensions has to match the iota_dimension.
MatchIota(DenseIntElementsAttr dimensions,Value iota)1532 bool MatchIota(DenseIntElementsAttr dimensions, Value iota) {
1533 return MatchSingleIota(dimensions, iota) ||
1534 MatchIotaBroadCastInDim(dimensions, iota) ||
1535 MatchReshapedIota(dimensions, iota) ||
1536 MatchConstIotaBroadCastInDim(dimensions, iota) ||
1537 MatchIotaConst(dimensions, iota);
1538 }
1539
MatchTopKComparator(Region & comparator)1540 bool MatchTopKComparator(Region &comparator) {
1541 if (!comparator.hasOneBlock()) return false;
1542 Block &comparator_blk = comparator.front();
1543 using OpListType = llvm::iplist<Operation>;
1544 OpListType &operations = comparator_blk.getOperations();
1545 if (operations.size() != 2) return false;
1546 auto compare_op = dyn_cast_or_null<mhlo::CompareOp>(&operations.front());
1547 auto return_op = dyn_cast_or_null<mhlo::ReturnOp>(&operations.back());
1548 if (!compare_op || !return_op) return false;
1549 // TODO(xuanyuanluo): Support mhlo::ComparisonDirection::LT direction.
1550 if (compare_op.comparison_direction() != mhlo::ComparisonDirection::GT)
1551 return false;
1552 if (compare_op.lhs() != comparator_blk.getArgument(0) ||
1553 compare_op.rhs() != comparator_blk.getArgument(1))
1554 return false;
1555 return return_op.getOperands().front() == compare_op.getResult();
1556 }
1557
1558 // In general, we convert the following form of sort to tf.TopK:
1559 //
1560 // %result = "mhlo.sort" (%keys, %indices) ({
1561 // ^bb0(%key_0, %key_1, %index_0, %index_1):
1562 // %1 = "mhlo.compare"(%key_0, %key_1) {mhlo::ComparisonDirection::GT}
1563 // -> tensor<i1>
1564 // }),
1565 //
1566 // where the indices is obtained by an IotaOp (maybe folded).
1567 class ConvertSortToTfTopk : public OpConversionPattern<mhlo::SortOp> {
1568 public:
1569 using OpConversionPattern::OpConversionPattern;
1570
matchAndRewrite(mhlo::SortOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1571 LogicalResult matchAndRewrite(
1572 mhlo::SortOp op, OpAdaptor adaptor,
1573 ConversionPatternRewriter &rewriter) const final {
1574 if (op->getOperands().size() != 2)
1575 return rewriter.notifyMatchFailure(
1576 op, "only match for the case where operands is of size 2");
1577 auto keys = op.operands()[0];
1578 auto indices = op.operands()[1];
1579 auto keys_ty = keys.getType().dyn_cast_or_null<ShapedType>();
1580 auto indices_ty = indices.getType().dyn_cast_or_null<ShapedType>();
1581 if (!keys_ty || !keys_ty.hasStaticShape() ||
1582 !keys_ty.getElementType().isIntOrFloat())
1583 return rewriter.notifyMatchFailure(
1584 op,
1585 "only match for the case where the first operand has a static "
1586 "int/float shapeType");
1587 if (!indices_ty || !indices_ty.hasStaticShape() ||
1588 !indices_ty.getElementType().isInteger(32))
1589 return rewriter.notifyMatchFailure(
1590 op,
1591 "only match for the case where the second operand an I32 shapeType");
1592 auto sort_dim = op.dimension();
1593 auto k = indices_ty.getDimSize(sort_dim);
1594 auto rank = keys_ty.getRank();
1595 if (sort_dim != rank - 1 || k < 1)
1596 return rewriter.notifyMatchFailure(
1597 op, "only match for sort dim = rank - 1 and DimSize >= 1");
1598
1599 // In the following, we'll check indices is obtained by a iota.
1600 auto sort_dim_attr = DenseIntElementsAttr::get(
1601 RankedTensorType::get({1}, rewriter.getI64Type()), {sort_dim});
1602 if (!MatchIota(sort_dim_attr, indices))
1603 return rewriter.notifyMatchFailure(
1604 op, "the second operand is supposed to be obtained from IOTA");
1605 if (!MatchTopKComparator(op.comparator()))
1606 return rewriter.notifyMatchFailure(op, "only match for GT comparator");
1607 ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
1608 Value k_cst = BuildIntConstOp(builder, rewriter, k, rewriter.getI32Type());
1609 rewriter.replaceOpWithNewOp<TopKV2Op>(op, keys.getType(), indices.getType(),
1610 keys, k_cst);
1611 return success();
1612 };
1613 };
1614
1615 // A struct to hold information about dimensions of dot_general operands.
1616 class DotDimensionsInfo {
1617 public:
DotDimensionsInfo(ShapedType type,ArrayRef<int64_t> batch_dimensions,ArrayRef<int64_t> contracting_dimensions)1618 DotDimensionsInfo(ShapedType type, ArrayRef<int64_t> batch_dimensions,
1619 ArrayRef<int64_t> contracting_dimensions) {
1620 const int rank = type.getRank();
1621 for (const int dim : batch_dimensions) {
1622 batch_dimensions_.axes.push_back(dim);
1623 batch_dimensions_.sizes.push_back(type.getDimSize(dim));
1624 }
1625
1626 for (const int dim : contracting_dimensions) {
1627 contracting_dimensions_.axes.push_back(dim);
1628 contracting_dimensions_.sizes.push_back(type.getDimSize(dim));
1629 }
1630
1631 for (int dim = 0; dim < rank; ++dim) {
1632 if (llvm::count(contracting_dimensions_.axes, dim) > 0 ||
1633 llvm::count(batch_dimensions_.axes, dim) > 0) {
1634 continue;
1635 }
1636 out_dimensions_.axes.push_back(dim);
1637 out_dimensions_.sizes.push_back(type.getDimSize(dim));
1638 }
1639 }
1640
batch_dimensions() const1641 const DimensionVector &batch_dimensions() const { return batch_dimensions_; }
contracting_dimensions() const1642 const DimensionVector &contracting_dimensions() const {
1643 return contracting_dimensions_;
1644 }
1645 // Out dimensions are any dimensions that are neither batch nor contracting
1646 // dimensions, hence will be propagated to output shape.
out_dimensions() const1647 const DimensionVector &out_dimensions() const { return out_dimensions_; }
1648
1649 // Returns the total dimension size after flattening all contracting
1650 // dimensions.
FlattenedContractingDimensionSize() const1651 int FlattenedContractingDimensionSize() const {
1652 return std::accumulate(contracting_dimensions_.sizes.begin(),
1653 contracting_dimensions_.sizes.end(), 1,
1654 std::multiplies<int64_t>());
1655 }
1656
1657 // Returns the total dimension size after flattening all out dimensions.
FlattenedOutDimensionSize() const1658 int FlattenedOutDimensionSize() const {
1659 return std::accumulate(out_dimensions_.sizes.begin(),
1660 out_dimensions_.sizes.end(), 1,
1661 std::multiplies<int64_t>());
1662 }
1663
1664 private:
1665 DimensionVector batch_dimensions_;
1666 DimensionVector contracting_dimensions_;
1667 // Out dimensions are any dimensions that are neither batch nor contracting
1668 // dimensions, hence will be propagated to output shape.
1669 DimensionVector out_dimensions_;
1670 };
1671
ConvertDot(PatternRewriter & rewriter,Value lhs,Value rhs,DotDimensionNumbersAttr dot_dimension_numbers,ShapedType result_type,mlir::Location loc)1672 Value ConvertDot(PatternRewriter &rewriter, Value lhs, Value rhs,
1673 DotDimensionNumbersAttr dot_dimension_numbers,
1674 ShapedType result_type, mlir::Location loc) {
1675 auto lhs_type = lhs.getType().cast<ShapedType>();
1676 auto rhs_type = rhs.getType().cast<ShapedType>();
1677 const int lhs_rank = lhs_type.getRank();
1678 const int rhs_rank = rhs_type.getRank();
1679
1680 // Collects lhs and rhs dimensions information.
1681 DotDimensionsInfo lhs_dot_dimensions_info(
1682 lhs_type, dot_dimension_numbers.getLhsBatchingDimensions(),
1683 dot_dimension_numbers.getLhsContractingDimensions());
1684 DotDimensionsInfo rhs_dot_dimensions_info(
1685 rhs_type, dot_dimension_numbers.getRhsBatchingDimensions(),
1686 dot_dimension_numbers.getRhsContractingDimensions());
1687
1688 // Transposes lhs shape to be in the order of {batch_dimensions,
1689 // out_dimensions, contracting dimensions}.
1690 llvm::SmallVector<int64_t, 4> lhs_permutation = Concat<int64_t>(
1691 lhs_dot_dimensions_info.batch_dimensions().AxesArray(),
1692 lhs_dot_dimensions_info.out_dimensions().AxesArray(),
1693 lhs_dot_dimensions_info.contracting_dimensions().AxesArray());
1694 llvm::SmallVector<int64_t, 4> lhs_transposed_shape = Concat<int64_t>(
1695 lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
1696 lhs_dot_dimensions_info.out_dimensions().SizesArray(),
1697 lhs_dot_dimensions_info.contracting_dimensions().SizesArray());
1698 auto lhs_transposed = rewriter.create<mhlo::TransposeOp>(
1699 loc,
1700 RankedTensorType::get(lhs_transposed_shape, lhs_type.getElementType()),
1701 lhs,
1702 DenseIntElementsAttr::get(
1703 RankedTensorType::get({lhs_rank}, rewriter.getI64Type()),
1704 lhs_permutation));
1705
1706 // Transposes rhs shape to be in the order of {batch_dimensions, contracting
1707 // dimensions, out_dimensions}.
1708 llvm::SmallVector<int64_t, 4> rhs_permutation = Concat<int64_t>(
1709 rhs_dot_dimensions_info.batch_dimensions().AxesArray(),
1710 rhs_dot_dimensions_info.contracting_dimensions().AxesArray(),
1711 rhs_dot_dimensions_info.out_dimensions().AxesArray());
1712 llvm::SmallVector<int64_t, 4> rhs_transposed_shape = Concat<int64_t>(
1713 rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
1714 rhs_dot_dimensions_info.contracting_dimensions().SizesArray(),
1715 rhs_dot_dimensions_info.out_dimensions().SizesArray());
1716 auto rhs_transposed = rewriter.create<mhlo::TransposeOp>(
1717 loc,
1718 RankedTensorType::get(rhs_transposed_shape, rhs_type.getElementType()),
1719 rhs,
1720 DenseIntElementsAttr::get(
1721 RankedTensorType::get({rhs_rank}, rewriter.getI64Type()),
1722 rhs_permutation));
1723
1724 // Reshapes lhs to flatten out_dimensions and contracting_dimensions.
1725 llvm::SmallVector<int64_t, 4> lhs_flattened_shape = Concat<int64_t>(
1726 lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
1727 llvm::ArrayRef<int64_t>{
1728 lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
1729 llvm::ArrayRef<int64_t>{
1730 lhs_dot_dimensions_info.FlattenedContractingDimensionSize()});
1731 auto lhs_flattend = rewriter.create<mhlo::ReshapeOp>(
1732 loc,
1733 RankedTensorType::get(lhs_flattened_shape, lhs_type.getElementType()),
1734 lhs_transposed.getResult());
1735
1736 // Reshapes rhs to flatten out_dimensions and contracting_dimensions.
1737 llvm::SmallVector<int64_t, 4> rhs_flattened_shape = Concat<int64_t>(
1738 rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
1739 llvm::ArrayRef<int64_t>{
1740 rhs_dot_dimensions_info.FlattenedContractingDimensionSize()},
1741 llvm::ArrayRef<int64_t>{
1742 rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
1743 auto rhs_flattend = rewriter.create<mhlo::ReshapeOp>(
1744 loc,
1745 RankedTensorType::get(rhs_flattened_shape, rhs_type.getElementType()),
1746 rhs_transposed.getResult());
1747
1748 // Creates matmul op of `lhs_flattend` and `rhs_flattend`.
1749 llvm::SmallVector<int64_t, 4> matmul_shape =
1750 Concat<int64_t>(lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
1751 llvm::ArrayRef<int64_t>{
1752 lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
1753 llvm::ArrayRef<int64_t>{
1754 rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
1755 auto matmul = rewriter.create<TF::BatchMatMulV3Op>(
1756 loc, RankedTensorType::get(matmul_shape, result_type.getElementType()),
1757 lhs_flattend.getResult(), rhs_flattend.getResult());
1758 auto reshaped =
1759 rewriter.create<mhlo::ReshapeOp>(loc, result_type, matmul.getResult());
1760 return reshaped.getResult();
1761 }
1762
1763 // Converts mhlo.dot to tf.MatMul. Reshape ops will be inserted when
1764 // necessary.
ConvertDotOp(PatternRewriter & rewriter,Operation * old_op)1765 Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) {
1766 auto dot_op = cast<mhlo::DotOp>(old_op);
1767 auto lhs_rank = dot_op.lhs().getType().cast<ShapedType>().getRank();
1768 auto dot_dimension_numbers =
1769 DotDimensionNumbersAttr::get(rewriter.getContext(),
1770 /*lhs_batching_dimensions=*/{},
1771 /*rhs_batching_dimensions=*/{},
1772 /*lhs_contracting_dimensions=*/
1773 {lhs_rank == 1 ? 0 : 1},
1774 /*rhs_contracting_dimensions=*/{0});
1775 return ConvertDot(rewriter, dot_op.lhs(), dot_op.rhs(), dot_dimension_numbers,
1776 dot_op.getResult().getType().cast<ShapedType>(),
1777 dot_op.getLoc());
1778 }
1779
1780 // Converts mhlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be
1781 // inserted to convert to well-formed matrix multiply.
ConvertDotGeneralOp(PatternRewriter & rewriter,Operation * old_op)1782 Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
1783 auto dot_general_op = cast<mhlo::DotGeneralOp>(old_op);
1784 return ConvertDot(rewriter, dot_general_op.lhs(), dot_general_op.rhs(),
1785 dot_general_op.dot_dimension_numbers(),
1786 dot_general_op.getResult().getType().cast<ShapedType>(),
1787 dot_general_op.getLoc());
1788 }
1789
1790 // Checks if the specified region is a binary reduction function that takes 2
1791 // inputs, passes it to an instance of the specifiied reduction op and then
1792 // returns the result.
1793 template <typename ReductionOp>
MatchBinaryReduceFunction(mlir::Region & function)1794 LogicalResult MatchBinaryReduceFunction(mlir::Region &function) {
1795 Block &body = function.front();
1796 if (body.getNumArguments() != 2) return failure();
1797
1798 mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
1799 if (!return_op) return failure();
1800 if (return_op.getNumOperands() != 1) return failure();
1801
1802 ReductionOp reduce_op = dyn_cast_or_null<ReductionOp>(
1803 return_op.getOperands().front().getDefiningOp());
1804 if (!reduce_op) return failure();
1805 if (reduce_op.lhs() != body.getArgument(0) ||
1806 reduce_op.rhs() != body.getArgument(1))
1807 return failure();
1808
1809 return success();
1810 }
1811
1812 // Check if the specified region is a binary reduction function that takes 2
1813 // inputs and returns the second input. Functions like this are used by update
1814 // scatter like ops.
1815 template <>
MatchBinaryReduceFunction(mlir::Region & function)1816 LogicalResult MatchBinaryReduceFunction<void>(mlir::Region &function) {
1817 Block &body = function.front();
1818 if (body.getNumArguments() != 2) return failure();
1819
1820 mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
1821 if (!return_op) return failure();
1822 if (return_op.getNumOperands() != 1) return failure();
1823 if (return_op.getOperands().front() != body.getArgument(1)) return failure();
1824 return success();
1825 }
1826
1827 // Replace BinaryOp with a combination of TfBinaryOp and TfReduceOp if the
1828 // init value doesn't match the expection of TfReduceOp.
1829 template <typename TfReduceOp, typename TfBinOp>
rewriteNonMatchInitValue(mhlo::ReduceOp reduce_op,Value input,ConstOp reduction_indices,ConversionPatternRewriter & rewriter)1830 LogicalResult rewriteNonMatchInitValue(mhlo::ReduceOp reduce_op, Value input,
1831 ConstOp reduction_indices,
1832 ConversionPatternRewriter &rewriter) {
1833 Value reduce_result = rewriter.create<TfReduceOp>(
1834 reduce_op.getLoc(), reduce_op.getType(0), input, reduction_indices,
1835 /*keep_dim=*/rewriter.getBoolAttr(false));
1836 rewriter.replaceOpWithNewOp<TfBinOp>(reduce_op, reduce_op.getType(0),
1837 reduce_result,
1838 reduce_op.init_values()[0]);
1839 return success();
1840 }
1841
1842 // Cannot replace BinaryOp if the init value doesn't match the expection of
1843 // TfReduceOp and there is no corresponding TfBinaryOp.
1844 template <>
rewriteNonMatchInitValue(mhlo::ReduceOp reduce_op,Value input,ConstOp reduction_indices,ConversionPatternRewriter & rewriter)1845 LogicalResult rewriteNonMatchInitValue<TF::MaxOp, void>(
1846 mhlo::ReduceOp reduce_op, Value input, ConstOp reduction_indices,
1847 ConversionPatternRewriter &rewriter) {
1848 return failure();
1849 }
1850
1851 template <>
rewriteNonMatchInitValue(mhlo::ReduceOp reduce_op,Value input,ConstOp reduction_indices,ConversionPatternRewriter & rewriter)1852 LogicalResult rewriteNonMatchInitValue<TF::MinOp, void>(
1853 mhlo::ReduceOp reduce_op, Value input, ConstOp reduction_indices,
1854 ConversionPatternRewriter &rewriter) {
1855 return failure();
1856 }
1857
1858 // Converts a mhlo.reduce op with a mlho binary operation into a tensorflow
1859 // reduction operation. If the initial value can be ignored, then convert it
1860 // into a single TfReduceOp. Otherwise, convert it into a TfReduceOp followed by
1861 // a TfBinaryOp.
1862 // For example:
1863 // 1) A mhlo::ReduceOp on value `x` with a mhlo::AndOp and a constant initial
1864 // value `true` is converted to a TF::Any on value `x`.
1865 // 2) A mhlo::ReduceOp on value `x` with a mhlo::AndOp with a non-constant
1866 // initial value `y` is converted to a TF::Any on value `x`, followed by a
1867 // TF::And with initial value `y`.
1868 template <typename BinaryOp, typename TfReduceOp, typename TfBinaryOp = void>
1869 class ConvertReduceOpToTfOp : public OpConversionPattern<mhlo::ReduceOp> {
1870 public:
1871 using OpConversionPattern::OpConversionPattern;
1872
matchAndRewrite(mhlo::ReduceOp reduce_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1873 LogicalResult matchAndRewrite(
1874 mhlo::ReduceOp reduce_op, OpAdaptor adaptor,
1875 ConversionPatternRewriter &rewriter) const final {
1876 if (failed(MatchReduceOpOperand(reduce_op))) return failure();
1877
1878 if (failed(MatchBinaryReduceFunction<BinaryOp>(reduce_op.body())))
1879 return failure();
1880
1881 auto operand = reduce_op.operands()[0];
1882
1883 // Get reduction dimension.
1884 DenseIntElementsAttr dimension = reduce_op.dimensions();
1885 SmallVector<int64_t, 4> reduce_dims;
1886 for (const int64_t &dim : dimension.getValues<int64_t>()) {
1887 reduce_dims.emplace_back(dim);
1888 }
1889 auto dim_type = RankedTensorType::get(
1890 {static_cast<int64_t>(reduce_dims.size())}, rewriter.getI64Type());
1891 auto reduction_indices = rewriter.create<ConstOp>(
1892 reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr(reduce_dims));
1893
1894 // In `MatchReduceOpOperand` function, we already match that the
1895 // "mhlo::ReduceOp" only has one operand, one init_value and one result.
1896
1897 // If the init value matches with the init value expected for the target
1898 // TfReduceOp, then replace the BinaryOp with a TfReduceOp. Otherwise,
1899 // replace the BinaryOp with a TfBinaryOp and a TfReduceOp.
1900 if (succeeded(MatchInitValue(reduce_op.init_values()[0]))) {
1901 rewriter.replaceOpWithNewOp<TfReduceOp>(
1902 reduce_op, reduce_op.getType(0), operand, reduction_indices,
1903 /*keep_dim=*/rewriter.getBoolAttr(false));
1904 return success();
1905 }
1906 return rewriteNonMatchInitValue<TfReduceOp, TfBinaryOp>(
1907 reduce_op, operand, reduction_indices, rewriter);
1908 }
1909
1910 private:
1911 // Checks that the init value matches with the init value expected for the
1912 // target TfReduceOp.
1913 virtual LogicalResult MatchInitValue(Value init_value) const = 0;
1914
1915 // This function tries to match that the "mhlo::ReduceOp" only has one
1916 // operand, one init_value and one result.
MatchReduceOpOperand(mhlo::ReduceOp reduce_op) const1917 LogicalResult MatchReduceOpOperand(mhlo::ReduceOp reduce_op) const {
1918 if (reduce_op.operands().size() != 1 ||
1919 reduce_op.init_values().size() != 1 ||
1920 reduce_op.getResults().size() != 1)
1921 return failure();
1922
1923 if (!reduce_op.operands()[0].getType().isa<RankedTensorType>())
1924 return failure();
1925 if (!reduce_op.getType(0).isa<RankedTensorType>()) return failure();
1926 return success();
1927 }
1928 };
1929
1930 class ConvertReduceOpToTfSum
1931 : public ConvertReduceOpToTfOp<mhlo::AddOp, TF::SumOp, TF::AddOp> {
1932 public:
1933 using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
1934
MatchInitValue(Value init_value) const1935 LogicalResult MatchInitValue(Value init_value) const override {
1936 auto type = init_value.getType().cast<ShapedType>().getElementType();
1937 if (type.isa<FloatType>()) {
1938 APFloat const_value(.0);
1939 if (failed(GetConstantSplatValue(init_value, const_value)) ||
1940 !const_value.isZero())
1941 return failure();
1942 } else if (type.isa<IntegerType>() && type.isSignlessInteger()) {
1943 APInt const_value;
1944 if (failed(GetConstantSplatValue(init_value, const_value)) ||
1945 !const_value.isZero())
1946 return failure();
1947 } else {
1948 return failure();
1949 }
1950
1951 return success();
1952 }
1953 };
1954
1955 class ConvertReduceOpToTfMax
1956 : public ConvertReduceOpToTfOp<mhlo::MaxOp, TF::MaxOp> {
1957 public:
1958 using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
1959
MatchInitValue(Value init_value) const1960 LogicalResult MatchInitValue(Value init_value) const override {
1961 auto type = init_value.getType().cast<ShapedType>().getElementType();
1962 if (type.isa<FloatType>()) {
1963 APFloat const_value(.0);
1964 if (failed(GetConstantSplatValue(init_value, const_value)) ||
1965 !const_value.isInfinity() || !const_value.isNegative())
1966 return failure();
1967 } else if (type.isa<IntegerType>() && type.isSignlessInteger()) {
1968 APInt const_value;
1969 if (failed(GetConstantSplatValue(init_value, const_value)) ||
1970 !const_value.isMinSignedValue())
1971 return failure();
1972 } else {
1973 return failure();
1974 }
1975 return success();
1976 }
1977 };
1978
1979 class ConvertReduceOpToTfMin
1980 : public ConvertReduceOpToTfOp<mhlo::MinOp, TF::MinOp> {
1981 public:
1982 using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
1983
MatchInitValue(Value init_value) const1984 LogicalResult MatchInitValue(Value init_value) const override {
1985 auto type = init_value.getType().cast<ShapedType>().getElementType();
1986
1987 if (type.isa<FloatType>()) {
1988 APFloat const_value(.0);
1989 if (failed(GetConstantSplatValue(init_value, const_value)) ||
1990 !const_value.isInfinity() || const_value.isNegative())
1991 return failure();
1992 } else if (type.isa<IntegerType>() && type.isSignlessInteger()) {
1993 APInt const_value;
1994 if (failed(GetConstantSplatValue(init_value, const_value)) ||
1995 !const_value.isMaxSignedValue())
1996 return failure();
1997 } else {
1998 return failure();
1999 }
2000 return success();
2001 }
2002 };
2003
2004 class ConvertReduceOpToTfAll
2005 : public ConvertReduceOpToTfOp<mhlo::AndOp, TF::AllOp, TF::LogicalAndOp> {
2006 public:
2007 using ConvertReduceOpToTfOp<mhlo::AndOp, TF::AllOp,
2008 TF::LogicalAndOp>::ConvertReduceOpToTfOp;
2009
MatchInitValue(Value init_value) const2010 LogicalResult MatchInitValue(Value init_value) const override {
2011 DenseIntElementsAttr init_attr;
2012 if (!matchPattern(init_value, m_Constant(&init_attr)) ||
2013 !init_attr.getType().getElementType().isInteger(1) ||
2014 !init_attr.isSplat() || !init_attr.getSplatValue<BoolAttr>().getValue())
2015 return failure();
2016 return success();
2017 }
2018 };
2019
2020 class ConvertReduceOpToTfAny
2021 : public ConvertReduceOpToTfOp<mhlo::OrOp, TF::AnyOp, TF::LogicalOrOp> {
2022 public:
2023 using ConvertReduceOpToTfOp<mhlo::OrOp, TF::AnyOp,
2024 TF::LogicalOrOp>::ConvertReduceOpToTfOp;
2025
MatchInitValue(Value init_value) const2026 LogicalResult MatchInitValue(Value init_value) const override {
2027 DenseIntElementsAttr init_attr;
2028 if (!matchPattern(init_value, m_Constant(&init_attr)) ||
2029 !init_attr.getType().getElementType().isInteger(1) ||
2030 !init_attr.isSplat() || init_attr.getSplatValue<BoolAttr>().getValue())
2031 return failure();
2032 return success();
2033 }
2034 };
2035
2036 template <typename TfReduce, typename TfArgReduce>
2037 class ConvertReduceOpToTfArgMinMax
2038 : public OpConversionPattern<mhlo::ReduceOp> {
2039 public:
2040 using OpConversionPattern::OpConversionPattern;
matchAndRewrite(mhlo::ReduceOp reduce_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2041 LogicalResult matchAndRewrite(
2042 mhlo::ReduceOp reduce_op, OpAdaptor adaptor,
2043 ConversionPatternRewriter &rewriter) const final {
2044 if (reduce_op.operands().size() != 2) return failure();
2045 if (reduce_op.dimensions().getNumElements() != 1) return failure();
2046
2047 // Check that the operand init is the expected value.
2048 DenseElementsAttr operand_init;
2049 if (!matchPattern(reduce_op.init_values().front(),
2050 m_Constant(&operand_init)))
2051 return failure();
2052 if (!IsValueInitValue(operand_init)) return failure();
2053
2054 // Check that the iota init is zero.
2055 DenseElementsAttr iota_init;
2056 if (!matchPattern(reduce_op.init_values().back(), m_Constant(&iota_init)))
2057 return failure();
2058 if (iota_init.getValues<APInt>()[0] != 0) return failure();
2059
2060 // Verify that the second argument is an Iota op along the same dimension
2061 // as the reduction.
2062 Value iota = reduce_op.operands().back();
2063 if (!MatchIota(reduce_op.dimensions(), iota)) return failure();
2064
2065 // Match the reduction computation.
2066 const bool is_float = operand_init.getElementType().isa<FloatType>();
2067 if (failed(matchReduceComputation(reduce_op.body(), is_float)))
2068 return failure();
2069
2070 Value operand = reduce_op.operands().front();
2071 int64_t axis = reduce_op.dimensions().getValues<int64_t>()[0];
2072
2073 auto dim_type = RankedTensorType::get({1}, rewriter.getI64Type());
2074 auto reduction_indices = rewriter.create<ConstOp>(
2075 reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr({axis}));
2076
2077 // Generate a Max and an ArgMax of as the mhlo op returns both while in TF
2078 // we have separate ops for them. If only one of them is used then the other
2079 // one will be garbage collected later.
2080 auto tf_reduce_op = rewriter.create<TfReduce>(
2081 reduce_op.getLoc(), reduce_op->getResult(0).getType(), operand,
2082 reduction_indices,
2083 /*keep_dim=*/rewriter.getBoolAttr(false));
2084 auto tf_argreduce_op = rewriter.create<TfArgReduce>(
2085 reduce_op.getLoc(), reduce_op->getResult(1).getType(), operand,
2086 reduction_indices);
2087
2088 rewriter.replaceOp(reduce_op, {tf_reduce_op, tf_argreduce_op});
2089 return success();
2090 }
2091
2092 // Pattern matches the following reduction function for ArgMax/ArgMin:
2093 // %0 = compare{GT}(%lhs_value, %rhs_value)
2094 // %1 = compare{NE}(%lhs_value, %lhs_value)
2095 // %2 = or(%0, %1)
2096 // %3 = select(%2, %lhs_value, %rhs_value)
2097 // %4 = compare{EQ}(%lhs_value, %rhs_value)
2098 // %5 = compare{LT}(%lhs_index, %rhs_index)
2099 // %6 = and(%4, %5)
2100 // %7 = or(%2, %6)
2101 // %8 = select(%7, %lhs_index, %rhs_index)
2102 // return %3, %8
2103 // Also note that %1 may be folded if %lhs_value is of integer types.
matchReduceComputation(Region & computation,bool is_float) const2104 LogicalResult matchReduceComputation(Region &computation,
2105 bool is_float) const {
2106 Block &body = computation.front();
2107 if (body.getNumArguments() != 4) return failure();
2108
2109 mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
2110 if (!return_op || return_op.getNumOperands() != 2) return failure();
2111
2112 mhlo::SelectOp value_select = llvm::dyn_cast_or_null<mhlo::SelectOp>(
2113 return_op.getOperand(0).getDefiningOp());
2114 if (!value_select || value_select.on_true() != body.getArgument(0) ||
2115 value_select.on_false() != body.getArgument(2))
2116 return failure();
2117
2118 if (is_float) {
2119 mhlo::OrOp value_or = llvm::dyn_cast_or_null<mhlo::OrOp>(
2120 value_select.getOperand(0).getDefiningOp());
2121 if (!value_or) return failure();
2122
2123 mhlo::CompareOp value_gt = llvm::dyn_cast_or_null<mhlo::CompareOp>(
2124 value_or.lhs().getDefiningOp());
2125 if (!value_gt || value_gt.comparison_direction() != CompareDirection() ||
2126 value_gt.lhs() != body.getArgument(0) ||
2127 value_gt.rhs() != body.getArgument(2))
2128 return failure();
2129
2130 mhlo::CompareOp value_ne = llvm::dyn_cast_or_null<mhlo::CompareOp>(
2131 value_or.rhs().getDefiningOp());
2132 if (!value_ne ||
2133 value_ne.comparison_direction() != mhlo::ComparisonDirection::NE ||
2134 value_ne.lhs() != body.getArgument(0) ||
2135 value_ne.rhs() != body.getArgument(0))
2136 return failure();
2137 } else {
2138 mhlo::CompareOp value_gt = llvm::dyn_cast_or_null<mhlo::CompareOp>(
2139 value_select.getOperand(0).getDefiningOp());
2140 if (!value_gt || value_gt.comparison_direction() != CompareDirection() ||
2141 value_gt.lhs() != body.getArgument(0) ||
2142 value_gt.rhs() != body.getArgument(2))
2143 return failure();
2144 }
2145
2146 mhlo::SelectOp index_select = llvm::dyn_cast_or_null<mhlo::SelectOp>(
2147 return_op.getOperand(1).getDefiningOp());
2148 if (!index_select || index_select.on_true() != body.getArgument(1) ||
2149 index_select.on_false() != body.getArgument(3))
2150 return failure();
2151
2152 mhlo::OrOp index_or =
2153 llvm::dyn_cast_or_null<mhlo::OrOp>(index_select.pred().getDefiningOp());
2154
2155 if (!index_or || index_or.lhs() != value_select.pred()) return failure();
2156
2157 mhlo::AndOp index_and =
2158 llvm::dyn_cast_or_null<mhlo::AndOp>(index_or.rhs().getDefiningOp());
2159 if (!index_and) return failure();
2160
2161 mhlo::CompareOp value_eq = llvm::dyn_cast_or_null<mhlo::CompareOp>(
2162 index_and.lhs().getDefiningOp());
2163 if (!value_eq ||
2164 value_eq.comparison_direction() != mhlo::ComparisonDirection::EQ ||
2165 value_eq.lhs() != body.getArgument(0) ||
2166 value_eq.rhs() != body.getArgument(2))
2167 return failure();
2168
2169 mhlo::CompareOp index_lt = llvm::dyn_cast_or_null<mhlo::CompareOp>(
2170 index_and.rhs().getDefiningOp());
2171 if (!index_lt ||
2172 index_lt.comparison_direction() != mhlo::ComparisonDirection::LT ||
2173 index_lt.lhs() != body.getArgument(1) ||
2174 index_lt.rhs() != body.getArgument(3))
2175 return failure();
2176
2177 return success();
2178 }
2179
2180 virtual mhlo::ComparisonDirection CompareDirection() const = 0;
2181
2182 virtual bool IsValueInitValue(const DenseElementsAttr &attr) const = 0;
2183 };
2184
2185 class ConvertReduceOpToTfArgmax
2186 : public ConvertReduceOpToTfArgMinMax<TF::MaxOp, TF::ArgMaxOp> {
2187 public:
2188 using ConvertReduceOpToTfArgMinMax::ConvertReduceOpToTfArgMinMax;
2189
CompareDirection() const2190 mhlo::ComparisonDirection CompareDirection() const override {
2191 return mhlo::ComparisonDirection::GT;
2192 }
IsValueInitValue(const DenseElementsAttr & attr) const2193 bool IsValueInitValue(const DenseElementsAttr &attr) const override {
2194 auto element_type = attr.getType().getElementType();
2195 if (attr.getNumElements() != 1 || !element_type.isIntOrFloat() ||
2196 element_type.isInteger(1))
2197 return false;
2198 if (element_type.isa<FloatType>()) {
2199 auto value = *attr.value_begin<APFloat>();
2200 return value.isNegative() && value.isInfinity();
2201 } else {
2202 auto value = *attr.value_begin<APInt>();
2203 return element_type.isUnsignedInteger() ? value.isMinValue()
2204 : value.isMinSignedValue();
2205 }
2206 }
2207 };
2208
2209 class ConvertReduceOpToTfArgmin
2210 : public ConvertReduceOpToTfArgMinMax<TF::MinOp, TF::ArgMinOp> {
2211 public:
2212 using ConvertReduceOpToTfArgMinMax::ConvertReduceOpToTfArgMinMax;
2213
CompareDirection() const2214 mhlo::ComparisonDirection CompareDirection() const override {
2215 return mhlo::ComparisonDirection::LT;
2216 }
IsValueInitValue(const DenseElementsAttr & attr) const2217 bool IsValueInitValue(const DenseElementsAttr &attr) const override {
2218 auto element_type = attr.getType().getElementType();
2219 if (attr.getNumElements() != 1 || !element_type.isIntOrFloat() ||
2220 element_type.isInteger(1))
2221 return false;
2222 if (element_type.isa<FloatType>()) {
2223 auto value = *attr.value_begin<APFloat>();
2224 return !value.isNegative() && value.isInfinity();
2225 } else {
2226 auto value = *attr.value_begin<APInt>();
2227 return element_type.isUnsignedInteger() ? value.isMaxValue()
2228 : value.isMaxSignedValue();
2229 }
2230 }
2231 };
2232
2233 class ConvertIotaOpToTfRange : public OpConversionPattern<mhlo::IotaOp> {
2234 public:
2235 using OpConversionPattern::OpConversionPattern;
2236
matchAndRewrite(mhlo::IotaOp iota_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2237 LogicalResult matchAndRewrite(
2238 mhlo::IotaOp iota_op, OpAdaptor adaptor,
2239 ConversionPatternRewriter &rewriter) const final {
2240 RankedTensorType type =
2241 iota_op.getType().dyn_cast_or_null<RankedTensorType>();
2242 // TF::RangeOp doesn't support UI16.
2243 if (!type || type.getElementType().isUnsignedInteger(16)) return failure();
2244
2245 const uint64_t dimension = iota_op.iota_dimension();
2246 Type element_type = type.getElementType();
2247 Attribute start, limit, delta;
2248 if (element_type.isa<FloatType>()) {
2249 start = rewriter.getFloatAttr(element_type, 0.0);
2250 limit = rewriter.getFloatAttr(element_type, type.getShape()[dimension]);
2251 delta = rewriter.getFloatAttr(element_type, 1.0);
2252 } else if (element_type.isa<IntegerType>()) {
2253 start = rewriter.getIntegerAttr(element_type, 0);
2254 limit = rewriter.getIntegerAttr(element_type, type.getShape()[dimension]);
2255 delta = rewriter.getIntegerAttr(element_type, 1);
2256 } else {
2257 return failure();
2258 }
2259
2260 auto range_type =
2261 RankedTensorType::get({type.getShape()[dimension]}, element_type);
2262 Value start_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), start);
2263 Value limit_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), limit);
2264 Value delta_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), delta);
2265 Value result = rewriter.create<TF::RangeOp>(iota_op.getLoc(), range_type,
2266 start_op, limit_op, delta_op);
2267
2268 if (type.getRank() > 1) {
2269 std::vector<int64_t> reshape_shape(type.getRank(), 1);
2270 reshape_shape[iota_op.iota_dimension()] = type.getShape()[dimension];
2271 auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
2272 Value reshape_shape_op = rewriter.create<TF::ConstOp>(
2273 iota_op.getLoc(), rewriter.getI64TensorAttr(reshape_shape));
2274 result = rewriter.create<TF::ReshapeOp>(iota_op.getLoc(), reshape_type,
2275 result, reshape_shape_op);
2276
2277 Value broadcast_shape_op = rewriter.create<TF::ConstOp>(
2278 iota_op.getLoc(), rewriter.getI64TensorAttr(type.getShape()));
2279 result = rewriter.create<TF::BroadcastToOp>(iota_op.getLoc(), type,
2280 result, broadcast_shape_op);
2281 }
2282
2283 rewriter.replaceOp(iota_op, result);
2284 return success();
2285 }
2286 };
2287
2288 // A helper function for ConvertMaxPoolOp and ConvertAvgMaxPoolOp. Returns true
2289 // if the given ReduceWindowOp is a spatial pooling without dilation. If returns
2290 // true, also outputs the window strides and the TF padding mode ("VALID" or
2291 // "SAME").
IsSpatialPoolingWithoutDilation(mhlo::ReduceWindowOp rw,llvm::SmallVectorImpl<int64_t> * window_strides,std::string * padding_mode)2292 bool IsSpatialPoolingWithoutDilation(
2293 mhlo::ReduceWindowOp rw, llvm::SmallVectorImpl<int64_t> *window_strides,
2294 std::string *padding_mode) {
2295 // tf.max_pool or tf.avg_pool need at least 3 dimensions (batch, spatial,
2296 // channel).
2297 const uint64_t rank = rw.window_dimensions().size();
2298 if (rank <= 2) return false;
2299
2300 if (rw.window_strides().has_value()) {
2301 window_strides->insert(window_strides->end(),
2302 rw.window_strides()->getValues<int64_t>().begin(),
2303 rw.window_strides()->getValues<int64_t>().end());
2304 } else {
2305 window_strides->resize(rank, 1);
2306 }
2307
2308 llvm::SmallVector<int64_t, 10> padding;
2309 if (rw.padding().has_value()) {
2310 padding.insert(padding.begin(), rw.padding()->getValues<int64_t>().begin(),
2311 rw.padding()->getValues<int64_t>().end());
2312 } else {
2313 padding.resize(2 * rank, 0);
2314 }
2315
2316 // Check that we don't do any reduction along the batch (first) and channel
2317 // (last) dimensions.
2318 const uint64_t batch_dim = 0;
2319 const uint64_t channel_dim = rank - 1;
2320 if (rw.window_dimensions().getValues<int64_t>()[batch_dim] != 1 ||
2321 rw.window_dimensions().getValues<int64_t>()[channel_dim] != 1 ||
2322 (*window_strides)[batch_dim] != 1 ||
2323 (*window_strides)[channel_dim] != 1 || padding[2 * batch_dim] != 0 ||
2324 padding[2 * batch_dim + 1] != 0 || padding[2 * channel_dim] != 0 ||
2325 padding[2 * channel_dim + 1] != 0)
2326 return false;
2327
2328 if (rw.window_dilations().has_value() &&
2329 !(rw.window_dilations()->isSplat() &&
2330 rw.window_dilations()->getSplatValue<APInt>() == 1))
2331 return false;
2332
2333 if (rw.base_dilations().has_value() &&
2334 !(rw.base_dilations()->isSplat() &&
2335 rw.base_dilations()->getSplatValue<APInt>() == 1))
2336 return false;
2337
2338 if (llvm::all_of(padding, [](int64_t i) { return i == 0; })) {
2339 *padding_mode = "VALID";
2340 return true;
2341 }
2342
2343 // Check that the individual padding values are corresponding to SAME
2344 // padding from TensorFlow.
2345 auto operand_type = rw.operands()[0].getType().dyn_cast<RankedTensorType>();
2346 RankedTensorType output_type =
2347 rw.getResult(0).getType().dyn_cast<RankedTensorType>();
2348 if (!operand_type || !output_type) return false;
2349
2350 for (uint64_t i = 1; i < rank - 1; ++i) {
2351 int64_t padding_size =
2352 (output_type.getShape()[i] - 1) * (*window_strides)[i] +
2353 rw.window_dimensions().getValues<int64_t>()[i] -
2354 operand_type.getShape()[i];
2355 if (padding[2 * i] != tensorflow::MathUtil::FloorOfRatio(
2356 padding_size, static_cast<int64_t>(2)) ||
2357 padding[2 * i + 1] != tensorflow::MathUtil::CeilOfRatio(
2358 padding_size, static_cast<int64_t>(2)))
2359 return false;
2360 }
2361
2362 *padding_mode = "SAME";
2363 return true;
2364 }
2365
2366 // Convert a reduce_window operation into a cumulative operation where possible
2367 // for a given binary operation.
2368 template <class BinaryOp, class TfCumOp>
2369 class ConvertLoweredCumOp : public OpConversionPattern<mhlo::ReduceWindowOp> {
2370 public:
2371 using OpConversionPattern::OpConversionPattern;
2372
2373 virtual bool IsInitValue(const DenseElementsAttr &attr) const = 0;
2374
matchAndRewrite(mhlo::ReduceWindowOp rw,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2375 LogicalResult matchAndRewrite(
2376 mhlo::ReduceWindowOp rw, OpAdaptor adaptor,
2377 ConversionPatternRewriter &rewriter) const final {
2378 if (rw.getNumResults() != 1 || rw.operands().size() != 1 ||
2379 rw.init_values().size() != 1)
2380 return failure();
2381
2382 if (failed(MatchBinaryReduceFunction<BinaryOp>(rw.body())))
2383 return failure();
2384
2385 // Ensure that initial_values are as expected.
2386 auto const_op = llvm::dyn_cast_or_null<mhlo::ConstantOp>(
2387 rw.init_values()[0].getDefiningOp());
2388 if (!const_op) return failure();
2389 auto const_op_dense_value = const_op.value().cast<DenseElementsAttr>();
2390 if (!const_op_dense_value || !IsInitValue(const_op_dense_value)) {
2391 return failure();
2392 }
2393
2394 auto operand_type = rw.operands()[0].getType().cast<ShapedType>();
2395
2396 // For a cumulative op, require a tensor of 1s for each dimension in
2397 // operand.
2398 auto is_splat_int64_ones =
2399 [&rewriter,
2400 &operand_type](const ::llvm::Optional<DenseIntElementsAttr> &attr) {
2401 // According to the definition, the default value of these attributes
2402 // are all ones when unspecified.
2403 if (!attr.has_value()) return true;
2404 if (attr->getType().getShape()[0] != operand_type.getRank())
2405 return false;
2406 if (!attr->isSplat()) return false;
2407 if (attr->getElementType() != rewriter.getIntegerType(64))
2408 return false;
2409 if (attr->getSplatValue<APInt>().getSExtValue() != 1) return false;
2410 return true;
2411 };
2412 if (!is_splat_int64_ones(rw.base_dilations()) ||
2413 !is_splat_int64_ones(rw.window_dilations()) ||
2414 !is_splat_int64_ones(rw.window_strides()))
2415 return failure();
2416
2417 // Determine which axis is being used for the cumulative operation.
2418 //
2419 // For a cumulative op, window_dimensions should be of the form:
2420 // dense<[1, 1, N, 1]>
2421 // where N is the same as the size of the corresponding input dimension
2422 // and there is a 1-entry for each input dimension not being operated
2423 // over.
2424 const auto &window_dimensions = rw.window_dimensions();
2425 if (window_dimensions.size() != operand_type.getRank()) return failure();
2426 int64_t cumulative_axis = -1;
2427 for (int64_t i = 0, e = window_dimensions.size(); i < e; ++i) {
2428 int64_t window_dimension = window_dimensions.getValues<int64_t>()[i];
2429 if (window_dimension == 1) continue;
2430 // Cumulative axis already set.
2431 if (cumulative_axis != -1) return failure();
2432 // Potential cumulative axis is not the right size.
2433 if (window_dimension != operand_type.getShape()[i]) return failure();
2434 cumulative_axis = i;
2435 }
2436
2437 if (cumulative_axis == -1) {
2438 return rewriter.notifyMatchFailure(rw, "no reduced dimension is found.");
2439 }
2440
2441 // For a cumulative op, padding (expressed as a list of left-padding and
2442 // right-padding pairs) should be of the form:
2443 // dense<[[0, 0], [0, 0], [N-1, 0], [0, 0]]>
2444 // where N is the size of the input dimension being operated over.
2445 if (!rw.padding()) return failure();
2446 const auto &padding = rw.padding()->getValues<int64_t>();
2447 if (padding.size() != operand_type.getRank() * 2) return failure();
2448 int64_t padding_value = operand_type.getShape()[cumulative_axis] - 1;
2449 for (int64_t dim = 0; dim < operand_type.getRank(); ++dim) {
2450 int64_t left_padding = padding[2 * dim];
2451 int64_t right_padding = padding[2 * dim + 1];
2452 if (dim == cumulative_axis) {
2453 if (left_padding != padding_value) return failure();
2454 } else {
2455 if (left_padding != 0) return failure();
2456 }
2457 if (right_padding != 0) return failure();
2458 }
2459
2460 auto axis = rewriter.create<TF::ConstOp>(
2461 rw->getLoc(),
2462 rewriter.getIntegerAttr(rewriter.getIntegerType(64), cumulative_axis));
2463
2464 rewriter.replaceOpWithNewOp<TfCumOp>(rw, rw.getType(0), rw.operands()[0],
2465 axis, /* exclusive */ false,
2466 /* reverse */ false);
2467 return success();
2468 }
2469 };
2470
2471 class ConvertLoweredCumSumOp
2472 : public ConvertLoweredCumOp<mhlo::AddOp, TF::CumsumOp> {
2473 using ConvertLoweredCumOp::ConvertLoweredCumOp;
IsInitValue(const DenseElementsAttr & attr) const2474 bool IsInitValue(const DenseElementsAttr &attr) const override {
2475 auto element_type = attr.getType().getElementType();
2476 if (attr.getNumElements() != 1 || !element_type.isIntOrFloat())
2477 return false;
2478 if (element_type.isa<FloatType>()) {
2479 auto value = *attr.value_begin<APFloat>();
2480 return value.isZero();
2481 }
2482 auto value = *attr.value_begin<APInt>();
2483 return value.isZero();
2484 }
2485 };
2486
2487 class ConvertLoweredCumProdOp
2488 : public ConvertLoweredCumOp<mhlo::MulOp, TF::CumprodOp> {
2489 using ConvertLoweredCumOp::ConvertLoweredCumOp;
IsInitValue(const DenseElementsAttr & attr) const2490 bool IsInitValue(const DenseElementsAttr &attr) const override {
2491 auto element_type = attr.getType().getElementType();
2492 if (attr.getNumElements() != 1 || !element_type.isIntOrFloat())
2493 return false;
2494 if (element_type.isa<FloatType>()) {
2495 auto value = *attr.value_begin<APFloat>();
2496 return value.isExactlyValue(1.0);
2497 }
2498 auto value = *attr.value_begin<APInt>();
2499 return value.getSExtValue() == 1;
2500 }
2501 };
2502
2503 // Maps the following representations of AvgPool in MHLO into a tf.AvgPool{3D}
2504 // operation when they cleanly map to 2D or 3D average pool with VALID or SAME
2505 // padding:
2506 // * div(reduce_sum_window(x), constant(sizeof(window)))
2507 // * div(reduce_sum_window(x), reduce_sum_window(constant(1)))
2508 class ConvertAvgPoolOp : public OpConversionPattern<mhlo::DivOp> {
2509 public:
ConvertAvgPoolOp(MLIRContext * context)2510 explicit ConvertAvgPoolOp(MLIRContext *context)
2511 : OpConversionPattern(context, /*benefit=*/10) {}
2512
matchAndRewrite(mhlo::DivOp div_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2513 LogicalResult matchAndRewrite(
2514 mhlo::DivOp div_op, OpAdaptor adaptor,
2515 ConversionPatternRewriter &rewriter) const final {
2516 auto rw =
2517 dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.lhs().getDefiningOp());
2518 if (!rw || rw->getNumResults() != 1) return failure();
2519
2520 // Check that the reduce-window is a sum-reduce-window.
2521 if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw.body())))
2522 return failure();
2523
2524 // Check that this is a floating point reduce window with a rank of 4 or 5.
2525 const RankedTensorType rw_type =
2526 rw.getResult(0).getType().dyn_cast<RankedTensorType>();
2527 if (!rw_type || !rw_type.getElementType().isa<FloatType>() ||
2528 rw_type.getRank() <= 3 || rw_type.getRank() > 5)
2529 return failure();
2530
2531 // Check that the Div op doesn't do broadcasting on the output of the reduce
2532 // window.
2533 if (div_op.getType() != rw_type) return failure();
2534
2535 // If the init value isn't zero then it can't be an average pool.
2536 if (!isFloatZero(rw.init_values()[0])) return failure();
2537
2538 llvm::SmallVector<int64_t, 5> window_strides;
2539 std::string padding_mode;
2540 if (!IsSpatialPoolingWithoutDilation(rw, &window_strides, &padding_mode)) {
2541 return rewriter.notifyMatchFailure(
2542 div_op, "not the root of spatial pooling without dilation");
2543 }
2544
2545 DenseFPElementsAttr divisor;
2546 if (matchPattern(div_op.rhs(), m_Constant(&divisor))) {
2547 // If the divisor is a constant then check that it matches with the number
2548 // of elements inside the window what is required for a VALID AvgPool.
2549 if (!divisor.isSplat()) return failure();
2550 int64_t window_size = 1;
2551 for (int64_t w : rw.window_dimensions().getValues<int64_t>()) {
2552 window_size *= w;
2553 }
2554 if (!divisor.getSplatValue<APFloat>().isExactlyValue(window_size))
2555 return failure();
2556
2557 if (padding_mode != "VALID") {
2558 return failure();
2559 }
2560
2561 return replaceWithAvgPool(
2562 div_op, rw.operands()[0],
2563 llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
2564 window_strides, "VALID", rewriter);
2565 }
2566
2567 auto rw_rhs =
2568 dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.rhs().getDefiningOp());
2569 if (rw_rhs && rw_rhs.getNumResults() == 1) {
2570 // Check that RHS is a sum-reduce-window.
2571 if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_rhs.body())))
2572 return failure();
2573
2574 // Check that the RHS is a reduce_window over a constant 1 operand with 0
2575 // as the init value.
2576 DenseFPElementsAttr rhs_operand;
2577 if (!isFloatZero(rw_rhs.init_values()[0]) ||
2578 !matchPattern(rw_rhs.operands()[0], m_Constant(&rhs_operand)) ||
2579 !rhs_operand.isSplat() ||
2580 !rhs_operand.getSplatValue<APFloat>().isExactlyValue(1.0))
2581 return failure();
2582
2583 // Check that the two reduce window have the same window configuration.
2584 if (rw.window_dimensions() != rw_rhs.window_dimensions() ||
2585 rw.window_strides() != rw_rhs.window_strides() ||
2586 rw.window_dilations() != rw_rhs.window_dilations() ||
2587 rw.base_dilations() != rw_rhs.base_dilations() ||
2588 rw.padding() != rw_rhs.padding())
2589 return failure();
2590
2591 return replaceWithAvgPool(
2592 div_op, rw.operands()[0],
2593 llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
2594 window_strides, padding_mode, rewriter);
2595 }
2596
2597 return failure();
2598 }
2599
2600 private:
isFloatZero(Value value) const2601 bool isFloatZero(Value value) const {
2602 DenseFPElementsAttr initial_value;
2603 return matchPattern(value, m_Constant(&initial_value)) &&
2604 initial_value.getNumElements() == 1 &&
2605 initial_value.getValues<APFloat>()[0].isZero();
2606 }
2607
replaceWithAvgPool(mhlo::DivOp op,Value input,llvm::ArrayRef<int64_t> ksizes,llvm::ArrayRef<int64_t> kstrides,llvm::StringRef padding,ConversionPatternRewriter & rewriter) const2608 LogicalResult replaceWithAvgPool(mhlo::DivOp op, Value input,
2609 llvm::ArrayRef<int64_t> ksizes,
2610 llvm::ArrayRef<int64_t> kstrides,
2611 llvm::StringRef padding,
2612 ConversionPatternRewriter &rewriter) const {
2613 if (ksizes.size() == 4) {
2614 rewriter.replaceOpWithNewOp<AvgPoolOp>(
2615 op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
2616 rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
2617 rewriter.getStringAttr("NHWC"));
2618 return success();
2619 } else if (ksizes.size() == 5) {
2620 rewriter.replaceOpWithNewOp<AvgPool3DOp>(
2621 op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
2622 rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
2623 rewriter.getStringAttr("NDHWC"));
2624 return success();
2625 }
2626 return failure();
2627 }
2628 };
2629
2630 class ConvertMaxPoolOp : public OpConversionPattern<mhlo::ReduceWindowOp> {
2631 public:
2632 using OpConversionPattern::OpConversionPattern;
2633
matchAndRewrite(mhlo::ReduceWindowOp rw,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2634 LogicalResult matchAndRewrite(
2635 mhlo::ReduceWindowOp rw, OpAdaptor adaptor,
2636 ConversionPatternRewriter &rewriter) const final {
2637 // Check that the reduce-window is a max-reduce-window.
2638 if (failed(MatchBinaryReduceFunction<mhlo::MaxOp>(rw.body())))
2639 return failure();
2640
2641 // Check that this is a floating point reduce window with a rank of 4 or 5.
2642 const RankedTensorType rw_type =
2643 rw.getResult(0).getType().dyn_cast<RankedTensorType>();
2644 if (!rw_type || !rw_type.getElementType().isa<FloatType>() ||
2645 rw_type.getRank() <= 3 || rw_type.getRank() > 5)
2646 return failure();
2647
2648 if (!isFloatMinusInfinity(rw.init_values()[0])) {
2649 return failure();
2650 }
2651
2652 llvm::SmallVector<int64_t, 5> window_strides;
2653 std::string padding_mode;
2654 if (!IsSpatialPoolingWithoutDilation(rw, &window_strides, &padding_mode)) {
2655 return rewriter.notifyMatchFailure(
2656 rw, "not the root of spatial pooling without dilation");
2657 }
2658
2659 return replaceWithMaxPool(
2660 rw, rw.operands()[0],
2661 llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
2662 window_strides, padding_mode, rewriter);
2663 }
2664
2665 private:
isFloatMinusInfinity(Value value) const2666 bool isFloatMinusInfinity(Value value) const {
2667 DenseFPElementsAttr float_value;
2668 if (!matchPattern(value, m_Constant(&float_value))) {
2669 return false;
2670 }
2671
2672 if (float_value.getNumElements() != 1) {
2673 return false;
2674 }
2675
2676 APFloat element = float_value.getValues<APFloat>()[0];
2677 if (!element.isInfinity()) {
2678 return false;
2679 }
2680 if (!element.isNegative()) {
2681 return false;
2682 }
2683
2684 return true;
2685 }
2686
replaceWithMaxPool(mhlo::ReduceWindowOp op,Value input,llvm::ArrayRef<int64_t> ksizes,llvm::ArrayRef<int64_t> kstrides,llvm::StringRef padding,ConversionPatternRewriter & rewriter) const2687 LogicalResult replaceWithMaxPool(mhlo::ReduceWindowOp op, Value input,
2688 llvm::ArrayRef<int64_t> ksizes,
2689 llvm::ArrayRef<int64_t> kstrides,
2690 llvm::StringRef padding,
2691 ConversionPatternRewriter &rewriter) const {
2692 if (ksizes.size() == 4) {
2693 rewriter.replaceOpWithNewOp<MaxPoolOp>(
2694 op, op.getType(0), input, rewriter.getI64ArrayAttr(ksizes),
2695 rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
2696 /*explicit_paddings=*/rewriter.getI64ArrayAttr({}),
2697 rewriter.getStringAttr("NHWC"));
2698 return success();
2699 } else if (ksizes.size() == 5) {
2700 rewriter.replaceOpWithNewOp<MaxPool3DOp>(
2701 op, op.getType(0), input, rewriter.getI64ArrayAttr(ksizes),
2702 rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
2703 rewriter.getStringAttr("NDHWC"));
2704 return success();
2705 }
2706 return failure();
2707 }
2708 };
2709
2710 class LegalizeHloToTf : public TF::LegalizeHloToTfPassBase<LegalizeHloToTf> {
2711 /// Performs the legalization to the TF dialect.
2712 void runOnOperation() override;
2713 };
2714
2715 // Returns the shape of the given value in a Constant Op.
ShapeToConst(PatternRewriter & rewriter,Value value)2716 arith::ConstantOp ShapeToConst(PatternRewriter &rewriter, Value value) {
2717 ArrayRef<int64_t> shape = value.getType().cast<ShapedType>().getShape();
2718 auto attr_type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
2719 rewriter.getIntegerType(64));
2720 auto attr = DenseElementsAttr::get(attr_type, shape);
2721 return rewriter.create<arith::ConstantOp>(value.getLoc(), attr_type, attr);
2722 }
2723
IsSign(APFloat a,APFloat sign)2724 bool IsSign(APFloat a, APFloat sign) {
2725 if (a.isNaN() || a.isZero()) return a == sign;
2726 if (a.isNegative()) return sign.isExactlyValue(-1.0);
2727 return sign.isExactlyValue(1.0);
2728 }
2729
2730 // Returns whether the splat constant is the sign of the FloatTensor
FloatTensorIsSign(PatternRewriter & rewriter,ElementsAttr floatv,ElementsAttr sgn_cst)2731 bool FloatTensorIsSign(PatternRewriter &rewriter, ElementsAttr floatv,
2732 ElementsAttr sgn_cst) {
2733 if (!sgn_cst.isa<SplatElementsAttr>()) return false;
2734 auto sgn_cst_spl = sgn_cst.cast<SplatElementsAttr>().getSplatValue<APFloat>();
2735 if (floatv.isa<SplatElementsAttr>()) {
2736 auto floatv_spl = floatv.cast<SplatElementsAttr>().getSplatValue<APFloat>();
2737 return IsSign(floatv_spl, sgn_cst_spl);
2738 } else if (floatv.isa<DenseElementsAttr>()) {
2739 auto floatv_dns = floatv.cast<DenseFPElementsAttr>();
2740 return llvm::all_of(floatv_dns.getValues<APFloat>(), [&](APFloat value) {
2741 return IsSign(value, sgn_cst_spl);
2742 });
2743 }
2744 return false;
2745 }
2746
2747 // Check that `arr` is an R1 iota with integer element type starting from `0`
2748 // with `size` number of values.
IsIotaAttr(ArrayRef<int64_t> arr,int64_t size)2749 bool IsIotaAttr(ArrayRef<int64_t> arr, int64_t size) {
2750 if (arr.size() != size) return false;
2751 int64_t iota = 0;
2752 for (auto s : arr) {
2753 if (s != iota) return false;
2754 ++iota;
2755 }
2756 return true;
2757 }
2758
2759 // Convert updates into canonical form as expected by tf.scatter ops.
2760 //
2761 // tf.scatter expects `update_window_dims` to be the trailing dimensions.
2762 //
2763 // To support scatter ops generated by numpy-like slice updates:
2764 // nd_array[:, [i,j]] = [i_values, j_values]
2765 //
2766 // `updates` must be transposed when the update_window_dims are the leading
2767 // dimensions of `updates`.
2768 //
2769 // Other values of `update_window_dims` are left unsupported.
2770 //
2771 // Eg 1. An update in canonical form:
2772 // * indices shape(A,B,C)
2773 // * updates shape(A,B,D,E,F)
2774 // Then:
2775 // * D,E,F are the update window dims [2,3,4]
2776 // * C is the index vector dimension
2777 // * A,B iterate over the updates and indices
2778 //
2779 // If `update_window_dims` are not the trailing dimensions then updates must be
2780 // transposed.
2781 //
2782 // Eg 2. An update in non-canonical form:
2783 // * indices shape(a,b,c)
2784 // * updates shape(d,e,f,a,b)
2785 // Then:
2786 // * d,e,f are the update window dims [0,1,2]
2787 // * c is the index vector dimension
2788 // * a,b iterate over the updates and indices
2789 //
2790 // The update needs permuting to be in the form (a,b,d,e,f) so that the update
2791 // window dims are the trailing dimensions.
2792 //
2793 // To canonicalize the updates above, replace the updates with:
2794 // transpose(updates, permutation={3,4,0,1,2})
2795 //
2796 // Note: NormalizeIndexVector is assumed to have run on the indices already so
2797 // that the index_vector_dim is the trailing dimension in `indices`.
CanonicalizeScatterUpdates(Operation * scatter_op,llvm::ArrayRef<int64_t> update_window_dims,const Value & indices,const ShapedType & indices_type,Value & updates,ShapedType & updates_type,ConversionPatternRewriter & rewriter)2798 LogicalResult CanonicalizeScatterUpdates(
2799 Operation *scatter_op, llvm::ArrayRef<int64_t> update_window_dims,
2800 const Value &indices, const ShapedType &indices_type, Value &updates,
2801 ShapedType &updates_type, ConversionPatternRewriter &rewriter) {
2802 auto canonical_update_window_dims = llvm::to_vector(
2803 llvm::seq<int64_t>(indices_type.getRank() - 1, updates_type.getRank()));
2804
2805 if (canonical_update_window_dims == update_window_dims) return success();
2806
2807 // Permute updates if `update_window_dims` are leading indices.
2808 // Other possibilities for `update_window_dims` are not supported yet.
2809 if (!IsIotaAttr(update_window_dims, update_window_dims.size()))
2810 return rewriter.notifyMatchFailure(
2811 scatter_op, "update_window_dims are not leading or trailing indices");
2812
2813 SmallVector<int64_t, 4> permutation_array(updates_type.getRank());
2814 int64_t dim = 0;
2815 // Move leading indices to the back of the array.
2816 const auto permutation_array_size = permutation_array.size();
2817 for (int64_t i = update_window_dims.size(); i < permutation_array_size; ++i) {
2818 permutation_array[i] = dim;
2819 ++dim;
2820 }
2821 // Move trailing indices to the front of the array.
2822 for (int64_t i = 0; i < update_window_dims.size(); ++i) {
2823 permutation_array[i] = dim;
2824 ++dim;
2825 }
2826
2827 auto permutation_and_shape = GetPermutationAndTransposedShape(
2828 permutation_array, updates_type, rewriter);
2829
2830 auto transposed_updates = rewriter.create<mhlo::TransposeOp>(
2831 scatter_op->getLoc(), permutation_and_shape.shape, updates,
2832 permutation_and_shape.permutation);
2833
2834 updates = transposed_updates;
2835 updates_type = permutation_and_shape.shape;
2836 return success();
2837 }
2838
2839 // If index_vector_dim == indices.rank() then insert the implicit extra
2840 // dimension into indices to normalize everything to index_vector_dim ==
2841 // indices.rank() - 1.
NormalizeIndexVector(Operation * parent_op,Value & indices,ShapedType & indices_type,int64_t index_vector_dim,ConversionPatternRewriter & rewriter)2842 LogicalResult NormalizeIndexVector(Operation *parent_op, Value &indices,
2843 ShapedType &indices_type,
2844 int64_t index_vector_dim,
2845 ConversionPatternRewriter &rewriter) {
2846 if (index_vector_dim == indices_type.getRank()) {
2847 llvm::SmallVector<int64_t, 4> new_start_indices_shape(
2848 indices_type.getShape().begin(), indices_type.getShape().end());
2849 new_start_indices_shape.push_back(1);
2850 indices_type = RankedTensorType::get(new_start_indices_shape,
2851 indices_type.getElementType());
2852 indices = rewriter.create<mhlo::ReshapeOp>(parent_op->getLoc(),
2853 indices_type, indices);
2854 } else if (index_vector_dim != indices_type.getRank() - 1) {
2855 // If index_vector_dim isn't the last dimension in indices then it isn't
2856 // supported yet.
2857 // TODO(tberghammer): Transpose indices to support this usecase.
2858 return rewriter.notifyMatchFailure(
2859 parent_op,
2860 "index vector dim isn't the last dimension in start indices");
2861 }
2862 return success();
2863 }
2864
2865 class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
2866 public:
2867 using OpConversionPattern::OpConversionPattern;
2868
2869 // Helper params for representing the transpose params for the "canonicalized"
2870 // output to the real output.
2871 struct TransposeParams {
2872 std::vector<int64_t> permutation;
2873 // The following are the "canonicalized" output shape with offset dims.
2874 std::vector<int64_t> canonicalized_output_shape;
2875 std::vector<int64_t> canonicalized_offset_dims;
2876 };
2877
matchAndRewrite(mhlo::GatherOp gather_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2878 LogicalResult matchAndRewrite(
2879 mhlo::GatherOp gather_op, OpAdaptor adaptor,
2880 ConversionPatternRewriter &rewriter) const final {
2881 Value operand = gather_op.operand();
2882 Value start_indices = gather_op.start_indices();
2883
2884 // Can only convert with static shaped gather.
2885 ShapedType operand_type = operand.getType().cast<ShapedType>();
2886 ShapedType start_indices_type = start_indices.getType().cast<ShapedType>();
2887 ShapedType result_type = gather_op.getResult().getType().cast<ShapedType>();
2888 if (!operand_type.hasStaticShape() ||
2889 !start_indices_type.hasStaticShape() || !result_type.hasStaticShape()) {
2890 return failure();
2891 }
2892
2893 // Normalize start_indices so index_vector_dim == start_indices.rank() - 1.
2894 int64_t index_vector_dim =
2895 gather_op.dimension_numbers().getIndexVectorDim();
2896 if (failed(NormalizeIndexVector(gather_op, start_indices,
2897 start_indices_type, index_vector_dim,
2898 rewriter))) {
2899 return failure();
2900 }
2901
2902 // Verify that start_index_map and collapsed_slice_dims contains the same
2903 // values.
2904 auto start_index_map = gather_op.dimension_numbers().getStartIndexMap();
2905 auto collapsed_slice_dims =
2906 gather_op.dimension_numbers().getCollapsedSliceDims();
2907 if (start_index_map.size() != collapsed_slice_dims.size()) {
2908 return rewriter.notifyMatchFailure(
2909 gather_op,
2910 "different size for start index map and collapsed slice dims");
2911 }
2912 for (auto c : collapsed_slice_dims) {
2913 if (llvm::count(start_index_map, c) == 0) {
2914 return rewriter.notifyMatchFailure(
2915 gather_op, "collapsed slice dim isn't present in start index map");
2916 }
2917 }
2918
2919 // Verify that slice_sizes is 1 for the indexed dimensions and the full
2920 // shape for the rest of the dimensions.
2921 auto slice_sizes = gather_op.slice_sizes();
2922 int64_t index = 0;
2923 for (int64_t s : slice_sizes.getValues<int64_t>()) {
2924 if (llvm::count(start_index_map, index)) {
2925 if (s != 1) {
2926 return rewriter.notifyMatchFailure(gather_op,
2927 "unsupported slice sizes");
2928 }
2929 } else {
2930 if (s != operand_type.getShape()[index]) {
2931 return rewriter.notifyMatchFailure(gather_op,
2932 "unsupported slice sizes");
2933 }
2934 }
2935 ++index;
2936 }
2937
2938 // Verify that offset_dims are the tailing dimensions in the output tensor.
2939 auto offset_dims = gather_op.dimension_numbers().getOffsetDims();
2940 SmallVector<int64_t, 4> offset_dims_vector(offset_dims.begin(),
2941 offset_dims.end());
2942 const TransposeParams &transpose_params =
2943 CanonicalizeOffset(/*result_type=*/result_type,
2944 /*original_offset_dims=*/offset_dims_vector);
2945
2946 int64_t offset = start_indices_type.getRank() - 1;
2947 for (int64_t o : transpose_params.canonicalized_offset_dims) {
2948 if (o != offset) {
2949 return rewriter.notifyMatchFailure(gather_op,
2950 "unsupported offset dims");
2951 }
2952 ++offset;
2953 }
2954
2955 // Transpose the operand to handle non-iota start index map.
2956 llvm::SmallVector<int64_t, 4> transpose_dimensions;
2957 llvm::SmallVector<int64_t, 4> transpose_shape;
2958 for (auto s : start_index_map) {
2959 transpose_dimensions.push_back(s);
2960 transpose_shape.push_back(operand_type.getShape()[s]);
2961 }
2962 for (int64_t i = 0, e = operand_type.getRank(); i < e; ++i) {
2963 if (llvm::count(start_index_map, i) == 0) {
2964 transpose_dimensions.push_back(i);
2965 transpose_shape.push_back(operand_type.getShape()[i]);
2966 }
2967 }
2968 operand_type =
2969 RankedTensorType::get(transpose_shape, operand_type.getElementType());
2970 operand = rewriter.create<mhlo::TransposeOp>(
2971 gather_op.getLoc(), operand_type, operand,
2972 rewriter.getI64TensorAttr(transpose_dimensions));
2973
2974 // Check whether we need to append a transpose op after the gather nd.
2975 bool need_transpose_after = false;
2976 for (int i = 0; i < transpose_params.permutation.size(); ++i) {
2977 if (i != transpose_params.permutation[i]) {
2978 need_transpose_after = true;
2979 break;
2980 }
2981 }
2982
2983 auto tf_gather_nd_result_type =
2984 RankedTensorType::get(transpose_params.canonicalized_output_shape,
2985 result_type.getElementType());
2986 auto tf_gather_nd_op = rewriter.create<TF::GatherNdOp>(
2987 gather_op->getLoc(), tf_gather_nd_result_type, operand, start_indices);
2988 if (!need_transpose_after) {
2989 rewriter.replaceOp(gather_op, tf_gather_nd_op->getOpResults());
2990 return success();
2991 }
2992
2993 // Insert the transpose op after the gather_nd.
2994 rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(
2995 gather_op, result_type, tf_gather_nd_op,
2996 rewriter.getI64TensorAttr(transpose_params.permutation));
2997
2998 return success();
2999 }
3000
3001 private:
3002 // Canonicalize the offset dims to make sure the offset dims are the trailing
3003 // dimensions of the output tensor.
3004 // We will also return the permutation for (the transpose op).
3005 // However, it's not guaranteed the canonicalized offset dims can make it
3006 // always legalizable to tf.
CanonicalizeOffset(ShapedType result_type,ArrayRef<int64_t> original_offset_dims) const3007 TransposeParams CanonicalizeOffset(
3008 ShapedType result_type, ArrayRef<int64_t> original_offset_dims) const {
3009 TransposeParams transpose_params;
3010 int output_rank = result_type.getRank();
3011 // The canonicalized offset should be the trailing of the output rank.
3012 for (int start = output_rank - original_offset_dims.size();
3013 start < output_rank; ++start) {
3014 transpose_params.canonicalized_offset_dims.push_back(start);
3015 }
3016
3017 // For those dims NOT inside the original_offset_dims are considered "batch
3018 // dims".
3019 std::vector<int64_t> batch_dims;
3020 // Offset dims are guaranteed to be sorted.
3021 int offset_index = 0;
3022 for (int64_t i = 0; i < output_rank; ++i) {
3023 if (offset_index >= original_offset_dims.size() ||
3024 original_offset_dims[offset_index] != i) {
3025 batch_dims.push_back(i);
3026 } else {
3027 ++offset_index;
3028 }
3029 }
3030
3031 // Populate the trnaspose permutation params from a "canonicalized" output
3032 // to the real output.
3033 // The canonicalized layout would be batch_dims followed by sliced_dims.
3034 // The current layout is essentially a transpose after the canonicalized
3035 // layout.
3036 // Take the following as an example:
3037 // If we have the:
3038 // original_offset_dims like [1, 2, 4]
3039 // batch_dims like [0, 3]
3040 // It's like performing transpose on a "canonicalized"
3041 // [batch_dims, sliced_dims]: [B1, B2, O1, O2, O3]
3042 // into the current layout: [B1, O1, O2, B2, O3]
3043 // where the permutation is [0, 2, 3, 1, 4]
3044 int batch_idx = 0;
3045 int offset_idx = 0;
3046 int batch_dim_size = batch_dims.size();
3047 for (int i = 0; i < output_rank; ++i) {
3048 if (batch_idx >= batch_dims.size()) {
3049 transpose_params.permutation.push_back(batch_dim_size + offset_idx);
3050 ++offset_idx;
3051 } else if (offset_idx < original_offset_dims.size() &&
3052 original_offset_dims[offset_idx] < batch_dims[batch_idx]) {
3053 transpose_params.permutation.push_back(batch_dim_size + offset_idx);
3054 ++offset_idx;
3055 } else {
3056 transpose_params.permutation.push_back(batch_idx++);
3057 }
3058 }
3059
3060 // Finally, let's find out what are the "canonicalized" output shape looks
3061 // like.
3062 for (auto dim : batch_dims) {
3063 transpose_params.canonicalized_output_shape.push_back(
3064 result_type.getDimSize(dim));
3065 }
3066 for (auto dim : original_offset_dims) {
3067 transpose_params.canonicalized_output_shape.push_back(
3068 result_type.getDimSize(dim));
3069 }
3070 return transpose_params;
3071 }
3072 };
3073
3074 class ConvertWhileOp : public OpConversionPattern<mhlo::WhileOp> {
3075 public:
3076 using OpConversionPattern::OpConversionPattern;
3077
matchAndRewrite(mhlo::WhileOp while_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3078 LogicalResult matchAndRewrite(
3079 mhlo::WhileOp while_op, OpAdaptor adaptor,
3080 ConversionPatternRewriter &rewriter) const final {
3081 // HLO WhileOp should have two regions: cond and body.
3082 if (while_op->getNumRegions() != 2) return failure();
3083
3084 // This rule doesn't support mhlo::WhileOp with tuple inputs.
3085 for (auto type : while_op->getOperandTypes()) {
3086 if (type.isa<TupleType>()) return failure();
3087 }
3088
3089 // Creates a TF::WhileRegionOp to replace the mhlo::WhileOp. HLO WhileOp
3090 // currently doesn't support stateless and shape invariant, so these
3091 // parameters are set to the default values.
3092 auto new_while = rewriter.create<TF::WhileRegionOp>(
3093 while_op.getLoc(), while_op->getResultTypes(), while_op->getOperands(),
3094 /*parallel_iterations=*/10,
3095 /*is_stateless=*/false, /*shape_invariant=*/false);
3096 new_while.cond().takeBody(while_op.cond());
3097 new_while.body().takeBody(while_op.body());
3098 ReplaceReturnOp(new_while.cond(), rewriter);
3099 ReplaceReturnOp(new_while.body(), rewriter);
3100 rewriter.replaceOp(while_op, new_while.getResults());
3101 return success();
3102 }
3103 };
3104
3105 class ConvertIfOp : public OpConversionPattern<mhlo::IfOp> {
3106 public:
3107 using OpConversionPattern::OpConversionPattern;
3108
matchAndRewrite(mhlo::IfOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3109 LogicalResult matchAndRewrite(
3110 mhlo::IfOp op, OpAdaptor adaptor,
3111 ConversionPatternRewriter &rewriter) const final {
3112 // HLO IfOp currently doesn't support stateless
3113 auto new_op = rewriter.create<TF::IfRegionOp>(
3114 op.getLoc(), op->getResultTypes(), op.pred(),
3115 /*is_stateless=*/false, /*_then_func_name=*/nullptr,
3116 /*_else_func_name=*/nullptr);
3117 new_op.then_branch().takeBody(op.true_branch());
3118 new_op.else_branch().takeBody(op.false_branch());
3119 ReplaceReturnOp(new_op.then_branch(), rewriter);
3120 ReplaceReturnOp(new_op.else_branch(), rewriter);
3121 rewriter.replaceOp(op, new_op.getResults());
3122 return success();
3123 }
3124 };
3125
3126 template <typename BinaryOp, typename TfOp>
3127 class ConvertScatterOp : public OpConversionPattern<mhlo::ScatterOp> {
3128 public:
3129 using OpConversionPattern::OpConversionPattern;
3130
matchAndRewrite(mhlo::ScatterOp scatter_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3131 LogicalResult matchAndRewrite(
3132 mhlo::ScatterOp scatter_op, OpAdaptor adaptor,
3133 ConversionPatternRewriter &rewriter) const final {
3134 OperandRange operands = scatter_op.operands();
3135 Value indices = scatter_op.scatter_indices();
3136 OperandRange updates = scatter_op.updates();
3137 if (operands.size() != 1 || updates.size() != 1) return failure();
3138
3139 ShapedType operand_type = operands[0].getType().cast<ShapedType>();
3140 ShapedType indices_type = indices.getType().cast<ShapedType>();
3141 ShapedType updates_type = updates[0].getType().cast<ShapedType>();
3142
3143 Value new_updates = updates[0];
3144
3145 // Can only convert with static shaped scatter.
3146 if (!operand_type.hasStaticShape() || !indices_type.hasStaticShape() ||
3147 !updates_type.hasStaticShape()) {
3148 return failure();
3149 }
3150
3151 // Match the scatter computation against computations supported by TF.
3152 if (failed(MatchBinaryReduceFunction<BinaryOp>(
3153 scatter_op.update_computation()))) {
3154 return failure();
3155 }
3156
3157 auto scatter_dimension_numbers = scatter_op.scatter_dimension_numbers();
3158
3159 // Normalize indices so index_vector_dim == indices.rank() - 1.
3160 int64_t index_vector_dim = scatter_dimension_numbers.getIndexVectorDim();
3161 if (failed(NormalizeIndexVector(scatter_op, indices, indices_type,
3162 index_vector_dim, rewriter))) {
3163 return failure();
3164 }
3165
3166 // Transform updates so that update window dims are the trailing dimensions
3167 // in the update tensor.
3168 auto update_window_dims = scatter_dimension_numbers.getUpdateWindowDims();
3169 if (failed(CanonicalizeScatterUpdates(scatter_op, update_window_dims,
3170 indices, indices_type, new_updates,
3171 updates_type, rewriter))) {
3172 return failure();
3173 }
3174
3175 auto inserted_window_dims =
3176 scatter_dimension_numbers.getInsertedWindowDims();
3177 auto scatter_dims_to_operand_dims =
3178 scatter_dimension_numbers.getScatterDimsToOperandDims();
3179
3180 if (IsIotaAttr(inserted_window_dims, indices_type.getShape().back()) &&
3181 IsIotaAttr(scatter_dims_to_operand_dims,
3182 indices_type.getShape().back())) {
3183 rewriter.replaceOpWithNewOp<TfOp>(scatter_op,
3184 scatter_op.getResult(0).getType(),
3185 operands[0], indices, new_updates);
3186 return success();
3187 }
3188 // Insert tranposes to support scatter operations generated from
3189 // numpy-like slice operations:
3190 // nd_array[:, [i,j]] = [i_values, j_values]
3191 //
3192 if (scatter_dims_to_operand_dims != inserted_window_dims) {
3193 // Support only dimension numbers generated by numpy-like slice
3194 // operations.
3195 return rewriter.notifyMatchFailure(
3196 scatter_op, "unsupported scatter_dims_to_operand_dims");
3197 }
3198
3199 // Transpose the operand and so that the trailing dimensions of the
3200 // operand are being updated. Then apply a tf.scatter op and transpose
3201 // back the result to get the same shape as the original operand.
3202
3203 SmallVector<int64_t, 4> permutation_array;
3204 for (int64_t i = 0; i < scatter_dims_to_operand_dims.size(); ++i) {
3205 permutation_array.push_back(scatter_dims_to_operand_dims[i]);
3206 }
3207 for (int64_t i = 0; i < operand_type.getRank(); ++i) {
3208 if (!llvm::is_contained(scatter_dims_to_operand_dims, i)) {
3209 permutation_array.push_back(i);
3210 }
3211 }
3212 auto permutation_and_shape = GetPermutationAndTransposedShape(
3213 permutation_array, operand_type, rewriter);
3214
3215 Location loc = scatter_op.getLoc();
3216 auto transposed_operand = rewriter.create<mhlo::TransposeOp>(
3217 loc, permutation_and_shape.shape, operands[0],
3218 permutation_and_shape.permutation);
3219
3220 // Apply TF scatter to update the trailing dimensions of the
3221 // transposed operand.
3222 auto tf_scatter_op =
3223 rewriter.create<TfOp>(loc, permutation_and_shape.shape,
3224 transposed_operand, indices, new_updates);
3225
3226 // Reverse the earlier transpose.
3227 auto inverse_permutation =
3228 GetInversePermutation(permutation_array, rewriter);
3229 rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(
3230 scatter_op, scatter_op.getResult(0).getType(), tf_scatter_op,
3231 inverse_permutation);
3232
3233 return success();
3234 }
3235 };
3236 using ConvertScatterAddOp =
3237 ConvertScatterOp<mhlo::AddOp, TF::TensorScatterAddOp>;
3238 using ConvertScatterMaxOp =
3239 ConvertScatterOp<mhlo::MaxOp, TF::TensorScatterMaxOp>;
3240 using ConvertScatterMinOp =
3241 ConvertScatterOp<mhlo::MinOp, TF::TensorScatterMinOp>;
3242 using ConvertScatterSubOp =
3243 ConvertScatterOp<mhlo::SubtractOp, TF::TensorScatterSubOp>;
3244 using ConvertScatterUpdateOp =
3245 ConvertScatterOp<void, TF::TensorScatterUpdateOp>;
3246
3247 // Converts mhlo.pad to tf.PadV2
ConvertPadOp(PatternRewriter & rewriter,Operation * old_op)3248 Value ConvertPadOp(PatternRewriter &rewriter, Operation *old_op) {
3249 auto pad_op = cast<mhlo::PadOp>(old_op);
3250 mlir::Location loc = pad_op.getLoc();
3251
3252 llvm::SmallVector<APInt, 8> padding;
3253 for (auto p : llvm::zip(pad_op.edge_padding_low().getValues<APInt>(),
3254 pad_op.edge_padding_high().getValues<APInt>())) {
3255 padding.push_back(std::get<0>(p));
3256 padding.push_back(std::get<1>(p));
3257 }
3258 auto attr_type = RankedTensorType::get({pad_op.edge_padding_low().size(), 2},
3259 rewriter.getI64Type());
3260 auto padding_attr = DenseIntElementsAttr::get(attr_type, padding);
3261 auto padding_op =
3262 rewriter.create<arith::ConstantOp>(loc, attr_type, padding_attr);
3263 return rewriter.create<PadV2Op>(loc, pad_op.getType(), pad_op.operand(),
3264 padding_op, pad_op.padding_value());
3265 }
3266
3267 // Returns true if broadcast_dimensions obey Tensorflow convention, as in new
3268 // dimensions are added as prefix.
IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions,Value output)3269 bool IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions,
3270 Value output) {
3271 // broadcast_dimensions is an increasing list by definition, thus it suffices
3272 // to check the first element.
3273 int64_t input_rank = broadcast_dimensions.getNumElements();
3274 int64_t output_rank = output.getType().cast<ShapedType>().getRank();
3275 return input_rank == 0 ||
3276 (broadcast_dimensions.getValues<APInt>()[0].getSExtValue() ==
3277 output_rank - input_rank);
3278 }
3279
3280 // Returns the intermediate shape that input tensor should be reshaped to during
3281 // legalization of BroadcastInDimOp.
ExpandedShape(PatternRewriter & rewriter,Value input,DenseIntElementsAttr broadcast_dimensions,Value output)3282 arith::ConstantOp ExpandedShape(PatternRewriter &rewriter, Value input,
3283 DenseIntElementsAttr broadcast_dimensions,
3284 Value output) {
3285 // Initialize expanded shape with output rank and dimensions of 1.
3286 SmallVector<Attribute, 4> expanded_shape(
3287 output.getType().cast<ShapedType>().getRank(),
3288 /*Value=*/rewriter.getI64IntegerAttr(1));
3289
3290 // Set dimension sizes specified by broadcast_dimensions.
3291 ArrayRef<int64_t> input_shape = input.getType().cast<ShapedType>().getShape();
3292 for (auto x : llvm::enumerate(broadcast_dimensions)) {
3293 expanded_shape[x.value().getSExtValue()] =
3294 rewriter.getI64IntegerAttr(input_shape[x.index()]);
3295 }
3296
3297 // Create the expanded type wrapped in a arith::ConstantOp.
3298 auto attr_type =
3299 RankedTensorType::get({static_cast<int64_t>(expanded_shape.size())},
3300 rewriter.getIntegerType(64));
3301 auto attr = DenseElementsAttr::get(attr_type, expanded_shape);
3302 return rewriter.create<arith::ConstantOp>(output.getLoc(), attr_type, attr);
3303 }
3304
3305 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc"
3306
3307 /// Performs the lowering to XLA dialect.
runOnOperation()3308 void LegalizeHloToTf::runOnOperation() {
3309 MLIRContext &context = getContext();
3310
3311 // Add legalization patterns to the list.
3312 RewritePatternSet patterns(&getContext());
3313 PopulateLegalizeHloToTfPatterns(&patterns, &context);
3314
3315 ConversionTarget target(context);
3316 target.addLegalDialect<TensorFlowDialect>();
3317 target.addLegalOp<func::CallOp, func::ConstantOp, arith::ConstantOp>();
3318 target.addLegalOp<mhlo::TupleOp>();
3319 if (failed(applyPartialConversion(getOperation(), target,
3320 std::move(patterns)))) {
3321 getOperation().emitError("mhlo to TF legalization failed.");
3322 signalPassFailure();
3323 }
3324 }
3325
3326 } // end namespace
3327
PopulateLegalizeHloToTfPatterns(RewritePatternSet * patterns,MLIRContext * context)3328 void PopulateLegalizeHloToTfPatterns(RewritePatternSet *patterns,
3329 MLIRContext *context) {
3330 patterns->add<
3331 ConvertAvgPoolOp, Convert2DConvOp, Convert1DConvOp,
3332 ConvertNonTrivialConvOp, ConvertDynamicSliceOp,
3333 ConvertDynamicUpdateSliceOp, ConvertGatherOp, ConvertIfOp,
3334 ConvertMaxPoolOp, ConvertScatterAddOp, ConvertScatterMaxOp,
3335 ConvertScatterMinOp, ConvertScatterSubOp, ConvertScatterUpdateOp,
3336 ConvertSliceOp, ConvertReduceOpToTfArgmax, ConvertReduceOpToTfArgmin,
3337 ConvertReduceOpToTfMax, ConvertReduceOpToTfMin, ConvertReduceOpToTfAll,
3338 ConvertReduceOpToTfAny, ConvertReduceOpToTfSum, ConvertSortToTfTopk,
3339 ConvertIotaOpToTfRange, ConvertWhileOp, ConvertLoweredCumSumOp,
3340 ConvertLoweredCumProdOp>(context);
3341 populateWithGenerated(*patterns);
3342 }
3343
CreateLegalizeHloToTfPass()3344 std::unique_ptr<OperationPass<func::FuncOp>> CreateLegalizeHloToTfPass() {
3345 return std::make_unique<LegalizeHloToTf>();
3346 }
3347
3348 } // end namespace TF
3349 } // end namespace mlir
3350