xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1// RUN: xla-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" -split-input-file %s | FILECHECK_OPTS="" FileCheck %s
2// RUN: xla-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -split-input-file -verify-diagnostics %s | FileCheck %s --check-prefix CHLO --dump-input-filter=all
3// This test runs twice:
4//   1. Through FILECHECK_OPTS="" FileCheck with chlo legalization disabled since verifying
5//      that the chlo ops emit produces more useful tests.
6//   2. With chlo legalization enabled, verifying diagnostics to pick up any
7//      issues with the full lowering (can catch some broadcasting corner
8//      cases which emit with a warning).
9
10//===----------------------------------------------------------------------===//
11// BatchNorm op legalizations.
12//===----------------------------------------------------------------------===//
13
14// -----
15
16// fusedBatchNormV2 is almost identical to fusedBatchNormV3 (and uses the same
17// code), so only do a couple of basic checks.
18
19// CHECK-LABEL: fusedBatchNormV2_noTraining
20func.func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
21  // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
22  %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
23  func.return %0#0 : tensor<8x8x8x8xf32>
24}
25
26// -----
27
28// CHECK-LABEL: fusedBatchNormV2_training
29func.func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
30  // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>)
31  %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
32  // CHECK: mhlo.constant
33  // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
34  func.return %0#0 : tensor<8x8x8x8xf32>
35}
36
37// -----
38
39// CHECK-LABEL: fusedBatchNormV3_noTraining
40func.func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
41  // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
42  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
43  func.return %0#0 : tensor<8x8x8x8xf32>
44}
45
46// -----
47
48// CHECK-LABEL: fusedBatchNormV3_noTraining_mixedPrecision
49// CHECK-SAME:  ([[X:%.*]]: tensor<8x8x8x8xbf16>, [[SCALE:%.*]]: tensor<8xf32>, [[OFFSET:%.*]]: tensor<8xf32>, [[MEAN:%.*]]: tensor<8xf32>, [[VARIANCE:%.*]]: tensor<8xf32>)
50func.func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) {
51  // CHECK: [[CONVERT_X:%.*]] = mhlo.convert([[X]]) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
52  // CHECK: [[Y:%.*]] = "mhlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
53  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>)
54  // CHECK: [[Y_CONVERT:%.*]] = mhlo.convert([[Y]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
55  // CHECK: [[DUMMY:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<0xf32>
56  // CHECK: [[DUMMY_CAST:%.*]] = tensor.cast [[DUMMY]] : tensor<0xf32> to tensor<*xf32>
57  // CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY_CAST]]
58  func.return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>
59}
60
61// -----
62
63// CHECK-LABEL: fusedBatchNormV3_training
64func.func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
65  // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>)
66  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
67  // CHECK: mhlo.constant
68  // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
69  func.return %0#0 : tensor<8x8x8x8xf32>
70}
71
72// -----
73
74// CHECK-LABEL: func @fusedBatchNormV3_training_batchVariance
75func.func @fusedBatchNormV3_training_batchVariance(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8xf32> {
76  // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>)
77  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
78  // CHECK: return %[[VAR]]
79  func.return %0#4 : tensor<8xf32>
80}
81
82// -----
83
84// CHECK-LABEL: fusedBatchNormV3_training_exponentialAvgFactor
85func.func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
86  // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>)
87  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 0.8 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
88  // CHECK: %[[FACTOR:.*]] = mhlo.constant dense<1.00195694>
89  // CHECK: %[[CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[VAR]], %[[FACTOR]]
90
91  // CHECK-DAG: %[[ALPHA:.*]] = mhlo.constant dense<0.199999988>
92  // CHECK-DAG: %[[BETA:.*]] = mhlo.constant dense<8.000000e-01>
93
94  // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg3
95  // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = chlo.broadcast_multiply %[[BETA]], %[[MEAN]]
96  // CHECK: %[[NEW_BATCH_MEAN:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]]
97
98  // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg4
99  // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]]
100  // CHECK: %[[NEW_BATCH_VAR:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]]
101
102  // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[MEAN]], %[[VAR]]
103  func.return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
104}
105
106// -----
107
108// CHECK-LABEL: fusedBatchNormV3_training_mixedPrecision
109func.func @fusedBatchNormV3_training_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
110  // CHECK: mhlo.convert(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
111  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
112  // CHECK: mhlo.convert({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
113  func.return %0#0 : tensor<8x8x8x8xbf16>
114}
115
116// -----
117
118// CHECK-LABEL: fusedBatchNormV3_NCHW
119func.func @fusedBatchNormV3_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
120  // CHECK: "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>)
121  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
122  func.return %0#0 : tensor<8x8x8x8xf32>
123}
124
125// -----
126
127// CHECK-LABEL: fusedBatchNormV3_NDHWC
128func.func @fusedBatchNormV3_NDHWC(%arg0: tensor<8x8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8x8xf32>) {
129  // CHECK: feature_index = 4 : i64
130  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NDHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
131  func.return %0#0 : tensor<8x8x8x8x8xf32>
132}
133
134// -----
135
136// CHECK-LABEL: fusedBatchNormV3_noTraining_dynamic_supported
137func.func @fusedBatchNormV3_noTraining_dynamic_supported(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>) -> (tensor<?x?x?x?xf32>) {
138  // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?x?x?x?xf32>
139  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>)
140  func.return %0#0 : tensor<?x?x?x?xf32>
141}
142
143// -----
144
145// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported1
146func.func @fusedBatchNormV3_training_dynamic_unsupported1(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>) -> (tensor<?x?x?x?xf32>) {
147  // CHECK: tf.FusedBatchNormV3
148  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>)
149  func.return %0#0 : tensor<?x?x?x?xf32>
150}
151
152// -----
153
154// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported2
155func.func @fusedBatchNormV3_training_dynamic_unsupported2(%arg0: tensor<?x6x?x?xf32>, %arg1: tensor<6xf32>, %arg2: tensor<6xf32>, %arg3: tensor<6xf32>, %arg4: tensor<6xf32>) -> (tensor<?x6x?x?xf32>) {
156  // CHECK: tf.FusedBatchNormV3
157  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<?x6x?x?xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) -> (tensor<?x6x?x?xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>)
158  func.return %0#0 : tensor<?x6x?x?xf32>
159}
160
161// -----
162
163// CHECK-LABEL: fusedBatchNormGrad_noTraining
164func.func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
165  // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32>
166  // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32>
167  // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
168
169  // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
170  // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32>
171
172  // CHECK:      %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
173  // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32>
174  // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32>
175  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
176  // CHECK-NEXT: %[[cmul:.*]] = mhlo.convert %[[mul]] : tensor<8x8x8x8xf32>
177  // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32>
178  // CHECK-NEXT: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[init]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
179  // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32>
180
181  // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32>
182  // CHECK:      %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
183  // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32>
184  // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32>
185
186  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
187  // CHECK-NEXT: %[[cgrad:.*]] = mhlo.convert %[[grad]] : tensor<8x8x8x8xf32>
188  // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32>
189  // CHECK-NEXT: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[init2]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
190  // CHECK-NEXT: %[[offset_backprop:.*]] = mhlo.convert %[[red2]] : tensor<8xf32>
191
192  // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[mul3]] : tensor<8x8x8x8xf32>
193  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
194
195  %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
196  func.return %0#0 : tensor<8x8x8x8xf32>
197}
198
199// -----
200
201// CHECK-LABEL: fusedBatchNormGrad_Training
202func.func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
203  // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32>
204  // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32>
205  // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>)
206  // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : tensor<8x8x8x8xf32>
207  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
208
209  %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
210  func.return %0#0 : tensor<8x8x8x8xf32>
211}
212
213// -----
214
215// CHECK-LABEL: fusedBatchNormGradV2_noTraining
216func.func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
217  // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32>
218  // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32>
219  // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
220
221  // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
222  // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32>
223
224  // CHECK:      %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
225  // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32>
226  // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32>
227  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
228  // CHECK-NEXT: %[[cmul:.*]] = mhlo.convert %[[mul]] : tensor<8x8x8x8xf32>
229  // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32>
230  // CHECK-NEXT: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[init]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
231  // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32>
232
233  // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32>
234  // CHECK:      %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
235  // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32>
236
237  // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32>
238
239  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
240  // CHECK-NEXT: %[[cgrad:.*]] = mhlo.convert %[[grad]] : tensor<8x8x8x8xf32>
241  // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32>
242  // CHECK-NEXT: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[init2]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
243  // CHECK-NEXT: %[[offset_backprop:.*]] = mhlo.convert %[[red2]] : tensor<8xf32>
244
245  // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[mul3]] : tensor<8x8x8x8xf32>
246  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
247
248  %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
249  func.return %0#0 : tensor<8x8x8x8xf32>
250}
251
252// -----
253
254// CHECK-LABEL: fusedBatchNormGradV2_Training
255func.func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
256  // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32>
257  // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32>
258  // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>)
259  // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : tensor<8x8x8x8xf32>
260  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
261
262  %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
263  func.return %0#0 : tensor<8x8x8x8xf32>
264}
265
266// -----
267
268// CHECK-LABEL: fusedBatchNormGradV2_noTraining_mixed_precision
269func.func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
270  // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32>
271  // CHECK-NEXT: %[[act:.*]] = mhlo.convert(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
272
273  // CHECK: %[[x_backprop:.*]] = mhlo.convert({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
274  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
275
276  %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
277  func.return %0#0 : tensor<8x8x8x8xbf16>
278}
279
280// -----
281
282// CHECK-LABEL: fusedBatchNormGradV2_Training_mixed_precision
283func.func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
284  // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32>
285  // CHECK-NEXT: %[[act:.*]] = mhlo.convert(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
286  // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>)
287  // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert(%[[grad_operand]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
288  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
289
290  %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
291  func.return %0#0 : tensor<8x8x8x8xbf16>
292}
293
294// -----
295
296// CHECK-LABEL: fusedBatchNormGradV3_noTraining
297func.func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
298  // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32>
299  // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32>
300  // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
301
302  // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
303  // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32>
304
305  // CHECK:      %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
306  // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32>
307  // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32>
308  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
309  // CHECK-NEXT: %[[cmul:.*]] = mhlo.convert %[[mul]] : tensor<8x8x8x8xf32>
310  // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32>
311  // CHECK-NEXT: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[init]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
312  // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32>
313
314  // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32>
315  // CHECK:      %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
316  // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32>
317
318  // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32>
319
320  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
321  // CHECK-NEXT: %[[cgrad:.*]] = mhlo.convert %[[grad]] : tensor<8x8x8x8xf32>
322  // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32>
323  // CHECK-NEXT: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[init2]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
324  // CHECK-NEXT: %[[offset_backprop:.*]] = mhlo.convert %[[red2]] : tensor<8xf32>
325
326  // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[mul3]] : tensor<8x8x8x8xf32>
327  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
328
329  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
330  func.return %0#0 : tensor<8x8x8x8xf32>
331}
332
333// -----
334
335// CHECK-LABEL: fusedBatchNormGradV3_Training
336func.func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>) {
337  // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32>
338  // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32>
339  // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>)
340  // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : tensor<8x8x8x8xf32>
341  // CHECK: return %[[x_backprop]]
342  // CHECK-SAME: tensor<8x8x8x8xf32>
343
344  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<0xf32>, tensor<*xf32>)
345  func.return %0#0, %0#3, %0#4 : tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>
346}
347
348// -----
349
350// CHECK-LABEL: fusedBatchNormGradV3_noTraining_mixed_precision
351func.func @fusedBatchNormGradV3_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
352  // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32>
353  // CHECK-NEXT: %[[act:.*]] = mhlo.convert(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
354
355  // CHECK: %[[x_backprop:.*]] = mhlo.convert({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
356  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
357
358  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
359  func.return %0#0 : tensor<8x8x8x8xbf16>
360}
361
362// -----
363
364// CHECK-LABEL: fusedBatchNormGradV3_Training_mixed_precision
365func.func @fusedBatchNormGradV3_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
366  // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32>
367  // CHECK-NEXT: %[[act:.*]] = mhlo.convert(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
368  // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>)
369  // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert(%[[grad_operand]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
370  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
371
372  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
373  func.return %0#0 : tensor<8x8x8x8xbf16>
374}
375
376// -----
377
378// CHECK-LABEL: fusedBatchNormGradV3_noTraining_NCHW
379func.func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
380  // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32>
381  // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32>
382  // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
383
384  // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
385  // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32>
386
387  // CHECK:      %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
388  // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32>
389  // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32>
390  // CHECK-NEXT: mhlo.constant dense<[0, 2, 3]> : tensor<3xi64>
391  // CHECK-NEXT: %[[cmul:.*]] = mhlo.convert %[[mul]] : tensor<8x8x8x8xf32>
392  // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32>
393  // CHECK-NEXT: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[init]]) applies mhlo.add across dimensions = [0, 2, 3] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
394  // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32>
395
396  // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32>
397  // CHECK:      %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
398  // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32>
399
400  // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32>
401
402  // CHECK-NEXT: mhlo.constant dense<[0, 2, 3]> : tensor<3xi64>
403  // CHECK-NEXT: %[[cgrad:.*]] = mhlo.convert %[[grad]] : tensor<8x8x8x8xf32>
404  // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32>
405  // CHECK-NEXT: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[init2]]) applies mhlo.add across dimensions = [0, 2, 3] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
406  // CHECK-NEXT: %[[offset_backprop:.*]] = mhlo.convert %[[red2]] : tensor<8xf32>
407
408  // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[mul3]] : tensor<8x8x8x8xf32>
409  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
410
411  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
412  func.return %0#0 : tensor<8x8x8x8xf32>
413}
414
415// -----
416
417// CHECK-LABEL: fusedBatchNormGradV3_Training_NCHW
418func.func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
419  // CHECK: %{{.*}} = "mhlo.batch_norm_grad"(%{{.*}}, %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>)
420  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
421  func.return %0#0 : tensor<8x8x8x8xf32>
422}
423
424//===----------------------------------------------------------------------===//
425// Bias op legalizations.
426//===----------------------------------------------------------------------===//
427
428// -----
429
430// CHECK-LABEL: func @biasAdd_default
431func.func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
432  // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
433  // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
434  // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
435  // CHECK-SAME:   {broadcast_dimensions = dense<3> : tensor<1xi64>}
436  // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
437  %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
438  func.return %0 : tensor<1x32x10x32xi32>
439}
440
441// -----
442
443// CHECK-LABEL: func @biasAdd_NHWC
444func.func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
445  // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
446  // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
447  // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
448  // CHECK-SAME:   {broadcast_dimensions = dense<3> : tensor<1xi64>}
449  // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
450  %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
451  func.return %0 : tensor<1x32x10x32xi32>
452}
453
454// -----
455
456// CHECK-LABEL: func @biasAdd_NCHW
457func.func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
458  // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
459  // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
460  // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
461  // CHECK-SAME:   {broadcast_dimensions = dense<1> : tensor<1xi64>}
462  // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
463  %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
464  func.return %0 : tensor<1x32x10x32xi32>
465}
466
467// -----
468
469// CHECK-LABEL: func @biasAdd_dynamic
470func.func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
471  // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
472  // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
473  // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
474  // CHECK-SAME:   {broadcast_dimensions = dense<1> : tensor<1xi64>}
475  // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
476  %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
477  func.return %0 : tensor<?x?x?x?xi32>
478}
479
480// -----
481
482// CHECK-LABEL: func @biasAdd_partial_dynamic
483func.func @biasAdd_partial_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<512xi32>) -> tensor<?x?x?x512xi32> {
484  // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
485  // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
486  // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
487  // CHECK-SAME:   {broadcast_dimensions = dense<3> : tensor<1xi64>}
488  // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
489  // CHECK: %[[CAST:.+]] = tensor.cast %[[RESULT]] : tensor<?x?x?x?xi32> to tensor<?x?x?x512xi32>
490  // CHECK: return %[[CAST]] : tensor<?x?x?x512xi32>
491  %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<?x?x?x?xi32>, tensor<512xi32>) -> tensor<?x?x?x512xi32>
492  func.return %0 : tensor<?x?x?x512xi32>
493}
494
495
496//===----------------------------------------------------------------------===//
497// ClipByValue
498//===----------------------------------------------------------------------===//
499
500// -----
501
502// CHECK-LABEL: @clip
503func.func @clip(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<f32> {
504  // CHECK: [[VAL:%.+]] = mhlo.clamp %arg1, %arg0, %arg2
505
506  %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
507  // CHECK: return [[VAL]]
508  func.return %0 : tensor<f32>
509}
510
511// -----
512
513// CHECK-LABEL: @clip_dynamic
514func.func @clip_dynamic(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>) -> tensor<?xf32> {
515  // CHECK-DAG: [[CLAMP:%.+]] = mhlo.clamp %arg1, %arg0, %arg2
516  %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
517
518  // CHECK: return [[CLAMP]]
519  func.return %0 : tensor<?xf32>
520}
521
522// -----
523
524// CHECK-LABEL: @clip_static_broadcast
525func.func @clip_static_broadcast(%arg0 : tensor<5xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<5xf32> {
526  // CHECK-DAG: [[SHPIDX:%.+]] = mhlo.constant dense<5>
527  // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>}
528  // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>}
529  // CHECK-DAG: [[CLAMP:%.+]] = mhlo.clamp [[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]]
530  %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<5xf32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
531
532  // CHECK: return [[CLAMP]]
533  func.return %0 : tensor<5xf32>
534}
535
536
537// CHECK-LABEL: @clip_dynamic_broadcast
538func.func @clip_dynamic_broadcast(%arg0 : tensor<?xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<?xf32> {
539  // CHECK: [[SHP:%.+]] = shape.shape_of %arg0
540  // CHECK: [[SHPIDX:%.+]] = arith.index_cast [[SHP]] : tensor<1xindex> to tensor<1xi32>
541  // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>}
542  // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>}
543  // CHECK-DAG: [[CLAMP:%.+]] = mhlo.clamp [[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]]
544  %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<?xf32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
545
546  // CHECK: return [[CLAMP]]
547  func.return %0 : tensor<?xf32>
548}
549
550//===----------------------------------------------------------------------===//
551// DiagPart
552//===----------------------------------------------------------------------===//
553
554// -----
555
556// CHECK-LABEL: func @diag_part
557// CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32>
558func.func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> {
559  // CHECK: %[[RS:.*]] = mhlo.reshape %[[ARG]] : (tensor<4x3x4x3xf32>) -> tensor<12x12xf32>
560  // CHECK-DAG: %[[IOTA0:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<12x12xi32>
561  // CHECK-DAG: %[[IOTA1:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<12x12xi32>
562  // CHECK-DAG: %[[COMP:.*]] = mhlo.compare EQ, %[[IOTA0]], %[[IOTA1]], NOTYPE : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1>
563  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
564  // CHECK-DAG: %[[ZERO_MAT:.*]] = "mhlo.broadcast"(%[[ZERO]]) {broadcast_sizes = dense<12> : tensor<2xi64>} : (tensor<f32>) -> tensor<12x12xf32>
565  // CHECK-DAG: %[[SEL:.*]] = "mhlo.select"(%[[COMP]], %[[RS]], %[[ZERO_MAT]]) : (tensor<12x12xi1>, tensor<12x12xf32>, tensor<12x12xf32>) -> tensor<12x12xf32>
566  // CHECK-DAG: %[[RED:.*]] = mhlo.reduce(%[[SEL]] init: %[[ZERO]]) applies mhlo.add across dimensions = [0] : (tensor<12x12xf32>, tensor<f32>) -> tensor<12xf32>
567  // CHECK-DAG:  %[[RES:.*]] = mhlo.reshape %[[RED]] : (tensor<12xf32>) -> tensor<4x3xf32>
568  // CHECK-DAG:  return %[[RES]] : tensor<4x3xf32>
569  %0 = "tf.DiagPart"(%arg0) : (tensor<4x3x4x3xf32>) -> tensor<4x3xf32>
570  func.return %0: tensor<4x3xf32>
571}
572
573//===----------------------------------------------------------------------===//
574// MatrixDiagPart
575//===----------------------------------------------------------------------===//
576
577// -----
578
579// CHECK-LABEL: func @matrix_diag_part
580// CHECK-SAME: %[[ARG:.*]]: tensor<7x140x128xi32>
581func.func @matrix_diag_part(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> {
582  // CHECK-DAG: %[[V0:.*]] = mhlo.constant dense<42> : tensor<i32>
583  // CHECK-DAG: %[[V1:.*]] = mhlo.constant dense<[-10, 11]> : tensor<2xi32>
584  // CHECK-DAG: %[[V2:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<1x22x128xi32>
585  // CHECK-DAG: %[[V3:.*]] = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<1x22x128xi32>
586  // CHECK-DAG: %[[V4:.*]] = mhlo.constant dense<0> : tensor<i32>
587  // CHECK-DAG: %[[V5:.*]] = "mhlo.broadcast"(%[[V4]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32>
588  // CHECK-DAG: %[[V6:.*]] = mhlo.constant dense<false> : tensor<i1>
589  // CHECK-DAG: %[[V7:.*]] = "mhlo.broadcast"(%[[V6]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1>
590  // CHECK-DAG: %[[V8:.*]] = mhlo.constant dense<true> : tensor<i1>
591  // CHECK-DAG: %[[V9:.*]] = "mhlo.broadcast"(%[[V8]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1>
592  // CHECK-DAG: %[[V10:.*]] = mhlo.constant dense<11> : tensor<i32>
593  // CHECK-DAG: %[[V11:.*]] = "mhlo.broadcast"(%[[V10]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32>
594  // CHECK-DAG: %[[V12:.*]] = mhlo.constant dense<140> : tensor<i32>
595  // CHECK-DAG: %[[V13:.*]] = "mhlo.broadcast"(%[[V12]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32>
596  // CHECK-DAG: %[[V14:.*]] = mhlo.constant dense<128> : tensor<i32>
597  // CHECK-DAG: %[[V15:.*]] = "mhlo.broadcast"(%[[V14]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32>
598  // CHECK-DAG: %[[V16:.*]] = mhlo.constant dense<128> : tensor<i32>
599  // CHECK-DAG: %[[V17:.*]] = "mhlo.broadcast"(%[[V16]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32>
600  // CHECK-DAG: %[[V18:.*]] = mhlo.subtract %[[V11]], %[[V2]] : tensor<1x22x128xi32>
601  // CHECK-DAG: %[[V19:.*]] = mhlo.negate %[[V18]] : tensor<1x22x128xi32>
602  // CHECK-DAG: %[[V20:.*]] = mhlo.minimum %[[V18]], %[[V5]] : tensor<1x22x128xi32>
603  // CHECK-DAG: %[[V21:.*]] = mhlo.add %[[V13]], %[[V20]] : tensor<1x22x128xi32>
604  // CHECK-DAG: %[[V22:.*]] = mhlo.maximum %[[V18]], %[[V5]] : tensor<1x22x128xi32>
605  // CHECK-DAG: %[[V23:.*]] = mhlo.subtract %[[V15]], %[[V22]] : tensor<1x22x128xi32>
606  // CHECK-DAG: %[[V24:.*]] = mhlo.minimum %[[V21]], %[[V23]] : tensor<1x22x128xi32>
607  // CHECK-DAG: %[[V25:.*]] = chlo.broadcast_compare %[[V18]], %[[V5]] {comparison_direction = #mhlo<comparison_direction GE>} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
608  // CHECK-DAG: %[[V26:.*]] = mhlo.subtract %[[V17]], %[[V24]] : tensor<1x22x128xi32>
609  // CHECK-DAG: %[[V27:.*]] = "mhlo.select"(%[[V25]], %[[V26]], %[[V5]]) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
610  // CHECK-DAG: %[[V28:.*]] = mhlo.maximum %[[V18]], %[[V5]] : tensor<1x22x128xi32>
611  // CHECK-DAG: %[[V29:.*]] = mhlo.subtract %[[V28]], %[[V27]] : tensor<1x22x128xi32>
612  // CHECK-DAG: %[[V30:.*]] = mhlo.maximum %[[V19]], %[[V5]] : tensor<1x22x128xi32>
613  // CHECK-DAG: %[[V31:.*]] = mhlo.subtract %[[V30]], %[[V27]] : tensor<1x22x128xi32>
614  // CHECK-DAG: %[[V32:.*]] = mhlo.add %[[V3]], %[[V29]] : tensor<1x22x128xi32>
615  // CHECK-DAG: %[[V33:.*]] = mhlo.add %[[V3]], %[[V31]] : tensor<1x22x128xi32>
616  // CHECK-DAG: %[[V34:.*]] = chlo.broadcast_compare %[[V32]], %[[V5]] {comparison_direction = #mhlo<comparison_direction GE>} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
617  // CHECK-DAG: %[[V35:.*]] = chlo.broadcast_compare %[[V32]], %[[V15]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
618  // CHECK-DAG: %[[V36:.*]] = mhlo.and %[[V34]], %[[V35]] : tensor<1x22x128xi1>
619  // CHECK-DAG: %[[V37:.*]] = chlo.broadcast_compare %[[V33]], %[[V5]] {comparison_direction = #mhlo<comparison_direction GE>} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
620  // CHECK-DAG: %[[V38:.*]] = chlo.broadcast_compare %[[V33]], %[[V13]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
621  // CHECK-DAG: %[[V39:.*]] = mhlo.and %[[V37]], %[[V38]] : tensor<1x22x128xi1>
622  // CHECK-DAG: %[[V40:.*]] = mhlo.and %[[V36]], %[[V39]] : tensor<1x22x128xi1>
623  // CHECK-DAG: %[[V41:.*]] = mhlo.reshape %[[V40]] : (tensor<1x22x128xi1>) -> tensor<22x128xi1>
624  // CHECK-DAG: %[[V42:.*]] = "mhlo.concatenate"(%[[V33]], %[[V32]]) {dimension = 0 : i64} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32>
625  // CHECK-DAG: %[[V43:.*]] = "mhlo.gather"(%[[ARG]], %[[V42]]) {dimension_numbers = #mhlo.gather<offset_dims = [0], collapsed_slice_dims = [1, 2], start_index_map = [1, 2]>, indices_are_sorted = false, slice_sizes = dense<[7, 1, 1]> : tensor<3xi64>} : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32>
626  // CHECK-DAG: %[[V44:.*]] = "mhlo.broadcast"(%[[V41]]) {broadcast_sizes = dense<7> : tensor<1xi64>} : (tensor<22x128xi1>) -> tensor<7x22x128xi1>
627  // CHECK-DAG: %[[V45:.*]] = "mhlo.broadcast"(%[[V0]]) {broadcast_sizes = dense<[7, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<7x22x128xi32>
628  // CHECK: %[[V46:.*]] = "mhlo.select"(%[[V44]], %[[V43]], %[[V45]]) : (tensor<7x22x128xi1>, tensor<7x22x128xi32>, tensor<7x22x128xi32>) -> tensor<7x22x128xi32>
629  // CHECK: return %[[V46]] : tensor<7x22x128xi32>
630  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
631  %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32>  // k
632  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
633      T = i32, align = "RIGHT_LEFT"
634  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32>
635  func.return %2: tensor<7x22x128xi32>
636}
637
638// -----
639
640// CHECK-LABEL: func @matrix_diag_part_single_diagonal
641func.func @matrix_diag_part_single_diagonal(%arg0: tensor<7x140x128xi32>) -> tensor<7x128xi32> {
642  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
643  %1 = mhlo.constant dense<0> : tensor<2xi32>  // k
644  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
645      T = i32, align = "RIGHT_LEFT"
646  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x128xi32>
647  // CHECK: %[[result:.*]] = mhlo.reshape {{.*}} : (tensor<7x1x128xi32>) -> tensor<7x128xi32>
648  // CHECK: return %[[result]] : tensor<7x128xi32>
649  func.return %2: tensor<7x128xi32>
650}
651
652// -----
653
654// CHECK-LABEL: func @matrix_diag_part_align_ll
655func.func @matrix_diag_part_align_ll(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> {
656  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
657  %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32>  // k
658  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
659      T = i32, align = "LEFT_LEFT"
660  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32>
661  // CHECK: %[[false:.*]] = mhlo.constant dense<false> : tensor<i1>
662  // CHECK: %[[b_false:.*]] = "mhlo.broadcast"(%[[false]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1>
663  // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[b_false]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
664  func.return %2: tensor<7x22x128xi32>
665}
666
667// -----
668
669// CHECK-LABEL: func @matrix_diag_part_align_lr
670func.func @matrix_diag_part_align_lr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> {
671  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
672  %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32>  // k
673  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
674      T = i32, align = "LEFT_RIGHT"
675  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32>
676  // CHECK: %[[le:.*]] = chlo.broadcast_compare %{{[0-9]*}}, %{{[0-9]*}} {comparison_direction = #mhlo<comparison_direction LE>} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
677  // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[le]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
678  func.return %2: tensor<7x22x128xi32>
679}
680
681// -----
682
683// CHECK-LABEL: func @matrix_diag_part_align_rl
684func.func @matrix_diag_part_align_rl(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> {
685  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
686  %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32>  // k
687  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
688      T = i32, align = "RIGHT_LEFT"
689  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32>
690  // CHECK: %[[ge:.*]] = chlo.broadcast_compare %{{[0-9]*}}, %{{[0-9]*}} {comparison_direction = #mhlo<comparison_direction GE>} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
691  // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[ge]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
692  func.return %2: tensor<7x22x128xi32>
693}
694
695// -----
696
697// CHECK-LABEL: func @matrix_diag_part_align_rr
698func.func @matrix_diag_part_align_rr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> {
699  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
700  %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32>  // k
701  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
702      T = i32, align = "RIGHT_RIGHT"
703  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32>
704  // CHECK: %[[true:.*]] = mhlo.constant dense<true> : tensor<i1>
705  // CHECK: %[[b_true:.*]] = "mhlo.broadcast"(%[[true]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1>
706  // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[b_true]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
707  func.return %2: tensor<7x22x128xi32>
708}
709
710// -----
711
712// CHECK-LABEL: func @matrix_diag_part_align_7d
713// CHECK: (%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32>
714func.func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> {
715  %0 = mhlo.constant dense<-1.> : tensor<f32>  // padding value
716  %1 = mhlo.constant dense<[-6, -3]> : tensor<2xi32>  // k
717  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
718      T = f32, align = "LEFT_RIGHT"
719  } : (tensor<3x5x7x9x11x13x17xf32>, tensor<2xi32>, tensor<f32>) -> tensor<3x5x7x9x11x4x10xf32>
720  func.return %2: tensor<3x5x7x9x11x4x10xf32>
721}
722
723//===----------------------------------------------------------------------===//
724// Erf
725//===----------------------------------------------------------------------===//
726
727// -----
728
729// CHECK-LABEL: func @erf
730func.func @erf(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
731  // CHECK: chlo.erf %arg0 : tensor<2x3xf32>
732  %0 = "tf.Erf"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
733  func.return %0 : tensor<2x3xf32>
734}
735
736//===----------------------------------------------------------------------===//
737// Erfc
738//===----------------------------------------------------------------------===//
739
740// -----
741
742// CHECK-LABEL: func @erfc
743func.func @erfc(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
744  // CHECK: chlo.erfc %arg0 : tensor<2x3xf32>
745  %0 = "tf.Erfc"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
746  func.return %0 : tensor<2x3xf32>
747}
748
749//===----------------------------------------------------------------------===//
750// Einsum.
751//===----------------------------------------------------------------------===//
752
753// -----
754
755// CHECK-LABEL: func @einsum
756func.func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> {
757  // CHECK:  mhlo.einsum
758  %0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32>
759  func.return %0: tensor<2x4xf32>
760}
761
762// -----
763
764// CHECK-LABEL: func @unary_einsum
765func.func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
766  // CHECK:  mhlo.unary_einsum
767  %0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
768  func.return %0: tensor<2x2xf32>
769}
770
771//===----------------------------------------------------------------------===//
772// FloorDiv and FloorMod.
773//===----------------------------------------------------------------------===//
774
775// -----
776
777// CHECK-LABEL: func @floordiv_broadcast_i32
778func.func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
779  // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
780  // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
781  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #mhlo<comparison_direction NE>}
782  // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
783  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #mhlo<comparison_direction LT>}
784  // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
785  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #mhlo<comparison_direction LT>}
786  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #mhlo<comparison_direction NE>}
787  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
788  // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
789  // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]]
790  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]])
791  // CHECK: return [[SELECT]]
792  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
793  func.return %0: tensor<2x3xi32>
794}
795
796// -----
797
798// CHECK-LABEL: func @floordiv_reverse_broadcast_i32
799func.func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
800  // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
801  // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]]
802  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #mhlo<comparison_direction NE>}
803  // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
804  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #mhlo<comparison_direction LT>}
805  // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
806  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #mhlo<comparison_direction LT>}
807  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #mhlo<comparison_direction NE>}
808  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
809  // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
810  // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]]
811  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]])
812  // CHECK: return [[SELECT]]
813  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
814  func.return %0: tensor<2x3xi32>
815}
816
817// -----
818
819// CHECK-LABEL: func @floordiv_f32
820func.func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
821  // CHECK-NEXT:  %[[DIV:.*]] = chlo.broadcast_divide %arg0, %arg0
822  // CHECK-NEXT:  %[[FLOOR:.*]] = mhlo.floor %[[DIV]]
823  // CHECK-NEXT:  return %[[FLOOR]] : tensor<2xf32>
824  %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
825  func.return %0: tensor<2xf32>
826}
827
828// -----
829
830// CHECK-LABEL: func @floordiv_bf16
831func.func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> {
832  // CHECK-NEXT:  mhlo.convert
833  // CHECK-NEXT:  mhlo.convert
834  // CHECK-NEXT:  chlo.broadcast_divide
835  // CHECK-NEXT:  mhlo.floor
836  // CHECK-NEXT:  mhlo.convert
837  // CHECK-NEXT:  return
838  %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xbf16>, tensor<2xbf16>) -> tensor<2xbf16>
839  func.return %0: tensor<2xbf16>
840}
841
842// -----
843
844// CHECK-LABEL: func @floordiv_f16_broadcast
845func.func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
846  // CHECK-NEXT:  chlo.broadcast_divide
847  // CHECK-NEXT:  mhlo.floor
848  // CHECK-NEXT:  return
849  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
850  func.return %0: tensor<2x3xf16>
851}
852
853// -----
854
855// CHECK-LABEL: func @floordiv_dynamic
856func.func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> {
857  // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
858  // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
859  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #mhlo<comparison_direction NE>}
860  // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
861  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #mhlo<comparison_direction LT>}
862  // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
863  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #mhlo<comparison_direction LT>}
864  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #mhlo<comparison_direction NE>}
865  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
866  // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
867  // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]]
868  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]])
869  // CHECK: return [[SELECT]]
870  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32>
871  func.return %0: tensor<?x?xi32>
872}
873
874// -----
875
876// CHECK-LABEL: func @floordiv_unsigned
877func.func @floordiv_unsigned(%arg0: tensor<?x?xui32>, %arg1: tensor<?xui32>) -> tensor<?x?xui32> {
878  // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
879  // CHECK: return [[DIV]]
880  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<?x?xui32>, tensor<?xui32>) -> tensor<?x?xui32>
881  func.return %0: tensor<?x?xui32>
882}
883
884// -----
885
886// CHECK-LABEL: func @floordiv_unranked
887func.func @floordiv_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
888  // CHECK-NOT: tf.FloorDiv
889  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
890  func.return %0: tensor<*xf32>
891}
892
893// -----
894
895// CHECK-LABEL: func @floordiv_int
896func.func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
897  // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
898  // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
899  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1>
900  // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> : tensor<i32>
901  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
902  // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> : tensor<i32>
903  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
904  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #mhlo<comparison_direction NE>}
905  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
906  // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> : tensor<i32>
907  // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]]
908  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]])
909  // CHECK: return [[SELECT]]
910  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
911  func.return %0: tensor<*xi32>
912}
913
914// -----
915
916// CHECK-LABEL: func @floormod_broadcast_numerator
917func.func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
918  // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
919  // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0>
920  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #mhlo<comparison_direction NE>}
921  // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0>
922  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #mhlo<comparison_direction LT>}
923  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #mhlo<comparison_direction LT>}
924  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #mhlo<comparison_direction NE>}
925  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
926  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]]
927  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]])
928  // CHECK-NEXT: return [[SELECT]]
929  %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
930  func.return %0: tensor<2x3xi32>
931}
932
933// -----
934
935// CHECK-LABEL: func @floormod_broadcast_denominator
936func.func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
937  // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
938  // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0>
939  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #mhlo<comparison_direction NE>}
940  // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0>
941  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #mhlo<comparison_direction LT>}
942  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #mhlo<comparison_direction LT>}
943  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #mhlo<comparison_direction NE>}
944  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
945  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
946  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]])
947  // CHECK-NEXT: return [[SELECT]]
948  %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
949  func.return %0: tensor<2x3xi32>
950}
951
952// -----
953
954// CHECK-LABEL: func @floormod_unsigned_broadcast_denominator
955func.func @floormod_unsigned_broadcast_denominator(%arg0: tensor<2x3xui32>, %arg1: tensor<3xui32>) -> tensor<2x3xui32> {
956  // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
957  // CHECK-NEXT: return [[REM]]
958  %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xui32>, tensor<3xui32>) -> tensor<2x3xui32>
959  func.return %0: tensor<2x3xui32>
960}
961
962// -----
963
964// CHECK-LABEL: func @floormod_dynamic_broadcast_numerator
965func.func @floormod_dynamic_broadcast_numerator_(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> {
966  // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
967  // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0>
968  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #mhlo<comparison_direction NE>}
969  // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0>
970  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #mhlo<comparison_direction LT>}
971  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #mhlo<comparison_direction LT>}
972  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #mhlo<comparison_direction NE>}
973  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
974  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
975  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]])
976  // CHECK-NEXT: return [[SELECT]]
977  %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32>
978  func.return %0: tensor<?x?xi32>
979}
980
981// -----
982
983// CHECK-LABEL: func @floormod_dynamic_broadcast_denominator
984func.func @floormod_dynamic_broadcast_denominator_(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
985  // CHECK-NOT: tf.FloorMod
986  // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
987  // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
988  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?x?xi1>
989  // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
990  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?x?xi1>
991  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?x?xi1>
992  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<?x?x?xi1>, tensor<?x?x?xi1>) -> tensor<?x?x?xi1>
993  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] : (tensor<?x?x?xi1>, tensor<?x?x?xi1>) -> tensor<?x?x?xi1>
994  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
995  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
996  // CHECK-NEXT: return [[SELECT]] : tensor<?x?x?xf32>
997  %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
998  func.return %0: tensor<?x?x?xf32>
999}
1000
1001// -----
1002
1003// CHECK-LABEL: func @floormod_unranked
1004func.func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
1005  // CHECK-NOT: tf.FloorMod
1006  %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
1007  func.return %0: tensor<*xi32>
1008}
1009
1010//===----------------------------------------------------------------------===//
1011// OnesLike
1012//===----------------------------------------------------------------------===//
1013
1014// -----
1015
1016// CHECK-LABEL: @ones_like
1017// CHECK-SAME:  (%[[ARG:.*]]: tensor<2x?xf32>)
1018func.func @ones_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> {
1019  // CHECK: %[[RES:.*]] = "chlo.constant_like"(%[[ARG]]) {value = 1.0{{.*}}}
1020  // CHECK: return %[[RES]]
1021  %0 = "tf.OnesLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32>
1022  func.return %0 : tensor<2x?xf32>
1023}
1024
1025//===----------------------------------------------------------------------===//
1026// ZerosLike
1027//===----------------------------------------------------------------------===//
1028
1029// -----
1030
1031// CHECK-LABEL: @zeros_like
1032// CHECK-SAME:  (%[[ARG:.*]]: tensor<2x?xf32>)
1033func.func @zeros_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> {
1034  // CHECK: %[[RES:.*]] = "chlo.constant_like"(%[[ARG]]) {value = 0.0{{.*}}}
1035  // CHECK: return %[[RES]]
1036  %0 = "tf.ZerosLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32>
1037  func.return %0 : tensor<2x?xf32>
1038}
1039
1040//===----------------------------------------------------------------------===//
1041// BroadcastTo.
1042//===----------------------------------------------------------------------===//
1043
1044// -----
1045
1046// CHECK-LABEL: func @broadcast_to
1047func.func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> {
1048  %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32>
1049
1050  // CHECK: [[CST:%.+]] = mhlo.constant
1051  // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg0, [[CST]])
1052  // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>}
1053  %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32>
1054  func.return %0 : tensor<16x16x16x16xf32>
1055}
1056
1057// -----
1058
1059// CHECK-LABEL: func @broadcast_scalar_to_unranked
1060// CHECK: (%[[ARG0:.*]]: tensor<f32>, %[[SHAPE:.*]]: tensor<?xi32>)
1061func.func @broadcast_scalar_to_unranked(%arg0: tensor<f32>, %shape: tensor<?xi32>) -> tensor<*xf32> {
1062  // CHECK: "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE]])
1063  // CHECK-SAME: {broadcast_dimensions = dense<> : tensor<0xi64>}
1064  %0 = "tf.BroadcastTo"(%arg0, %shape) : (tensor<f32>, tensor<?xi32>) -> tensor<*xf32>
1065  func.return %0 : tensor<*xf32>
1066}
1067
1068//===----------------------------------------------------------------------===//
1069// Complex op legalizations.
1070//===----------------------------------------------------------------------===//
1071
1072// -----
1073
1074// CHECK-LABEL: func @complex
1075func.func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex<f32>> {
1076  // CHECK: chlo.broadcast_complex
1077  %1 = "tf.Complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
1078  func.return %1 : tensor<3xcomplex<f32>>
1079}
1080
1081// -----
1082
1083// CHECK-LABEL: func @imag
1084func.func @imag(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xf32> {
1085  // CHECK: mhlo.imag
1086  %1 = "tf.Imag"(%arg0) : (tensor<3xcomplex<f32>>) -> tensor<3xf32>
1087  func.return %1 : tensor<3xf32>
1088}
1089
1090// -----
1091
1092// CHECK-LABEL: func @real
1093func.func @real(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xf32> {
1094  // CHECK: mhlo.real
1095  %1 = "tf.Real"(%arg0) : (tensor<3xcomplex<f32>>) -> tensor<3xf32>
1096  func.return %1 : tensor<3xf32>
1097}
1098
1099//===----------------------------------------------------------------------===//
1100// Concat op legalizations.
1101//===----------------------------------------------------------------------===//
1102
1103// -----
1104
1105// CHECK-LABEL: func @concat_v2
1106func.func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> {
1107  // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
1108  %axis = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64>
1109  %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<6x3xf32>
1110  func.return %1 : tensor<6x3xf32>
1111}
1112
1113// -----
1114
1115// CHECK-LABEL: func @concat_v2_neg_axis
1116func.func @concat_v2_neg_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> {
1117  // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
1118
1119  %axis = "tf.Const"() { value = dense<-2> : tensor<i64> } : () -> tensor<i64>
1120  %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<6x3xf32>
1121  func.return %1 : tensor<6x3xf32>
1122}
1123
1124// -----
1125
1126// CHECK-LABEL: func @concat_v2_1d_axis
1127func.func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> {
1128  // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32>
1129
1130  %axis = "tf.Const"() { value = dense<[1]> : tensor<1xi64> } : () -> tensor<1xi64>
1131  %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<1xi64>) -> tensor<3x6xf32>
1132  func.return %1 : tensor<3x6xf32>
1133}
1134
1135// -----
1136
1137// CHECK-LABEL: func @concat_v2_non_const_axis
1138func.func @concat_v2_non_const_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %axis: tensor<i64>) -> tensor<3x6xf32> {
1139  // CHECK: "tf.ConcatV2"
1140  %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<3x6xf32>
1141  func.return %1 : tensor<3x6xf32>
1142}
1143
1144// -----
1145
1146// CHECK-LABEL: func @concat_v2_unranked
1147func.func @concat_v2_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
1148  %axis = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64>
1149  // CHECK: "tf.ConcatV2"
1150  %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<*xf32>, tensor<*xf32>, tensor<i64>) -> tensor<*xf32>
1151  func.return %1 : tensor<*xf32>
1152}
1153
1154//===----------------------------------------------------------------------===//
1155// Pad op legalizations.
1156//===----------------------------------------------------------------------===//
1157
1158// -----
1159
1160// CHECK-LABEL: func @padv2_1D
1161func.func @padv2_1D(%arg0: tensor<3xf32>, %arg1: tensor<f32>) -> tensor<6xf32> {
1162  %padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64>
1163  // CHECK: "mhlo.pad"(%arg0, %arg1) {
1164  // CHECK-SAME: edge_padding_high = dense<2> : tensor<1xi64>,
1165  // CHECK-SAME: edge_padding_low = dense<1> : tensor<1xi64>,
1166  // CHECK-SAME: interior_padding = dense<0> : tensor<1xi64>
1167  %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3xf32>, tensor<1x2xi64>, tensor<f32>) -> tensor<6xf32>
1168  func.return %1 : tensor<6xf32>
1169}
1170
1171// -----
1172
1173// CHECK-LABEL: func @padv2_2D
1174func.func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> {
1175  %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi64> } : () -> tensor<2x2xi64>
1176  // CHECK: "mhlo.pad"(%arg0, %arg1) {
1177  // CHECK-SAME:    edge_padding_high = dense<[2, 4]> : tensor<2xi64>,
1178  // CHECK-SAME:    edge_padding_low = dense<[1, 3]> : tensor<2xi64>,
1179  // CHECK-SAME:    interior_padding = dense<0> : tensor<2xi64>
1180  %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi64>, tensor<f32>) -> tensor<6x9xf32>
1181  func.return %1 : tensor<6x9xf32>
1182}
1183
1184// -----
1185
1186// CHECK-LABEL: func @padv2_i32_paddings
1187func.func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> {
1188  %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32>
1189  // CHECK: "mhlo.pad"(%arg0, %arg1) {
1190  // CHECK-SAME:    edge_padding_high = dense<[2, 4]> : tensor<2xi64>,
1191  // CHECK-SAME:    edge_padding_low = dense<[1, 3]> : tensor<2xi64>,
1192  // CHECK-SAME:    interior_padding = dense<0> : tensor<2xi64>
1193  %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor<f32>) -> tensor<6x9xf32>
1194  func.return %1 : tensor<6x9xf32>
1195}
1196
1197// -----
1198
1199// CHECK-LABEL: func @padv2_dynamic
1200func.func @padv2_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<f32>, %arg2: tensor<1x2xi64>) -> tensor<?xf32> {
1201  // CHECK: "mhlo.transpose"({{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x2xi64>) -> tensor<2x1xi64>
1202  // CHECK: mhlo.reshape {{.*}} : (tensor<2x1xi64>) -> tensor<2xi64>
1203  // CHECK: "mhlo.slice"({{.*}}) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
1204  // CHECK: "mhlo.slice"({{.*}}) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
1205  // CHECK: mhlo.dynamic_pad {{.*}} : (tensor<?xf32>, tensor<f32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<?xf32>
1206  %1 = "tf.PadV2"(%arg0, %arg2, %arg1) : (tensor<?xf32>, tensor<1x2xi64>, tensor<f32>) -> tensor<?xf32>
1207  func.return %1 : tensor<?xf32>
1208}
1209
1210//===----------------------------------------------------------------------===//
1211// Identity op legalizations.
1212//===----------------------------------------------------------------------===//
1213
1214// -----
1215
1216// CHECK-LABEL: func @identity
1217func.func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> {
1218  // CHECK-NEXT:  return %arg0 : tensor<1xi32>
1219  %0 = "tf.Identity"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
1220  func.return %0: tensor<1xi32>
1221}
1222
1223// -----
1224
1225// CHECK-LABEL: func @identityN
1226func.func @identityN(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) {
1227  // CHECK-NEXT:  return %arg0, %arg1 : tensor<1xi32>, tensor<1xf32>
1228  %0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>)
1229  func.return %0#0, %0#1: tensor<1xi32>, tensor<1xf32>
1230}
1231
1232// -----
1233
1234// CHECK-LABEL: func @stopgradient
1235func.func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> {
1236  // CHECK-NEXT:  return %arg0 : tensor<1xi32>
1237  %0 = "tf.StopGradient"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
1238  func.return %0: tensor<1xi32>
1239}
1240
1241// -----
1242
1243// CHECK-LABEL: func @preventgradient
1244func.func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> {
1245  // CHECK-NEXT:  return %arg0 : tensor<1xi32>
1246  %0 = "tf.PreventGradient"(%arg0) {message = "fin gradients"} : (tensor<1xi32>) -> tensor<1xi32>
1247  func.return %0: tensor<1xi32>
1248}
1249
1250// -----
1251
1252// CHECK-LABEL: func @checkNumerics
1253func.func @checkNumerics(%arg0: tensor<1xf32>) -> tensor<1xf32> {
1254  // CHECK-NEXT:  return %arg0 : tensor<1xf32>
1255  %0 = "tf.CheckNumerics"(%arg0) {message = "check numerics"} : (tensor<1xf32>) -> tensor<1xf32>
1256  func.return %0: tensor<1xf32>
1257}
1258
1259//===----------------------------------------------------------------------===//
1260// InfeedDequeueTuple legalization
1261//===----------------------------------------------------------------------===//
1262
1263// -----
1264
1265// CHECK-LABEL: func @infeed_dequeue_tuple
1266func.func @infeed_dequeue_tuple() -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) {
1267// CHECK: [[TOKEN:%.*]] = mhlo.create_token  : !mhlo.token
1268// CHECK: [[INFEED:%.*]]:3 = "mhlo.infeed"([[TOKEN]]) {infeed_config = ""{{.*}}} : (!mhlo.token) -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>, !mhlo.token)
1269// CHECK: return [[INFEED]]#0, [[INFEED]]#1
1270  %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>)
1271  func.return %0#0, %0#1 : tensor<1x8x4x4xi32>, tensor<1x100x1xf32>
1272}
1273
1274// -----
1275
1276// CHECK-LABEL: func @infeed_dequeue_tuple_dynamic_error
1277func.func @infeed_dequeue_tuple_dynamic_error() -> (tensor<3x3xf32>, tensor<4x?xf32>) {
1278  // We expect legalization to fail for dynamic shapes:
1279  // CHECK: [[INFEED:%.*]] = "tf.InfeedDequeueTuple"{{.*}}
1280  %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3x3xf32>, tensor<4x?xf32>)
1281  func.return %0#0, %0#1 : tensor<3x3xf32>, tensor<4x?xf32>
1282}
1283
1284// The following op sharding is used:
1285// Proto debug string:
1286//   type: TUPLE
1287//   tuple_shardings {
1288//     type: MAXIMAL
1289//     tile_assignment_dimensions: 1
1290//     tile_assignment_devices: 0
1291//   }
1292// Serialized string:
1293//   "\08\02*\08\08\01\1A\01\01\22\01\00"
1294
1295// CHECK-LABEL: infeed_dequeue_tuple_sharding
1296func.func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> {
1297  // CHECK: "mhlo.infeed"
1298  // An additional sharding is added at the end to account for token result.
1299  // Proto debug string:
1300  //   type: TUPLE
1301  //   tuple_shardings {
1302  //     type: MAXIMAL
1303  //     tile_assignment_dimensions: 1
1304  //     tile_assignment_devices: 0
1305  //   }
1306  //   tuple_shardings {
1307  //     type: MAXIMAL
1308  //     tile_assignment_dimensions: 1
1309  //     tile_assignment_devices: 0
1310  //   }
1311  // CHECK-SAME: mhlo.sharding = "\08\02*\08\08\01\1A\01\01\22\01\00*\08\08\01\1A\01\01\22\01\00"
1312  %0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32>
1313  func.return %0 : tensor<8xi32>
1314}
1315
1316//===----------------------------------------------------------------------===//
1317// Nullary op legalizations.
1318//===----------------------------------------------------------------------===//
1319
1320// -----
1321
1322// CHECK-LABEL: @const
1323func.func @const() -> tensor<2xi32> {
1324  // CHECK: mhlo.constant dense<0> : tensor<2xi32>
1325  %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>)
1326  func.return %0: tensor<2xi32>
1327}
1328
1329// -----
1330
1331// CHECK-LABEL: @const_dynamic_output
1332func.func @const_dynamic_output() -> tensor<*xi32> {
1333  // CHECK: [[CONST:%.*]] = mhlo.constant dense<0> : tensor<2xi32>
1334  // CHECK: [[CAST:%.*]] = tensor.cast [[CONST]] : tensor<2xi32> to tensor<*xi32>
1335  %0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>)
1336  // CHECK: return [[CAST]]
1337  func.return %0: tensor<*xi32>
1338}
1339
1340// -----
1341
1342// CHECK-LABEL: @opaque_const
1343func.func @opaque_const() -> tensor<!tf_type.variant<tensor<2xi32>>> {
1344  // CHECK-NOT: mhlo.constant
1345  %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<2xi32>>>
1346  func.return %0 : tensor<!tf_type.variant<tensor<2xi32>>>
1347}
1348
1349//===----------------------------------------------------------------------===//
1350// Matmul op legalizations.
1351//===----------------------------------------------------------------------===//
1352
1353// -----
1354
1355// CHECK-LABEL: matmul_notranspose
1356// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<7x11xf32>)
1357func.func @matmul_notranspose(%a: tensor<5x7xf32>, %b: tensor<7x11xf32>) -> tensor<5x11xf32> {
1358  // CHECK: "mhlo.dot"(%[[A]], %[[B]])
1359  %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<5x7xf32>, tensor<7x11xf32>) -> tensor<5x11xf32>
1360
1361  func.return %0 : tensor<5x11xf32>
1362}
1363
1364// -----
1365
1366// CHECK-LABEL: matmul_transpose_b
1367// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<11x7xf32>)
1368func.func @matmul_transpose_b(%a: tensor<5x7xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> {
1369  // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>}
1370  // CHECK: "mhlo.dot"(%[[A]], %[[UPDATED_B]])
1371  %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = true} : (tensor<5x7xf32>, tensor<11x7xf32>) -> tensor<5x11xf32>
1372
1373  func.return %0 : tensor<5x11xf32>
1374}
1375
1376// -----
1377
1378// CHECK-LABEL: matmul_transpose_both
1379// CHECK-SAME: (%[[A:.*]]: tensor<7x5xf32>, %[[B:.*]]: tensor<11x7xf32>)
1380func.func @matmul_transpose_both(%a: tensor<7x5xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> {
1381  // CHECK: %[[UPDATED_A:.*]] = "mhlo.transpose"(%[[A]]) {permutation = dense<[1, 0]> : tensor<2xi64>}
1382  // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>}
1383  // CHECK: "mhlo.dot"(%[[UPDATED_A]], %[[UPDATED_B]])
1384  %0 = "tf.MatMul"(%a, %b) {transpose_a = true, transpose_b = true} : (tensor<7x5xf32>, tensor<11x7xf32>) -> tensor<5x11xf32>
1385
1386  func.return %0 : tensor<5x11xf32>
1387}
1388
1389// Verify that MatMul with ranked inputs are lowered to HLO.
1390// CHECK-LABEL: matmul_ranked
1391func.func @matmul_ranked(%a: tensor<?x7xf32>, %b: tensor<7x?xf32>) -> tensor<?x?xf32> {
1392  // CHECK: "mhlo.dot"
1393  %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<?x7xf32>, tensor<7x?xf32>) -> tensor<?x?xf32>
1394
1395  func.return %0 : tensor<?x?xf32>
1396}
1397
1398// Verify that MatMul with unranked inputs are lowered to HLO.
1399// CHECK-LABEL: matmul_unranked
1400func.func @matmul_unranked(%a: tensor<*xf32>, %b: tensor<*xf32>) -> tensor<*xf32> {
1401  // CHECK: "mhlo.dot"
1402  %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
1403
1404  func.return %0 : tensor<*xf32>
1405}
1406
1407// Verify SparseMatMul is legalized to dot.
1408// CHECK-LABEL: test_sparse_mat_mul
1409func.func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> {
1410  // CHECK: "mhlo.dot"
1411  %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32>
1412  func.return %0: tensor<3x5xf32>
1413}
1414
1415// SparseMatMul where one operand needs to be transposed and the other one not.
1416//
1417// CHECK-LABEL:   @test_sparse_mat_mul_with_transpose
1418// CHECK-SAME:      %[[ARG0:.*]]: tensor<3x4xf32>
1419// CHECK-SAME:      %[[ARG1:.*]]: tensor<5x4xf32>
1420// CHECK-SAME:      -> tensor<3x5xf32>
1421// CHECK:           %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[ARG1]])
1422// CHECK-SAME:        permutation = dense<[1, 0]>
1423// CHECK-SAME:        -> tensor<4x5xf32>
1424// CHECK:           %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]])
1425// CHECK-SAME:        -> tensor<3x5xf32>
1426// CHECK:           return %[[RESULT]]
1427func.func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> {
1428  %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32>
1429  func.return %0: tensor<3x5xf32>
1430}
1431
1432// SparseMatMul where one operand needs to be casted and the other one not.
1433//
1434// CHECK-LABEL:   @test_sparse_mat_mul_with_cast
1435// CHECK-SAME:      %[[ARG0:.*]]: tensor<3x4xf32>
1436// CHECK-SAME:      %[[ARG1:.*]]: tensor<4x5xbf16>
1437// CHECK-SAME:      -> tensor<3x5xf32>
1438// CHECK:           %[[CAST:.*]] = mhlo.convert(%[[ARG1]])
1439// CHECK-SAME:        -> tensor<4x5xf32>
1440// CHECK:           %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]])
1441// CHECK-SAME:        -> tensor<3x5xf32>
1442// CHECK:           return %[[RESULT]]
1443func.func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> {
1444  %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32>
1445  func.return %0: tensor<3x5xf32>
1446}
1447
1448//===----------------------------------------------------------------------===//
1449// MatrixBandPart op legalizations.
1450//===----------------------------------------------------------------------===//
1451
1452// -----
1453
1454// CHECK-LABEL: matrix_band_part
1455// CHECK-SAME: (%[[INPUT:.*]]: tensor<64x64xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
1456func.func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
1457  // CHECK-DAG: %[[M:.*]] = mhlo.constant dense<64> : tensor<i64>
1458  // CHECK-DAG: %[[N:.*]] = mhlo.constant dense<64> : tensor<i64>
1459
1460  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i64>
1461  // CHECK-DAG: %[[A:.*]] = mhlo.compare LT, %[[LOWER]], %[[ZERO]] : (tensor<i64>, tensor<i64>) -> tensor<i1>
1462  // CHECK-DAG: %[[B:.*]] = "mhlo.select"(%[[A]], %[[M]], %[[LOWER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
1463
1464  // CHECK-DAG: %[[C:.*]] = mhlo.compare LT, %[[UPPER]], %[[ZERO]] : (tensor<i64>, tensor<i64>) -> tensor<i1>
1465  // CHECK-DAG: %[[D:.*]] = "mhlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
1466  // CHECK-DAG: %[[F:.*]] = mhlo.negate %[[B]] : tensor<i64>
1467
1468  // CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xi64>
1469  // CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xi64>
1470  // CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xi64>
1471  // CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = #mhlo<comparison_direction LE>} : (tensor<i64>, tensor<64x64xi64>) -> tensor<64x64xi1>
1472
1473  // CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = #mhlo<comparison_direction LE>} : (tensor<64x64xi64>, tensor<i64>) -> tensor<64x64xi1>
1474
1475  // CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1>
1476
1477  // CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<64x64xbf16>
1478
1479  // CHECK-DAG: %[[R:.*]] = chlo.broadcast_select %[[J]], %[[INPUT]], %[[ZERO2]]
1480  // CHECK-DAG: return %[[R]]
1481  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
1482  func.return %0 : tensor<64x64xbf16>
1483}
1484
1485// -----
1486
1487// CHECK-LABEL: matrix_band_part_2
1488// CHECK-SAME: (%[[INPUT:.*]]: tensor<12x24x48xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
1489func.func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<12x24x48xbf16> {
1490  // CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xi64>
1491  // CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xi64>
1492  // CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xi64>
1493
1494  // CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = #mhlo<comparison_direction LE>} : (tensor<i64>, tensor<24x48xi64>) -> tensor<24x48xi1>
1495
1496  // CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = #mhlo<comparison_direction LE>} : (tensor<24x48xi64>, tensor<i64>) -> tensor<24x48xi1>
1497  // CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1>
1498
1499  // CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16>
1500
1501  // CHECK-DAG: %[[R:.*]] = chlo.broadcast_select %[[J]], %[[INPUT]], %[[ZERO2]]
1502  // CHECK-DAG: return %[[R]]
1503  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<12x24x48xbf16>, tensor<i64>, tensor<i64>) -> tensor<12x24x48xbf16>
1504  func.return %0 : tensor<12x24x48xbf16>
1505}
1506
1507// -----
1508
1509// CHECK-LABEL: matrix_band_part_3
1510// CHECK-SAME: (%[[INPUT:.*]]: tensor<*xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
1511func.func @matrix_band_part_3(%arg0: tensor<*xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
1512  // CHECK: "tf.MatrixBandPart"
1513  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<*xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
1514  func.return %0 : tensor<*xbf16>
1515}
1516
1517// -----
1518
1519// CHECK-LABEL: matrix_band_part_4
1520// CHECK-SAME: (%[[INPUT:.*]]: tensor<24x48xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
1521func.func @matrix_band_part_4(%arg0: tensor<24x48xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<24x48xbf16> {
1522  // This one should lower.
1523  // CHECK-NOT: "tf.MatrixBandPart"
1524  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<24x48xbf16>, tensor<i64>, tensor<i64>) -> tensor<24x48xbf16>
1525  func.return %0 : tensor<24x48xbf16>
1526}
1527
1528//===----------------------------------------------------------------------===//
1529// MaxPool op legalizations.
1530//===----------------------------------------------------------------------===//
1531
1532// -----
1533
1534// CHECK-LABEL: maxpool_valid_padding
1535// CHECK-SAME: %[[ARG:.*]]: tensor
1536func.func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> {
1537  // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32>
1538  // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]])
1539  // CHECK: mhlo.maximum
1540  // CHECK: mhlo.return
1541  // CHECK: {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>}
1542
1543  %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32>
1544  func.return %0 : tensor<2x3x5x7xi32>
1545}
1546
1547// -----
1548
1549// CHECK-LABEL: maxpool_same_padding
1550// CHECK-SAME: %[[ARG:.*]]: tensor
1551func.func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> {
1552  // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>
1553
1554  %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32>
1555  func.return %0 : tensor<2x4x7x7xi32>
1556}
1557
1558// -----
1559
1560// CHECK-LABEL: maxpool_3d_valid_padding
1561// CHECK-SAME: %[[ARG:.*]]: tensor
1562func.func @maxpool_3d_valid_padding(%arg0: tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> {
1563  // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32>
1564  // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]])
1565  // CHECK: mhlo.maximum
1566  // CHECK: mhlo.return
1567  // CHECK: {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 4, 4, 1]> : tensor<5xi64>}
1568
1569  %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32>
1570  func.return %0 : tensor<2x8x3x5x7xf32>
1571}
1572
1573// -----
1574
1575// CHECK-LABEL: maxpool_3d_same_padding
1576// CHECK-SAME: %[[ARG:.*]]: tensor
1577func.func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> {
1578  // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64>
1579
1580  %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32>
1581  func.return %0 : tensor<2x8x4x7x7xf32>
1582}
1583
1584// -----
1585
1586// CHECK-LABEL: maxpool_explicit_padding
1587func.func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> {
1588  // CHECK: tf.MaxPool
1589  // TODO(b/165938852): need to support explicit padding in max_pool.
1590
1591  %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "EXPLICIT", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32>
1592  func.return %0 : tensor<2x3x5x7xi32>
1593}
1594
1595//===----------------------------------------------------------------------===//
1596// MaxPoolGrad op legalizations.
1597//===----------------------------------------------------------------------===//
1598
1599// -----
1600
1601// CHECK-LABEL: @max_pool_grad_valid
1602// CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32>
1603func.func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
1604  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
1605  // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ({
1606  // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
1607  // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor<f32>, tensor<f32>) -> tensor<i1>
1608  // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor<i1>
1609  // CHECK: },  {
1610  // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
1611  // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor<f32>
1612  // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor<f32>
1613  // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) -> tensor<10x24x24x64xf32>
1614  // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32>
1615  %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) {
1616     data_format = "NHWC",
1617     ksize = [1, 2, 2, 1],
1618     padding = "VALID",
1619     strides = [1, 2, 2, 1]
1620  } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32>
1621  func.return %result : tensor<10x24x24x64xf32>
1622}
1623
1624// -----
1625
1626// CHECK-LABEL: @max_pool_3d_grad_valid
1627// CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32>
1628func.func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> {
1629  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
1630  // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ({
1631  // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
1632  // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor<f32>, tensor<f32>) -> tensor<i1>
1633  // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor<i1>
1634  // CHECK: },  {
1635  // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
1636  // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor<f32>
1637  // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor<f32>
1638  // CHECK: }) {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<f32>) -> tensor<10x8x24x24x64xf32>
1639  // CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32>
1640  %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32>
1641  func.return %result : tensor<10x8x24x24x64xf32>
1642}
1643
1644// -----
1645
1646// CHECK-LABEL: @max_pool_grad_same
1647func.func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> {
1648  // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>
1649  %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) {
1650     data_format = "NHWC",
1651     ksize = [1, 2, 3, 1],
1652     padding = "SAME",
1653     strides = [1, 4, 4, 1]
1654  } : (tensor<2x13x25x7xf32>, tensor<2x4x7x7xf32>, tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32>
1655  func.return %result : tensor<2x13x25x7xf32>
1656}
1657
1658// -----
1659
1660// CHECK-LABEL: @max_pool_3d_grad_same
1661func.func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: tensor<2x8x4x7x7xf32>, %grad: tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> {
1662  // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64>
1663  %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>, tensor<2x8x4x7x7xf32>, tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32>
1664  func.return %result : tensor<2x8x13x25x7xf32>
1665}
1666
1667//===----------------------------------------------------------------------===//
1668// OneHot op legalizations.
1669//===----------------------------------------------------------------------===//
1670
1671// -----
1672
1673// CHECK-LABEL:one_hot
1674func.func @one_hot(%indices: tensor<3xi32>, %on_value: tensor<f32>, %off_value: tensor<f32>) -> tensor<3x5xf32> {
1675  // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x5xi32>
1676  // CHECK: %[[BCAST_ARG0:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<3x5xi32>
1677  // CHECK: %[[COMPARE:.*]] = mhlo.compare EQ, %[[BCAST_ARG0]], %[[IOTA]], NOTYPE : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1>
1678  // CHECK: %[[ON_VALUE:.*]] = "mhlo.broadcast"(%arg1) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor<f32>) -> tensor<3x5xf32>
1679  // CHECK: %[[OFF_VALUE:.*]] = "mhlo.broadcast"(%arg2) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor<f32>) -> tensor<3x5xf32>
1680  // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[COMPARE]], %[[ON_VALUE]], %[[OFF_VALUE]]) : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32>
1681  // CHECK: return %[[RESULT]] : tensor<3x5xf32>
1682  %depth = "tf.Const"() { value = dense<5> : tensor<i32> } : () -> tensor<i32>
1683  %result = "tf.OneHot"(%indices, %depth, %on_value, %off_value) {axis = -1 : i64} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<3x5xf32>
1684  func.return %result : tensor<3x5xf32>
1685}
1686
1687//===----------------------------------------------------------------------===//
1688// tf.OutfeedEnqueueTuple legalization
1689//===----------------------------------------------------------------------===//
1690
1691// -----
1692
1693// CHECK-LABEL: func @outfeed_enqueue_tuple
1694// CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>)
1695func.func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () {
1696// CHECK: [[TOKEN:%.*]] = mhlo.create_token  : !mhlo.token
1697// CHECK: "mhlo.outfeed"([[VAL_0]], [[VAL_1]], [[TOKEN]]) {outfeed_config = ""} : (tensor<3xi32>, tensor<4xf32>, !mhlo.token) -> !mhlo.token
1698  "tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> ()
1699  func.return
1700}
1701
1702//===----------------------------------------------------------------------===//
1703// Pack op legalizations.
1704//===----------------------------------------------------------------------===//
1705
1706// -----
1707
1708// CHECK-LABEL: func @pack
1709func.func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
1710  // CHECK: mhlo.reshape {{.*}} : (tensor<2xi32>) -> tensor<1x2xi32>
1711  // CHECK: mhlo.reshape {{.*}} : (tensor<2xi32>) -> tensor<1x2xi32>
1712  // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
1713
1714  %0 = "tf.Pack"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
1715  func.return %0 : tensor<2x2xi32>
1716}
1717
1718//===----------------------------------------------------------------------===//
1719// PartitionedCall op legalization.
1720//===----------------------------------------------------------------------===//
1721
1722// -----
1723
1724// CHECK-LABEL: func @partitioned_call
1725func.func @partitioned_call(%arg0: tensor<i32>) -> tensor<i32> {
1726  // CHECK: call @pcall_func(%arg0) : (tensor<i32>) -> tensor<i32>
1727  %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @pcall_func} : (tensor<i32>) -> (tensor<i32>)
1728  func.return %0 : tensor<i32>
1729}
1730
1731
1732func.func @pcall_func(%arg0: tensor<i32>) -> tensor<i32> {
1733  func.return %arg0 : tensor<i32>
1734}
1735
1736// -----
1737
1738// CHECK-LABEL: func @partitioned_call_multi_input
1739func.func @partitioned_call_multi_input(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
1740  // CHECK: call @pcall_multi_input(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
1741  %0 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_input} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
1742  func.return %0 : tensor<i32>
1743}
1744
1745
1746func.func @pcall_multi_input(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
1747  func.return %arg0 : tensor<i32>
1748}
1749
1750// -----
1751
1752// CHECK-LABEL: func @partitioned_call_multi_in_out
1753func.func @partitioned_call_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
1754  // CHECK: call @pcall_multi_in_out(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
1755  %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
1756  func.return %0, %1 : tensor<i32>, tensor<i32>
1757}
1758
1759
1760func.func @pcall_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
1761  func.return %arg1, %arg0 : tensor<i32>, tensor<i32>
1762}
1763
1764// CHECK-LABEL: func @unhandled_partitioned_call
1765func.func @unhandled_partitioned_call(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<i32>, tensor<i32>) {
1766  // The argument types don't match the parameter types for the
1767  // pcall_multi_in_out function. That's fine for a PartitionedCallOp but not
1768  // for a standard CallOp, so this op can't be lowered.
1769  // CHECK: "tf.PartitionedCall"
1770  %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<i32>, tensor<i32>)
1771  func.return %0, %1 : tensor<i32>, tensor<i32>
1772}
1773
1774
1775// CHECK-LABEL: func @unhandled_partitioned_call_2
1776func.func @unhandled_partitioned_call_2(%arg0: tensor<i32>, %arg1: tensor<*xi32>) -> (tensor<i32>, tensor<i32>) {
1777  // CHECK: "tf.PartitionedCall"
1778  %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<i32>, tensor<*xi32>) -> (tensor<i32>, tensor<i32>)
1779  func.return %0, %1 : tensor<i32>, tensor<i32>
1780}
1781
1782// -----
1783
1784//===----------------------------------------------------------------------===//
1785// ReverseV2 op legalization.
1786//===----------------------------------------------------------------------===//
1787
1788// -----
1789
1790// CHECK-LABEL: @reverse_func_32
1791func.func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> {
1792  %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>)
1793
1794  // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>}
1795  %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32>
1796
1797  // CHECK: return [[VAL]] : tensor<5xi32>
1798  func.return %reversed : tensor<5xi32>
1799}
1800
1801// -----
1802
1803// CHECK-LABEL: @reverse_func_64
1804func.func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> {
1805  %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>)
1806
1807  // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>}
1808  %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32>
1809
1810  // CHECK: return [[VAL]] : tensor<5xi32>
1811  func.return %reversed : tensor<5xi32>
1812}
1813
1814// -----
1815
1816// CHECK-LABEL: @reverse_func_neg
1817func.func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> {
1818  %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
1819
1820  // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<1> : tensor<1xi64>}
1821  %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32>
1822
1823  // CHECK: return [[VAL]] : tensor<5x5xi32>
1824  func.return %reversed : tensor<5x5xi32>
1825}
1826
1827//===----------------------------------------------------------------------===//
1828// StatefulPartitionedCall op legalization.
1829//===----------------------------------------------------------------------===//
1830
1831// -----
1832
1833// CHECK-LABEL: func @stateful_partitioned_call
1834// CHECK-SAME: [[ARG:%.+]]: tensor<i32>
1835func.func @stateful_partitioned_call(%arg0: tensor<i32>) -> tensor<i32> {
1836  // CHECK: call @stateful_pcall_func([[ARG]]) : (tensor<i32>) -> tensor<i32>
1837  %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_func} : (tensor<i32>) -> (tensor<i32>)
1838  func.return %0 : tensor<i32>
1839}
1840
1841func.func @stateful_pcall_func(%arg0: tensor<i32>) -> tensor<i32> {
1842  func.return %arg0 : tensor<i32>
1843}
1844
1845// -----
1846
1847// CHECK-LABEL: func @stateful_partitioned_call_multi_in_out
1848// CHECK-SAME: ([[ARG0:%.+]]: tensor<i32>, [[ARG1:%.+]]: tensor<i32>)
1849func.func @stateful_partitioned_call_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
1850  // CHECK: call @stateful_pcall_multi_in_out([[ARG0]], [[ARG1]]) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
1851  %0, %1 = "tf.StatefulPartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_multi_in_out} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
1852  func.return %0, %1 : tensor<i32>, tensor<i32>
1853}
1854
1855func.func @stateful_pcall_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
1856  func.return %arg1, %arg0 : tensor<i32>, tensor<i32>
1857}
1858
1859//===----------------------------------------------------------------------===//
1860// Elu op legalizations.
1861//===----------------------------------------------------------------------===//
1862
1863// -----
1864
1865// CHECK-LABEL: func @elu
1866func.func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
1867  // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<1xf32>) -> tensor<1xf32>
1868  // CHECK-DAG: %[[PRED:.*]] = mhlo.compare GT, %arg0, %[[ZERO]]
1869  // CHECK-DAG: %[[EXP:.*]] = mhlo.exponential_minus_one %arg0
1870  // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %arg0, %[[EXP]])
1871  // CHECK: return %[[RESULT]]
1872  %0 = "tf.Elu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
1873  func.return %0: tensor<1xf32>
1874}
1875
1876// -----
1877
1878// CHECK-LABEL: func @elu_unranked
1879func.func @elu_unranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
1880  // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<?xf32>) -> tensor<?xf32>
1881  // CHECK-DAG: %[[PRED:.*]] = mhlo.compare GT, %arg0, %[[ZERO]]
1882  // CHECK-DAG: %[[EXP:.*]] = mhlo.exponential_minus_one %arg0
1883  // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %arg0, %[[EXP]])
1884  // CHECK: return %[[RESULT]]
1885  %0 = "tf.Elu"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
1886  func.return %0: tensor<?xf32>
1887}
1888
1889// -----
1890
1891// CHECK-LABEL: func @elu_grad
1892// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<?x?xf32>)
1893func.func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tensor<4x8xf32> {
1894  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
1895  // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
1896  // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #mhlo<comparison_direction GT>}
1897  // CHECK-DAG: %[[ADD1:.*]] = chlo.broadcast_add %[[FEATURES]], %[[ONE]] {broadcast_dimensions = dense<> : tensor<0xi64>}
1898  // CHECK-DAG: %[[MULGRAD:.*]] = mhlo.multiply(%[[GRADIENTS]], %[[ADD1]]) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32>
1899  // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[MULGRAD]])
1900  // CHECK: return %[[RESULT]]
1901  %2 = "tf.EluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32>
1902  func.return %2 : tensor<4x8xf32>
1903}
1904
1905//===----------------------------------------------------------------------===//
1906// Relu op legalizations.
1907//===----------------------------------------------------------------------===//
1908
1909// -----
1910
1911// CHECK-LABEL: func @relu
1912func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
1913  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
1914  // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
1915  %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
1916  func.return %0: tensor<1xi32>
1917}
1918
1919// -----
1920
1921// CHECK-LABEL: func @relu_unranked
1922func.func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
1923  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
1924  // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
1925  %0 = "tf.Relu"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
1926  func.return %0: tensor<?xi32>
1927}
1928
1929// -----
1930
1931// CHECK-LABEL: func @relu_unsigned
1932func.func @relu_unsigned(%arg0: tensor<?xui32>) -> tensor<?xui32> {
1933  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<ui32>
1934  // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<ui32>, tensor<?xui32>) -> tensor<?xui32>
1935  %0 = "tf.Relu"(%arg0) : (tensor<?xui32>) -> tensor<?xui32>
1936  func.return %0: tensor<?xui32>
1937}
1938
1939// -----
1940
1941// CHECK-LABEL: func @relu6
1942func.func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
1943  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
1944  // CHECK-DAG: %[[SIX:.*]] = mhlo.constant dense<6> : tensor<i32>
1945  // CHECK: mhlo.clamp %[[ZERO]], %arg0, %[[SIX]] : (tensor<i32>, tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
1946  %0 = "tf.Relu6"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
1947  func.return %0: tensor<1xi32>
1948}
1949
1950// -----
1951
1952// CHECK-LABEL: func @relu6_unranked
1953func.func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
1954  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
1955  // CHECK-DAG: %[[SIX:.*]] = mhlo.constant dense<6> : tensor<i32>
1956  // CHECK: mhlo.clamp %[[ZERO]], %arg0, %[[SIX]] : (tensor<i32>, tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
1957  %0 = "tf.Relu6"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
1958  func.return %0: tensor<?xi32>
1959}
1960
1961// -----
1962
1963// CHECK-LABEL: func @relu6_unsigned
1964func.func @relu6_unsigned(%arg0: tensor<?xui32>) -> tensor<?xui32> {
1965  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<ui32>
1966  // CHECK-DAG: %[[SIX:.*]] = mhlo.constant dense<6> : tensor<ui32>
1967  // CHECK: mhlo.clamp %[[ZERO]], %arg0, %[[SIX]] : (tensor<ui32>, tensor<?xui32>, tensor<ui32>) -> tensor<?xui32>
1968  %0 = "tf.Relu6"(%arg0) : (tensor<?xui32>) -> tensor<?xui32>
1969  func.return %0: tensor<?xui32>
1970}
1971
1972// -----
1973
1974// CHECK-LABEL: func @relu_grad_unranked
1975// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<?x?xf32>, %[[FEATURES:.*]]: tensor<?x?xf32>)
1976func.func @relu_grad_unranked(%gradients: tensor<?x?xf32>, %features: tensor<?x?xf32>) -> tensor<?x?xf32> {
1977  // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
1978  // CHECK-DAG: %[[PRED:.*]] = mhlo.compare GT, %arg1, %0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
1979  // CHECK-DAG: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<?x?xi1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
1980  // CHECK-DAG: return %[[RESULT]] : tensor<?x?xf32>
1981  %2 = "tf.ReluGrad"(%gradients, %features) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
1982  func.return %2 : tensor<?x?xf32>
1983}
1984
1985// -----
1986
1987// CHECK-LABEL: func @leaky_relu
1988func.func @leaky_relu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attributes {tf.entry_function = {}} {
1989    // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg0) {value = 2.000000e-01 : f32} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
1990    // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
1991    // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[ALPHA]] : tensor<1x4x4x3xf32>
1992    // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP]], %[[ZERO]], NOTYPE : (tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xi1>
1993    // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[INP]], %[[LEAKY]]) : (tensor<1x4x4x3xi1>, tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
1994    // CHECK-NEXT: return %[[RES]] : tensor<1x4x4x3xf32>
1995    %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
1996    func.return %0 : tensor<1x4x4x3xf32>
1997}
1998
1999// -----
2000
2001// CHECK-LABEL: func @leaky_relu_unranked
2002func.func @leaky_relu_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> attributes {tf.entry_function = {}} {
2003    // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg0) {value = 2.000000e-01 : f32} : (tensor<*xf32>) -> tensor<*xf32>
2004    // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
2005    // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[ALPHA]] : tensor<*xf32>
2006    // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP]], %[[ZERO]], NOTYPE : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1>
2007    // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[INP]], %[[LEAKY]]) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
2008    // CHECK-NEXT: return %[[RES]] : tensor<*xf32>
2009    %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<*xf32>) -> tensor<*xf32>
2010    func.return %0 : tensor<*xf32>
2011}
2012
2013// -----
2014
2015// CHECK-LABEL: func @leaky_relu_grad
2016func.func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) -> tensor<1x4x4xf32> attributes {tf.entry_function = {}} {
2017    // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg1) {value = 2.000000e-01 : f32} : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
2018    // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) {value = 0.000000e+00 : f32} : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
2019    // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] : tensor<1x4x4xf32>
2020    // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP:.*]], %[[ZERO]], NOTYPE : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xi1>
2021    // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]]) : (tensor<1x4x4xi1>, tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
2022    // CHECK-NEXT: return %[[RES]] : tensor<1x4x4xf32>
2023    %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
2024    func.return %0 : tensor<1x4x4xf32>
2025}
2026
2027// -----
2028
2029// CHECK-LABEL: func @leaky_relu_grad_unranked
2030func.func @leaky_relu_grad_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {tf.entry_function = {}} {
2031    // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg1) {value = 2.000000e-01 : f32} : (tensor<*xf32>) -> tensor<*xf32>
2032    // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
2033    // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] : tensor<*xf32>
2034    // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP:.*]], %[[ZERO]], NOTYPE : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1>
2035    // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]]) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
2036    // CHECK-NEXT: return %[[RES]] : tensor<*xf32>
2037    %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
2038    func.return %0 : tensor<*xf32>
2039}
2040
2041// -----
2042
2043// CHECK-LABEL: func @softsign
2044func.func @softsign(%arg0: tensor<4x10xf32>) -> tensor<4x10xf32> {
2045    // CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : f32} : (tensor<4x10xf32>) -> tensor<4x10xf32>
2046    // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<4x10xf32>
2047    // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[ONE]], %[[ABS]] : tensor<4x10xf32>
2048    // CHECK-NEXT: %[[DIV:.*]] = mhlo.divide %{{.*}}, %[[ADD]] : tensor<4x10xf32>
2049    // CHECK-NEXT: return %[[DIV]] : tensor<4x10xf32>
2050    %0 = "tf.Softsign"(%arg0) : (tensor<4x10xf32>) -> tensor<4x10xf32>
2051    func.return %0 : tensor<4x10xf32>
2052}
2053
2054// -----
2055
2056// CHECK-LABEL: func @softsign_unranked
2057func.func @softsign_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2058    // CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
2059    // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<*xf32>
2060    // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[ONE]], %[[ABS]] : tensor<*xf32>
2061    // CHECK-NEXT: %[[DIV:.*]] = mhlo.divide %{{.*}}, %[[ADD]] : tensor<*xf32>
2062    // CHECK-NEXT: return %[[DIV]] : tensor<*xf32>
2063    %0 = "tf.Softsign"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2064    func.return %0 : tensor<*xf32>
2065}
2066
2067// -----
2068
2069// CHECK-LABEL: func @softsign_grad
2070func.func @softsign_grad(%arg0: tensor<4x10xf32>, %arg1: tensor<4x10xf32>) -> tensor<4x10xf32> {
2071
2072    // CHECK-NEXT: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
2073    // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<4x10xf32>
2074    // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = chlo.broadcast_add %[[ONE]], %[[ABS]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<4x10xf32>) -> tensor<4x10xf32>
2075    // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %[[BROADCAST_ADD]], %[[BROADCAST_ADD]] : tensor<4x10xf32>
2076    // CHECK-NEXT: %[[BROADCAST_DIV:.*]] = chlo.broadcast_divide %{{.*}}, %[[MUL]] : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32>
2077    // CHECK-NEXT: return %[[BROADCAST_DIV]] : tensor<4x10xf32>
2078    %0 = "tf.SoftsignGrad"(%arg0, %arg1) : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32>
2079    func.return %0 : tensor<4x10xf32>
2080}
2081
2082//===----------------------------------------------------------------------===//
2083// Roll op legalizations.
2084//===----------------------------------------------------------------------===//
2085
2086// -----
2087
2088// CHECK-LABEL: func @Roll_0D
2089func.func @Roll_0D(%arg0: tensor<512xi32>, %shift: tensor<i32>) -> tensor<512xi32> {
2090  %axis = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
2091  //      CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
2092  //      CHECK-DAG: %[[AXIS_SIZE:.*]] = mhlo.constant dense<512> : tensor<i32>
2093  //      CHECK: %[[T1:.+]] = mhlo.remainder %arg1, %[[AXIS_SIZE]] : tensor<i32>
2094  //      CHECK: %[[T2:.+]] = mhlo.add %[[T1]], %[[AXIS_SIZE]] : tensor<i32>
2095  //      CHECK: %[[T3:.+]] = mhlo.remainder %[[T2]], %[[AXIS_SIZE]] : tensor<i32>
2096  //      CHECK: %[[CONCAT:.+]] = "mhlo.concatenate"(%arg0, %arg0) {dimension = 0 : i64}
2097  //      CHECK: %[[OFFSET:.+]] = mhlo.subtract %[[AXIS_SIZE]], %[[T3]] : tensor<i32>
2098  //      CHECK: "mhlo.dynamic_slice"(%[[CONCAT]], %[[OFFSET]])
2099  // CHECK-SAME:    {slice_sizes = dense<512> : tensor<1xi64>}
2100  // CHECK-SAME:    (tensor<1024xi32>, tensor<i32>) -> tensor<512xi32>
2101  %0 = "tf.Roll"(%arg0, %shift, %axis) {device = ""} : (tensor<512xi32>, tensor<i32>, tensor<i32>) -> tensor<512xi32>
2102  func.return %0 : tensor<512xi32>
2103}
2104
2105//===----------------------------------------------------------------------===//
2106// Select op legalizations.
2107//===----------------------------------------------------------------------===//
2108
2109// -----
2110
2111// CHECK-LABEL: func @select_batch_static
2112func.func @select_batch_static(%arg0: tensor<2xi1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> {
2113  // CHECK: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %{{.*}}) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<2xi1>, tensor<3xindex>) -> tensor<2x6x8xi1>
2114  // CHECK: "mhlo.select"(%[[BCAST]], %arg1, %arg2)
2115  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32>
2116  func.return %0: tensor<2x6x8xi32>
2117}
2118
2119// -----
2120
2121// CHECK-LABEL: func @select_batch_static_r1
2122func.func @select_batch_static_r1(%arg0: tensor<i1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> {
2123  // CHECK: "mhlo.select"(%arg0, %arg1, %arg2)
2124  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32>
2125  func.return %0: tensor<2x6x8xi32>
2126}
2127
2128// -----
2129
2130// CHECK-LABEL: func @select_batch_static_all_same
2131func.func @select_batch_static_all_same(%arg0: tensor<2x6x8xi1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> {
2132  // CHECK: "mhlo.select"(%arg0, %arg1, %arg2)
2133  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2x6x8xi1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32>
2134  func.return %0: tensor<2x6x8xi32>
2135}
2136
2137// -----
2138
2139// CHECK-LABEL: func @select_batch_dynamic_r1
2140func.func @select_batch_dynamic_r1(%arg0: tensor<?xi1>, %arg1: tensor<?x?x8xi32>, %arg2: tensor<?x?x8xi32>) -> tensor<?x?x8xi32> {
2141  // CHECK-NEXT: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor<?xi1> -> tensor<1xindex>
2142  // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<?x?x8xi32> -> tensor<3xindex>
2143  // CHECK-NEXT: %[[SHAPE2:.*]] = shape.shape_of %arg2 : tensor<?x?x8xi32> -> tensor<3xindex>
2144  // CHECK-NEXT: %[[SHAPEEQ1:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex>
2145  // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
2146  // CHECK-NEXT: %[[HEAD:.*]], %[[TAIL:.*]] = "shape.split_at"(%[[SHAPE1]], %[[C1]]) : (tensor<3xindex>, index) -> (tensor<1xindex>, tensor<2xindex>)
2147  // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[HEAD]] : tensor<1xindex>, tensor<1xindex>
2148  // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming_all %[[SHAPEEQ1]], %[[SHAPEEQ2]]
2149  // CHECK-NEXT: %[[ASSUMING:.*]] = shape.assuming %[[SHAPEEQ]] -> (tensor<?x?x8xi32>) {
2150  // CHECK-NEXT: %[[SHAPE1E:.*]] = shape.to_extent_tensor %[[SHAPE1]] : tensor<3xindex> -> tensor<3xindex>
2151  // CHECK-NEXT: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE1E]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi1>, tensor<3xindex>) -> tensor<?x?x8xi1>
2152  // CHECK-NEXT: %[[SELECT:.*]] = "mhlo.select"(%[[BCAST]], %arg1, %arg2) : (tensor<?x?x8xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32>
2153  // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<?x?x8xi32>
2154  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<?xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32>
2155  func.return %0: tensor<?x?x8xi32>
2156}
2157
2158// -----
2159
2160// CHECK-LABEL: func @select_batch_dynamic
2161func.func @select_batch_dynamic(%arg0: tensor<?x?x8xi1>, %arg1: tensor<?x?x8xi32>, %arg2: tensor<?x?x8xi32>) -> tensor<?x?x8xi32> {
2162  // CHECK-NEXT: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor<?x?x8xi1> -> tensor<3xindex>
2163  // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<?x?x8xi32> -> tensor<3xindex>
2164  // CHECK-NEXT: %[[SHAPE2:.*]] = shape.shape_of %arg2 : tensor<?x?x8xi32> -> tensor<3xindex>
2165  // CHECK-NEXT: %[[SHAPEEQ1:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex>
2166  // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[SHAPE1]] : tensor<3xindex>, tensor<3xindex>
2167  // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming_all %[[SHAPEEQ1]], %[[SHAPEEQ2]]
2168  // CHECK-NEXT: %[[ASSUMING:.*]] = shape.assuming %[[SHAPEEQ]] -> (tensor<?x?x8xi32>) {
2169  // CHECK-NEXT: %[[SELECT:.*]] = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<?x?x8xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32>
2170  // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<?x?x8xi32>
2171  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<?x?x8xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32>
2172  func.return %0: tensor<?x?x8xi32>
2173}
2174
2175// -----
2176
2177// CHECK-LABEL: testSelectInvalidUnranked
2178func.func @testSelectInvalidUnranked(%arg0: tensor<6x7xi1>, %arg1: tensor<*xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> {
2179  // CHECK-NEXT: tf.Select
2180  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<6x7xi1>, tensor<*xf16>, tensor<*xf16>) -> tensor<*xf16>
2181  func.return %0: tensor<*xf16>
2182}
2183
2184// -----
2185
2186// CHECK-LABEL: testSelectThenUnranked
2187func.func @testSelectThenUnranked(%arg0: tensor<3xi1>, %arg1: tensor<*xf16>, %arg2: tensor<3x2xf16>) -> tensor<*xf16> {
2188  // CHECK-NEXT: tf.Select
2189  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<*xf16>, tensor<3x2xf16>) -> tensor<*xf16>
2190  func.return %0: tensor<*xf16>
2191}
2192
2193// -----
2194
2195// CHECK-LABEL: testSelectElseUnranked
2196func.func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> {
2197  // CHECK-NEXT: tf.Select
2198  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<*xf16>) -> tensor<*xf16>
2199  func.return %0: tensor<*xf16>
2200}
2201
2202// -----
2203
2204// CHECK-LABEL: func @selectv2_dynamic_ranked
2205func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> {
2206  // CHECK: chlo.broadcast_select
2207  %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32>
2208  func.return %0: tensor<2x?x8xi32>
2209}
2210
2211// -----
2212
2213// CHECK-LABEL: func @selectv2_unranked
2214func.func @selectv2_unranked(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<*xi32>) -> tensor<*xi32> {
2215  // CHECK: chlo.broadcast_select
2216  %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<*xi32>) -> tensor<*xi32>
2217  func.return %0: tensor<*xi32>
2218}
2219
2220//===----------------------------------------------------------------------===//
2221// Fast Fourier Transform op legalization.
2222//===----------------------------------------------------------------------===//
2223
2224// -----
2225
2226// CHECK-LABEL: func @fft_1D
2227func.func @fft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
2228  // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo<fft_type FFT>} : (tensor<8xcomplex<f32>>
2229  %0 = "tf.FFT"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
2230  func.return %0 : tensor<8xcomplex<f32>>
2231}
2232
2233// -----
2234
2235// CHECK-LABEL: func @ifft_1D
2236func.func @ifft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
2237  // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo<fft_type IFFT>} : (tensor<8xcomplex<f32>>
2238  %0 = "tf.IFFT"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
2239  func.return %0 : tensor<8xcomplex<f32>>
2240}
2241
2242// -----
2243
2244// CHECK-LABEL: func @rfft_1D
2245func.func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<5xcomplex<f32>> {
2246  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
2247  // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<8xf32>
2248  %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<8xf32>, tensor<1xi32>) -> tensor<5xcomplex<f32>>
2249  func.return %0 : tensor<5xcomplex<f32>>
2250}
2251
2252// -----
2253
2254// CHECK-LABEL: func @rfft_1D_padded
2255func.func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<5xcomplex<f32>> {
2256  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
2257  // CHECK: %[[PADDED:.*]] = "mhlo.pad"(%arg0, %{{.*}}) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<7xf32>, tensor<f32>) -> tensor<8xf32>
2258  // CHECK: "mhlo.fft"(%[[PADDED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<8xf32>
2259  %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<7xf32>, tensor<1xi32>) -> tensor<5xcomplex<f32>>
2260  func.return %0 : tensor<5xcomplex<f32>>
2261}
2262
2263// -----
2264
2265// CHECK-LABEL: func @rfft_1D_sliced
2266func.func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x5xcomplex<f32>> {
2267  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
2268  // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[2, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x9xf32>) -> tensor<2x8xf32>
2269  // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<2x8xf32>
2270  %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<2x9xf32>, tensor<1xi32>) -> tensor<2x5xcomplex<f32>>
2271  func.return %0 : tensor<2x5xcomplex<f32>>
2272}
2273
2274// -----
2275
2276// CHECK-LABEL: func @irfft_1D
2277func.func @irfft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xf32> {
2278  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
2279  // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<5> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<8xcomplex<f32>>) -> tensor<5xcomplex<f32>>
2280  // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo<fft_type IRFFT>} : (tensor<5xcomplex<f32>>
2281  %0 = "tf.IRFFT"(%arg0, %fftlength) : (tensor<8xcomplex<f32>>, tensor<1xi32>) -> tensor<8xf32>
2282  func.return %0 : tensor<8xf32>
2283}
2284
2285// -----
2286
2287// CHECK-LABEL: fft_1D_dynamic
2288func.func @fft_1D_dynamic(%arg0: tensor<?xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
2289  // CHECK: "tf.FFT"
2290  %0 = "tf.FFT"(%arg0) : (tensor<?xcomplex<f32>>) -> tensor<8xcomplex<f32>>
2291  func.return %0 : tensor<8xcomplex<f32>>
2292}
2293
2294// -----
2295
2296// CHECK-LABEL: rfft_1D_dynamic
2297func.func @rfft_1D_dynamic(%arg0: tensor<?xf32>) -> tensor<8xcomplex<f32>> {
2298  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
2299  // CHECK: "tf.RFFT"
2300  %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<?xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>>
2301  func.return %0 : tensor<8xcomplex<f32>>
2302}
2303
2304//===----------------------------------------------------------------------===//
2305// Shape op legalization.
2306//===----------------------------------------------------------------------===//
2307
2308// -----
2309
2310// CHECK-LABEL: func @shape_1D
2311func.func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> {
2312  // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
2313  // CHECK: [[TENSOR:%.+]] = arith.index_cast [[SHAPE]] : tensor<1xindex> to tensor<1xi32>
2314  %0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32>
2315
2316  // CHECK: return [[TENSOR]]
2317  func.return %0 : tensor<1xi32>
2318}
2319
2320// -----
2321
2322// CHECK-LABEL: func @shape_2D
2323func.func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> {
2324  // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
2325  // CHECK: [[TENSOR:%.+]] = arith.index_cast [[SHAPE]] : tensor<2xindex> to tensor<2xi32>
2326  %0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32>
2327
2328  // CHECK: return [[TENSOR]]
2329  func.return %0 : tensor<2xi32>
2330}
2331
2332// -----
2333
2334// CHECK-LABEL: func @shape_rankless
2335func.func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> {
2336  // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
2337  // CHECK: [[TENSOR:%.+]] = arith.index_cast [[SHAPE]] : tensor<?xindex> to tensor<?xi32>
2338  %0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32>
2339
2340  // CHECK: return [[TENSOR]]
2341  func.return %0 : tensor<?xi32>
2342}
2343
2344//===----------------------------------------------------------------------===//
2345// Transpose op legalization.
2346//===----------------------------------------------------------------------===//
2347
2348// -----
2349
2350// CHECK-LABEL: @transpose_noop
2351func.func @transpose_noop(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
2352  %permutation = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2353  // CHECK: return %arg0
2354  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<2x3xf32>
2355  func.return %0 : tensor<2x3xf32>
2356}
2357
2358// -----
2359
2360// CHECK-LABEL: @transpose_2d
2361func.func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
2362  %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2363  // CHECK: "mhlo.transpose"
2364  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32>
2365  func.return %0 : tensor<3x2xf32>
2366}
2367
2368// -----
2369
2370// CHECK-LABEL: @transpose_3d_int32
2371func.func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
2372  %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> (tensor<3xi32>)
2373  // CHECK: "mhlo.transpose"
2374  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x2x1xf32>
2375  func.return %0 : tensor<3x2x1xf32>
2376}
2377
2378// -----
2379
2380// CHECK-LABEL: @transpose_3d
2381func.func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
2382  %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> (tensor<3xi64>)
2383  // CHECK: "mhlo.transpose"
2384  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32>
2385  func.return %0 : tensor<3x2x1xf32>
2386}
2387
2388// -----
2389
2390// CHECK-LABEL: @transpose_dynamic_2d
2391func.func @transpose_dynamic_2d(%arg0: tensor<?x4xf32>) -> tensor<4x?xf32> {
2392  %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2393  // CHECK: "mhlo.transpose"
2394  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<?x4xf32>, tensor<2xi64>) -> tensor<4x?xf32>
2395  func.return %0 : tensor<4x?xf32>
2396}
2397
2398// -----
2399
2400// CHECK-LABEL: @transpose_unranked_2d
2401func.func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2402  %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2403  // CHECK: "mhlo.transpose"
2404  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<*xf32>, tensor<2xi64>) -> tensor<*xf32>
2405  func.return %0 : tensor<*xf32>
2406}
2407
2408
2409//===----------------------------------------------------------------------===//
2410// Unary op legalizations.
2411//===----------------------------------------------------------------------===//
2412
2413// -----
2414
2415// CHECK-LABEL: @abs
2416func.func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2417  // CHECK:  mhlo.abs %arg0 : tensor<2xf32>
2418  %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2419  func.return %0 : tensor<2xf32>
2420}
2421
2422// -----
2423
2424// CHECK-LABEL: func @abs_dynamic
2425func.func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2426  // CHECK:  mhlo.abs %arg0 : tensor<?xf32>
2427  %0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2428  func.return %0 : tensor<?xf32>
2429}
2430
2431// -----
2432
2433// CHECK-LABEL: func @abs_unranked
2434func.func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2435  // CHECK:  mhlo.abs %arg0 : tensor<*xf32>
2436  %0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2437  func.return %0 : tensor<*xf32>
2438}
2439
2440// -----
2441
2442// CHECK-LABEL: @acos
2443// CHLO-LABEL: @acos
2444func.func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2445  // CHECK:  chlo.acos %arg0 : tensor<2xf32>
2446// CHLO:   %[[VAL_1:.*]] = mhlo.compare NE, {{.*}}
2447// CHLO:   %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00>
2448// CHLO:   %[[VAL_4:.*]] = mhlo.constant dense<1.000000e+00>
2449// CHLO:   %[[VAL_5:.*]] = mhlo.multiply %arg0, %arg0
2450// CHLO:   %[[VAL_6:.*]] = mhlo.subtract %[[VAL_4]], %[[VAL_5]]
2451// CHLO:   %[[VAL_7:.*]] = mhlo.sqrt %[[VAL_6]]
2452// CHLO:   %[[VAL_8:.*]] = mhlo.constant dense<1.000000e+00>
2453// CHLO:   %[[VAL_9:.*]] = mhlo.add %[[VAL_8]], %arg0
2454// CHLO:   %[[VAL_10:.*]] = mhlo.atan2 %[[VAL_7]], %[[VAL_9]]
2455// CHLO:   %[[VAL_11:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_10]]
2456// CHLO:   %[[VAL_12:.*]] = mhlo.constant dense<3.14159274>
2457// CHLO:   %[[VAL_13:.*]] = "mhlo.select"(%[[VAL_1]], %[[VAL_11]], %[[VAL_12]])
2458// CHLO:       return %[[VAL_13]] : tensor<2xf32>
2459  %0 = "tf.Acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2460  func.return %0 : tensor<2xf32>
2461}
2462
2463// -----
2464
2465// CHECK-LABEL: @acos_complex
2466// CHLO-LABEL: @acos_complex
2467func.func @acos_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
2468  // CHLO: tf.Acos
2469  %0 = "tf.Acos"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
2470  func.return %0 : tensor<2xcomplex<f32>>
2471}
2472
2473// -----
2474
2475// CHECK-LABEL: @acos_dynamic
2476// CHLO-LABEL: @acos_dynamic
2477func.func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2478  // CHECK:  chlo.acos %arg0 : tensor<*xf32>
2479  // `tf.Acos` is lowered to `chlo.constant_like` operations which can only be
2480  // lowered further on ranked tensors.  Unranked CHLO must be transformed to
2481  // ranked code before further lowering.
2482  // CHLO: "tf.Acos"
2483  %0 = "tf.Acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2484  func.return %0 : tensor<*xf32>
2485}
2486
2487// -----
2488
2489// CHECK-LABEL: @tan
2490// CHECK-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32>
2491// CHLO-LABEL: @tan
2492// CHLO-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32>
2493func.func @tan(%arg : tensor<2xf32>) -> tensor<2xf32> {
2494  // CHECK: chlo.tan %[[ARG]] : tensor<2xf32>
2495  // CHLO: %[[SINE:.*]] = mhlo.sine %[[ARG]]
2496  // CHLO  %[[COSINE:.*]] = mhlo.cosine %[[ARG]]
2497  // CHLO  %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]])
2498  %result = "tf.Tan"(%arg) : (tensor<2xf32>) -> tensor<2xf32>
2499  func.return %result : tensor<2xf32>
2500}
2501
2502// -----
2503
2504// CHECK-LABEL: @tan_unranked
2505// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32>
2506// CHLO-LABEL: @tan_unranked
2507// CHLO-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32>
2508func.func @tan_unranked(%arg : tensor<*xf32>) -> tensor<*xf32> {
2509  // CHECK: chlo.tan %[[ARG]] : tensor<*xf32>
2510  // CHLO: %[[SINE:.*]] = mhlo.sine %[[ARG]]
2511  // CHLO  %[[COSINE:.*]] = mhlo.cosine %[[ARG]]
2512  // CHLO  %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]])
2513  %result = "tf.Tan"(%arg) : (tensor<*xf32>) -> tensor<*xf32>
2514  func.return %result : tensor<*xf32>
2515}
2516
2517// -----
2518
2519// CHECK-LABEL: func @cast_dynamic_i2f
2520func.func @cast_dynamic_i2f(%arg0: tensor<?xi32>) -> tensor<?xf32> {
2521  // CHECK: mhlo.convert(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
2522  %0 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
2523  func.return %0 : tensor<?xf32>
2524}
2525
2526// -----
2527
2528// CHECK-LABEL: func @cast_i2f
2529func.func @cast_i2f(%arg0: tensor<2xi32>) -> tensor<2xf32> {
2530  // CHECK: mhlo.convert(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
2531  %0 = "tf.Cast"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
2532  func.return %0 : tensor<2xf32>
2533}
2534
2535// -----
2536
2537// CHECK-LABEL: func @cast_c2f
2538func.func @cast_c2f(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
2539  // CHECK: mhlo.convert(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
2540  %0 = "tf.Cast"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
2541  func.return %0 : tensor<2xf32>
2542}
2543
2544// -----
2545
2546// CHECK-LABEL: @ceil
2547func.func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2548  // CHECK:  mhlo.ceil %arg0 : tensor<2xf32>
2549  %0 = "tf.Ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2550  func.return %0 : tensor<2xf32>
2551}
2552
2553// -----
2554
2555// CHECK-LABEL: func @ceil_dynamic
2556func.func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2557  // CHECK:  mhlo.ceil %arg0 : tensor<?xf32>
2558  %0 = "tf.Ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2559  func.return %0 : tensor<?xf32>
2560}
2561
2562// -----
2563
2564// CHECK-LABEL: func @ceil_unranked
2565func.func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2566  // CHECK:  mhlo.ceil %arg0 : tensor<*xf32>
2567  %0 = "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2568  func.return %0 : tensor<*xf32>
2569}
2570
2571// -----
2572
2573// CHECK-LABEL: @complex_abs
2574func.func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
2575  // CHECK:  mhlo.abs(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
2576  %0 = "tf.ComplexAbs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
2577  func.return %0 : tensor<2xf32>
2578}
2579
2580// -----
2581
2582// CHECK-LABEL: @cos
2583func.func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2584  // CHECK:  mhlo.cosine %arg0 : tensor<2xf32>
2585  %0 = "tf.Cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2586  func.return %0 : tensor<2xf32>
2587}
2588
2589// -----
2590
2591// CHECK-LABEL: func @cos_dynamic
2592func.func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2593  // CHECK:  mhlo.cosine %arg0 : tensor<?xf32>
2594  %0 = "tf.Cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2595  func.return %0 : tensor<?xf32>
2596}
2597
2598// -----
2599
2600// CHECK-LABEL: func @cos_unranked
2601func.func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2602  // CHECK:  mhlo.cosine %arg0 : tensor<*xf32>
2603  %0 = "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2604  func.return %0 : tensor<*xf32>
2605}
2606
2607// -----
2608
2609// CHECK-LABEL: @exp
2610func.func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2611  // CHECK:  mhlo.exponential %arg0 : tensor<2xf32>
2612  %0 = "tf.Exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2613  func.return %0 : tensor<2xf32>
2614}
2615
2616// -----
2617
2618// CHECK-LABEL: @expm1
2619func.func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2620  // CHECK:  mhlo.exponential_minus_one %arg0 : tensor<2xf32>
2621  %0 = "tf.Expm1"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2622  func.return %0 : tensor<2xf32>
2623}
2624
2625// -----
2626
2627// CHECK-LABEL: func @exp_dynamic
2628func.func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2629  // CHECK:  mhlo.exponential %arg0 : tensor<?xf32>
2630  %0 = "tf.Exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2631  func.return %0 : tensor<?xf32>
2632}
2633
2634// -----
2635
2636// CHECK-LABEL: func @exp_unranked
2637func.func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2638  // CHECK:  mhlo.exponential %arg0 : tensor<*xf32>
2639  %0 = "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2640  func.return %0 : tensor<*xf32>
2641}
2642
2643// -----
2644
2645// CHECK-LABEL: @floor
2646func.func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2647  // CHECK:  mhlo.floor %arg0 : tensor<2xf32>
2648  %0 = "tf.Floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2649  func.return %0 : tensor<2xf32>
2650}
2651
2652// -----
2653
2654// CHECK-LABEL: func @floor_dynamic
2655func.func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2656  // CHECK:  mhlo.floor %arg0 : tensor<?xf32>
2657  %0 = "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2658  func.return %0 : tensor<?xf32>
2659}
2660
2661// -----
2662
2663// CHECK-LABEL: func @floor_unranked
2664func.func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2665  // CHECK:  mhlo.floor %arg0 : tensor<*xf32>
2666  %0 = "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2667  func.return %0 : tensor<*xf32>
2668}
2669
2670// -----
2671
2672// CHECK-LABEL: func @invert_op_unranked
2673func.func @invert_op_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> {
2674  // CHECK:  mhlo.not %arg0 : tensor<*xi32>
2675  %0 = "tf.Invert"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
2676  func.return %0 : tensor<*xi32>
2677}
2678
2679// -----
2680
2681// CHECK-LABEL: @is_finite
2682func.func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> {
2683  // CHECK:  mhlo.is_finite(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
2684  %0 = "tf.IsFinite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
2685  func.return %0 : tensor<2xi1>
2686}
2687
2688// -----
2689
2690// CHECK-LABEL: func @is_finite_dynamic
2691func.func @is_finite_dynamic(%arg0: tensor<?xf32>) -> tensor<?xi1> {
2692  // CHECK:  mhlo.is_finite(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
2693  %0 = "tf.IsFinite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
2694  func.return %0 : tensor<?xi1>
2695}
2696
2697// -----
2698
2699// CHECK-LABEL: func @is_finite_unranked
2700func.func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> {
2701  // CHECK:  mhlo.is_finite(%arg0) : (tensor<*xf32>) -> tensor<*xi1>
2702  %0 = "tf.IsFinite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1>
2703  func.return %0 : tensor<*xi1>
2704}
2705
2706// -----
2707
2708// CHECK-LABEL: @log
2709func.func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2710  // CHECK:  mhlo.log %arg0 : tensor<2xf32>
2711  %0 = "tf.Log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2712  func.return %0 : tensor<2xf32>
2713}
2714
2715// -----
2716
2717// CHECK-LABEL: func @log_dynamic
2718func.func @log_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2719  // CHECK:  mhlo.log %arg0 : tensor<?xf32>
2720  %0 = "tf.Log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2721  func.return %0 : tensor<?xf32>
2722}
2723
2724// -----
2725
2726// CHECK-LABEL: func @log_unranked
2727func.func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2728  // CHECK:  mhlo.log %arg0 : tensor<*xf32>
2729  %0 = "tf.Log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2730  func.return %0 : tensor<*xf32>
2731}
2732
2733// -----
2734
2735// CHECK-LABEL: @log1p
2736func.func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2737  // CHECK:  mhlo.log_plus_one %arg0 : tensor<2xf32>
2738  %0 = "tf.Log1p"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2739  func.return %0 : tensor<2xf32>
2740}
2741
2742// -----
2743
2744// CHECK-LABEL: func @log1p_dynamic
2745func.func @log1p_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2746  // CHECK:  mhlo.log_plus_one %arg0 : tensor<?xf32>
2747  %0 = "tf.Log1p"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2748  func.return %0 : tensor<?xf32>
2749}
2750
2751// -----
2752
2753// CHECK-LABEL: func @log1p_unranked
2754func.func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2755  // CHECK:  mhlo.log_plus_one %arg0 : tensor<*xf32>
2756  %0 = "tf.Log1p"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2757  func.return %0 : tensor<*xf32>
2758}
2759
2760// -----
2761
2762// CHECK-LABEL: func @not_op_unranked
2763func.func @not_op_unranked(%arg0: tensor<*xi1>) -> tensor<*xi1> {
2764  // CHECK:  mhlo.not %arg0 : tensor<*xi1>
2765  %0 = "tf.LogicalNot"(%arg0) : (tensor<*xi1>) -> tensor<*xi1>
2766  func.return %0 : tensor<*xi1>
2767}
2768
2769// -----
2770
2771// CHECK-LABEL: @neg
2772func.func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2773  // CHECK:  mhlo.negate %arg0 : tensor<2xf32>
2774  %0 = "tf.Neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2775  func.return %0 : tensor<2xf32>
2776}
2777
2778// -----
2779
2780// CHECK-LABEL: func @neg_dynamic
2781func.func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2782  // CHECK:  mhlo.negate %arg0 : tensor<?xf32>
2783  %0 = "tf.Neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2784  func.return %0 : tensor<?xf32>
2785}
2786
2787// -----
2788
2789// CHECK-LABEL: func @neg_unranked
2790func.func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2791  // CHECK:  mhlo.negate %arg0 : tensor<*xf32>
2792  %0 = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2793  func.return %0 : tensor<*xf32>
2794}
2795
2796// -----
2797
2798// CHECK-LABEL: @sigmoid
2799func.func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2800  // CHECK: mhlo.logistic
2801  %0 = "tf.Sigmoid"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2802  func.return %0 : tensor<2xf32>
2803}
2804
2805// -----
2806
2807// CHECK-LABEL: @sigmoid_complex
2808func.func @sigmoid_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
2809  // CHECK: mhlo.logistic
2810  %0 = "tf.Sigmoid"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
2811  func.return %0 : tensor<2xcomplex<f32>>
2812}
2813
2814// -----
2815
2816// CHECK-LABEL: @sigmoid_unranked
2817func.func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2818  // CHECK: mhlo.logistic
2819  %0 = "tf.Sigmoid"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2820  func.return %0 : tensor<*xf32>
2821}
2822
2823
2824// CHECK-LABEL: @sigmoid_grad
2825func.func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
2826  // CHECK-DAG: [[MUL0:%.+]] =  mhlo.multiply %arg1, %arg0 : tensor<2xf32>
2827  // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<2xf32>
2828  // CHECK-DAG: [[SUB:%.+]] =  mhlo.subtract [[ONE]], %arg0 : tensor<2xf32>
2829  // CHECK-DAG: [[MUL1:%.+]] =  mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xf32>
2830  // CHECK: return [[MUL1]]
2831  %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
2832  func.return %0 : tensor<2xf32>
2833}
2834
2835// -----
2836
2837// CHECK-LABEL: @sigmoid_grad_complex
2838func.func @sigmoid_grad_complex(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
2839  // CHECK-DAG: [[MUL0:%.+]] =  mhlo.multiply %arg1, %arg0 : tensor<2xcomplex<f32>>
2840  // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<2xcomplex<f32>>
2841  // CHECK-DAG: [[SUB:%.+]] =  mhlo.subtract [[ONE]], %arg0 : tensor<2xcomplex<f32>>
2842  // CHECK-DAG: [[MUL1:%.+]] =  mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xcomplex<f32>>
2843  // CHECK: return [[MUL1]]
2844  %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
2845  func.return %0 : tensor<2xcomplex<f32>>
2846}
2847
2848// -----
2849
2850// CHECK-LABEL: @sigmoid_grad_dynamic
2851func.func @sigmoid_grad_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
2852  // CHECK: chlo.broadcast_multiply {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
2853  // CHECK: chlo.broadcast_subtract {{.*}} {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
2854  // CHECK: chlo.broadcast_multiply {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
2855  %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
2856  func.return %0 : tensor<?xf32>
2857}
2858
2859// -----
2860
2861// CHECK-LABEL: @sin
2862func.func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2863  // CHECK:  mhlo.sine %arg0 : tensor<2xf32>
2864  %0 = "tf.Sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2865  func.return %0 : tensor<2xf32>
2866}
2867
2868// -----
2869
2870// CHECK-LABEL: func @sin_dynamic
2871func.func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2872  // CHECK:  mhlo.sine %arg0 : tensor<?xf32>
2873  %0 = "tf.Sin"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2874  func.return %0 : tensor<?xf32>
2875}
2876
2877// -----
2878
2879// CHECK-LABEL: func @sin_unranked
2880func.func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2881  // CHECK:  mhlo.sine %arg0 : tensor<*xf32>
2882  %0 = "tf.Sin"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2883  func.return %0 : tensor<*xf32>
2884}
2885
2886// -----
2887
2888// CHECK-LABEL: func @rsqrt
2889func.func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2890  // CHECK:  mhlo.rsqrt %arg0 : tensor<2xf32>
2891  %0 = "tf.Rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2892  func.return %0 : tensor<2xf32>
2893}
2894
2895// -----
2896
2897// CHECK-LABEL: func @rsqrt_dynamic
2898func.func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2899  // CHECK:  mhlo.rsqrt %arg0 : tensor<?xf32>
2900  %0 = "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2901  func.return %0 : tensor<?xf32>
2902}
2903
2904// -----
2905
2906// CHECK-LABEL: func @rsqrt_unranked
2907func.func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2908  // CHECK:  mhlo.rsqrt %arg0 : tensor<*xf32>
2909  %0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2910  func.return %0 : tensor<*xf32>
2911}
2912
2913// -----
2914
2915// CHECK-LABEL: func @sqrt
2916func.func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2917  // CHECK:  mhlo.sqrt %arg0 : tensor<2xf32>
2918  %0 = "tf.Sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2919  func.return %0 : tensor<2xf32>
2920}
2921
2922// -----
2923
2924// CHECK-LABEL: func @sqrt_dynamic
2925func.func @sqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2926  // CHECK:  mhlo.sqrt %arg0 : tensor<?xf32>
2927  %0 = "tf.Sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2928  func.return %0 : tensor<?xf32>
2929}
2930
2931// -----
2932
2933// CHECK-LABEL: func @sqrt_unranked
2934func.func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2935  // CHECK:  mhlo.sqrt %arg0 : tensor<*xf32>
2936  %0 = "tf.Sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2937  func.return %0 : tensor<*xf32>
2938}
2939
2940// -----
2941
2942// CHECK-LABEL: func @tanh
2943func.func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2944  // CHECK:  mhlo.tanh %arg0 : tensor<2xf32>
2945  %0 = "tf.Tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2946  func.return %0 : tensor<2xf32>
2947}
2948
2949// -----
2950
2951// CHECK-LABEL: func @tanh_dynamic
2952func.func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2953  // CHECK:  mhlo.tanh %arg0 : tensor<?xf32>
2954  %0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2955  func.return %0 : tensor<?xf32>
2956}
2957
2958// -----
2959
2960// CHECK-LABEL: func @tanh_unranked
2961func.func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2962  // CHECK:  mhlo.tanh %arg0 : tensor<*xf32>
2963  %0 = "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2964  func.return %0 : tensor<*xf32>
2965}
2966
2967// -----
2968
2969// CHECK-LABEL: func @bitcast
2970func.func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2971  // CHECK:  mhlo.bitcast_convert %arg0 : (tensor<2xf32>) -> tensor<2xf32>
2972  %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2973  func.return %0 : tensor<2xf32>
2974}
2975
2976// -----
2977
2978// CHECK-LABEL: func @bitcast_dynamic
2979func.func @bitcast_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2980  // CHECK:  mhlo.bitcast_convert %arg0 : (tensor<?xf32>) -> tensor<?xf32>
2981  %0 = "tf.Bitcast"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2982  func.return %0 : tensor<?xf32>
2983}
2984
2985// -----
2986
2987// CHECK-LABEL: func @bitcast_unranked
2988func.func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2989  // CHECK:  mhlo.bitcast_convert %arg0 : (tensor<*xf32>) -> tensor<*xf32>
2990  %0 = "tf.Bitcast"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2991  func.return %0 : tensor<*xf32>
2992}
2993
2994// -----
2995
2996// CHECK-LABEL: func @bitcast_same_widths
2997func.func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> {
2998  // CHECK:  mhlo.bitcast_convert %arg0 : (tensor<2xf32>) -> tensor<2xi32>
2999  %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xi32>
3000  func.return %0 : tensor<2xi32>
3001}
3002
3003// -----
3004
3005// CHECK-LABEL: func @bitcast_smaller_input_width
3006func.func @bitcast_smaller_input_width(%arg0: tensor<8xi8>) -> tensor<i64> {
3007  // CHECK:  mhlo.bitcast_convert %arg0 : (tensor<8xi8>) -> tensor<i64>
3008  %0 = "tf.Bitcast"(%arg0) : (tensor<8xi8>) -> tensor<i64>
3009  func.return %0 : tensor<i64>
3010}
3011
3012// -----
3013
3014// CHECK-LABEL: func @bitcast_smaller_output_width
3015func.func @bitcast_smaller_output_width(%arg0: tensor<2xf32>) -> tensor<2x2xf16> {
3016  // CHECK:  mhlo.bitcast_convert %arg0 : (tensor<2xf32>) -> tensor<2x2xf16>
3017  %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2x2xf16>
3018  func.return %0 : tensor<2x2xf16>
3019}
3020
3021// -----
3022
3023// CHECK-LABEL: reshape
3024func.func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<2x1xf32> {
3025  // CHECK:  mhlo.reshape
3026  %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<2x1xf32>
3027  func.return %0 : tensor<2x1xf32>
3028}
3029
3030// -----
3031
3032// CHECK-LABEL: not_lowering_reshape
3033func.func @not_lowering_reshape(%arg0: tensor<!tf_type.string>, %arg1: tensor<1xi32>) -> tensor<1x!tf_type.string> {
3034  // CHECK:  "tf.Reshape"
3035  %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<!tf_type.string>, tensor<1xi32>) -> tensor<1x!tf_type.string>
3036  func.return %0 : tensor<1x!tf_type.string>
3037}
3038
3039// -----
3040
3041// CHECK-LABEL: reshape_dynamic
3042func.func @reshape_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> {
3043  // CHECK:  "chlo.dynamic_reshape"
3044  // CHLO:  mhlo.compute_reshape_shape
3045  // CHLO:  mhlo.dynamic_reshape
3046  %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
3047  func.return %0 : tensor<?x?xf32>
3048}
3049
3050// -----
3051
3052// CHECK-LABEL: reshape_unranked
3053// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
3054// CHECK-SAME: %[[TARGET_SHAPE:.*]]: tensor<2xi32>
3055func.func @reshape_unranked(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> {
3056  // CHECK:  "chlo.dynamic_reshape"
3057  // CHLO:  shape.shape_of
3058  // CHLO:  shape.num_elements
3059  // CHLO:  mhlo.cstr_reshapable
3060  // CHLO:  assuming{{.*}}{
3061  // CHLO:   mhlo.compute_reshape_shape
3062  // CHLO:   mhlo.dynamic_reshape
3063  // CHLO:  }
3064  %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32>
3065  func.return %0 : tensor<?x?xf32>
3066}
3067
3068// -----
3069
3070// CHECK-LABEL: squeeze
3071func.func @squeeze(%arg0: tensor<1x1x10xf32>) -> tensor<1x10xf32> {
3072  // CHECK: mhlo.reshape
3073  %0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
3074  func.return %0 : tensor<1x10xf32>
3075}
3076
3077// -----
3078
3079// CHECK-LABEL: squeeze_ranked
3080func.func @squeeze_ranked(%arg0: tensor<?x?x?xf32>) -> tensor<?xf32> {
3081  // CHECK: %[[C2:.*]] = arith.constant 2 : index
3082  // CHECK: %[[D2:.*]] = tensor.dim %arg0, %[[C2]] : tensor<?x?x?xf32>
3083  // CHECK: %[[T:.*]] = tensor.from_elements %[[D2]] : tensor<1xindex>
3084  // CHECK: %[[R:.*]] = "chlo.dynamic_reshape"(%arg0, %[[T]]) : (tensor<?x?x?xf32>, tensor<1xindex>) -> tensor<?xf32>
3085  // CHECK: return %[[R]] : tensor<?xf32>
3086  %0 = "tf.Squeeze"(%arg0) { squeeze_dims = [0, 1] }: (tensor<?x?x?xf32>) -> tensor<?xf32>
3087  func.return %0 : tensor<?xf32>
3088}
3089
3090// -----
3091
3092// CHECK-LABEL: squeeze_ranked_negative
3093func.func @squeeze_ranked_negative(%arg0: tensor<?x?x10xf32>) -> tensor<?x10xf32> {
3094  // CHECK: %[[C0:.*]] = arith.constant 0 : index
3095  // CHECK: %[[D0:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?x10xf32>
3096  // CHECK: %[[C2:.*]] = arith.constant 2 : index
3097  // CHECK: %[[D2:.*]] = tensor.dim %arg0, %[[C2]] : tensor<?x?x10xf32>
3098  // CHECK: %[[T:.*]] = tensor.from_elements %[[D0]], %[[D2]] : tensor<2xindex>
3099  // CHECK: %[[R:.*]] = "chlo.dynamic_reshape"(%arg0, %[[T]]) : (tensor<?x?x10xf32>, tensor<2xindex>) -> tensor<?x10xf32>
3100  // CHECK: return %[[R]] : tensor<?x10xf32>
3101  %0 = "tf.Squeeze"(%arg0) { squeeze_dims = [-2] }: (tensor<?x?x10xf32>) -> tensor<?x10xf32>
3102  func.return %0 : tensor<?x10xf32>
3103}
3104
3105// -----
3106
3107// CHECK-LABEL: squeeze_ranked_dynamic
3108func.func @squeeze_ranked_dynamic(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
3109  // CHECK: "tf.Squeeze"
3110  %0 = "tf.Squeeze"(%arg0) : (tensor<?x?xf32>) -> tensor<?xf32>
3111  func.return %0 : tensor<?xf32>
3112}
3113
3114// -----
3115
3116// CHECK-LABEL: squeeze_dynamic
3117func.func @squeeze_dynamic(%arg0: tensor<?x10xf32>) -> tensor<*xf32> {
3118  // CHECK: "tf.Squeeze"
3119  %0 = "tf.Squeeze"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
3120  func.return %0 : tensor<*xf32>
3121}
3122
3123// -----
3124
3125// CHECK-LABEL: expand_dims
3126func.func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor<i32>) -> tensor<1x2xf32> {
3127  // CHECK: mhlo.reshape
3128  %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor<i32>) -> tensor<1x2xf32>
3129  func.return %0 : tensor<1x2xf32>
3130}
3131
3132// -----
3133
3134// CHECK-LABEL: expand_dims_dynamic
3135func.func @expand_dims_dynamic(%arg0: tensor<?x?xf32>) -> tensor<?x1x?xf32> {
3136  %axis = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> (tensor<i32>)
3137
3138  // CHECK-DAG: %[[SHAPEOF:.+]] = shape.shape_of %arg0
3139  // CHECK-DAG: %[[CST0:.+]] = arith.constant 0
3140  // CHECK-DAG: %[[CST1:.+]] = arith.constant 1
3141  // CHECK-DAG: %[[GETEXTENT0:.+]] = tensor.extract %[[SHAPEOF]][%[[CST0]]]
3142  // CHECK-DAG: %[[CST1_0:.+]] = arith.constant 1
3143  // CHECK-DAG: %[[GETEXTENT1:.+]] = tensor.extract %[[SHAPEOF]][%[[CST1_0]]]
3144  // CHECK-DAG: %[[TOEXTENTS:.+]] = tensor.from_elements %[[GETEXTENT0]], %[[CST1]], %[[GETEXTENT1]]
3145  // CHECK-DAG: %[[RESHAPE:.+]] = mhlo.dynamic_reshape %arg0, %[[TOEXTENTS]]
3146  %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<?x?xf32>, tensor<i32>) -> tensor<?x1x?xf32>
3147
3148  // CHECK: return %[[RESHAPE]]
3149  func.return %0 : tensor<?x1x?xf32>
3150}
3151
3152// -----
3153
3154// CHECK-LABEL: expand_dynamic_dims_rank1_axis
3155func.func @expand_dynamic_dims_rank1_axis(%arg0: tensor<?x?x4xf32>) -> tensor<?x1x?x4xf32> {
3156  %axis = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3157
3158  // CHECK-DAG: %[[SHAPEOF:.+]] = shape.shape_of %arg0
3159  // CHECK-DAG: %[[CST0:.+]] = arith.constant 0
3160  // CHECK-DAG: %[[CST1:.+]] = arith.constant 1
3161  // CHECK-DAG: %[[GETEXTENT0:.+]] = tensor.extract %[[SHAPEOF]][%[[CST0]]]
3162  // CHECK-DAG: %[[CST1_0:.+]] = arith.constant 1
3163  // CHECK-DAG: %[[GETEXTENT1:.+]] = tensor.extract %[[SHAPEOF]][%[[CST1_0]]]
3164  // CHECK-DAG: %[[CST2:.+]] = arith.constant 2
3165  // CHECK-DAG: %[[GETEXTENT2:.+]] = tensor.extract %[[SHAPEOF]][%[[CST2]]]
3166  // CHECK-DAG: %[[TOEXTENTS:.+]] = tensor.from_elements %[[GETEXTENT0]], %[[CST1]], %[[GETEXTENT1]], %[[GETEXTENT2]]
3167  // CHECK-DAG: %[[RESHAPE:.+]] = mhlo.dynamic_reshape %arg0, %[[TOEXTENTS]]
3168  %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<?x?x4xf32>, tensor<1xi32>) -> tensor<?x1x?x4xf32>
3169
3170  // CHECK: return %[[RESHAPE]]
3171  func.return %0 : tensor<?x1x?x4xf32>
3172}
3173
3174// -----
3175
3176// CHECK-LABEL: func @sign
3177// CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32>
3178func.func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
3179  // CHECK: [[SIGN:%.*]] = mhlo.sign [[ARG]]
3180  // CHECK: return [[SIGN]] : tensor<1x2x3x4xf32>
3181  %0 = "tf.Sign"(%arg0) : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>)
3182  func.return %0 : tensor<1x2x3x4xf32>
3183}
3184
3185// -----
3186
3187// CHECK-LABEL: func @sign_dynamic
3188func.func @sign_dynamic(%arg0: tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xf32> {
3189  // CHECK: [[SIGN:%.*]] = mhlo.sign %arg0 : tensor<?x2x3x?xf32>
3190  // CHECK: return [[SIGN]] : tensor<?x2x3x?xf32>
3191  %0 = "tf.Sign"(%arg0) : (tensor<?x2x3x?xf32>) -> (tensor<?x2x3x?xf32>)
3192  func.return %0 : tensor<?x2x3x?xf32>
3193}
3194
3195// -----
3196
3197// CHECK-LABEL: slice_constant_start
3198func.func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
3199  // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<i64>
3200  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
3201  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
3202  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} :
3203  // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64>
3204  // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64>
3205  // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]])
3206  // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} :
3207  // CHECK-DAG-SAME: (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
3208  // CHECK: return %[[RESULT]] : tensor<2xi32>
3209  %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
3210  %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>)
3211  %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32>
3212  func.return %0 : tensor<2xi32>
3213}
3214
3215// -----
3216
3217// CHECK-LABEL: slice_i32_consts
3218func.func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> {
3219  // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<i32>
3220  // CHECK: "mhlo.dynamic_slice"(%arg0, %[[START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i32>) -> tensor<2xi32>
3221  %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
3222  %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>)
3223  %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
3224  func.return %0 : tensor<2xi32>
3225}
3226
3227// -----
3228
3229// CHECK-LABEL: slice_constant_start_negative_one_size
3230func.func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> {
3231  // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<i64>
3232  // CHECK: %[[RESULT:.*]] =  "mhlo.dynamic_slice"(%arg0, %[[START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<3xi32>
3233  // CHECK: return %[[RESULT]] : tensor<3xi32>
3234  %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
3235  %sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
3236  %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi32>
3237  func.return %0 : tensor<3xi32>
3238}
3239
3240// -----
3241
3242// CHECK-LABEL: slice_constant_start_dynamic_shape
3243func.func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
3244  // CHECK-DAG: %[[START1:.*]] = mhlo.constant dense<1> : tensor<i64>
3245  // CHECK-DAG: %[[START2:.*]] = mhlo.constant dense<0> : tensor<i64>
3246  // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"
3247  // CHECK-DAG-SAME: (%arg0, %[[START1]], %[[START2]])
3248  // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} :
3249  // CHECK-DAG-SAME: (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
3250  // CHECK: return %[[RESULT]] : tensor<1x4xi32>
3251  %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
3252  %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
3253  %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<?x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
3254  func.return %0 : tensor<1x4xi32>
3255}
3256
3257// -----
3258
3259// CHECK-LABEL: slice_variable_start
3260func.func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
3261  // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%arg1)
3262  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
3263  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
3264  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
3265  // CHECK: %[[RESHAPED_START1:.*]] = mhlo.reshape %[[SLICED_START1]] : (tensor<1xi64>) -> tensor<i64>
3266  // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%arg1)
3267  // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>,
3268  // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>,
3269  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
3270  // CHECK: %[[RESHAPED_START2:.*]] = mhlo.reshape %[[SLICED_START2]] : (tensor<1xi64>) -> tensor<i64>
3271  // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
3272  // CHECK: return %[[RESULT]] : tensor<1x4xi32>
3273  %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
3274  %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
3275  func.return %0 : tensor<1x4xi32>
3276}
3277
3278// -----
3279
3280// CHECK-LABEL: slice_mhlo_sizes
3281func.func @slice_mhlo_sizes(%arg0: tensor<1x1024x4xf32>, %arg1: tensor<3xi32>) -> tensor<1x512x4xf32> {
3282  // CHECK-NOT: "tf.Slice"
3283  %0 = "mhlo.constant"() {value = dense<[1, 512, 4]> : tensor<3xi32>} : () -> tensor<3xi32>
3284  %1 = "tf.Slice"(%arg0, %arg1, %0) : (tensor<1x1024x4xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x4xf32>
3285  func.return %1 : tensor<1x512x4xf32>
3286}
3287
3288// -----
3289
3290// CHECK-LABEL: slice_variable_start_negative_one_size
3291func.func @slice_variable_start_negative_one_size(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
3292  // CHECK: %[[RESULT:.*]] = "tf.Slice"
3293  // CHECK: return %[[RESULT]] : tensor<1x4xi32>
3294  %sizes = "tf.Const"() {value = dense<[1, -1]> : tensor<2xi64>} : () -> (tensor<2xi64>)
3295  %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
3296  func.return %0 : tensor<1x4xi32>
3297}
3298
3299// -----
3300
3301// CHECK-LABEL: slice_real_dynamic_slice
3302func.func @slice_real_dynamic_slice(%arg0: tensor<4xi32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>) -> tensor<*xi32> {
3303  // CHECK: tensor.extract {{.*}} : tensor<1xi64>
3304  // CHECK: tensor.extract {{.*}} : tensor<1xi64>
3305  // CHECK: arith.index_cast {{.*}} : index to i64
3306  // CHECK: arith.cmpi eq, {{.*}} : i64
3307  // CHECK: arith.addi {{.*}} : i64
3308  // CHECK: tensor.dim {{.*}} : tensor<4xi32>
3309  // CHECK: arith.index_cast {{.*}} : index to i64
3310  // CHECK: select {{.*}} : i64
3311  // CHECK: arith.index_cast {{.*}} : i64 to index
3312  // CHECK: arith.index_cast {{.*}} : i64 to index
3313  // CHECK: tensor.from_elements {{.*}} : tensor<1xindex>
3314  // CHECK: tensor.from_elements {{.*}} : tensor<1xindex>
3315  // CHECK: tensor.from_elements {{.*}} : tensor<1xindex>
3316  %0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<*xi32>
3317  func.return %0 : tensor<*xi32>
3318}
3319
3320//===----------------------------------------------------------------------===//
3321// StridedSlice op legalizations.
3322//===----------------------------------------------------------------------===//
3323
3324// -----
3325
3326// CHECK-LABEL: simple_strided_slice
3327func.func @simple_strided_slice(%input: tensor<4x8xf32>) -> tensor<3x2xf32> {
3328  %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>)
3329  %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>)
3330  %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
3331
3332  // CHECK: mhlo.slice
3333  // CHECK-DAG-SAME: start_indices = dense<[0, 1]>
3334  // CHECK-DAG-SAME: limit_indices = dense<[3, 7]>
3335  // CHECK-DAG-SAME: strides = dense<[1, 3]>
3336  // CHECK-SAME: -> tensor<3x2xf32>
3337
3338  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
3339      : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32>
3340  func.return %output : tensor<3x2xf32>
3341}
3342
3343// -----
3344
3345// CHECK-LABEL: dynamic_strided_slice
3346func.func @dynamic_strided_slice(%input: tensor<?x8xf32>) -> tensor<?x2xf32> {
3347  %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>)
3348  %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>)
3349  %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
3350
3351  // CHECK: "tf.StridedSlice"
3352  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
3353      : (tensor<?x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x2xf32>
3354  func.return %output : tensor<?x2xf32>
3355}
3356
3357// -----
3358
3359// CHECK-LABEL: strided_slice_negative_indices
3360func.func @strided_slice_negative_indices(%input: tensor<4x8xf32>) -> tensor<3x2xf32> {
3361  %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>)
3362  %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>)
3363  %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
3364
3365  // CHECK: "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>}
3366
3367  // CHECK: mhlo.slice
3368  // CHECK-DAG-SAME: start_indices = dense<[0, 1]>
3369  // CHECK-DAG-SAME: limit_indices = dense<[3, 7]>
3370  // CHECK-DAG-SAME: strides = dense<[1, 3]>
3371  // CHECK-SAME: -> tensor<3x2xf32>
3372
3373  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
3374      : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32>
3375  func.return %output : tensor<3x2xf32>
3376}
3377
3378// -----
3379
3380// CHECK-LABEL: dynamic_strided_slice_negative_indices
3381func.func @dynamic_strided_slice_negative_indices(%input: tensor<?x8xf32>) -> tensor<?x2xf32> {
3382  %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>)
3383  %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>)
3384  %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
3385
3386  // CHECK: tf.StridedSlice
3387  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
3388      : (tensor<?x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x2xf32>
3389  func.return %output : tensor<?x2xf32>
3390}
3391
3392// -----
3393
3394// CHECK-LABEL: strided_slice_range_clamping
3395func.func @strided_slice_range_clamping(%input: tensor<4x8xf32>) -> tensor<1x3xf32> {
3396  %begin = "tf.Const"() {value = dense<[-4, -10]> : tensor<2xi32>} : () -> (tensor<2xi32>)
3397  %end = "tf.Const"() {value = dense<[1, 10]> : tensor<2xi32>} : () -> (tensor<2xi32>)
3398  %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
3399
3400  // CHECK: mhlo.slice
3401  // CHECK-DAG-SAME: start_indices = dense<[0, 0]>
3402  // CHECK-DAG-SAME: limit_indices = dense<[1, 8]>
3403  // CHECK-DAG-SAME: strides = dense<[1, 3]>
3404  // CHECK-SAME: -> tensor<1x3xf32>
3405  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
3406      : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x3xf32>
3407  func.return %output : tensor<1x3xf32>
3408}
3409
3410// -----
3411
3412// CHECK-LABEL: strided_slice_empty
3413func.func @strided_slice_empty(%input: tensor<4xf32>) -> tensor<0xf32> {
3414  %begin = "tf.Const"() {value = dense<[-4]> : tensor<1xi32>} : () -> (tensor<1xi32>)
3415  %end = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
3416  %strides = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
3417
3418  // CHECK: mhlo.constant dense<> : tensor<0xf32>
3419  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
3420      : (tensor<4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf32>
3421  func.return %output : tensor<0xf32>
3422}
3423
3424// -----
3425
3426// CHECK-LABEL: strided_slice_begin_end_mask
3427// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<4x128x1024xf32>
3428func.func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) {
3429
3430  // For StridedSlice
3431  // Dim #:        0,   1,    2
3432  // Input shape: [4, 128, 1024]
3433  // Begin:        1,   4,   -3
3434  // End:          8,  65,   42
3435  // Stride:       1,   4,   -1
3436  // Begin mask:   0,   0,    1  (= 1)
3437  // End mask:     1,   0,    0  (= 4)
3438
3439  // So result shape:
3440  // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4
3441  // Dim #1: 4 to 65 stride 4: so 16
3442  // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022
3443  // result shape: [4, 16, 1022]
3444
3445  %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>)
3446  %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>)
3447  %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>)
3448
3449  // CHECK: %[[REVERSE:.*]] = "mhlo.reverse"(%[[INPUT]])
3450
3451  // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[REVERSE]])
3452  // CHECK-DAG-SAME: limit_indices = dense<[4, 65, 1024]>
3453  // CHECK-DAG-SAME: start_indices = dense<[0, 4, 2]>
3454  // CHECK-DAG-SAME: strides = dense<[1, 4, 1]>
3455  // CHECK-SAME: -> tensor<4x16x1022xf32>
3456
3457  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x16x1022xf32>
3458
3459  // CHECK: mhlo.reshape %[[SLICE]]
3460  // CHECK-SAME: -> tensor<4x16x1022xf32>
3461
3462  func.return
3463}
3464
3465// -----
3466
3467// CHECK-LABEL: strided_slice_shrink_axis_mask
3468// CHECK-SAME: %[[INPUT:.+]]: tensor<4x128x1024xf32>
3469func.func @strided_slice_shrink_axis_mask(%input: tensor<4x128x1024xf32>) {
3470
3471  // For StridedSlice
3472  // Dim #:            0,   1,    2
3473  // Input shape:     [4, 128, 1024]
3474  // Begin:            1,   4,   -3
3475  // End:              8,  65,   42
3476  // Stride:           1,   4,   -1
3477  // Begin mask:       1,   0,    0  (= 1)
3478  // End mask:         0,   0,    1  (= 4)
3479  // Shrink axis mask: 1,   0,    1  (= 5)
3480
3481  // So result shape:
3482  // Dim #0: shrink axis, take value at [1]
3483  // Dim #1: 4 to 65 stride 4: so 16
3484  // Dim #2: shrink axis, take value at [-3]
3485  // result shape: [16]
3486
3487  // As output shape of StridedSlice differs, a reshape will follow.
3488
3489  %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>)
3490  %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>)
3491  %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>)
3492
3493  // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]])
3494  // CHECK-DAG-SAME: limit_indices = dense<[1, 65, 1022]>
3495  // CHECK-DAG-SAME: start_indices = dense<[0, 4, 1021]>
3496  // CHECK-DAG-SAME: strides = dense<[1, 4, 1]>
3497  // CHECK-SAME: -> tensor<1x16x1xf32>
3498
3499  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4, shrink_axis_mask = 5} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<16xf32>
3500
3501  // CHECK: mhlo.reshape %[[SLICE]]
3502  // CHECK-SAME: -> tensor<16xf32>
3503
3504  func.return
3505}
3506
3507// -----
3508
3509// CHECK-LABEL: strided_slice_ellipsis_mask
3510// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32>
3511func.func @strided_slice_ellipsis_mask(%input: tensor<2x4x8x16x32x64xf32>) {
3512  // For StridedSlice input[1, ..., 8:, :10, 2:6:2]
3513  // The ellipsis mask is applied to dim #1, #2, i.e, we get canonicalized
3514  // slice input[1, :, :, 8:, :10, 2:6:2]
3515
3516  // The start, limit indices and strides attributes of mhlo.slice would
3517  // reflect the canonicalized slice.
3518  // As output shape of StridedSlice differs, a reshape will follow.
3519
3520  %begin = "tf.Const"() {value = dense<[1, 0, 8, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>)
3521  %end = "tf.Const"() {value = dense<[2, 0, 10, 10, 6]> : tensor<5xi32>} : () -> (tensor<5xi32>)
3522  %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>)
3523
3524  // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]])
3525  // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64>
3526  // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64>
3527  // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64>
3528  // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32>
3529  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 8, end_mask = 4, shrink_axis_mask = 1, ellipsis_mask = 2} : (tensor<2x4x8x16x32x64xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<4x8x8x10x2xf32>
3530
3531  // CHECK: mhlo.reshape %[[SLICE]]
3532  // CHECK-SAME: -> tensor<4x8x8x10x2xf32>
3533
3534  func.return
3535}
3536
3537// -----
3538
3539// CHECK-LABEL: strided_slice_new_axis_mask
3540// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32>
3541func.func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) {
3542  // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis]
3543  // New axis mask is at index 1 and 6 of sparse spec, so
3544  // new_axis_mask = 2^1 + 2^6 = 66
3545  // The ellipsis mask is applied to dim #1, #2 of input i.e, we get
3546  // canonicalized slice input[1, :, :, 8:, :10, 2:6:2]
3547  // This is then reshaped to add the new axes.
3548
3549  // The start, limit indices and strides attributes of mhlo.slice would
3550  // reflect the canonicalized slice.
3551  // As output shape of StridedSlice differs, a reshape will follow to reflect
3552  // new axes added.
3553
3554  %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>)
3555  %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>)
3556  %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>)
3557
3558  // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]])
3559  // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64>
3560  // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64>
3561  // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64>
3562  // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32>
3563  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<2x4x8x16x32x64xf32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>) -> tensor<1x4x8x8x10x2x1xf32>
3564
3565  // CHECK: mhlo.reshape %[[SLICE]]
3566  // CHECK-SAME: -> tensor<1x4x8x8x10x2x1xf32>
3567
3568  func.return
3569}
3570
3571// -----
3572
3573// CHECK-LABEL: strided_slice_implicit_ellipsis_mask(
3574// CHECK-SAME: [[INPUT:%.*]]: tensor<10x16x2xf32>
3575func.func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tensor<2x16x2xf32> {
3576  // StridedSlice gets input[8:10], which is same as input[8:10, ...]
3577  // The start_indices, limit_indices, and strides attribute of mhlo.slice
3578  // reflect the canonicalized slice.
3579  %begin = "tf.Const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32>
3580  %end = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32>
3581  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3582  // CHECK: [[SLICE:%.*]] = "mhlo.slice"([[INPUT]])
3583  // CHECK-DAG-SAME: limit_indices = dense<[10, 16, 2]> : tensor<3xi64>
3584  // CHECK-DAG-SAME: start_indices = dense<[8, 0, 0]> : tensor<3xi64>
3585  // CHECK-DAG-SAME: strides = dense<1> : tensor<3xi64>
3586  // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[SLICE]] : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32>
3587  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = f32} : (tensor<10x16x2xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x16x2xf32>
3588  // CHECK: return [[RESHAPE]] : tensor<2x16x2xf32>
3589  func.return %0 : tensor<2x16x2xf32>
3590}
3591
3592// -----
3593
3594// CHECK-LABEL: strided_slice_nonconstant_begin_end
3595func.func @strided_slice_nonconstant_begin_end(%arg0: tensor<i32>, %arg1: tensor<32x1x97xi32>) -> (tensor<1x97xi32>) {
3596  // In this case, the `begin` and `end` inputs are unknown at compile time --
3597  // so the StridedSlice needs to slice these vectors and use that as input to
3598  // an HLO dynamic slice.
3599  %begin = "tf.Pack"(%arg0) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
3600  %0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
3601  %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3602  %2 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
3603  %end = "tf.Pack"(%2) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
3604  // CHECK: %[[A:.*]] = mhlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
3605  // CHECK-NEXT: %[[BEGIN:.*]] = "mhlo.concatenate"(%[[A]])
3606  // CHECK-DAG-SAME: {dimension = 0 : i64} : (tensor<1xi32>) -> tensor<1xi32>
3607  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
3608  // CHECK-NEXT: %[[INDEX:.*]] = "mhlo.slice"(%[[BEGIN]])
3609  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
3610  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
3611  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32>
3612  // CHECK-NEXT: %[[INDEX2:.*]] = mhlo.reshape %[[INDEX]] : (tensor<1xi32>) -> tensor<i32>
3613  // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]]
3614  // CHECK-DAG-SAME: {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
3615  // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor<i32>
3616  // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
3617  // CHECK-NEXT: %[[INDEX3:.*]] = "mhlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) :
3618  // CHECK-DAG-SAME: (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
3619  // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic_slice"
3620  // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]])
3621  // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} :
3622  // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1x1x97xi32>
3623  // CHECK-NEXT: %[[FINAL:.*]] = mhlo.reshape %[[SLICED]] : (tensor<1x1x97xi32>) -> tensor<1x97xi32>
3624  %result = "tf.StridedSlice"(%arg1, %begin, %end, %1) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3625  // CHECK-NEXT: return %[[FINAL]] : tensor<1x97xi32>
3626  func.return %result : tensor<1x97xi32>
3627}
3628
3629// -----
3630
3631// CHECK-LABEL: strided_slice_nonconstant_begin_end_with_start_end_mask
3632// CHECK-SAME: (%[[INPUT:.*]]: tensor<32x1x97xi32>, %[[BEGIN:.*]]: tensor<3xi32>, %[[END:.*]]: tensor<3xi32>)
3633func.func @strided_slice_nonconstant_begin_end_with_start_end_mask(%input: tensor<32x1x97xi32>, %begin: tensor<3xi32>, %end: tensor<3xi32>) -> (tensor<1x97xi32>) {
3634  %strides = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32>
3635
3636  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
3637  // CHECK: %[[INDEX:.*]] = "mhlo.slice"(%[[BEGIN]])
3638  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>
3639  // CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64>
3640  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>
3641  // CHECK-NEXT: %[[INDEX2:.*]] = mhlo.reshape %[[INDEX]] : (tensor<1xi32>) -> tensor<i32>
3642  // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]]
3643  // CHECK-DAG-SAME: {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
3644  // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor<i32>
3645  // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
3646  // CHECK-NEXT: %[[INDEX3:.*]] = "mhlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) :
3647  // CHECK-DAG-SAME: (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
3648  // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic_slice"
3649  // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]])
3650  // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} :
3651  // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1x1x97xi32>
3652  // CHECK-NEXT: %[[FINAL:.*]] = mhlo.reshape %[[SLICED]] : (tensor<1x1x97xi32>) -> tensor<1x97xi32>
3653  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x97xi32>
3654  func.return %result : tensor<1x97xi32>
3655}
3656
3657// -----
3658
3659// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_1
3660func.func @strided_slice_nonconstant_begin_end_stride_1(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>, %strides: tensor<1xi32>) -> (tensor<1x97xi32>) {
3661  // Dynamic stride: when `begin` and `end` inputs are unknown at compile time,
3662  // `strides` must be known.
3663  // CHECK: tf.StridedSlice
3664  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3665  func.return %result : tensor<1x97xi32>
3666}
3667
3668// -----
3669
3670// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_2
3671func.func @strided_slice_nonconstant_begin_end_stride_2(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3672  // Invalid stride (not equal to 1): when `begin` and `end` inputs are unknown
3673  // at compile time, `strides` must be known to have all 1 values.
3674  %strides = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
3675  // CHECK: tf.StridedSlice
3676  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3677  func.return %result : tensor<1x97xi32>
3678}
3679
3680// -----
3681
3682// CHECK-LABEL: strided_slice_nonconstant_begin_end_invalid_elem_count
3683func.func @strided_slice_nonconstant_begin_end_invalid_elem_count(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>) -> tensor<6x10xf32> {
3684  %strides = "tf.Const"() { value = dense<[1, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
3685  // When begin/end are dynamic, the number of output elements must be equal to
3686  // the number of input elements sliced.
3687  // CHECK: tf.StridedSlice
3688  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<6x10xf32>
3689  func.return %0 : tensor<6x10xf32>
3690}
3691
3692// -----
3693
3694// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_new_axis_mask
3695func.func @strided_slice_nonconstant_begin_end_and_new_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3696  // New axis mask: When `begin` and `end` inputs are unknown at compile time,
3697  // we can't support a new_axis mask.
3698  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3699  // CHECK: tf.StridedSlice
3700  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 15 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3701  func.return %result : tensor<1x97xi32>
3702}
3703
3704// -----
3705
3706// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_ellipsis_mask
3707func.func @strided_slice_nonconstant_begin_end_and_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3708  // This ellipsis mask is not supported because it does not refer to the last
3709  // dimension.
3710  // [0, 1, 0] = 2
3711  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3712  // CHECK: tf.StridedSlice
3713  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3714  func.return %result : tensor<1x97xi32>
3715}
3716
3717// -----
3718
3719// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask
3720func.func @strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3721  // This ellipsis mask is supported because it refers to the last dimension.
3722  // [1, 0, 0] = 4
3723  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3724  // CHECK: mhlo.dynamic_slice
3725  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 4 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3726  func.return %result : tensor<1x97xi32>
3727}
3728
3729// -----
3730
3731// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask
3732func.func @strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3733  // This shrink_axis mask is supported because it refers to a major dimension.
3734  // [1, 1, 1] = 7
3735  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3736  // CHECK: mhlo.dynamic_slice
3737  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 7 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3738  func.return %result : tensor<1x97xi32>
3739}
3740
3741// -----
3742
3743// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask
3744func.func @strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3745  // This shrink_axis mask is unsupported because it does not refer to a major
3746  // dimension.
3747  // [0, 1, 0] = 2
3748  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3749  // CHECK: tf.StridedSlice
3750  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3751  func.return %result : tensor<1x97xi32>
3752}
3753
3754
3755//===----------------------------------------------------------------------===//
3756// Reduction op legalizations.
3757//===----------------------------------------------------------------------===//
3758
3759// -----
3760
3761// CHECK-LABEL: func @mean
3762func.func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3763  // CHECK: %[[CAST:.*]] = mhlo.convert(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32>
3764  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32>
3765  // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
3766  // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %{{.*}} {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
3767  // CHECK: %[[CAST_BACK:.*]] = mhlo.convert(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16>
3768  // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16>
3769  // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3770  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3771  %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3772  func.return %0 : tensor<4x1xf16>
3773}
3774
3775// -----
3776
3777// CHECK-LABEL: func @mean_scalar_dim
3778func.func @mean_scalar_dim(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3779  // Verify that tf.Mean op with scalar attributes are lowered successfully.
3780
3781  // CHECK-NOT: tf.Mean
3782  %dimension = "tf.Const"() { value = dense<1> : tensor<i64> } : () -> tensor<i64>
3783  %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<i64>) -> tensor<4x1xf16>
3784  func.return %0 : tensor<4x1xf16>
3785}
3786
3787// -----
3788
3789// CHECK-LABEL: func @mean_dynamic
3790func.func @mean_dynamic(%arg0: tensor<?x?xf16>) -> tensor<?x1xf16> {
3791  // CHECK: %[[CAST:.*]] = mhlo.convert(%arg0) : (tensor<?x?xf16>) -> tensor<?x?xf32>
3792  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32>
3793  // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
3794  // CHECK: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor<?x?xf16> -> tensor<2xindex>
3795  // CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index
3796  // CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
3797  // CHECK: %[[REDUCED_DIM:.*]] = tensor.extract %[[SHAPE0]][%[[C1_2]]] : tensor<2xindex>
3798  // CHECK: %[[MUL:.*]] = arith.muli %[[C1_1]], %[[REDUCED_DIM]] : index
3799  // CHECK: %[[INDEX_CAST:.*]] = arith.index_cast %[[MUL]] : index to i64
3800  // CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[INDEX_CAST]] : tensor<i64>
3801  // CHECK: %[[CONVERT:.*]] = mhlo.convert(%[[TENSOR]]) : (tensor<i64>) -> tensor<f32>
3802  // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[CONVERT]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
3803  // CHECK: %[[MEAN_CONVERTED:.*]] = mhlo.convert(%[[MEAN]]) : (tensor<?xf32>) -> tensor<?xf16>
3804  // CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[MEAN_CONVERTED]] : tensor<?xf16> -> tensor<1xindex>
3805  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
3806  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
3807  // CHECK: %[[UNREDUCED_DIM:.*]] = tensor.extract %[[SHAPE1]][%[[C0]]] : tensor<1xindex>
3808  // CHECK: %[[RESULT_SHAPE:.*]] = tensor.from_elements %[[UNREDUCED_DIM]], %[[C1]] : tensor<2xindex>
3809  // CHECK: %[[RESULT:.*]] = mhlo.dynamic_reshape %[[MEAN_CONVERTED]], %[[RESULT_SHAPE]] : (tensor<?xf16>, tensor<2xindex>) -> tensor<?x1xf16>
3810  // CHECK: return %[[RESULT]] : tensor<?x1xf16>
3811  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3812  %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<?x?xf16>, tensor<1xi64>) -> tensor<?x1xf16>
3813  func.return %0 : tensor<?x1xf16>
3814}
3815
3816// -----
3817
3818// CHECK-LABEL: func @sum
3819func.func @sum(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3820  // CHECK: %[[CAST:.*]] = mhlo.convert(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32>
3821  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32>
3822  // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
3823  // CHECK: %[[CAST_BACK:.*]] = mhlo.convert(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16>
3824  // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16>
3825  // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3826  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3827  %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3828  func.return %0 : tensor<4x1xf16>
3829}
3830
3831// -----
3832
3833// CHECK-LABEL: func @sum_dynamic
3834func.func @sum_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> {
3835    // CHECK: %[[CAST:.*]] = mhlo.convert(%arg0) : (tensor<4x?xf16>) -> tensor<4x?xf32>
3836    // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32>
3837    // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x?xf32>, tensor<f32>) -> tensor<4xf32>
3838    // CHECK: %[[CAST_BACK:.*]] = mhlo.convert(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16>
3839    // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16>
3840    // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3841  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3842  %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3843  func.return %0 : tensor<4x1xf16>
3844}
3845
3846// -----
3847
3848// CHECK-LABEL: func @max
3849func.func @max(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3850  // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : tensor<4x8xf16>
3851  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor<f16>
3852  // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.maximum across dimensions = [1] : (tensor<4x8xf16>, tensor<f16>) -> tensor<4xf16>
3853  // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : tensor<4xf16>
3854  // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16>
3855  // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3856  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3857  %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3858  func.return %0 : tensor<4x1xf16>
3859}
3860
3861// -----
3862
3863// CHECK-LABEL: func @max_qint
3864// Regression test to ensure we don't crash getting the initial value for
3865// tf.Max when using quantized integer types.
3866func.func @max_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> {
3867  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3868  %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8>
3869  func.return %0 : tensor<4x1x!tf_type.qint8>
3870}
3871
3872// -----
3873
3874// CHECK-LABEL: func @max_dynamic
3875func.func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> {
3876    // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : tensor<4x?xf16>
3877    // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor<f16>
3878    // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.maximum across dimensions = [1] : (tensor<4x?xf16>, tensor<f16>) -> tensor<4xf16>
3879    // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : tensor<4xf16>
3880    // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16>
3881    // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3882  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3883  %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3884  func.return %0 : tensor<4x1xf16>
3885}
3886
3887// -----
3888
3889// CHECK-LABEL: func @min
3890func.func @min(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3891  // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : tensor<4x8xf16>
3892  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0x7C00> : tensor<f16>
3893  // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.minimum across dimensions = [1] : (tensor<4x8xf16>, tensor<f16>) -> tensor<4xf16>
3894  // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : tensor<4xf16>
3895  // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16>
3896  // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3897  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3898  %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3899  func.return %0 : tensor<4x1xf16>
3900}
3901
3902// -----
3903
3904// CHECK-LABEL: func @min_qint
3905// Regression test to ensure we don't crash getting the initial value for
3906// tf.Min when using quantized integer types.
3907func.func @min_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> {
3908  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3909  %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8>
3910  func.return %0 : tensor<4x1x!tf_type.qint8>
3911}
3912
3913// -----
3914
3915// CHECK-LABEL: func @prod
3916func.func @prod(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3917  // CHECK: %[[CAST:.*]] = mhlo.convert(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32>
3918  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
3919  // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.multiply across dimensions = [1] : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
3920  // CHECK: %[[CAST_BACK:.*]] = mhlo.convert(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16>
3921  // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16>
3922  // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3923  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3924  %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3925  func.return %0 : tensor<4x1xf16>
3926}
3927
3928// -----
3929
3930// CHECK-LABEL: func @prod_qint
3931// Regression test to ensure we don't crash getting the initial value for
3932// tf.Prod when using quantized integer types.
3933func.func @prod_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> {
3934  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3935  %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8>
3936  func.return %0 : tensor<4x1x!tf_type.qint8>
3937}
3938
3939// -----
3940
3941// CHECK-LABEL: @all
3942func.func @all(%input: tensor<4x8xi1>) -> tensor<4xi1> {
3943  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3944  // CHECK: %[[INIT:.*]] = mhlo.constant dense<true> : tensor<i1>
3945  // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%{{.*}} init: %[[INIT]]) applies mhlo.and across dimensions = [1] : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1>
3946  %0 = "tf.All"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1>
3947  func.return %0 : tensor<4xi1>
3948}
3949
3950// -----
3951
3952// CHECK-LABEL: @all_keep_dim
3953func.func @all_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> {
3954  // CHECK: mhlo.reshape %{{.*}} : (tensor<4xi1>) -> tensor<4x1xi1>
3955  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3956  %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1>
3957  func.return %0 : tensor<4x1xi1>
3958}
3959
3960// -----
3961
3962// CHECK-LABEL: @all_dynamic
3963func.func @all_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> {
3964  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3965  // CHECK: %[[ARG:.*]] = mhlo.convert %{{.*}} : tensor<4x?xi1>
3966  // CHECK: mhlo.reduce(%[[ARG]]
3967  %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1>
3968  func.return %0 : tensor<4x1xi1>
3969}
3970
3971// -----
3972
3973// CHECK-LABEL: @any
3974func.func @any(%input: tensor<4x8xi1>) -> tensor<4xi1> {
3975  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3976  // CHECK: %[[INIT:.*]] = mhlo.constant dense<false> : tensor<i1>
3977  // CHECK: mhlo.reduce(%{{.*}} init: %[[INIT]]) applies mhlo.or across dimensions = [1] : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1>
3978  %0 = "tf.Any"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1>
3979  func.return %0 : tensor<4xi1>
3980}
3981
3982// -----
3983
3984// CHECK-LABEL: @any_keep_dim
3985func.func @any_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> {
3986  // CHECK: mhlo.reshape %{{.*}} : (tensor<4xi1>) -> tensor<4x1xi1>
3987  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3988  %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1>
3989  func.return %0 : tensor<4x1xi1>
3990}
3991
3992// -----
3993
3994// CHECK-LABEL: @any_dynamic
3995func.func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> {
3996  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3997  // CHECK: %[[ARG:.*]] = mhlo.convert %{{.*}} : tensor<4x?xi1>
3998  // CHECK: mhlo.reduce(%[[ARG]]
3999  %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1>
4000  func.return %0 : tensor<4x1xi1>
4001}
4002
4003//===----------------------------------------------------------------------===//
4004// Tile op legalizations.
4005//===----------------------------------------------------------------------===//
4006
4007// -----
4008
4009// CHECK-LABEL: func @tile_by_reshape
4010func.func @tile_by_reshape(%arg0: tensor<4x8xf32>) -> tensor<28x24xf32> {
4011  // CHECK: %[[BROADCASTED:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<4x8xf32>) -> tensor<7x4x3x8xf32>
4012  // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[BROADCASTED]] : (tensor<7x4x3x8xf32>) -> tensor<28x24xf32>
4013  // CHECK: return %[[RESULT]] : tensor<28x24xf32>
4014  %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64>
4015  %0 = "tf.Tile"(%arg0, %multiples) : (tensor<4x8xf32>, tensor<2xi64>) -> tensor<28x24xf32>
4016  func.return %0 : tensor<28x24xf32>
4017}
4018
4019// -----
4020
4021// CHECK-LABEL: func @tile_just_broadcast
4022func.func @tile_just_broadcast(%arg0: tensor<1x1xf32>) -> tensor<7x3xf32> {
4023  // CHECK: %[[RESULT:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<7x3xf32>
4024  // CHECK: return %[[RESULT]] : tensor<7x3xf32>
4025  %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64>
4026  %0 = "tf.Tile"(%arg0, %multiples) : (tensor<1x1xf32>, tensor<2xi64>) -> tensor<7x3xf32>
4027  func.return %0 : tensor<7x3xf32>
4028}
4029
4030// -----
4031
4032// CHECK-LABEL: func @tile_dynamic_shape
4033func.func @tile_dynamic_shape(%arg0: tensor<?x8xf32>) -> tensor<?x24xf32> {
4034  %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi32> } : () -> tensor<2xi32>
4035  // CHECK: tensor.dim {{.*}} : tensor<?x8xf32>
4036  // CHECK: tensor.from_elements  {{.*}} : tensor<4xindex>
4037  // CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<?x8xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
4038  // CHECK: muli {{.*}} : index
4039  // CHECK: tensor.from_elements {{.*}} : tensor<2xindex>
4040  // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor<?x?x?x?xf32>, tensor<2xindex>) -> tensor<?x24xf32>
4041  %0 = "tf.Tile"(%arg0, %multiples) : (tensor<?x8xf32>, tensor<2xi32>) -> tensor<?x24xf32>
4042  func.return %0 : tensor<?x24xf32>
4043}
4044
4045//===----------------------------------------------------------------------===//
4046// ArgMax/ArgMin op legalizations.
4047//===----------------------------------------------------------------------===//
4048
4049// -----
4050
4051// CHECK-LABEL: func @argmax_i64_input_i32_output_axis_0
4052func.func @argmax_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> {
4053  // CHECK: %[[INIT:.*]] = mhlo.constant dense<-9223372036854775808> : tensor<i64>
4054  // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32>
4055  // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi64> -> tensor<2xindex>
4056  // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<3x7xi32>
4057  // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]])
4058  // CHECK: (%[[ARG1:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<i64>) (%[[ARG2:.*]]: tensor<i32>, %[[ARG4:.*]]: tensor<i32>)
4059  // CHECK: %[[COMPARE:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor<i64>, tensor<i64>) -> tensor<i1>
4060  // CHECK:  %[[RESULT1:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG1]], %[[ARG3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
4061  // CHECK: %[[COMPARE_EQ:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor<i64>, tensor<i64>) -> tensor<i1>
4062  // CHECK:  %[[MIN:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]]
4063  // CHECK:  %[[RESULT2:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG2]], %[[ARG4]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
4064  // CHECK:  %[[RESULT3:.*]] = "mhlo.select"(%[[COMPARE_EQ]], %[[MIN]], %[[RESULT2]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
4065  // CHECK: mhlo.return %[[RESULT1]], %[[RESULT3]] : tensor<i64>, tensor<i32>
4066  // CHECK: return %[[REDUCE]]#1 : tensor<7xi32>
4067  %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
4068  %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi64>, tensor<i32>) -> tensor<7xi32>
4069  func.return %0 : tensor<7xi32>
4070}
4071
4072// -----
4073
4074// CHECK-LABEL: func @argmax_f32_input_i64_output_axis_1
4075func.func @argmax_f32_input_i64_output_axis_1(%arg0: tensor<3x7xf32>) -> tensor<3xi64> {
4076  // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32>
4077  // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant  dense<0> : tensor<i64>
4078  // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xf32> -> tensor<2xindex>
4079  // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<3x7xi64>
4080  // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]])
4081  // CHECK: return %[[REDUCE]]#1 : tensor<3xi64>
4082  %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
4083  %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xf32>, tensor<i32>) -> tensor<3xi64>
4084  func.return %0 : tensor<3xi64>
4085}
4086
4087// -----
4088
4089// CHECK-LABEL: func @argmax_i1_input_i64_output_axis_1
4090func.func @argmax_i1_input_i64_output_axis_1(%arg0: tensor<3x7xi1>) -> tensor<3xi64> {
4091  // CHECK-DAG: %[[INIT:.*]] = mhlo.constant dense<false> : tensor<i1>
4092  // CHECK-DAG: %[[INDEX_INIT:.*]] = mhlo.constant  dense<0> : tensor<i64>
4093  // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi1> -> tensor<2xindex>
4094  // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<3x7xi64>
4095  // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]])
4096  // CHECK: return %[[REDUCE]]#1 : tensor<3xi64>
4097  %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
4098  %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi1>, tensor<i32>) -> tensor<3xi64>
4099  func.return %0 : tensor<3xi64>
4100}
4101
4102// -----
4103
4104// CHECK-LABEL: func @argmax_dynamic_shape_input_output
4105func.func @argmax_dynamic_shape_input_output(%arg0: tensor<3x?xi32>) -> tensor<?xi32> {
4106  // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32>
4107  // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32>
4108  // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x?xi32> -> tensor<2xindex>
4109  // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<3x?xi32>
4110  // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]])
4111  // CHECK: return %[[REDUCE]]#1 : tensor<?xi32>
4112  %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
4113  %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor<i32>) -> tensor<?xi32>
4114  func.return %0 : tensor<?xi32>
4115}
4116
4117// -----
4118
4119// CHECK-LABEL: func @argmax_dynamic_shape_input
4120func.func @argmax_dynamic_shape_input(%arg0: tensor<3x?xi32>) -> tensor<3xi32> {
4121  // CHECK-DAG: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32>
4122  // CHECK-DAG: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32>
4123  // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x?xi32> -> tensor<2xindex>
4124  // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<3x?xi32>
4125  // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]])
4126  // CHECK: return %[[REDUCE]]#1 : tensor<3xi32>
4127  %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
4128  %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor<i32>) -> tensor<3xi32>
4129  func.return %0 : tensor<3xi32>
4130}
4131
4132// -----
4133
4134// CHECK-LABEL: func @argmin_i64_input_i32_output_axis_0
4135func.func @argmin_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> {
4136  // CHECK: %[[INIT:.*]] = mhlo.constant dense<9223372036854775807> : tensor<i64>
4137  // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32>
4138  // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi64> -> tensor<2xindex>
4139  // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<3x7xi32>
4140  // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]])
4141  // CHECK: (%[[ARG1:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<i64>) (%[[ARG2:.*]]: tensor<i32>, %[[ARG4:.*]]: tensor<i32>)
4142  // CHECK: %[[COMPARE:.*]] = mhlo.compare LE, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor<i64>, tensor<i64>) -> tensor<i1>
4143  // CHECK:  %[[RESULT1:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG1]], %[[ARG3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
4144  // CHECK: %[[COMPARE_EQ:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor<i64>, tensor<i64>) -> tensor<i1>
4145  // CHECK:  %[[MIN:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]]
4146  // CHECK:  %[[RESULT2:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG2]], %[[ARG4]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
4147  // CHECK:  %[[RESULT3:.*]] = "mhlo.select"(%[[COMPARE_EQ]], %[[MIN]], %[[RESULT2]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
4148  // CHECK: mhlo.return %[[RESULT1]], %[[RESULT3]] : tensor<i64>, tensor<i32>
4149  // CHECK: return %[[REDUCE]]#1 : tensor<7xi32>
4150  %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
4151  %0 = "tf.ArgMin"(%arg0, %axis) : (tensor<3x7xi64>, tensor<i32>) -> tensor<7xi32>
4152  func.return %0 : tensor<7xi32>
4153}
4154
4155//===----------------------------------------------------------------------===//
4156// Random op legalizations.
4157//===----------------------------------------------------------------------===//
4158
4159// -----
4160
4161// CHECK-LABEL: func @rng_uniform
4162func.func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> {
4163  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4164  // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
4165  // CHECK: %[[CONV:.*]] = mhlo.convert(%arg0) : (tensor<3xi32>) -> tensor<3xi64>
4166  // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*UNIFORM.*}} -> tensor<12x?x64xf32>
4167  %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32>
4168  // CHECK: return %[[F32]]
4169  func.return %0 : tensor<12x?x64xf32>
4170}
4171
4172// -----
4173
4174// CHECK-LABEL: func @rng_std_normal
4175func.func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> {
4176  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4177  // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
4178  // CHECK: %[[CONV:.*]] = mhlo.convert(%arg0) : (tensor<3xi32>) -> tensor<3xi64>
4179  // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*NORMAL.*}} -> tensor<12x?x64xf32>
4180  %0 = "tf.RandomStandardNormal"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32>
4181  // CHECK: return %[[F32]]
4182  func.return %0 : tensor<12x?x64xf32>
4183}
4184
4185//===----------------------------------------------------------------------===//
4186// Range op legalizations.
4187//===----------------------------------------------------------------------===//
4188
4189// -----
4190
4191// CHECK-LABEL: func @range
4192// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[DELTA:%.*]]: tensor<f32>
4193func.func @range(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<5xf32> {
4194  %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor<f32>} : () -> tensor<f32>
4195  // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"
4196  // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<> : tensor<0xi64>}
4197  // CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>}
4198  %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
4199  func.return %3 : tensor<5xf32>
4200}
4201
4202// -----
4203
4204// CHECK-LABEL: func @range_dynamic
4205// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[DELTA:%.*]]: tensor<f32>
4206func.func @range_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<?xf32> {
4207  // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0
4208  // CHECK-DAG: [[ABS1:%.+]] = mhlo.abs [[SUB]]
4209  // CHECK-DAG: [[CONVERT1:%.+]] = mhlo.convert [[ABS1]]
4210  // CHECK-DAG: [[CONVERT2:%.+]] = mhlo.convert %arg2
4211  // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT1]], [[CONVERT2]]
4212  // CHECK-DAG: [[CEIL:%.+]] = mhlo.ceil [[DIV]]
4213  // CHECK-DAG: [[CONVERT3:%.+]] = mhlo.convert([[CEIL]])
4214  // CHECK-DAG: [[RESHAPE:%.+]] = mhlo.reshape [[CONVERT3]]
4215  // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64}
4216  // CHECK-DAG: [[CONVERT3:%.+]] = mhlo.convert %arg0
4217  // CHECK-DAG: [[CONVERT4:%.+]] = mhlo.convert %arg2
4218  // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<> : tensor<0xi64>}
4219  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<> : tensor<0xi64>}
4220  %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
4221
4222  // CHECK: return [[ADD]]
4223  func.return %2 : tensor<?xf32>
4224}
4225
4226// -----
4227
4228// CHECK-LABEL: func @range_int_dynamic
4229// CHECK-SAME: [[START:%.*]]: tensor<i32>, [[DELTA:%.*]]: tensor<i32>
4230func.func @range_int_dynamic(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?xi32> {
4231  // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0
4232  // CHECK-DAG: [[ABS1:%.+]] = mhlo.abs [[SUB]]
4233  // CHECK-DAG: [[CONVERT1:%.+]] = mhlo.convert([[ABS1]])
4234  // CHECK-DAG: [[CONVERT2:%.+]] = mhlo.convert(%arg2)
4235  // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT1]], [[CONVERT2]]
4236  // CHECK-DAG: [[CEIL:%.+]] = mhlo.ceil [[DIV]]
4237  // CHECK-DAG: [[CONVERT3:%.+]] = mhlo.convert([[CEIL]])
4238  // CHECK-DAG: [[RESHAPE:%.+]] = mhlo.reshape [[CONVERT3]]
4239  // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64}
4240  // CHECK-DAG: [[CONVERT3:%.+]] = mhlo.convert %arg0
4241  // CHECK-DAG: [[CONVERT4:%.+]] = mhlo.convert %arg2
4242  // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<> : tensor<0xi64>}
4243  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<> : tensor<0xi64>}
4244  %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
4245
4246  // CHECK: return [[ADD]]
4247  func.return %2 : tensor<?xi32>
4248}
4249
4250// -----
4251
4252// CHECK-LABEL: func @linspace_static
4253// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[STOP:%.*]]: tensor<f32>
4254func.func @linspace_static(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<4xf32> {
4255  // CHECK-DAG: [[NUM:%.*]] = mhlo.constant dense<4>
4256  // CHECK-DAG: [[NUM_F32:%.*]] = mhlo.convert([[NUM]])
4257  // CHECK-DAG: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00>
4258  // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = chlo.broadcast_subtract [[NUM_F32]], [[ONE]]
4259  // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]]
4260  // CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]]
4261  // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64}
4262  // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<> : tensor<0xi64>}
4263  // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>}
4264  // CHECK: return [[LINSPACE]]
4265  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor<i32>} : () -> tensor<i32>
4266  %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<4xf32>
4267  func.return %1 : tensor<4xf32>
4268}
4269
4270// -----
4271
4272// CHECK-LABEL: func @linspace_dynamic
4273func.func @linspace_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>) -> tensor<?xf32> {
4274  // CHECK: "tf.LinSpace"
4275  %0 = "tf.LinSpace"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<?xf32>
4276  func.return %0 : tensor<?xf32>
4277}
4278
4279// -----
4280
4281// CHECK-LABEL: func @linspace_invalid_num
4282func.func @linspace_invalid_num(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<?xf32> {
4283  // CHECK: mhlo.constant dense<> : tensor<0xi32>
4284  // CHECK: "tf.LinSpace"
4285  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
4286  %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<0xi32>) -> tensor<?xf32>
4287  func.return %1 : tensor<?xf32>
4288}
4289
4290//===----------------------------------------------------------------------===//
4291// LegacyCall op legalizations.
4292//===----------------------------------------------------------------------===//
4293
4294// -----
4295
4296func.func @identity_func(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> {
4297  func.return %arg0: tensor<10x2xf32>
4298}
4299
4300// CHECK-LABEL: testSimpleLegacyCallOp
4301func.func @testSimpleLegacyCallOp(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> {
4302  // CHECK: %[[RESULT:.*]] = call @identity_func(%arg0) : (tensor<10x2xf32>) -> tensor<10x2xf32>
4303  %0 = "tf.LegacyCall"(%arg0) {f = @identity_func} : (tensor<10x2xf32>) -> tensor<10x2xf32>
4304  // CHECK: return %[[RESULT]]
4305  func.return %0: tensor<10x2xf32>
4306}
4307
4308// -----
4309
4310func.func @select_first(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> {
4311  func.return %arg0: tensor<10x2xf32>
4312}
4313
4314// CHECK-LABEL: testMultiInputLegacyCallOp
4315func.func @testMultiInputLegacyCallOp(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> {
4316  // CHECK: %[[RESULT:.*]] = call @select_first(%arg0, %arg1) : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32>
4317  %0 = "tf.LegacyCall"(%arg0, %arg1) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @select_first} : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32>
4318  // CHECK: return %[[RESULT]]
4319  func.return %0: tensor<10x2xf32>
4320}
4321
4322//===----------------------------------------------------------------------===//
4323// Conv op legalizations.
4324//===----------------------------------------------------------------------===//
4325
4326// -----
4327
4328// CHECK-LABEL: conv_simple
4329func.func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> {
4330
4331  // CHECK: mhlo.convolution(%arg0, %arg1)
4332  // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
4333  // CHECK-SAME{LITERAL}: window = {stride = [4, 5], pad = [[0, 1], [2, 3]], rhs_dilate = [2, 3]}
4334  // CHECK-SAME: batch_group_count = 1
4335  // CHECK-SAME: feature_group_count = 2
4336
4337  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32>
4338  func.return %0 : tensor<256x8x7x16xf32>
4339}
4340
4341// -----
4342
4343// CHECK-LABEL: conv3d_simple
4344func.func @conv3d_simple(%arg0: tensor<256x32x32x32x6xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> {
4345
4346  // CHECK: mhlo.convolution(%arg0, %arg1)
4347  // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f]
4348  // CHECK-SAME{LITERAL}: window = {stride = [5, 6, 7], pad = [[1, 2], [2, 3], [2, 3]], rhs_dilate = [2, 3, 4]}
4349  // CHECK-SAME: batch_group_count = 1
4350  // CHECK-SAME: feature_group_count = 2
4351
4352  %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", dilations = [1, 2, 3, 4, 1], padding = "SAME", strides = [1, 5, 6, 7, 1]} : (tensor<256x32x32x32x6xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32>
4353  func.return %0 : tensor<256x7x6x5x16xf32>
4354}
4355
4356// -----
4357
4358// CHECK-LABEL: depthwiseconv_simple
4359func.func @depthwiseconv_simple(%arg0: tensor<?x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32>) -> tensor<?x3x4x9xf32> {
4360  // CHECK: %[[RESHAPED_FILTER:.*]] = mhlo.reshape %arg1 : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32>
4361  // CHECK: mhlo.convolution(%arg0, %[[RESHAPED_FILTER]])
4362  // CHECK-SAME: feature_group_count = 3
4363  %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {
4364    data_format = "NHWC",
4365    device = "",
4366    dilations = [1, 1, 1, 1],
4367    explicit_paddings = [],
4368    padding = "VALID",
4369    strides = [1, 1, 1, 1]
4370  } : (tensor<?x4x5x3xf32>, tensor<2x2x3x3xf32>) -> tensor<?x3x4x9xf32>
4371  func.return %0 : tensor<?x3x4x9xf32>
4372}
4373
4374// -----
4375
4376// CHECK-LABEL: conv_valid_padding
4377func.func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> {
4378  // CHECK: mhlo.convolution(%arg0, %arg1)
4379
4380  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x4x5x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32>
4381  func.return %0 : tensor<1x2x3x1xf32>
4382}
4383
4384// -----
4385
4386// CHECK-LABEL: conv_explicit_paddings
4387func.func @conv_explicit_paddings(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> {
4388
4389  // CHECK: mhlo.convolution(%arg0, %arg1)
4390  // CHECK-SAME{LITERAL}: pad = [[6, 0], [3, 3]]
4391
4392  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "EXPLICIT", explicit_paddings = [0, 0, 6, 0, 3, 3, 0, 0], strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32>
4393  func.return %0 : tensor<256x9x7x16xf32>
4394}
4395
4396// -----
4397
4398// CHECK-LABEL: @conv2d_backprop_input_dynamic
4399func.func @conv2d_backprop_input_dynamic(%filter: tensor<2x2x1x16xf32>, %out_backprop: tensor<?x256x256x16xf32>) -> tensor<?x512x512x1xf32> {
4400  // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>}
4401  // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]])
4402  // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]
4403  // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]}
4404  // CHECK-SAME: batch_group_count = 1 : i64
4405  // CHECK-SAME: feature_group_count = 1 : i64
4406  // CHECK: return %[[RESULT]]
4407  %cst_0_1d = "tf.Const"() {device = "", value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
4408  %cst_1_0d = "tf.Const"() {device = "", value = dense<1> : tensor<i32>} : () -> tensor<i32>
4409  %cst_1_1d = "tf.Const"() {device = "", value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
4410  %cst_512_0d = "tf.Const"() {device = "", value = dense<512> : tensor<i32>} : () -> tensor<i32>
4411  %out_backprop_shape = "tf.Shape"(%out_backprop) {device = ""} : (tensor<?x256x256x16xf32>) -> tensor<4xi32>
4412  %batch_size = "tf.StridedSlice"(%out_backprop_shape, %cst_0_1d, %cst_1_1d, %cst_1_1d) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
4413  %input_shape = "tf.Pack"(%batch_size, %cst_512_0d, %cst_512_0d, %cst_1_0d) {axis = 0 : i64, device = ""} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<4xi32>
4414  %result = "tf.Conv2DBackpropInput"(%input_shape, %filter, %out_backprop) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<4xi32>, tensor<2x2x1x16xf32>, tensor<?x256x256x16xf32>) -> tensor<?x512x512x1xf32>
4415  return %result : tensor<?x512x512x1xf32>
4416}
4417
4418// -----
4419
4420// CHECK-LABEL: @conv2d_backprop_input
4421func.func @conv2d_backprop_input(
4422    %filter: tensor<3x3x1x32xf32>,
4423    %out_backprop: tensor<100x26x26x32xf32>
4424  ) -> tensor<100x28x28x1xf32> {
4425    // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>}
4426    // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]])
4427    // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]
4428    // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
4429    // CHECK-SAME: batch_group_count = 1 : i64
4430    // CHECK-SAME: feature_group_count = 1 : i64
4431    // CHECK: return %[[RESULT]]
4432  %input_sizes = "tf.Const" () { value = dense<[100,28,28,1]> : tensor<4xi32> } : () -> tensor<4xi32>
4433  %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) {
4434    data_format = "NHWC",
4435    dilations = [1, 1, 1, 1],
4436    explicit_paddings = [],
4437    padding = "VALID",
4438    strides = [1, 1, 1, 1],
4439    use_cudnn_on_gpu = true
4440  } : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32>
4441  func.return %result : tensor<100x28x28x1xf32>
4442}
4443
4444// -----
4445
4446// CHECK-LABEL: @conv2d_backprop_input_grouped
4447func.func @conv2d_backprop_input_grouped(
4448    %filter: tensor<2x2x5x21xf32>,
4449    %out_backprop: tensor<5x2x2x21xf32>
4450  ) -> tensor<5x3x3x15xf32> {
4451  %input_sizes = "tf.Const" () { value = dense<[5, 3, 3, 15]> : tensor<4xi32> } : () -> tensor<4xi32>
4452
4453  // Verify filter transformation for grouped convolution.
4454
4455  // CHECK: %[[RESHAPE:.*]] = mhlo.reshape %arg0 : (tensor<2x2x5x21xf32>) -> tensor<2x2x5x3x7xf32>
4456  // CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[RESHAPE]])
4457  // CHECK-SAME: permutation = dense<[0, 1, 3, 2, 4]>
4458  // CHECK-SAME: (tensor<2x2x5x3x7xf32>) -> tensor<2x2x3x5x7xf32>
4459  // CHECK: mhlo.reshape %[[TRANSPOSE]] : (tensor<2x2x3x5x7xf32>) -> tensor<2x2x15x7xf32>
4460
4461  %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) {
4462    data_format = "NHWC",
4463    dilations = [1, 1, 1, 1],
4464    explicit_paddings = [],
4465    padding = "VALID",
4466    strides = [1, 1, 1, 1],
4467    use_cudnn_on_gpu = true
4468  } : (tensor<4xi32>, tensor<2x2x5x21xf32>, tensor<5x2x2x21xf32>) -> tensor<5x3x3x15xf32>
4469  func.return %result : tensor<5x3x3x15xf32>
4470}
4471
4472
4473// CHECK-LABEL: @conv3d_backprop_input
4474func.func @conv3d_backprop_input(%filter: tensor<3x3x3x1x6xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> {
4475  // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
4476  // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]])
4477  // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, o, i]->[b, 0, 1, 2, f]
4478  // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]}
4479  // CHECK-SAME: batch_group_count = 1 : i64,
4480  // CHECK-SAME: feature_group_count = 1 : i64
4481
4482  // CHECK: return %[[RESULT]]
4483  %input_sizes = "tf.Const" () {value = dense<[2, 8, 8, 8, 1]> : tensor<5xi32>} : () -> tensor<5xi32>
4484  %result = "tf.Conv3DBackpropInputV2"(%input_sizes, %filter, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1],  padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<5xi32>, tensor<3x3x3x1x6xf32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32>
4485  func.return %result : tensor<2x8x8x8x1xf32>
4486}
4487
4488// -----
4489
4490// CHECK-LABEL: @conv2d_backprop_filter
4491func.func @conv2d_backprop_filter(
4492    %input: tensor<100x28x28x1xf32>,
4493    %out_backprop: tensor<100x26x26x32xf32>
4494  ) -> tensor<3x3x1x32xf32> {
4495  // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1)
4496  // CHECK-SAME: dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f]
4497  // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
4498  // CHECK-SAME:  batch_group_count = 1 : i64
4499  // CHECK-SAME:  feature_group_count = 1 : i64
4500  // CHECK: return %[[RESULT]]
4501  %filter_sizes = "tf.Const" () { value = dense<[3,3,1,32]> : tensor<4xi32> } : () -> tensor<4xi32>
4502  %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) {
4503    data_format = "NHWC",
4504    dilations = [1, 1, 1, 1],
4505    explicit_paddings = [],
4506    padding = "VALID",
4507    strides = [1, 1, 1, 1],
4508    use_cudnn_on_gpu = true
4509  } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<3x3x1x32xf32>
4510  func.return %result : tensor<3x3x1x32xf32>
4511}
4512
4513// -----
4514
4515// CHECK-LABEL: @conv2d_backprop_filter_grouped
4516func.func @conv2d_backprop_filter_grouped(
4517    %input: tensor<1x2x2x2xf32>,
4518    %out_backprop: tensor<1x1x1x2xf32>
4519  ) -> tensor<2x2x1x2xf32> {
4520
4521  // CHECK: mhlo.convolution(%arg0, %arg1)
4522  // CHECK-SAME:  batch_group_count = 2 : i64
4523  // CHECK-SAME:  feature_group_count = 1 : i64
4524
4525  %filter_sizes = "tf.Const" () { value = dense<[2, 2, 1, 2]> : tensor<4xi32> } : () -> tensor<4xi32>
4526  %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) {
4527    data_format = "NHWC",
4528    dilations = [1, 1, 1, 1],
4529    explicit_paddings = [],
4530    padding = "VALID",
4531    strides = [1, 1, 1, 1],
4532    use_cudnn_on_gpu = true
4533  } : (tensor<1x2x2x2xf32>, tensor<4xi32>, tensor<1x1x1x2xf32>) -> tensor<2x2x1x2xf32>
4534  func.return %result : tensor<2x2x1x2xf32>
4535}
4536
4537
4538// CHECK-LABEL: @conv3d_backprop_filter
4539func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> {
4540  // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1)
4541  // CHECK-SAME: dim_numbers = [f, 0, 1, 2, b]x[i, 0, 1, 2, o]->[0, 1, 2, b, f]
4542  // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]}
4543  // CHECK-SAME: batch_group_count = 1 : i64
4544  // CHECK-SAME: feature_group_count = 1 : i64
4545  // CHECK: return %[[RESULT]]
4546  %filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32>
4547  %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1],  padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32>
4548  func.return %result : tensor<3x3x3x1x6xf32>
4549}
4550
4551// -----
4552
4553// CHECK-LABEL: @collective_permute
4554func.func @collective_permute(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
4555  %source_target_pairs = "tf.Const" () {
4556    value = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi32>
4557  } : () -> tensor<3x2xi32>
4558
4559  // CHECK: "mhlo.collective_permute"
4560  // CHECK-SAME: source_target_pairs = dense<{{\[}}[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
4561  %0 = "tf.CollectivePermute"(%arg0, %source_target_pairs) {
4562  } : (tensor<128x32xf32>, tensor<3x2xi32>) -> tensor<128x32xf32>
4563
4564  func.return %0 : tensor<128x32xf32>
4565}
4566
4567// -----
4568
4569// CHECK-LABEL: @cross_replica_sum
4570func.func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> {
4571  %replica_groups = "tf.Const" () {
4572    value = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
4573  } : () -> tensor<2x4xi32>
4574
4575  // CHECK: mhlo.cross-replica-sum
4576  // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
4577  %result = "tf.CrossReplicaSum" (%input, %replica_groups) : (tensor<10xf32>, tensor<2x4xi32>) -> tensor<10xf32>
4578  func.return %result : tensor<10xf32>
4579}
4580
4581// -----
4582
4583// CHECK-LABEL: conv_dynamic
4584func.func @conv_dynamic(%arg0: tensor<?x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<?x8x7x16xf32> {
4585  // CHECK: "mhlo.dynamic_conv"
4586  // CHECK-SAME: {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 2 : i64, rhs_dilation = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[4, 5]> : tensor<2xi64>} : (tensor<?x32x32x6xf32>, tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<?x8x7x16xf32>
4587  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<?x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<?x8x7x16xf32>
4588  func.return %0 : tensor<?x8x7x16xf32>
4589}
4590
4591//===----------------------------------------------------------------------===//
4592// tf.Split legalization
4593//===----------------------------------------------------------------------===//
4594
4595// -----
4596
4597// CHECK-LABEL: @split_not_match_non_const_split_dim
4598func.func @split_not_match_non_const_split_dim(%input: tensor<4x4xf32>, %split_dim: tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>) {
4599  // CHECK: tf.Split
4600  %0:2 = "tf.Split"(%split_dim, %input) : (tensor<i32>, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>)
4601  func.return %0#0, %0#1 : tensor<*xf32>, tensor<*xf32>
4602}
4603
4604// -----
4605
4606// CHECK-LABEL: @split_not_match_unknown_input_dim
4607func.func @split_not_match_unknown_input_dim(%input: tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) {
4608  %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
4609  // CHECK: tensor.dim {{.*}} : tensor<4x?x4xf32>
4610  // CHECK: arith.divsi {{.*}} : index
4611  // CHECK: tensor.from_elements {{.*}} : tensor<3xindex>
4612  // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<4x?x4xf32>
4613  // CHECK: muli {{.*}} : index
4614  // CHECK: muli {{.*}} : index
4615  // CHECK: tensor.from_elements {{.*}} : tensor<3xindex>
4616  // CHECK: tensor.from_elements {{.*}} : tensor<3xindex>
4617  // CHECK: tensor.from_elements {{.*}} : tensor<3xindex>
4618  // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<4x?x4xf32>
4619  %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>)
4620  func.return %0#0, %0#1 : tensor<4x?x4xf32>, tensor<4x?x4xf32>
4621}
4622
4623// -----
4624
4625// CHECK-LABEL: @split_match_and_split_into_two
4626func.func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) {
4627  %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
4628  // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, 6]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32>
4629  // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32>
4630  %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>)
4631  // CHECK: return %[[ONE]], %[[TWO]]
4632  func.return %0#0, %0#1 : tensor<2x6xf32>, tensor<2x6xf32>
4633}
4634
4635// -----
4636
4637// CHECK-LABEL: @split_match_and_split_into_two_dynamic
4638func.func @split_match_and_split_into_two_dynamic(%input: tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) {
4639  %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
4640  // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, -1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32>
4641  // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, -1]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32>
4642  %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>)
4643  // CHECK: return %[[ONE]], %[[TWO]]
4644  func.return %0#0, %0#1 : tensor<2x?xf32>, tensor<2x?xf32>
4645}
4646
4647// -----
4648
4649// CHECK-LABEL: @split_match_and_split_into_three
4650// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>)
4651func.func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) {
4652  %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
4653  // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
4654  // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
4655  // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
4656  %0:3 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>)
4657  // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
4658  func.return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>
4659}
4660
4661//===----------------------------------------------------------------------===//
4662// tf.TopKV2 legalization
4663//===----------------------------------------------------------------------===//
4664
4665// -----
4666
4667// CHECK-LABEL: topk_v2_non_const_k
4668func.func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) {
4669  // CHECK: tf.TopKV2
4670  %0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>)
4671  func.return %0#0, %0#1: tensor<?xf32>, tensor<?xi32>
4672}
4673
4674// -----
4675
4676// CHECK-LABEL: topk_v2_unknown_input_last_dim
4677func.func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) {
4678  %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
4679  // CHECK: tf.TopKV2
4680  %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor<i32>) -> (tensor<16x?xf32>, tensor<16x?xi32>)
4681  func.return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32>
4682}
4683
4684// -----
4685
4686// CHECK-LABEL: topk_v2
4687// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32>
4688func.func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) {
4689  %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
4690
4691  // CHECK:     chlo.top_k(%[[INPUT]], k = 8)
4692  %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) -> (tensor<16x8xf32>, tensor<16x8xi32>)
4693  func.return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32>
4694}
4695
4696//===----------------------------------------------------------------------===//
4697// tf.SplitV legalization
4698//===----------------------------------------------------------------------===//
4699
4700// -----
4701
4702// CHECK-LABEL: @splitv_match_and_split_into_three
4703// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>)
4704func.func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) {
4705  %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
4706  %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
4707  // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x1xf32>
4708  // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
4709  // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x3xf32>
4710  %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
4711  // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
4712  func.return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>
4713}
4714
4715// -----
4716
4717// CHECK-LABEL: @splitv_match_and_split_into_three_dynamic
4718func.func @splitv_match_and_split_into_three_dynamic(%input: tensor<?x6xf32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>) {
4719  %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
4720  %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
4721  // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x1xf32>
4722  // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x2xf32>
4723  // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x3xf32>
4724  %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<?x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>)
4725  func.return %0#0, %0#1, %0#2 : tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>
4726}
4727
4728// -----
4729
4730// CHECK-LABEL: @splitv_dynamic_dim_in_split_sizes
4731func.func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) {
4732  %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
4733  %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
4734  // CHECK: limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>
4735  // CHECK: limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>
4736  // CHECK: limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>
4737  %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
4738  func.return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>
4739}
4740
4741//===----------------------------------------------------------------------===//
4742// tf.Assert legalization
4743//===----------------------------------------------------------------------===//
4744
4745// -----
4746
4747// CHECK-LABEL: @assert
4748func.func @assert(%arg0: tensor<i1>, %arg1: tensor<*xf32>) {
4749  // CHECK-NOT: tf.Assert
4750  "tf.Assert"(%arg0, %arg1) {summarize = 1} : (tensor<i1>, tensor<*xf32>) -> ()
4751  func.return
4752}
4753
4754//===----------------------------------------------------------------------===//
4755// tf.Unpack legalization
4756//===----------------------------------------------------------------------===//
4757
4758// -----
4759
4760// CHECK-LABEL: @unpack
4761func.func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) {
4762  // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
4763  // CHECK: %[[RES1:.*]] = mhlo.reshape %[[SLICE1]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
4764  // CHECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
4765  // CHECK: %[[RES2:.*]] = mhlo.reshape %[[SLICE2]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
4766  // CHECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
4767  // CHECK: %[[RES3:.*]] = mhlo.reshape %[[SLICE3]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
4768
4769  %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>)
4770  // return %[[RES1]], %[[RES2]], %[[RES3]]
4771  func.return %0#0, %0#1, %0#2 : tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>
4772}
4773
4774// -----
4775
4776// CHECK-LABEL: func @unpack_dynamic
4777func.func @unpack_dynamic(%arg0: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
4778  // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<?x?x2xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x?x1xf32>
4779  // CHECK: tensor.from_elements {{.*}} : tensor<2xi32>
4780  // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor<?x?x1xf32>, tensor<2xi32>) -> tensor<?x?xf32>
4781  // CHECK: tensor.from_elements {{.*}} : tensor<3xi32>
4782  // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<?x?x2xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x?x1xf32>
4783  // CHECK: tensor.from_elements {{.*}} : tensor<2xi32>
4784  // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor<?x?x1xf32>, tensor<2xi32>) -> tensor<?x?xf32>
4785  // CHECK: return {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>
4786  %0:2 = "tf.Unpack"(%arg0) {axis = -1 : i64} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
4787  func.return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
4788}
4789
4790// -----
4791
4792// CHECK-LABEL: @unpack_unranked
4793func.func @unpack_unranked(%input: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
4794
4795  // CHECK: tf.Unpack
4796  %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
4797  func.return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
4798}
4799
4800//===----------------------------------------------------------------------===//
4801// tf.UnsortedSegment{Max|Min|Prod|Sum} legalization
4802//===----------------------------------------------------------------------===//
4803
4804// -----
4805
4806// CHECK-LABEL: @unsorted_segment_sum
4807// CHECK-SAME: [[DATA:%.*]]: tensor<8x16x64xf32>
4808// CHECK-SAME: [[SI:%.*]]: tensor<8x16xi32>
4809func.func @unsorted_segment_sum(%data: tensor<8x16x64xf32>, %segment_ids : tensor<8x16xi32>) -> (tensor<4x64xf32>) {
4810  %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
4811  // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4812  // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ZERO]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor<f32>) -> tensor<4x64xf32>
4813  // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ({
4814  // CHECK: ^{{.*}}([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
4815  // CHECK:   [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] : tensor<f32>
4816  // CHECK:   mhlo.return [[ADD]]
4817  // CHECK: indices_are_sorted = false,
4818  // CHECK-SAME: scatter_dimension_numbers =
4819  // CHECK-SAME:   update_window_dims = [2]
4820  // CHECK-SAME:   inserted_window_dims = [0]
4821  // CHECK-SAME:   scatter_dims_to_operand_dims = [0]
4822  // CHECK-SAME:   index_vector_dim = 2
4823  // CHECK-SAME: unique_indices = false
4824  // CHECK-SAME: (tensor<4x64xf32>, tensor<8x16xi32>, tensor<8x16x64xf32>) -> tensor<4x64xf32>
4825  // CHECK: return [[SCATTER]]
4826  %0 = "tf.UnsortedSegmentSum"(%data, %segment_ids, %num_segments) : (tensor<8x16x64xf32>, tensor<8x16xi32>, tensor<i32>) -> (tensor<4x64xf32>)
4827  func.return %0: tensor<4x64xf32>
4828}
4829
4830// -----
4831
4832// CHECK-LABEL: @unsorted_segment_prod
4833// CHECK-SAME: [[DATA:%.*]]: tensor<8x?x64xf32>
4834// CHECK-SAME: [[SI:%.*]]: tensor<?x16xi32>
4835func.func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) {
4836  %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
4837  // CHECK: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
4838  // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ONE]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor<f32>) -> tensor<4x64xf32>
4839  // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ({
4840  // CHECK: ^{{.*}}([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
4841  // CHECK:   [[MUL:%.*]] = mhlo.multiply [[LHS]], [[RHS]] : tensor<f32>
4842  // CHECK:   mhlo.return [[MUL]]
4843  // CHECK: indices_are_sorted = false
4844  // CHECK-SAME: scatter_dimension_numbers =
4845  // CHECK-SAME:   update_window_dims = [2]
4846  // CHECK-SAME:   inserted_window_dims = [0]
4847  // CHECK-SAME:   scatter_dims_to_operand_dims = [0]
4848  // CHECK-SAME:   index_vector_dim = 2
4849  // CHECK-SAME: unique_indices = false
4850  // CHECK-SAME: (tensor<4x64xf32>, tensor<?x16xi32>, tensor<8x?x64xf32>) -> tensor<4x?xf32>
4851  // CHECK: return [[SCATTER]]
4852  %0 = "tf.UnsortedSegmentProd"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>)
4853  func.return %0: tensor<4x?xf32>
4854}
4855
4856// -----
4857
4858// CHECK-LABEL: @unsorted_segment_min
4859func.func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) {
4860  %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
4861  // CHECK: mhlo.constant dense<3.40282347E+38> : tensor<f32>
4862  // CHECK: mhlo.scatter
4863  // CHECK: mhlo.minimum
4864  %0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>)
4865  func.return %0: tensor<4x?xf32>
4866}
4867
4868// -----
4869
4870// CHECK-LABEL: @unsorted_segment_max
4871func.func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) {
4872  %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
4873  // CHECK: mhlo.constant dense<-3.40282347E+38> : tensor<f32>
4874  // CHECK: mhlo.scatter
4875  // CHECK: mhlo.maximum
4876  %0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>)
4877  func.return %0: tensor<4x?xf32>
4878}
4879
4880//===----------------------------------------------------------------------===//
4881// tf.GatherNd legalization
4882//===----------------------------------------------------------------------===//
4883// CHECK-LABEL: func @gatherNd_dynamic
4884func.func @gatherNd_dynamic(%arg0: tensor<?x?x?xi32>, %arg1: tensor<?x6x2xi32>) -> tensor<?x6x?xi32> {
4885  // CHECK: tensor.dim
4886  // CHECK: index_cast
4887  // CHECK: tensor.from_elements
4888  // CHECK: mhlo.dynamic_gather
4889  // CHECK-SAME: dimension_numbers =
4890  // CHECK-SAME:   offset_dims = [2]
4891  // CHECK-SAME:   collapsed_slice_dims = [0, 1]
4892  // CHECK-SAME:   start_index_map = [0, 1]
4893  // CHECK-SAME:   index_vector_dim = 2
4894  // CHECK-SAME: indices_are_sorted = false
4895  %0 =  "tf.GatherNd"(%arg0, %arg1) {Tindices = i32, Tparams = i32, device = ""} : (tensor<?x?x?xi32>, tensor<?x6x2xi32>) -> tensor<?x6x?xi32>
4896  func.return %0 : tensor<?x6x?xi32>
4897}
4898
4899// -----
4900
4901// CHECK-LABEL: func @gatherNd_static
4902func.func @gatherNd_static(%arg0: tensor<2x4x128xf32>, %arg1: tensor<2x1xi32>) -> tensor<2x4x128xf32> {
4903  // CHECK:      "mhlo.gather"({{.*}}) {
4904  // CHECK-SAME:   dimension_numbers =
4905  // CHECK-SAME:     offset_dims = [1, 2]
4906  // CHECK-SAME:     collapsed_slice_dims = [0]
4907  // CHECK-SAME:     start_index_map = [0]
4908  // CHECK-SAME:     index_vector_dim = 1
4909  // CHECK-SAME:   indices_are_sorted = false
4910  // CHECK-SAME:   slice_sizes = dense<[1, 4, 128]>
4911  // CHECK-SAME: (tensor<2x4x128xf32>, tensor<2x1xi32>) -> tensor<2x4x128xf32>
4912  %0 =  "tf.GatherNd"(%arg0, %arg1) {Tindices = i32, Tparams = i32, device = ""} : (tensor<2x4x128xf32>, tensor<2x1xi32>) -> tensor<2x4x128xf32>
4913  func.return %0 : tensor<2x4x128xf32>
4914}
4915
4916//===----------------------------------------------------------------------===//
4917// tf.GatherV2 legalization
4918//===----------------------------------------------------------------------===//
4919
4920// -----
4921
4922// CHECK-LABEL: @gather_v2
4923//  CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]]
4924//  CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]
4925func.func @gather_v2(%params: tensor<16x2x3xf32>, %indices: tensor<16x5xi32>) -> tensor<16x2x5xf32> {
4926  //      CHECK: mhlo.torch_index_select
4927  // CHECK-SAME:   %[[PARAMS]], %[[INDICES]]
4928  // CHECK-SAME:   batch_dims = 1
4929  // CHECK-SAME:   dim = 2
4930  %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
4931  %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5xf32>
4932  func.return %1 : tensor<16x2x5xf32>
4933}
4934
4935// -----
4936
4937// CHECK-LABEL: @gather_v2_dynamic
4938//  CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]]
4939//  CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]
4940func.func @gather_v2_dynamic(%params: tensor<?x?x?xf32>, %indices: tensor<?x?xi32>) -> tensor<*xf32> {
4941  //      CHECK: mhlo.torch_index_select
4942  // CHECK-SAME:   %[[PARAMS]], %[[INDICES]]
4943  // CHECK-SAME:   batch_dims = 1
4944  // CHECK-SAME:   dim = 2
4945  %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
4946  %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor<?x?x?xf32>, tensor<?x?xi32>, tensor<1xi32>) -> tensor<*xf32>
4947  func.return %1 : tensor<*xf32>
4948}
4949
4950// -----
4951
4952// CHECK-LABEL: @gather_v2_dynamic_index_i64
4953//  CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]]
4954//  CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]
4955func.func @gather_v2_dynamic_index_i64(%params: tensor<?x?x?xf32>, %indices: tensor<?x?xi64>) -> tensor<*xf32> {
4956  //      CHECK: mhlo.torch_index_select
4957  // CHECK-SAME:   %[[PARAMS]], %[[INDICES]]
4958  // CHECK-SAME:   batch_dims = 1
4959  // CHECK-SAME:   dim = 2
4960  %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
4961  %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor<?x?x?xf32>, tensor<?x?xi64>, tensor<1xi32>) -> tensor<*xf32>
4962  func.return %1 : tensor<*xf32>
4963}
4964
4965// -----
4966
4967// CHECK-LABEL: @gather_v2_unranked
4968func.func @gather_v2_unranked(%params: tensor<*xf32>, %indices: tensor<*xi32>) -> tensor<*xf32> {
4969  // CHECK: tf.GatherV2
4970  %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
4971  %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor<*xf32>, tensor<*xi32>, tensor<1xi32>) -> tensor<*xf32>
4972  func.return %1 : tensor<*xf32>
4973}
4974
4975// -----
4976
4977// CHECK-LABEL: @gather_v2_dynamic_shape
4978//  CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]]
4979//  CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]
4980func.func @gather_v2_dynamic_shape(%params: tensor<?x2x3xf32>, %indices: tensor<?x5xi32>) -> tensor<?x2x5xf32> {
4981  //      CHECK: mhlo.torch_index_select
4982  // CHECK-SAME:   %[[PARAMS]], %[[INDICES]]
4983  // CHECK-SAME:   batch_dims = 1
4984  // CHECK-SAME:   dim = 2
4985  %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
4986  %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor<?x2x3xf32>, tensor<?x5xi32>, tensor<1xi32>) -> tensor<?x2x5xf32>
4987  func.return %1 : tensor<?x2x5xf32>
4988}
4989
4990//===----------------------------------------------------------------------===//
4991// tf.StridedSliceGrad legalization
4992//===----------------------------------------------------------------------===//
4993
4994// -----
4995
4996// CHECK-LABEL: strided_slice_grad
4997// CHECK-SAME: [[GRAD:%.*]]: tensor<4x16x1022xf32>
4998func.func @strided_slice_grad(%grad: tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> {
4999
5000  // For StridedSlice
5001  // Dim #:        0,   1,    2
5002  // Input shape: [4, 128, 1024]
5003  // Begin:        1,   4,   -3
5004  // End:          8,  65,   42
5005  // Stride:       1,   4,   -1
5006  // Begin mask:   1,   0,    0  (= 1)
5007  // End mask:     0,   0,    1  (= 4)
5008
5009  // So result shape:
5010  // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4
5011  // Dim #1: 4 to 65 stride 4: so 16
5012  // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022
5013  // result shape: [4, 16, 1022]
5014
5015  // To pad back:
5016  // Dim #:        0,   1,   2
5017  // Pad low:      0,   4,   0
5018  // Pad interm:   0,   3,   0
5019  // Pad high:     0,  63,   2
5020
5021  %shape = "tf.Const"() {value = dense<[4, 128, 1024]> : tensor<3xi32>} : () -> (tensor<3xi32>)
5022  %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>)
5023  %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>)
5024  %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>)
5025
5026  // CHECK: [[RESHAPE:%.*]] = mhlo.reshape %arg0 : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32>
5027  // CHECK: [[REVERSE:%.*]] = "mhlo.reverse"([[RESHAPE]]) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32>
5028  // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5029  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REVERSE]], [[ZERO]]) {edge_padding_high = dense<[0, 63, 2]> : tensor<3xi64>, edge_padding_low = dense<[0, 4, 0]> : tensor<3xi64>, interior_padding = dense<[0, 3, 0]> : tensor<3xi64>} : (tensor<4x16x1022xf32>, tensor<f32>) -> tensor<4x128x1024xf32>
5030
5031  %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 1, end_mask = 4} : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32>
5032  // CHECK: return [[PAD]]
5033  func.return %0: tensor<4x128x1024xf32>
5034}
5035
5036// -----
5037
5038// CHECK-LABEL: strided_slice_grad_shrink_axis_mask
5039// CHECK-SAME: [[GRAD:%.*]]: tensor<8xf32>
5040func.func @strided_slice_grad_shrink_axis_mask(%grad: tensor<8xf32>) -> tensor<4x8xf32> {
5041  // Input to StridedSlice was of shape 4x8xf32
5042  // Strided slice gets input[2:3, 0:8]
5043  // shrink_axis_mask is 1 denoting that dim#0 is shrunk. So the output is 8xf32
5044  // which is the shape of gradient.
5045  // StridedSliceGrad would reshape the gradient to 1x8xf32 and
5046  // then pad to match the shape of input 4x8xf32.
5047
5048  %shape = "tf.Const"() {value = dense<[4, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>)
5049  %begin = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> (tensor<2xi32>)
5050  %end = "tf.Const"() {value = dense<[3, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>)
5051  %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>)
5052
5053  // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<8xf32>) -> tensor<1x8xf32>
5054  // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5055  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]])
5056  // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0]> : tensor<2xi64>
5057  // CHECK-DAG-SAME: edge_padding_high = dense<[1, 0]> : tensor<2xi64>
5058  // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<2xi64>
5059  %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, shrink_axis_mask = 1} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<8xf32>) -> tensor<4x8xf32>
5060
5061  // CHECK: return [[PAD]] : tensor<4x8xf32>
5062  func.return %0 : tensor<4x8xf32>
5063}
5064
5065// -----
5066
5067// CHECK-LABEL: strided_slice_grad_new_axis_mask
5068// CHECK-SAME: [[GRAD:%.*]]: tensor<1x2xf32>
5069func.func @strided_slice_grad_new_axis_mask(%grad: tensor<1x2xf32>) -> tensor<8xf32> {
5070  // Input to StridedSlice was of shape 8xf32
5071  // Strided slice gets input[tf.new_axis, 2:4]
5072  // new_axis_mask is 1 denoting new axis is inserted at dim#0. So the output is
5073  // 1x2xf32 which is the shape of gradient.
5074  // StridedSliceGrad would reshape the gradient to 2xf32 and
5075  // then pad to match the shape of input 4x8xf32.
5076
5077  %shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
5078  %begin = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>)
5079  %end = "tf.Const"() {value = dense<[0, 4]> : tensor<2xi32>} : () -> (tensor<2xi32>)
5080  %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>)
5081
5082  // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<1x2xf32>) -> tensor<2xf32>
5083  // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5084  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]])
5085  // CHECK-DAG-SAME: edge_padding_low = dense<2> : tensor<1xi64>
5086  // CHECK-DAG-SAME: edge_padding_high = dense<4> : tensor<1xi64>
5087  // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<1xi64>
5088  %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, new_axis_mask = 1} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1x2xf32>) -> tensor<8xf32>
5089
5090  // CHECK: return [[PAD]] : tensor<8xf32>
5091  func.return %0 : tensor<8xf32>
5092}
5093
5094// -----
5095
5096// CHECK-LABEL: strided_slice_grad_ellipsis_mask
5097// CHECK-SAME: [[GRAD:%.*]]: tensor<2x4x8xf32>
5098func.func @strided_slice_grad_ellipsis_mask(%grad: tensor<2x4x8xf32>) -> tensor<4x4x8xf32> {
5099  // Input to StridedSlice was of shape 4x4x8xf32
5100  // Strided slice gets input[2:4, ...]
5101  // ellipsis_mask is 2 denoting that slice contains all elements in dim#1 and
5102  // dim#2, ignoring begin and end indices for these dimensions. So the output
5103  // is 2x4x8xf32 which is the shape of gradient.
5104  // StridedSliceGrad would pad the gradient to match the shape of
5105  // input 4x4x8xf32.
5106
5107  %shape = "tf.Const"() {value = dense<[4, 4, 8]> : tensor<3xi32>} : () -> (tensor<3xi32>)
5108  %begin = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
5109  %end = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>)
5110  %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>)
5111
5112  // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
5113  // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5114  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]])
5115  // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0, 0]> : tensor<3xi64>
5116  // CHECK-DAG-SAME: edge_padding_high = dense<0> : tensor<3xi64>
5117  // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<3xi64>
5118  %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, ellipsis_mask = 2} : (tensor<3xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2x4x8xf32>) -> tensor<4x4x8xf32>
5119
5120  // CHECK: return [[PAD]] : tensor<4x4x8xf32>
5121  func.return %0 : tensor<4x4x8xf32>
5122}
5123
5124
5125// CHECK-LABEL: strided_slice_grad_all_masks
5126// CHECK-SAME: [[GRAD:%.*]]: tensor<1x4x8x8x10x2x1xf32>
5127func.func @strided_slice_grad_all_masks(%grad: tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> {
5128  // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis]
5129  // New axis mask is at index 1 and 6 of sparse spec, so
5130  // new_axis_mask = 2^1 + 2^6 = 66
5131  // The ellipsis mask is applied to dim #1, #2 of input i.e, we get
5132  // canonicalized slice input[1, :, :, 8:, :10, 2:6:2]
5133  // The StridedSliceGrad op would propogate the gradient for the sliced tensor
5134  // to the original input tensor by padding with zeroes.
5135
5136  %shape = "tf.Const"() {value = dense<[2, 4, 8, 16, 32, 64]> : tensor<6xi32>} : () -> (tensor<6xi32>)
5137  %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>)
5138  %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>)
5139  %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>)
5140
5141  // Remove 2 new axes (at index 1 and 6) and 1 shrink axis (at index 0)
5142  // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<1x4x8x8x10x2x1xf32>) -> tensor<1x4x8x8x10x2xf32>
5143  // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5144  // The edge_padding_low, edge_padding_high and interior_padding attributes of
5145  // mhlo.pad would reflect the padding required to get the shape of the
5146  // input of StridedSlice op.
5147  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZERO]])
5148  // CHECK-DAG-SAME: edge_padding_low = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64>
5149  // CHECK-DAG-SAME: edge_padding_high = dense<[0, 0, 0, 0, 22, 59]> : tensor<6xi64>
5150  // CHECK-DAG-SAME: interior_padding = dense<[0, 0, 0, 0, 0, 1]> : tensor<6xi64>
5151  %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<6xi32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>, tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32>
5152
5153  // CHECK: return [[PAD]] : tensor<2x4x8x16x32x64xf32>
5154  func.return %0 : tensor<2x4x8x16x32x64xf32>
5155}
5156
5157// -----
5158
5159// CHECK-LABEL: @tensor_scatter_update
5160func.func @tensor_scatter_update(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
5161  // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({
5162  // CHECK:  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
5163  // CHECK:    mhlo.return %arg4 : tensor<f32>
5164  // CHECK:  })
5165  // CHECK-SAME: indices_are_sorted = false
5166  // CHECK-SAME: scatter_dimension_numbers
5167  // CHECK-SAME:   update_window_dims = [1]
5168  // CHECK-SAME:   inserted_window_dims = [0, 1]
5169  // CHECK-SAME:   scatter_dims_to_operand_dims = [0, 1]
5170  // CHECK-SAME:   index_vector_dim = 1
5171  // CHECK-SAME: unique_indices = false
5172  %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32>
5173  func.return %0 : tensor<?x?x?xf32>
5174}
5175
5176// -----
5177
5178// CHECK-LABEL: @tensor_scatter_update_scalar_update
5179func.func @tensor_scatter_update_scalar_update(%tensor: tensor<4x3xi32>, %indices: tensor<2x1xi32>, %updates: tensor<i32>) -> tensor<4x3xi32> {
5180  // CHECK: mhlo.constant dense<[2, 3]> : tensor<2xi64>
5181  // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg2, %0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<2xi64>) -> tensor<2x3xi32>
5182  // CHECK: "mhlo.scatter"
5183  %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<4x3xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<4x3xi32>
5184  func.return %0 : tensor<4x3xi32>
5185}
5186
5187// -----
5188
5189// CHECK-LABEL: @tensor_scatter_add
5190func.func @tensor_scatter_add(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
5191  // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({
5192  // CHECK:  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
5193  // CHECK:    %1 = mhlo.add %arg3, %arg4 : tensor<f32>
5194  // CHECK:    mhlo.return %1 : tensor<f32>
5195  // CHECK:  })
5196  // CHECK-SAME: indices_are_sorted = false
5197  // CHECK-SAME: scatter_dimension_numbers
5198  // CHECK-SAME:   update_window_dims = [1]
5199  // CHECK-SAME:   inserted_window_dims = [0, 1]
5200  // CHECK-SAME:   scatter_dims_to_operand_dims = [0, 1]
5201  // CHECK-SAME:   index_vector_dim = 1
5202  // CHECK-SAME: unique_indices = false
5203  %0 = "tf.TensorScatterAdd"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32>
5204  func.return %0 : tensor<?x?x?xf32>
5205}
5206
5207// -----
5208
5209// CHECK-LABEL: @tensor_scatter_add_scalar_update
5210func.func @tensor_scatter_add_scalar_update(%tensor: tensor<4x3xi32>, %indices: tensor<2x1xi32>, %updates: tensor<i32>) -> tensor<4x3xi32> {
5211  // CHECK: mhlo.constant dense<[2, 3]> : tensor<2xi64>
5212  // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg2, %0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<2xi64>) -> tensor<2x3xi32>
5213  // CHECK: "mhlo.scatter"
5214  %0 = "tf.TensorScatterAdd"(%tensor, %indices, %updates) : (tensor<4x3xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<4x3xi32>
5215  func.return %0 : tensor<4x3xi32>
5216}
5217
5218// -----
5219
5220// CHECK-LABEL: @tensor_scatter_sub
5221func.func @tensor_scatter_sub(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
5222  // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({
5223  // CHECK:  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
5224  // CHECK:    %1 = mhlo.subtract %arg3, %arg4 : tensor<f32>
5225  // CHECK:    mhlo.return %1 : tensor<f32>
5226  // CHECK:  })
5227  // CHECK-SAME: indices_are_sorted = false
5228  // CHECK-SAME: scatter_dimension_numbers
5229  // CHECK-SAME:   update_window_dims = [1]
5230  // CHECK-SAME:   inserted_window_dims = [0, 1]
5231  // CHECK-SAME:   scatter_dims_to_operand_dims = [0, 1]
5232  // CHECK-SAME:   index_vector_dim = 1
5233  // CHECK-SAME: unique_indices = false
5234  %0 = "tf.TensorScatterSub"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32>
5235  func.return %0 : tensor<?x?x?xf32>
5236}
5237
5238// -----
5239
5240// CHECK-LABEL: @tensor_scatter_min
5241func.func @tensor_scatter_min(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
5242  // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({
5243  // CHECK:  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
5244  // CHECK:    %1 = mhlo.minimum %arg3, %arg4 : tensor<f32>
5245  // CHECK:    mhlo.return %1 : tensor<f32>
5246  // CHECK:  })
5247  // CHECK-SAME: indices_are_sorted = false
5248  // CHECK-SAME: scatter_dimension_numbers
5249  // CHECK-SAME:   update_window_dims = [1]
5250  // CHECK-SAME:   inserted_window_dims = [0, 1]
5251  // CHECK-SAME:   scatter_dims_to_operand_dims = [0, 1]
5252  // CHECK-SAME:   index_vector_dim = 1
5253  // CHECK-SAME: unique_indices = false
5254  %0 = "tf.TensorScatterMin"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32>
5255  func.return %0 : tensor<?x?x?xf32>
5256}
5257
5258// -----
5259
5260// CHECK-LABEL: @tensor_scatter_max
5261func.func @tensor_scatter_max(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
5262  // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({
5263  // CHECK:  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
5264  // CHECK:    %1 = mhlo.maximum %arg3, %arg4 : tensor<f32>
5265  // CHECK:    mhlo.return %1 : tensor<f32>
5266  // CHECK:  })
5267  // CHECK-SAME: indices_are_sorted = false
5268  // CHECK-SAME: scatter_dimension_numbers
5269  // CHECK-SAME:   update_window_dims = [1]
5270  // CHECK-SAME:   inserted_window_dims = [0, 1]
5271  // CHECK-SAME:   scatter_dims_to_operand_dims = [0, 1]
5272  // CHECK-SAME:   index_vector_dim = 1
5273  // CHECK-SAME: unique_indices = false
5274  %0 = "tf.TensorScatterMax"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32>
5275  func.return %0 : tensor<?x?x?xf32>
5276}
5277
5278//===----------------------------------------------------------------------===//
5279// tf.RandomShuffle legalization
5280//===----------------------------------------------------------------------===//
5281
5282// -----
5283
5284// CHECK-LABEL: @random_shuffle_first_dim_1
5285// CHECK-SAME: [[INPUT:%.*]]: tensor<1x?xf32>
5286func.func @random_shuffle_first_dim_1(%input: tensor<1x?xf32>) -> tensor<1x?xf32> {
5287  %0 = "tf.RandomShuffle"(%input) : (tensor<1x?xf32>) -> (tensor<1x?xf32>)
5288  // CHECK-NEXT: return [[INPUT]]
5289  func.return %0: tensor<1x?xf32>
5290}
5291
5292// -----
5293
5294// CHECK-LABEL: @random_shuffle_1D_16
5295// CHECK-SAME: [[INPUT:%.*]]: tensor<16xf32>
5296func.func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> {
5297  // CHECK-DAG: [[SHAPE:%.*]] = mhlo.constant dense<16> : tensor<1xi64>
5298  // CHECK-DAG: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor<i32>
5299  // CHECK-DAG: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor<i32>
5300  // CHECK: [[RNG:%.*]] = "mhlo.rng"([[LOWER]], [[UPPER]], [[SHAPE]]) {rng_distribution = #mhlo.rng_distribution<UNIFORM>}
5301  // CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) ({
5302  // CHECK: ^{{.*}}([[ARG1:%.*]]: tensor<i32>, [[ARG2:%.*]]: tensor<i32>, {{.*}}: tensor<f32>, {{.*}}: tensor<f32>):
5303  // CHECK:   mhlo.compare LT, [[ARG1]], [[ARG2]], TOTALORDER
5304  // CHECK: }) {dimension = -1 : i64, is_stable = {{.*}}} : (tensor<16xi32>, tensor<16xf32>) -> (tensor<16xi32>, tensor<16xf32>)
5305  // CHECK: return [[SORT]]#1
5306  %0 = "tf.RandomShuffle"(%input) : (tensor<16xf32>) -> (tensor<16xf32>)
5307  func.return %0: tensor<16xf32>
5308}
5309
5310// -----
5311
5312// CHECK-LABEL: @random_shuffle_1D_10240
5313func.func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> {
5314  // CHECK: mhlo.rng{{.*UNIFORM.*}}
5315  // CHECK: mhlo.sort
5316  // CHECK: mhlo.rng{{.*UNIFORM.*}}
5317  // CHECK: mhlo.sort
5318  %0 = "tf.RandomShuffle"(%input) : (tensor<10240xf32>) -> (tensor<10240xf32>)
5319  func.return %0: tensor<10240xf32>
5320}
5321
5322// -----
5323
5324// CHECK-LABEL: @random_shuffle_3D
5325// CHECK-SAME: [[INPUT:%.*]]: tensor<4x?x16xf32>
5326func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> {
5327  // CHECK: [[INDICES:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
5328
5329  // CHECK-DAG: [[RNG_SHAPE:%.*]] = mhlo.constant dense<4> : tensor<1xi64>
5330  // CHECK-DAG: [[RNG_LOWER:%.*]] = mhlo.constant dense<0> : tensor<i32>
5331  // CHECK-DAG: [[RNG_UPPER:%.*]] = mhlo.constant dense<4> : tensor<i32>
5332  // CHECK: [[SWAPS:%.*]] = "mhlo.rng"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]]) {rng_distribution = #mhlo.rng_distribution<UNIFORM>}
5333
5334  // CHECK: [[IV_INIT:%.*]] = mhlo.constant dense<0> : tensor<i32>
5335
5336  // CHECK: [[WHILE_OUT:%.*]]:3 = mhlo.while([[ITER_ARG0:.*]] = [[IV_INIT]], [[ITER_ARG1:.*]] = [[SWAPS]], [[ITER_ARG2:.*]] = [[INDICES]])
5337  // CHECK:   [[LIMIT:%.*]] = mhlo.constant dense<4> : tensor<i32>
5338  // CHECK:   [[CMP:%.*]] = mhlo.compare LT, [[ITER_ARG0]], [[LIMIT]], NOTYPE
5339  // CHECK:   mhlo.return [[CMP]]
5340  // CHECK: } do {
5341  // CHECK:   [[SRC_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[ITER_ARG0]]) {slice_sizes = dense<1> : tensor<i64>} : (tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
5342  // CHECK:   [[SWP_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG1]], [[ITER_ARG0]]) {slice_sizes = dense<1> : tensor<i64>} : (tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
5343  // CHECK:   [[SWP:%.*]] = mhlo.reshape [[SWP_IDX]] : (tensor<1xi32>) -> tensor<i32>
5344  // CHECK:   [[TGT_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[SWP]]) {slice_sizes = dense<1> : tensor<i64>}
5345  // CHECK:   [[INDICES1:%.*]] = mhlo.dynamic_update_slice [[ITER_ARG2]], [[TGT_IDX]], [[ITER_ARG0]] : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32>
5346  // CHECK:   [[INDICES2:%.*]] = mhlo.dynamic_update_slice [[INDICES1]], [[SRC_IDX]], [[SWP]] : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32>
5347  // CHECK:   [[ONE:%.*]] = mhlo.constant dense<1> : tensor<i32>
5348  // CHECK:   [[NEW_IV:%.*]] = chlo.broadcast_add [[ITER_ARG0]], [[ONE]]
5349  // CHECK:   mhlo.return [[NEW_IV]], [[ITER_ARG1]], [[INDICES2]]
5350  // CHECK: }
5351
5352  // CHECK: [[GATHER:%.*]] = "mhlo.gather"([[INPUT]], [[WHILE_OUT]]#2)
5353  // CHECK-SAME:   dimension_numbers =
5354  // CHECK-SAME:     offset_dims = [1, 2]
5355  // CHECK-SAME:     collapsed_slice_dims = [0]
5356  // CHECK-SAME:     start_index_map = [0]
5357  // CHECK-SAME:     index_vector_dim = 1
5358  // CHECK-SAME: indices_are_sorted = false
5359  // CHECK-SAME: slice_sizes = dense<[1, -1, 16]>
5360  // CHECK: (tensor<4x?x16xf32>, tensor<4xi32>) -> tensor<4x?x16xf32>
5361
5362  // CHECK: return [[GATHER]]
5363
5364  %0 = "tf.RandomShuffle"(%input) : (tensor<4x?x16xf32>) -> (tensor<4x?x16xf32>)
5365  func.return %0: tensor<4x?x16xf32>
5366}
5367
5368//===----------------------------------------------------------------------===//
5369// tf.AvgPool legalization
5370//===----------------------------------------------------------------------===//
5371
5372// -----
5373
5374// CHECK-LABEL:   @avgpool_valid_padding
5375// CHECK-SAME:      [[ARG:%.+]]: tensor<2x12x21x7xf16>
5376// CHECK:           [[CONV32:%.+]] = mhlo.convert(%arg0) : (tensor<2x12x21x7xf16>) -> tensor<2x12x21x7xf32>
5377// CHECK:           [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5378// CHECK:           [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ({
5379// CHECK:           ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
5380// CHECK:             [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
5381// CHECK:             mhlo.return [[ADD]]
5382// CHECK:           })
5383// CHECK-SAME:        window_dimensions = dense<[1, 2, 2, 1]>
5384// CHECK-SAME:        window_strides = dense<[1, 4, 4, 1]>
5385// CHECK-SAME:        -> tensor<2x3x5x7xf32>
5386// CHECK:           [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
5387// CHECK:           [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
5388// CHECK-SAME:        broadcast_dimensions = dense<>
5389// CHECK-SAME:        -> tensor<2x3x5x7xf32>
5390// CHECK:           [[CONV16:%.+]] = mhlo.convert([[DIV_RESULT]])
5391// CHECK-SAME:        -> tensor<2x3x5x7xf16>
5392// CHECK:           return [[CONV16]]
5393func.func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> {
5394  %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16>
5395  func.return %0 : tensor<2x3x5x7xf16>
5396}
5397
5398// -----
5399
5400// CHECK-LABEL:   @avgpool_3d_valid_padding
5401// CHECK-SAME:      [[ARG:%.+]]: tensor<2x4x12x21x7xf16>
5402// CHECK:           [[CONV32:%.+]] = mhlo.convert(%arg0) : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x12x21x7xf32>
5403// CHECK:           [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5404// CHECK:           [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ({
5405// CHECK:           ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
5406// CHECK:           [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
5407// CHECK:             mhlo.return [[ADD]]
5408// CHECK:           })
5409// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 2, 1]>
5410// CHECK-SAME:        window_strides = dense<[1, 1, 4, 4, 1]>
5411// CHECK-SAME:        -> tensor<2x4x3x5x7xf32>
5412// CHECK:           [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
5413// CHECK:           [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
5414// CHECK-SAME:        broadcast_dimensions = dense<>
5415// CHECK-SAME:        -> tensor<2x4x3x5x7xf32>
5416// CHECK:           [[CONV16:%.+]] = mhlo.convert([[DIV_RESULT]])
5417// CHECK-SAME:        -> tensor<2x4x3x5x7xf16>
5418// CHECK:           return [[CONV16]]
5419func.func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> {
5420  %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16>
5421  func.return %0 : tensor<2x4x3x5x7xf16>
5422}
5423
5424// -----
5425
5426// CHECK-LABEL:   @avgpool_nchw_format
5427// CHECK-SAME:      [[ARG:%.+]]: tensor<2x7x12x21xf16>
5428// CHECK:           [[CONV32:%.+]] = mhlo.convert(%arg0) : (tensor<2x7x12x21xf16>) -> tensor<2x7x12x21xf32>
5429// CHECK:           [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5430// CHECK:           [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ({
5431// CHECK:           ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
5432// CHECK:             [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
5433// CHECK:             mhlo.return [[ADD]]
5434// CHECK:           })
5435// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 2]>
5436// CHECK-SAME:        window_strides = dense<[1, 1, 4, 4]>
5437// CHECK-SAME:        -> tensor<2x7x3x5xf32>
5438// CHECK:           [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
5439// CHECK:           [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
5440// CHECK-SAME:        broadcast_dimensions = dense<>
5441// CHECK-SAME:        -> tensor<2x7x3x5xf32>
5442// CHECK:           [[CONV16:%.+]] = mhlo.convert([[DIV_RESULT]])
5443// CHECK-SAME:        -> tensor<2x7x3x5xf16>
5444// CHECK:           return [[CONV16]]
5445func.func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> {
5446  %0 = "tf.AvgPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 2, 2], padding = "VALID", strides = [1, 1, 4, 4]} : (tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16>
5447  func.return %0 : tensor<2x7x3x5xf16>
5448}
5449
5450// -----
5451
5452// CHECK-LABEL:   @avgpool_3d_ncdhw_format
5453// CHECK-SAME:      [[ARG:%.+]]: tensor<2x7x4x12x21xf16>
5454// CHECK:           [[CONV32:%.+]] = mhlo.convert(%arg0) : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x12x21xf32>
5455// CHECK:           [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5456// CHECK:           [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ({
5457// CHECK:           ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
5458// CHECK:             [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
5459// CHECK:             mhlo.return [[ADD]]
5460// CHECK:           })
5461// CHECK-SAME:        window_dimensions = dense<[1, 1, 1, 2, 2]>
5462// CHECK-SAME:        window_strides = dense<[1, 1, 1, 4, 4]>
5463// CHECK-SAME:        -> tensor<2x7x4x3x5xf32>
5464// CHECK:           [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
5465// CHECK:           [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
5466// CHECK-SAME:        broadcast_dimensions = dense<>
5467// CHECK-SAME:        -> tensor<2x7x4x3x5xf32>
5468// CHECK:           [[CONV16:%.+]] = mhlo.convert([[DIV_RESULT]])
5469// CHECK-SAME:        -> tensor<2x7x4x3x5xf16>
5470// CHECK:           return [[CONV16]]
5471func.func @avgpool_3d_ncdhw_format(%arg0: tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> {
5472  %0 = "tf.AvgPool3D"(%arg0) {data_format = "NCDHW", ksize = [1, 1, 1, 2, 2], padding = "VALID", strides = [1, 1, 1, 4, 4]} : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16>
5473  func.return %0 : tensor<2x7x4x3x5xf16>
5474}
5475
5476// -----
5477
5478// CHECK-LABEL:   @avgpool_same_padding(
5479// CHECK-SAME:      %[[ARG0:.*]]: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32>
5480// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5481// CHECK:           %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ({
5482// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
5483// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
5484// CHECK:             mhlo.return %[[SUM1]] : tensor<f32>
5485// CHECK:           })
5486// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]>
5487// CHECK-SAME:        window_dimensions = dense<[1, 5, 2, 1]>
5488// CHECK-SAME:        window_strides = dense<[1, 3, 4, 1]>
5489// CHECK-SAME:        -> tensor<2x4x6x7xf32>
5490// CHECK:           %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x12x21x7xf32>
5491// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ({
5492// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
5493// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
5494// CHECK:             mhlo.return %[[SUM2]] : tensor<f32>
5495// CHECK:           })
5496// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]>
5497// CHECK-SAME:        window_dimensions = dense<[1, 5, 2, 1]>
5498// CHECK-SAME:        window_strides = dense<[1, 3, 4, 1]>
5499// CHECK-SAME:        -> tensor<2x4x6x7xf32>
5500// CHECK:           %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] : tensor<2x4x6x7xf32>
5501// CHECK:           return %[[RESULT]] : tensor<2x4x6x7xf32>
5502// CHECK:         }
5503func.func @avgpool_same_padding(%arg0: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> {
5504  %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 5, 2, 1], padding = "SAME", strides = [1, 3, 4, 1]} : (tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32>
5505  func.return %0 : tensor<2x4x6x7xf32>
5506}
5507
5508// -----
5509
5510// CHECK-LABEL:   @avgpool_3d_same_padding(
5511// CHECK-SAME:      %[[ARG0:.*]]: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32>
5512// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5513// CHECK:           %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ({
5514// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
5515// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
5516// CHECK:             mhlo.return %[[SUM1]] : tensor<f32>
5517// CHECK:           })
5518// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]>
5519// CHECK-SAME:        window_dimensions = dense<[1, 1, 5, 2, 1]>
5520// CHECK-SAME:        window_strides = dense<[1, 1, 3, 4, 1]>
5521// CHECK-SAME:        -> tensor<2x4x4x6x7xf32>
5522// CHECK:           %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x4x12x21x7xf32>
5523// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ({
5524// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
5525// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
5526// CHECK:             mhlo.return %[[SUM2]] : tensor<f32>
5527// CHECK:           })
5528// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]>
5529// CHECK-SAME:        window_dimensions = dense<[1, 1, 5, 2, 1]>
5530// CHECK-SAME:        window_strides = dense<[1, 1, 3, 4, 1]>
5531// CHECK-SAME:        -> tensor<2x4x4x6x7xf32>
5532// CHECK:           %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]]
5533// CHECK:           return %[[RESULT]] : tensor<2x4x4x6x7xf32>
5534// CHECK:         }
5535func.func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> {
5536  %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 5, 2, 1], padding = "SAME", strides = [1, 1, 3, 4, 1]} : (tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32>
5537  func.return %0 : tensor<2x4x4x6x7xf32>
5538}
5539
5540//===----------------------------------------------------------------------===//
5541// AvgPoolGrad op legalizations.
5542//===----------------------------------------------------------------------===//
5543
5544// -----
5545
5546// CHECK-LABEL:   @avgpool_grad_valid_padding(
5547// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> {
5548// CHECK-DAG:       %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5549// CHECK-DAG:       %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
5550// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
5551// CHECK-SAME:        broadcast_dimensions = dense<>
5552// CHECK-SAME:        -> tensor<10x12x16x64xf32>
5553// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
5554// CHECK-SAME:        edge_padding_high = dense<[0, 1, 1, 0]>
5555// CHECK-SAME:        edge_padding_low = dense<[0, 1, 1, 0]>
5556// CHECK-SAME:        interior_padding = dense<[0, 1, 1, 0]>
5557// CHECK-SAME:        -> tensor<10x25x33x64xf32>
5558// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({
5559// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
5560// CHECK:             %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
5561// CHECK:             mhlo.return %[[SUM]] : tensor<f32>
5562// CHECK:           })
5563// CHECK-SAME:        window_dimensions = dense<[1, 2, 2, 1]>
5564// CHECK-SAME:        window_strides = dense<1>
5565// CHECK-SAME:        -> tensor<10x24x32x64xf32>
5566// CHECK:           return %[[RESULT]] : tensor<10x24x32x64xf32>
5567func.func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> {
5568  %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>)
5569  %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
5570     data_format = "NHWC",
5571     ksize = [1, 2, 2, 1],
5572     padding = "VALID",
5573     strides = [1, 2, 2, 1]
5574  } : (tensor<4xi32>, tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32>
5575  func.return %result : tensor<10x24x32x64xf32>
5576}
5577
5578// -----
5579
5580// CHECK-LABEL:   @avgpool_3d_grad_valid_padding(
5581// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> {
5582// CHECK-DAG:       %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5583// CHECK-DAG:       %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
5584// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor<f32>) -> tensor<10x8x12x16x64xf32>
5585// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
5586// CHECK-SAME:        edge_padding_high = dense<[0, 0, 1, 1, 0]>
5587// CHECK-SAME:        edge_padding_low = dense<[0, 0, 1, 1, 0]>
5588// CHECK-SAME:        interior_padding = dense<[0, 0, 1, 1, 0]>
5589// CHECK-SAME:        -> tensor<10x8x25x33x64xf32>
5590// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({
5591// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
5592// CHECK:             %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
5593// CHECK:             mhlo.return %[[SUM]] : tensor<f32>
5594// CHECK:           })
5595// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 2, 1]>
5596// CHECK-SAME:        window_strides = dense<1>
5597// CHECK-SAME:        -> tensor<10x8x24x32x64xf32>
5598// CHECK:           return %[[RESULT]] : tensor<10x8x24x32x64xf32>
5599func.func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> {
5600  %orig_input_shape = "tf.Const"() {value = dense<[10, 8, 24, 32, 64]> : tensor<5xi32>} : () -> (tensor<5xi32>)
5601  %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) {
5602    data_format = "NDHWC",
5603    ksize = [1, 1, 2, 2, 1],
5604    padding = "VALID",
5605    strides = [1, 1, 2, 2, 1]} : (tensor<5xi32>, tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32>
5606  func.return %result : tensor<10x8x24x32x64xf32>
5607}
5608
5609// -----
5610
5611// CHECK-LABEL:   @avgpool_grad_same_padding(
5612// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> {
5613// CHECK-DAG:       %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5614// CHECK-DAG:       %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32>
5615// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ({
5616// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
5617// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
5618// CHECK:             mhlo.return %[[SUM1]] : tensor<f32>
5619// CHECK:           })
5620// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]>
5621// CHECK-SAME:        window_dimensions = dense<[1, 2, 3, 1]>
5622// CHECK-SAME:        window_strides = dense<[1, 4, 4, 1]>
5623// CHECK-SAME:        -> tensor<2x4x7x9xf32>
5624// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x4x7x9xf32>
5625// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
5626// CHECK-SAME:        edge_padding_high = dense<[0, 0, 1, 0]>
5627// CHECK-SAME:        edge_padding_low = dense<[0, 1, 1, 0]>
5628// CHECK-SAME:        interior_padding = dense<[0, 3, 3, 0]>
5629// CHECK-SAME:        -> tensor<2x14x27x9xf32>
5630// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({
5631// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
5632// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
5633// CHECK:             mhlo.return %[[SUM2]] : tensor<f32>
5634// CHECK:           })
5635// CHECK-SAME:        window_dimensions = dense<[1, 2, 3, 1]>
5636// CHECK-SAME:        window_strides = dense<1>
5637// CHECK-SAME:        -> tensor<2x13x25x9xf32>
5638// CHECK:           return %[[RESULT]] : tensor<2x13x25x9xf32>
5639func.func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> {
5640  %orig_input_shape = "tf.Const"() {value = dense<[2, 13, 25, 9]> : tensor<4xi32>} : () -> (tensor<4xi32>)
5641  %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
5642     data_format = "NHWC",
5643     ksize = [1, 2, 3, 1],
5644     padding = "SAME",
5645     strides = [1, 4, 4, 1]
5646  } : (tensor<4xi32>, tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32>
5647  func.return %result : tensor<2x13x25x9xf32>
5648}
5649
5650// -----
5651
5652// CHECK-LABEL:   @avgpool_3d_grad_same_padding(
5653// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> {
5654// CHECK-DAG:       %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5655// CHECK-DAG:       %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32>
5656// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ({
5657// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
5658// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
5659// CHECK:             mhlo.return %[[SUM1]] : tensor<f32>
5660// CHECK:           })
5661// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]>
5662// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 3, 1]>
5663// CHECK-SAME:        window_strides = dense<[1, 1, 4, 4, 1]>
5664// CHECK-SAME:        -> tensor<2x8x4x7x9xf32>
5665// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x8x4x7x9xf32>
5666// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
5667// CHECK-SAME:        edge_padding_high = dense<[0, 0, 0, 1, 0]>
5668// CHECK-SAME:        edge_padding_low = dense<[0, 0, 1, 1, 0]>
5669// CHECK-SAME:        interior_padding = dense<[0, 0, 3, 3, 0]>
5670// CHECK-SAME:        -> tensor<2x8x14x27x9xf32>
5671// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({
5672// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
5673// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
5674// CHECK:             mhlo.return %[[SUM2]] : tensor<f32>
5675// CHECK:           })
5676// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 3, 1]>
5677// CHECK-SAME:        window_strides = dense<1>
5678// CHECK-SAME:        -> tensor<2x8x13x25x9xf32>
5679// CHECK:           return %[[RESULT]] : tensor<2x8x13x25x9xf32>
5680func.func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> {
5681  %orig_input_shape = "tf.Const"() {value = dense<[2, 8, 13, 25, 9]> : tensor<5xi32>} : () -> (tensor<5xi32>)
5682  %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) {
5683    data_format = "NDHWC",
5684    ksize = [1, 1, 2, 3, 1],
5685    padding = "SAME",
5686    strides = [1, 1, 4, 4, 1]} : (tensor<5xi32>, tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32>
5687  func.return %result : tensor<2x8x13x25x9xf32>
5688}
5689
5690// -----
5691
5692// CHECK-LABEL:   @avgpool_grad_nchw_format(
5693// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> {
5694// CHECK-DAG:       %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5695// CHECK-DAG:       %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32>
5696// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ({
5697// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
5698// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
5699// CHECK:             mhlo.return %[[SUM1]] : tensor<f32>
5700// CHECK:           })
5701// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1]]>
5702// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 3]>
5703// CHECK-SAME:        window_strides = dense<[1, 1, 4, 4]>
5704// CHECK-SAME:        -> tensor<2x9x4x7xf32>
5705// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x4x7xf32>
5706// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
5707// CHECK-SAME:        edge_padding_high = dense<[0, 0, 0, 1]>
5708// CHECK-SAME:        edge_padding_low = dense<[0, 0, 1, 1]>
5709// CHECK-SAME:        interior_padding = dense<[0, 0, 3, 3]>
5710// CHECK-SAME:        -> tensor<2x9x14x27xf32>
5711// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({
5712// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
5713// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
5714// CHECK:             mhlo.return %[[SUM2]] : tensor<f32>
5715// CHECK:           })
5716// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 3]>
5717// CHECK-SAME:        window_strides = dense<1>
5718// CHECK-SAME:        -> tensor<2x9x13x25xf32>
5719// CHECK:           return %[[RESULT]] : tensor<2x9x13x25xf32>
5720func.func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> {
5721  %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 13, 25]> : tensor<4xi32>} : () -> (tensor<4xi32>)
5722  %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
5723     data_format = "NCHW",
5724     ksize = [1, 1, 2, 3],
5725     padding = "SAME",
5726     strides = [1, 1, 4, 4]
5727  } : (tensor<4xi32>, tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32>
5728  func.return %result : tensor<2x9x13x25xf32>
5729}
5730
5731// -----
5732
5733// CHECK-LABEL:   @avgpool_3d_grad_ncdwh_format(
5734// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> {
5735// CHECK-DAG:       %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5736// CHECK-DAG:       %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32>
5737// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ({
5738// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
5739// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
5740// CHECK:             mhlo.return %[[SUM1]] : tensor<f32>
5741// CHECK:           })
5742// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 0], [0, 0], [0, 1], [1, 1]]>
5743// CHECK-SAME:        window_dimensions = dense<[1, 1, 1, 2, 3]>
5744// CHECK-SAME:        window_strides = dense<[1, 1, 1, 4, 4]>
5745// CHECK-SAME:        -> tensor<2x9x8x4x7xf32>
5746// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x8x4x7xf32>
5747// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
5748// CHECK-SAME:        edge_padding_high = dense<[0, 0, 0, 0, 1]>
5749// CHECK-SAME:        edge_padding_low = dense<[0, 0, 0, 1, 1]>
5750// CHECK-SAME:        interior_padding = dense<[0, 0, 0, 3, 3]>
5751// CHECK-SAME:        -> tensor<2x9x8x14x27xf32>
5752// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({
5753// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
5754// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
5755// CHECK:             mhlo.return %[[SUM2]] : tensor<f32>
5756// CHECK:           })
5757// CHECK-SAME:        window_dimensions = dense<[1, 1, 1, 2, 3]>
5758// CHECK-SAME:        window_strides = dense<1> : tensor<5xi64>
5759// CHECK-SAME:        -> tensor<2x9x8x13x25xf32>
5760// CHECK:           return %[[RESULT]] : tensor<2x9x8x13x25xf32>
5761func.func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> {
5762  %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 8, 13, 25]> : tensor<5xi32>} : () -> (tensor<5xi32>)
5763  %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) {
5764    data_format = "NCDHW",
5765    ksize = [1, 1, 1, 2, 3],
5766    padding = "SAME",
5767    strides = [1, 1, 1, 4, 4]} : (tensor<5xi32>, tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32>
5768  func.return %result : tensor<2x9x8x13x25xf32>
5769}
5770
5771// -----
5772
5773// CHECK-LABEL:   @avgpool_grad_bf16(
5774// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> {
5775// CHECK-DAG:       %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<bf16>
5776// CHECK-DAG:       %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<bf16>
5777// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
5778// CHECK-SAME:        broadcast_dimensions = dense<>
5779// CHECK-SAME:        -> tensor<10x12x16x64xbf16>
5780// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
5781// CHECK-SAME:        edge_padding_high = dense<[0, 1, 1, 0]>
5782// CHECK-SAME:        edge_padding_low = dense<[0, 1, 1, 0]>
5783// CHECK-SAME:        interior_padding = dense<[0, 1, 1, 0]>
5784// CHECK-SAME:        -> tensor<10x25x33x64xbf16>
5785// CHECK:           %[[REDUCE_WINDOW_INPUT_CONVERTED:.*]] = mhlo.convert(%[[REDUCE_WINDOW_INPUT]]) : (tensor<10x25x33x64xbf16>) -> tensor<10x25x33x64xf32>
5786// CHECK:           %[[ZERO_F32:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5787// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT_CONVERTED]], %[[ZERO_F32]]) ({
5788// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
5789// CHECK:             %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
5790// CHECK:             mhlo.return %[[SUM]] : tensor<f32>
5791// CHECK:           })
5792// CHECK-SAME:        window_dimensions = dense<[1, 2, 2, 1]>
5793// CHECK-SAME:        window_strides = dense<1>
5794// CHECK-SAME:        -> tensor<10x24x32x64xf32>
5795// CHECK:           %[[RESULT_CONVERTED:.*]] = mhlo.convert(%[[RESULT]]) : (tensor<10x24x32x64xf32>) -> tensor<10x24x32x64xbf16>
5796// CHECK:           return %[[RESULT_CONVERTED]] : tensor<10x24x32x64xbf16>
5797func.func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> {
5798  %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>)
5799  %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
5800     data_format = "NHWC",
5801     ksize = [1, 2, 2, 1],
5802     padding = "VALID",
5803     strides = [1, 2, 2, 1]
5804  } : (tensor<4xi32>, tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16>
5805  func.return %result : tensor<10x24x32x64xbf16>
5806}
5807
5808// -----
5809
5810// CHECK-LABEL: xla_sharding
5811func.func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> {
5812  // CHECK-NEXT: "mhlo.custom_call"(%arg0) {call_target_name = "Sharding", mhlo.sharding = ""}
5813  %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "", sharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32>
5814  func.return %0 : tensor<4x16xf32>
5815}
5816
5817// -----
5818
5819// CHECK-LABEL: inplace_update_one
5820func.func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> {
5821  // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0>
5822  // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
5823  // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
5824  // CHECK-DAG: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]]
5825  // CHECK-DAG: [[UPDATE:%.+]] = mhlo.dynamic_update_slice %arg0, [[SLICE2]], [[RESHAPE1]], [[CST]]
5826  %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32>
5827
5828  // CHECK: return [[UPDATE]]
5829  func.return %0 : tensor<8x4xf32>
5830}
5831
5832// -----
5833
5834// CHECK-LABEL: inplace_update_three
5835func.func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> {
5836  // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0>
5837  // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
5838  // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
5839  // CHECK-DAG: [[SLICE3:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
5840  // CHECK-DAG: [[SLICE4:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
5841  // CHECK-DAG: [[SLICE5:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
5842  // CHECK-DAG: [[SLICE6:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
5843  // CHECK-DAG: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]]
5844  // CHECK-DAG: [[RESHAPE2:%.+]] = mhlo.reshape [[SLICE2]]
5845  // CHECK-DAG: [[RESHAPE3:%.+]] = mhlo.reshape [[SLICE3]]
5846  // CHECK-DAG: [[UPDATE1:%.+]] = mhlo.dynamic_update_slice %arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]]
5847  // CHECK-DAG: [[UPDATE2:%.+]] = mhlo.dynamic_update_slice [[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]]
5848  // CHECK-DAG: [[UPDATE3:%.+]] = mhlo.dynamic_update_slice [[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]]
5849  %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32>
5850
5851  // CHECK:  return [[UPDATE3]] : tensor<8x8x4xf32>
5852  func.return %0 : tensor<8x8x4xf32>
5853}
5854
5855// -----
5856
5857// CHECK-LABEL: xla_dynamic_update_slice
5858func.func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> {
5859  // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
5860  // CHECK: [[RESHAPE0:%.+]] = mhlo.reshape [[SLICE0]] : (tensor<1xi32>) -> tensor<i32>
5861  // CHECK: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
5862  // CHECK: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] : (tensor<1xi32>) -> tensor<i32>
5863  // CHECK: [[DUS:%.+]] = mhlo.dynamic_update_slice %arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]] : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<i32>, tensor<i32>) -> tensor<4x16xf32>
5864  // CHECK: return [[DUS]]
5865  %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<2xi32>) -> tensor<4x16xf32>
5866  func.return %0 : tensor<4x16xf32>
5867}
5868
5869// -----
5870
5871// CHECK-LABEL: xla_dynamic_update_slice2
5872func.func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<1xi32>) -> tensor<4xf32> {
5873  // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32>
5874  // CHECK: [[RESHAPE0:%.+]] = mhlo.reshape [[SLICE0]] : (tensor<1xi32>) -> tensor<i32>
5875  // CHECK: [[DUS:%.+]] = mhlo.dynamic_update_slice %arg0, %arg1, [[RESHAPE0]] : (tensor<4xf32>, tensor<2xf32>, tensor<i32>) -> tensor<4xf32>
5876  // CHECK: return [[DUS]]
5877  %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<1xi32>) -> tensor<4xf32>
5878  func.return %0 : tensor<4xf32>
5879}
5880
5881//===----------------------------------------------------------------------===//
5882// AllToAll op legalizations.
5883//===----------------------------------------------------------------------===//
5884
5885// -----
5886
5887// CHECK-LABEL: func @alltoall_basic
5888// See https://www.tensorflow.org/api_docs/python/tf/raw_ops/AllToAll
5889func.func @alltoall_basic(%input: tensor<1x2xf32>) -> tensor<2x1xf32> {
5890  %group_assignment = "tf.Const" () {
5891    value = dense<[[0, 1]]> : tensor<1x2xi32>
5892  } : () -> tensor<1x2xi32>
5893  %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 0 : i64, split_count = 2 : i64, split_dimension = 1 : i64} :  (tensor<1x2xf32>, tensor<1x2xi32>)  -> tensor<2x1xf32>
5894  // CHECK: mhlo.all_to_all
5895  // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
5896  func.return %result : tensor<2x1xf32>
5897}
5898
5899
5900//===----------------------------------------------------------------------===//
5901// Cumsum op legalizations.
5902//===----------------------------------------------------------------------===//
5903
5904// -----
5905
5906// CHECK-LABEL: func @cumsum_static
5907// CHECK-SAME: [[X:%.*]]: tensor<4xf32>
5908func.func @cumsum_static(%arg0: tensor<4xf32>) -> tensor<4xf32> {
5909  // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32>
5910  // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[X]] : tensor<4xf32>
5911  // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5912  // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ({
5913  // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>):
5914  // CHECK:   [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32>
5915  // CHECK:   mhlo.return [[SUM]] : tensor<f32>
5916  // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
5917  // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[REDUCE]] : tensor<4xf32>
5918  // CHECK: return [[CONVERT_REDUCE]]
5919  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
5920  %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
5921  func.return %1 : tensor<4xf32>
5922}
5923
5924// -----
5925
5926// CHECK-LABEL: func @cumsum_exclusive
5927// CHECK-SAME: [[X:%.*]]: tensor<4xf32>
5928func.func @cumsum_exclusive(%arg0: tensor<4xf32>) -> tensor<4xf32> {
5929  // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32>
5930  // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[X]] : tensor<4xf32>
5931  // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5932  // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ({
5933  // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>):
5934  // CHECK:   [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32>
5935  // CHECK:   mhlo.return [[SUM]] : tensor<f32>
5936  // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
5937  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
5938  // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[PAD]] : tensor<4xf32>
5939  // CHECK: return [[CONVERT_REDUCE]]
5940  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
5941  %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
5942  func.return %1 : tensor<4xf32>
5943}
5944
5945// -----
5946
5947// CHECK-LABEL: func @cumsum_reverse
5948// CHECK-SAME: [[X:%.*]]: tensor<4xf32>
5949func.func @cumsum_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> {
5950  // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32>
5951  // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
5952  // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[REVERSE1]] : tensor<4xf32>
5953  // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5954  // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ({
5955  // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>):
5956  // CHECK:   [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32>
5957  // CHECK:   mhlo.return [[SUM]] : tensor<f32>
5958  // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
5959  // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[REDUCE]] : tensor<4xf32>
5960  // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
5961  // CHECK: return [[REVERSE_BACK]]
5962  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
5963  %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = true} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
5964  func.return %1 : tensor<4xf32>
5965}
5966
5967// -----
5968
5969// CHECK-LABEL: func @cumsum_exclusive_reverse
5970// CHECK-SAME: [[X:%.*]]: tensor<4xf32>
5971func.func @cumsum_exclusive_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> {
5972  // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32>
5973  // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
5974  // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[REVERSE1]] : tensor<4xf32>
5975  // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5976  // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ({
5977  // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>):
5978  // CHECK:   [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32>
5979  // CHECK:   mhlo.return [[SUM]] : tensor<f32>
5980  // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
5981  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
5982  // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[PAD]] : tensor<4xf32>
5983  // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
5984  // CHECK: return [[REVERSE_BACK]]
5985  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
5986  %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = true} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
5987  func.return %1 : tensor<4xf32>
5988}
5989
5990// -----
5991
5992// CHECK-LABEL: func @cumsum_empty
5993func.func @cumsum_empty(%arg0: tensor<0xf32>) -> tensor<0xf32> {
5994  %0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
5995
5996  // CHECK: mhlo.constant dense<> : tensor<0xf32>
5997  %1 = "tf.Cumsum"(%arg0, %0) : (tensor<0xf32>, tensor<i32>) -> tensor<0xf32>
5998  func.return %1 : tensor<0xf32>
5999}
6000
6001// -----
6002
6003// CHECK-LABEL: func @cumsum_dynamic
6004func.func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32> {
6005  // CHECK: "tf.Cumsum"
6006  %0 = "tf.Cumsum"(%arg0, %arg1) : (tensor<?xf32>, tensor<i32>) -> tensor<?xf32>
6007  func.return %0 : tensor<?xf32>
6008}
6009
6010//===----------------------------------------------------------------------===//
6011// Cumprod op legalizations.
6012//===----------------------------------------------------------------------===//
6013
6014// -----
6015
6016// CHECK-LABEL: func @cumprod
6017func.func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> {
6018  // CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
6019  // CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ({
6020  // CHECK:   mhlo.mul
6021  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
6022  %1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
6023  func.return %1 : tensor<4xf32>
6024}
6025
6026//===----------------------------------------------------------------------===//
6027// Qr op legalization
6028//===----------------------------------------------------------------------===//
6029
6030// CHECK:  func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>)
6031func.func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) {
6032  // The tf.Qr lowering is a full algorithm that is not effective to verify with
6033  // FileCheck. Just verify that it converted.
6034  // TODO(laurenzo): Move this out of the mainline tf2xla conversion as it is
6035  // really only applicable to certain legacy uses.
6036  // CHECK-NOT: "tf.Qr"
6037  %0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>)
6038  func.return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32>
6039}
6040
6041//===----------------------------------------------------------------------===//
6042// tf.Softplus legalization
6043//===----------------------------------------------------------------------===//
6044
6045// -----
6046
6047// CHECK-LABEL: func @softplus_f16
6048// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf16>)
6049func.func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> {
6050  // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]]
6051  // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.220700e-04> : tensor<f16>
6052  // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]]
6053  // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f16>
6054  // CHECK:     [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
6055  // CHECK:     [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]]
6056  // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #mhlo<comparison_direction GT>}
6057  // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #mhlo<comparison_direction LT>}
6058  // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]]
6059  // CHECK:     [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
6060  // CHECK:     [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
6061  %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf16>) -> tensor<8x16xf16>
6062
6063  // CHECK:     return [[ENTRY_SELECT]] : tensor<8x16xf16>
6064  func.return %0 : tensor<8x16xf16>
6065}
6066
6067// -----
6068
6069// CHECK-LABEL: func @softplus_bf16
6070// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xbf16>)
6071func.func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> {
6072  // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]]
6073  // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<7.812500e-03> : tensor<bf16>
6074  // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]]
6075  // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<bf16>
6076  // CHECK:     [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
6077  // CHECK:     [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]]
6078  // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #mhlo<comparison_direction GT>}
6079  // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #mhlo<comparison_direction LT>}
6080  // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]]
6081  // CHECK:     [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
6082  // CHECK:     [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
6083  %0 = "tf.Softplus"(%arg0) : (tensor<8x16xbf16>) -> tensor<8x16xbf16>
6084
6085  // CHECK:     return [[ENTRY_SELECT]] : tensor<8x16xbf16>
6086  func.return %0 : tensor<8x16xbf16>
6087}
6088
6089// -----
6090
6091// CHECK-LABEL: func @softplus_f32
6092// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf32>)
6093func.func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
6094  // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]]
6095  // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.1920929E-7> : tensor<f32>
6096  // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]]
6097  // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
6098  // CHECK:     [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
6099  // CHECK:     [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]]
6100  // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #mhlo<comparison_direction GT>}
6101  // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #mhlo<comparison_direction LT>}
6102  // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]]
6103  // CHECK:     [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
6104  // CHECK:     [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
6105  %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
6106
6107  // CHECK:     return [[ENTRY_SELECT]] : tensor<8x16xf32>
6108  func.return %0 : tensor<8x16xf32>
6109}
6110
6111// -----
6112
6113// CHECK-LABEL: func @softplus_f64
6114// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf64>)
6115func.func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> {
6116  // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]]
6117  // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<2.2204460492503131E-16> : tensor<f64>
6118  // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]]
6119  // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f64>
6120  // CHECK:     [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
6121  // CHECK:     [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]]
6122  // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #mhlo<comparison_direction GT>}
6123  // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #mhlo<comparison_direction LT>}
6124  // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]]
6125  // CHECK:     [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
6126  // CHECK:     [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
6127  %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf64>) -> tensor<8x16xf64>
6128
6129  // CHECK:     return [[ENTRY_SELECT]] : tensor<8x16xf64>
6130  func.return %0 : tensor<8x16xf64>
6131}
6132
6133// -----
6134
6135// CHECK-LABEL: @xla_gather
6136func.func @xla_gather(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<1x300x10xf32> {
6137  %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi64> } : () -> tensor<3xi64>
6138
6139  // CHECK: "mhlo.gather"
6140  // CHECK-SAME: dimension_numbers =
6141  // CHECK-SAME:   offset_dims = [0, 1]
6142  // CHECK-SAME:   collapsed_slice_dims = [0]
6143  // CHECK-SAME:   start_index_map = [0, 1]
6144  // CHECK-SAME:   index_vector_dim = 1
6145  // CHECK-SAME: indices_are_sorted = true
6146  // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>
6147
6148  %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01\20\01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xf32>
6149  func.return %0 : tensor<1x300x10xf32>
6150}
6151
6152// -----
6153
6154// CHECK-LABEL: @xla_gather_i32
6155func.func @xla_gather_i32(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<1x300x10xf32> {
6156  %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi32> } : () -> tensor<3xi32>
6157
6158  // CHECK: "mhlo.gather"
6159  // CHECK-SAME: dimension_numbers =
6160  // CHECK-SAME:   offset_dims = [0, 1]
6161  // CHECK-SAME:   collapsed_slice_dims = [0]
6162  // CHECK-SAME:   start_index_map = [0, 1]
6163  // CHECK-SAME:   index_vector_dim = 1
6164  // CHECK-SAME: indices_are_sorted = true
6165  // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>
6166
6167  %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01\20\01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi32>) -> tensor<1x300x10xf32>
6168  func.return %0 : tensor<1x300x10xf32>
6169}
6170
6171
6172// CHECK: func @stridedslice_with_i32
6173func.func @stridedslice_with_i32(%arg0: tensor<i32>) -> tensor<4xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "const_0_arg", outputs = "identity_0_retval_RetVal"}} {
6174// CHECK-NOT: tf.StridedSlice
6175// CHECK: [[DYNSLICE:%.*]] = "mhlo.dynamic_slice
6176// CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[DYNSLICE]]
6177// CHECK: return [[RESHAPE]]
6178  %0 = "tf.Const"() {value = dense<[[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
6179  %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
6180  %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
6181  %3 = "tf.AddV2"(%arg0, %1) {_xla_inferred_shapes = [#tf_type.shape<>], device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
6182  %4 = "tf.Pack"(%3) {_xla_inferred_shapes = [#tf_type.shape<1>], axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
6183  %5 = "tf.Pack"(%arg0) {_xla_inferred_shapes = [#tf_type.shape<1>], axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
6184  %6 = "tf.StridedSlice"(%0, %5, %4, %2) {_xla_inferred_shapes = [#tf_type.shape<4>], begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2x4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf32>
6185  func.return %6 : tensor<4xf32>
6186}
6187
6188func.func @replica_id() -> tensor<i32> {
6189  // CHECK: %[[ID:.*]] = mhlo.replica_id : tensor<ui32>
6190  // CHECK: %[[RESULT:.*]] = mhlo.convert(%0) : (tensor<ui32>) -> tensor<i32>
6191  %0 = "tf.XlaReplicaId"() : () -> tensor<i32>
6192  func.return %0 : tensor<i32>
6193}
6194
6195// CHECK: func @angle_c64
6196// CHECK-SAME: ([[ARG0:%.*]]: tensor<complex<f32>>)
6197func.func @angle_c64(%arg0: tensor<complex<f32>>) -> tensor<f32> {
6198// CHECK: [[IMAG:%.*]] = mhlo.imag([[ARG0]])
6199// CHECK: [[REAL:%.*]] = mhlo.real([[ARG0]])
6200// CHECK: [[ATAN2:%.*]] = mhlo.atan2 [[IMAG]], [[REAL]]
6201  %0 = "tf.Angle"(%arg0): (tensor<complex<f32>>) -> tensor<f32>
6202  func.return %0 : tensor<f32>
6203}
6204
6205//===----------------------------------------------------------------------===//
6206// tf.ApproximateEqual legalization
6207//===----------------------------------------------------------------------===//
6208
6209// CHECK-LABEL: func @approximateequal_f64
6210func.func @approximateequal_f64(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) -> tensor<?xi1> {
6211  // CHECK: %[[SUB:.*]] = mhlo.subtract %arg0, %arg1 : tensor<?xf64>
6212  // CHECK: %[[ABS:.*]] = mhlo.abs %[[SUB]] : tensor<?xf64>
6213  // CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
6214  // CHECK: %[[CONVERT:.*]] = mhlo.convert(%[[CST]]) : (tensor<f32>) -> tensor<f64>
6215  // CHECK: %[[LE:.*]] = chlo.broadcast_compare %[[ABS]], %[[CONVERT]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<?xf64>, tensor<f64>) -> tensor<?xi1>
6216  // CHECK: return %[[LE]] : tensor<?xi1>
6217  %equal = "tf.ApproximateEqual"(%arg0, %arg1) { tolerance = 2. : f32 } : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xi1>
6218  func.return %equal : tensor<?xi1>
6219}
6220
6221// CHECK-LABEL: func @approximateequal_i32
6222func.func @approximateequal_i32(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi1> {
6223  // CHECK: %[[SUB:.*]] = mhlo.subtract %arg0, %arg1 : tensor<?xi32>
6224  // CHECK: %[[ABS:.*]] = mhlo.abs %[[SUB]] : tensor<?xi32>
6225  // CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
6226  // CHECK: %[[CONVERT:.*]] = mhlo.convert(%[[CST]]) : (tensor<f32>) -> tensor<i32>
6227  // CHECK: %[[LE:.*]] = chlo.broadcast_compare %[[ABS]], %[[CONVERT]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi1>
6228  // CHECK: return %[[LE]] : tensor<?xi1>
6229  %equal = "tf.ApproximateEqual"(%arg0, %arg1) { tolerance = 2. : f32 } : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi1>
6230  func.return %equal : tensor<?xi1>
6231}
6232
6233// CHECK-LABEL: func @approximateequal_complex64
6234func.func @approximateequal_complex64(%arg0: tensor<?xcomplex<f32>>, %arg1: tensor<?xcomplex<f32>>) -> tensor<?xi1> {
6235  // CHECK: %[[SUB:.*]] = mhlo.subtract %arg0, %arg1 : tensor<?xcomplex<f32>>
6236  // CHECK: %[[ABS:.*]] = mhlo.abs(%[[SUB]]) : (tensor<?xcomplex<f32>>) -> tensor<?xf32>
6237  // CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
6238  // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[CST]] : tensor<f32>
6239  // CHECK: %[[LE:.*]] = chlo.broadcast_compare %[[ABS]], %[[CONVERT]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<?xf32>, tensor<f32>) -> tensor<?xi1>
6240  // CHECK: return %[[LE]] : tensor<?xi1>
6241  %equal = "tf.ApproximateEqual"(%arg0, %arg1) { tolerance = 2. : f32 } : (tensor<?xcomplex<f32>>, tensor<?xcomplex<f32>>) -> tensor<?xi1>
6242  func.return %equal : tensor<?xi1>
6243}
6244
6245//===----------------------------------------------------------------------===//
6246// tf.XlaConvV2 legalization
6247//===----------------------------------------------------------------------===//
6248
6249// -----
6250
6251// CHECK-LABEL: xla_conv_v2
6252func.func @xla_conv_v2(%lhs: tensor<8x4x16x16x16xf32>, %rhs: tensor<4x3x3x16x16xf32>) -> (tensor<4x4x14x14x16xf32>) {
6253  %feature_group_count = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
6254  %lhs_dilation = "tf.Const"() {value = dense<[4, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
6255  %rhs_dilation = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32>
6256  %padding = "tf.Const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
6257  %strides = "tf.Const"() {value = dense<[3, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
6258  // CHECK: mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], window = {stride = [3, 1, 1], pad = {{\[\[}}0, 0], {{\[}}0, 0], {{\[}}0, 0]], lhs_dilate = [4, 1, 1], rhs_dilate = [1, 1, 1]} {batch_group_count = 2 : i64, feature_group_count = 1 : i64, precision_config = []} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>) -> tensor<4x4x14x14x16xf32>
6259  %0 = "tf.XlaConvV2"(%lhs, %rhs, %strides, %padding, %lhs_dilation, %rhs_dilation, %feature_group_count) {batch_group_count = 2 : i64, dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<4x4x14x14x16xf32>
6260  func.return %0 : tensor<4x4x14x14x16xf32>
6261}
6262
6263//===----------------------------------------------------------------------===//
6264// tf.XlaDot legalization
6265//===----------------------------------------------------------------------===//
6266
6267// -----
6268
6269// CHECK-LABEL: @xladot_matmul(
6270// CHECK-SAME:    %[[LHS:.*]]: tensor<64x32xi8>, %[[RHS:.*]]: tensor<32x16xi8>) -> tensor<64x16xi32>
6271func.func @xladot_matmul(%lhs : tensor<64x32xi8>, %rhs : tensor<32x16xi8>) -> tensor<64x16xi32> {
6272  // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) {
6273  // CHECK-SAME:  dot_dimension_numbers = #mhlo.dot<
6274  // CHECK-NOT:     lhs_batching_dimensions =
6275  // CHECK-NOT:     rhs_batching_dimensions =
6276  // CHECK-SAME:    lhs_contracting_dimensions = [1]
6277  // CHECK-SAME:    rhs_contracting_dimensions = [0]
6278  // CHECK-SAME:  precision_config = []
6279  %res = "tf.XlaDot"(%lhs, %rhs) {dimension_numbers = "\0A\01\01\12\01\00", precision_config = ""} : (tensor<64x32xi8>, tensor<32x16xi8>) -> tensor<64x16xi32>
6280  func.return %res : tensor<64x16xi32>
6281}
6282
6283//===----------------------------------------------------------------------===//
6284// tf.XlaDotV2 legalization
6285//===----------------------------------------------------------------------===//
6286
6287// -----
6288
6289// CHECK-LABEL: @xladotv2_matmul(
6290// CHECK-SAME:    %[[LHS:.*]]: tensor<64x32xi8>, %[[RHS:.*]]: tensor<32x16xi8>) -> tensor<64x16xi32>
6291func.func @xladotv2_matmul(%lhs : tensor<64x32xi8>, %rhs : tensor<32x16xi8>) -> tensor<64x16xi32> {
6292  // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) {
6293  // CHECK-SAME:  dot_dimension_numbers = #mhlo.dot<
6294  // CHECK-NOT:     lhs_batching_dimensions =
6295  // CHECK-NOT:     rhs_batching_dimensions =
6296  // CHECK-SAME:    lhs_contracting_dimensions = [1]
6297  // CHECK-SAME:    rhs_contracting_dimensions = [0]
6298  // CHECK-SAME:  precision_config = []
6299  %res = "tf.XlaDotV2"(%lhs, %rhs) {dimension_numbers = "\0A\01\01\12\01\00", precision_config = ""} : (tensor<64x32xi8>, tensor<32x16xi8>) -> tensor<64x16xi32>
6300  func.return %res : tensor<64x16xi32>
6301}
6302
6303//===----------------------------------------------------------------------===//
6304// tf.XlaDynamicSlice legalization
6305//===----------------------------------------------------------------------===//
6306// -----
6307
6308// CHECK-LABEL: xla_dynamic_slice_constant_start
6309func.func @xla_dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
6310  // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<i64>
6311  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
6312  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
6313  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} :
6314  // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64>
6315  // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64>
6316  // CHECK-NEXT: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]])
6317  // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} :
6318  // CHECK-DAG-SAME: (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
6319  // CHECK-NEXT: return %[[RESULT]] : tensor<2xi32>
6320  %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
6321  %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>)
6322  %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32>
6323  func.return %0 : tensor<2xi32>
6324}
6325
6326// -----
6327
6328// CHECK-LABEL: xla_dynamic_slice_i32_consts
6329func.func @xla_dynamic_slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> {
6330  // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<i32>
6331  // CHECK: "mhlo.dynamic_slice"(%arg0, %[[START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i32>) -> tensor<2xi32>
6332  %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
6333  %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>)
6334  %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
6335  func.return %0 : tensor<2xi32>
6336}
6337
6338// -----
6339
6340// CHECK-LABEL: xla_dynamic_slice_constant_start_dynamic_shape
6341func.func @xla_dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
6342  // CHECK-DAG: %[[START1:.*]] = mhlo.constant dense<1> : tensor<i64>
6343  // CHECK-DAG: %[[START2:.*]] = mhlo.constant dense<0> : tensor<i64>
6344  // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"
6345  // CHECK-DAG-SAME: (%arg0, %[[START1]], %[[START2]])
6346  // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} :
6347  // CHECK-DAG-SAME: (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
6348  // CHECK: return %[[RESULT]] : tensor<1x4xi32>
6349  %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
6350  %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
6351  %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor<?x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
6352  func.return %0 : tensor<1x4xi32>
6353}
6354
6355// -----
6356
6357// CHECK-LABEL: xla_dynamic_slice_variable_start
6358func.func @xla_dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
6359  // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%arg1)
6360  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
6361  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
6362  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
6363  // CHECK: %[[RESHAPED_START1:.*]] = mhlo.reshape %[[SLICED_START1]] : (tensor<1xi64>) -> tensor<i64>
6364  // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%arg1)
6365  // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>,
6366  // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>,
6367  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
6368  // CHECK: %[[RESHAPED_START2:.*]] = mhlo.reshape %[[SLICED_START2]] : (tensor<1xi64>) -> tensor<i64>
6369  // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
6370  // CHECK: return %[[RESULT]] : tensor<1x4xi32>
6371  %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
6372  %0 = "tf.XlaDynamicSlice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
6373  func.return %0 : tensor<1x4xi32>
6374}
6375
6376// -----
6377
6378// CHECK-LABEL: xla_dynamic_slice_mhlo_sizes
6379func.func @xla_dynamic_slice_mhlo_sizes(%arg0: tensor<1x1024x4xf32>, %arg1: tensor<3xi32>) -> tensor<1x512x4xf32> {
6380  // CHECK-NOT: "tf.XlaDynamicSlice"
6381  %0 = "mhlo.constant"() {value = dense<[1, 512, 4]> : tensor<3xi32>} : () -> tensor<3xi32>
6382  %1 = "tf.XlaDynamicSlice"(%arg0, %arg1, %0) : (tensor<1x1024x4xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x4xf32>
6383  func.return %1 : tensor<1x512x4xf32>
6384}
6385
6386//===----------------------------------------------------------------------===//
6387// tf.XlaEinsum legalization
6388//===----------------------------------------------------------------------===//
6389
6390// -----
6391
6392// CHECK-LABEL: func @xlaeinsum
6393func.func @xlaeinsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> {
6394  // CHECK-NEXT:  mhlo.einsum
6395  %0 = "tf.XlaEinsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32>
6396  func.return %0: tensor<2x4xf32>
6397}
6398
6399
6400//===----------------------------------------------------------------------===//
6401// tf.XlaReduceWindow legalization
6402//===----------------------------------------------------------------------===//
6403// -----
6404// CHECK-LABEL: @test_xla_reduce_window
6405func.func @test_xla_reduce_window(%arg0: tensor<7xf32>, %arg1: tensor<f32>) -> tensor<10xf32> {
6406  %cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
6407  %cst_0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
6408  %cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
6409  %cst_2 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
6410  %cst_3 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32>
6411  // CHECK: %[[REDUCE:.*]] = "mhlo.reduce_window"(%arg0, %arg1) ({
6412  // CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
6413  // CHECK-NEXT:   %[[SUM:.*]] = func.call @sum_reducer3(%[[ARG0]], %[[ARG1]]){{.*}}
6414  // CHECK-NEXT:   mhlo.return %[[SUM]] : tensor<*xf32>
6415  // CHECK-NEXT: }) {base_dilations = dense<3> : tensor<1xi64>, padding = dense<0> : tensor<1x2xi64>, window_dilations = dense<4> : tensor<1xi64>, window_dimensions = dense<1> : tensor<1xi64>, window_strides = dense<2> : tensor<1xi64>} : (tensor<7xf32>, tensor<f32>) -> tensor<10xf32>
6416  // CHECK-NEXT: return %[[REDUCE]]
6417  %0 = "tf.XlaReduceWindow"(%arg0, %arg1, %cst_0, %cst_1, %cst_2, %cst_3, %cst) {computation = @sum_reducer3} : (tensor<7xf32>, tensor<f32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<10xf32>
6418  func.return %0 : tensor<10xf32>
6419}
6420
6421func.func private @sum_reducer3(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
6422  %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
6423  func.return %0 : tensor<*xf32>
6424}
6425
6426//===----------------------------------------------------------------------===//
6427// tf.XlaSort legalization
6428//===----------------------------------------------------------------------===//
6429
6430// -----
6431
6432// CHECK-LABEL: @xlasort_int
6433// CHECK-SAME: %[[INPUT:.*]]: tensor<16xi32>
6434func.func @xlasort_int(%input: tensor<16xi32>) -> (tensor<16xi32>) {
6435  // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) ({
6436  // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<i32>, %[[RHS:.*]]: tensor<i32>)
6437  // CHECK-NEXT:   %[[CMP:.*]] = mhlo.compare LT, %[[LHS]], %[[RHS]], NOTYPE
6438  // CHECK-NEXT:   mhlo.return %[[CMP]]
6439  // CHECK-NEXT: }) {dimension = -1 : i64, is_stable = false} : (tensor<16xi32>) -> tensor<16xi32>
6440  // CHECK-NEXT: return %[[SORT]]
6441  %output = "tf.XlaSort"(%input) : (tensor<16xi32>) -> (tensor<16xi32>)
6442  func.return %output : tensor<16xi32>
6443}
6444
6445// -----
6446
6447// CHECK-LABEL: @xlasort_float
6448// CHECK-SAME: %[[INPUT:.*]]: tensor<8xf64>
6449func.func @xlasort_float(%input: tensor<8xf64>) -> (tensor<8xf64>) {
6450  // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) ({
6451  // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<f64>, %[[RHS:.*]]: tensor<f64>)
6452  // CHECK-NEXT:   %[[CMP:.*]] = mhlo.compare LT, %[[LHS]], %[[RHS]], TOTALORDER
6453  // CHECK-NEXT:   mhlo.return %[[CMP]]
6454  // CHECK-NEXT: }) {dimension = -1 : i64, is_stable = false} : (tensor<8xf64>) -> tensor<8xf64>
6455  // CHECK-NEXT: return %[[SORT]]
6456  %output = "tf.XlaSort"(%input) : (tensor<8xf64>) -> (tensor<8xf64>)
6457  func.return %output : tensor<8xf64>
6458}
6459
6460// -----
6461
6462// CHECK-LABEL: @xlasort_const
6463func.func @xlasort_const() -> (tensor<2x3xi64>) {
6464  // CHECK: [2, 4, 3], [6, 5, 1]
6465  %input = "tf.Const"() {value = dense<[[2, 4, 3], [6, 5, 1]]> : tensor<2x3xi64>} : () -> (tensor<2x3xi64>)
6466  // CHECK-NEXT: [2, 3, 4], [1, 5, 6]
6467  %output = "tf.XlaSort"(%input): (tensor<2x3xi64>) -> (tensor<2x3xi64>)
6468  func.return %output : tensor<2x3xi64>
6469}
6470
6471//===----------------------------------------------------------------------===//
6472// tf.XlaRngBitGenerator legalization
6473//===----------------------------------------------------------------------===//
6474
6475// CHECK-LABEL: @xla_rng_bit_generator
6476// CHECK-SAME: %[[STATE:.*]]: tensor<2xui64>
6477func.func @xla_rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1,_arg2", outputs = "_retval0,_retval1"}} {
6478  // CHECK-NEXT: %0 = mhlo.constant dense<[10, 12]> : tensor<2xi32>
6479  %cst = "tf.Const"() {value = dense<[10, 12]> : tensor<2xi32>} : () -> tensor<2xi32>
6480  // CHECK-NEXT: %1 = mhlo.constant dense<3> : tensor<i32>
6481  %cst_0 = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
6482  // CHECK-NEXT: %[[OUTPUT_STATE:.*]], %[[OUTPUT:.*]] = "mhlo.rng_bit_generator"(%[[STATE]]) {rng_algorithm = #mhlo.rng_algorithm<DEFAULT>} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>)
6483  // CHECK-NEXT: return %[[OUTPUT_STATE]], %[[OUTPUT]] : tensor<2xui64>, tensor<10x12xui32>
6484  %output_key, %output = "tf.XlaRngBitGenerator"(%cst_0, %arg0, %cst) : (tensor<i32>, tensor<2xui64>, tensor<2xi32>) -> (tensor<2xui64>, tensor<10x12xui32>)
6485  func.return %output_key, %output : tensor<2xui64>, tensor<10x12xui32>
6486}
6487
6488//===----------------------------------------------------------------------===//
6489// tf.XlaVariadicV2 legalization
6490//===----------------------------------------------------------------------===//
6491
6492// -----
6493// CHECK-LABEL: @xla_variadic_reduce_v2
6494func.func @xla_variadic_reduce_v2(%arg0: tensor<2x3xcomplex<f64>>, %arg1: tensor<complex<f64>>) -> tensor<3xcomplex<f64>> attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} {
6495  // CHECK: %[[REDUCE:.*]] = mhlo.reduce(%arg0 init: %arg1)
6496  // CHECK-SAME: dimensions = [0]
6497  // CHECK-NEXT: (%[[ARG0:.*]]: tensor<complex<f64>>, %[[ARG1:.*]]: tensor<complex<f64>>)
6498  // CHECK-NEXT:   %[[SUM:.*]] = func.call @sum_reducer(%[[ARG0]], %[[ARG1]]){{.*}}
6499  // CHECK-NEXT:   mhlo.return %[[SUM]] : tensor<complex<f64>>
6500  // CHECK: return %[[REDUCE]]
6501  %0 = "tf.XlaVariadicReduceV2"(%arg0, %arg1) {_XlaHasReferenceVars = false, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", dimensions_to_reduce = [0], operand_segment_sizes = array<i32: 1, 1>, reducer = @sum_reducer} : (tensor<2x3xcomplex<f64>>, tensor<complex<f64>>) -> tensor<3xcomplex<f64>>
6502  func.return %0 : tensor<3xcomplex<f64>>
6503}
6504
6505func.func private @sum_reducer(%arg0: tensor<complex<f64>>, %arg1: tensor<complex<f64>>) -> tensor<complex<f64>> {
6506  %0 = "tf.AddV2"(%arg1, %arg0) : (tensor<complex<f64>>, tensor<complex<f64>>) -> tensor<complex<f64>>
6507  func.return %0 : tensor<complex<f64>>
6508}
6509
6510// -----
6511
6512// CHECK-LABEL: @xla_variadic_reduce_v2_dynamic
6513func.func @xla_variadic_reduce_v2_dynamic(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} {
6514  // CHECK: %[[REDUCE:.*]] = mhlo.reduce(%arg0 init: %arg1)
6515  // CHECK-SAME: dimensions = [0]
6516  // CHECK-NEXT: (%[[ARG0:.*]]: tensor<i32>, %[[ARG1:.*]]: tensor<i32>)
6517  // CHECK-NEXT:   %[[SUM:.*]] = func.call @sum_reducer2(%[[ARG0]], %[[ARG1]]){{.*}}
6518  // CHECK-NEXT:   mhlo.return %[[SUM]] : tensor<i32>
6519  // CHECK: return %[[REDUCE]]
6520  %0 = "tf.XlaVariadicReduceV2"(%arg0, %arg1) {_XlaHasReferenceVars = false, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", dimensions_to_reduce = [0], operand_segment_sizes = array<i32: 1, 1>, reducer = @sum_reducer2} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
6521  func.return %0 : tensor<*xi32>
6522}
6523
6524func.func private @sum_reducer2(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
6525  %0 = "tf.AddV2"(%arg1, %arg0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
6526  func.return %0 : tensor<i32>
6527}
6528
6529//===----------------------------------------------------------------------===//
6530// tf.XlaVariadicSort legalization
6531//===----------------------------------------------------------------------===//
6532
6533// CHECK-LABEL: @xla_variadic_sort
6534// CHECK-SAME: %[[INPUT:.*]]: tensor<2x3x4xui8>
6535func.func @xla_variadic_sort(%arg0: tensor<2x3x4xui8>) -> tensor<2x3x4xui8> attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} {
6536  // CHECK-NEXT: {{.*}} = mhlo.constant dense<0> : tensor<i32>
6537  %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
6538  // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) ({
6539  // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<ui8>, %[[RHS:.*]]: tensor<ui8>)
6540  // CHECK-NEXT:   %[[CMP:.*]] = func.call @compare_lt(%[[LHS]], %[[RHS]]) : (tensor<ui8>, tensor<ui8>) -> tensor<i1>
6541  // CHECK-NEXT:   mhlo.return %[[CMP]]
6542  // CHECK-NEXT: }) {dimension = 0 : i64, is_stable = false} : (tensor<2x3x4xui8>) -> tensor<2x3x4xui8>
6543  // CHECK-NEXT: return %[[SORT]]
6544  %0 = "tf.XlaVariadicSort"(%arg0, %cst) {_XlaHasReferenceVars = false, comparator = @compare_lt, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", is_stable = false} : (tensor<2x3x4xui8>, tensor<i32>) -> tensor<2x3x4xui8>
6545  func.return %0 : tensor<2x3x4xui8>
6546}
6547
6548func.func private @compare_lt(%arg0: tensor<ui8>, %arg1: tensor<ui8>) -> tensor<i1> attributes {tf._disable_call_shape_inference = true} {
6549  %0 = "tf.Less"(%arg0, %arg1) {device = ""} : (tensor<ui8>, tensor<ui8>) -> tensor<i1>
6550    func.return %0 : tensor<i1>
6551}
6552
6553//===----------------------------------------------------------------------===//
6554// tf.NextAfter legalization
6555//===----------------------------------------------------------------------===//
6556// CHECK-LABEL: func @nextafter
6557func.func @nextafter(%arg0: tensor<2xf32>, %arg1 : tensor<2xf32>) -> tensor<2xf32> {
6558  // CHECK-NEXT:  %0 = chlo.broadcast_next_after %arg0, %arg1 : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
6559  // CHECK-NEXT:  return %0 : tensor<2xf32>
6560  %0 = "tf.NextAfter"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
6561  func.return %0: tensor<2xf32>
6562}
6563
6564//===----------------------------------------------------------------------===//
6565// tf.XlaReduceScatter legalization
6566//===----------------------------------------------------------------------===//
6567// CHECK-LABEL: func @xla_reduce_scatter
6568func.func @xla_reduce_scatter(%arg0: tensor<128x128xf32>) -> tensor<64x128xf32> {
6569    %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
6570    %cst_0 = "tf.Const"() {value = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32>
6571    // CHECK:          "mhlo.reduce_scatter"(%arg0)
6572    // CHECK{LITERAL}: replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]>
6573    // CHECK-SAME:     scatter_dimension = 0
6574    //
6575    %1 = "tf.XlaReduceScatter"(%arg0, %cst_0, %cst) {reduce_op = "Add"} : (tensor<128x128xf32>, tensor<4x2xi32>, tensor<i32>) -> tensor<64x128xf32>
6576    func.return %1 : tensor<64x128xf32>
6577}
6578
6579
6580//===----------------------------------------------------------------------===//
6581// tf.XlaSelectAndScatter legalization
6582//===----------------------------------------------------------------------===//
6583func.func @test_xla_select_and_scatter(%arg0: tensor<4x5x1x1xbf16>, %arg1: tensor<2x2x1x1xbf16>, %arg2: tensor<bf16>) -> tensor<?x?x?x?xbf16> {
6584  %cst = "tf.Const"() {value = dense<0> : tensor<4x2xi32>} : () -> tensor<4x2xi32>
6585  %cst_0 = "tf.Const"() {value = dense<[2, 2, 1, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
6586  %cst_1 = "tf.Const"() {value = dense<[2, 3, 1, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
6587  // CHECK: %[[SELECT_AND_SCATTER:.*]] = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) ({
6588  // CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: tensor<*xbf16>, %[[ARG1:.*]]: tensor<*xbf16>)
6589  // CHECK-NEXT:   %[[RES:.*]] = func.call @ge_select(%[[ARG0]], %[[ARG1]]){{.*}}
6590  // CHECK-NEXT:   mhlo.return %[[RES]] : tensor<*xi1>
6591  // CHECK-NEXT: },  {
6592  // CHECK-NEXT: ^{{.*}}(%[[ARG2:.*]]: tensor<*xbf16>, %[[ARG3:.*]]: tensor<*xbf16>)
6593  // CHECK-NEXT:   %[[RES:.*]] = func.call @add_scatter(%[[ARG2]], %[[ARG3]]){{.*}}
6594  // CHECK-NEXT:   mhlo.return %[[RES]] : tensor<*xbf16>
6595  // CHECK-NEXT: }) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[2, 3, 1, 1]> : tensor<4xi64>, window_strides = dense<[2, 2, 1, 1]> : tensor<4xi64>} : (tensor<4x5x1x1xbf16>, tensor<2x2x1x1xbf16>, tensor<bf16>) -> tensor<?x?x?x?xbf16>
6596  // CHECK-NEXT: return %[[SELECT_AND_SCATTER]]
6597  %0 = "tf.XlaSelectAndScatter"(%arg0, %cst_1, %cst_0, %cst, %arg1, %arg2) {scatter = @add_scatter, select = @ge_select} : (tensor<4x5x1x1xbf16>, tensor<4xi32>, tensor<4xi32>, tensor<4x2xi32>, tensor<2x2x1x1xbf16>, tensor<bf16>) -> tensor<?x?x?x?xbf16>
6598  func.return %0 : tensor<?x?x?x?xbf16>
6599}
6600
6601func.func private @add_scatter(%arg0: tensor<*xbf16>, %arg1: tensor<*xbf16>) -> tensor<*xbf16> {
6602  %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor<*xbf16>, tensor<*xbf16>) -> tensor<*xbf16>
6603  func.return %0 : tensor<*xbf16>
6604}
6605
6606func.func private @ge_select(%arg0: tensor<*xbf16>, %arg1: tensor<*xbf16>) -> tensor<*xi1> {
6607  %0 = "tf.GreaterEqual"(%arg0, %arg1) {device = ""} : (tensor<*xbf16>, tensor<*xbf16>) -> tensor<*xi1>
6608  func.return %0 : tensor<*xi1>
6609}
6610
6611//===----------------------------------------------------------------------===//
6612// tf.XlaOptimizationBarrier legalization
6613//===----------------------------------------------------------------------===//
6614
6615func.func @test_xla_optimization_barrier(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xi32>) -> (tensor<4x4xf32>, tensor<3x4xi32>) {
6616  // CHECK: %[[OPT_BARRIER:.*]]:2 = mhlo.optimization_barrier %arg0, %arg1
6617  // CHECK-NEXT: return %[[OPT_BARRIER]]#0, %[[OPT_BARRIER]]#1
6618  %0, %1 = "tf.XlaOptimizationBarrier"(%arg0, %arg1) : (tensor<4x4xf32>, tensor<3x4xi32>) -> (tensor<4x4xf32>, tensor<3x4xi32>)
6619  func.return %0, %1 : tensor<4x4xf32>, tensor<3x4xi32>
6620}
6621