xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for legalizing HLO to TensorFlow.
17 
18 #include <cstddef>
19 #include <cstdint>
20 #include <cstdlib>
21 #include <functional>
22 #include <memory>
23 #include <numeric>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/None.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/STLForwardCompat.h"
34 #include "llvm/ADT/Sequence.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
40 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
41 #include "mlir/IR/Attributes.h"  // from @llvm-project
42 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
43 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
44 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
45 #include "mlir/IR/ImplicitLocOpBuilder.h"  // from @llvm-project
46 #include "mlir/IR/Location.h"  // from @llvm-project
47 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
48 #include "mlir/IR/Matchers.h"  // from @llvm-project
49 #include "mlir/IR/Operation.h"  // from @llvm-project
50 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
51 #include "mlir/IR/Region.h"  // from @llvm-project
52 #include "mlir/IR/Value.h"  // from @llvm-project
53 #include "mlir/Pass/Pass.h"  // from @llvm-project
54 #include "mlir/Support/LLVM.h"  // from @llvm-project
55 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
56 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
57 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
58 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
59 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
60 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
61 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
62 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/broadcast_utils.h"
63 #include "tensorflow/core/framework/kernel_shape_util.h"
64 #include "tensorflow/core/lib/math/math_util.h"
65 
66 namespace mlir {
67 namespace TF {
68 namespace {
69 
70 using mhlo::DotDimensionNumbersAttr;
71 
72 // Replaces `region`'s terminator to TF::Yield.
ReplaceReturnOp(Region & region,PatternRewriter & rewriter)73 void ReplaceReturnOp(Region &region, PatternRewriter &rewriter) {
74   OpBuilder::InsertionGuard guard(rewriter);
75 
76   for (auto &block : region.getBlocks()) {
77     Operation *terminator = block.getTerminator();
78     auto return_op = llvm::dyn_cast_or_null<mhlo::ReturnOp>(terminator);
79     if (return_op == nullptr) continue;
80 
81     rewriter.setInsertionPoint(return_op);
82     rewriter.replaceOpWithNewOp<TF::YieldOp>(return_op,
83                                              return_op->getOperands());
84   }
85 }
86 
87 // If `value` is a splat constant, returns a success and set `splat_value`
88 // to the splate constant value.
89 // `SplatValueType` can be `APInt` or `APFloat`.
90 template <typename SplatValueType>
GetConstantSplatValue(Value value,SplatValueType & splat_value)91 LogicalResult GetConstantSplatValue(Value value, SplatValueType &splat_value) {
92   DenseElementsAttr attr;
93   if (!matchPattern(value, m_Constant(&attr)) || !attr.isSplat()) {
94     return failure();
95   }
96 
97   splat_value = attr.getSplatValue<SplatValueType>();
98   return success();
99 }
100 
101 struct PermutationAndShape {
102   DenseIntElementsAttr permutation;
103   ShapedType shape;
104 };
105 
106 // Returns a DenseIntElementsAttr for a permutation and the shape after
107 // applying the permutation to a given shape through a transpose.
GetPermutationAndTransposedShape(llvm::ArrayRef<int64_t> permutation_array,ShapedType input_type,ConversionPatternRewriter & rewriter)108 PermutationAndShape GetPermutationAndTransposedShape(
109     llvm::ArrayRef<int64_t> permutation_array, ShapedType input_type,
110     ConversionPatternRewriter &rewriter) {
111   assert(permutation_array.size() == input_type.getRank());
112   llvm::SmallVector<int64_t> transposed_shape(permutation_array.size());
113   for (int64_t i = 0; i < permutation_array.size(); ++i) {
114     transposed_shape[i] = input_type.getDimSize(permutation_array[i]);
115   }
116   auto transposed_type =
117       RankedTensorType::get(transposed_shape, input_type.getElementType());
118   DenseIntElementsAttr permutation = DenseIntElementsAttr::get(
119       RankedTensorType::get(permutation_array.size(), rewriter.getI64Type()),
120       permutation_array);
121   return {permutation, transposed_type};
122 }
123 
124 // Returns the inverse permutation array for a permutation array.
GetInversePermutationArray(llvm::ArrayRef<int64_t> permutation_array)125 llvm::SmallVector<int64_t> GetInversePermutationArray(
126     llvm::ArrayRef<int64_t> permutation_array) {
127   llvm::SmallVector<int64_t> inverse_permutation_array(
128       permutation_array.size());
129   const auto permutation_array_size = permutation_array.size();
130   for (int64_t i = 0; i < permutation_array_size; ++i) {
131     inverse_permutation_array[permutation_array[i]] = i;
132   }
133   return inverse_permutation_array;
134 }
135 
136 // Returns the DenseIntElementsAttr for an inverse permutation given a
137 // permutation_array.
GetInversePermutation(llvm::ArrayRef<int64_t> permutation_array,ConversionPatternRewriter & rewriter)138 DenseIntElementsAttr GetInversePermutation(
139     llvm::ArrayRef<int64_t> permutation_array,
140     ConversionPatternRewriter &rewriter) {
141   SmallVector<int64_t, 4> inverse_permutation_array =
142       GetInversePermutationArray(permutation_array);
143   return DenseIntElementsAttr::get(
144       RankedTensorType::get(inverse_permutation_array.size(),
145                             rewriter.getI64Type()),
146       inverse_permutation_array);
147 }
148 
149 // Returns a DenseIntElementsAttr for an inverse permutation and the shape after
150 // applying the inverse permutation to a given shape through a transpose.
GetInversePermutationAndShape(llvm::ArrayRef<int64_t> permutation_array,ShapedType input_type,ConversionPatternRewriter & rewriter)151 PermutationAndShape GetInversePermutationAndShape(
152     llvm::ArrayRef<int64_t> permutation_array, ShapedType input_type,
153     ConversionPatternRewriter &rewriter) {
154   SmallVector<int64_t, 4> inverse_permutation_array =
155       GetInversePermutationArray(permutation_array);
156   return GetPermutationAndTransposedShape(inverse_permutation_array, input_type,
157                                           rewriter);
158 }
159 
160 // Common functionality for ConvertConvOp classes.
161 template <int SupportedSpatialDims>
162 struct ConvertNdConvOp {
IsSupportedConvOpmlir::TF::__anonab43b57b0111::ConvertNdConvOp163   bool IsSupportedConvOp(mhlo::ConvolutionOp conv_op) const {
164     if (!conv_op.lhs().getType().cast<ShapedType>().hasStaticShape() ||
165         !conv_op.rhs().getType().cast<ShapedType>().hasStaticShape() ||
166         !conv_op.getType().cast<ShapedType>().hasStaticShape())
167       return false;
168 
169     // All ones in "lhs_dilation" means this "mhlo.conv" op should be
170     // converted to "tf.Conv2D" or "tf.DepthwiseConv2dNativeOp".
171     if (conv_op.lhs_dilation().has_value()) {
172       auto lhs_dilation = conv_op.lhs_dilation().getValue();
173       if (!lhs_dilation.isSplat() || lhs_dilation.getSplatValue<int64_t>() != 1)
174         return false;
175     }
176 
177     if (!conv_op.window_strides().has_value() || conv_op.window_strides()
178                                                          .getValue()
179                                                          .getType()
180                                                          .cast<ShapedType>()
181                                                          .getRank() != 1)
182       return false;
183 
184     auto num_spatial_dims =
185         conv_op.dimension_numbers().getInputSpatialDimensions().size();
186     // TODO(b/158636600): Currently we don't support 3D Convolution.
187     if (num_spatial_dims != SupportedSpatialDims) return false;
188 
189     return true;
190   }
191 };
192 
193 // Convert a 1-D convolution into a 2-D convolution (which TF supports) so that
194 // it can be rewritten by the pattern `Convert2DConvOp`.
195 class Convert1DConvOp : public OpConversionPattern<mhlo::ConvolutionOp>,
196                         ConvertNdConvOp<1> {
197  public:
198   using OpConversionPattern::OpConversionPattern;
199 
matchAndRewrite(mhlo::ConvolutionOp conv_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const200   LogicalResult matchAndRewrite(
201       mhlo::ConvolutionOp conv_op, OpAdaptor adaptor,
202       ConversionPatternRewriter &rewriter) const final {
203     //
204     // Check that input is a supported 1d convolution.
205     //
206 
207     if (!IsSupportedConvOp(conv_op) || conv_op->getNumResults() != 1)
208       return rewriter.notifyMatchFailure(conv_op, "unsupported conv op.");
209 
210     const mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
211 
212     // Group convolution is not supported yet.
213     const int64_t input_feature_dimension = dnums.getInputFeatureDimension();
214     const int64_t input_channels =
215         conv_op.lhs().getType().cast<ShapedType>().getDimSize(
216             input_feature_dimension);
217     const int64_t feature_group_count = conv_op.feature_group_count();
218     if (feature_group_count != 1 && feature_group_count != input_channels)
219       return rewriter.notifyMatchFailure(conv_op,
220                                          "Group convolution is not supported,");
221 
222     //
223     // Transpose and reshape the input and kernel
224     //
225 
226     // Reshape input image to add a new spatial dimension.
227     auto image_type = conv_op.lhs().getType().cast<ShapedType>();
228     SmallVector<int64_t, 4> image_2d_shape(image_type.getShape().begin(),
229                                            image_type.getShape().end());
230     image_2d_shape.push_back(1);
231     auto image_2d_type =
232         RankedTensorType::get(image_2d_shape, image_type.getElementType());
233     auto image_2d_op = rewriter.create<mhlo::ReshapeOp>(
234         conv_op.getLoc(), image_2d_type, conv_op.lhs());
235 
236     // Transpose image to get it into NWHC form (where H is the added dim).
237     SmallVector<int64_t, 4> image_permutation = {
238         dnums.getInputBatchDimension(), dnums.getInputSpatialDimensions()[0],
239         3,  // The trailing dim that we added.
240         dnums.getInputFeatureDimension()};
241     auto image_permutation_and_shape = GetPermutationAndTransposedShape(
242         image_permutation, image_2d_type, rewriter);
243     auto transposed_image_2d_op = rewriter.create<mhlo::TransposeOp>(
244         conv_op.getLoc(), image_permutation_and_shape.shape,
245         image_2d_op->getResult(0), image_permutation_and_shape.permutation);
246 
247     // Reshape kernel to add a new spatial dimension.
248     auto kernel_type = conv_op.rhs().getType().cast<ShapedType>();
249     SmallVector<int64_t, 4> kernel_2d_shape;
250     for (int64_t dim : kernel_type.getShape()) {
251       kernel_2d_shape.push_back(dim);
252     }
253     kernel_2d_shape.push_back(1);
254     auto kernel_2d_type =
255         RankedTensorType::get(kernel_2d_shape, kernel_type.getElementType());
256     auto kernel_2d_op = rewriter.create<mhlo::ReshapeOp>(
257         conv_op.getLoc(), kernel_2d_type, conv_op.rhs());
258 
259     // Transpose kernel to get it into WHIO form (where H is the added dim).
260     SmallVector<int64_t, 4> kernel_permutation = {
261         dnums.getKernelSpatialDimensions()[0],
262         3,  // The trailing dim that we added.
263         dnums.getKernelInputFeatureDimension(),
264         dnums.getKernelOutputFeatureDimension()};
265     auto kernel_permutation_and_shape = GetPermutationAndTransposedShape(
266         kernel_permutation, kernel_2d_type, rewriter);
267     auto transposed_kernel_2d_op = rewriter.create<mhlo::TransposeOp>(
268         conv_op.getLoc(), kernel_permutation_and_shape.shape,
269         kernel_2d_op->getResult(0), kernel_permutation_and_shape.permutation);
270 
271     //
272     // Create 2d equivalents for 1d convolution attributes.
273     //
274 
275     // Window Strides
276     SmallVector<int64_t, 2> window_strides_2d_array;
277     for (const auto v : conv_op.window_strides()->getValues<int64_t>()) {
278       window_strides_2d_array.emplace_back(v);
279     }
280     window_strides_2d_array.push_back(1);
281     auto window_strides_2d = DenseIntElementsAttr::get(
282         RankedTensorType::get({2}, rewriter.getI64Type()),
283         window_strides_2d_array);
284 
285     // Padding
286     SmallVector<int64_t, 4> padding_2d_array;
287     for (const auto v : conv_op.padding().getValue().getValues<int64_t>()) {
288       padding_2d_array.emplace_back(v);
289     }
290     // The newly added spatial dimension requires zero left and right padding.
291     padding_2d_array.push_back(0);
292     padding_2d_array.push_back(0);
293     auto padding_2d = DenseIntElementsAttr::get(
294         RankedTensorType::get({2, 2}, rewriter.getI64Type()), padding_2d_array);
295 
296     // LHS dilation
297     SmallVector<int64_t, 4> lhs_dilation_array_2d;
298     for (const auto v :
299          conv_op.lhs_dilation().getValue().getValues<int64_t>()) {
300       lhs_dilation_array_2d.emplace_back(v);
301     }
302     lhs_dilation_array_2d.push_back(1);
303     auto lhs_dilation_2d = DenseIntElementsAttr::get(
304         RankedTensorType::get({2}, rewriter.getI64Type()),
305         lhs_dilation_array_2d);
306 
307     // RHS dilation
308     SmallVector<int64_t, 4> rhs_dilation_array_2d;
309     for (const auto v :
310          conv_op.rhs_dilation().getValue().getValues<int64_t>()) {
311       rhs_dilation_array_2d.emplace_back(v);
312     }
313     rhs_dilation_array_2d.push_back(1);
314     auto rhs_dilation_2d = DenseIntElementsAttr::get(
315         RankedTensorType::get({2}, rewriter.getI64Type()),
316         rhs_dilation_array_2d);
317 
318     // Window reversal is unsupported.
319     if (conv_op.window_reversal().has_value() &&
320         conv_op.window_reversal()->getValues<bool>()[0] == true)
321       return failure();
322     auto window_reversal_2d = DenseIntElementsAttr::get(
323         RankedTensorType::get({2}, rewriter.getI64Type()),
324         SmallVector<int64_t>({0, 0}));
325 
326     // Precision config
327     if (!conv_op.precision_config().has_value()) return failure();
328 
329     // Dimension numbers reflect the form of the 2d conv op NWHC * WHIO -> NWHC
330     auto dnums_2d =
331         mhlo::ConvDimensionNumbersAttr::get(rewriter.getContext(),
332                                             /*inputBatchDimension=*/0,
333                                             /*inputFeatureDimension=*/3,
334                                             /*inputSpatialDimensions=*/{1, 2},
335                                             /*kernelInputDimension=*/2,
336                                             /*kernelOutputDimension=*/3,
337                                             /*kernelSpatialDimensions=*/{0, 1},
338                                             /*outputBatchDimension=*/0,
339                                             /*outputFeatureDimension=*/3,
340                                             /*outputSpatialDimensions=*/{1, 2});
341     //
342     // Generate a 2-D convolution
343     //
344 
345     // Determine the 2-D convolution output shape.
346     auto output_type = conv_op->getResult(0).getType().cast<ShapedType>();
347     SmallVector<int64_t, 4> output_2d_shape;
348     for (int64_t dim : output_type.getShape()) {
349       output_2d_shape.push_back(dim);
350     }
351     output_2d_shape.push_back(1);
352     auto output_2d_type =
353         RankedTensorType::get(output_2d_shape, output_type.getElementType());
354     SmallVector<int64_t, 4> output_permutation = {
355         dnums.getOutputBatchDimension(), dnums.getOutputSpatialDimensions()[0],
356         3,  // The trailing dim that we added.
357         dnums.getOutputFeatureDimension()};
358     auto transposed_output_2d_shape =
359         GetPermutationAndTransposedShape(output_permutation, output_2d_type,
360                                          rewriter)
361             .shape;
362 
363     auto conv2d_op = rewriter.create<mhlo::ConvolutionOp>(
364         conv_op.getLoc(), transposed_output_2d_shape,
365         transposed_image_2d_op.getResult(), transposed_kernel_2d_op.getResult(),
366         window_strides_2d, padding_2d, lhs_dilation_2d, rhs_dilation_2d,
367         window_reversal_2d, dnums_2d, conv_op.feature_group_count(),
368         conv_op.batch_group_count(), *conv_op.precision_config());
369 
370     OpResult conv2d_output = conv2d_op->getResult(0);
371     auto conv2d_output_type = conv2d_output.getType().cast<ShapedType>();
372 
373     //
374     // Transpose and reshape the output
375     //
376 
377     // Since output is in NWHC form we need to undo the permutation we have
378     // affectively applied.
379     auto output_permutation_and_shape = GetInversePermutationAndShape(
380         output_permutation, conv2d_output_type, rewriter);
381     auto transposed_output_2d_op = rewriter.create<mhlo::TransposeOp>(
382         conv_op.getLoc(), output_permutation_and_shape.shape, conv2d_output,
383         output_permutation_and_shape.permutation);
384 
385     // Drop the trailing spatial dimension from the output.
386     rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
387         conv_op, output_type, transposed_output_2d_op.getResult());
388     return success();
389   }
390 };
391 
392 class Convert2DConvOp : public OpConversionPattern<mhlo::ConvolutionOp>,
393                         ConvertNdConvOp<2> {
394  public:
395   using OpConversionPattern::OpConversionPattern;
396 
matchAndRewrite(mhlo::ConvolutionOp conv_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const397   LogicalResult matchAndRewrite(
398       mhlo::ConvolutionOp conv_op, OpAdaptor adaptor,
399       ConversionPatternRewriter &rewriter) const final {
400     if (!IsSupportedConvOp(conv_op)) {
401       return failure();
402     }
403 
404     // Constructs strides array.
405     // For example, [2, 3] -> [1, 2, 3, 1].
406     SmallVector<int64_t, 4> strides({1});
407     for (const auto v :
408          conv_op.window_strides().getValue().getValues<int64_t>()) {
409       strides.emplace_back(v);
410     }
411     strides.emplace_back(1);
412 
413     // Constructs dilation array.
414     SmallVector<int64_t, 4> dilation;
415     if (auto rhs_dilation = conv_op.rhs_dilation()) {
416       // For example, [2, 3] -> [1, 2, 3, 1].
417       dilation.emplace_back(1);
418       dilation.append(rhs_dilation.getValue().getValues<int64_t>().begin(),
419                       rhs_dilation.getValue().getValues<int64_t>().end());
420       dilation.emplace_back(1);
421     } else {
422       // Default value
423       dilation = {1, 1, 1, 1};
424     }
425 
426     mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
427     const int input_feature_dimension = dnums.getInputFeatureDimension();
428     const int input_channels =
429         conv_op.lhs().getType().cast<ShapedType>().getDimSize(
430             input_feature_dimension);
431     int feature_group_count = conv_op.feature_group_count();
432 
433     if (feature_group_count != 1 && feature_group_count != input_channels) {
434       // Group convolution is not supported yet.
435       return failure();
436     }
437 
438     const int num_spatial_dims = dnums.getInputSpatialDimensions().size();
439     const bool is_depthwise_conv = input_channels == feature_group_count;
440     std::string padding;
441     SmallVector<int64_t, 8> explicit_padding;
442     if (!conv_op.padding().has_value() ||
443         (conv_op.padding().getValue().isSplat() &&
444          conv_op.padding()->getSplatValue<int64_t>() == 0)) {
445       padding = "VALID";
446     } else {
447       SmallVector<int64_t, 4> padding_array;
448       for (const auto v : conv_op.padding().getValue().getValues<int64_t>()) {
449         padding_array.emplace_back(v);
450       }
451 
452       if (IsSamePadding(conv_op, num_spatial_dims, strides, dilation,
453                         padding_array)) {
454         // Check if padding is "SAME".
455         padding = "SAME";
456       } else {
457         padding = "EXPLICIT";
458         explicit_padding.push_back(0);
459         explicit_padding.push_back(0);
460         explicit_padding.append(padding_array);
461         explicit_padding.push_back(0);
462         explicit_padding.push_back(0);
463       }
464     }
465 
466     CreateConvOp(conv_op, strides, padding, explicit_padding, dilation,
467                  is_depthwise_conv, input_channels, num_spatial_dims, rewriter);
468     return success();
469   };
470 
471  private:
IsSamePadding(mhlo::ConvolutionOp conv_op,int num_spatial_dims,ArrayRef<int64_t> strides,ArrayRef<int64_t> dilation,ArrayRef<int64_t> padding_array) const472   bool IsSamePadding(mhlo::ConvolutionOp conv_op, int num_spatial_dims,
473                      ArrayRef<int64_t> strides, ArrayRef<int64_t> dilation,
474                      ArrayRef<int64_t> padding_array) const {
475     mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
476     auto input_spatial_dim = dnums.getInputSpatialDimensions();
477     auto kernel_spatial_dim = dnums.getKernelSpatialDimensions();
478     for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
479       int dim = i + 1;
480       int64_t output_size;
481       int64_t pad_low_int64;
482       int64_t pad_high_int64;
483       tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
484           conv_op.lhs().getType().cast<ShapedType>().getDimSize(
485               input_spatial_dim[i]),
486           conv_op.rhs().getType().cast<ShapedType>().getDimSize(
487               kernel_spatial_dim[i]),
488           dilation[dim], strides[dim], tensorflow::Padding::SAME, &output_size,
489           &pad_low_int64, &pad_high_int64);
490       if (!status.ok()) return false;
491       if (padding_array[2 * i] != pad_low_int64 ||
492           padding_array[2 * i + 1] != pad_high_int64)
493         return false;
494     }
495 
496     return true;
497   }
498 
499   // Returns true if the op needs reformat.
NeedsReformatTypeAndPermutation(int batch_dim,int feature_dim,int spatial_dim_start,int default_batch_dim,int default_feature_dim,int default_spatial_dim_start) const500   bool NeedsReformatTypeAndPermutation(int batch_dim, int feature_dim,
501                                        int spatial_dim_start,
502                                        int default_batch_dim,
503                                        int default_feature_dim,
504                                        int default_spatial_dim_start) const {
505     return batch_dim != default_batch_dim ||
506            feature_dim != default_feature_dim ||
507            spatial_dim_start != default_spatial_dim_start;
508   }
509 
510   // Gets reformat type and permutation attribute. Call this function only if
511   // NeedsReformatTypeAndPermutation returns true.
512   std::pair<RankedTensorType, DenseIntElementsAttr>
GetReformatTypeAndPermutation(int batch_dim,int feature_dim,int spatial_dim_start,int default_batch_dim,int default_feature_dim,int default_spatial_dim_start,int num_spatial_dims,RankedTensorType type,ConversionPatternRewriter & rewriter) const513   GetReformatTypeAndPermutation(int batch_dim, int feature_dim,
514                                 int spatial_dim_start, int default_batch_dim,
515                                 int default_feature_dim,
516                                 int default_spatial_dim_start,
517                                 int num_spatial_dims, RankedTensorType type,
518                                 ConversionPatternRewriter &rewriter) const {
519     auto shape = type.getShape();
520     llvm::SmallVector<int64_t, 4> permutation_array(num_spatial_dims + 2);
521     permutation_array[default_batch_dim] = batch_dim;
522     permutation_array[default_feature_dim] = feature_dim;
523     llvm::SmallVector<int64_t, 4> transposed_shape(num_spatial_dims + 2);
524     transposed_shape[default_batch_dim] = shape[batch_dim];
525     transposed_shape[default_feature_dim] = shape[feature_dim];
526     for (int i : llvm::seq<int>(0, num_spatial_dims)) {
527       permutation_array[default_spatial_dim_start + i] = spatial_dim_start + i;
528       transposed_shape[default_spatial_dim_start + i] =
529           shape[spatial_dim_start + i];
530     }
531     auto new_type =
532         RankedTensorType::get(transposed_shape, type.getElementType());
533     auto permutation = DenseIntElementsAttr::get(
534         RankedTensorType::get({type.getRank()}, rewriter.getI64Type()),
535         permutation_array);
536     return {new_type, permutation};
537   }
538 
FormatToNHWC(Value value,int batch_dim,int feature_dim,ArrayRef<int64_t> spatial_dimensions,int default_batch_dim,int default_feature_dim,int default_spatial_dim_start,int num_spatial_dims,ConversionPatternRewriter & rewriter) const539   Value FormatToNHWC(Value value, int batch_dim, int feature_dim,
540                      ArrayRef<int64_t> spatial_dimensions,
541                      int default_batch_dim, int default_feature_dim,
542                      int default_spatial_dim_start, int num_spatial_dims,
543                      ConversionPatternRewriter &rewriter) const {
544     auto type = value.getType().cast<RankedTensorType>();
545     DenseIntElementsAttr permutation;
546     const int spatial_dim_start = spatial_dimensions.front();
547     if (!NeedsReformatTypeAndPermutation(
548             batch_dim, feature_dim, spatial_dim_start, default_batch_dim,
549             default_feature_dim, default_spatial_dim_start)) {
550       // Transpose is not needed because the current format is "NHWC".
551       return value;
552     }
553     std::pair<RankedTensorType &, DenseIntElementsAttr &>(type, permutation) =
554         GetReformatTypeAndPermutation(batch_dim, feature_dim, spatial_dim_start,
555                                       default_batch_dim, default_feature_dim,
556                                       default_spatial_dim_start,
557                                       num_spatial_dims, type, rewriter);
558     return rewriter.create<mhlo::TransposeOp>(value.getLoc(), type, value,
559                                               permutation);
560   }
561 
562   // Slices the input `value` if there are negative padding values in
563   // `explicit_padding`.
SliceNegativePadding(Value value,ArrayRef<int64_t> explicit_padding,ConversionPatternRewriter & rewriter) const564   Value SliceNegativePadding(Value value, ArrayRef<int64_t> explicit_padding,
565                              ConversionPatternRewriter &rewriter) const {
566     // If no padding is negative return the input as is.
567     if (llvm::all_of(explicit_padding, [](int64_t pad) { return pad >= 0; })) {
568       return value;
569     }
570 
571     auto input_type = value.getType().cast<RankedTensorType>();
572     auto input_shape = input_type.getShape();
573 
574     llvm::SmallVector<int64_t, 4> start;
575     llvm::SmallVector<int64_t, 4> size;
576     start.reserve(explicit_padding.size() / 2);
577     size.reserve(explicit_padding.size() / 2);
578     for (int i = 0, e = explicit_padding.size() / 2; i < e; ++i) {
579       int64_t pre_padding = explicit_padding[2 * i];
580       int64_t post_padding = explicit_padding[2 * i + 1];
581       int64_t pre_slice = pre_padding < 0 ? -pre_padding : 0;
582       int64_t post_slice = post_padding < 0 ? -post_padding : 0;
583       start.push_back(pre_slice);
584       size.push_back(input_shape[i] - pre_slice - post_slice);
585     }
586 
587     auto start_attr = rewriter.create<ConstOp>(
588         value.getLoc(),
589         DenseIntElementsAttr::get(
590             RankedTensorType::get({static_cast<int64_t>(start.size())},
591                                   rewriter.getI64Type()),
592             start));
593     auto size_attr = rewriter.create<ConstOp>(
594         value.getLoc(),
595         DenseIntElementsAttr::get(
596             RankedTensorType::get({static_cast<int64_t>(size.size())},
597                                   rewriter.getI64Type()),
598             size));
599     auto output_type = RankedTensorType::get(size, input_type.getElementType());
600 
601     return rewriter.create<SliceOp>(value.getLoc(), output_type, value,
602                                     start_attr, size_attr);
603   }
604 
CreateConvOp(mhlo::ConvolutionOp conv_op,ArrayRef<int64_t> strides,StringRef padding,ArrayRef<int64_t> explicit_padding,ArrayRef<int64_t> dilation,bool is_depthwise_conv,int input_channels,int num_spatial_dims,ConversionPatternRewriter & rewriter) const605   void CreateConvOp(mhlo::ConvolutionOp conv_op, ArrayRef<int64_t> strides,
606                     StringRef padding, ArrayRef<int64_t> explicit_padding,
607                     ArrayRef<int64_t> dilation, bool is_depthwise_conv,
608                     int input_channels, int num_spatial_dims,
609                     ConversionPatternRewriter &rewriter) const {
610     mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
611     // Transposes lhs and rhs if their formats are not NHWC.
612     Value lhs = FormatToNHWC(
613         conv_op.lhs(), dnums.getInputBatchDimension(),
614         dnums.getInputFeatureDimension(), dnums.getInputSpatialDimensions(),
615         /*default_batch_dim=*/0, /*default_feature_dim=*/num_spatial_dims + 1,
616         /*default_spatial_dim_start=*/1, num_spatial_dims, rewriter);
617     Value rhs = FormatToNHWC(
618         conv_op.rhs(), dnums.getKernelInputFeatureDimension(),
619         dnums.getKernelOutputFeatureDimension(),
620         dnums.getKernelSpatialDimensions(),
621         /*default_batch_dim=*/num_spatial_dims,
622         /*default_feature_dim=*/num_spatial_dims + 1,
623         /*default_spatial_dim_start=*/0, num_spatial_dims, rewriter);
624 
625     // Emulate negative padding with a slice and remove negative values from the
626     // padding vector.
627     Value sliced_lhs = SliceNegativePadding(lhs, explicit_padding, rewriter);
628     auto new_padding = llvm::to_vector<4>(llvm::map_range(
629         explicit_padding, [](int64_t dim) { return dim > 0 ? dim : 0; }));
630 
631     auto conv_output_type = conv_op.getType().cast<RankedTensorType>();
632     DenseIntElementsAttr permutation;
633     const bool need_transpose_output = NeedsReformatTypeAndPermutation(
634         dnums.getOutputBatchDimension(), dnums.getOutputFeatureDimension(),
635         dnums.getOutputSpatialDimensions().front(),
636         /*default_batch_dim=*/0, /*default_feature_dim=*/num_spatial_dims + 1,
637         /*default_spatial_dim_start=*/1);
638     if (need_transpose_output) {
639       std::pair<RankedTensorType &, DenseIntElementsAttr &>(conv_output_type,
640                                                             permutation) =
641           GetReformatTypeAndPermutation(
642               dnums.getOutputBatchDimension(),
643               dnums.getOutputFeatureDimension(),
644               dnums.getOutputSpatialDimensions().front(),
645               /*default_batch_dim=*/0,
646               /*default_feature_dim=*/num_spatial_dims + 1,
647               /*default_spatial_dim_start=*/1, num_spatial_dims,
648               conv_output_type, rewriter);
649     }
650     Value output;
651     if (is_depthwise_conv) {
652       // Reshapes filter format to [filter_height, filter_width, in_channels,
653       // channel_multiplier] from HLO's [filter_height, filter_width, 1,
654       // in_channels * channel_multiplier] format.
655       auto filter_type = rhs.getType().cast<ShapedType>();
656       llvm::ArrayRef<int64_t> hlo_filter_shape = filter_type.getShape();
657       llvm::SmallVector<int64_t, 4> tf_filter_shape(hlo_filter_shape.begin(),
658                                                     hlo_filter_shape.end());
659       tf_filter_shape[2] = input_channels;
660       tf_filter_shape[3] = hlo_filter_shape.back() / input_channels;
661       auto reshaped_filter = rewriter.create<mhlo::ReshapeOp>(
662           rhs.getLoc(),
663           RankedTensorType::get(tf_filter_shape, filter_type.getElementType()),
664           rhs);
665 
666       output = rewriter.create<DepthwiseConv2dNativeOp>(
667           conv_op.getLoc(), conv_output_type, sliced_lhs, reshaped_filter,
668           rewriter.getI64ArrayAttr(strides),
669           /*padding=*/rewriter.getStringAttr(padding),
670           /*explicit_paddings=*/rewriter.getI64ArrayAttr(new_padding),
671           /*data_format=*/rewriter.getStringAttr("NHWC"),
672           /*dilations=*/rewriter.getI64ArrayAttr(dilation));
673     } else {
674       output = rewriter.create<Conv2DOp>(
675           conv_op.getLoc(), conv_output_type, sliced_lhs, rhs,
676           rewriter.getI64ArrayAttr(strides),
677           /*use_cudnn_on_gpu=*/rewriter.getBoolAttr(true),
678           /*padding=*/rewriter.getStringAttr(padding),
679           /*explicit_paddings=*/rewriter.getI64ArrayAttr(new_padding),
680           /*data_format=*/rewriter.getStringAttr("NHWC"),
681           /*dilations=*/rewriter.getI64ArrayAttr(dilation));
682     }
683 
684     if (need_transpose_output) {
685       // Converts from "NHWC" format back to the original output format.
686       std::pair<RankedTensorType &, DenseIntElementsAttr &>(conv_output_type,
687                                                             permutation) =
688           GetReformatTypeAndPermutation(
689               /*batch_dim=*/0, /*feature_dim=*/num_spatial_dims + 1,
690               /*spatial_dim_start=*/1, dnums.getOutputBatchDimension(),
691               dnums.getOutputFeatureDimension(),
692               *dnums.getOutputSpatialDimensions().begin(), num_spatial_dims,
693               conv_output_type, rewriter);
694       output = rewriter.create<mhlo::TransposeOp>(
695           conv_op.getLoc(), conv_op.getType(), output, permutation);
696     }
697     rewriter.replaceOp(conv_op, {output});
698   }
699 };
700 
701 class ConvertNonTrivialConvOp
702     : public OpConversionPattern<mhlo::ConvolutionOp> {
703  public:
704   using OpConversionPattern::OpConversionPattern;
705 
matchAndRewrite(mhlo::ConvolutionOp conv_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const706   LogicalResult matchAndRewrite(
707       mhlo::ConvolutionOp conv_op, OpAdaptor adaptor,
708       ConversionPatternRewriter &rewriter) const final {
709     if (IsSupportedConvOp(conv_op, rewriter).failed()) {
710       return rewriter.notifyMatchFailure(
711           conv_op,
712           "doesn't support to convert to ConvBackpropInputOp or "
713           "ResizeBilinearOp");
714     }
715 
716     // tf.ResizeBilinearOp is perferred than tf.Conv2DBackpropInputOp since
717     // the former has better portability, especially in inference use cases.
718     bool align_corners;
719     llvm::SmallVector<int, 2> output_sizes;
720     if (MatchResizeOp(conv_op, align_corners, output_sizes, rewriter)
721             .succeeded()) {
722       CreateResizeBilinearOp(conv_op, output_sizes, align_corners, rewriter);
723       return success();
724     }
725 
726     // Constructs strides array from lhs_dilation.
727     // For example, [2, 3] -> [1, 2, 3, 1].
728     SmallVector<int64_t, 4> strides({1});
729     strides.append(
730         conv_op.lhs_dilation().getValue().getValues<int64_t>().begin(),
731         conv_op.lhs_dilation().getValue().getValues<int64_t>().end());
732     strides.emplace_back(1);
733 
734     // Constructs dilation array.
735     SmallVector<int64_t, 4> dilation;
736     if (auto rhs_dilation = conv_op.rhs_dilation()) {
737       // For example, [2, 3] -> [1, 2, 3, 1].
738       dilation.emplace_back(1);
739       dilation.append(rhs_dilation.getValue().getValues<int64_t>().begin(),
740                       rhs_dilation.getValue().getValues<int64_t>().end());
741       dilation.emplace_back(1);
742     } else {
743       // Default value
744       dilation = {1, 1, 1, 1};
745     }
746 
747     mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
748     std::string padding;
749     if (!conv_op.padding().has_value() ||
750         (conv_op.padding().getValue().isSplat() &&
751          conv_op.padding()->getSplatValue<int64_t>() == 0)) {
752       padding = "VALID";
753     } else {
754       auto spatial_dims = dnums.getInputSpatialDimensions();
755       int num_spatial_dims =
756           std::accumulate(spatial_dims.begin(), spatial_dims.end(), 1LL,
757                           std::multiplies<int64_t>{});
758       if (!IsSamePadding(conv_op, num_spatial_dims, strides)) {
759         return rewriter.notifyMatchFailure(
760             conv_op, "requires padding to be SAME or VALID");
761       }
762       padding = "SAME";
763     }
764 
765     // Converts int64_t to int32_t.
766     llvm::SmallVector<int, 4> input_shape;
767     for (int64_t dim : conv_op.getType().cast<RankedTensorType>().getShape()) {
768       input_shape.push_back(dim);
769     }
770     auto input_sizes = rewriter.create<ConstOp>(
771         conv_op.getLoc(),
772         DenseIntElementsAttr::get(
773             RankedTensorType::get({static_cast<int64_t>(input_shape.size())},
774                                   rewriter.getI32Type()),
775             input_shape));
776     // Mirror the filter in the spatial dimensions.
777     auto filter = rewriter.create<mhlo::ReverseOp>(
778         conv_op.getLoc(), conv_op.rhs(),
779         rewriter.getI64TensorAttr(dnums.getKernelSpatialDimensions()));
780     rewriter.replaceOpWithNewOp<Conv2DBackpropInputOp>(
781         conv_op, conv_op.getType(), input_sizes, filter, conv_op.lhs(),
782         rewriter.getI64ArrayAttr(strides),
783         /*use_cudnn_on_gpu=*/rewriter.getBoolAttr(true),
784         /*padding=*/rewriter.getStringAttr(padding),
785         /*explicit_paddings=*/rewriter.getI64ArrayAttr({}),
786         /*data_format=*/rewriter.getStringAttr("NHWC"),
787         /*dilations=*/rewriter.getI64ArrayAttr(dilation));
788     return success();
789   };
790 
791  private:
IsSamePadding(mhlo::ConvolutionOp conv_op,int num_spatial_dims,ArrayRef<int64_t> strides) const792   bool IsSamePadding(mhlo::ConvolutionOp conv_op, int num_spatial_dims,
793                      ArrayRef<int64_t> strides) const {
794     for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
795       int dim = i + 1;
796       int stride = strides[dim];
797       int input_size = conv_op.getType().cast<ShapedType>().getDimSize(dim);
798       int output_size =
799           conv_op.lhs().getType().cast<ShapedType>().getDimSize(dim);
800       if (output_size != (input_size + stride - 1) / stride) {
801         return false;
802       }
803     }
804 
805     return true;
806   }
807 
IsSupportedConvOp(mhlo::ConvolutionOp conv_op,ConversionPatternRewriter & rewriter) const808   LogicalResult IsSupportedConvOp(mhlo::ConvolutionOp conv_op,
809                                   ConversionPatternRewriter &rewriter) const {
810     if (!conv_op.lhs().getType().cast<ShapedType>().hasStaticShape() ||
811         !conv_op.rhs().getType().cast<ShapedType>().hasStaticShape() ||
812         !conv_op.getType().cast<ShapedType>().hasStaticShape())
813       return rewriter.notifyMatchFailure(conv_op, "requires static shape");
814     mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
815     const int input_feature_dimension = dnums.getInputFeatureDimension();
816     const int input_channels =
817         conv_op.lhs().getType().cast<ShapedType>().getDimSize(
818             input_feature_dimension);
819     int feature_group_count = conv_op.feature_group_count();
820 
821     if (feature_group_count != 1 && feature_group_count != input_channels) {
822       // Group convolution is not supported yet.
823       return rewriter.notifyMatchFailure(conv_op,
824                                          "doesn't support group convolution");
825     }
826 
827     // Checks lhs_dilation is non-trivial.
828     if (!conv_op.lhs_dilation().has_value()) {
829       return rewriter.notifyMatchFailure(conv_op,
830                                          "requires lhs_dilation attribute");
831     }
832     auto lhs_dilation = conv_op.lhs_dilation().getValue();
833     if (lhs_dilation.isSplat() && lhs_dilation.getSplatValue<int64_t>() == 1)
834       return rewriter.notifyMatchFailure(conv_op,
835                                          "requires non-trivial lhs_dilation");
836 
837     if (!conv_op.window_strides().has_value() || conv_op.window_strides()
838                                                          .getValue()
839                                                          .getType()
840                                                          .cast<ShapedType>()
841                                                          .getRank() != 1)
842       return rewriter.notifyMatchFailure(
843           conv_op, "requires window_strides to equal to one");
844 
845     int num_spatial_dims = dnums.getInputSpatialDimensions().size();
846     // TODO(chhe): Currently we don't support 3D Convolution.
847     if (num_spatial_dims != 2)
848       return rewriter.notifyMatchFailure(conv_op,
849                                          "doesn't support more than 2D");
850 
851     // TODO(chhe): To support more data formats other than "NHWC".
852     // Checks format [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f].
853     if (dnums.getInputBatchDimension() != 0 ||
854         dnums.getInputFeatureDimension() != num_spatial_dims + 1)
855       return rewriter.notifyMatchFailure(conv_op,
856                                          "requires input format [b, 0, 1, f]");
857     auto input_spatial_dimensions = dnums.getInputSpatialDimensions();
858     for (auto p : llvm::enumerate(input_spatial_dimensions)) {
859       if (p.value() != p.index() + 1)
860         return rewriter.notifyMatchFailure(
861             conv_op, "requires input format [b, 0, 1, f]");
862     }
863 
864     // Checks output dimensions.
865     if (dnums.getOutputBatchDimension() != 0 ||
866         conv_op.dimension_numbers().getOutputFeatureDimension() !=
867             num_spatial_dims + 1)
868       return rewriter.notifyMatchFailure(conv_op,
869                                          "requires output format [b, 0, 1, f]");
870     auto output_spatial_dimensions = dnums.getOutputSpatialDimensions();
871     for (auto p : llvm::enumerate(output_spatial_dimensions)) {
872       if (p.value() != p.index() + 1)
873         return rewriter.notifyMatchFailure(
874             conv_op, "requires output format [b, 0, 1, f]");
875     }
876 
877     // Checks kernel dimensions.
878     if (dnums.getKernelInputFeatureDimension() != num_spatial_dims + 1 ||
879         dnums.getKernelOutputFeatureDimension() != num_spatial_dims)
880       return rewriter.notifyMatchFailure(conv_op,
881                                          "requires kernel format [b, 0, 1, f]");
882     auto kernel_spatial_dimensions = dnums.getKernelSpatialDimensions();
883     for (auto p : llvm::enumerate(kernel_spatial_dimensions)) {
884       if (p.value() != p.index())
885         return rewriter.notifyMatchFailure(
886             conv_op, "requires kernel format [0, 1, o, i]");
887     }
888 
889     return success();
890   }
891 
CreateResizeBilinearOp(mhlo::ConvolutionOp conv_op,llvm::ArrayRef<int32_t> output_sizes,bool align_corners,ConversionPatternRewriter & rewriter) const892   void CreateResizeBilinearOp(mhlo::ConvolutionOp conv_op,
893                               llvm::ArrayRef<int32_t> output_sizes,
894                               bool align_corners,
895                               ConversionPatternRewriter &rewriter) const {
896     Value output_sizes_attr = rewriter.create<ConstOp>(
897         conv_op.getLoc(),
898         DenseIntElementsAttr::get(
899             RankedTensorType::get({static_cast<int64_t>(output_sizes.size())},
900                                   rewriter.getI32Type()),
901             output_sizes));
902     // The value of half_pixel_centers couldn't be inferred from the IR and XLA
903     // only support half_pixel_centers=True as in 01/11/2022. Here
904     // half_pixel_centers=False is hardcoded.
905     Value output = rewriter.create<ResizeBilinearOp>(
906         conv_op.getLoc(), conv_op.getType(), conv_op.lhs(), output_sizes_attr,
907         /*align_corners=*/rewriter.getBoolAttr(align_corners),
908         /*half_pixel_centers=*/rewriter.getBoolAttr(false));
909     rewriter.replaceOp(conv_op, {output});
910   }
911 
MatchResizeOp(mhlo::ConvolutionOp conv_op,bool & align_corners,llvm::SmallVector<int,2> & output_sizes,ConversionPatternRewriter & rewriter) const912   LogicalResult MatchResizeOp(mhlo::ConvolutionOp conv_op, bool &align_corners,
913                               llvm::SmallVector<int, 2> &output_sizes,
914                               ConversionPatternRewriter &rewriter) const {
915     mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
916     auto input_spatial_dimensions = dnums.getInputSpatialDimensions();
917     auto kernel_spatial_dimensions = dnums.getKernelSpatialDimensions();
918     auto output_spatial_dimensions = dnums.getOutputSpatialDimensions();
919     if (input_spatial_dimensions.size() != 2 ||
920         output_spatial_dimensions.size() != 2 ||
921         kernel_spatial_dimensions.size() != 2 ||
922         input_spatial_dimensions[0] != output_spatial_dimensions[0] ||
923         input_spatial_dimensions[1] != output_spatial_dimensions[1])
924       return rewriter.notifyMatchFailure(
925           conv_op, "can only be converted to 2D resize op");
926 
927     // When "lhs_dilation" is 2D and contains at least "1", and "rhs_dilation"
928     // are all "1"s, this "mhlo.conv" op can potentially be converted to
929     // "tf.ResizeBilinearOp".
930     if (!conv_op.rhs_dilation().has_value() || !conv_op.padding().has_value())
931       return rewriter.notifyMatchFailure(
932           conv_op, "resize op requires rhs_dilation and padding");
933 
934     auto lhs_dilation = conv_op.lhs_dilation().getValue();
935     auto rhs_dilation = conv_op.rhs_dilation().getValue();
936     auto window_strides = conv_op.window_strides().getValue();
937     auto padding = conv_op.padding().getValue();
938     if (lhs_dilation.getNumElements() != 2 || !rhs_dilation.isSplat() ||
939         rhs_dilation.getSplatValue<int64_t>() != 1 ||
940         window_strides.getNumElements() != 2 || padding.getNumElements() != 4)
941       return rewriter.notifyMatchFailure(
942           conv_op, "resize op requires [2] dilations and [2,2] padding");
943     auto lhs_dilation_values = lhs_dilation.getValues<int64_t>();
944     auto window_strides_values = window_strides.getValues<int64_t>();
945     auto padding_values = padding.getValues<int64_t>();
946 
947     // Cast the dimension sizes to int.
948     auto lhs_type = conv_op.lhs().getType().cast<ShapedType>();
949     llvm::SmallVector<int> input_sizes = {
950         static_cast<int>(lhs_type.getDimSize(input_spatial_dimensions[0])),
951         static_cast<int>(lhs_type.getDimSize(input_spatial_dimensions[1]))};
952     output_sizes = {static_cast<int>(conv_op.getType().getDimSize(
953                         output_spatial_dimensions[0])),
954                     static_cast<int>(conv_op.getType().getDimSize(
955                         output_spatial_dimensions[1]))};
956 
957     // This is based on method in compiler/tf2xla/kernels/image_resize_ops.cc
958     auto can_convert_to_bilinear = [](bool align_corners, int64_t dilation,
959                                       int64_t padding, int64_t stride,
960                                       int64_t input_spatial,
961                                       int64_t output_spatial) {
962       int64_t input_spatial_size =
963           align_corners ? input_spatial - 1 : input_spatial;
964       int64_t output_spatial_size =
965           align_corners ? output_spatial - 1 : output_spatial;
966 
967       int64_t gcd =
968           tensorflow::MathUtil::GCD(static_cast<uint64_t>(input_spatial_size),
969                                     static_cast<uint64_t>(output_spatial_size));
970       if ((input_spatial_size % gcd != 0) ||
971           (input_spatial_size / gcd != stride) || (dilation - 1 != padding)) {
972         return false;
973       }
974 
975       return true;
976     };
977 
978     // Only of the lhs_dilation must be 1, then the non-1 dimension is the
979     // resize dimension.
980     if (lhs_dilation_values[0] != 1 && lhs_dilation_values[1] == 1) {
981       if (can_convert_to_bilinear(
982               /*align_corners=*/true, lhs_dilation_values[0], padding_values[0],
983               window_strides_values[0], input_sizes[0], output_sizes[0])) {
984         align_corners = true;
985         return success();
986       }
987       if (can_convert_to_bilinear(
988               /*align_corners=*/false, lhs_dilation_values[0],
989               padding_values[0], window_strides_values[0], input_sizes[0],
990               output_sizes[0])) {
991         align_corners = false;
992         return success();
993       }
994     }
995 
996     if (lhs_dilation_values[0] == 1 && lhs_dilation_values[1] != 1) {
997       if (can_convert_to_bilinear(
998               /*align_corners=*/true, lhs_dilation_values[1], padding_values[2],
999               window_strides_values[1], input_sizes[1], output_sizes[1])) {
1000         align_corners = true;
1001         return success();
1002       }
1003       if (can_convert_to_bilinear(
1004               /*align_corners=*/false, lhs_dilation_values[1],
1005               padding_values[2], window_strides_values[1], input_sizes[1],
1006               output_sizes[1])) {
1007         align_corners = false;
1008         return success();
1009       }
1010     }
1011 
1012     return rewriter.notifyMatchFailure(conv_op,
1013                                        "can not be converted to resize op");
1014   }
1015 };
1016 
1017 class ConvertSliceOp : public OpConversionPattern<mhlo::SliceOp> {
1018  public:
1019   using OpConversionPattern::OpConversionPattern;
1020 
matchAndRewrite(mhlo::SliceOp slice_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1021   LogicalResult matchAndRewrite(
1022       mhlo::SliceOp slice_op, OpAdaptor adaptor,
1023       ConversionPatternRewriter &rewriter) const final {
1024     auto begin =
1025         rewriter.create<ConstOp>(slice_op.getLoc(), slice_op.start_indices());
1026     auto end =
1027         rewriter.create<ConstOp>(slice_op.getLoc(), slice_op.limit_indices());
1028     auto strides =
1029         rewriter.create<ConstOp>(slice_op.getLoc(), slice_op.strides());
1030     rewriter.replaceOpWithNewOp<StridedSliceOp>(
1031         slice_op, slice_op.getType(), slice_op.operand(), begin, end, strides);
1032     return success();
1033   }
1034 };
1035 
1036 class ConvertDynamicSliceOp : public OpConversionPattern<mhlo::DynamicSliceOp> {
1037  public:
1038   using OpConversionPattern::OpConversionPattern;
1039 
matchAndRewrite(mhlo::DynamicSliceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1040   LogicalResult matchAndRewrite(
1041       mhlo::DynamicSliceOp op, OpAdaptor adaptor,
1042       ConversionPatternRewriter &rewriter) const final {
1043     ShapedType input_type = op.operand().getType().cast<ShapedType>();
1044     if (!input_type.hasStaticShape()) return failure();
1045     Type start_indices_element_type = op.start_indices()
1046                                           .front()
1047                                           .getType()
1048                                           .cast<ShapedType>()
1049                                           .getElementType();
1050 
1051     // The mhlo dynamic_slice's start_indices can be either signed/unsigned
1052     // int32/int64. However, TF only takes in either i32 or i64 types for begin,
1053     // so we will always put a cast.
1054     Type signed_start_indices_element_type;
1055     if (start_indices_element_type.isInteger(32)) {
1056       signed_start_indices_element_type = rewriter.getI32Type();
1057     } else {
1058       signed_start_indices_element_type = rewriter.getI64Type();
1059     }
1060 
1061     // Clamp indices to [0, input_size - output_size]
1062     llvm::SmallVector<Value, 4> start_indices_vector;
1063     start_indices_vector.reserve(op.start_indices().size());
1064     Value clamp_min = rewriter.create<ConstOp>(
1065         op.getLoc(),
1066         rewriter.getIntegerAttr(signed_start_indices_element_type, 0));
1067     for (uint64_t i = 0, e = op.start_indices().size(); i < e; ++i) {
1068       // Always put a cast there.
1069       auto start = op.start_indices()[i];
1070       auto cast_type = start.getType().cast<ShapedType>().clone(
1071           signed_start_indices_element_type);
1072       auto cast_op = rewriter.create<CastOp>(op.getLoc(), cast_type, start);
1073       Value clamp_max = rewriter.create<ConstOp>(
1074           op.getLoc(), rewriter.getIntegerAttr(
1075                            signed_start_indices_element_type,
1076                            input_type.getShape()[i] -
1077                                op.slice_sizes().getValues<int64_t>()[i]));
1078       Value clamped_index = rewriter.create<mhlo::ClampOp>(
1079           op.getLoc(), cast_type, clamp_min, cast_op, clamp_max);
1080       start_indices_vector.push_back(clamped_index);
1081     }
1082 
1083     // Pack individual start indices to start indices tensor.
1084     Type start_indices_type = RankedTensorType::get(
1085         {static_cast<int64_t>(start_indices_vector.size())},
1086         signed_start_indices_element_type);
1087     Value start_indices_op = rewriter.create<PackOp>(
1088         op.getLoc(), start_indices_type, ValueRange(start_indices_vector));
1089 
1090     Value slice_sices_op =
1091         rewriter.create<ConstOp>(op.getLoc(), op.slice_sizes());
1092     rewriter.replaceOpWithNewOp<SliceOp>(op, op.getType(), op.operand(),
1093                                          start_indices_op, slice_sices_op);
1094     return success();
1095   };
1096 };
1097 
1098 // Appends all elements in `range` to `values`.
1099 template <typename ValueT, typename Range>
Append(llvm::SmallVectorImpl<ValueT> & values,Range && range)1100 void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range) {
1101   values.insert(values.end(), range.begin(), range.end());
1102 }
1103 
1104 // Appends all elements in `range` to `values`.
1105 template <typename ValueT, typename Range, typename... RangeTs>
Append(llvm::SmallVectorImpl<ValueT> & values,Range && range,RangeTs &&...ranges)1106 void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range,
1107             RangeTs &&...ranges) {
1108   values.insert(values.end(), range.begin(), range.end());
1109   Append(values, ranges...);
1110 }
1111 
1112 // Returns the number of elements in `range`.
1113 template <typename Range>
Size(Range && range)1114 size_t Size(Range &&range) {
1115   return range.size();
1116 }
1117 
1118 // Returns the total number of elements in a variadic number of `ranges`.
1119 template <typename Range, typename... RangeTs>
Size(Range && range,RangeTs &&...ranges)1120 size_t Size(Range &&range, RangeTs &&...ranges) {
1121   return range.size() + Size(std::forward<RangeTs>(ranges)...);
1122 }
1123 
1124 // Concats all elements in `ranges` and returns a small vector as a result.
1125 template <typename ValueT, typename... RangeTs>
Concat(RangeTs &&...ranges)1126 llvm::SmallVector<ValueT, 4> Concat(RangeTs &&...ranges) {
1127   llvm::SmallVector<int64_t, 4> results;
1128   results.reserve(Size(std::forward<RangeTs>(ranges)...));
1129   Append(results, std::forward<RangeTs>(ranges)...);
1130   return results;
1131 }
1132 
1133 // A struct to hold axes and sizes for a set of dimensions.
1134 struct DimensionVector {
AxesArraymlir::TF::__anonab43b57b0111::DimensionVector1135   llvm::ArrayRef<int64_t> AxesArray() const { return axes; }
SizesArraymlir::TF::__anonab43b57b0111::DimensionVector1136   llvm::ArrayRef<int64_t> SizesArray() const { return sizes; }
1137 
1138   llvm::SmallVector<int64_t, 4> axes;
1139   llvm::SmallVector<int64_t, 4> sizes;
1140 };
1141 
1142 // Create a single const integer.
BuildIntConstOp(ImplicitLocOpBuilder & builder,ConversionPatternRewriter & rewriter,int64_t const_value,Type type)1143 Value BuildIntConstOp(ImplicitLocOpBuilder &builder,
1144                       ConversionPatternRewriter &rewriter, int64_t const_value,
1145                       Type type) {
1146   Value result_const =
1147       builder.create<ConstOp>(rewriter.getIntegerAttr(type, const_value));
1148   return result_const;
1149 }
1150 // Create a const integer vector tensor (1-dim).
BuildIntArrayConstOp(ImplicitLocOpBuilder & builder,ConversionPatternRewriter & rewriter,ArrayRef<int64_t> const_value,Type type)1151 Value BuildIntArrayConstOp(ImplicitLocOpBuilder &builder,
1152                            ConversionPatternRewriter &rewriter,
1153                            ArrayRef<int64_t> const_value, Type type) {
1154   DenseIntElementsAttr const_value_raw;
1155   if (type == rewriter.getI64Type()) {
1156     const_value_raw = rewriter.getI64TensorAttr(const_value);
1157   } else {
1158     // Convert I64 const array to I32.
1159     llvm::SmallVector<int32_t> const_i32_vec;
1160     for (auto element : const_value) {
1161       const_i32_vec.push_back(static_cast<int32_t>(element));
1162     }
1163     const_value_raw = rewriter.getI32TensorAttr(const_i32_vec);
1164   }
1165   Value result_const = builder.create<ConstOp>(const_value_raw);
1166   return result_const;
1167 }
1168 
1169 // Create a tensor that is reshaped from input.
BuildReshapeOp(ImplicitLocOpBuilder & builder,ConversionPatternRewriter & rewriter,Value input,ArrayRef<int64_t> shape,Type idx_type,Type element_type)1170 Value BuildReshapeOp(ImplicitLocOpBuilder &builder,
1171                      ConversionPatternRewriter &rewriter, Value input,
1172                      ArrayRef<int64_t> shape, Type idx_type,
1173                      Type element_type) {
1174   Value shape_cst = BuildIntArrayConstOp(builder, rewriter, shape, idx_type);
1175   Value reshaped_input = builder.create<ReshapeOp>(
1176       RankedTensorType::get(shape, element_type), input, shape_cst);
1177   return reshaped_input;
1178 }
1179 
1180 // Create a tensor which is equal to input[begin: begin + size].
BuildSliceOp(ImplicitLocOpBuilder & builder,ConversionPatternRewriter & rewriter,Value input,Value begin,ArrayRef<int64_t> shape,Type idx_type,Type element_type)1181 Value BuildSliceOp(ImplicitLocOpBuilder &builder,
1182                    ConversionPatternRewriter &rewriter, Value input,
1183                    Value begin, ArrayRef<int64_t> shape, Type idx_type,
1184                    Type element_type) {
1185   Value shape_cst = BuildIntArrayConstOp(builder, rewriter, shape, idx_type);
1186   Value slice_result = builder.create<SliceOp>(
1187       RankedTensorType::get(shape, element_type), input, begin, shape_cst);
1188   return slice_result;
1189 }
1190 
1191 class ConvertDynamicUpdateSliceOp
1192     : public OpConversionPattern<mhlo::DynamicUpdateSliceOp> {
1193  public:
1194   using OpConversionPattern::OpConversionPattern;
1195 
matchAndRewrite(mhlo::DynamicUpdateSliceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1196   LogicalResult matchAndRewrite(
1197       mhlo::DynamicUpdateSliceOp op, OpAdaptor adaptor,
1198       ConversionPatternRewriter &rewriter) const final {
1199     ShapedType operand_type = op.operand().getType().cast<ShapedType>();
1200     ShapedType update_type =
1201         op.update().getType().dyn_cast_or_null<ShapedType>();
1202     ShapedType start_indices_type =
1203         op.start_indices().front().getType().dyn_cast_or_null<ShapedType>();
1204     if (update_type == nullptr || start_indices_type == nullptr)
1205       return rewriter.notifyMatchFailure(
1206           op, "update and start_indices should have ShapedType");
1207     if (!operand_type.hasStaticShape() || !update_type.hasStaticShape())
1208       return rewriter.notifyMatchFailure(
1209           op, "shape of operand and update should be static");
1210 
1211     Type idx_type = start_indices_type.getElementType();
1212     int64_t shape_dim = operand_type.getRank();
1213     auto operand_shape = operand_type.getShape();
1214     auto update_shape = update_type.getShape();
1215 
1216     ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
1217     Value zero_cst = BuildIntConstOp(builder, rewriter, 0, idx_type);
1218     Value one_cst = BuildIntConstOp(builder, rewriter, 1, idx_type);
1219     // Clamp start indices in [0, operand_size - update_size].
1220     llvm::SmallVector<Value> start_indices_vector;
1221     Append(start_indices_vector, op.start_indices());
1222     auto shape_tensor_type = RankedTensorType::get({shape_dim}, idx_type);
1223     Value start_indices_tensor =
1224         builder.create<PackOp>(shape_tensor_type, start_indices_vector);
1225     Value operand_shape_cst =
1226         BuildIntArrayConstOp(builder, rewriter, operand_shape, idx_type);
1227     Value update_shape_cst =
1228         BuildIntArrayConstOp(builder, rewriter, update_shape, idx_type);
1229     Value max_start_indices =
1230         builder.create<SubOp>(operand_shape_cst, update_shape_cst);
1231     Value start_indices_clip_max =
1232         builder.create<MinimumOp>(start_indices_tensor, max_start_indices);
1233     Value clamped_start_indices =
1234         builder.create<MaximumOp>(start_indices_clip_max, zero_cst);
1235 
1236     // Do dynamic_upate_slice on flattened operand and update with the aid of
1237     // tf.TensorScatterUpdate op. It takes in 3 parameters: flat_operand,
1238     // indices and flat_update. The indices are computed as follows:
1239     // 1. Construct a range (0, n_operand). It arranges a id number to each
1240     //    element position in operand.
1241     // 2. Reshape the range to the shape of operand.
1242     // 3. Compute the id numbers of update positions by choose a slice form
1243     //    clamped_start_indices to clamped_start_indices + update_size.
1244     // 4. Flatten the update id numbers and the indices is obtained.
1245     int64_t n_operand = operand_type.getNumElements();
1246     Value n_operand_cst =
1247         BuildIntConstOp(builder, rewriter, n_operand, idx_type);
1248     Value range_flat =
1249         builder.create<RangeOp>(zero_cst, n_operand_cst, one_cst);
1250     Value range = BuildReshapeOp(builder, rewriter, range_flat, operand_shape,
1251                                  idx_type, idx_type);
1252     Value update_indices_raw =
1253         BuildSliceOp(builder, rewriter, range, clamped_start_indices,
1254                      update_shape, idx_type, idx_type);
1255     int64_t n_update = update_type.getNumElements();
1256     Type element_type = operand_type.getElementType();
1257     Value update_indices = BuildReshapeOp(builder, rewriter, update_indices_raw,
1258                                           {n_update, 1}, idx_type, idx_type);
1259     Value operand_flat = BuildReshapeOp(builder, rewriter, op.operand(),
1260                                         {n_operand}, idx_type, element_type);
1261     Value update_flat = BuildReshapeOp(builder, rewriter, op.update(),
1262                                        {n_update}, idx_type, element_type);
1263     Value flat_result = builder.create<TensorScatterUpdateOp>(
1264         operand_flat, update_indices, update_flat);
1265 
1266     // Reshape back before return.
1267     rewriter.replaceOpWithNewOp<ReshapeOp>(op, operand_type, flat_result,
1268                                            operand_shape_cst);
1269     return success();
1270   };
1271 };
1272 
1273 // It returns "true" when Value $iota is obtained from the following mlir code:
1274 //
1275 // $iota = "mhlo.iota"(){iota_dimension = $dimensions[0]},
1276 //
1277 // where $dimensions must have size 1 and iota can have rank>=1.
1278 // It usually used for matching rank 1 iota since the iotaOp will be folded to
1279 // IotaOp + BroadCastInDimOp except for the case when result shape is rank 1.
MatchSingleIota(DenseIntElementsAttr dimensions,Value iota)1280 bool MatchSingleIota(DenseIntElementsAttr dimensions, Value iota) {
1281   auto iota_op = dyn_cast_or_null<mhlo::IotaOp>(iota.getDefiningOp());
1282   if (!iota_op || dimensions.getNumElements() != 1) return false;
1283   auto dim = *dimensions.value_begin<APInt>();
1284   return dim == iota_op.iota_dimension();
1285 }
1286 
1287 // It matches %iota generated from the following mlir codes:
1288 //
1289 // %iota_r1 = "mhlo.iota"(){iota_dimension = 0} :() -> tensor<Lxi32>
1290 // %iota = "mhlo.broadcast_in_dim(%iota_r1){
1291 //    broadcast_dimensions = dense<[$dimensions[0]]>},
1292 //
1293 // where %dimensions is of size 1. It ususally comes from an IotaOp that is
1294 // folded to IotaOp (rank1) + BroadCastInDimOp.
MatchIotaBroadCastInDim(DenseIntElementsAttr dimensions,Value iota)1295 bool MatchIotaBroadCastInDim(DenseIntElementsAttr dimensions, Value iota) {
1296   auto iota_broadcast =
1297       dyn_cast_or_null<mhlo::BroadcastInDimOp>(iota.getDefiningOp());
1298   if (!iota_broadcast || iota_broadcast.broadcast_dimensions() != dimensions)
1299     return false;
1300   if (!isa_and_nonnull<mhlo::IotaOp>(iota_broadcast.operand().getDefiningOp()))
1301     return false;
1302   return true;
1303 }
1304 
1305 // Matches %iota generated from the following code (rank 3 example):
1306 //
1307 // %iota_r1 = "mhlo.iota"(){iota_dimension = 0 : i32} : () -> tensor<44xi32>
1308 // %iota = "mhlo.reshape"(%iota_r1): (tensor<44xi32>) -> tensor<1x1x44xi32>
1309 //
1310 // Where $dimensions is of size 1 and $dimensions[0] = 2.
1311 //
1312 // In general matches a 1-D Iota with multiple dimensions of size 1 added
1313 // through a reshape.
MatchReshapedIota(DenseIntElementsAttr dimensions,Value iota)1314 bool MatchReshapedIota(DenseIntElementsAttr dimensions, Value iota) {
1315   if (dimensions.getNumElements() != 1) return false;
1316   auto reshape_op = dyn_cast_or_null<mhlo::ReshapeOp>(iota.getDefiningOp());
1317   if (!reshape_op) return false;
1318   auto operand_type =
1319       reshape_op.operand().getType().dyn_cast<RankedTensorType>();
1320   if (!operand_type || !operand_type.hasStaticShape()) return false;
1321   auto reshape_type = reshape_op.getType().cast<RankedTensorType>();
1322 
1323   // Reshape can take a 1-D iota input and add extra dims of size one.
1324   if (operand_type.getRank() != 1) return false;
1325   if (!dyn_cast_or_null<mhlo::IotaOp>(reshape_op.operand().getDefiningOp()))
1326     return false;
1327 
1328   int64_t iota_dim = (*dimensions.value_begin<APInt>()).getSExtValue();
1329   for (int64_t i = 0, e = reshape_type.getRank(); i < e; ++i) {
1330     if (i == iota_dim) {
1331       if (reshape_type.getDimSize(i) != operand_type.getDimSize(0))
1332         return false;
1333     } else if (reshape_type.getDimSize(i) != 1) {
1334       return false;
1335     }
1336   }
1337   return true;
1338 }
1339 
1340 // It matches %iota generated from the following mlir codes:
1341 //
1342 // %iota_r1 = mhlo.constant dense<[0, 1, 2, ..., L]>
1343 // %iota = "mhlo.broadcast_in_dim(%iota_r1){
1344 //    broadcast_dimensions = dense<[$dimensions[0]]>},
1345 //
1346 // where $dimensions is of size 1. It ususally comes from an IotaOp that is
1347 // folded to ConstOp (folded rank1 iota) + BroadCastInDimOp.
MatchConstIotaBroadCastInDim(DenseIntElementsAttr dimensions,Value iota)1348 bool MatchConstIotaBroadCastInDim(DenseIntElementsAttr dimensions, Value iota) {
1349   if (dimensions.getNumElements() != 1) return false;
1350   auto iota_broadcast =
1351       dyn_cast_or_null<mhlo::BroadcastInDimOp>(iota.getDefiningOp());
1352   if (!iota_broadcast || iota_broadcast.broadcast_dimensions() != dimensions)
1353     return false;
1354   DenseElementsAttr range_const;
1355   if (!matchPattern(iota_broadcast.operand(), m_Constant(&range_const)))
1356     return false;
1357   int index = 0;
1358   for (auto value : range_const.getValues<APInt>()) {
1359     if (value != index++) return false;
1360   }
1361   return true;
1362 }
1363 
1364 // Facilitate access to 1-d backing data for a tensor so that values in a 1-d
1365 // slice of the tensor can be accessed as if part of an ArrayView.
1366 class StridedArrayViewBase {
1367  protected:
StridedArrayViewBase(ArrayRef<int64_t> shape,ArrayRef<int64_t> index,int64_t axis)1368   StridedArrayViewBase(ArrayRef<int64_t> shape, ArrayRef<int64_t> index,
1369                        int64_t axis) {
1370     assert(shape.size() == index.size());
1371     assert(axis < shape.size());
1372     assert(axis >= 0);
1373     assert(index[axis] == 0);
1374     offset_ = IndexToOffset(shape, index);
1375     stride_ = StrideForAxis(shape, axis);
1376     size_ = shape[axis];
1377   }
1378 
1379   // Returns the size of the 1-d slice across the tensor.
size() const1380   int64_t size() const { return size_; }
1381 
1382   // Calculates the next index in a tensor excluding a specified axis.
1383   //
1384   // Returns the next index where one exists.
1385   // If there is no valid next index, returns `std::nullopt`.
1386   //
1387   // `index` should have the same size as `shape`.
1388   // Each value `dim` in `index` should be in [0, shape[dim]).
NextTensorIndex(SmallVector<int64_t> index,ArrayRef<int64_t> shape,int64_t fixed_axis)1389   static llvm::Optional<SmallVector<int64_t>> NextTensorIndex(
1390       SmallVector<int64_t> index, ArrayRef<int64_t> shape, int64_t fixed_axis) {
1391 #ifndef NDEBUG
1392     assert(shape.size() == index.size());
1393     assert(fixed_axis < shape.size());
1394     assert(fixed_axis >= 0);
1395     assert(index[fixed_axis] == 0);
1396     for (size_t i = 0; i < shape.size(); ++i) {
1397       assert(index[i] < shape[i]);
1398       assert(index[i] >= 0);
1399     }
1400 #endif  // NDEBUG
1401     for (int64_t dim = shape.size() - 1; dim >= 0; --dim) {
1402       if (dim == fixed_axis) continue;
1403       ++index[dim];
1404       if (index[dim] < shape[dim]) return std::move(index);
1405       index[dim] = 0;
1406     }
1407     return llvm::None;
1408   }
1409 
1410  protected:
1411   // Calculates how many values to skip across a 1-D contiguous array that holds
1412   // backing data for a given shape to access the value at a given index along a
1413   // StridedArrayView across a higher dimensional shape.
1414   //
1415   // The index `i` must be in [0, shape[axis])` where `shape` is the shape
1416   // of the tensor and `axis` is the axis along the tensor that the
1417   // StridedArrayView indexes along.
OffsetForIndex(int64_t i) const1418   int64_t OffsetForIndex(int64_t i) const { return offset_ + i * stride_; }
1419 
1420  private:
1421   // Calculates how many values to skip across a 1-D contiguous array that holds
1422   // backing data for a given shape to access the next value along a given axis.
1423   //
1424   // `axis` should be a valid dimension in `shape`.
StrideForAxis(ArrayRef<int64_t> shape,int64_t axis)1425   static int64_t StrideForAxis(ArrayRef<int64_t> shape, int64_t axis) {
1426     int64_t stride = 1;  // Start with the trailing dimension.
1427     for (int64_t dim = shape.size() - 1; dim > axis; --dim) {
1428       stride *= shape[dim];
1429     }
1430     return stride;
1431   }
1432 
1433   // Calculates how many values to skip across a 1-D contiguous array that holds
1434   // backing data for a given shape to access data at a specified index.
1435   //
1436   // `index` should have the same size as `shape`.
1437   // Each value `dim` in `index` should be in [0, shape[dim]).
IndexToOffset(ArrayRef<int64_t> shape,ArrayRef<int64_t> index)1438   static int64_t IndexToOffset(ArrayRef<int64_t> shape,
1439                                ArrayRef<int64_t> index) {
1440 #ifndef NDEBUG
1441     assert(shape.size() == index.size());
1442     for (size_t i = 0; i < shape.size(); ++i) {
1443       assert(index[i] < shape[i]);
1444       assert(index[i] >= 0);
1445     }
1446 #endif  // NDEBUG
1447     int64_t offset = 0;
1448     int64_t stride = 1;
1449     for (int64_t dim = shape.size() - 1; dim >= 0; --dim) {
1450       offset += index[dim] * stride;
1451       stride *= shape[dim];
1452     }
1453     return offset;
1454   }
1455 
1456   int64_t offset_;
1457   int64_t stride_;
1458   int64_t size_;
1459 };
1460 
1461 template <typename T>
1462 class StridedArrayView;  // Class requires specialization.
1463 
1464 // Wraps a DenseIntElementsAttr that holds backing data for a tensor so that
1465 // int64_t values in a 1-d slice of the tensor can be accessed as if part of an
1466 // ArrayView.
1467 template <>
1468 class StridedArrayView<DenseIntElementsAttr> : StridedArrayViewBase {
1469  public:
StridedArrayView(const DenseIntElementsAttr & data,ArrayRef<int64_t> shape,ArrayRef<int64_t> index,int64_t axis)1470   StridedArrayView(const DenseIntElementsAttr &data, ArrayRef<int64_t> shape,
1471                    ArrayRef<int64_t> index, int64_t axis)
1472       : StridedArrayViewBase(shape, index, axis), data_(data) {
1473     int64_t element_count = 1;
1474     for (int64_t i = 0, e = shape.size(); i < e; ++i) {
1475       element_count *= shape[i];
1476     }
1477     assert(data.getNumElements() == element_count);
1478   }
1479 
1480   using StridedArrayViewBase::NextTensorIndex;
1481   using StridedArrayViewBase::size;
1482 
operator [](int64_t i) const1483   int64_t operator[](int64_t i) const {
1484     return data_.getValues<APInt>()[OffsetForIndex(i)].getSExtValue();
1485   }
1486 
1487  private:
1488   const DenseIntElementsAttr &data_;
1489 };
1490 
1491 // Matches %iota generated from the following mlir codes (rank 2 example):
1492 //
1493 // %iota = mhlo.constant dense<[[0, 1, 2, ..., L],
1494 //                              [0, 1, 2, ..., L]
1495 //                              ...
1496 //                              [0, 1, 2, ..., L]]>,
1497 // where $dimensions is of size 1.
1498 //
1499 // StridedArrayViews are used to check the iota property across the constant
1500 // data so that the iota dimension does not need to be the (inner) z-dimension.
MatchIotaConst(DenseIntElementsAttr dimensions,Value iota)1501 bool MatchIotaConst(DenseIntElementsAttr dimensions, Value iota) {
1502   DenseIntElementsAttr iota_const_attr;
1503   if (!matchPattern(iota, m_Constant(&iota_const_attr))) return false;
1504 
1505   auto iota_type = iota_const_attr.getType();
1506   auto iota_shape = iota_type.getShape();
1507   auto reduce_dim = (*dimensions.value_begin<APInt>()).getSExtValue();
1508   if (reduce_dim < 0) reduce_dim += iota_type.getRank();
1509 
1510   auto index =
1511       llvm::Optional<SmallVector<int64_t>>(std::in_place, iota_type.getRank());
1512   while (index.has_value()) {
1513     StridedArrayView<DenseIntElementsAttr> array_view(
1514         iota_const_attr, iota_shape, *index, reduce_dim);
1515     for (int64_t i = 0; i < array_view.size(); ++i) {
1516       if (array_view[i] != i) return false;
1517     }
1518     index = StridedArrayView<DenseIntElementsAttr>::NextTensorIndex(
1519         std::move(*index), iota_shape, reduce_dim);
1520   }
1521 
1522   return true;
1523 }
1524 
1525 // The following 5 different forms of mhlo::iota will be matched:
1526 // 1. IotaOp.
1527 // 2. IotaOp + BroadCastInDim.
1528 // 3. IotaOp + Reshape.
1529 // 4. Constant (folded Iota) + BroadCastInDim.
1530 // 5. Constant (folded result).
1531 // Moreover, the dimensions has to match the iota_dimension.
MatchIota(DenseIntElementsAttr dimensions,Value iota)1532 bool MatchIota(DenseIntElementsAttr dimensions, Value iota) {
1533   return MatchSingleIota(dimensions, iota) ||
1534          MatchIotaBroadCastInDim(dimensions, iota) ||
1535          MatchReshapedIota(dimensions, iota) ||
1536          MatchConstIotaBroadCastInDim(dimensions, iota) ||
1537          MatchIotaConst(dimensions, iota);
1538 }
1539 
MatchTopKComparator(Region & comparator)1540 bool MatchTopKComparator(Region &comparator) {
1541   if (!comparator.hasOneBlock()) return false;
1542   Block &comparator_blk = comparator.front();
1543   using OpListType = llvm::iplist<Operation>;
1544   OpListType &operations = comparator_blk.getOperations();
1545   if (operations.size() != 2) return false;
1546   auto compare_op = dyn_cast_or_null<mhlo::CompareOp>(&operations.front());
1547   auto return_op = dyn_cast_or_null<mhlo::ReturnOp>(&operations.back());
1548   if (!compare_op || !return_op) return false;
1549   // TODO(xuanyuanluo): Support mhlo::ComparisonDirection::LT direction.
1550   if (compare_op.comparison_direction() != mhlo::ComparisonDirection::GT)
1551     return false;
1552   if (compare_op.lhs() != comparator_blk.getArgument(0) ||
1553       compare_op.rhs() != comparator_blk.getArgument(1))
1554     return false;
1555   return return_op.getOperands().front() == compare_op.getResult();
1556 }
1557 
1558 // In general, we convert the following form of sort to tf.TopK:
1559 //
1560 // %result = "mhlo.sort" (%keys, %indices) ({
1561 //  ^bb0(%key_0, %key_1, %index_0, %index_1):
1562 //     %1 = "mhlo.compare"(%key_0, %key_1) {mhlo::ComparisonDirection::GT}
1563 //     -> tensor<i1>
1564 //  }),
1565 //
1566 // where the indices is obtained by an IotaOp (maybe folded).
1567 class ConvertSortToTfTopk : public OpConversionPattern<mhlo::SortOp> {
1568  public:
1569   using OpConversionPattern::OpConversionPattern;
1570 
matchAndRewrite(mhlo::SortOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1571   LogicalResult matchAndRewrite(
1572       mhlo::SortOp op, OpAdaptor adaptor,
1573       ConversionPatternRewriter &rewriter) const final {
1574     if (op->getOperands().size() != 2)
1575       return rewriter.notifyMatchFailure(
1576           op, "only match for the case where operands is of size 2");
1577     auto keys = op.operands()[0];
1578     auto indices = op.operands()[1];
1579     auto keys_ty = keys.getType().dyn_cast_or_null<ShapedType>();
1580     auto indices_ty = indices.getType().dyn_cast_or_null<ShapedType>();
1581     if (!keys_ty || !keys_ty.hasStaticShape() ||
1582         !keys_ty.getElementType().isIntOrFloat())
1583       return rewriter.notifyMatchFailure(
1584           op,
1585           "only match for the case where the first operand has a static "
1586           "int/float shapeType");
1587     if (!indices_ty || !indices_ty.hasStaticShape() ||
1588         !indices_ty.getElementType().isInteger(32))
1589       return rewriter.notifyMatchFailure(
1590           op,
1591           "only match for the case where the second operand an I32 shapeType");
1592     auto sort_dim = op.dimension();
1593     auto k = indices_ty.getDimSize(sort_dim);
1594     auto rank = keys_ty.getRank();
1595     if (sort_dim != rank - 1 || k < 1)
1596       return rewriter.notifyMatchFailure(
1597           op, "only match for sort dim = rank - 1 and DimSize >= 1");
1598 
1599     // In the following, we'll check indices is obtained by a iota.
1600     auto sort_dim_attr = DenseIntElementsAttr::get(
1601         RankedTensorType::get({1}, rewriter.getI64Type()), {sort_dim});
1602     if (!MatchIota(sort_dim_attr, indices))
1603       return rewriter.notifyMatchFailure(
1604           op, "the second operand is supposed to be obtained from IOTA");
1605     if (!MatchTopKComparator(op.comparator()))
1606       return rewriter.notifyMatchFailure(op, "only match for GT comparator");
1607     ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
1608     Value k_cst = BuildIntConstOp(builder, rewriter, k, rewriter.getI32Type());
1609     rewriter.replaceOpWithNewOp<TopKV2Op>(op, keys.getType(), indices.getType(),
1610                                           keys, k_cst);
1611     return success();
1612   };
1613 };
1614 
1615 // A struct to hold information about dimensions of dot_general operands.
1616 class DotDimensionsInfo {
1617  public:
DotDimensionsInfo(ShapedType type,ArrayRef<int64_t> batch_dimensions,ArrayRef<int64_t> contracting_dimensions)1618   DotDimensionsInfo(ShapedType type, ArrayRef<int64_t> batch_dimensions,
1619                     ArrayRef<int64_t> contracting_dimensions) {
1620     const int rank = type.getRank();
1621     for (const int dim : batch_dimensions) {
1622       batch_dimensions_.axes.push_back(dim);
1623       batch_dimensions_.sizes.push_back(type.getDimSize(dim));
1624     }
1625 
1626     for (const int dim : contracting_dimensions) {
1627       contracting_dimensions_.axes.push_back(dim);
1628       contracting_dimensions_.sizes.push_back(type.getDimSize(dim));
1629     }
1630 
1631     for (int dim = 0; dim < rank; ++dim) {
1632       if (llvm::count(contracting_dimensions_.axes, dim) > 0 ||
1633           llvm::count(batch_dimensions_.axes, dim) > 0) {
1634         continue;
1635       }
1636       out_dimensions_.axes.push_back(dim);
1637       out_dimensions_.sizes.push_back(type.getDimSize(dim));
1638     }
1639   }
1640 
batch_dimensions() const1641   const DimensionVector &batch_dimensions() const { return batch_dimensions_; }
contracting_dimensions() const1642   const DimensionVector &contracting_dimensions() const {
1643     return contracting_dimensions_;
1644   }
1645   // Out dimensions are any dimensions that are neither batch nor contracting
1646   // dimensions, hence will be propagated to output shape.
out_dimensions() const1647   const DimensionVector &out_dimensions() const { return out_dimensions_; }
1648 
1649   // Returns the total dimension size after flattening all contracting
1650   // dimensions.
FlattenedContractingDimensionSize() const1651   int FlattenedContractingDimensionSize() const {
1652     return std::accumulate(contracting_dimensions_.sizes.begin(),
1653                            contracting_dimensions_.sizes.end(), 1,
1654                            std::multiplies<int64_t>());
1655   }
1656 
1657   // Returns the total dimension size after flattening all out dimensions.
FlattenedOutDimensionSize() const1658   int FlattenedOutDimensionSize() const {
1659     return std::accumulate(out_dimensions_.sizes.begin(),
1660                            out_dimensions_.sizes.end(), 1,
1661                            std::multiplies<int64_t>());
1662   }
1663 
1664  private:
1665   DimensionVector batch_dimensions_;
1666   DimensionVector contracting_dimensions_;
1667   // Out dimensions are any dimensions that are neither batch nor contracting
1668   // dimensions, hence will be propagated to output shape.
1669   DimensionVector out_dimensions_;
1670 };
1671 
ConvertDot(PatternRewriter & rewriter,Value lhs,Value rhs,DotDimensionNumbersAttr dot_dimension_numbers,ShapedType result_type,mlir::Location loc)1672 Value ConvertDot(PatternRewriter &rewriter, Value lhs, Value rhs,
1673                  DotDimensionNumbersAttr dot_dimension_numbers,
1674                  ShapedType result_type, mlir::Location loc) {
1675   auto lhs_type = lhs.getType().cast<ShapedType>();
1676   auto rhs_type = rhs.getType().cast<ShapedType>();
1677   const int lhs_rank = lhs_type.getRank();
1678   const int rhs_rank = rhs_type.getRank();
1679 
1680   // Collects lhs and rhs dimensions information.
1681   DotDimensionsInfo lhs_dot_dimensions_info(
1682       lhs_type, dot_dimension_numbers.getLhsBatchingDimensions(),
1683       dot_dimension_numbers.getLhsContractingDimensions());
1684   DotDimensionsInfo rhs_dot_dimensions_info(
1685       rhs_type, dot_dimension_numbers.getRhsBatchingDimensions(),
1686       dot_dimension_numbers.getRhsContractingDimensions());
1687 
1688   // Transposes lhs shape to be in the order of {batch_dimensions,
1689   // out_dimensions, contracting dimensions}.
1690   llvm::SmallVector<int64_t, 4> lhs_permutation = Concat<int64_t>(
1691       lhs_dot_dimensions_info.batch_dimensions().AxesArray(),
1692       lhs_dot_dimensions_info.out_dimensions().AxesArray(),
1693       lhs_dot_dimensions_info.contracting_dimensions().AxesArray());
1694   llvm::SmallVector<int64_t, 4> lhs_transposed_shape = Concat<int64_t>(
1695       lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
1696       lhs_dot_dimensions_info.out_dimensions().SizesArray(),
1697       lhs_dot_dimensions_info.contracting_dimensions().SizesArray());
1698   auto lhs_transposed = rewriter.create<mhlo::TransposeOp>(
1699       loc,
1700       RankedTensorType::get(lhs_transposed_shape, lhs_type.getElementType()),
1701       lhs,
1702       DenseIntElementsAttr::get(
1703           RankedTensorType::get({lhs_rank}, rewriter.getI64Type()),
1704           lhs_permutation));
1705 
1706   // Transposes rhs shape to be in the order of {batch_dimensions, contracting
1707   // dimensions, out_dimensions}.
1708   llvm::SmallVector<int64_t, 4> rhs_permutation = Concat<int64_t>(
1709       rhs_dot_dimensions_info.batch_dimensions().AxesArray(),
1710       rhs_dot_dimensions_info.contracting_dimensions().AxesArray(),
1711       rhs_dot_dimensions_info.out_dimensions().AxesArray());
1712   llvm::SmallVector<int64_t, 4> rhs_transposed_shape = Concat<int64_t>(
1713       rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
1714       rhs_dot_dimensions_info.contracting_dimensions().SizesArray(),
1715       rhs_dot_dimensions_info.out_dimensions().SizesArray());
1716   auto rhs_transposed = rewriter.create<mhlo::TransposeOp>(
1717       loc,
1718       RankedTensorType::get(rhs_transposed_shape, rhs_type.getElementType()),
1719       rhs,
1720       DenseIntElementsAttr::get(
1721           RankedTensorType::get({rhs_rank}, rewriter.getI64Type()),
1722           rhs_permutation));
1723 
1724   // Reshapes lhs to flatten out_dimensions and contracting_dimensions.
1725   llvm::SmallVector<int64_t, 4> lhs_flattened_shape = Concat<int64_t>(
1726       lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
1727       llvm::ArrayRef<int64_t>{
1728           lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
1729       llvm::ArrayRef<int64_t>{
1730           lhs_dot_dimensions_info.FlattenedContractingDimensionSize()});
1731   auto lhs_flattend = rewriter.create<mhlo::ReshapeOp>(
1732       loc,
1733       RankedTensorType::get(lhs_flattened_shape, lhs_type.getElementType()),
1734       lhs_transposed.getResult());
1735 
1736   // Reshapes rhs to flatten out_dimensions and contracting_dimensions.
1737   llvm::SmallVector<int64_t, 4> rhs_flattened_shape = Concat<int64_t>(
1738       rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
1739       llvm::ArrayRef<int64_t>{
1740           rhs_dot_dimensions_info.FlattenedContractingDimensionSize()},
1741       llvm::ArrayRef<int64_t>{
1742           rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
1743   auto rhs_flattend = rewriter.create<mhlo::ReshapeOp>(
1744       loc,
1745       RankedTensorType::get(rhs_flattened_shape, rhs_type.getElementType()),
1746       rhs_transposed.getResult());
1747 
1748   // Creates matmul op of `lhs_flattend` and `rhs_flattend`.
1749   llvm::SmallVector<int64_t, 4> matmul_shape =
1750       Concat<int64_t>(lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
1751                       llvm::ArrayRef<int64_t>{
1752                           lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
1753                       llvm::ArrayRef<int64_t>{
1754                           rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
1755   auto matmul = rewriter.create<TF::BatchMatMulV3Op>(
1756       loc, RankedTensorType::get(matmul_shape, result_type.getElementType()),
1757       lhs_flattend.getResult(), rhs_flattend.getResult());
1758   auto reshaped =
1759       rewriter.create<mhlo::ReshapeOp>(loc, result_type, matmul.getResult());
1760   return reshaped.getResult();
1761 }
1762 
1763 // Converts mhlo.dot to tf.MatMul. Reshape ops will be inserted when
1764 // necessary.
ConvertDotOp(PatternRewriter & rewriter,Operation * old_op)1765 Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) {
1766   auto dot_op = cast<mhlo::DotOp>(old_op);
1767   auto lhs_rank = dot_op.lhs().getType().cast<ShapedType>().getRank();
1768   auto dot_dimension_numbers =
1769       DotDimensionNumbersAttr::get(rewriter.getContext(),
1770                                    /*lhs_batching_dimensions=*/{},
1771                                    /*rhs_batching_dimensions=*/{},
1772                                    /*lhs_contracting_dimensions=*/
1773                                    {lhs_rank == 1 ? 0 : 1},
1774                                    /*rhs_contracting_dimensions=*/{0});
1775   return ConvertDot(rewriter, dot_op.lhs(), dot_op.rhs(), dot_dimension_numbers,
1776                     dot_op.getResult().getType().cast<ShapedType>(),
1777                     dot_op.getLoc());
1778 }
1779 
1780 // Converts mhlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be
1781 // inserted to convert to well-formed matrix multiply.
ConvertDotGeneralOp(PatternRewriter & rewriter,Operation * old_op)1782 Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
1783   auto dot_general_op = cast<mhlo::DotGeneralOp>(old_op);
1784   return ConvertDot(rewriter, dot_general_op.lhs(), dot_general_op.rhs(),
1785                     dot_general_op.dot_dimension_numbers(),
1786                     dot_general_op.getResult().getType().cast<ShapedType>(),
1787                     dot_general_op.getLoc());
1788 }
1789 
1790 // Checks if the specified region is a binary reduction function that takes 2
1791 // inputs, passes it to an instance of the specifiied reduction op and then
1792 // returns the result.
1793 template <typename ReductionOp>
MatchBinaryReduceFunction(mlir::Region & function)1794 LogicalResult MatchBinaryReduceFunction(mlir::Region &function) {
1795   Block &body = function.front();
1796   if (body.getNumArguments() != 2) return failure();
1797 
1798   mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
1799   if (!return_op) return failure();
1800   if (return_op.getNumOperands() != 1) return failure();
1801 
1802   ReductionOp reduce_op = dyn_cast_or_null<ReductionOp>(
1803       return_op.getOperands().front().getDefiningOp());
1804   if (!reduce_op) return failure();
1805   if (reduce_op.lhs() != body.getArgument(0) ||
1806       reduce_op.rhs() != body.getArgument(1))
1807     return failure();
1808 
1809   return success();
1810 }
1811 
1812 // Check if the specified region is a binary reduction function that takes 2
1813 // inputs and returns the second input. Functions like this are used by update
1814 // scatter like ops.
1815 template <>
MatchBinaryReduceFunction(mlir::Region & function)1816 LogicalResult MatchBinaryReduceFunction<void>(mlir::Region &function) {
1817   Block &body = function.front();
1818   if (body.getNumArguments() != 2) return failure();
1819 
1820   mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
1821   if (!return_op) return failure();
1822   if (return_op.getNumOperands() != 1) return failure();
1823   if (return_op.getOperands().front() != body.getArgument(1)) return failure();
1824   return success();
1825 }
1826 
1827 // Replace BinaryOp with a combination of TfBinaryOp and TfReduceOp if the
1828 // init value doesn't match the expection of TfReduceOp.
1829 template <typename TfReduceOp, typename TfBinOp>
rewriteNonMatchInitValue(mhlo::ReduceOp reduce_op,Value input,ConstOp reduction_indices,ConversionPatternRewriter & rewriter)1830 LogicalResult rewriteNonMatchInitValue(mhlo::ReduceOp reduce_op, Value input,
1831                                        ConstOp reduction_indices,
1832                                        ConversionPatternRewriter &rewriter) {
1833   Value reduce_result = rewriter.create<TfReduceOp>(
1834       reduce_op.getLoc(), reduce_op.getType(0), input, reduction_indices,
1835       /*keep_dim=*/rewriter.getBoolAttr(false));
1836   rewriter.replaceOpWithNewOp<TfBinOp>(reduce_op, reduce_op.getType(0),
1837                                        reduce_result,
1838                                        reduce_op.init_values()[0]);
1839   return success();
1840 }
1841 
1842 // Cannot replace BinaryOp if the init value doesn't match the expection of
1843 // TfReduceOp and there is no corresponding TfBinaryOp.
1844 template <>
rewriteNonMatchInitValue(mhlo::ReduceOp reduce_op,Value input,ConstOp reduction_indices,ConversionPatternRewriter & rewriter)1845 LogicalResult rewriteNonMatchInitValue<TF::MaxOp, void>(
1846     mhlo::ReduceOp reduce_op, Value input, ConstOp reduction_indices,
1847     ConversionPatternRewriter &rewriter) {
1848   return failure();
1849 }
1850 
1851 template <>
rewriteNonMatchInitValue(mhlo::ReduceOp reduce_op,Value input,ConstOp reduction_indices,ConversionPatternRewriter & rewriter)1852 LogicalResult rewriteNonMatchInitValue<TF::MinOp, void>(
1853     mhlo::ReduceOp reduce_op, Value input, ConstOp reduction_indices,
1854     ConversionPatternRewriter &rewriter) {
1855   return failure();
1856 }
1857 
1858 // Converts a mhlo.reduce op with a mlho binary operation into a tensorflow
1859 // reduction operation. If the initial value can be ignored, then convert it
1860 // into a single TfReduceOp. Otherwise, convert it into a TfReduceOp followed by
1861 // a TfBinaryOp.
1862 // For example:
1863 //   1) A mhlo::ReduceOp on value `x` with a mhlo::AndOp and a constant initial
1864 // value `true` is converted to a TF::Any on value `x`.
1865 //   2) A mhlo::ReduceOp on value `x` with a mhlo::AndOp with a non-constant
1866 // initial value `y` is converted to a TF::Any on value `x`, followed by a
1867 // TF::And with initial value `y`.
1868 template <typename BinaryOp, typename TfReduceOp, typename TfBinaryOp = void>
1869 class ConvertReduceOpToTfOp : public OpConversionPattern<mhlo::ReduceOp> {
1870  public:
1871   using OpConversionPattern::OpConversionPattern;
1872 
matchAndRewrite(mhlo::ReduceOp reduce_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1873   LogicalResult matchAndRewrite(
1874       mhlo::ReduceOp reduce_op, OpAdaptor adaptor,
1875       ConversionPatternRewriter &rewriter) const final {
1876     if (failed(MatchReduceOpOperand(reduce_op))) return failure();
1877 
1878     if (failed(MatchBinaryReduceFunction<BinaryOp>(reduce_op.body())))
1879       return failure();
1880 
1881     auto operand = reduce_op.operands()[0];
1882 
1883     // Get reduction dimension.
1884     DenseIntElementsAttr dimension = reduce_op.dimensions();
1885     SmallVector<int64_t, 4> reduce_dims;
1886     for (const int64_t &dim : dimension.getValues<int64_t>()) {
1887       reduce_dims.emplace_back(dim);
1888     }
1889     auto dim_type = RankedTensorType::get(
1890         {static_cast<int64_t>(reduce_dims.size())}, rewriter.getI64Type());
1891     auto reduction_indices = rewriter.create<ConstOp>(
1892         reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr(reduce_dims));
1893 
1894     // In `MatchReduceOpOperand` function, we already match that the
1895     // "mhlo::ReduceOp" only has one operand, one init_value and one result.
1896 
1897     // If the init value matches with the init value expected for the target
1898     // TfReduceOp, then replace the BinaryOp with a TfReduceOp. Otherwise,
1899     // replace the BinaryOp with a TfBinaryOp and a TfReduceOp.
1900     if (succeeded(MatchInitValue(reduce_op.init_values()[0]))) {
1901       rewriter.replaceOpWithNewOp<TfReduceOp>(
1902           reduce_op, reduce_op.getType(0), operand, reduction_indices,
1903           /*keep_dim=*/rewriter.getBoolAttr(false));
1904       return success();
1905     }
1906     return rewriteNonMatchInitValue<TfReduceOp, TfBinaryOp>(
1907         reduce_op, operand, reduction_indices, rewriter);
1908   }
1909 
1910  private:
1911   // Checks that the init value matches with the init value expected for the
1912   // target TfReduceOp.
1913   virtual LogicalResult MatchInitValue(Value init_value) const = 0;
1914 
1915   // This function tries to match that the "mhlo::ReduceOp" only has one
1916   // operand, one init_value and one result.
MatchReduceOpOperand(mhlo::ReduceOp reduce_op) const1917   LogicalResult MatchReduceOpOperand(mhlo::ReduceOp reduce_op) const {
1918     if (reduce_op.operands().size() != 1 ||
1919         reduce_op.init_values().size() != 1 ||
1920         reduce_op.getResults().size() != 1)
1921       return failure();
1922 
1923     if (!reduce_op.operands()[0].getType().isa<RankedTensorType>())
1924       return failure();
1925     if (!reduce_op.getType(0).isa<RankedTensorType>()) return failure();
1926     return success();
1927   }
1928 };
1929 
1930 class ConvertReduceOpToTfSum
1931     : public ConvertReduceOpToTfOp<mhlo::AddOp, TF::SumOp, TF::AddOp> {
1932  public:
1933   using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
1934 
MatchInitValue(Value init_value) const1935   LogicalResult MatchInitValue(Value init_value) const override {
1936     auto type = init_value.getType().cast<ShapedType>().getElementType();
1937     if (type.isa<FloatType>()) {
1938       APFloat const_value(.0);
1939       if (failed(GetConstantSplatValue(init_value, const_value)) ||
1940           !const_value.isZero())
1941         return failure();
1942     } else if (type.isa<IntegerType>() && type.isSignlessInteger()) {
1943       APInt const_value;
1944       if (failed(GetConstantSplatValue(init_value, const_value)) ||
1945           !const_value.isZero())
1946         return failure();
1947     } else {
1948       return failure();
1949     }
1950 
1951     return success();
1952   }
1953 };
1954 
1955 class ConvertReduceOpToTfMax
1956     : public ConvertReduceOpToTfOp<mhlo::MaxOp, TF::MaxOp> {
1957  public:
1958   using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
1959 
MatchInitValue(Value init_value) const1960   LogicalResult MatchInitValue(Value init_value) const override {
1961     auto type = init_value.getType().cast<ShapedType>().getElementType();
1962     if (type.isa<FloatType>()) {
1963       APFloat const_value(.0);
1964       if (failed(GetConstantSplatValue(init_value, const_value)) ||
1965           !const_value.isInfinity() || !const_value.isNegative())
1966         return failure();
1967     } else if (type.isa<IntegerType>() && type.isSignlessInteger()) {
1968       APInt const_value;
1969       if (failed(GetConstantSplatValue(init_value, const_value)) ||
1970           !const_value.isMinSignedValue())
1971         return failure();
1972     } else {
1973       return failure();
1974     }
1975     return success();
1976   }
1977 };
1978 
1979 class ConvertReduceOpToTfMin
1980     : public ConvertReduceOpToTfOp<mhlo::MinOp, TF::MinOp> {
1981  public:
1982   using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
1983 
MatchInitValue(Value init_value) const1984   LogicalResult MatchInitValue(Value init_value) const override {
1985     auto type = init_value.getType().cast<ShapedType>().getElementType();
1986 
1987     if (type.isa<FloatType>()) {
1988       APFloat const_value(.0);
1989       if (failed(GetConstantSplatValue(init_value, const_value)) ||
1990           !const_value.isInfinity() || const_value.isNegative())
1991         return failure();
1992     } else if (type.isa<IntegerType>() && type.isSignlessInteger()) {
1993       APInt const_value;
1994       if (failed(GetConstantSplatValue(init_value, const_value)) ||
1995           !const_value.isMaxSignedValue())
1996         return failure();
1997     } else {
1998       return failure();
1999     }
2000     return success();
2001   }
2002 };
2003 
2004 class ConvertReduceOpToTfAll
2005     : public ConvertReduceOpToTfOp<mhlo::AndOp, TF::AllOp, TF::LogicalAndOp> {
2006  public:
2007   using ConvertReduceOpToTfOp<mhlo::AndOp, TF::AllOp,
2008                               TF::LogicalAndOp>::ConvertReduceOpToTfOp;
2009 
MatchInitValue(Value init_value) const2010   LogicalResult MatchInitValue(Value init_value) const override {
2011     DenseIntElementsAttr init_attr;
2012     if (!matchPattern(init_value, m_Constant(&init_attr)) ||
2013         !init_attr.getType().getElementType().isInteger(1) ||
2014         !init_attr.isSplat() || !init_attr.getSplatValue<BoolAttr>().getValue())
2015       return failure();
2016     return success();
2017   }
2018 };
2019 
2020 class ConvertReduceOpToTfAny
2021     : public ConvertReduceOpToTfOp<mhlo::OrOp, TF::AnyOp, TF::LogicalOrOp> {
2022  public:
2023   using ConvertReduceOpToTfOp<mhlo::OrOp, TF::AnyOp,
2024                               TF::LogicalOrOp>::ConvertReduceOpToTfOp;
2025 
MatchInitValue(Value init_value) const2026   LogicalResult MatchInitValue(Value init_value) const override {
2027     DenseIntElementsAttr init_attr;
2028     if (!matchPattern(init_value, m_Constant(&init_attr)) ||
2029         !init_attr.getType().getElementType().isInteger(1) ||
2030         !init_attr.isSplat() || init_attr.getSplatValue<BoolAttr>().getValue())
2031       return failure();
2032     return success();
2033   }
2034 };
2035 
2036 template <typename TfReduce, typename TfArgReduce>
2037 class ConvertReduceOpToTfArgMinMax
2038     : public OpConversionPattern<mhlo::ReduceOp> {
2039  public:
2040   using OpConversionPattern::OpConversionPattern;
matchAndRewrite(mhlo::ReduceOp reduce_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2041   LogicalResult matchAndRewrite(
2042       mhlo::ReduceOp reduce_op, OpAdaptor adaptor,
2043       ConversionPatternRewriter &rewriter) const final {
2044     if (reduce_op.operands().size() != 2) return failure();
2045     if (reduce_op.dimensions().getNumElements() != 1) return failure();
2046 
2047     // Check that the operand init is the expected value.
2048     DenseElementsAttr operand_init;
2049     if (!matchPattern(reduce_op.init_values().front(),
2050                       m_Constant(&operand_init)))
2051       return failure();
2052     if (!IsValueInitValue(operand_init)) return failure();
2053 
2054     // Check that the iota init is zero.
2055     DenseElementsAttr iota_init;
2056     if (!matchPattern(reduce_op.init_values().back(), m_Constant(&iota_init)))
2057       return failure();
2058     if (iota_init.getValues<APInt>()[0] != 0) return failure();
2059 
2060     // Verify that the second argument is an Iota op along the same dimension
2061     // as the reduction.
2062     Value iota = reduce_op.operands().back();
2063     if (!MatchIota(reduce_op.dimensions(), iota)) return failure();
2064 
2065     // Match the reduction computation.
2066     const bool is_float = operand_init.getElementType().isa<FloatType>();
2067     if (failed(matchReduceComputation(reduce_op.body(), is_float)))
2068       return failure();
2069 
2070     Value operand = reduce_op.operands().front();
2071     int64_t axis = reduce_op.dimensions().getValues<int64_t>()[0];
2072 
2073     auto dim_type = RankedTensorType::get({1}, rewriter.getI64Type());
2074     auto reduction_indices = rewriter.create<ConstOp>(
2075         reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr({axis}));
2076 
2077     // Generate a Max and an ArgMax of as the mhlo op returns both while in TF
2078     // we have separate ops for them. If only one of them is used then the other
2079     // one will be garbage collected later.
2080     auto tf_reduce_op = rewriter.create<TfReduce>(
2081         reduce_op.getLoc(), reduce_op->getResult(0).getType(), operand,
2082         reduction_indices,
2083         /*keep_dim=*/rewriter.getBoolAttr(false));
2084     auto tf_argreduce_op = rewriter.create<TfArgReduce>(
2085         reduce_op.getLoc(), reduce_op->getResult(1).getType(), operand,
2086         reduction_indices);
2087 
2088     rewriter.replaceOp(reduce_op, {tf_reduce_op, tf_argreduce_op});
2089     return success();
2090   }
2091 
2092   // Pattern matches the following reduction function for ArgMax/ArgMin:
2093   // %0 = compare{GT}(%lhs_value, %rhs_value)
2094   // %1 = compare{NE}(%lhs_value, %lhs_value)
2095   // %2 = or(%0, %1)
2096   // %3 = select(%2, %lhs_value, %rhs_value)
2097   // %4 = compare{EQ}(%lhs_value, %rhs_value)
2098   // %5 = compare{LT}(%lhs_index, %rhs_index)
2099   // %6 = and(%4, %5)
2100   // %7 = or(%2, %6)
2101   // %8 = select(%7, %lhs_index, %rhs_index)
2102   // return %3, %8
2103   // Also note that %1 may be folded if %lhs_value is of integer types.
matchReduceComputation(Region & computation,bool is_float) const2104   LogicalResult matchReduceComputation(Region &computation,
2105                                        bool is_float) const {
2106     Block &body = computation.front();
2107     if (body.getNumArguments() != 4) return failure();
2108 
2109     mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
2110     if (!return_op || return_op.getNumOperands() != 2) return failure();
2111 
2112     mhlo::SelectOp value_select = llvm::dyn_cast_or_null<mhlo::SelectOp>(
2113         return_op.getOperand(0).getDefiningOp());
2114     if (!value_select || value_select.on_true() != body.getArgument(0) ||
2115         value_select.on_false() != body.getArgument(2))
2116       return failure();
2117 
2118     if (is_float) {
2119       mhlo::OrOp value_or = llvm::dyn_cast_or_null<mhlo::OrOp>(
2120           value_select.getOperand(0).getDefiningOp());
2121       if (!value_or) return failure();
2122 
2123       mhlo::CompareOp value_gt = llvm::dyn_cast_or_null<mhlo::CompareOp>(
2124           value_or.lhs().getDefiningOp());
2125       if (!value_gt || value_gt.comparison_direction() != CompareDirection() ||
2126           value_gt.lhs() != body.getArgument(0) ||
2127           value_gt.rhs() != body.getArgument(2))
2128         return failure();
2129 
2130       mhlo::CompareOp value_ne = llvm::dyn_cast_or_null<mhlo::CompareOp>(
2131           value_or.rhs().getDefiningOp());
2132       if (!value_ne ||
2133           value_ne.comparison_direction() != mhlo::ComparisonDirection::NE ||
2134           value_ne.lhs() != body.getArgument(0) ||
2135           value_ne.rhs() != body.getArgument(0))
2136         return failure();
2137     } else {
2138       mhlo::CompareOp value_gt = llvm::dyn_cast_or_null<mhlo::CompareOp>(
2139           value_select.getOperand(0).getDefiningOp());
2140       if (!value_gt || value_gt.comparison_direction() != CompareDirection() ||
2141           value_gt.lhs() != body.getArgument(0) ||
2142           value_gt.rhs() != body.getArgument(2))
2143         return failure();
2144     }
2145 
2146     mhlo::SelectOp index_select = llvm::dyn_cast_or_null<mhlo::SelectOp>(
2147         return_op.getOperand(1).getDefiningOp());
2148     if (!index_select || index_select.on_true() != body.getArgument(1) ||
2149         index_select.on_false() != body.getArgument(3))
2150       return failure();
2151 
2152     mhlo::OrOp index_or =
2153         llvm::dyn_cast_or_null<mhlo::OrOp>(index_select.pred().getDefiningOp());
2154 
2155     if (!index_or || index_or.lhs() != value_select.pred()) return failure();
2156 
2157     mhlo::AndOp index_and =
2158         llvm::dyn_cast_or_null<mhlo::AndOp>(index_or.rhs().getDefiningOp());
2159     if (!index_and) return failure();
2160 
2161     mhlo::CompareOp value_eq = llvm::dyn_cast_or_null<mhlo::CompareOp>(
2162         index_and.lhs().getDefiningOp());
2163     if (!value_eq ||
2164         value_eq.comparison_direction() != mhlo::ComparisonDirection::EQ ||
2165         value_eq.lhs() != body.getArgument(0) ||
2166         value_eq.rhs() != body.getArgument(2))
2167       return failure();
2168 
2169     mhlo::CompareOp index_lt = llvm::dyn_cast_or_null<mhlo::CompareOp>(
2170         index_and.rhs().getDefiningOp());
2171     if (!index_lt ||
2172         index_lt.comparison_direction() != mhlo::ComparisonDirection::LT ||
2173         index_lt.lhs() != body.getArgument(1) ||
2174         index_lt.rhs() != body.getArgument(3))
2175       return failure();
2176 
2177     return success();
2178   }
2179 
2180   virtual mhlo::ComparisonDirection CompareDirection() const = 0;
2181 
2182   virtual bool IsValueInitValue(const DenseElementsAttr &attr) const = 0;
2183 };
2184 
2185 class ConvertReduceOpToTfArgmax
2186     : public ConvertReduceOpToTfArgMinMax<TF::MaxOp, TF::ArgMaxOp> {
2187  public:
2188   using ConvertReduceOpToTfArgMinMax::ConvertReduceOpToTfArgMinMax;
2189 
CompareDirection() const2190   mhlo::ComparisonDirection CompareDirection() const override {
2191     return mhlo::ComparisonDirection::GT;
2192   }
IsValueInitValue(const DenseElementsAttr & attr) const2193   bool IsValueInitValue(const DenseElementsAttr &attr) const override {
2194     auto element_type = attr.getType().getElementType();
2195     if (attr.getNumElements() != 1 || !element_type.isIntOrFloat() ||
2196         element_type.isInteger(1))
2197       return false;
2198     if (element_type.isa<FloatType>()) {
2199       auto value = *attr.value_begin<APFloat>();
2200       return value.isNegative() && value.isInfinity();
2201     } else {
2202       auto value = *attr.value_begin<APInt>();
2203       return element_type.isUnsignedInteger() ? value.isMinValue()
2204                                               : value.isMinSignedValue();
2205     }
2206   }
2207 };
2208 
2209 class ConvertReduceOpToTfArgmin
2210     : public ConvertReduceOpToTfArgMinMax<TF::MinOp, TF::ArgMinOp> {
2211  public:
2212   using ConvertReduceOpToTfArgMinMax::ConvertReduceOpToTfArgMinMax;
2213 
CompareDirection() const2214   mhlo::ComparisonDirection CompareDirection() const override {
2215     return mhlo::ComparisonDirection::LT;
2216   }
IsValueInitValue(const DenseElementsAttr & attr) const2217   bool IsValueInitValue(const DenseElementsAttr &attr) const override {
2218     auto element_type = attr.getType().getElementType();
2219     if (attr.getNumElements() != 1 || !element_type.isIntOrFloat() ||
2220         element_type.isInteger(1))
2221       return false;
2222     if (element_type.isa<FloatType>()) {
2223       auto value = *attr.value_begin<APFloat>();
2224       return !value.isNegative() && value.isInfinity();
2225     } else {
2226       auto value = *attr.value_begin<APInt>();
2227       return element_type.isUnsignedInteger() ? value.isMaxValue()
2228                                               : value.isMaxSignedValue();
2229     }
2230   }
2231 };
2232 
2233 class ConvertIotaOpToTfRange : public OpConversionPattern<mhlo::IotaOp> {
2234  public:
2235   using OpConversionPattern::OpConversionPattern;
2236 
matchAndRewrite(mhlo::IotaOp iota_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2237   LogicalResult matchAndRewrite(
2238       mhlo::IotaOp iota_op, OpAdaptor adaptor,
2239       ConversionPatternRewriter &rewriter) const final {
2240     RankedTensorType type =
2241         iota_op.getType().dyn_cast_or_null<RankedTensorType>();
2242     // TF::RangeOp doesn't support UI16.
2243     if (!type || type.getElementType().isUnsignedInteger(16)) return failure();
2244 
2245     const uint64_t dimension = iota_op.iota_dimension();
2246     Type element_type = type.getElementType();
2247     Attribute start, limit, delta;
2248     if (element_type.isa<FloatType>()) {
2249       start = rewriter.getFloatAttr(element_type, 0.0);
2250       limit = rewriter.getFloatAttr(element_type, type.getShape()[dimension]);
2251       delta = rewriter.getFloatAttr(element_type, 1.0);
2252     } else if (element_type.isa<IntegerType>()) {
2253       start = rewriter.getIntegerAttr(element_type, 0);
2254       limit = rewriter.getIntegerAttr(element_type, type.getShape()[dimension]);
2255       delta = rewriter.getIntegerAttr(element_type, 1);
2256     } else {
2257       return failure();
2258     }
2259 
2260     auto range_type =
2261         RankedTensorType::get({type.getShape()[dimension]}, element_type);
2262     Value start_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), start);
2263     Value limit_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), limit);
2264     Value delta_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), delta);
2265     Value result = rewriter.create<TF::RangeOp>(iota_op.getLoc(), range_type,
2266                                                 start_op, limit_op, delta_op);
2267 
2268     if (type.getRank() > 1) {
2269       std::vector<int64_t> reshape_shape(type.getRank(), 1);
2270       reshape_shape[iota_op.iota_dimension()] = type.getShape()[dimension];
2271       auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
2272       Value reshape_shape_op = rewriter.create<TF::ConstOp>(
2273           iota_op.getLoc(), rewriter.getI64TensorAttr(reshape_shape));
2274       result = rewriter.create<TF::ReshapeOp>(iota_op.getLoc(), reshape_type,
2275                                               result, reshape_shape_op);
2276 
2277       Value broadcast_shape_op = rewriter.create<TF::ConstOp>(
2278           iota_op.getLoc(), rewriter.getI64TensorAttr(type.getShape()));
2279       result = rewriter.create<TF::BroadcastToOp>(iota_op.getLoc(), type,
2280                                                   result, broadcast_shape_op);
2281     }
2282 
2283     rewriter.replaceOp(iota_op, result);
2284     return success();
2285   }
2286 };
2287 
2288 // A helper function for ConvertMaxPoolOp and ConvertAvgMaxPoolOp. Returns true
2289 // if the given ReduceWindowOp is a spatial pooling without dilation. If returns
2290 // true, also outputs the window strides and the TF padding mode ("VALID" or
2291 // "SAME").
IsSpatialPoolingWithoutDilation(mhlo::ReduceWindowOp rw,llvm::SmallVectorImpl<int64_t> * window_strides,std::string * padding_mode)2292 bool IsSpatialPoolingWithoutDilation(
2293     mhlo::ReduceWindowOp rw, llvm::SmallVectorImpl<int64_t> *window_strides,
2294     std::string *padding_mode) {
2295   // tf.max_pool or tf.avg_pool need at least 3 dimensions (batch, spatial,
2296   // channel).
2297   const uint64_t rank = rw.window_dimensions().size();
2298   if (rank <= 2) return false;
2299 
2300   if (rw.window_strides().has_value()) {
2301     window_strides->insert(window_strides->end(),
2302                            rw.window_strides()->getValues<int64_t>().begin(),
2303                            rw.window_strides()->getValues<int64_t>().end());
2304   } else {
2305     window_strides->resize(rank, 1);
2306   }
2307 
2308   llvm::SmallVector<int64_t, 10> padding;
2309   if (rw.padding().has_value()) {
2310     padding.insert(padding.begin(), rw.padding()->getValues<int64_t>().begin(),
2311                    rw.padding()->getValues<int64_t>().end());
2312   } else {
2313     padding.resize(2 * rank, 0);
2314   }
2315 
2316   // Check that we don't do any reduction along the batch (first) and channel
2317   // (last) dimensions.
2318   const uint64_t batch_dim = 0;
2319   const uint64_t channel_dim = rank - 1;
2320   if (rw.window_dimensions().getValues<int64_t>()[batch_dim] != 1 ||
2321       rw.window_dimensions().getValues<int64_t>()[channel_dim] != 1 ||
2322       (*window_strides)[batch_dim] != 1 ||
2323       (*window_strides)[channel_dim] != 1 || padding[2 * batch_dim] != 0 ||
2324       padding[2 * batch_dim + 1] != 0 || padding[2 * channel_dim] != 0 ||
2325       padding[2 * channel_dim + 1] != 0)
2326     return false;
2327 
2328   if (rw.window_dilations().has_value() &&
2329       !(rw.window_dilations()->isSplat() &&
2330         rw.window_dilations()->getSplatValue<APInt>() == 1))
2331     return false;
2332 
2333   if (rw.base_dilations().has_value() &&
2334       !(rw.base_dilations()->isSplat() &&
2335         rw.base_dilations()->getSplatValue<APInt>() == 1))
2336     return false;
2337 
2338   if (llvm::all_of(padding, [](int64_t i) { return i == 0; })) {
2339     *padding_mode = "VALID";
2340     return true;
2341   }
2342 
2343   // Check that the individual padding values are corresponding to SAME
2344   // padding from TensorFlow.
2345   auto operand_type = rw.operands()[0].getType().dyn_cast<RankedTensorType>();
2346   RankedTensorType output_type =
2347       rw.getResult(0).getType().dyn_cast<RankedTensorType>();
2348   if (!operand_type || !output_type) return false;
2349 
2350   for (uint64_t i = 1; i < rank - 1; ++i) {
2351     int64_t padding_size =
2352         (output_type.getShape()[i] - 1) * (*window_strides)[i] +
2353         rw.window_dimensions().getValues<int64_t>()[i] -
2354         operand_type.getShape()[i];
2355     if (padding[2 * i] != tensorflow::MathUtil::FloorOfRatio(
2356                               padding_size, static_cast<int64_t>(2)) ||
2357         padding[2 * i + 1] != tensorflow::MathUtil::CeilOfRatio(
2358                                   padding_size, static_cast<int64_t>(2)))
2359       return false;
2360   }
2361 
2362   *padding_mode = "SAME";
2363   return true;
2364 }
2365 
2366 // Convert a reduce_window operation into a cumulative operation where possible
2367 // for a given binary operation.
2368 template <class BinaryOp, class TfCumOp>
2369 class ConvertLoweredCumOp : public OpConversionPattern<mhlo::ReduceWindowOp> {
2370  public:
2371   using OpConversionPattern::OpConversionPattern;
2372 
2373   virtual bool IsInitValue(const DenseElementsAttr &attr) const = 0;
2374 
matchAndRewrite(mhlo::ReduceWindowOp rw,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2375   LogicalResult matchAndRewrite(
2376       mhlo::ReduceWindowOp rw, OpAdaptor adaptor,
2377       ConversionPatternRewriter &rewriter) const final {
2378     if (rw.getNumResults() != 1 || rw.operands().size() != 1 ||
2379         rw.init_values().size() != 1)
2380       return failure();
2381 
2382     if (failed(MatchBinaryReduceFunction<BinaryOp>(rw.body())))
2383       return failure();
2384 
2385     // Ensure that initial_values are as expected.
2386     auto const_op = llvm::dyn_cast_or_null<mhlo::ConstantOp>(
2387         rw.init_values()[0].getDefiningOp());
2388     if (!const_op) return failure();
2389     auto const_op_dense_value = const_op.value().cast<DenseElementsAttr>();
2390     if (!const_op_dense_value || !IsInitValue(const_op_dense_value)) {
2391       return failure();
2392     }
2393 
2394     auto operand_type = rw.operands()[0].getType().cast<ShapedType>();
2395 
2396     // For a cumulative op, require a tensor of 1s for each dimension in
2397     // operand.
2398     auto is_splat_int64_ones =
2399         [&rewriter,
2400          &operand_type](const ::llvm::Optional<DenseIntElementsAttr> &attr) {
2401           // According to the definition, the default value of these attributes
2402           // are all ones when unspecified.
2403           if (!attr.has_value()) return true;
2404           if (attr->getType().getShape()[0] != operand_type.getRank())
2405             return false;
2406           if (!attr->isSplat()) return false;
2407           if (attr->getElementType() != rewriter.getIntegerType(64))
2408             return false;
2409           if (attr->getSplatValue<APInt>().getSExtValue() != 1) return false;
2410           return true;
2411         };
2412     if (!is_splat_int64_ones(rw.base_dilations()) ||
2413         !is_splat_int64_ones(rw.window_dilations()) ||
2414         !is_splat_int64_ones(rw.window_strides()))
2415       return failure();
2416 
2417     // Determine which axis is being used for the cumulative operation.
2418     //
2419     // For a cumulative op, window_dimensions should be of the form:
2420     //  dense<[1, 1, N, 1]>
2421     // where N is the same as the size of the corresponding input dimension
2422     // and there is a 1-entry for each input dimension not being operated
2423     // over.
2424     const auto &window_dimensions = rw.window_dimensions();
2425     if (window_dimensions.size() != operand_type.getRank()) return failure();
2426     int64_t cumulative_axis = -1;
2427     for (int64_t i = 0, e = window_dimensions.size(); i < e; ++i) {
2428       int64_t window_dimension = window_dimensions.getValues<int64_t>()[i];
2429       if (window_dimension == 1) continue;
2430       // Cumulative axis already set.
2431       if (cumulative_axis != -1) return failure();
2432       // Potential cumulative axis is not the right size.
2433       if (window_dimension != operand_type.getShape()[i]) return failure();
2434       cumulative_axis = i;
2435     }
2436 
2437     if (cumulative_axis == -1) {
2438       return rewriter.notifyMatchFailure(rw, "no reduced dimension is found.");
2439     }
2440 
2441     // For a cumulative op, padding (expressed as a list of left-padding and
2442     // right-padding pairs) should be of the form:
2443     //  dense<[[0, 0], [0, 0], [N-1, 0], [0, 0]]>
2444     // where N is the size of the input dimension being operated over.
2445     if (!rw.padding()) return failure();
2446     const auto &padding = rw.padding()->getValues<int64_t>();
2447     if (padding.size() != operand_type.getRank() * 2) return failure();
2448     int64_t padding_value = operand_type.getShape()[cumulative_axis] - 1;
2449     for (int64_t dim = 0; dim < operand_type.getRank(); ++dim) {
2450       int64_t left_padding = padding[2 * dim];
2451       int64_t right_padding = padding[2 * dim + 1];
2452       if (dim == cumulative_axis) {
2453         if (left_padding != padding_value) return failure();
2454       } else {
2455         if (left_padding != 0) return failure();
2456       }
2457       if (right_padding != 0) return failure();
2458     }
2459 
2460     auto axis = rewriter.create<TF::ConstOp>(
2461         rw->getLoc(),
2462         rewriter.getIntegerAttr(rewriter.getIntegerType(64), cumulative_axis));
2463 
2464     rewriter.replaceOpWithNewOp<TfCumOp>(rw, rw.getType(0), rw.operands()[0],
2465                                          axis, /* exclusive */ false,
2466                                          /* reverse */ false);
2467     return success();
2468   }
2469 };
2470 
2471 class ConvertLoweredCumSumOp
2472     : public ConvertLoweredCumOp<mhlo::AddOp, TF::CumsumOp> {
2473   using ConvertLoweredCumOp::ConvertLoweredCumOp;
IsInitValue(const DenseElementsAttr & attr) const2474   bool IsInitValue(const DenseElementsAttr &attr) const override {
2475     auto element_type = attr.getType().getElementType();
2476     if (attr.getNumElements() != 1 || !element_type.isIntOrFloat())
2477       return false;
2478     if (element_type.isa<FloatType>()) {
2479       auto value = *attr.value_begin<APFloat>();
2480       return value.isZero();
2481     }
2482     auto value = *attr.value_begin<APInt>();
2483     return value.isZero();
2484   }
2485 };
2486 
2487 class ConvertLoweredCumProdOp
2488     : public ConvertLoweredCumOp<mhlo::MulOp, TF::CumprodOp> {
2489   using ConvertLoweredCumOp::ConvertLoweredCumOp;
IsInitValue(const DenseElementsAttr & attr) const2490   bool IsInitValue(const DenseElementsAttr &attr) const override {
2491     auto element_type = attr.getType().getElementType();
2492     if (attr.getNumElements() != 1 || !element_type.isIntOrFloat())
2493       return false;
2494     if (element_type.isa<FloatType>()) {
2495       auto value = *attr.value_begin<APFloat>();
2496       return value.isExactlyValue(1.0);
2497     }
2498     auto value = *attr.value_begin<APInt>();
2499     return value.getSExtValue() == 1;
2500   }
2501 };
2502 
2503 // Maps the following representations of AvgPool in MHLO into a tf.AvgPool{3D}
2504 // operation when they cleanly map to 2D or 3D average pool with VALID or SAME
2505 // padding:
2506 // * div(reduce_sum_window(x), constant(sizeof(window)))
2507 // * div(reduce_sum_window(x), reduce_sum_window(constant(1)))
2508 class ConvertAvgPoolOp : public OpConversionPattern<mhlo::DivOp> {
2509  public:
ConvertAvgPoolOp(MLIRContext * context)2510   explicit ConvertAvgPoolOp(MLIRContext *context)
2511       : OpConversionPattern(context, /*benefit=*/10) {}
2512 
matchAndRewrite(mhlo::DivOp div_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2513   LogicalResult matchAndRewrite(
2514       mhlo::DivOp div_op, OpAdaptor adaptor,
2515       ConversionPatternRewriter &rewriter) const final {
2516     auto rw =
2517         dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.lhs().getDefiningOp());
2518     if (!rw || rw->getNumResults() != 1) return failure();
2519 
2520     // Check that the reduce-window is a sum-reduce-window.
2521     if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw.body())))
2522       return failure();
2523 
2524     // Check that this is a floating point reduce window with a rank of 4 or 5.
2525     const RankedTensorType rw_type =
2526         rw.getResult(0).getType().dyn_cast<RankedTensorType>();
2527     if (!rw_type || !rw_type.getElementType().isa<FloatType>() ||
2528         rw_type.getRank() <= 3 || rw_type.getRank() > 5)
2529       return failure();
2530 
2531     // Check that the Div op doesn't do broadcasting on the output of the reduce
2532     // window.
2533     if (div_op.getType() != rw_type) return failure();
2534 
2535     // If the init value isn't zero then it can't be an average pool.
2536     if (!isFloatZero(rw.init_values()[0])) return failure();
2537 
2538     llvm::SmallVector<int64_t, 5> window_strides;
2539     std::string padding_mode;
2540     if (!IsSpatialPoolingWithoutDilation(rw, &window_strides, &padding_mode)) {
2541       return rewriter.notifyMatchFailure(
2542           div_op, "not the root of spatial pooling without dilation");
2543     }
2544 
2545     DenseFPElementsAttr divisor;
2546     if (matchPattern(div_op.rhs(), m_Constant(&divisor))) {
2547       // If the divisor is a constant then check that it matches with the number
2548       // of elements inside the window what is required for a VALID AvgPool.
2549       if (!divisor.isSplat()) return failure();
2550       int64_t window_size = 1;
2551       for (int64_t w : rw.window_dimensions().getValues<int64_t>()) {
2552         window_size *= w;
2553       }
2554       if (!divisor.getSplatValue<APFloat>().isExactlyValue(window_size))
2555         return failure();
2556 
2557       if (padding_mode != "VALID") {
2558         return failure();
2559       }
2560 
2561       return replaceWithAvgPool(
2562           div_op, rw.operands()[0],
2563           llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
2564           window_strides, "VALID", rewriter);
2565     }
2566 
2567     auto rw_rhs =
2568         dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.rhs().getDefiningOp());
2569     if (rw_rhs && rw_rhs.getNumResults() == 1) {
2570       // Check that RHS is a sum-reduce-window.
2571       if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_rhs.body())))
2572         return failure();
2573 
2574       // Check that the RHS is a reduce_window over a constant 1 operand with 0
2575       // as the init value.
2576       DenseFPElementsAttr rhs_operand;
2577       if (!isFloatZero(rw_rhs.init_values()[0]) ||
2578           !matchPattern(rw_rhs.operands()[0], m_Constant(&rhs_operand)) ||
2579           !rhs_operand.isSplat() ||
2580           !rhs_operand.getSplatValue<APFloat>().isExactlyValue(1.0))
2581         return failure();
2582 
2583       // Check that the two reduce window have the same window configuration.
2584       if (rw.window_dimensions() != rw_rhs.window_dimensions() ||
2585           rw.window_strides() != rw_rhs.window_strides() ||
2586           rw.window_dilations() != rw_rhs.window_dilations() ||
2587           rw.base_dilations() != rw_rhs.base_dilations() ||
2588           rw.padding() != rw_rhs.padding())
2589         return failure();
2590 
2591       return replaceWithAvgPool(
2592           div_op, rw.operands()[0],
2593           llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
2594           window_strides, padding_mode, rewriter);
2595     }
2596 
2597     return failure();
2598   }
2599 
2600  private:
isFloatZero(Value value) const2601   bool isFloatZero(Value value) const {
2602     DenseFPElementsAttr initial_value;
2603     return matchPattern(value, m_Constant(&initial_value)) &&
2604            initial_value.getNumElements() == 1 &&
2605            initial_value.getValues<APFloat>()[0].isZero();
2606   }
2607 
replaceWithAvgPool(mhlo::DivOp op,Value input,llvm::ArrayRef<int64_t> ksizes,llvm::ArrayRef<int64_t> kstrides,llvm::StringRef padding,ConversionPatternRewriter & rewriter) const2608   LogicalResult replaceWithAvgPool(mhlo::DivOp op, Value input,
2609                                    llvm::ArrayRef<int64_t> ksizes,
2610                                    llvm::ArrayRef<int64_t> kstrides,
2611                                    llvm::StringRef padding,
2612                                    ConversionPatternRewriter &rewriter) const {
2613     if (ksizes.size() == 4) {
2614       rewriter.replaceOpWithNewOp<AvgPoolOp>(
2615           op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
2616           rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
2617           rewriter.getStringAttr("NHWC"));
2618       return success();
2619     } else if (ksizes.size() == 5) {
2620       rewriter.replaceOpWithNewOp<AvgPool3DOp>(
2621           op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
2622           rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
2623           rewriter.getStringAttr("NDHWC"));
2624       return success();
2625     }
2626     return failure();
2627   }
2628 };
2629 
2630 class ConvertMaxPoolOp : public OpConversionPattern<mhlo::ReduceWindowOp> {
2631  public:
2632   using OpConversionPattern::OpConversionPattern;
2633 
matchAndRewrite(mhlo::ReduceWindowOp rw,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2634   LogicalResult matchAndRewrite(
2635       mhlo::ReduceWindowOp rw, OpAdaptor adaptor,
2636       ConversionPatternRewriter &rewriter) const final {
2637     // Check that the reduce-window is a max-reduce-window.
2638     if (failed(MatchBinaryReduceFunction<mhlo::MaxOp>(rw.body())))
2639       return failure();
2640 
2641     // Check that this is a floating point reduce window with a rank of 4 or 5.
2642     const RankedTensorType rw_type =
2643         rw.getResult(0).getType().dyn_cast<RankedTensorType>();
2644     if (!rw_type || !rw_type.getElementType().isa<FloatType>() ||
2645         rw_type.getRank() <= 3 || rw_type.getRank() > 5)
2646       return failure();
2647 
2648     if (!isFloatMinusInfinity(rw.init_values()[0])) {
2649       return failure();
2650     }
2651 
2652     llvm::SmallVector<int64_t, 5> window_strides;
2653     std::string padding_mode;
2654     if (!IsSpatialPoolingWithoutDilation(rw, &window_strides, &padding_mode)) {
2655       return rewriter.notifyMatchFailure(
2656           rw, "not the root of spatial pooling without dilation");
2657     }
2658 
2659     return replaceWithMaxPool(
2660         rw, rw.operands()[0],
2661         llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
2662         window_strides, padding_mode, rewriter);
2663   }
2664 
2665  private:
isFloatMinusInfinity(Value value) const2666   bool isFloatMinusInfinity(Value value) const {
2667     DenseFPElementsAttr float_value;
2668     if (!matchPattern(value, m_Constant(&float_value))) {
2669       return false;
2670     }
2671 
2672     if (float_value.getNumElements() != 1) {
2673       return false;
2674     }
2675 
2676     APFloat element = float_value.getValues<APFloat>()[0];
2677     if (!element.isInfinity()) {
2678       return false;
2679     }
2680     if (!element.isNegative()) {
2681       return false;
2682     }
2683 
2684     return true;
2685   }
2686 
replaceWithMaxPool(mhlo::ReduceWindowOp op,Value input,llvm::ArrayRef<int64_t> ksizes,llvm::ArrayRef<int64_t> kstrides,llvm::StringRef padding,ConversionPatternRewriter & rewriter) const2687   LogicalResult replaceWithMaxPool(mhlo::ReduceWindowOp op, Value input,
2688                                    llvm::ArrayRef<int64_t> ksizes,
2689                                    llvm::ArrayRef<int64_t> kstrides,
2690                                    llvm::StringRef padding,
2691                                    ConversionPatternRewriter &rewriter) const {
2692     if (ksizes.size() == 4) {
2693       rewriter.replaceOpWithNewOp<MaxPoolOp>(
2694           op, op.getType(0), input, rewriter.getI64ArrayAttr(ksizes),
2695           rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
2696           /*explicit_paddings=*/rewriter.getI64ArrayAttr({}),
2697           rewriter.getStringAttr("NHWC"));
2698       return success();
2699     } else if (ksizes.size() == 5) {
2700       rewriter.replaceOpWithNewOp<MaxPool3DOp>(
2701           op, op.getType(0), input, rewriter.getI64ArrayAttr(ksizes),
2702           rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
2703           rewriter.getStringAttr("NDHWC"));
2704       return success();
2705     }
2706     return failure();
2707   }
2708 };
2709 
2710 class LegalizeHloToTf : public TF::LegalizeHloToTfPassBase<LegalizeHloToTf> {
2711   /// Performs the legalization to the TF dialect.
2712   void runOnOperation() override;
2713 };
2714 
2715 // Returns the shape of the given value in a Constant Op.
ShapeToConst(PatternRewriter & rewriter,Value value)2716 arith::ConstantOp ShapeToConst(PatternRewriter &rewriter, Value value) {
2717   ArrayRef<int64_t> shape = value.getType().cast<ShapedType>().getShape();
2718   auto attr_type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
2719                                          rewriter.getIntegerType(64));
2720   auto attr = DenseElementsAttr::get(attr_type, shape);
2721   return rewriter.create<arith::ConstantOp>(value.getLoc(), attr_type, attr);
2722 }
2723 
IsSign(APFloat a,APFloat sign)2724 bool IsSign(APFloat a, APFloat sign) {
2725   if (a.isNaN() || a.isZero()) return a == sign;
2726   if (a.isNegative()) return sign.isExactlyValue(-1.0);
2727   return sign.isExactlyValue(1.0);
2728 }
2729 
2730 // Returns whether the splat constant is the sign of the FloatTensor
FloatTensorIsSign(PatternRewriter & rewriter,ElementsAttr floatv,ElementsAttr sgn_cst)2731 bool FloatTensorIsSign(PatternRewriter &rewriter, ElementsAttr floatv,
2732                        ElementsAttr sgn_cst) {
2733   if (!sgn_cst.isa<SplatElementsAttr>()) return false;
2734   auto sgn_cst_spl = sgn_cst.cast<SplatElementsAttr>().getSplatValue<APFloat>();
2735   if (floatv.isa<SplatElementsAttr>()) {
2736     auto floatv_spl = floatv.cast<SplatElementsAttr>().getSplatValue<APFloat>();
2737     return IsSign(floatv_spl, sgn_cst_spl);
2738   } else if (floatv.isa<DenseElementsAttr>()) {
2739     auto floatv_dns = floatv.cast<DenseFPElementsAttr>();
2740     return llvm::all_of(floatv_dns.getValues<APFloat>(), [&](APFloat value) {
2741       return IsSign(value, sgn_cst_spl);
2742     });
2743   }
2744   return false;
2745 }
2746 
2747 // Check that `arr` is an R1 iota with integer element type starting from `0`
2748 // with `size` number of values.
IsIotaAttr(ArrayRef<int64_t> arr,int64_t size)2749 bool IsIotaAttr(ArrayRef<int64_t> arr, int64_t size) {
2750   if (arr.size() != size) return false;
2751   int64_t iota = 0;
2752   for (auto s : arr) {
2753     if (s != iota) return false;
2754     ++iota;
2755   }
2756   return true;
2757 }
2758 
2759 // Convert updates into canonical form as expected by tf.scatter ops.
2760 //
2761 // tf.scatter expects `update_window_dims` to be the trailing dimensions.
2762 //
2763 // To support scatter ops generated by numpy-like slice updates:
2764 //   nd_array[:, [i,j]] = [i_values, j_values]
2765 //
2766 // `updates` must be transposed when the update_window_dims are the leading
2767 // dimensions of `updates`.
2768 //
2769 // Other values of `update_window_dims` are left unsupported.
2770 //
2771 // Eg 1. An update in canonical form:
2772 //  * indices shape(A,B,C)
2773 //  * updates shape(A,B,D,E,F)
2774 // Then:
2775 //  * D,E,F are the update window dims [2,3,4]
2776 //  * C is the index vector dimension
2777 //  * A,B iterate over the updates and indices
2778 //
2779 // If `update_window_dims` are not the trailing dimensions then updates must be
2780 // transposed.
2781 //
2782 // Eg 2. An update in non-canonical form:
2783 //  * indices shape(a,b,c)
2784 //  * updates shape(d,e,f,a,b)
2785 // Then:
2786 //  * d,e,f are the update window dims [0,1,2]
2787 //  * c is the index vector dimension
2788 //  * a,b iterate over the updates and indices
2789 //
2790 //  The update needs permuting to be in the form (a,b,d,e,f) so that the update
2791 //  window dims are the trailing dimensions.
2792 //
2793 // To canonicalize the updates above, replace the updates with:
2794 //   transpose(updates, permutation={3,4,0,1,2})
2795 //
2796 // Note: NormalizeIndexVector is assumed to have run on the indices already so
2797 // that the index_vector_dim is the trailing dimension in `indices`.
CanonicalizeScatterUpdates(Operation * scatter_op,llvm::ArrayRef<int64_t> update_window_dims,const Value & indices,const ShapedType & indices_type,Value & updates,ShapedType & updates_type,ConversionPatternRewriter & rewriter)2798 LogicalResult CanonicalizeScatterUpdates(
2799     Operation *scatter_op, llvm::ArrayRef<int64_t> update_window_dims,
2800     const Value &indices, const ShapedType &indices_type, Value &updates,
2801     ShapedType &updates_type, ConversionPatternRewriter &rewriter) {
2802   auto canonical_update_window_dims = llvm::to_vector(
2803       llvm::seq<int64_t>(indices_type.getRank() - 1, updates_type.getRank()));
2804 
2805   if (canonical_update_window_dims == update_window_dims) return success();
2806 
2807   // Permute updates if `update_window_dims` are leading indices.
2808   // Other possibilities for `update_window_dims` are not supported yet.
2809   if (!IsIotaAttr(update_window_dims, update_window_dims.size()))
2810     return rewriter.notifyMatchFailure(
2811         scatter_op, "update_window_dims are not leading or trailing indices");
2812 
2813   SmallVector<int64_t, 4> permutation_array(updates_type.getRank());
2814   int64_t dim = 0;
2815   // Move leading indices to the back of the array.
2816   const auto permutation_array_size = permutation_array.size();
2817   for (int64_t i = update_window_dims.size(); i < permutation_array_size; ++i) {
2818     permutation_array[i] = dim;
2819     ++dim;
2820   }
2821   // Move trailing indices to the front of the array.
2822   for (int64_t i = 0; i < update_window_dims.size(); ++i) {
2823     permutation_array[i] = dim;
2824     ++dim;
2825   }
2826 
2827   auto permutation_and_shape = GetPermutationAndTransposedShape(
2828       permutation_array, updates_type, rewriter);
2829 
2830   auto transposed_updates = rewriter.create<mhlo::TransposeOp>(
2831       scatter_op->getLoc(), permutation_and_shape.shape, updates,
2832       permutation_and_shape.permutation);
2833 
2834   updates = transposed_updates;
2835   updates_type = permutation_and_shape.shape;
2836   return success();
2837 }
2838 
2839 // If index_vector_dim == indices.rank() then insert the implicit extra
2840 // dimension into indices to normalize everything to index_vector_dim ==
2841 // indices.rank() - 1.
NormalizeIndexVector(Operation * parent_op,Value & indices,ShapedType & indices_type,int64_t index_vector_dim,ConversionPatternRewriter & rewriter)2842 LogicalResult NormalizeIndexVector(Operation *parent_op, Value &indices,
2843                                    ShapedType &indices_type,
2844                                    int64_t index_vector_dim,
2845                                    ConversionPatternRewriter &rewriter) {
2846   if (index_vector_dim == indices_type.getRank()) {
2847     llvm::SmallVector<int64_t, 4> new_start_indices_shape(
2848         indices_type.getShape().begin(), indices_type.getShape().end());
2849     new_start_indices_shape.push_back(1);
2850     indices_type = RankedTensorType::get(new_start_indices_shape,
2851                                          indices_type.getElementType());
2852     indices = rewriter.create<mhlo::ReshapeOp>(parent_op->getLoc(),
2853                                                indices_type, indices);
2854   } else if (index_vector_dim != indices_type.getRank() - 1) {
2855     // If index_vector_dim isn't the last dimension in indices then it isn't
2856     // supported yet.
2857     // TODO(tberghammer): Transpose indices to support this usecase.
2858     return rewriter.notifyMatchFailure(
2859         parent_op,
2860         "index vector dim isn't the last dimension in start indices");
2861   }
2862   return success();
2863 }
2864 
2865 class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
2866  public:
2867   using OpConversionPattern::OpConversionPattern;
2868 
2869   // Helper params for representing the transpose params for the "canonicalized"
2870   // output to the real output.
2871   struct TransposeParams {
2872     std::vector<int64_t> permutation;
2873     // The following are the "canonicalized" output shape with offset dims.
2874     std::vector<int64_t> canonicalized_output_shape;
2875     std::vector<int64_t> canonicalized_offset_dims;
2876   };
2877 
matchAndRewrite(mhlo::GatherOp gather_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2878   LogicalResult matchAndRewrite(
2879       mhlo::GatherOp gather_op, OpAdaptor adaptor,
2880       ConversionPatternRewriter &rewriter) const final {
2881     Value operand = gather_op.operand();
2882     Value start_indices = gather_op.start_indices();
2883 
2884     // Can only convert with static shaped gather.
2885     ShapedType operand_type = operand.getType().cast<ShapedType>();
2886     ShapedType start_indices_type = start_indices.getType().cast<ShapedType>();
2887     ShapedType result_type = gather_op.getResult().getType().cast<ShapedType>();
2888     if (!operand_type.hasStaticShape() ||
2889         !start_indices_type.hasStaticShape() || !result_type.hasStaticShape()) {
2890       return failure();
2891     }
2892 
2893     // Normalize start_indices so index_vector_dim == start_indices.rank() - 1.
2894     int64_t index_vector_dim =
2895         gather_op.dimension_numbers().getIndexVectorDim();
2896     if (failed(NormalizeIndexVector(gather_op, start_indices,
2897                                     start_indices_type, index_vector_dim,
2898                                     rewriter))) {
2899       return failure();
2900     }
2901 
2902     // Verify that start_index_map and collapsed_slice_dims contains the same
2903     // values.
2904     auto start_index_map = gather_op.dimension_numbers().getStartIndexMap();
2905     auto collapsed_slice_dims =
2906         gather_op.dimension_numbers().getCollapsedSliceDims();
2907     if (start_index_map.size() != collapsed_slice_dims.size()) {
2908       return rewriter.notifyMatchFailure(
2909           gather_op,
2910           "different size for start index map and collapsed slice dims");
2911     }
2912     for (auto c : collapsed_slice_dims) {
2913       if (llvm::count(start_index_map, c) == 0) {
2914         return rewriter.notifyMatchFailure(
2915             gather_op, "collapsed slice dim isn't present in start index map");
2916       }
2917     }
2918 
2919     // Verify that slice_sizes is 1 for the indexed dimensions and the full
2920     // shape for the rest of the dimensions.
2921     auto slice_sizes = gather_op.slice_sizes();
2922     int64_t index = 0;
2923     for (int64_t s : slice_sizes.getValues<int64_t>()) {
2924       if (llvm::count(start_index_map, index)) {
2925         if (s != 1) {
2926           return rewriter.notifyMatchFailure(gather_op,
2927                                              "unsupported slice sizes");
2928         }
2929       } else {
2930         if (s != operand_type.getShape()[index]) {
2931           return rewriter.notifyMatchFailure(gather_op,
2932                                              "unsupported slice sizes");
2933         }
2934       }
2935       ++index;
2936     }
2937 
2938     // Verify that offset_dims are the tailing dimensions in the output tensor.
2939     auto offset_dims = gather_op.dimension_numbers().getOffsetDims();
2940     SmallVector<int64_t, 4> offset_dims_vector(offset_dims.begin(),
2941                                                offset_dims.end());
2942     const TransposeParams &transpose_params =
2943         CanonicalizeOffset(/*result_type=*/result_type,
2944                            /*original_offset_dims=*/offset_dims_vector);
2945 
2946     int64_t offset = start_indices_type.getRank() - 1;
2947     for (int64_t o : transpose_params.canonicalized_offset_dims) {
2948       if (o != offset) {
2949         return rewriter.notifyMatchFailure(gather_op,
2950                                            "unsupported offset dims");
2951       }
2952       ++offset;
2953     }
2954 
2955     // Transpose the operand to handle non-iota start index map.
2956     llvm::SmallVector<int64_t, 4> transpose_dimensions;
2957     llvm::SmallVector<int64_t, 4> transpose_shape;
2958     for (auto s : start_index_map) {
2959       transpose_dimensions.push_back(s);
2960       transpose_shape.push_back(operand_type.getShape()[s]);
2961     }
2962     for (int64_t i = 0, e = operand_type.getRank(); i < e; ++i) {
2963       if (llvm::count(start_index_map, i) == 0) {
2964         transpose_dimensions.push_back(i);
2965         transpose_shape.push_back(operand_type.getShape()[i]);
2966       }
2967     }
2968     operand_type =
2969         RankedTensorType::get(transpose_shape, operand_type.getElementType());
2970     operand = rewriter.create<mhlo::TransposeOp>(
2971         gather_op.getLoc(), operand_type, operand,
2972         rewriter.getI64TensorAttr(transpose_dimensions));
2973 
2974     // Check whether we need to append a transpose op after the gather nd.
2975     bool need_transpose_after = false;
2976     for (int i = 0; i < transpose_params.permutation.size(); ++i) {
2977       if (i != transpose_params.permutation[i]) {
2978         need_transpose_after = true;
2979         break;
2980       }
2981     }
2982 
2983     auto tf_gather_nd_result_type =
2984         RankedTensorType::get(transpose_params.canonicalized_output_shape,
2985                               result_type.getElementType());
2986     auto tf_gather_nd_op = rewriter.create<TF::GatherNdOp>(
2987         gather_op->getLoc(), tf_gather_nd_result_type, operand, start_indices);
2988     if (!need_transpose_after) {
2989       rewriter.replaceOp(gather_op, tf_gather_nd_op->getOpResults());
2990       return success();
2991     }
2992 
2993     // Insert the transpose op after the gather_nd.
2994     rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(
2995         gather_op, result_type, tf_gather_nd_op,
2996         rewriter.getI64TensorAttr(transpose_params.permutation));
2997 
2998     return success();
2999   }
3000 
3001  private:
3002   // Canonicalize the offset dims to make sure the offset dims are the trailing
3003   // dimensions of the output tensor.
3004   // We will also return the permutation for (the transpose op).
3005   // However, it's not guaranteed the canonicalized offset dims can make it
3006   // always legalizable to tf.
CanonicalizeOffset(ShapedType result_type,ArrayRef<int64_t> original_offset_dims) const3007   TransposeParams CanonicalizeOffset(
3008       ShapedType result_type, ArrayRef<int64_t> original_offset_dims) const {
3009     TransposeParams transpose_params;
3010     int output_rank = result_type.getRank();
3011     // The canonicalized offset should be the trailing of the output rank.
3012     for (int start = output_rank - original_offset_dims.size();
3013          start < output_rank; ++start) {
3014       transpose_params.canonicalized_offset_dims.push_back(start);
3015     }
3016 
3017     // For those dims NOT inside the original_offset_dims are considered "batch
3018     // dims".
3019     std::vector<int64_t> batch_dims;
3020     // Offset dims are guaranteed to be sorted.
3021     int offset_index = 0;
3022     for (int64_t i = 0; i < output_rank; ++i) {
3023       if (offset_index >= original_offset_dims.size() ||
3024           original_offset_dims[offset_index] != i) {
3025         batch_dims.push_back(i);
3026       } else {
3027         ++offset_index;
3028       }
3029     }
3030 
3031     // Populate the trnaspose permutation params from a "canonicalized" output
3032     // to the real output.
3033     // The canonicalized layout would be batch_dims followed by sliced_dims.
3034     // The current layout is essentially a transpose after the canonicalized
3035     // layout.
3036     // Take the following as an example:
3037     // If we have the:
3038     // original_offset_dims like [1, 2, 4]
3039     // batch_dims like [0, 3]
3040     // It's like performing transpose on a "canonicalized"
3041     // [batch_dims, sliced_dims]: [B1, B2, O1, O2, O3]
3042     // into the current layout: [B1, O1, O2, B2, O3]
3043     // where the permutation is [0, 2, 3, 1, 4]
3044     int batch_idx = 0;
3045     int offset_idx = 0;
3046     int batch_dim_size = batch_dims.size();
3047     for (int i = 0; i < output_rank; ++i) {
3048       if (batch_idx >= batch_dims.size()) {
3049         transpose_params.permutation.push_back(batch_dim_size + offset_idx);
3050         ++offset_idx;
3051       } else if (offset_idx < original_offset_dims.size() &&
3052                  original_offset_dims[offset_idx] < batch_dims[batch_idx]) {
3053         transpose_params.permutation.push_back(batch_dim_size + offset_idx);
3054         ++offset_idx;
3055       } else {
3056         transpose_params.permutation.push_back(batch_idx++);
3057       }
3058     }
3059 
3060     // Finally, let's find out what are the "canonicalized" output shape looks
3061     // like.
3062     for (auto dim : batch_dims) {
3063       transpose_params.canonicalized_output_shape.push_back(
3064           result_type.getDimSize(dim));
3065     }
3066     for (auto dim : original_offset_dims) {
3067       transpose_params.canonicalized_output_shape.push_back(
3068           result_type.getDimSize(dim));
3069     }
3070     return transpose_params;
3071   }
3072 };
3073 
3074 class ConvertWhileOp : public OpConversionPattern<mhlo::WhileOp> {
3075  public:
3076   using OpConversionPattern::OpConversionPattern;
3077 
matchAndRewrite(mhlo::WhileOp while_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3078   LogicalResult matchAndRewrite(
3079       mhlo::WhileOp while_op, OpAdaptor adaptor,
3080       ConversionPatternRewriter &rewriter) const final {
3081     // HLO WhileOp should have two regions: cond and body.
3082     if (while_op->getNumRegions() != 2) return failure();
3083 
3084     // This rule doesn't support mhlo::WhileOp with tuple inputs.
3085     for (auto type : while_op->getOperandTypes()) {
3086       if (type.isa<TupleType>()) return failure();
3087     }
3088 
3089     // Creates a TF::WhileRegionOp to replace the mhlo::WhileOp. HLO WhileOp
3090     // currently doesn't support stateless and shape invariant, so these
3091     // parameters are set to the default values.
3092     auto new_while = rewriter.create<TF::WhileRegionOp>(
3093         while_op.getLoc(), while_op->getResultTypes(), while_op->getOperands(),
3094         /*parallel_iterations=*/10,
3095         /*is_stateless=*/false, /*shape_invariant=*/false);
3096     new_while.cond().takeBody(while_op.cond());
3097     new_while.body().takeBody(while_op.body());
3098     ReplaceReturnOp(new_while.cond(), rewriter);
3099     ReplaceReturnOp(new_while.body(), rewriter);
3100     rewriter.replaceOp(while_op, new_while.getResults());
3101     return success();
3102   }
3103 };
3104 
3105 class ConvertIfOp : public OpConversionPattern<mhlo::IfOp> {
3106  public:
3107   using OpConversionPattern::OpConversionPattern;
3108 
matchAndRewrite(mhlo::IfOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3109   LogicalResult matchAndRewrite(
3110       mhlo::IfOp op, OpAdaptor adaptor,
3111       ConversionPatternRewriter &rewriter) const final {
3112     // HLO IfOp currently doesn't support stateless
3113     auto new_op = rewriter.create<TF::IfRegionOp>(
3114         op.getLoc(), op->getResultTypes(), op.pred(),
3115         /*is_stateless=*/false, /*_then_func_name=*/nullptr,
3116         /*_else_func_name=*/nullptr);
3117     new_op.then_branch().takeBody(op.true_branch());
3118     new_op.else_branch().takeBody(op.false_branch());
3119     ReplaceReturnOp(new_op.then_branch(), rewriter);
3120     ReplaceReturnOp(new_op.else_branch(), rewriter);
3121     rewriter.replaceOp(op, new_op.getResults());
3122     return success();
3123   }
3124 };
3125 
3126 template <typename BinaryOp, typename TfOp>
3127 class ConvertScatterOp : public OpConversionPattern<mhlo::ScatterOp> {
3128  public:
3129   using OpConversionPattern::OpConversionPattern;
3130 
matchAndRewrite(mhlo::ScatterOp scatter_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3131   LogicalResult matchAndRewrite(
3132       mhlo::ScatterOp scatter_op, OpAdaptor adaptor,
3133       ConversionPatternRewriter &rewriter) const final {
3134     OperandRange operands = scatter_op.operands();
3135     Value indices = scatter_op.scatter_indices();
3136     OperandRange updates = scatter_op.updates();
3137     if (operands.size() != 1 || updates.size() != 1) return failure();
3138 
3139     ShapedType operand_type = operands[0].getType().cast<ShapedType>();
3140     ShapedType indices_type = indices.getType().cast<ShapedType>();
3141     ShapedType updates_type = updates[0].getType().cast<ShapedType>();
3142 
3143     Value new_updates = updates[0];
3144 
3145     // Can only convert with static shaped scatter.
3146     if (!operand_type.hasStaticShape() || !indices_type.hasStaticShape() ||
3147         !updates_type.hasStaticShape()) {
3148       return failure();
3149     }
3150 
3151     // Match the scatter computation against computations supported by TF.
3152     if (failed(MatchBinaryReduceFunction<BinaryOp>(
3153             scatter_op.update_computation()))) {
3154       return failure();
3155     }
3156 
3157     auto scatter_dimension_numbers = scatter_op.scatter_dimension_numbers();
3158 
3159     // Normalize indices so index_vector_dim == indices.rank() - 1.
3160     int64_t index_vector_dim = scatter_dimension_numbers.getIndexVectorDim();
3161     if (failed(NormalizeIndexVector(scatter_op, indices, indices_type,
3162                                     index_vector_dim, rewriter))) {
3163       return failure();
3164     }
3165 
3166     // Transform updates so that update window dims are the trailing dimensions
3167     // in the update tensor.
3168     auto update_window_dims = scatter_dimension_numbers.getUpdateWindowDims();
3169     if (failed(CanonicalizeScatterUpdates(scatter_op, update_window_dims,
3170                                           indices, indices_type, new_updates,
3171                                           updates_type, rewriter))) {
3172       return failure();
3173     }
3174 
3175     auto inserted_window_dims =
3176         scatter_dimension_numbers.getInsertedWindowDims();
3177     auto scatter_dims_to_operand_dims =
3178         scatter_dimension_numbers.getScatterDimsToOperandDims();
3179 
3180     if (IsIotaAttr(inserted_window_dims, indices_type.getShape().back()) &&
3181         IsIotaAttr(scatter_dims_to_operand_dims,
3182                    indices_type.getShape().back())) {
3183       rewriter.replaceOpWithNewOp<TfOp>(scatter_op,
3184                                         scatter_op.getResult(0).getType(),
3185                                         operands[0], indices, new_updates);
3186       return success();
3187     }
3188     // Insert tranposes to support scatter operations generated from
3189     // numpy-like slice operations:
3190     //   nd_array[:, [i,j]] = [i_values, j_values]
3191     //
3192     if (scatter_dims_to_operand_dims != inserted_window_dims) {
3193       // Support only dimension numbers generated by numpy-like slice
3194       // operations.
3195       return rewriter.notifyMatchFailure(
3196           scatter_op, "unsupported scatter_dims_to_operand_dims");
3197     }
3198 
3199     // Transpose the operand and so that the trailing dimensions of the
3200     // operand are being updated. Then apply a tf.scatter op and transpose
3201     // back the result to get the same shape as the original operand.
3202 
3203     SmallVector<int64_t, 4> permutation_array;
3204     for (int64_t i = 0; i < scatter_dims_to_operand_dims.size(); ++i) {
3205       permutation_array.push_back(scatter_dims_to_operand_dims[i]);
3206     }
3207     for (int64_t i = 0; i < operand_type.getRank(); ++i) {
3208       if (!llvm::is_contained(scatter_dims_to_operand_dims, i)) {
3209         permutation_array.push_back(i);
3210       }
3211     }
3212     auto permutation_and_shape = GetPermutationAndTransposedShape(
3213         permutation_array, operand_type, rewriter);
3214 
3215     Location loc = scatter_op.getLoc();
3216     auto transposed_operand = rewriter.create<mhlo::TransposeOp>(
3217         loc, permutation_and_shape.shape, operands[0],
3218         permutation_and_shape.permutation);
3219 
3220     // Apply TF scatter to update the trailing dimensions of the
3221     // transposed operand.
3222     auto tf_scatter_op =
3223         rewriter.create<TfOp>(loc, permutation_and_shape.shape,
3224                               transposed_operand, indices, new_updates);
3225 
3226     // Reverse the earlier transpose.
3227     auto inverse_permutation =
3228         GetInversePermutation(permutation_array, rewriter);
3229     rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(
3230         scatter_op, scatter_op.getResult(0).getType(), tf_scatter_op,
3231         inverse_permutation);
3232 
3233     return success();
3234   }
3235 };
3236 using ConvertScatterAddOp =
3237     ConvertScatterOp<mhlo::AddOp, TF::TensorScatterAddOp>;
3238 using ConvertScatterMaxOp =
3239     ConvertScatterOp<mhlo::MaxOp, TF::TensorScatterMaxOp>;
3240 using ConvertScatterMinOp =
3241     ConvertScatterOp<mhlo::MinOp, TF::TensorScatterMinOp>;
3242 using ConvertScatterSubOp =
3243     ConvertScatterOp<mhlo::SubtractOp, TF::TensorScatterSubOp>;
3244 using ConvertScatterUpdateOp =
3245     ConvertScatterOp<void, TF::TensorScatterUpdateOp>;
3246 
3247 // Converts mhlo.pad to tf.PadV2
ConvertPadOp(PatternRewriter & rewriter,Operation * old_op)3248 Value ConvertPadOp(PatternRewriter &rewriter, Operation *old_op) {
3249   auto pad_op = cast<mhlo::PadOp>(old_op);
3250   mlir::Location loc = pad_op.getLoc();
3251 
3252   llvm::SmallVector<APInt, 8> padding;
3253   for (auto p : llvm::zip(pad_op.edge_padding_low().getValues<APInt>(),
3254                           pad_op.edge_padding_high().getValues<APInt>())) {
3255     padding.push_back(std::get<0>(p));
3256     padding.push_back(std::get<1>(p));
3257   }
3258   auto attr_type = RankedTensorType::get({pad_op.edge_padding_low().size(), 2},
3259                                          rewriter.getI64Type());
3260   auto padding_attr = DenseIntElementsAttr::get(attr_type, padding);
3261   auto padding_op =
3262       rewriter.create<arith::ConstantOp>(loc, attr_type, padding_attr);
3263   return rewriter.create<PadV2Op>(loc, pad_op.getType(), pad_op.operand(),
3264                                   padding_op, pad_op.padding_value());
3265 }
3266 
3267 // Returns true if broadcast_dimensions obey Tensorflow convention, as in new
3268 // dimensions are added as prefix.
IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions,Value output)3269 bool IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions,
3270                         Value output) {
3271   // broadcast_dimensions is an increasing list by definition, thus it suffices
3272   // to check the first element.
3273   int64_t input_rank = broadcast_dimensions.getNumElements();
3274   int64_t output_rank = output.getType().cast<ShapedType>().getRank();
3275   return input_rank == 0 ||
3276          (broadcast_dimensions.getValues<APInt>()[0].getSExtValue() ==
3277           output_rank - input_rank);
3278 }
3279 
3280 // Returns the intermediate shape that input tensor should be reshaped to during
3281 // legalization of BroadcastInDimOp.
ExpandedShape(PatternRewriter & rewriter,Value input,DenseIntElementsAttr broadcast_dimensions,Value output)3282 arith::ConstantOp ExpandedShape(PatternRewriter &rewriter, Value input,
3283                                 DenseIntElementsAttr broadcast_dimensions,
3284                                 Value output) {
3285   // Initialize expanded shape with output rank and dimensions of 1.
3286   SmallVector<Attribute, 4> expanded_shape(
3287       output.getType().cast<ShapedType>().getRank(),
3288       /*Value=*/rewriter.getI64IntegerAttr(1));
3289 
3290   // Set dimension sizes specified by broadcast_dimensions.
3291   ArrayRef<int64_t> input_shape = input.getType().cast<ShapedType>().getShape();
3292   for (auto x : llvm::enumerate(broadcast_dimensions)) {
3293     expanded_shape[x.value().getSExtValue()] =
3294         rewriter.getI64IntegerAttr(input_shape[x.index()]);
3295   }
3296 
3297   // Create the expanded type wrapped in a arith::ConstantOp.
3298   auto attr_type =
3299       RankedTensorType::get({static_cast<int64_t>(expanded_shape.size())},
3300                             rewriter.getIntegerType(64));
3301   auto attr = DenseElementsAttr::get(attr_type, expanded_shape);
3302   return rewriter.create<arith::ConstantOp>(output.getLoc(), attr_type, attr);
3303 }
3304 
3305 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc"
3306 
3307 /// Performs the lowering to XLA dialect.
runOnOperation()3308 void LegalizeHloToTf::runOnOperation() {
3309   MLIRContext &context = getContext();
3310 
3311   // Add legalization patterns to the list.
3312   RewritePatternSet patterns(&getContext());
3313   PopulateLegalizeHloToTfPatterns(&patterns, &context);
3314 
3315   ConversionTarget target(context);
3316   target.addLegalDialect<TensorFlowDialect>();
3317   target.addLegalOp<func::CallOp, func::ConstantOp, arith::ConstantOp>();
3318   target.addLegalOp<mhlo::TupleOp>();
3319   if (failed(applyPartialConversion(getOperation(), target,
3320                                     std::move(patterns)))) {
3321     getOperation().emitError("mhlo to TF legalization failed.");
3322     signalPassFailure();
3323   }
3324 }
3325 
3326 }  // end namespace
3327 
PopulateLegalizeHloToTfPatterns(RewritePatternSet * patterns,MLIRContext * context)3328 void PopulateLegalizeHloToTfPatterns(RewritePatternSet *patterns,
3329                                      MLIRContext *context) {
3330   patterns->add<
3331       ConvertAvgPoolOp, Convert2DConvOp, Convert1DConvOp,
3332       ConvertNonTrivialConvOp, ConvertDynamicSliceOp,
3333       ConvertDynamicUpdateSliceOp, ConvertGatherOp, ConvertIfOp,
3334       ConvertMaxPoolOp, ConvertScatterAddOp, ConvertScatterMaxOp,
3335       ConvertScatterMinOp, ConvertScatterSubOp, ConvertScatterUpdateOp,
3336       ConvertSliceOp, ConvertReduceOpToTfArgmax, ConvertReduceOpToTfArgmin,
3337       ConvertReduceOpToTfMax, ConvertReduceOpToTfMin, ConvertReduceOpToTfAll,
3338       ConvertReduceOpToTfAny, ConvertReduceOpToTfSum, ConvertSortToTfTopk,
3339       ConvertIotaOpToTfRange, ConvertWhileOp, ConvertLoweredCumSumOp,
3340       ConvertLoweredCumProdOp>(context);
3341   populateWithGenerated(*patterns);
3342 }
3343 
CreateLegalizeHloToTfPass()3344 std::unique_ptr<OperationPass<func::FuncOp>> CreateLegalizeHloToTfPass() {
3345   return std::make_unique<LegalizeHloToTf>();
3346 }
3347 
3348 }  // end namespace TF
3349 }  // end namespace mlir
3350