1// RUN: xla-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" -split-input-file %s | FILECHECK_OPTS="" FileCheck %s 2// RUN: xla-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -split-input-file -verify-diagnostics %s | FileCheck %s --check-prefix CHLO --dump-input-filter=all 3// This test runs twice: 4// 1. Through FILECHECK_OPTS="" FileCheck with chlo legalization disabled since verifying 5// that the chlo ops emit produces more useful tests. 6// 2. With chlo legalization enabled, verifying diagnostics to pick up any 7// issues with the full lowering (can catch some broadcasting corner 8// cases which emit with a warning). 9 10//===----------------------------------------------------------------------===// 11// BatchNorm op legalizations. 12//===----------------------------------------------------------------------===// 13 14// ----- 15 16// fusedBatchNormV2 is almost identical to fusedBatchNormV3 (and uses the same 17// code), so only do a couple of basic checks. 18 19// CHECK-LABEL: fusedBatchNormV2_noTraining 20func.func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 21 // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> 22 %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 23 func.return %0#0 : tensor<8x8x8x8xf32> 24} 25 26// ----- 27 28// CHECK-LABEL: fusedBatchNormV2_training 29func.func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 30 // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) 31 %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 32 // CHECK: mhlo.constant 33 // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 34 func.return %0#0 : tensor<8x8x8x8xf32> 35} 36 37// ----- 38 39// CHECK-LABEL: fusedBatchNormV3_noTraining 40func.func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 41 // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> 42 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 43 func.return %0#0 : tensor<8x8x8x8xf32> 44} 45 46// ----- 47 48// CHECK-LABEL: fusedBatchNormV3_noTraining_mixedPrecision 49// CHECK-SAME: ([[X:%.*]]: tensor<8x8x8x8xbf16>, [[SCALE:%.*]]: tensor<8xf32>, [[OFFSET:%.*]]: tensor<8xf32>, [[MEAN:%.*]]: tensor<8xf32>, [[VARIANCE:%.*]]: tensor<8xf32>) 50func.func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) { 51 // CHECK: [[CONVERT_X:%.*]] = mhlo.convert([[X]]) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 52 // CHECK: [[Y:%.*]] = "mhlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} 53 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) 54 // CHECK: [[Y_CONVERT:%.*]] = mhlo.convert([[Y]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 55 // CHECK: [[DUMMY:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<0xf32> 56 // CHECK: [[DUMMY_CAST:%.*]] = tensor.cast [[DUMMY]] : tensor<0xf32> to tensor<*xf32> 57 // CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY_CAST]] 58 func.return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32> 59} 60 61// ----- 62 63// CHECK-LABEL: fusedBatchNormV3_training 64func.func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 65 // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) 66 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 67 // CHECK: mhlo.constant 68 // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 69 func.return %0#0 : tensor<8x8x8x8xf32> 70} 71 72// ----- 73 74// CHECK-LABEL: func @fusedBatchNormV3_training_batchVariance 75func.func @fusedBatchNormV3_training_batchVariance(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8xf32> { 76 // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) 77 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 78 // CHECK: return %[[VAR]] 79 func.return %0#4 : tensor<8xf32> 80} 81 82// ----- 83 84// CHECK-LABEL: fusedBatchNormV3_training_exponentialAvgFactor 85func.func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { 86 // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) 87 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 0.8 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 88 // CHECK: %[[FACTOR:.*]] = mhlo.constant dense<1.00195694> 89 // CHECK: %[[CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[VAR]], %[[FACTOR]] 90 91 // CHECK-DAG: %[[ALPHA:.*]] = mhlo.constant dense<0.199999988> 92 // CHECK-DAG: %[[BETA:.*]] = mhlo.constant dense<8.000000e-01> 93 94 // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg3 95 // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = chlo.broadcast_multiply %[[BETA]], %[[MEAN]] 96 // CHECK: %[[NEW_BATCH_MEAN:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] 97 98 // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg4 99 // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]] 100 // CHECK: %[[NEW_BATCH_VAR:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] 101 102 // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[MEAN]], %[[VAR]] 103 func.return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> 104} 105 106// ----- 107 108// CHECK-LABEL: fusedBatchNormV3_training_mixedPrecision 109func.func @fusedBatchNormV3_training_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { 110 // CHECK: mhlo.convert(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 111 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 112 // CHECK: mhlo.convert({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 113 func.return %0#0 : tensor<8x8x8x8xbf16> 114} 115 116// ----- 117 118// CHECK-LABEL: fusedBatchNormV3_NCHW 119func.func @fusedBatchNormV3_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 120 // CHECK: "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) 121 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 122 func.return %0#0 : tensor<8x8x8x8xf32> 123} 124 125// ----- 126 127// CHECK-LABEL: fusedBatchNormV3_NDHWC 128func.func @fusedBatchNormV3_NDHWC(%arg0: tensor<8x8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8x8xf32>) { 129 // CHECK: feature_index = 4 : i64 130 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NDHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 131 func.return %0#0 : tensor<8x8x8x8x8xf32> 132} 133 134// ----- 135 136// CHECK-LABEL: fusedBatchNormV3_noTraining_dynamic_supported 137func.func @fusedBatchNormV3_noTraining_dynamic_supported(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>) -> (tensor<?x?x?x?xf32>) { 138 // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?x?x?x?xf32> 139 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) 140 func.return %0#0 : tensor<?x?x?x?xf32> 141} 142 143// ----- 144 145// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported1 146func.func @fusedBatchNormV3_training_dynamic_unsupported1(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>) -> (tensor<?x?x?x?xf32>) { 147 // CHECK: tf.FusedBatchNormV3 148 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) 149 func.return %0#0 : tensor<?x?x?x?xf32> 150} 151 152// ----- 153 154// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported2 155func.func @fusedBatchNormV3_training_dynamic_unsupported2(%arg0: tensor<?x6x?x?xf32>, %arg1: tensor<6xf32>, %arg2: tensor<6xf32>, %arg3: tensor<6xf32>, %arg4: tensor<6xf32>) -> (tensor<?x6x?x?xf32>) { 156 // CHECK: tf.FusedBatchNormV3 157 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<?x6x?x?xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) -> (tensor<?x6x?x?xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) 158 func.return %0#0 : tensor<?x6x?x?xf32> 159} 160 161// ----- 162 163// CHECK-LABEL: fusedBatchNormGrad_noTraining 164func.func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 165 // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> 166 // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> 167 // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32> 168 169 // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 170 // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> 171 172 // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 173 // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> 174 // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> 175 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 176 // CHECK-NEXT: %[[cmul:.*]] = mhlo.convert %[[mul]] : tensor<8x8x8x8xf32> 177 // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32> 178 // CHECK-NEXT: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[init]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 179 // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32> 180 181 // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> 182 // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 183 // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> 184 // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> 185 186 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 187 // CHECK-NEXT: %[[cgrad:.*]] = mhlo.convert %[[grad]] : tensor<8x8x8x8xf32> 188 // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32> 189 // CHECK-NEXT: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[init2]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 190 // CHECK-NEXT: %[[offset_backprop:.*]] = mhlo.convert %[[red2]] : tensor<8xf32> 191 192 // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[mul3]] : tensor<8x8x8x8xf32> 193 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 194 195 %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 196 func.return %0#0 : tensor<8x8x8x8xf32> 197} 198 199// ----- 200 201// CHECK-LABEL: fusedBatchNormGrad_Training 202func.func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 203 // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> 204 // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> 205 // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) 206 // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : tensor<8x8x8x8xf32> 207 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 208 209 %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 210 func.return %0#0 : tensor<8x8x8x8xf32> 211} 212 213// ----- 214 215// CHECK-LABEL: fusedBatchNormGradV2_noTraining 216func.func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 217 // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> 218 // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> 219 // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32> 220 221 // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 222 // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> 223 224 // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 225 // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> 226 // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> 227 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 228 // CHECK-NEXT: %[[cmul:.*]] = mhlo.convert %[[mul]] : tensor<8x8x8x8xf32> 229 // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32> 230 // CHECK-NEXT: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[init]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 231 // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32> 232 233 // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> 234 // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 235 // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> 236 237 // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> 238 239 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 240 // CHECK-NEXT: %[[cgrad:.*]] = mhlo.convert %[[grad]] : tensor<8x8x8x8xf32> 241 // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32> 242 // CHECK-NEXT: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[init2]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 243 // CHECK-NEXT: %[[offset_backprop:.*]] = mhlo.convert %[[red2]] : tensor<8xf32> 244 245 // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[mul3]] : tensor<8x8x8x8xf32> 246 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 247 248 %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 249 func.return %0#0 : tensor<8x8x8x8xf32> 250} 251 252// ----- 253 254// CHECK-LABEL: fusedBatchNormGradV2_Training 255func.func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 256 // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> 257 // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> 258 // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) 259 // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : tensor<8x8x8x8xf32> 260 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 261 262 %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 263 func.return %0#0 : tensor<8x8x8x8xf32> 264} 265 266// ----- 267 268// CHECK-LABEL: fusedBatchNormGradV2_noTraining_mixed_precision 269func.func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { 270 // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> 271 // CHECK-NEXT: %[[act:.*]] = mhlo.convert(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 272 273 // CHECK: %[[x_backprop:.*]] = mhlo.convert({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 274 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> 275 276 %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 277 func.return %0#0 : tensor<8x8x8x8xbf16> 278} 279 280// ----- 281 282// CHECK-LABEL: fusedBatchNormGradV2_Training_mixed_precision 283func.func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { 284 // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> 285 // CHECK-NEXT: %[[act:.*]] = mhlo.convert(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 286 // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) 287 // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert(%[[grad_operand]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 288 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> 289 290 %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 291 func.return %0#0 : tensor<8x8x8x8xbf16> 292} 293 294// ----- 295 296// CHECK-LABEL: fusedBatchNormGradV3_noTraining 297func.func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 298 // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> 299 // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> 300 // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32> 301 302 // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 303 // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> 304 305 // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 306 // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> 307 // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> 308 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 309 // CHECK-NEXT: %[[cmul:.*]] = mhlo.convert %[[mul]] : tensor<8x8x8x8xf32> 310 // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32> 311 // CHECK-NEXT: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[init]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 312 // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32> 313 314 // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> 315 // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 316 // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> 317 318 // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> 319 320 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 321 // CHECK-NEXT: %[[cgrad:.*]] = mhlo.convert %[[grad]] : tensor<8x8x8x8xf32> 322 // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32> 323 // CHECK-NEXT: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[init2]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 324 // CHECK-NEXT: %[[offset_backprop:.*]] = mhlo.convert %[[red2]] : tensor<8xf32> 325 326 // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[mul3]] : tensor<8x8x8x8xf32> 327 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 328 329 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 330 func.return %0#0 : tensor<8x8x8x8xf32> 331} 332 333// ----- 334 335// CHECK-LABEL: fusedBatchNormGradV3_Training 336func.func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>) { 337 // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> 338 // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> 339 // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) 340 // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : tensor<8x8x8x8xf32> 341 // CHECK: return %[[x_backprop]] 342 // CHECK-SAME: tensor<8x8x8x8xf32> 343 344 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<0xf32>, tensor<*xf32>) 345 func.return %0#0, %0#3, %0#4 : tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32> 346} 347 348// ----- 349 350// CHECK-LABEL: fusedBatchNormGradV3_noTraining_mixed_precision 351func.func @fusedBatchNormGradV3_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { 352 // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> 353 // CHECK-NEXT: %[[act:.*]] = mhlo.convert(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 354 355 // CHECK: %[[x_backprop:.*]] = mhlo.convert({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 356 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> 357 358 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 359 func.return %0#0 : tensor<8x8x8x8xbf16> 360} 361 362// ----- 363 364// CHECK-LABEL: fusedBatchNormGradV3_Training_mixed_precision 365func.func @fusedBatchNormGradV3_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { 366 // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> 367 // CHECK-NEXT: %[[act:.*]] = mhlo.convert(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 368 // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) 369 // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert(%[[grad_operand]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 370 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> 371 372 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 373 func.return %0#0 : tensor<8x8x8x8xbf16> 374} 375 376// ----- 377 378// CHECK-LABEL: fusedBatchNormGradV3_noTraining_NCHW 379func.func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 380 // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> 381 // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> 382 // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32> 383 384 // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 385 // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> 386 387 // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 388 // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> 389 // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> 390 // CHECK-NEXT: mhlo.constant dense<[0, 2, 3]> : tensor<3xi64> 391 // CHECK-NEXT: %[[cmul:.*]] = mhlo.convert %[[mul]] : tensor<8x8x8x8xf32> 392 // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32> 393 // CHECK-NEXT: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[init]]) applies mhlo.add across dimensions = [0, 2, 3] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 394 // CHECK-NEXT: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32> 395 396 // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> 397 // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 398 // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> 399 400 // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> 401 402 // CHECK-NEXT: mhlo.constant dense<[0, 2, 3]> : tensor<3xi64> 403 // CHECK-NEXT: %[[cgrad:.*]] = mhlo.convert %[[grad]] : tensor<8x8x8x8xf32> 404 // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32> 405 // CHECK-NEXT: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[init2]]) applies mhlo.add across dimensions = [0, 2, 3] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 406 // CHECK-NEXT: %[[offset_backprop:.*]] = mhlo.convert %[[red2]] : tensor<8xf32> 407 408 // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[mul3]] : tensor<8x8x8x8xf32> 409 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 410 411 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 412 func.return %0#0 : tensor<8x8x8x8xf32> 413} 414 415// ----- 416 417// CHECK-LABEL: fusedBatchNormGradV3_Training_NCHW 418func.func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 419 // CHECK: %{{.*}} = "mhlo.batch_norm_grad"(%{{.*}}, %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) 420 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 421 func.return %0#0 : tensor<8x8x8x8xf32> 422} 423 424//===----------------------------------------------------------------------===// 425// Bias op legalizations. 426//===----------------------------------------------------------------------===// 427 428// ----- 429 430// CHECK-LABEL: func @biasAdd_default 431func.func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { 432 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 433 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] 434 // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) 435 // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} 436 // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] 437 %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> 438 func.return %0 : tensor<1x32x10x32xi32> 439} 440 441// ----- 442 443// CHECK-LABEL: func @biasAdd_NHWC 444func.func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { 445 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 446 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] 447 // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) 448 // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} 449 // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] 450 %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> 451 func.return %0 : tensor<1x32x10x32xi32> 452} 453 454// ----- 455 456// CHECK-LABEL: func @biasAdd_NCHW 457func.func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { 458 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 459 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] 460 // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) 461 // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} 462 // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] 463 %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> 464 func.return %0 : tensor<1x32x10x32xi32> 465} 466 467// ----- 468 469// CHECK-LABEL: func @biasAdd_dynamic 470func.func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> { 471 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 472 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] 473 // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) 474 // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} 475 // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] 476 %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32> 477 func.return %0 : tensor<?x?x?x?xi32> 478} 479 480// ----- 481 482// CHECK-LABEL: func @biasAdd_partial_dynamic 483func.func @biasAdd_partial_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<512xi32>) -> tensor<?x?x?x512xi32> { 484 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 485 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] 486 // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) 487 // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} 488 // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] 489 // CHECK: %[[CAST:.+]] = tensor.cast %[[RESULT]] : tensor<?x?x?x?xi32> to tensor<?x?x?x512xi32> 490 // CHECK: return %[[CAST]] : tensor<?x?x?x512xi32> 491 %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<?x?x?x?xi32>, tensor<512xi32>) -> tensor<?x?x?x512xi32> 492 func.return %0 : tensor<?x?x?x512xi32> 493} 494 495 496//===----------------------------------------------------------------------===// 497// ClipByValue 498//===----------------------------------------------------------------------===// 499 500// ----- 501 502// CHECK-LABEL: @clip 503func.func @clip(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<f32> { 504 // CHECK: [[VAL:%.+]] = mhlo.clamp %arg1, %arg0, %arg2 505 506 %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32> 507 // CHECK: return [[VAL]] 508 func.return %0 : tensor<f32> 509} 510 511// ----- 512 513// CHECK-LABEL: @clip_dynamic 514func.func @clip_dynamic(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>) -> tensor<?xf32> { 515 // CHECK-DAG: [[CLAMP:%.+]] = mhlo.clamp %arg1, %arg0, %arg2 516 %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> 517 518 // CHECK: return [[CLAMP]] 519 func.return %0 : tensor<?xf32> 520} 521 522// ----- 523 524// CHECK-LABEL: @clip_static_broadcast 525func.func @clip_static_broadcast(%arg0 : tensor<5xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<5xf32> { 526 // CHECK-DAG: [[SHPIDX:%.+]] = mhlo.constant dense<5> 527 // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} 528 // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} 529 // CHECK-DAG: [[CLAMP:%.+]] = mhlo.clamp [[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]] 530 %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<5xf32>, tensor<f32>, tensor<f32>) -> tensor<5xf32> 531 532 // CHECK: return [[CLAMP]] 533 func.return %0 : tensor<5xf32> 534} 535 536 537// CHECK-LABEL: @clip_dynamic_broadcast 538func.func @clip_dynamic_broadcast(%arg0 : tensor<?xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<?xf32> { 539 // CHECK: [[SHP:%.+]] = shape.shape_of %arg0 540 // CHECK: [[SHPIDX:%.+]] = arith.index_cast [[SHP]] : tensor<1xindex> to tensor<1xi32> 541 // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} 542 // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} 543 // CHECK-DAG: [[CLAMP:%.+]] = mhlo.clamp [[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]] 544 %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<?xf32>, tensor<f32>, tensor<f32>) -> tensor<?xf32> 545 546 // CHECK: return [[CLAMP]] 547 func.return %0 : tensor<?xf32> 548} 549 550//===----------------------------------------------------------------------===// 551// DiagPart 552//===----------------------------------------------------------------------===// 553 554// ----- 555 556// CHECK-LABEL: func @diag_part 557// CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32> 558func.func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { 559 // CHECK: %[[RS:.*]] = mhlo.reshape %[[ARG]] : (tensor<4x3x4x3xf32>) -> tensor<12x12xf32> 560 // CHECK-DAG: %[[IOTA0:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<12x12xi32> 561 // CHECK-DAG: %[[IOTA1:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<12x12xi32> 562 // CHECK-DAG: %[[COMP:.*]] = mhlo.compare EQ, %[[IOTA0]], %[[IOTA1]], NOTYPE : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1> 563 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 564 // CHECK-DAG: %[[ZERO_MAT:.*]] = "mhlo.broadcast"(%[[ZERO]]) {broadcast_sizes = dense<12> : tensor<2xi64>} : (tensor<f32>) -> tensor<12x12xf32> 565 // CHECK-DAG: %[[SEL:.*]] = "mhlo.select"(%[[COMP]], %[[RS]], %[[ZERO_MAT]]) : (tensor<12x12xi1>, tensor<12x12xf32>, tensor<12x12xf32>) -> tensor<12x12xf32> 566 // CHECK-DAG: %[[RED:.*]] = mhlo.reduce(%[[SEL]] init: %[[ZERO]]) applies mhlo.add across dimensions = [0] : (tensor<12x12xf32>, tensor<f32>) -> tensor<12xf32> 567 // CHECK-DAG: %[[RES:.*]] = mhlo.reshape %[[RED]] : (tensor<12xf32>) -> tensor<4x3xf32> 568 // CHECK-DAG: return %[[RES]] : tensor<4x3xf32> 569 %0 = "tf.DiagPart"(%arg0) : (tensor<4x3x4x3xf32>) -> tensor<4x3xf32> 570 func.return %0: tensor<4x3xf32> 571} 572 573//===----------------------------------------------------------------------===// 574// MatrixDiagPart 575//===----------------------------------------------------------------------===// 576 577// ----- 578 579// CHECK-LABEL: func @matrix_diag_part 580// CHECK-SAME: %[[ARG:.*]]: tensor<7x140x128xi32> 581func.func @matrix_diag_part(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { 582 // CHECK-DAG: %[[V0:.*]] = mhlo.constant dense<42> : tensor<i32> 583 // CHECK-DAG: %[[V1:.*]] = mhlo.constant dense<[-10, 11]> : tensor<2xi32> 584 // CHECK-DAG: %[[V2:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<1x22x128xi32> 585 // CHECK-DAG: %[[V3:.*]] = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<1x22x128xi32> 586 // CHECK-DAG: %[[V4:.*]] = mhlo.constant dense<0> : tensor<i32> 587 // CHECK-DAG: %[[V5:.*]] = "mhlo.broadcast"(%[[V4]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32> 588 // CHECK-DAG: %[[V6:.*]] = mhlo.constant dense<false> : tensor<i1> 589 // CHECK-DAG: %[[V7:.*]] = "mhlo.broadcast"(%[[V6]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1> 590 // CHECK-DAG: %[[V8:.*]] = mhlo.constant dense<true> : tensor<i1> 591 // CHECK-DAG: %[[V9:.*]] = "mhlo.broadcast"(%[[V8]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1> 592 // CHECK-DAG: %[[V10:.*]] = mhlo.constant dense<11> : tensor<i32> 593 // CHECK-DAG: %[[V11:.*]] = "mhlo.broadcast"(%[[V10]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32> 594 // CHECK-DAG: %[[V12:.*]] = mhlo.constant dense<140> : tensor<i32> 595 // CHECK-DAG: %[[V13:.*]] = "mhlo.broadcast"(%[[V12]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32> 596 // CHECK-DAG: %[[V14:.*]] = mhlo.constant dense<128> : tensor<i32> 597 // CHECK-DAG: %[[V15:.*]] = "mhlo.broadcast"(%[[V14]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32> 598 // CHECK-DAG: %[[V16:.*]] = mhlo.constant dense<128> : tensor<i32> 599 // CHECK-DAG: %[[V17:.*]] = "mhlo.broadcast"(%[[V16]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32> 600 // CHECK-DAG: %[[V18:.*]] = mhlo.subtract %[[V11]], %[[V2]] : tensor<1x22x128xi32> 601 // CHECK-DAG: %[[V19:.*]] = mhlo.negate %[[V18]] : tensor<1x22x128xi32> 602 // CHECK-DAG: %[[V20:.*]] = mhlo.minimum %[[V18]], %[[V5]] : tensor<1x22x128xi32> 603 // CHECK-DAG: %[[V21:.*]] = mhlo.add %[[V13]], %[[V20]] : tensor<1x22x128xi32> 604 // CHECK-DAG: %[[V22:.*]] = mhlo.maximum %[[V18]], %[[V5]] : tensor<1x22x128xi32> 605 // CHECK-DAG: %[[V23:.*]] = mhlo.subtract %[[V15]], %[[V22]] : tensor<1x22x128xi32> 606 // CHECK-DAG: %[[V24:.*]] = mhlo.minimum %[[V21]], %[[V23]] : tensor<1x22x128xi32> 607 // CHECK-DAG: %[[V25:.*]] = chlo.broadcast_compare %[[V18]], %[[V5]] {comparison_direction = #mhlo<comparison_direction GE>} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 608 // CHECK-DAG: %[[V26:.*]] = mhlo.subtract %[[V17]], %[[V24]] : tensor<1x22x128xi32> 609 // CHECK-DAG: %[[V27:.*]] = "mhlo.select"(%[[V25]], %[[V26]], %[[V5]]) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 610 // CHECK-DAG: %[[V28:.*]] = mhlo.maximum %[[V18]], %[[V5]] : tensor<1x22x128xi32> 611 // CHECK-DAG: %[[V29:.*]] = mhlo.subtract %[[V28]], %[[V27]] : tensor<1x22x128xi32> 612 // CHECK-DAG: %[[V30:.*]] = mhlo.maximum %[[V19]], %[[V5]] : tensor<1x22x128xi32> 613 // CHECK-DAG: %[[V31:.*]] = mhlo.subtract %[[V30]], %[[V27]] : tensor<1x22x128xi32> 614 // CHECK-DAG: %[[V32:.*]] = mhlo.add %[[V3]], %[[V29]] : tensor<1x22x128xi32> 615 // CHECK-DAG: %[[V33:.*]] = mhlo.add %[[V3]], %[[V31]] : tensor<1x22x128xi32> 616 // CHECK-DAG: %[[V34:.*]] = chlo.broadcast_compare %[[V32]], %[[V5]] {comparison_direction = #mhlo<comparison_direction GE>} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 617 // CHECK-DAG: %[[V35:.*]] = chlo.broadcast_compare %[[V32]], %[[V15]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 618 // CHECK-DAG: %[[V36:.*]] = mhlo.and %[[V34]], %[[V35]] : tensor<1x22x128xi1> 619 // CHECK-DAG: %[[V37:.*]] = chlo.broadcast_compare %[[V33]], %[[V5]] {comparison_direction = #mhlo<comparison_direction GE>} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 620 // CHECK-DAG: %[[V38:.*]] = chlo.broadcast_compare %[[V33]], %[[V13]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 621 // CHECK-DAG: %[[V39:.*]] = mhlo.and %[[V37]], %[[V38]] : tensor<1x22x128xi1> 622 // CHECK-DAG: %[[V40:.*]] = mhlo.and %[[V36]], %[[V39]] : tensor<1x22x128xi1> 623 // CHECK-DAG: %[[V41:.*]] = mhlo.reshape %[[V40]] : (tensor<1x22x128xi1>) -> tensor<22x128xi1> 624 // CHECK-DAG: %[[V42:.*]] = "mhlo.concatenate"(%[[V33]], %[[V32]]) {dimension = 0 : i64} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32> 625 // CHECK-DAG: %[[V43:.*]] = "mhlo.gather"(%[[ARG]], %[[V42]]) {dimension_numbers = #mhlo.gather<offset_dims = [0], collapsed_slice_dims = [1, 2], start_index_map = [1, 2]>, indices_are_sorted = false, slice_sizes = dense<[7, 1, 1]> : tensor<3xi64>} : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> 626 // CHECK-DAG: %[[V44:.*]] = "mhlo.broadcast"(%[[V41]]) {broadcast_sizes = dense<7> : tensor<1xi64>} : (tensor<22x128xi1>) -> tensor<7x22x128xi1> 627 // CHECK-DAG: %[[V45:.*]] = "mhlo.broadcast"(%[[V0]]) {broadcast_sizes = dense<[7, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<7x22x128xi32> 628 // CHECK: %[[V46:.*]] = "mhlo.select"(%[[V44]], %[[V43]], %[[V45]]) : (tensor<7x22x128xi1>, tensor<7x22x128xi32>, tensor<7x22x128xi32>) -> tensor<7x22x128xi32> 629 // CHECK: return %[[V46]] : tensor<7x22x128xi32> 630 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 631 %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k 632 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 633 T = i32, align = "RIGHT_LEFT" 634 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32> 635 func.return %2: tensor<7x22x128xi32> 636} 637 638// ----- 639 640// CHECK-LABEL: func @matrix_diag_part_single_diagonal 641func.func @matrix_diag_part_single_diagonal(%arg0: tensor<7x140x128xi32>) -> tensor<7x128xi32> { 642 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 643 %1 = mhlo.constant dense<0> : tensor<2xi32> // k 644 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 645 T = i32, align = "RIGHT_LEFT" 646 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x128xi32> 647 // CHECK: %[[result:.*]] = mhlo.reshape {{.*}} : (tensor<7x1x128xi32>) -> tensor<7x128xi32> 648 // CHECK: return %[[result]] : tensor<7x128xi32> 649 func.return %2: tensor<7x128xi32> 650} 651 652// ----- 653 654// CHECK-LABEL: func @matrix_diag_part_align_ll 655func.func @matrix_diag_part_align_ll(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { 656 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 657 %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k 658 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 659 T = i32, align = "LEFT_LEFT" 660 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32> 661 // CHECK: %[[false:.*]] = mhlo.constant dense<false> : tensor<i1> 662 // CHECK: %[[b_false:.*]] = "mhlo.broadcast"(%[[false]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1> 663 // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[b_false]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 664 func.return %2: tensor<7x22x128xi32> 665} 666 667// ----- 668 669// CHECK-LABEL: func @matrix_diag_part_align_lr 670func.func @matrix_diag_part_align_lr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { 671 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 672 %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k 673 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 674 T = i32, align = "LEFT_RIGHT" 675 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32> 676 // CHECK: %[[le:.*]] = chlo.broadcast_compare %{{[0-9]*}}, %{{[0-9]*}} {comparison_direction = #mhlo<comparison_direction LE>} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 677 // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[le]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 678 func.return %2: tensor<7x22x128xi32> 679} 680 681// ----- 682 683// CHECK-LABEL: func @matrix_diag_part_align_rl 684func.func @matrix_diag_part_align_rl(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { 685 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 686 %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k 687 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 688 T = i32, align = "RIGHT_LEFT" 689 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32> 690 // CHECK: %[[ge:.*]] = chlo.broadcast_compare %{{[0-9]*}}, %{{[0-9]*}} {comparison_direction = #mhlo<comparison_direction GE>} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 691 // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[ge]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 692 func.return %2: tensor<7x22x128xi32> 693} 694 695// ----- 696 697// CHECK-LABEL: func @matrix_diag_part_align_rr 698func.func @matrix_diag_part_align_rr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { 699 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 700 %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k 701 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 702 T = i32, align = "RIGHT_RIGHT" 703 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32> 704 // CHECK: %[[true:.*]] = mhlo.constant dense<true> : tensor<i1> 705 // CHECK: %[[b_true:.*]] = "mhlo.broadcast"(%[[true]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1> 706 // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[b_true]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 707 func.return %2: tensor<7x22x128xi32> 708} 709 710// ----- 711 712// CHECK-LABEL: func @matrix_diag_part_align_7d 713// CHECK: (%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> 714func.func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> { 715 %0 = mhlo.constant dense<-1.> : tensor<f32> // padding value 716 %1 = mhlo.constant dense<[-6, -3]> : tensor<2xi32> // k 717 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 718 T = f32, align = "LEFT_RIGHT" 719 } : (tensor<3x5x7x9x11x13x17xf32>, tensor<2xi32>, tensor<f32>) -> tensor<3x5x7x9x11x4x10xf32> 720 func.return %2: tensor<3x5x7x9x11x4x10xf32> 721} 722 723//===----------------------------------------------------------------------===// 724// Erf 725//===----------------------------------------------------------------------===// 726 727// ----- 728 729// CHECK-LABEL: func @erf 730func.func @erf(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { 731 // CHECK: chlo.erf %arg0 : tensor<2x3xf32> 732 %0 = "tf.Erf"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> 733 func.return %0 : tensor<2x3xf32> 734} 735 736//===----------------------------------------------------------------------===// 737// Erfc 738//===----------------------------------------------------------------------===// 739 740// ----- 741 742// CHECK-LABEL: func @erfc 743func.func @erfc(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { 744 // CHECK: chlo.erfc %arg0 : tensor<2x3xf32> 745 %0 = "tf.Erfc"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> 746 func.return %0 : tensor<2x3xf32> 747} 748 749//===----------------------------------------------------------------------===// 750// Einsum. 751//===----------------------------------------------------------------------===// 752 753// ----- 754 755// CHECK-LABEL: func @einsum 756func.func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { 757 // CHECK: mhlo.einsum 758 %0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32> 759 func.return %0: tensor<2x4xf32> 760} 761 762// ----- 763 764// CHECK-LABEL: func @unary_einsum 765func.func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { 766 // CHECK: mhlo.unary_einsum 767 %0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> 768 func.return %0: tensor<2x2xf32> 769} 770 771//===----------------------------------------------------------------------===// 772// FloorDiv and FloorMod. 773//===----------------------------------------------------------------------===// 774 775// ----- 776 777// CHECK-LABEL: func @floordiv_broadcast_i32 778func.func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { 779 // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 780 // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 781 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #mhlo<comparison_direction NE>} 782 // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> 783 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #mhlo<comparison_direction LT>} 784 // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> 785 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #mhlo<comparison_direction LT>} 786 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #mhlo<comparison_direction NE>} 787 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 788 // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> 789 // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] 790 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]]) 791 // CHECK: return [[SELECT]] 792 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> 793 func.return %0: tensor<2x3xi32> 794} 795 796// ----- 797 798// CHECK-LABEL: func @floordiv_reverse_broadcast_i32 799func.func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { 800 // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 801 // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]] 802 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #mhlo<comparison_direction NE>} 803 // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> 804 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #mhlo<comparison_direction LT>} 805 // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> 806 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #mhlo<comparison_direction LT>} 807 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #mhlo<comparison_direction NE>} 808 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 809 // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> 810 // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] 811 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]]) 812 // CHECK: return [[SELECT]] 813 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> 814 func.return %0: tensor<2x3xi32> 815} 816 817// ----- 818 819// CHECK-LABEL: func @floordiv_f32 820func.func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { 821 // CHECK-NEXT: %[[DIV:.*]] = chlo.broadcast_divide %arg0, %arg0 822 // CHECK-NEXT: %[[FLOOR:.*]] = mhlo.floor %[[DIV]] 823 // CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32> 824 %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> 825 func.return %0: tensor<2xf32> 826} 827 828// ----- 829 830// CHECK-LABEL: func @floordiv_bf16 831func.func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { 832 // CHECK-NEXT: mhlo.convert 833 // CHECK-NEXT: mhlo.convert 834 // CHECK-NEXT: chlo.broadcast_divide 835 // CHECK-NEXT: mhlo.floor 836 // CHECK-NEXT: mhlo.convert 837 // CHECK-NEXT: return 838 %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xbf16>, tensor<2xbf16>) -> tensor<2xbf16> 839 func.return %0: tensor<2xbf16> 840} 841 842// ----- 843 844// CHECK-LABEL: func @floordiv_f16_broadcast 845func.func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { 846 // CHECK-NEXT: chlo.broadcast_divide 847 // CHECK-NEXT: mhlo.floor 848 // CHECK-NEXT: return 849 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> 850 func.return %0: tensor<2x3xf16> 851} 852 853// ----- 854 855// CHECK-LABEL: func @floordiv_dynamic 856func.func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> { 857 // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 858 // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 859 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #mhlo<comparison_direction NE>} 860 // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> 861 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #mhlo<comparison_direction LT>} 862 // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> 863 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #mhlo<comparison_direction LT>} 864 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #mhlo<comparison_direction NE>} 865 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 866 // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> 867 // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] 868 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]]) 869 // CHECK: return [[SELECT]] 870 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32> 871 func.return %0: tensor<?x?xi32> 872} 873 874// ----- 875 876// CHECK-LABEL: func @floordiv_unsigned 877func.func @floordiv_unsigned(%arg0: tensor<?x?xui32>, %arg1: tensor<?xui32>) -> tensor<?x?xui32> { 878 // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 879 // CHECK: return [[DIV]] 880 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<?x?xui32>, tensor<?xui32>) -> tensor<?x?xui32> 881 func.return %0: tensor<?x?xui32> 882} 883 884// ----- 885 886// CHECK-LABEL: func @floordiv_unranked 887func.func @floordiv_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { 888 // CHECK-NOT: tf.FloorDiv 889 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> 890 func.return %0: tensor<*xf32> 891} 892 893// ----- 894 895// CHECK-LABEL: func @floordiv_int 896func.func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { 897 // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> 898 // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> 899 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> 900 // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> : tensor<i32> 901 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1> 902 // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> : tensor<i32> 903 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1> 904 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #mhlo<comparison_direction NE>} 905 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 906 // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> : tensor<i32> 907 // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] 908 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]]) 909 // CHECK: return [[SELECT]] 910 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> 911 func.return %0: tensor<*xi32> 912} 913 914// ----- 915 916// CHECK-LABEL: func @floormod_broadcast_numerator 917func.func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { 918 // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 919 // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> 920 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #mhlo<comparison_direction NE>} 921 // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> 922 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #mhlo<comparison_direction LT>} 923 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #mhlo<comparison_direction LT>} 924 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #mhlo<comparison_direction NE>} 925 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 926 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] 927 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) 928 // CHECK-NEXT: return [[SELECT]] 929 %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> 930 func.return %0: tensor<2x3xi32> 931} 932 933// ----- 934 935// CHECK-LABEL: func @floormod_broadcast_denominator 936func.func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { 937 // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 938 // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> 939 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #mhlo<comparison_direction NE>} 940 // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> 941 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #mhlo<comparison_direction LT>} 942 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #mhlo<comparison_direction LT>} 943 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #mhlo<comparison_direction NE>} 944 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 945 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} 946 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) 947 // CHECK-NEXT: return [[SELECT]] 948 %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> 949 func.return %0: tensor<2x3xi32> 950} 951 952// ----- 953 954// CHECK-LABEL: func @floormod_unsigned_broadcast_denominator 955func.func @floormod_unsigned_broadcast_denominator(%arg0: tensor<2x3xui32>, %arg1: tensor<3xui32>) -> tensor<2x3xui32> { 956 // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 957 // CHECK-NEXT: return [[REM]] 958 %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xui32>, tensor<3xui32>) -> tensor<2x3xui32> 959 func.return %0: tensor<2x3xui32> 960} 961 962// ----- 963 964// CHECK-LABEL: func @floormod_dynamic_broadcast_numerator 965func.func @floormod_dynamic_broadcast_numerator_(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> { 966 // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 967 // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> 968 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #mhlo<comparison_direction NE>} 969 // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> 970 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #mhlo<comparison_direction LT>} 971 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #mhlo<comparison_direction LT>} 972 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #mhlo<comparison_direction NE>} 973 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 974 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} 975 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) 976 // CHECK-NEXT: return [[SELECT]] 977 %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32> 978 func.return %0: tensor<?x?xi32> 979} 980 981// ----- 982 983// CHECK-LABEL: func @floormod_dynamic_broadcast_denominator 984func.func @floormod_dynamic_broadcast_denominator_(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { 985 // CHECK-NOT: tf.FloorMod 986 // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> 987 // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 988 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?x?xi1> 989 // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 990 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?x?xi1> 991 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?x?xi1> 992 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<?x?x?xi1>, tensor<?x?x?xi1>) -> tensor<?x?x?xi1> 993 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] : (tensor<?x?x?xi1>, tensor<?x?x?xi1>) -> tensor<?x?x?xi1> 994 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> 995 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> 996 // CHECK-NEXT: return [[SELECT]] : tensor<?x?x?xf32> 997 %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> 998 func.return %0: tensor<?x?x?xf32> 999} 1000 1001// ----- 1002 1003// CHECK-LABEL: func @floormod_unranked 1004func.func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { 1005 // CHECK-NOT: tf.FloorMod 1006 %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> 1007 func.return %0: tensor<*xi32> 1008} 1009 1010//===----------------------------------------------------------------------===// 1011// OnesLike 1012//===----------------------------------------------------------------------===// 1013 1014// ----- 1015 1016// CHECK-LABEL: @ones_like 1017// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>) 1018func.func @ones_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { 1019 // CHECK: %[[RES:.*]] = "chlo.constant_like"(%[[ARG]]) {value = 1.0{{.*}}} 1020 // CHECK: return %[[RES]] 1021 %0 = "tf.OnesLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> 1022 func.return %0 : tensor<2x?xf32> 1023} 1024 1025//===----------------------------------------------------------------------===// 1026// ZerosLike 1027//===----------------------------------------------------------------------===// 1028 1029// ----- 1030 1031// CHECK-LABEL: @zeros_like 1032// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>) 1033func.func @zeros_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { 1034 // CHECK: %[[RES:.*]] = "chlo.constant_like"(%[[ARG]]) {value = 0.0{{.*}}} 1035 // CHECK: return %[[RES]] 1036 %0 = "tf.ZerosLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> 1037 func.return %0 : tensor<2x?xf32> 1038} 1039 1040//===----------------------------------------------------------------------===// 1041// BroadcastTo. 1042//===----------------------------------------------------------------------===// 1043 1044// ----- 1045 1046// CHECK-LABEL: func @broadcast_to 1047func.func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { 1048 %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32> 1049 1050 // CHECK: [[CST:%.+]] = mhlo.constant 1051 // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg0, [[CST]]) 1052 // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} 1053 %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32> 1054 func.return %0 : tensor<16x16x16x16xf32> 1055} 1056 1057// ----- 1058 1059// CHECK-LABEL: func @broadcast_scalar_to_unranked 1060// CHECK: (%[[ARG0:.*]]: tensor<f32>, %[[SHAPE:.*]]: tensor<?xi32>) 1061func.func @broadcast_scalar_to_unranked(%arg0: tensor<f32>, %shape: tensor<?xi32>) -> tensor<*xf32> { 1062 // CHECK: "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE]]) 1063 // CHECK-SAME: {broadcast_dimensions = dense<> : tensor<0xi64>} 1064 %0 = "tf.BroadcastTo"(%arg0, %shape) : (tensor<f32>, tensor<?xi32>) -> tensor<*xf32> 1065 func.return %0 : tensor<*xf32> 1066} 1067 1068//===----------------------------------------------------------------------===// 1069// Complex op legalizations. 1070//===----------------------------------------------------------------------===// 1071 1072// ----- 1073 1074// CHECK-LABEL: func @complex 1075func.func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex<f32>> { 1076 // CHECK: chlo.broadcast_complex 1077 %1 = "tf.Complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>> 1078 func.return %1 : tensor<3xcomplex<f32>> 1079} 1080 1081// ----- 1082 1083// CHECK-LABEL: func @imag 1084func.func @imag(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xf32> { 1085 // CHECK: mhlo.imag 1086 %1 = "tf.Imag"(%arg0) : (tensor<3xcomplex<f32>>) -> tensor<3xf32> 1087 func.return %1 : tensor<3xf32> 1088} 1089 1090// ----- 1091 1092// CHECK-LABEL: func @real 1093func.func @real(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xf32> { 1094 // CHECK: mhlo.real 1095 %1 = "tf.Real"(%arg0) : (tensor<3xcomplex<f32>>) -> tensor<3xf32> 1096 func.return %1 : tensor<3xf32> 1097} 1098 1099//===----------------------------------------------------------------------===// 1100// Concat op legalizations. 1101//===----------------------------------------------------------------------===// 1102 1103// ----- 1104 1105// CHECK-LABEL: func @concat_v2 1106func.func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { 1107 // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> 1108 %axis = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64> 1109 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<6x3xf32> 1110 func.return %1 : tensor<6x3xf32> 1111} 1112 1113// ----- 1114 1115// CHECK-LABEL: func @concat_v2_neg_axis 1116func.func @concat_v2_neg_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { 1117 // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> 1118 1119 %axis = "tf.Const"() { value = dense<-2> : tensor<i64> } : () -> tensor<i64> 1120 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<6x3xf32> 1121 func.return %1 : tensor<6x3xf32> 1122} 1123 1124// ----- 1125 1126// CHECK-LABEL: func @concat_v2_1d_axis 1127func.func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> { 1128 // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> 1129 1130 %axis = "tf.Const"() { value = dense<[1]> : tensor<1xi64> } : () -> tensor<1xi64> 1131 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<1xi64>) -> tensor<3x6xf32> 1132 func.return %1 : tensor<3x6xf32> 1133} 1134 1135// ----- 1136 1137// CHECK-LABEL: func @concat_v2_non_const_axis 1138func.func @concat_v2_non_const_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %axis: tensor<i64>) -> tensor<3x6xf32> { 1139 // CHECK: "tf.ConcatV2" 1140 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<3x6xf32> 1141 func.return %1 : tensor<3x6xf32> 1142} 1143 1144// ----- 1145 1146// CHECK-LABEL: func @concat_v2_unranked 1147func.func @concat_v2_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { 1148 %axis = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64> 1149 // CHECK: "tf.ConcatV2" 1150 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<*xf32>, tensor<*xf32>, tensor<i64>) -> tensor<*xf32> 1151 func.return %1 : tensor<*xf32> 1152} 1153 1154//===----------------------------------------------------------------------===// 1155// Pad op legalizations. 1156//===----------------------------------------------------------------------===// 1157 1158// ----- 1159 1160// CHECK-LABEL: func @padv2_1D 1161func.func @padv2_1D(%arg0: tensor<3xf32>, %arg1: tensor<f32>) -> tensor<6xf32> { 1162 %padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64> 1163 // CHECK: "mhlo.pad"(%arg0, %arg1) { 1164 // CHECK-SAME: edge_padding_high = dense<2> : tensor<1xi64>, 1165 // CHECK-SAME: edge_padding_low = dense<1> : tensor<1xi64>, 1166 // CHECK-SAME: interior_padding = dense<0> : tensor<1xi64> 1167 %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3xf32>, tensor<1x2xi64>, tensor<f32>) -> tensor<6xf32> 1168 func.return %1 : tensor<6xf32> 1169} 1170 1171// ----- 1172 1173// CHECK-LABEL: func @padv2_2D 1174func.func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> { 1175 %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi64> } : () -> tensor<2x2xi64> 1176 // CHECK: "mhlo.pad"(%arg0, %arg1) { 1177 // CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>, 1178 // CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>, 1179 // CHECK-SAME: interior_padding = dense<0> : tensor<2xi64> 1180 %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi64>, tensor<f32>) -> tensor<6x9xf32> 1181 func.return %1 : tensor<6x9xf32> 1182} 1183 1184// ----- 1185 1186// CHECK-LABEL: func @padv2_i32_paddings 1187func.func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> { 1188 %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32> 1189 // CHECK: "mhlo.pad"(%arg0, %arg1) { 1190 // CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>, 1191 // CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>, 1192 // CHECK-SAME: interior_padding = dense<0> : tensor<2xi64> 1193 %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor<f32>) -> tensor<6x9xf32> 1194 func.return %1 : tensor<6x9xf32> 1195} 1196 1197// ----- 1198 1199// CHECK-LABEL: func @padv2_dynamic 1200func.func @padv2_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<f32>, %arg2: tensor<1x2xi64>) -> tensor<?xf32> { 1201 // CHECK: "mhlo.transpose"({{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x2xi64>) -> tensor<2x1xi64> 1202 // CHECK: mhlo.reshape {{.*}} : (tensor<2x1xi64>) -> tensor<2xi64> 1203 // CHECK: "mhlo.slice"({{.*}}) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> 1204 // CHECK: "mhlo.slice"({{.*}}) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> 1205 // CHECK: mhlo.dynamic_pad {{.*}} : (tensor<?xf32>, tensor<f32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<?xf32> 1206 %1 = "tf.PadV2"(%arg0, %arg2, %arg1) : (tensor<?xf32>, tensor<1x2xi64>, tensor<f32>) -> tensor<?xf32> 1207 func.return %1 : tensor<?xf32> 1208} 1209 1210//===----------------------------------------------------------------------===// 1211// Identity op legalizations. 1212//===----------------------------------------------------------------------===// 1213 1214// ----- 1215 1216// CHECK-LABEL: func @identity 1217func.func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> { 1218 // CHECK-NEXT: return %arg0 : tensor<1xi32> 1219 %0 = "tf.Identity"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> 1220 func.return %0: tensor<1xi32> 1221} 1222 1223// ----- 1224 1225// CHECK-LABEL: func @identityN 1226func.func @identityN(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) { 1227 // CHECK-NEXT: return %arg0, %arg1 : tensor<1xi32>, tensor<1xf32> 1228 %0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) 1229 func.return %0#0, %0#1: tensor<1xi32>, tensor<1xf32> 1230} 1231 1232// ----- 1233 1234// CHECK-LABEL: func @stopgradient 1235func.func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { 1236 // CHECK-NEXT: return %arg0 : tensor<1xi32> 1237 %0 = "tf.StopGradient"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> 1238 func.return %0: tensor<1xi32> 1239} 1240 1241// ----- 1242 1243// CHECK-LABEL: func @preventgradient 1244func.func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { 1245 // CHECK-NEXT: return %arg0 : tensor<1xi32> 1246 %0 = "tf.PreventGradient"(%arg0) {message = "fin gradients"} : (tensor<1xi32>) -> tensor<1xi32> 1247 func.return %0: tensor<1xi32> 1248} 1249 1250// ----- 1251 1252// CHECK-LABEL: func @checkNumerics 1253func.func @checkNumerics(%arg0: tensor<1xf32>) -> tensor<1xf32> { 1254 // CHECK-NEXT: return %arg0 : tensor<1xf32> 1255 %0 = "tf.CheckNumerics"(%arg0) {message = "check numerics"} : (tensor<1xf32>) -> tensor<1xf32> 1256 func.return %0: tensor<1xf32> 1257} 1258 1259//===----------------------------------------------------------------------===// 1260// InfeedDequeueTuple legalization 1261//===----------------------------------------------------------------------===// 1262 1263// ----- 1264 1265// CHECK-LABEL: func @infeed_dequeue_tuple 1266func.func @infeed_dequeue_tuple() -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) { 1267// CHECK: [[TOKEN:%.*]] = mhlo.create_token : !mhlo.token 1268// CHECK: [[INFEED:%.*]]:3 = "mhlo.infeed"([[TOKEN]]) {infeed_config = ""{{.*}}} : (!mhlo.token) -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>, !mhlo.token) 1269// CHECK: return [[INFEED]]#0, [[INFEED]]#1 1270 %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) 1271 func.return %0#0, %0#1 : tensor<1x8x4x4xi32>, tensor<1x100x1xf32> 1272} 1273 1274// ----- 1275 1276// CHECK-LABEL: func @infeed_dequeue_tuple_dynamic_error 1277func.func @infeed_dequeue_tuple_dynamic_error() -> (tensor<3x3xf32>, tensor<4x?xf32>) { 1278 // We expect legalization to fail for dynamic shapes: 1279 // CHECK: [[INFEED:%.*]] = "tf.InfeedDequeueTuple"{{.*}} 1280 %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3x3xf32>, tensor<4x?xf32>) 1281 func.return %0#0, %0#1 : tensor<3x3xf32>, tensor<4x?xf32> 1282} 1283 1284// The following op sharding is used: 1285// Proto debug string: 1286// type: TUPLE 1287// tuple_shardings { 1288// type: MAXIMAL 1289// tile_assignment_dimensions: 1 1290// tile_assignment_devices: 0 1291// } 1292// Serialized string: 1293// "\08\02*\08\08\01\1A\01\01\22\01\00" 1294 1295// CHECK-LABEL: infeed_dequeue_tuple_sharding 1296func.func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> { 1297 // CHECK: "mhlo.infeed" 1298 // An additional sharding is added at the end to account for token result. 1299 // Proto debug string: 1300 // type: TUPLE 1301 // tuple_shardings { 1302 // type: MAXIMAL 1303 // tile_assignment_dimensions: 1 1304 // tile_assignment_devices: 0 1305 // } 1306 // tuple_shardings { 1307 // type: MAXIMAL 1308 // tile_assignment_dimensions: 1 1309 // tile_assignment_devices: 0 1310 // } 1311 // CHECK-SAME: mhlo.sharding = "\08\02*\08\08\01\1A\01\01\22\01\00*\08\08\01\1A\01\01\22\01\00" 1312 %0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32> 1313 func.return %0 : tensor<8xi32> 1314} 1315 1316//===----------------------------------------------------------------------===// 1317// Nullary op legalizations. 1318//===----------------------------------------------------------------------===// 1319 1320// ----- 1321 1322// CHECK-LABEL: @const 1323func.func @const() -> tensor<2xi32> { 1324 // CHECK: mhlo.constant dense<0> : tensor<2xi32> 1325 %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>) 1326 func.return %0: tensor<2xi32> 1327} 1328 1329// ----- 1330 1331// CHECK-LABEL: @const_dynamic_output 1332func.func @const_dynamic_output() -> tensor<*xi32> { 1333 // CHECK: [[CONST:%.*]] = mhlo.constant dense<0> : tensor<2xi32> 1334 // CHECK: [[CAST:%.*]] = tensor.cast [[CONST]] : tensor<2xi32> to tensor<*xi32> 1335 %0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>) 1336 // CHECK: return [[CAST]] 1337 func.return %0: tensor<*xi32> 1338} 1339 1340// ----- 1341 1342// CHECK-LABEL: @opaque_const 1343func.func @opaque_const() -> tensor<!tf_type.variant<tensor<2xi32>>> { 1344 // CHECK-NOT: mhlo.constant 1345 %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<2xi32>>> 1346 func.return %0 : tensor<!tf_type.variant<tensor<2xi32>>> 1347} 1348 1349//===----------------------------------------------------------------------===// 1350// Matmul op legalizations. 1351//===----------------------------------------------------------------------===// 1352 1353// ----- 1354 1355// CHECK-LABEL: matmul_notranspose 1356// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<7x11xf32>) 1357func.func @matmul_notranspose(%a: tensor<5x7xf32>, %b: tensor<7x11xf32>) -> tensor<5x11xf32> { 1358 // CHECK: "mhlo.dot"(%[[A]], %[[B]]) 1359 %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<5x7xf32>, tensor<7x11xf32>) -> tensor<5x11xf32> 1360 1361 func.return %0 : tensor<5x11xf32> 1362} 1363 1364// ----- 1365 1366// CHECK-LABEL: matmul_transpose_b 1367// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<11x7xf32>) 1368func.func @matmul_transpose_b(%a: tensor<5x7xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { 1369 // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>} 1370 // CHECK: "mhlo.dot"(%[[A]], %[[UPDATED_B]]) 1371 %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = true} : (tensor<5x7xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> 1372 1373 func.return %0 : tensor<5x11xf32> 1374} 1375 1376// ----- 1377 1378// CHECK-LABEL: matmul_transpose_both 1379// CHECK-SAME: (%[[A:.*]]: tensor<7x5xf32>, %[[B:.*]]: tensor<11x7xf32>) 1380func.func @matmul_transpose_both(%a: tensor<7x5xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { 1381 // CHECK: %[[UPDATED_A:.*]] = "mhlo.transpose"(%[[A]]) {permutation = dense<[1, 0]> : tensor<2xi64>} 1382 // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>} 1383 // CHECK: "mhlo.dot"(%[[UPDATED_A]], %[[UPDATED_B]]) 1384 %0 = "tf.MatMul"(%a, %b) {transpose_a = true, transpose_b = true} : (tensor<7x5xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> 1385 1386 func.return %0 : tensor<5x11xf32> 1387} 1388 1389// Verify that MatMul with ranked inputs are lowered to HLO. 1390// CHECK-LABEL: matmul_ranked 1391func.func @matmul_ranked(%a: tensor<?x7xf32>, %b: tensor<7x?xf32>) -> tensor<?x?xf32> { 1392 // CHECK: "mhlo.dot" 1393 %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<?x7xf32>, tensor<7x?xf32>) -> tensor<?x?xf32> 1394 1395 func.return %0 : tensor<?x?xf32> 1396} 1397 1398// Verify that MatMul with unranked inputs are lowered to HLO. 1399// CHECK-LABEL: matmul_unranked 1400func.func @matmul_unranked(%a: tensor<*xf32>, %b: tensor<*xf32>) -> tensor<*xf32> { 1401 // CHECK: "mhlo.dot" 1402 %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> 1403 1404 func.return %0 : tensor<*xf32> 1405} 1406 1407// Verify SparseMatMul is legalized to dot. 1408// CHECK-LABEL: test_sparse_mat_mul 1409func.func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> { 1410 // CHECK: "mhlo.dot" 1411 %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32> 1412 func.return %0: tensor<3x5xf32> 1413} 1414 1415// SparseMatMul where one operand needs to be transposed and the other one not. 1416// 1417// CHECK-LABEL: @test_sparse_mat_mul_with_transpose 1418// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> 1419// CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32> 1420// CHECK-SAME: -> tensor<3x5xf32> 1421// CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[ARG1]]) 1422// CHECK-SAME: permutation = dense<[1, 0]> 1423// CHECK-SAME: -> tensor<4x5xf32> 1424// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]]) 1425// CHECK-SAME: -> tensor<3x5xf32> 1426// CHECK: return %[[RESULT]] 1427func.func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> { 1428 %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32> 1429 func.return %0: tensor<3x5xf32> 1430} 1431 1432// SparseMatMul where one operand needs to be casted and the other one not. 1433// 1434// CHECK-LABEL: @test_sparse_mat_mul_with_cast 1435// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> 1436// CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16> 1437// CHECK-SAME: -> tensor<3x5xf32> 1438// CHECK: %[[CAST:.*]] = mhlo.convert(%[[ARG1]]) 1439// CHECK-SAME: -> tensor<4x5xf32> 1440// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]]) 1441// CHECK-SAME: -> tensor<3x5xf32> 1442// CHECK: return %[[RESULT]] 1443func.func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> { 1444 %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32> 1445 func.return %0: tensor<3x5xf32> 1446} 1447 1448//===----------------------------------------------------------------------===// 1449// MatrixBandPart op legalizations. 1450//===----------------------------------------------------------------------===// 1451 1452// ----- 1453 1454// CHECK-LABEL: matrix_band_part 1455// CHECK-SAME: (%[[INPUT:.*]]: tensor<64x64xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>) 1456func.func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> { 1457 // CHECK-DAG: %[[M:.*]] = mhlo.constant dense<64> : tensor<i64> 1458 // CHECK-DAG: %[[N:.*]] = mhlo.constant dense<64> : tensor<i64> 1459 1460 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i64> 1461 // CHECK-DAG: %[[A:.*]] = mhlo.compare LT, %[[LOWER]], %[[ZERO]] : (tensor<i64>, tensor<i64>) -> tensor<i1> 1462 // CHECK-DAG: %[[B:.*]] = "mhlo.select"(%[[A]], %[[M]], %[[LOWER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> 1463 1464 // CHECK-DAG: %[[C:.*]] = mhlo.compare LT, %[[UPPER]], %[[ZERO]] : (tensor<i64>, tensor<i64>) -> tensor<i1> 1465 // CHECK-DAG: %[[D:.*]] = "mhlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> 1466 // CHECK-DAG: %[[F:.*]] = mhlo.negate %[[B]] : tensor<i64> 1467 1468 // CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xi64> 1469 // CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xi64> 1470 // CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xi64> 1471 // CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = #mhlo<comparison_direction LE>} : (tensor<i64>, tensor<64x64xi64>) -> tensor<64x64xi1> 1472 1473 // CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = #mhlo<comparison_direction LE>} : (tensor<64x64xi64>, tensor<i64>) -> tensor<64x64xi1> 1474 1475 // CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1> 1476 1477 // CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<64x64xbf16> 1478 1479 // CHECK-DAG: %[[R:.*]] = chlo.broadcast_select %[[J]], %[[INPUT]], %[[ZERO2]] 1480 // CHECK-DAG: return %[[R]] 1481 %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16> 1482 func.return %0 : tensor<64x64xbf16> 1483} 1484 1485// ----- 1486 1487// CHECK-LABEL: matrix_band_part_2 1488// CHECK-SAME: (%[[INPUT:.*]]: tensor<12x24x48xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>) 1489func.func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<12x24x48xbf16> { 1490 // CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xi64> 1491 // CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xi64> 1492 // CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xi64> 1493 1494 // CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = #mhlo<comparison_direction LE>} : (tensor<i64>, tensor<24x48xi64>) -> tensor<24x48xi1> 1495 1496 // CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = #mhlo<comparison_direction LE>} : (tensor<24x48xi64>, tensor<i64>) -> tensor<24x48xi1> 1497 // CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1> 1498 1499 // CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16> 1500 1501 // CHECK-DAG: %[[R:.*]] = chlo.broadcast_select %[[J]], %[[INPUT]], %[[ZERO2]] 1502 // CHECK-DAG: return %[[R]] 1503 %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<12x24x48xbf16>, tensor<i64>, tensor<i64>) -> tensor<12x24x48xbf16> 1504 func.return %0 : tensor<12x24x48xbf16> 1505} 1506 1507// ----- 1508 1509// CHECK-LABEL: matrix_band_part_3 1510// CHECK-SAME: (%[[INPUT:.*]]: tensor<*xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>) 1511func.func @matrix_band_part_3(%arg0: tensor<*xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> { 1512 // CHECK: "tf.MatrixBandPart" 1513 %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<*xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16> 1514 func.return %0 : tensor<*xbf16> 1515} 1516 1517// ----- 1518 1519// CHECK-LABEL: matrix_band_part_4 1520// CHECK-SAME: (%[[INPUT:.*]]: tensor<24x48xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>) 1521func.func @matrix_band_part_4(%arg0: tensor<24x48xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<24x48xbf16> { 1522 // This one should lower. 1523 // CHECK-NOT: "tf.MatrixBandPart" 1524 %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<24x48xbf16>, tensor<i64>, tensor<i64>) -> tensor<24x48xbf16> 1525 func.return %0 : tensor<24x48xbf16> 1526} 1527 1528//===----------------------------------------------------------------------===// 1529// MaxPool op legalizations. 1530//===----------------------------------------------------------------------===// 1531 1532// ----- 1533 1534// CHECK-LABEL: maxpool_valid_padding 1535// CHECK-SAME: %[[ARG:.*]]: tensor 1536func.func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { 1537 // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32> 1538 // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]]) 1539 // CHECK: mhlo.maximum 1540 // CHECK: mhlo.return 1541 // CHECK: {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} 1542 1543 %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> 1544 func.return %0 : tensor<2x3x5x7xi32> 1545} 1546 1547// ----- 1548 1549// CHECK-LABEL: maxpool_same_padding 1550// CHECK-SAME: %[[ARG:.*]]: tensor 1551func.func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> { 1552 // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> 1553 1554 %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> 1555 func.return %0 : tensor<2x4x7x7xi32> 1556} 1557 1558// ----- 1559 1560// CHECK-LABEL: maxpool_3d_valid_padding 1561// CHECK-SAME: %[[ARG:.*]]: tensor 1562func.func @maxpool_3d_valid_padding(%arg0: tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> { 1563 // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32> 1564 // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]]) 1565 // CHECK: mhlo.maximum 1566 // CHECK: mhlo.return 1567 // CHECK: {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 4, 4, 1]> : tensor<5xi64>} 1568 1569 %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> 1570 func.return %0 : tensor<2x8x3x5x7xf32> 1571} 1572 1573// ----- 1574 1575// CHECK-LABEL: maxpool_3d_same_padding 1576// CHECK-SAME: %[[ARG:.*]]: tensor 1577func.func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> { 1578 // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> 1579 1580 %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> 1581 func.return %0 : tensor<2x8x4x7x7xf32> 1582} 1583 1584// ----- 1585 1586// CHECK-LABEL: maxpool_explicit_padding 1587func.func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { 1588 // CHECK: tf.MaxPool 1589 // TODO(b/165938852): need to support explicit padding in max_pool. 1590 1591 %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "EXPLICIT", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> 1592 func.return %0 : tensor<2x3x5x7xi32> 1593} 1594 1595//===----------------------------------------------------------------------===// 1596// MaxPoolGrad op legalizations. 1597//===----------------------------------------------------------------------===// 1598 1599// ----- 1600 1601// CHECK-LABEL: @max_pool_grad_valid 1602// CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32> 1603func.func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { 1604 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 1605 // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ({ 1606 // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>): 1607 // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor<f32>, tensor<f32>) -> tensor<i1> 1608 // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor<i1> 1609 // CHECK: }, { 1610 // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>): 1611 // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor<f32> 1612 // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor<f32> 1613 // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) -> tensor<10x24x24x64xf32> 1614 // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32> 1615 %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { 1616 data_format = "NHWC", 1617 ksize = [1, 2, 2, 1], 1618 padding = "VALID", 1619 strides = [1, 2, 2, 1] 1620 } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> 1621 func.return %result : tensor<10x24x24x64xf32> 1622} 1623 1624// ----- 1625 1626// CHECK-LABEL: @max_pool_3d_grad_valid 1627// CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32> 1628func.func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> { 1629 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 1630 // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ({ 1631 // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>): 1632 // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor<f32>, tensor<f32>) -> tensor<i1> 1633 // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor<i1> 1634 // CHECK: }, { 1635 // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>): 1636 // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor<f32> 1637 // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor<f32> 1638 // CHECK: }) {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<f32>) -> tensor<10x8x24x24x64xf32> 1639 // CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32> 1640 %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> 1641 func.return %result : tensor<10x8x24x24x64xf32> 1642} 1643 1644// ----- 1645 1646// CHECK-LABEL: @max_pool_grad_same 1647func.func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> { 1648 // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> 1649 %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { 1650 data_format = "NHWC", 1651 ksize = [1, 2, 3, 1], 1652 padding = "SAME", 1653 strides = [1, 4, 4, 1] 1654 } : (tensor<2x13x25x7xf32>, tensor<2x4x7x7xf32>, tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> 1655 func.return %result : tensor<2x13x25x7xf32> 1656} 1657 1658// ----- 1659 1660// CHECK-LABEL: @max_pool_3d_grad_same 1661func.func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: tensor<2x8x4x7x7xf32>, %grad: tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> { 1662 // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> 1663 %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>, tensor<2x8x4x7x7xf32>, tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> 1664 func.return %result : tensor<2x8x13x25x7xf32> 1665} 1666 1667//===----------------------------------------------------------------------===// 1668// OneHot op legalizations. 1669//===----------------------------------------------------------------------===// 1670 1671// ----- 1672 1673// CHECK-LABEL:one_hot 1674func.func @one_hot(%indices: tensor<3xi32>, %on_value: tensor<f32>, %off_value: tensor<f32>) -> tensor<3x5xf32> { 1675 // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x5xi32> 1676 // CHECK: %[[BCAST_ARG0:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<3x5xi32> 1677 // CHECK: %[[COMPARE:.*]] = mhlo.compare EQ, %[[BCAST_ARG0]], %[[IOTA]], NOTYPE : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> 1678 // CHECK: %[[ON_VALUE:.*]] = "mhlo.broadcast"(%arg1) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor<f32>) -> tensor<3x5xf32> 1679 // CHECK: %[[OFF_VALUE:.*]] = "mhlo.broadcast"(%arg2) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor<f32>) -> tensor<3x5xf32> 1680 // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[COMPARE]], %[[ON_VALUE]], %[[OFF_VALUE]]) : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> 1681 // CHECK: return %[[RESULT]] : tensor<3x5xf32> 1682 %depth = "tf.Const"() { value = dense<5> : tensor<i32> } : () -> tensor<i32> 1683 %result = "tf.OneHot"(%indices, %depth, %on_value, %off_value) {axis = -1 : i64} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<3x5xf32> 1684 func.return %result : tensor<3x5xf32> 1685} 1686 1687//===----------------------------------------------------------------------===// 1688// tf.OutfeedEnqueueTuple legalization 1689//===----------------------------------------------------------------------===// 1690 1691// ----- 1692 1693// CHECK-LABEL: func @outfeed_enqueue_tuple 1694// CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>) 1695func.func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () { 1696// CHECK: [[TOKEN:%.*]] = mhlo.create_token : !mhlo.token 1697// CHECK: "mhlo.outfeed"([[VAL_0]], [[VAL_1]], [[TOKEN]]) {outfeed_config = ""} : (tensor<3xi32>, tensor<4xf32>, !mhlo.token) -> !mhlo.token 1698 "tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> () 1699 func.return 1700} 1701 1702//===----------------------------------------------------------------------===// 1703// Pack op legalizations. 1704//===----------------------------------------------------------------------===// 1705 1706// ----- 1707 1708// CHECK-LABEL: func @pack 1709func.func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { 1710 // CHECK: mhlo.reshape {{.*}} : (tensor<2xi32>) -> tensor<1x2xi32> 1711 // CHECK: mhlo.reshape {{.*}} : (tensor<2xi32>) -> tensor<1x2xi32> 1712 // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> 1713 1714 %0 = "tf.Pack"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> 1715 func.return %0 : tensor<2x2xi32> 1716} 1717 1718//===----------------------------------------------------------------------===// 1719// PartitionedCall op legalization. 1720//===----------------------------------------------------------------------===// 1721 1722// ----- 1723 1724// CHECK-LABEL: func @partitioned_call 1725func.func @partitioned_call(%arg0: tensor<i32>) -> tensor<i32> { 1726 // CHECK: call @pcall_func(%arg0) : (tensor<i32>) -> tensor<i32> 1727 %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @pcall_func} : (tensor<i32>) -> (tensor<i32>) 1728 func.return %0 : tensor<i32> 1729} 1730 1731 1732func.func @pcall_func(%arg0: tensor<i32>) -> tensor<i32> { 1733 func.return %arg0 : tensor<i32> 1734} 1735 1736// ----- 1737 1738// CHECK-LABEL: func @partitioned_call_multi_input 1739func.func @partitioned_call_multi_input(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> { 1740 // CHECK: call @pcall_multi_input(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32> 1741 %0 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_input} : (tensor<i32>, tensor<i32>) -> (tensor<i32>) 1742 func.return %0 : tensor<i32> 1743} 1744 1745 1746func.func @pcall_multi_input(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> { 1747 func.return %arg0 : tensor<i32> 1748} 1749 1750// ----- 1751 1752// CHECK-LABEL: func @partitioned_call_multi_in_out 1753func.func @partitioned_call_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { 1754 // CHECK: call @pcall_multi_in_out(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) 1755 %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) 1756 func.return %0, %1 : tensor<i32>, tensor<i32> 1757} 1758 1759 1760func.func @pcall_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { 1761 func.return %arg1, %arg0 : tensor<i32>, tensor<i32> 1762} 1763 1764// CHECK-LABEL: func @unhandled_partitioned_call 1765func.func @unhandled_partitioned_call(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<i32>, tensor<i32>) { 1766 // The argument types don't match the parameter types for the 1767 // pcall_multi_in_out function. That's fine for a PartitionedCallOp but not 1768 // for a standard CallOp, so this op can't be lowered. 1769 // CHECK: "tf.PartitionedCall" 1770 %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<i32>, tensor<i32>) 1771 func.return %0, %1 : tensor<i32>, tensor<i32> 1772} 1773 1774 1775// CHECK-LABEL: func @unhandled_partitioned_call_2 1776func.func @unhandled_partitioned_call_2(%arg0: tensor<i32>, %arg1: tensor<*xi32>) -> (tensor<i32>, tensor<i32>) { 1777 // CHECK: "tf.PartitionedCall" 1778 %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<i32>, tensor<*xi32>) -> (tensor<i32>, tensor<i32>) 1779 func.return %0, %1 : tensor<i32>, tensor<i32> 1780} 1781 1782// ----- 1783 1784//===----------------------------------------------------------------------===// 1785// ReverseV2 op legalization. 1786//===----------------------------------------------------------------------===// 1787 1788// ----- 1789 1790// CHECK-LABEL: @reverse_func_32 1791func.func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> { 1792 %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>) 1793 1794 // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} 1795 %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32> 1796 1797 // CHECK: return [[VAL]] : tensor<5xi32> 1798 func.return %reversed : tensor<5xi32> 1799} 1800 1801// ----- 1802 1803// CHECK-LABEL: @reverse_func_64 1804func.func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> { 1805 %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>) 1806 1807 // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} 1808 %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32> 1809 1810 // CHECK: return [[VAL]] : tensor<5xi32> 1811 func.return %reversed : tensor<5xi32> 1812} 1813 1814// ----- 1815 1816// CHECK-LABEL: @reverse_func_neg 1817func.func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> { 1818 %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) 1819 1820 // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<1> : tensor<1xi64>} 1821 %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32> 1822 1823 // CHECK: return [[VAL]] : tensor<5x5xi32> 1824 func.return %reversed : tensor<5x5xi32> 1825} 1826 1827//===----------------------------------------------------------------------===// 1828// StatefulPartitionedCall op legalization. 1829//===----------------------------------------------------------------------===// 1830 1831// ----- 1832 1833// CHECK-LABEL: func @stateful_partitioned_call 1834// CHECK-SAME: [[ARG:%.+]]: tensor<i32> 1835func.func @stateful_partitioned_call(%arg0: tensor<i32>) -> tensor<i32> { 1836 // CHECK: call @stateful_pcall_func([[ARG]]) : (tensor<i32>) -> tensor<i32> 1837 %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_func} : (tensor<i32>) -> (tensor<i32>) 1838 func.return %0 : tensor<i32> 1839} 1840 1841func.func @stateful_pcall_func(%arg0: tensor<i32>) -> tensor<i32> { 1842 func.return %arg0 : tensor<i32> 1843} 1844 1845// ----- 1846 1847// CHECK-LABEL: func @stateful_partitioned_call_multi_in_out 1848// CHECK-SAME: ([[ARG0:%.+]]: tensor<i32>, [[ARG1:%.+]]: tensor<i32>) 1849func.func @stateful_partitioned_call_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { 1850 // CHECK: call @stateful_pcall_multi_in_out([[ARG0]], [[ARG1]]) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) 1851 %0, %1 = "tf.StatefulPartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_multi_in_out} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) 1852 func.return %0, %1 : tensor<i32>, tensor<i32> 1853} 1854 1855func.func @stateful_pcall_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { 1856 func.return %arg1, %arg0 : tensor<i32>, tensor<i32> 1857} 1858 1859//===----------------------------------------------------------------------===// 1860// Elu op legalizations. 1861//===----------------------------------------------------------------------===// 1862 1863// ----- 1864 1865// CHECK-LABEL: func @elu 1866func.func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> { 1867 // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<1xf32>) -> tensor<1xf32> 1868 // CHECK-DAG: %[[PRED:.*]] = mhlo.compare GT, %arg0, %[[ZERO]] 1869 // CHECK-DAG: %[[EXP:.*]] = mhlo.exponential_minus_one %arg0 1870 // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %arg0, %[[EXP]]) 1871 // CHECK: return %[[RESULT]] 1872 %0 = "tf.Elu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> 1873 func.return %0: tensor<1xf32> 1874} 1875 1876// ----- 1877 1878// CHECK-LABEL: func @elu_unranked 1879func.func @elu_unranked(%arg0: tensor<?xf32>) -> tensor<?xf32> { 1880 // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<?xf32>) -> tensor<?xf32> 1881 // CHECK-DAG: %[[PRED:.*]] = mhlo.compare GT, %arg0, %[[ZERO]] 1882 // CHECK-DAG: %[[EXP:.*]] = mhlo.exponential_minus_one %arg0 1883 // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %arg0, %[[EXP]]) 1884 // CHECK: return %[[RESULT]] 1885 %0 = "tf.Elu"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 1886 func.return %0: tensor<?xf32> 1887} 1888 1889// ----- 1890 1891// CHECK-LABEL: func @elu_grad 1892// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<?x?xf32>) 1893func.func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tensor<4x8xf32> { 1894 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 1895 // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 1896 // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #mhlo<comparison_direction GT>} 1897 // CHECK-DAG: %[[ADD1:.*]] = chlo.broadcast_add %[[FEATURES]], %[[ONE]] {broadcast_dimensions = dense<> : tensor<0xi64>} 1898 // CHECK-DAG: %[[MULGRAD:.*]] = mhlo.multiply(%[[GRADIENTS]], %[[ADD1]]) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32> 1899 // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[MULGRAD]]) 1900 // CHECK: return %[[RESULT]] 1901 %2 = "tf.EluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32> 1902 func.return %2 : tensor<4x8xf32> 1903} 1904 1905//===----------------------------------------------------------------------===// 1906// Relu op legalizations. 1907//===----------------------------------------------------------------------===// 1908 1909// ----- 1910 1911// CHECK-LABEL: func @relu 1912func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { 1913 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 1914 // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32> 1915 %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> 1916 func.return %0: tensor<1xi32> 1917} 1918 1919// ----- 1920 1921// CHECK-LABEL: func @relu_unranked 1922func.func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> { 1923 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 1924 // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32> 1925 %0 = "tf.Relu"(%arg0) : (tensor<?xi32>) -> tensor<?xi32> 1926 func.return %0: tensor<?xi32> 1927} 1928 1929// ----- 1930 1931// CHECK-LABEL: func @relu_unsigned 1932func.func @relu_unsigned(%arg0: tensor<?xui32>) -> tensor<?xui32> { 1933 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<ui32> 1934 // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<ui32>, tensor<?xui32>) -> tensor<?xui32> 1935 %0 = "tf.Relu"(%arg0) : (tensor<?xui32>) -> tensor<?xui32> 1936 func.return %0: tensor<?xui32> 1937} 1938 1939// ----- 1940 1941// CHECK-LABEL: func @relu6 1942func.func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { 1943 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 1944 // CHECK-DAG: %[[SIX:.*]] = mhlo.constant dense<6> : tensor<i32> 1945 // CHECK: mhlo.clamp %[[ZERO]], %arg0, %[[SIX]] : (tensor<i32>, tensor<1xi32>, tensor<i32>) -> tensor<1xi32> 1946 %0 = "tf.Relu6"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> 1947 func.return %0: tensor<1xi32> 1948} 1949 1950// ----- 1951 1952// CHECK-LABEL: func @relu6_unranked 1953func.func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> { 1954 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 1955 // CHECK-DAG: %[[SIX:.*]] = mhlo.constant dense<6> : tensor<i32> 1956 // CHECK: mhlo.clamp %[[ZERO]], %arg0, %[[SIX]] : (tensor<i32>, tensor<?xi32>, tensor<i32>) -> tensor<?xi32> 1957 %0 = "tf.Relu6"(%arg0) : (tensor<?xi32>) -> tensor<?xi32> 1958 func.return %0: tensor<?xi32> 1959} 1960 1961// ----- 1962 1963// CHECK-LABEL: func @relu6_unsigned 1964func.func @relu6_unsigned(%arg0: tensor<?xui32>) -> tensor<?xui32> { 1965 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<ui32> 1966 // CHECK-DAG: %[[SIX:.*]] = mhlo.constant dense<6> : tensor<ui32> 1967 // CHECK: mhlo.clamp %[[ZERO]], %arg0, %[[SIX]] : (tensor<ui32>, tensor<?xui32>, tensor<ui32>) -> tensor<?xui32> 1968 %0 = "tf.Relu6"(%arg0) : (tensor<?xui32>) -> tensor<?xui32> 1969 func.return %0: tensor<?xui32> 1970} 1971 1972// ----- 1973 1974// CHECK-LABEL: func @relu_grad_unranked 1975// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<?x?xf32>, %[[FEATURES:.*]]: tensor<?x?xf32>) 1976func.func @relu_grad_unranked(%gradients: tensor<?x?xf32>, %features: tensor<?x?xf32>) -> tensor<?x?xf32> { 1977 // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32> 1978 // CHECK-DAG: %[[PRED:.*]] = mhlo.compare GT, %arg1, %0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1> 1979 // CHECK-DAG: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<?x?xi1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> 1980 // CHECK-DAG: return %[[RESULT]] : tensor<?x?xf32> 1981 %2 = "tf.ReluGrad"(%gradients, %features) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> 1982 func.return %2 : tensor<?x?xf32> 1983} 1984 1985// ----- 1986 1987// CHECK-LABEL: func @leaky_relu 1988func.func @leaky_relu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attributes {tf.entry_function = {}} { 1989 // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg0) {value = 2.000000e-01 : f32} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> 1990 // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> 1991 // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[ALPHA]] : tensor<1x4x4x3xf32> 1992 // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP]], %[[ZERO]], NOTYPE : (tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xi1> 1993 // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[INP]], %[[LEAKY]]) : (tensor<1x4x4x3xi1>, tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> 1994 // CHECK-NEXT: return %[[RES]] : tensor<1x4x4x3xf32> 1995 %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> 1996 func.return %0 : tensor<1x4x4x3xf32> 1997} 1998 1999// ----- 2000 2001// CHECK-LABEL: func @leaky_relu_unranked 2002func.func @leaky_relu_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> attributes {tf.entry_function = {}} { 2003 // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg0) {value = 2.000000e-01 : f32} : (tensor<*xf32>) -> tensor<*xf32> 2004 // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> 2005 // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[ALPHA]] : tensor<*xf32> 2006 // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP]], %[[ZERO]], NOTYPE : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1> 2007 // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[INP]], %[[LEAKY]]) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> 2008 // CHECK-NEXT: return %[[RES]] : tensor<*xf32> 2009 %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<*xf32>) -> tensor<*xf32> 2010 func.return %0 : tensor<*xf32> 2011} 2012 2013// ----- 2014 2015// CHECK-LABEL: func @leaky_relu_grad 2016func.func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) -> tensor<1x4x4xf32> attributes {tf.entry_function = {}} { 2017 // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg1) {value = 2.000000e-01 : f32} : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32> 2018 // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) {value = 0.000000e+00 : f32} : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32> 2019 // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] : tensor<1x4x4xf32> 2020 // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP:.*]], %[[ZERO]], NOTYPE : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xi1> 2021 // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]]) : (tensor<1x4x4xi1>, tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32> 2022 // CHECK-NEXT: return %[[RES]] : tensor<1x4x4xf32> 2023 %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32> 2024 func.return %0 : tensor<1x4x4xf32> 2025} 2026 2027// ----- 2028 2029// CHECK-LABEL: func @leaky_relu_grad_unranked 2030func.func @leaky_relu_grad_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {tf.entry_function = {}} { 2031 // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg1) {value = 2.000000e-01 : f32} : (tensor<*xf32>) -> tensor<*xf32> 2032 // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> 2033 // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] : tensor<*xf32> 2034 // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP:.*]], %[[ZERO]], NOTYPE : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1> 2035 // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]]) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> 2036 // CHECK-NEXT: return %[[RES]] : tensor<*xf32> 2037 %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> 2038 func.return %0 : tensor<*xf32> 2039} 2040 2041// ----- 2042 2043// CHECK-LABEL: func @softsign 2044func.func @softsign(%arg0: tensor<4x10xf32>) -> tensor<4x10xf32> { 2045 // CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : f32} : (tensor<4x10xf32>) -> tensor<4x10xf32> 2046 // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<4x10xf32> 2047 // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[ONE]], %[[ABS]] : tensor<4x10xf32> 2048 // CHECK-NEXT: %[[DIV:.*]] = mhlo.divide %{{.*}}, %[[ADD]] : tensor<4x10xf32> 2049 // CHECK-NEXT: return %[[DIV]] : tensor<4x10xf32> 2050 %0 = "tf.Softsign"(%arg0) : (tensor<4x10xf32>) -> tensor<4x10xf32> 2051 func.return %0 : tensor<4x10xf32> 2052} 2053 2054// ----- 2055 2056// CHECK-LABEL: func @softsign_unranked 2057func.func @softsign_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2058 // CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> 2059 // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<*xf32> 2060 // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[ONE]], %[[ABS]] : tensor<*xf32> 2061 // CHECK-NEXT: %[[DIV:.*]] = mhlo.divide %{{.*}}, %[[ADD]] : tensor<*xf32> 2062 // CHECK-NEXT: return %[[DIV]] : tensor<*xf32> 2063 %0 = "tf.Softsign"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2064 func.return %0 : tensor<*xf32> 2065} 2066 2067// ----- 2068 2069// CHECK-LABEL: func @softsign_grad 2070func.func @softsign_grad(%arg0: tensor<4x10xf32>, %arg1: tensor<4x10xf32>) -> tensor<4x10xf32> { 2071 2072 // CHECK-NEXT: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 2073 // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<4x10xf32> 2074 // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = chlo.broadcast_add %[[ONE]], %[[ABS]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<4x10xf32>) -> tensor<4x10xf32> 2075 // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %[[BROADCAST_ADD]], %[[BROADCAST_ADD]] : tensor<4x10xf32> 2076 // CHECK-NEXT: %[[BROADCAST_DIV:.*]] = chlo.broadcast_divide %{{.*}}, %[[MUL]] : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> 2077 // CHECK-NEXT: return %[[BROADCAST_DIV]] : tensor<4x10xf32> 2078 %0 = "tf.SoftsignGrad"(%arg0, %arg1) : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> 2079 func.return %0 : tensor<4x10xf32> 2080} 2081 2082//===----------------------------------------------------------------------===// 2083// Roll op legalizations. 2084//===----------------------------------------------------------------------===// 2085 2086// ----- 2087 2088// CHECK-LABEL: func @Roll_0D 2089func.func @Roll_0D(%arg0: tensor<512xi32>, %shift: tensor<i32>) -> tensor<512xi32> { 2090 %axis = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>) 2091 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 2092 // CHECK-DAG: %[[AXIS_SIZE:.*]] = mhlo.constant dense<512> : tensor<i32> 2093 // CHECK: %[[T1:.+]] = mhlo.remainder %arg1, %[[AXIS_SIZE]] : tensor<i32> 2094 // CHECK: %[[T2:.+]] = mhlo.add %[[T1]], %[[AXIS_SIZE]] : tensor<i32> 2095 // CHECK: %[[T3:.+]] = mhlo.remainder %[[T2]], %[[AXIS_SIZE]] : tensor<i32> 2096 // CHECK: %[[CONCAT:.+]] = "mhlo.concatenate"(%arg0, %arg0) {dimension = 0 : i64} 2097 // CHECK: %[[OFFSET:.+]] = mhlo.subtract %[[AXIS_SIZE]], %[[T3]] : tensor<i32> 2098 // CHECK: "mhlo.dynamic_slice"(%[[CONCAT]], %[[OFFSET]]) 2099 // CHECK-SAME: {slice_sizes = dense<512> : tensor<1xi64>} 2100 // CHECK-SAME: (tensor<1024xi32>, tensor<i32>) -> tensor<512xi32> 2101 %0 = "tf.Roll"(%arg0, %shift, %axis) {device = ""} : (tensor<512xi32>, tensor<i32>, tensor<i32>) -> tensor<512xi32> 2102 func.return %0 : tensor<512xi32> 2103} 2104 2105//===----------------------------------------------------------------------===// 2106// Select op legalizations. 2107//===----------------------------------------------------------------------===// 2108 2109// ----- 2110 2111// CHECK-LABEL: func @select_batch_static 2112func.func @select_batch_static(%arg0: tensor<2xi1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { 2113 // CHECK: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %{{.*}}) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<2xi1>, tensor<3xindex>) -> tensor<2x6x8xi1> 2114 // CHECK: "mhlo.select"(%[[BCAST]], %arg1, %arg2) 2115 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> 2116 func.return %0: tensor<2x6x8xi32> 2117} 2118 2119// ----- 2120 2121// CHECK-LABEL: func @select_batch_static_r1 2122func.func @select_batch_static_r1(%arg0: tensor<i1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { 2123 // CHECK: "mhlo.select"(%arg0, %arg1, %arg2) 2124 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> 2125 func.return %0: tensor<2x6x8xi32> 2126} 2127 2128// ----- 2129 2130// CHECK-LABEL: func @select_batch_static_all_same 2131func.func @select_batch_static_all_same(%arg0: tensor<2x6x8xi1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { 2132 // CHECK: "mhlo.select"(%arg0, %arg1, %arg2) 2133 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2x6x8xi1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> 2134 func.return %0: tensor<2x6x8xi32> 2135} 2136 2137// ----- 2138 2139// CHECK-LABEL: func @select_batch_dynamic_r1 2140func.func @select_batch_dynamic_r1(%arg0: tensor<?xi1>, %arg1: tensor<?x?x8xi32>, %arg2: tensor<?x?x8xi32>) -> tensor<?x?x8xi32> { 2141 // CHECK-NEXT: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor<?xi1> -> tensor<1xindex> 2142 // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<?x?x8xi32> -> tensor<3xindex> 2143 // CHECK-NEXT: %[[SHAPE2:.*]] = shape.shape_of %arg2 : tensor<?x?x8xi32> -> tensor<3xindex> 2144 // CHECK-NEXT: %[[SHAPEEQ1:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> 2145 // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index 2146 // CHECK-NEXT: %[[HEAD:.*]], %[[TAIL:.*]] = "shape.split_at"(%[[SHAPE1]], %[[C1]]) : (tensor<3xindex>, index) -> (tensor<1xindex>, tensor<2xindex>) 2147 // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[HEAD]] : tensor<1xindex>, tensor<1xindex> 2148 // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming_all %[[SHAPEEQ1]], %[[SHAPEEQ2]] 2149 // CHECK-NEXT: %[[ASSUMING:.*]] = shape.assuming %[[SHAPEEQ]] -> (tensor<?x?x8xi32>) { 2150 // CHECK-NEXT: %[[SHAPE1E:.*]] = shape.to_extent_tensor %[[SHAPE1]] : tensor<3xindex> -> tensor<3xindex> 2151 // CHECK-NEXT: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE1E]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi1>, tensor<3xindex>) -> tensor<?x?x8xi1> 2152 // CHECK-NEXT: %[[SELECT:.*]] = "mhlo.select"(%[[BCAST]], %arg1, %arg2) : (tensor<?x?x8xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32> 2153 // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<?x?x8xi32> 2154 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<?xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32> 2155 func.return %0: tensor<?x?x8xi32> 2156} 2157 2158// ----- 2159 2160// CHECK-LABEL: func @select_batch_dynamic 2161func.func @select_batch_dynamic(%arg0: tensor<?x?x8xi1>, %arg1: tensor<?x?x8xi32>, %arg2: tensor<?x?x8xi32>) -> tensor<?x?x8xi32> { 2162 // CHECK-NEXT: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor<?x?x8xi1> -> tensor<3xindex> 2163 // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<?x?x8xi32> -> tensor<3xindex> 2164 // CHECK-NEXT: %[[SHAPE2:.*]] = shape.shape_of %arg2 : tensor<?x?x8xi32> -> tensor<3xindex> 2165 // CHECK-NEXT: %[[SHAPEEQ1:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> 2166 // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[SHAPE1]] : tensor<3xindex>, tensor<3xindex> 2167 // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming_all %[[SHAPEEQ1]], %[[SHAPEEQ2]] 2168 // CHECK-NEXT: %[[ASSUMING:.*]] = shape.assuming %[[SHAPEEQ]] -> (tensor<?x?x8xi32>) { 2169 // CHECK-NEXT: %[[SELECT:.*]] = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<?x?x8xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32> 2170 // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<?x?x8xi32> 2171 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<?x?x8xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32> 2172 func.return %0: tensor<?x?x8xi32> 2173} 2174 2175// ----- 2176 2177// CHECK-LABEL: testSelectInvalidUnranked 2178func.func @testSelectInvalidUnranked(%arg0: tensor<6x7xi1>, %arg1: tensor<*xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { 2179 // CHECK-NEXT: tf.Select 2180 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<6x7xi1>, tensor<*xf16>, tensor<*xf16>) -> tensor<*xf16> 2181 func.return %0: tensor<*xf16> 2182} 2183 2184// ----- 2185 2186// CHECK-LABEL: testSelectThenUnranked 2187func.func @testSelectThenUnranked(%arg0: tensor<3xi1>, %arg1: tensor<*xf16>, %arg2: tensor<3x2xf16>) -> tensor<*xf16> { 2188 // CHECK-NEXT: tf.Select 2189 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<*xf16>, tensor<3x2xf16>) -> tensor<*xf16> 2190 func.return %0: tensor<*xf16> 2191} 2192 2193// ----- 2194 2195// CHECK-LABEL: testSelectElseUnranked 2196func.func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { 2197 // CHECK-NEXT: tf.Select 2198 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<*xf16>) -> tensor<*xf16> 2199 func.return %0: tensor<*xf16> 2200} 2201 2202// ----- 2203 2204// CHECK-LABEL: func @selectv2_dynamic_ranked 2205func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> { 2206 // CHECK: chlo.broadcast_select 2207 %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32> 2208 func.return %0: tensor<2x?x8xi32> 2209} 2210 2211// ----- 2212 2213// CHECK-LABEL: func @selectv2_unranked 2214func.func @selectv2_unranked(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<*xi32>) -> tensor<*xi32> { 2215 // CHECK: chlo.broadcast_select 2216 %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<*xi32>) -> tensor<*xi32> 2217 func.return %0: tensor<*xi32> 2218} 2219 2220//===----------------------------------------------------------------------===// 2221// Fast Fourier Transform op legalization. 2222//===----------------------------------------------------------------------===// 2223 2224// ----- 2225 2226// CHECK-LABEL: func @fft_1D 2227func.func @fft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> { 2228 // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo<fft_type FFT>} : (tensor<8xcomplex<f32>> 2229 %0 = "tf.FFT"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> 2230 func.return %0 : tensor<8xcomplex<f32>> 2231} 2232 2233// ----- 2234 2235// CHECK-LABEL: func @ifft_1D 2236func.func @ifft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> { 2237 // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo<fft_type IFFT>} : (tensor<8xcomplex<f32>> 2238 %0 = "tf.IFFT"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> 2239 func.return %0 : tensor<8xcomplex<f32>> 2240} 2241 2242// ----- 2243 2244// CHECK-LABEL: func @rfft_1D 2245func.func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<5xcomplex<f32>> { 2246 %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 2247 // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<8xf32> 2248 %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<8xf32>, tensor<1xi32>) -> tensor<5xcomplex<f32>> 2249 func.return %0 : tensor<5xcomplex<f32>> 2250} 2251 2252// ----- 2253 2254// CHECK-LABEL: func @rfft_1D_padded 2255func.func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<5xcomplex<f32>> { 2256 %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 2257 // CHECK: %[[PADDED:.*]] = "mhlo.pad"(%arg0, %{{.*}}) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<7xf32>, tensor<f32>) -> tensor<8xf32> 2258 // CHECK: "mhlo.fft"(%[[PADDED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<8xf32> 2259 %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<7xf32>, tensor<1xi32>) -> tensor<5xcomplex<f32>> 2260 func.return %0 : tensor<5xcomplex<f32>> 2261} 2262 2263// ----- 2264 2265// CHECK-LABEL: func @rfft_1D_sliced 2266func.func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x5xcomplex<f32>> { 2267 %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 2268 // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[2, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x9xf32>) -> tensor<2x8xf32> 2269 // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<2x8xf32> 2270 %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<2x9xf32>, tensor<1xi32>) -> tensor<2x5xcomplex<f32>> 2271 func.return %0 : tensor<2x5xcomplex<f32>> 2272} 2273 2274// ----- 2275 2276// CHECK-LABEL: func @irfft_1D 2277func.func @irfft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xf32> { 2278 %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 2279 // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<5> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<8xcomplex<f32>>) -> tensor<5xcomplex<f32>> 2280 // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo<fft_type IRFFT>} : (tensor<5xcomplex<f32>> 2281 %0 = "tf.IRFFT"(%arg0, %fftlength) : (tensor<8xcomplex<f32>>, tensor<1xi32>) -> tensor<8xf32> 2282 func.return %0 : tensor<8xf32> 2283} 2284 2285// ----- 2286 2287// CHECK-LABEL: fft_1D_dynamic 2288func.func @fft_1D_dynamic(%arg0: tensor<?xcomplex<f32>>) -> tensor<8xcomplex<f32>> { 2289 // CHECK: "tf.FFT" 2290 %0 = "tf.FFT"(%arg0) : (tensor<?xcomplex<f32>>) -> tensor<8xcomplex<f32>> 2291 func.return %0 : tensor<8xcomplex<f32>> 2292} 2293 2294// ----- 2295 2296// CHECK-LABEL: rfft_1D_dynamic 2297func.func @rfft_1D_dynamic(%arg0: tensor<?xf32>) -> tensor<8xcomplex<f32>> { 2298 %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 2299 // CHECK: "tf.RFFT" 2300 %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<?xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>> 2301 func.return %0 : tensor<8xcomplex<f32>> 2302} 2303 2304//===----------------------------------------------------------------------===// 2305// Shape op legalization. 2306//===----------------------------------------------------------------------===// 2307 2308// ----- 2309 2310// CHECK-LABEL: func @shape_1D 2311func.func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> { 2312 // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 2313 // CHECK: [[TENSOR:%.+]] = arith.index_cast [[SHAPE]] : tensor<1xindex> to tensor<1xi32> 2314 %0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32> 2315 2316 // CHECK: return [[TENSOR]] 2317 func.return %0 : tensor<1xi32> 2318} 2319 2320// ----- 2321 2322// CHECK-LABEL: func @shape_2D 2323func.func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> { 2324 // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 2325 // CHECK: [[TENSOR:%.+]] = arith.index_cast [[SHAPE]] : tensor<2xindex> to tensor<2xi32> 2326 %0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32> 2327 2328 // CHECK: return [[TENSOR]] 2329 func.return %0 : tensor<2xi32> 2330} 2331 2332// ----- 2333 2334// CHECK-LABEL: func @shape_rankless 2335func.func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> { 2336 // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 2337 // CHECK: [[TENSOR:%.+]] = arith.index_cast [[SHAPE]] : tensor<?xindex> to tensor<?xi32> 2338 %0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32> 2339 2340 // CHECK: return [[TENSOR]] 2341 func.return %0 : tensor<?xi32> 2342} 2343 2344//===----------------------------------------------------------------------===// 2345// Transpose op legalization. 2346//===----------------------------------------------------------------------===// 2347 2348// ----- 2349 2350// CHECK-LABEL: @transpose_noop 2351func.func @transpose_noop(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { 2352 %permutation = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2353 // CHECK: return %arg0 2354 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<2x3xf32> 2355 func.return %0 : tensor<2x3xf32> 2356} 2357 2358// ----- 2359 2360// CHECK-LABEL: @transpose_2d 2361func.func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { 2362 %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2363 // CHECK: "mhlo.transpose" 2364 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32> 2365 func.return %0 : tensor<3x2xf32> 2366} 2367 2368// ----- 2369 2370// CHECK-LABEL: @transpose_3d_int32 2371func.func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { 2372 %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> (tensor<3xi32>) 2373 // CHECK: "mhlo.transpose" 2374 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x2x1xf32> 2375 func.return %0 : tensor<3x2x1xf32> 2376} 2377 2378// ----- 2379 2380// CHECK-LABEL: @transpose_3d 2381func.func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { 2382 %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> (tensor<3xi64>) 2383 // CHECK: "mhlo.transpose" 2384 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> 2385 func.return %0 : tensor<3x2x1xf32> 2386} 2387 2388// ----- 2389 2390// CHECK-LABEL: @transpose_dynamic_2d 2391func.func @transpose_dynamic_2d(%arg0: tensor<?x4xf32>) -> tensor<4x?xf32> { 2392 %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2393 // CHECK: "mhlo.transpose" 2394 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<?x4xf32>, tensor<2xi64>) -> tensor<4x?xf32> 2395 func.return %0 : tensor<4x?xf32> 2396} 2397 2398// ----- 2399 2400// CHECK-LABEL: @transpose_unranked_2d 2401func.func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2402 %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2403 // CHECK: "mhlo.transpose" 2404 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<*xf32>, tensor<2xi64>) -> tensor<*xf32> 2405 func.return %0 : tensor<*xf32> 2406} 2407 2408 2409//===----------------------------------------------------------------------===// 2410// Unary op legalizations. 2411//===----------------------------------------------------------------------===// 2412 2413// ----- 2414 2415// CHECK-LABEL: @abs 2416func.func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2417 // CHECK: mhlo.abs %arg0 : tensor<2xf32> 2418 %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2419 func.return %0 : tensor<2xf32> 2420} 2421 2422// ----- 2423 2424// CHECK-LABEL: func @abs_dynamic 2425func.func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2426 // CHECK: mhlo.abs %arg0 : tensor<?xf32> 2427 %0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2428 func.return %0 : tensor<?xf32> 2429} 2430 2431// ----- 2432 2433// CHECK-LABEL: func @abs_unranked 2434func.func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2435 // CHECK: mhlo.abs %arg0 : tensor<*xf32> 2436 %0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2437 func.return %0 : tensor<*xf32> 2438} 2439 2440// ----- 2441 2442// CHECK-LABEL: @acos 2443// CHLO-LABEL: @acos 2444func.func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2445 // CHECK: chlo.acos %arg0 : tensor<2xf32> 2446// CHLO: %[[VAL_1:.*]] = mhlo.compare NE, {{.*}} 2447// CHLO: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> 2448// CHLO: %[[VAL_4:.*]] = mhlo.constant dense<1.000000e+00> 2449// CHLO: %[[VAL_5:.*]] = mhlo.multiply %arg0, %arg0 2450// CHLO: %[[VAL_6:.*]] = mhlo.subtract %[[VAL_4]], %[[VAL_5]] 2451// CHLO: %[[VAL_7:.*]] = mhlo.sqrt %[[VAL_6]] 2452// CHLO: %[[VAL_8:.*]] = mhlo.constant dense<1.000000e+00> 2453// CHLO: %[[VAL_9:.*]] = mhlo.add %[[VAL_8]], %arg0 2454// CHLO: %[[VAL_10:.*]] = mhlo.atan2 %[[VAL_7]], %[[VAL_9]] 2455// CHLO: %[[VAL_11:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_10]] 2456// CHLO: %[[VAL_12:.*]] = mhlo.constant dense<3.14159274> 2457// CHLO: %[[VAL_13:.*]] = "mhlo.select"(%[[VAL_1]], %[[VAL_11]], %[[VAL_12]]) 2458// CHLO: return %[[VAL_13]] : tensor<2xf32> 2459 %0 = "tf.Acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2460 func.return %0 : tensor<2xf32> 2461} 2462 2463// ----- 2464 2465// CHECK-LABEL: @acos_complex 2466// CHLO-LABEL: @acos_complex 2467func.func @acos_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> { 2468 // CHLO: tf.Acos 2469 %0 = "tf.Acos"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> 2470 func.return %0 : tensor<2xcomplex<f32>> 2471} 2472 2473// ----- 2474 2475// CHECK-LABEL: @acos_dynamic 2476// CHLO-LABEL: @acos_dynamic 2477func.func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2478 // CHECK: chlo.acos %arg0 : tensor<*xf32> 2479 // `tf.Acos` is lowered to `chlo.constant_like` operations which can only be 2480 // lowered further on ranked tensors. Unranked CHLO must be transformed to 2481 // ranked code before further lowering. 2482 // CHLO: "tf.Acos" 2483 %0 = "tf.Acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2484 func.return %0 : tensor<*xf32> 2485} 2486 2487// ----- 2488 2489// CHECK-LABEL: @tan 2490// CHECK-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32> 2491// CHLO-LABEL: @tan 2492// CHLO-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32> 2493func.func @tan(%arg : tensor<2xf32>) -> tensor<2xf32> { 2494 // CHECK: chlo.tan %[[ARG]] : tensor<2xf32> 2495 // CHLO: %[[SINE:.*]] = mhlo.sine %[[ARG]] 2496 // CHLO %[[COSINE:.*]] = mhlo.cosine %[[ARG]] 2497 // CHLO %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]]) 2498 %result = "tf.Tan"(%arg) : (tensor<2xf32>) -> tensor<2xf32> 2499 func.return %result : tensor<2xf32> 2500} 2501 2502// ----- 2503 2504// CHECK-LABEL: @tan_unranked 2505// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> 2506// CHLO-LABEL: @tan_unranked 2507// CHLO-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> 2508func.func @tan_unranked(%arg : tensor<*xf32>) -> tensor<*xf32> { 2509 // CHECK: chlo.tan %[[ARG]] : tensor<*xf32> 2510 // CHLO: %[[SINE:.*]] = mhlo.sine %[[ARG]] 2511 // CHLO %[[COSINE:.*]] = mhlo.cosine %[[ARG]] 2512 // CHLO %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]]) 2513 %result = "tf.Tan"(%arg) : (tensor<*xf32>) -> tensor<*xf32> 2514 func.return %result : tensor<*xf32> 2515} 2516 2517// ----- 2518 2519// CHECK-LABEL: func @cast_dynamic_i2f 2520func.func @cast_dynamic_i2f(%arg0: tensor<?xi32>) -> tensor<?xf32> { 2521 // CHECK: mhlo.convert(%arg0) : (tensor<?xi32>) -> tensor<?xf32> 2522 %0 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> tensor<?xf32> 2523 func.return %0 : tensor<?xf32> 2524} 2525 2526// ----- 2527 2528// CHECK-LABEL: func @cast_i2f 2529func.func @cast_i2f(%arg0: tensor<2xi32>) -> tensor<2xf32> { 2530 // CHECK: mhlo.convert(%arg0) : (tensor<2xi32>) -> tensor<2xf32> 2531 %0 = "tf.Cast"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> 2532 func.return %0 : tensor<2xf32> 2533} 2534 2535// ----- 2536 2537// CHECK-LABEL: func @cast_c2f 2538func.func @cast_c2f(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> { 2539 // CHECK: mhlo.convert(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32> 2540 %0 = "tf.Cast"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32> 2541 func.return %0 : tensor<2xf32> 2542} 2543 2544// ----- 2545 2546// CHECK-LABEL: @ceil 2547func.func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2548 // CHECK: mhlo.ceil %arg0 : tensor<2xf32> 2549 %0 = "tf.Ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2550 func.return %0 : tensor<2xf32> 2551} 2552 2553// ----- 2554 2555// CHECK-LABEL: func @ceil_dynamic 2556func.func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2557 // CHECK: mhlo.ceil %arg0 : tensor<?xf32> 2558 %0 = "tf.Ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2559 func.return %0 : tensor<?xf32> 2560} 2561 2562// ----- 2563 2564// CHECK-LABEL: func @ceil_unranked 2565func.func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2566 // CHECK: mhlo.ceil %arg0 : tensor<*xf32> 2567 %0 = "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2568 func.return %0 : tensor<*xf32> 2569} 2570 2571// ----- 2572 2573// CHECK-LABEL: @complex_abs 2574func.func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> { 2575 // CHECK: mhlo.abs(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32> 2576 %0 = "tf.ComplexAbs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32> 2577 func.return %0 : tensor<2xf32> 2578} 2579 2580// ----- 2581 2582// CHECK-LABEL: @cos 2583func.func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2584 // CHECK: mhlo.cosine %arg0 : tensor<2xf32> 2585 %0 = "tf.Cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2586 func.return %0 : tensor<2xf32> 2587} 2588 2589// ----- 2590 2591// CHECK-LABEL: func @cos_dynamic 2592func.func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2593 // CHECK: mhlo.cosine %arg0 : tensor<?xf32> 2594 %0 = "tf.Cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2595 func.return %0 : tensor<?xf32> 2596} 2597 2598// ----- 2599 2600// CHECK-LABEL: func @cos_unranked 2601func.func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2602 // CHECK: mhlo.cosine %arg0 : tensor<*xf32> 2603 %0 = "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2604 func.return %0 : tensor<*xf32> 2605} 2606 2607// ----- 2608 2609// CHECK-LABEL: @exp 2610func.func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2611 // CHECK: mhlo.exponential %arg0 : tensor<2xf32> 2612 %0 = "tf.Exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2613 func.return %0 : tensor<2xf32> 2614} 2615 2616// ----- 2617 2618// CHECK-LABEL: @expm1 2619func.func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2620 // CHECK: mhlo.exponential_minus_one %arg0 : tensor<2xf32> 2621 %0 = "tf.Expm1"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2622 func.return %0 : tensor<2xf32> 2623} 2624 2625// ----- 2626 2627// CHECK-LABEL: func @exp_dynamic 2628func.func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2629 // CHECK: mhlo.exponential %arg0 : tensor<?xf32> 2630 %0 = "tf.Exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2631 func.return %0 : tensor<?xf32> 2632} 2633 2634// ----- 2635 2636// CHECK-LABEL: func @exp_unranked 2637func.func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2638 // CHECK: mhlo.exponential %arg0 : tensor<*xf32> 2639 %0 = "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2640 func.return %0 : tensor<*xf32> 2641} 2642 2643// ----- 2644 2645// CHECK-LABEL: @floor 2646func.func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2647 // CHECK: mhlo.floor %arg0 : tensor<2xf32> 2648 %0 = "tf.Floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2649 func.return %0 : tensor<2xf32> 2650} 2651 2652// ----- 2653 2654// CHECK-LABEL: func @floor_dynamic 2655func.func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2656 // CHECK: mhlo.floor %arg0 : tensor<?xf32> 2657 %0 = "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2658 func.return %0 : tensor<?xf32> 2659} 2660 2661// ----- 2662 2663// CHECK-LABEL: func @floor_unranked 2664func.func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2665 // CHECK: mhlo.floor %arg0 : tensor<*xf32> 2666 %0 = "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2667 func.return %0 : tensor<*xf32> 2668} 2669 2670// ----- 2671 2672// CHECK-LABEL: func @invert_op_unranked 2673func.func @invert_op_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> { 2674 // CHECK: mhlo.not %arg0 : tensor<*xi32> 2675 %0 = "tf.Invert"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> 2676 func.return %0 : tensor<*xi32> 2677} 2678 2679// ----- 2680 2681// CHECK-LABEL: @is_finite 2682func.func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { 2683 // CHECK: mhlo.is_finite(%arg0) : (tensor<2xf32>) -> tensor<2xi1> 2684 %0 = "tf.IsFinite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> 2685 func.return %0 : tensor<2xi1> 2686} 2687 2688// ----- 2689 2690// CHECK-LABEL: func @is_finite_dynamic 2691func.func @is_finite_dynamic(%arg0: tensor<?xf32>) -> tensor<?xi1> { 2692 // CHECK: mhlo.is_finite(%arg0) : (tensor<?xf32>) -> tensor<?xi1> 2693 %0 = "tf.IsFinite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1> 2694 func.return %0 : tensor<?xi1> 2695} 2696 2697// ----- 2698 2699// CHECK-LABEL: func @is_finite_unranked 2700func.func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> { 2701 // CHECK: mhlo.is_finite(%arg0) : (tensor<*xf32>) -> tensor<*xi1> 2702 %0 = "tf.IsFinite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1> 2703 func.return %0 : tensor<*xi1> 2704} 2705 2706// ----- 2707 2708// CHECK-LABEL: @log 2709func.func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2710 // CHECK: mhlo.log %arg0 : tensor<2xf32> 2711 %0 = "tf.Log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2712 func.return %0 : tensor<2xf32> 2713} 2714 2715// ----- 2716 2717// CHECK-LABEL: func @log_dynamic 2718func.func @log_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2719 // CHECK: mhlo.log %arg0 : tensor<?xf32> 2720 %0 = "tf.Log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2721 func.return %0 : tensor<?xf32> 2722} 2723 2724// ----- 2725 2726// CHECK-LABEL: func @log_unranked 2727func.func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2728 // CHECK: mhlo.log %arg0 : tensor<*xf32> 2729 %0 = "tf.Log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2730 func.return %0 : tensor<*xf32> 2731} 2732 2733// ----- 2734 2735// CHECK-LABEL: @log1p 2736func.func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2737 // CHECK: mhlo.log_plus_one %arg0 : tensor<2xf32> 2738 %0 = "tf.Log1p"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2739 func.return %0 : tensor<2xf32> 2740} 2741 2742// ----- 2743 2744// CHECK-LABEL: func @log1p_dynamic 2745func.func @log1p_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2746 // CHECK: mhlo.log_plus_one %arg0 : tensor<?xf32> 2747 %0 = "tf.Log1p"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2748 func.return %0 : tensor<?xf32> 2749} 2750 2751// ----- 2752 2753// CHECK-LABEL: func @log1p_unranked 2754func.func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2755 // CHECK: mhlo.log_plus_one %arg0 : tensor<*xf32> 2756 %0 = "tf.Log1p"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2757 func.return %0 : tensor<*xf32> 2758} 2759 2760// ----- 2761 2762// CHECK-LABEL: func @not_op_unranked 2763func.func @not_op_unranked(%arg0: tensor<*xi1>) -> tensor<*xi1> { 2764 // CHECK: mhlo.not %arg0 : tensor<*xi1> 2765 %0 = "tf.LogicalNot"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> 2766 func.return %0 : tensor<*xi1> 2767} 2768 2769// ----- 2770 2771// CHECK-LABEL: @neg 2772func.func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2773 // CHECK: mhlo.negate %arg0 : tensor<2xf32> 2774 %0 = "tf.Neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2775 func.return %0 : tensor<2xf32> 2776} 2777 2778// ----- 2779 2780// CHECK-LABEL: func @neg_dynamic 2781func.func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2782 // CHECK: mhlo.negate %arg0 : tensor<?xf32> 2783 %0 = "tf.Neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2784 func.return %0 : tensor<?xf32> 2785} 2786 2787// ----- 2788 2789// CHECK-LABEL: func @neg_unranked 2790func.func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2791 // CHECK: mhlo.negate %arg0 : tensor<*xf32> 2792 %0 = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2793 func.return %0 : tensor<*xf32> 2794} 2795 2796// ----- 2797 2798// CHECK-LABEL: @sigmoid 2799func.func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2800 // CHECK: mhlo.logistic 2801 %0 = "tf.Sigmoid"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2802 func.return %0 : tensor<2xf32> 2803} 2804 2805// ----- 2806 2807// CHECK-LABEL: @sigmoid_complex 2808func.func @sigmoid_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> { 2809 // CHECK: mhlo.logistic 2810 %0 = "tf.Sigmoid"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> 2811 func.return %0 : tensor<2xcomplex<f32>> 2812} 2813 2814// ----- 2815 2816// CHECK-LABEL: @sigmoid_unranked 2817func.func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2818 // CHECK: mhlo.logistic 2819 %0 = "tf.Sigmoid"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2820 func.return %0 : tensor<*xf32> 2821} 2822 2823 2824// CHECK-LABEL: @sigmoid_grad 2825func.func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { 2826 // CHECK-DAG: [[MUL0:%.+]] = mhlo.multiply %arg1, %arg0 : tensor<2xf32> 2827 // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<2xf32> 2828 // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract [[ONE]], %arg0 : tensor<2xf32> 2829 // CHECK-DAG: [[MUL1:%.+]] = mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xf32> 2830 // CHECK: return [[MUL1]] 2831 %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> 2832 func.return %0 : tensor<2xf32> 2833} 2834 2835// ----- 2836 2837// CHECK-LABEL: @sigmoid_grad_complex 2838func.func @sigmoid_grad_complex(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> { 2839 // CHECK-DAG: [[MUL0:%.+]] = mhlo.multiply %arg1, %arg0 : tensor<2xcomplex<f32>> 2840 // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<2xcomplex<f32>> 2841 // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract [[ONE]], %arg0 : tensor<2xcomplex<f32>> 2842 // CHECK-DAG: [[MUL1:%.+]] = mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xcomplex<f32>> 2843 // CHECK: return [[MUL1]] 2844 %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> 2845 func.return %0 : tensor<2xcomplex<f32>> 2846} 2847 2848// ----- 2849 2850// CHECK-LABEL: @sigmoid_grad_dynamic 2851func.func @sigmoid_grad_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> { 2852 // CHECK: chlo.broadcast_multiply {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> 2853 // CHECK: chlo.broadcast_subtract {{.*}} {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32> 2854 // CHECK: chlo.broadcast_multiply {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> 2855 %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> 2856 func.return %0 : tensor<?xf32> 2857} 2858 2859// ----- 2860 2861// CHECK-LABEL: @sin 2862func.func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2863 // CHECK: mhlo.sine %arg0 : tensor<2xf32> 2864 %0 = "tf.Sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2865 func.return %0 : tensor<2xf32> 2866} 2867 2868// ----- 2869 2870// CHECK-LABEL: func @sin_dynamic 2871func.func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2872 // CHECK: mhlo.sine %arg0 : tensor<?xf32> 2873 %0 = "tf.Sin"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2874 func.return %0 : tensor<?xf32> 2875} 2876 2877// ----- 2878 2879// CHECK-LABEL: func @sin_unranked 2880func.func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2881 // CHECK: mhlo.sine %arg0 : tensor<*xf32> 2882 %0 = "tf.Sin"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2883 func.return %0 : tensor<*xf32> 2884} 2885 2886// ----- 2887 2888// CHECK-LABEL: func @rsqrt 2889func.func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2890 // CHECK: mhlo.rsqrt %arg0 : tensor<2xf32> 2891 %0 = "tf.Rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2892 func.return %0 : tensor<2xf32> 2893} 2894 2895// ----- 2896 2897// CHECK-LABEL: func @rsqrt_dynamic 2898func.func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2899 // CHECK: mhlo.rsqrt %arg0 : tensor<?xf32> 2900 %0 = "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2901 func.return %0 : tensor<?xf32> 2902} 2903 2904// ----- 2905 2906// CHECK-LABEL: func @rsqrt_unranked 2907func.func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2908 // CHECK: mhlo.rsqrt %arg0 : tensor<*xf32> 2909 %0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2910 func.return %0 : tensor<*xf32> 2911} 2912 2913// ----- 2914 2915// CHECK-LABEL: func @sqrt 2916func.func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2917 // CHECK: mhlo.sqrt %arg0 : tensor<2xf32> 2918 %0 = "tf.Sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2919 func.return %0 : tensor<2xf32> 2920} 2921 2922// ----- 2923 2924// CHECK-LABEL: func @sqrt_dynamic 2925func.func @sqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2926 // CHECK: mhlo.sqrt %arg0 : tensor<?xf32> 2927 %0 = "tf.Sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2928 func.return %0 : tensor<?xf32> 2929} 2930 2931// ----- 2932 2933// CHECK-LABEL: func @sqrt_unranked 2934func.func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2935 // CHECK: mhlo.sqrt %arg0 : tensor<*xf32> 2936 %0 = "tf.Sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2937 func.return %0 : tensor<*xf32> 2938} 2939 2940// ----- 2941 2942// CHECK-LABEL: func @tanh 2943func.func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2944 // CHECK: mhlo.tanh %arg0 : tensor<2xf32> 2945 %0 = "tf.Tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2946 func.return %0 : tensor<2xf32> 2947} 2948 2949// ----- 2950 2951// CHECK-LABEL: func @tanh_dynamic 2952func.func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2953 // CHECK: mhlo.tanh %arg0 : tensor<?xf32> 2954 %0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2955 func.return %0 : tensor<?xf32> 2956} 2957 2958// ----- 2959 2960// CHECK-LABEL: func @tanh_unranked 2961func.func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2962 // CHECK: mhlo.tanh %arg0 : tensor<*xf32> 2963 %0 = "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2964 func.return %0 : tensor<*xf32> 2965} 2966 2967// ----- 2968 2969// CHECK-LABEL: func @bitcast 2970func.func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2971 // CHECK: mhlo.bitcast_convert %arg0 : (tensor<2xf32>) -> tensor<2xf32> 2972 %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2973 func.return %0 : tensor<2xf32> 2974} 2975 2976// ----- 2977 2978// CHECK-LABEL: func @bitcast_dynamic 2979func.func @bitcast_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2980 // CHECK: mhlo.bitcast_convert %arg0 : (tensor<?xf32>) -> tensor<?xf32> 2981 %0 = "tf.Bitcast"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2982 func.return %0 : tensor<?xf32> 2983} 2984 2985// ----- 2986 2987// CHECK-LABEL: func @bitcast_unranked 2988func.func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2989 // CHECK: mhlo.bitcast_convert %arg0 : (tensor<*xf32>) -> tensor<*xf32> 2990 %0 = "tf.Bitcast"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2991 func.return %0 : tensor<*xf32> 2992} 2993 2994// ----- 2995 2996// CHECK-LABEL: func @bitcast_same_widths 2997func.func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> { 2998 // CHECK: mhlo.bitcast_convert %arg0 : (tensor<2xf32>) -> tensor<2xi32> 2999 %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> 3000 func.return %0 : tensor<2xi32> 3001} 3002 3003// ----- 3004 3005// CHECK-LABEL: func @bitcast_smaller_input_width 3006func.func @bitcast_smaller_input_width(%arg0: tensor<8xi8>) -> tensor<i64> { 3007 // CHECK: mhlo.bitcast_convert %arg0 : (tensor<8xi8>) -> tensor<i64> 3008 %0 = "tf.Bitcast"(%arg0) : (tensor<8xi8>) -> tensor<i64> 3009 func.return %0 : tensor<i64> 3010} 3011 3012// ----- 3013 3014// CHECK-LABEL: func @bitcast_smaller_output_width 3015func.func @bitcast_smaller_output_width(%arg0: tensor<2xf32>) -> tensor<2x2xf16> { 3016 // CHECK: mhlo.bitcast_convert %arg0 : (tensor<2xf32>) -> tensor<2x2xf16> 3017 %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2x2xf16> 3018 func.return %0 : tensor<2x2xf16> 3019} 3020 3021// ----- 3022 3023// CHECK-LABEL: reshape 3024func.func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<2x1xf32> { 3025 // CHECK: mhlo.reshape 3026 %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<2x1xf32> 3027 func.return %0 : tensor<2x1xf32> 3028} 3029 3030// ----- 3031 3032// CHECK-LABEL: not_lowering_reshape 3033func.func @not_lowering_reshape(%arg0: tensor<!tf_type.string>, %arg1: tensor<1xi32>) -> tensor<1x!tf_type.string> { 3034 // CHECK: "tf.Reshape" 3035 %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<!tf_type.string>, tensor<1xi32>) -> tensor<1x!tf_type.string> 3036 func.return %0 : tensor<1x!tf_type.string> 3037} 3038 3039// ----- 3040 3041// CHECK-LABEL: reshape_dynamic 3042func.func @reshape_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> { 3043 // CHECK: "chlo.dynamic_reshape" 3044 // CHLO: mhlo.compute_reshape_shape 3045 // CHLO: mhlo.dynamic_reshape 3046 %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32> 3047 func.return %0 : tensor<?x?xf32> 3048} 3049 3050// ----- 3051 3052// CHECK-LABEL: reshape_unranked 3053// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> 3054// CHECK-SAME: %[[TARGET_SHAPE:.*]]: tensor<2xi32> 3055func.func @reshape_unranked(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> { 3056 // CHECK: "chlo.dynamic_reshape" 3057 // CHLO: shape.shape_of 3058 // CHLO: shape.num_elements 3059 // CHLO: mhlo.cstr_reshapable 3060 // CHLO: assuming{{.*}}{ 3061 // CHLO: mhlo.compute_reshape_shape 3062 // CHLO: mhlo.dynamic_reshape 3063 // CHLO: } 3064 %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32> 3065 func.return %0 : tensor<?x?xf32> 3066} 3067 3068// ----- 3069 3070// CHECK-LABEL: squeeze 3071func.func @squeeze(%arg0: tensor<1x1x10xf32>) -> tensor<1x10xf32> { 3072 // CHECK: mhlo.reshape 3073 %0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32> 3074 func.return %0 : tensor<1x10xf32> 3075} 3076 3077// ----- 3078 3079// CHECK-LABEL: squeeze_ranked 3080func.func @squeeze_ranked(%arg0: tensor<?x?x?xf32>) -> tensor<?xf32> { 3081 // CHECK: %[[C2:.*]] = arith.constant 2 : index 3082 // CHECK: %[[D2:.*]] = tensor.dim %arg0, %[[C2]] : tensor<?x?x?xf32> 3083 // CHECK: %[[T:.*]] = tensor.from_elements %[[D2]] : tensor<1xindex> 3084 // CHECK: %[[R:.*]] = "chlo.dynamic_reshape"(%arg0, %[[T]]) : (tensor<?x?x?xf32>, tensor<1xindex>) -> tensor<?xf32> 3085 // CHECK: return %[[R]] : tensor<?xf32> 3086 %0 = "tf.Squeeze"(%arg0) { squeeze_dims = [0, 1] }: (tensor<?x?x?xf32>) -> tensor<?xf32> 3087 func.return %0 : tensor<?xf32> 3088} 3089 3090// ----- 3091 3092// CHECK-LABEL: squeeze_ranked_negative 3093func.func @squeeze_ranked_negative(%arg0: tensor<?x?x10xf32>) -> tensor<?x10xf32> { 3094 // CHECK: %[[C0:.*]] = arith.constant 0 : index 3095 // CHECK: %[[D0:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?x10xf32> 3096 // CHECK: %[[C2:.*]] = arith.constant 2 : index 3097 // CHECK: %[[D2:.*]] = tensor.dim %arg0, %[[C2]] : tensor<?x?x10xf32> 3098 // CHECK: %[[T:.*]] = tensor.from_elements %[[D0]], %[[D2]] : tensor<2xindex> 3099 // CHECK: %[[R:.*]] = "chlo.dynamic_reshape"(%arg0, %[[T]]) : (tensor<?x?x10xf32>, tensor<2xindex>) -> tensor<?x10xf32> 3100 // CHECK: return %[[R]] : tensor<?x10xf32> 3101 %0 = "tf.Squeeze"(%arg0) { squeeze_dims = [-2] }: (tensor<?x?x10xf32>) -> tensor<?x10xf32> 3102 func.return %0 : tensor<?x10xf32> 3103} 3104 3105// ----- 3106 3107// CHECK-LABEL: squeeze_ranked_dynamic 3108func.func @squeeze_ranked_dynamic(%arg0: tensor<?x?xf32>) -> tensor<?xf32> { 3109 // CHECK: "tf.Squeeze" 3110 %0 = "tf.Squeeze"(%arg0) : (tensor<?x?xf32>) -> tensor<?xf32> 3111 func.return %0 : tensor<?xf32> 3112} 3113 3114// ----- 3115 3116// CHECK-LABEL: squeeze_dynamic 3117func.func @squeeze_dynamic(%arg0: tensor<?x10xf32>) -> tensor<*xf32> { 3118 // CHECK: "tf.Squeeze" 3119 %0 = "tf.Squeeze"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32> 3120 func.return %0 : tensor<*xf32> 3121} 3122 3123// ----- 3124 3125// CHECK-LABEL: expand_dims 3126func.func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor<i32>) -> tensor<1x2xf32> { 3127 // CHECK: mhlo.reshape 3128 %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor<i32>) -> tensor<1x2xf32> 3129 func.return %0 : tensor<1x2xf32> 3130} 3131 3132// ----- 3133 3134// CHECK-LABEL: expand_dims_dynamic 3135func.func @expand_dims_dynamic(%arg0: tensor<?x?xf32>) -> tensor<?x1x?xf32> { 3136 %axis = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> (tensor<i32>) 3137 3138 // CHECK-DAG: %[[SHAPEOF:.+]] = shape.shape_of %arg0 3139 // CHECK-DAG: %[[CST0:.+]] = arith.constant 0 3140 // CHECK-DAG: %[[CST1:.+]] = arith.constant 1 3141 // CHECK-DAG: %[[GETEXTENT0:.+]] = tensor.extract %[[SHAPEOF]][%[[CST0]]] 3142 // CHECK-DAG: %[[CST1_0:.+]] = arith.constant 1 3143 // CHECK-DAG: %[[GETEXTENT1:.+]] = tensor.extract %[[SHAPEOF]][%[[CST1_0]]] 3144 // CHECK-DAG: %[[TOEXTENTS:.+]] = tensor.from_elements %[[GETEXTENT0]], %[[CST1]], %[[GETEXTENT1]] 3145 // CHECK-DAG: %[[RESHAPE:.+]] = mhlo.dynamic_reshape %arg0, %[[TOEXTENTS]] 3146 %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<?x?xf32>, tensor<i32>) -> tensor<?x1x?xf32> 3147 3148 // CHECK: return %[[RESHAPE]] 3149 func.return %0 : tensor<?x1x?xf32> 3150} 3151 3152// ----- 3153 3154// CHECK-LABEL: expand_dynamic_dims_rank1_axis 3155func.func @expand_dynamic_dims_rank1_axis(%arg0: tensor<?x?x4xf32>) -> tensor<?x1x?x4xf32> { 3156 %axis = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3157 3158 // CHECK-DAG: %[[SHAPEOF:.+]] = shape.shape_of %arg0 3159 // CHECK-DAG: %[[CST0:.+]] = arith.constant 0 3160 // CHECK-DAG: %[[CST1:.+]] = arith.constant 1 3161 // CHECK-DAG: %[[GETEXTENT0:.+]] = tensor.extract %[[SHAPEOF]][%[[CST0]]] 3162 // CHECK-DAG: %[[CST1_0:.+]] = arith.constant 1 3163 // CHECK-DAG: %[[GETEXTENT1:.+]] = tensor.extract %[[SHAPEOF]][%[[CST1_0]]] 3164 // CHECK-DAG: %[[CST2:.+]] = arith.constant 2 3165 // CHECK-DAG: %[[GETEXTENT2:.+]] = tensor.extract %[[SHAPEOF]][%[[CST2]]] 3166 // CHECK-DAG: %[[TOEXTENTS:.+]] = tensor.from_elements %[[GETEXTENT0]], %[[CST1]], %[[GETEXTENT1]], %[[GETEXTENT2]] 3167 // CHECK-DAG: %[[RESHAPE:.+]] = mhlo.dynamic_reshape %arg0, %[[TOEXTENTS]] 3168 %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<?x?x4xf32>, tensor<1xi32>) -> tensor<?x1x?x4xf32> 3169 3170 // CHECK: return %[[RESHAPE]] 3171 func.return %0 : tensor<?x1x?x4xf32> 3172} 3173 3174// ----- 3175 3176// CHECK-LABEL: func @sign 3177// CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32> 3178func.func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { 3179 // CHECK: [[SIGN:%.*]] = mhlo.sign [[ARG]] 3180 // CHECK: return [[SIGN]] : tensor<1x2x3x4xf32> 3181 %0 = "tf.Sign"(%arg0) : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>) 3182 func.return %0 : tensor<1x2x3x4xf32> 3183} 3184 3185// ----- 3186 3187// CHECK-LABEL: func @sign_dynamic 3188func.func @sign_dynamic(%arg0: tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xf32> { 3189 // CHECK: [[SIGN:%.*]] = mhlo.sign %arg0 : tensor<?x2x3x?xf32> 3190 // CHECK: return [[SIGN]] : tensor<?x2x3x?xf32> 3191 %0 = "tf.Sign"(%arg0) : (tensor<?x2x3x?xf32>) -> (tensor<?x2x3x?xf32>) 3192 func.return %0 : tensor<?x2x3x?xf32> 3193} 3194 3195// ----- 3196 3197// CHECK-LABEL: slice_constant_start 3198func.func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { 3199 // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<i64> 3200 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 3201 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 3202 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : 3203 // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64> 3204 // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64> 3205 // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]]) 3206 // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} : 3207 // CHECK-DAG-SAME: (tensor<4xi32>, tensor<i64>) -> tensor<2xi32> 3208 // CHECK: return %[[RESULT]] : tensor<2xi32> 3209 %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) 3210 %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>) 3211 %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32> 3212 func.return %0 : tensor<2xi32> 3213} 3214 3215// ----- 3216 3217// CHECK-LABEL: slice_i32_consts 3218func.func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { 3219 // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<i32> 3220 // CHECK: "mhlo.dynamic_slice"(%arg0, %[[START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i32>) -> tensor<2xi32> 3221 %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) 3222 %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) 3223 %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> 3224 func.return %0 : tensor<2xi32> 3225} 3226 3227// ----- 3228 3229// CHECK-LABEL: slice_constant_start_negative_one_size 3230func.func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> { 3231 // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<i64> 3232 // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<3xi32> 3233 // CHECK: return %[[RESULT]] : tensor<3xi32> 3234 %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) 3235 %sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>) 3236 %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi32> 3237 func.return %0 : tensor<3xi32> 3238} 3239 3240// ----- 3241 3242// CHECK-LABEL: slice_constant_start_dynamic_shape 3243func.func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { 3244 // CHECK-DAG: %[[START1:.*]] = mhlo.constant dense<1> : tensor<i64> 3245 // CHECK-DAG: %[[START2:.*]] = mhlo.constant dense<0> : tensor<i64> 3246 // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice" 3247 // CHECK-DAG-SAME: (%arg0, %[[START1]], %[[START2]]) 3248 // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : 3249 // CHECK-DAG-SAME: (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32> 3250 // CHECK: return %[[RESULT]] : tensor<1x4xi32> 3251 %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) 3252 %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) 3253 %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<?x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> 3254 func.return %0 : tensor<1x4xi32> 3255} 3256 3257// ----- 3258 3259// CHECK-LABEL: slice_variable_start 3260func.func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { 3261 // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%arg1) 3262 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 3263 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 3264 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> 3265 // CHECK: %[[RESHAPED_START1:.*]] = mhlo.reshape %[[SLICED_START1]] : (tensor<1xi64>) -> tensor<i64> 3266 // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%arg1) 3267 // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>, 3268 // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, 3269 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> 3270 // CHECK: %[[RESHAPED_START2:.*]] = mhlo.reshape %[[SLICED_START2]] : (tensor<1xi64>) -> tensor<i64> 3271 // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32> 3272 // CHECK: return %[[RESULT]] : tensor<1x4xi32> 3273 %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) 3274 %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> 3275 func.return %0 : tensor<1x4xi32> 3276} 3277 3278// ----- 3279 3280// CHECK-LABEL: slice_mhlo_sizes 3281func.func @slice_mhlo_sizes(%arg0: tensor<1x1024x4xf32>, %arg1: tensor<3xi32>) -> tensor<1x512x4xf32> { 3282 // CHECK-NOT: "tf.Slice" 3283 %0 = "mhlo.constant"() {value = dense<[1, 512, 4]> : tensor<3xi32>} : () -> tensor<3xi32> 3284 %1 = "tf.Slice"(%arg0, %arg1, %0) : (tensor<1x1024x4xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x4xf32> 3285 func.return %1 : tensor<1x512x4xf32> 3286} 3287 3288// ----- 3289 3290// CHECK-LABEL: slice_variable_start_negative_one_size 3291func.func @slice_variable_start_negative_one_size(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { 3292 // CHECK: %[[RESULT:.*]] = "tf.Slice" 3293 // CHECK: return %[[RESULT]] : tensor<1x4xi32> 3294 %sizes = "tf.Const"() {value = dense<[1, -1]> : tensor<2xi64>} : () -> (tensor<2xi64>) 3295 %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> 3296 func.return %0 : tensor<1x4xi32> 3297} 3298 3299// ----- 3300 3301// CHECK-LABEL: slice_real_dynamic_slice 3302func.func @slice_real_dynamic_slice(%arg0: tensor<4xi32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>) -> tensor<*xi32> { 3303 // CHECK: tensor.extract {{.*}} : tensor<1xi64> 3304 // CHECK: tensor.extract {{.*}} : tensor<1xi64> 3305 // CHECK: arith.index_cast {{.*}} : index to i64 3306 // CHECK: arith.cmpi eq, {{.*}} : i64 3307 // CHECK: arith.addi {{.*}} : i64 3308 // CHECK: tensor.dim {{.*}} : tensor<4xi32> 3309 // CHECK: arith.index_cast {{.*}} : index to i64 3310 // CHECK: select {{.*}} : i64 3311 // CHECK: arith.index_cast {{.*}} : i64 to index 3312 // CHECK: arith.index_cast {{.*}} : i64 to index 3313 // CHECK: tensor.from_elements {{.*}} : tensor<1xindex> 3314 // CHECK: tensor.from_elements {{.*}} : tensor<1xindex> 3315 // CHECK: tensor.from_elements {{.*}} : tensor<1xindex> 3316 %0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<*xi32> 3317 func.return %0 : tensor<*xi32> 3318} 3319 3320//===----------------------------------------------------------------------===// 3321// StridedSlice op legalizations. 3322//===----------------------------------------------------------------------===// 3323 3324// ----- 3325 3326// CHECK-LABEL: simple_strided_slice 3327func.func @simple_strided_slice(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { 3328 %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>) 3329 %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) 3330 %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 3331 3332 // CHECK: mhlo.slice 3333 // CHECK-DAG-SAME: start_indices = dense<[0, 1]> 3334 // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> 3335 // CHECK-DAG-SAME: strides = dense<[1, 3]> 3336 // CHECK-SAME: -> tensor<3x2xf32> 3337 3338 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 3339 : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32> 3340 func.return %output : tensor<3x2xf32> 3341} 3342 3343// ----- 3344 3345// CHECK-LABEL: dynamic_strided_slice 3346func.func @dynamic_strided_slice(%input: tensor<?x8xf32>) -> tensor<?x2xf32> { 3347 %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>) 3348 %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) 3349 %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 3350 3351 // CHECK: "tf.StridedSlice" 3352 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 3353 : (tensor<?x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x2xf32> 3354 func.return %output : tensor<?x2xf32> 3355} 3356 3357// ----- 3358 3359// CHECK-LABEL: strided_slice_negative_indices 3360func.func @strided_slice_negative_indices(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { 3361 %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>) 3362 %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) 3363 %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 3364 3365 // CHECK: "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} 3366 3367 // CHECK: mhlo.slice 3368 // CHECK-DAG-SAME: start_indices = dense<[0, 1]> 3369 // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> 3370 // CHECK-DAG-SAME: strides = dense<[1, 3]> 3371 // CHECK-SAME: -> tensor<3x2xf32> 3372 3373 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 3374 : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32> 3375 func.return %output : tensor<3x2xf32> 3376} 3377 3378// ----- 3379 3380// CHECK-LABEL: dynamic_strided_slice_negative_indices 3381func.func @dynamic_strided_slice_negative_indices(%input: tensor<?x8xf32>) -> tensor<?x2xf32> { 3382 %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>) 3383 %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) 3384 %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 3385 3386 // CHECK: tf.StridedSlice 3387 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 3388 : (tensor<?x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x2xf32> 3389 func.return %output : tensor<?x2xf32> 3390} 3391 3392// ----- 3393 3394// CHECK-LABEL: strided_slice_range_clamping 3395func.func @strided_slice_range_clamping(%input: tensor<4x8xf32>) -> tensor<1x3xf32> { 3396 %begin = "tf.Const"() {value = dense<[-4, -10]> : tensor<2xi32>} : () -> (tensor<2xi32>) 3397 %end = "tf.Const"() {value = dense<[1, 10]> : tensor<2xi32>} : () -> (tensor<2xi32>) 3398 %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 3399 3400 // CHECK: mhlo.slice 3401 // CHECK-DAG-SAME: start_indices = dense<[0, 0]> 3402 // CHECK-DAG-SAME: limit_indices = dense<[1, 8]> 3403 // CHECK-DAG-SAME: strides = dense<[1, 3]> 3404 // CHECK-SAME: -> tensor<1x3xf32> 3405 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 3406 : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x3xf32> 3407 func.return %output : tensor<1x3xf32> 3408} 3409 3410// ----- 3411 3412// CHECK-LABEL: strided_slice_empty 3413func.func @strided_slice_empty(%input: tensor<4xf32>) -> tensor<0xf32> { 3414 %begin = "tf.Const"() {value = dense<[-4]> : tensor<1xi32>} : () -> (tensor<1xi32>) 3415 %end = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) 3416 %strides = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) 3417 3418 // CHECK: mhlo.constant dense<> : tensor<0xf32> 3419 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 3420 : (tensor<4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf32> 3421 func.return %output : tensor<0xf32> 3422} 3423 3424// ----- 3425 3426// CHECK-LABEL: strided_slice_begin_end_mask 3427// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<4x128x1024xf32> 3428func.func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) { 3429 3430 // For StridedSlice 3431 // Dim #: 0, 1, 2 3432 // Input shape: [4, 128, 1024] 3433 // Begin: 1, 4, -3 3434 // End: 8, 65, 42 3435 // Stride: 1, 4, -1 3436 // Begin mask: 0, 0, 1 (= 1) 3437 // End mask: 1, 0, 0 (= 4) 3438 3439 // So result shape: 3440 // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 3441 // Dim #1: 4 to 65 stride 4: so 16 3442 // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022 3443 // result shape: [4, 16, 1022] 3444 3445 %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) 3446 %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) 3447 %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) 3448 3449 // CHECK: %[[REVERSE:.*]] = "mhlo.reverse"(%[[INPUT]]) 3450 3451 // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[REVERSE]]) 3452 // CHECK-DAG-SAME: limit_indices = dense<[4, 65, 1024]> 3453 // CHECK-DAG-SAME: start_indices = dense<[0, 4, 2]> 3454 // CHECK-DAG-SAME: strides = dense<[1, 4, 1]> 3455 // CHECK-SAME: -> tensor<4x16x1022xf32> 3456 3457 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x16x1022xf32> 3458 3459 // CHECK: mhlo.reshape %[[SLICE]] 3460 // CHECK-SAME: -> tensor<4x16x1022xf32> 3461 3462 func.return 3463} 3464 3465// ----- 3466 3467// CHECK-LABEL: strided_slice_shrink_axis_mask 3468// CHECK-SAME: %[[INPUT:.+]]: tensor<4x128x1024xf32> 3469func.func @strided_slice_shrink_axis_mask(%input: tensor<4x128x1024xf32>) { 3470 3471 // For StridedSlice 3472 // Dim #: 0, 1, 2 3473 // Input shape: [4, 128, 1024] 3474 // Begin: 1, 4, -3 3475 // End: 8, 65, 42 3476 // Stride: 1, 4, -1 3477 // Begin mask: 1, 0, 0 (= 1) 3478 // End mask: 0, 0, 1 (= 4) 3479 // Shrink axis mask: 1, 0, 1 (= 5) 3480 3481 // So result shape: 3482 // Dim #0: shrink axis, take value at [1] 3483 // Dim #1: 4 to 65 stride 4: so 16 3484 // Dim #2: shrink axis, take value at [-3] 3485 // result shape: [16] 3486 3487 // As output shape of StridedSlice differs, a reshape will follow. 3488 3489 %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) 3490 %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) 3491 %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) 3492 3493 // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) 3494 // CHECK-DAG-SAME: limit_indices = dense<[1, 65, 1022]> 3495 // CHECK-DAG-SAME: start_indices = dense<[0, 4, 1021]> 3496 // CHECK-DAG-SAME: strides = dense<[1, 4, 1]> 3497 // CHECK-SAME: -> tensor<1x16x1xf32> 3498 3499 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4, shrink_axis_mask = 5} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<16xf32> 3500 3501 // CHECK: mhlo.reshape %[[SLICE]] 3502 // CHECK-SAME: -> tensor<16xf32> 3503 3504 func.return 3505} 3506 3507// ----- 3508 3509// CHECK-LABEL: strided_slice_ellipsis_mask 3510// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32> 3511func.func @strided_slice_ellipsis_mask(%input: tensor<2x4x8x16x32x64xf32>) { 3512 // For StridedSlice input[1, ..., 8:, :10, 2:6:2] 3513 // The ellipsis mask is applied to dim #1, #2, i.e, we get canonicalized 3514 // slice input[1, :, :, 8:, :10, 2:6:2] 3515 3516 // The start, limit indices and strides attributes of mhlo.slice would 3517 // reflect the canonicalized slice. 3518 // As output shape of StridedSlice differs, a reshape will follow. 3519 3520 %begin = "tf.Const"() {value = dense<[1, 0, 8, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>) 3521 %end = "tf.Const"() {value = dense<[2, 0, 10, 10, 6]> : tensor<5xi32>} : () -> (tensor<5xi32>) 3522 %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>) 3523 3524 // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) 3525 // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64> 3526 // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> 3527 // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64> 3528 // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32> 3529 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 8, end_mask = 4, shrink_axis_mask = 1, ellipsis_mask = 2} : (tensor<2x4x8x16x32x64xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<4x8x8x10x2xf32> 3530 3531 // CHECK: mhlo.reshape %[[SLICE]] 3532 // CHECK-SAME: -> tensor<4x8x8x10x2xf32> 3533 3534 func.return 3535} 3536 3537// ----- 3538 3539// CHECK-LABEL: strided_slice_new_axis_mask 3540// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32> 3541func.func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) { 3542 // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis] 3543 // New axis mask is at index 1 and 6 of sparse spec, so 3544 // new_axis_mask = 2^1 + 2^6 = 66 3545 // The ellipsis mask is applied to dim #1, #2 of input i.e, we get 3546 // canonicalized slice input[1, :, :, 8:, :10, 2:6:2] 3547 // This is then reshaped to add the new axes. 3548 3549 // The start, limit indices and strides attributes of mhlo.slice would 3550 // reflect the canonicalized slice. 3551 // As output shape of StridedSlice differs, a reshape will follow to reflect 3552 // new axes added. 3553 3554 %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) 3555 %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) 3556 %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>) 3557 3558 // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) 3559 // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64> 3560 // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> 3561 // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64> 3562 // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32> 3563 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<2x4x8x16x32x64xf32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>) -> tensor<1x4x8x8x10x2x1xf32> 3564 3565 // CHECK: mhlo.reshape %[[SLICE]] 3566 // CHECK-SAME: -> tensor<1x4x8x8x10x2x1xf32> 3567 3568 func.return 3569} 3570 3571// ----- 3572 3573// CHECK-LABEL: strided_slice_implicit_ellipsis_mask( 3574// CHECK-SAME: [[INPUT:%.*]]: tensor<10x16x2xf32> 3575func.func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tensor<2x16x2xf32> { 3576 // StridedSlice gets input[8:10], which is same as input[8:10, ...] 3577 // The start_indices, limit_indices, and strides attribute of mhlo.slice 3578 // reflect the canonicalized slice. 3579 %begin = "tf.Const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32> 3580 %end = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> 3581 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3582 // CHECK: [[SLICE:%.*]] = "mhlo.slice"([[INPUT]]) 3583 // CHECK-DAG-SAME: limit_indices = dense<[10, 16, 2]> : tensor<3xi64> 3584 // CHECK-DAG-SAME: start_indices = dense<[8, 0, 0]> : tensor<3xi64> 3585 // CHECK-DAG-SAME: strides = dense<1> : tensor<3xi64> 3586 // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[SLICE]] : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32> 3587 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = f32} : (tensor<10x16x2xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x16x2xf32> 3588 // CHECK: return [[RESHAPE]] : tensor<2x16x2xf32> 3589 func.return %0 : tensor<2x16x2xf32> 3590} 3591 3592// ----- 3593 3594// CHECK-LABEL: strided_slice_nonconstant_begin_end 3595func.func @strided_slice_nonconstant_begin_end(%arg0: tensor<i32>, %arg1: tensor<32x1x97xi32>) -> (tensor<1x97xi32>) { 3596 // In this case, the `begin` and `end` inputs are unknown at compile time -- 3597 // so the StridedSlice needs to slice these vectors and use that as input to 3598 // an HLO dynamic slice. 3599 %begin = "tf.Pack"(%arg0) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32> 3600 %0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 3601 %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3602 %2 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32> 3603 %end = "tf.Pack"(%2) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32> 3604 // CHECK: %[[A:.*]] = mhlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32> 3605 // CHECK-NEXT: %[[BEGIN:.*]] = "mhlo.concatenate"(%[[A]]) 3606 // CHECK-DAG-SAME: {dimension = 0 : i64} : (tensor<1xi32>) -> tensor<1xi32> 3607 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 3608 // CHECK-NEXT: %[[INDEX:.*]] = "mhlo.slice"(%[[BEGIN]]) 3609 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 3610 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 3611 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> 3612 // CHECK-NEXT: %[[INDEX2:.*]] = mhlo.reshape %[[INDEX]] : (tensor<1xi32>) -> tensor<i32> 3613 // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] 3614 // CHECK-DAG-SAME: {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1> 3615 // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor<i32> 3616 // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor<i32>, tensor<i32>) -> tensor<i32> 3617 // CHECK-NEXT: %[[INDEX3:.*]] = "mhlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) : 3618 // CHECK-DAG-SAME: (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32> 3619 // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic_slice" 3620 // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]]) 3621 // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} : 3622 // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1x1x97xi32> 3623 // CHECK-NEXT: %[[FINAL:.*]] = mhlo.reshape %[[SLICED]] : (tensor<1x1x97xi32>) -> tensor<1x97xi32> 3624 %result = "tf.StridedSlice"(%arg1, %begin, %end, %1) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3625 // CHECK-NEXT: return %[[FINAL]] : tensor<1x97xi32> 3626 func.return %result : tensor<1x97xi32> 3627} 3628 3629// ----- 3630 3631// CHECK-LABEL: strided_slice_nonconstant_begin_end_with_start_end_mask 3632// CHECK-SAME: (%[[INPUT:.*]]: tensor<32x1x97xi32>, %[[BEGIN:.*]]: tensor<3xi32>, %[[END:.*]]: tensor<3xi32>) 3633func.func @strided_slice_nonconstant_begin_end_with_start_end_mask(%input: tensor<32x1x97xi32>, %begin: tensor<3xi32>, %end: tensor<3xi32>) -> (tensor<1x97xi32>) { 3634 %strides = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> 3635 3636 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 3637 // CHECK: %[[INDEX:.*]] = "mhlo.slice"(%[[BEGIN]]) 3638 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64> 3639 // CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64> 3640 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> 3641 // CHECK-NEXT: %[[INDEX2:.*]] = mhlo.reshape %[[INDEX]] : (tensor<1xi32>) -> tensor<i32> 3642 // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] 3643 // CHECK-DAG-SAME: {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1> 3644 // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor<i32> 3645 // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor<i32>, tensor<i32>) -> tensor<i32> 3646 // CHECK-NEXT: %[[INDEX3:.*]] = "mhlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) : 3647 // CHECK-DAG-SAME: (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32> 3648 // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic_slice" 3649 // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]]) 3650 // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} : 3651 // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1x1x97xi32> 3652 // CHECK-NEXT: %[[FINAL:.*]] = mhlo.reshape %[[SLICED]] : (tensor<1x1x97xi32>) -> tensor<1x97xi32> 3653 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x97xi32> 3654 func.return %result : tensor<1x97xi32> 3655} 3656 3657// ----- 3658 3659// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_1 3660func.func @strided_slice_nonconstant_begin_end_stride_1(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>, %strides: tensor<1xi32>) -> (tensor<1x97xi32>) { 3661 // Dynamic stride: when `begin` and `end` inputs are unknown at compile time, 3662 // `strides` must be known. 3663 // CHECK: tf.StridedSlice 3664 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3665 func.return %result : tensor<1x97xi32> 3666} 3667 3668// ----- 3669 3670// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_2 3671func.func @strided_slice_nonconstant_begin_end_stride_2(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3672 // Invalid stride (not equal to 1): when `begin` and `end` inputs are unknown 3673 // at compile time, `strides` must be known to have all 1 values. 3674 %strides = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> 3675 // CHECK: tf.StridedSlice 3676 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3677 func.return %result : tensor<1x97xi32> 3678} 3679 3680// ----- 3681 3682// CHECK-LABEL: strided_slice_nonconstant_begin_end_invalid_elem_count 3683func.func @strided_slice_nonconstant_begin_end_invalid_elem_count(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>) -> tensor<6x10xf32> { 3684 %strides = "tf.Const"() { value = dense<[1, 1]> : tensor<2xi64> } : () -> tensor<2xi64> 3685 // When begin/end are dynamic, the number of output elements must be equal to 3686 // the number of input elements sliced. 3687 // CHECK: tf.StridedSlice 3688 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<6x10xf32> 3689 func.return %0 : tensor<6x10xf32> 3690} 3691 3692// ----- 3693 3694// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_new_axis_mask 3695func.func @strided_slice_nonconstant_begin_end_and_new_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3696 // New axis mask: When `begin` and `end` inputs are unknown at compile time, 3697 // we can't support a new_axis mask. 3698 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3699 // CHECK: tf.StridedSlice 3700 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 15 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3701 func.return %result : tensor<1x97xi32> 3702} 3703 3704// ----- 3705 3706// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_ellipsis_mask 3707func.func @strided_slice_nonconstant_begin_end_and_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3708 // This ellipsis mask is not supported because it does not refer to the last 3709 // dimension. 3710 // [0, 1, 0] = 2 3711 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3712 // CHECK: tf.StridedSlice 3713 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3714 func.return %result : tensor<1x97xi32> 3715} 3716 3717// ----- 3718 3719// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask 3720func.func @strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3721 // This ellipsis mask is supported because it refers to the last dimension. 3722 // [1, 0, 0] = 4 3723 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3724 // CHECK: mhlo.dynamic_slice 3725 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 4 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3726 func.return %result : tensor<1x97xi32> 3727} 3728 3729// ----- 3730 3731// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask 3732func.func @strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3733 // This shrink_axis mask is supported because it refers to a major dimension. 3734 // [1, 1, 1] = 7 3735 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3736 // CHECK: mhlo.dynamic_slice 3737 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 7 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3738 func.return %result : tensor<1x97xi32> 3739} 3740 3741// ----- 3742 3743// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask 3744func.func @strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3745 // This shrink_axis mask is unsupported because it does not refer to a major 3746 // dimension. 3747 // [0, 1, 0] = 2 3748 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3749 // CHECK: tf.StridedSlice 3750 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3751 func.return %result : tensor<1x97xi32> 3752} 3753 3754 3755//===----------------------------------------------------------------------===// 3756// Reduction op legalizations. 3757//===----------------------------------------------------------------------===// 3758 3759// ----- 3760 3761// CHECK-LABEL: func @mean 3762func.func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3763 // CHECK: %[[CAST:.*]] = mhlo.convert(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> 3764 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32> 3765 // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> 3766 // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %{{.*}} {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 3767 // CHECK: %[[CAST_BACK:.*]] = mhlo.convert(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16> 3768 // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> 3769 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3770 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3771 %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3772 func.return %0 : tensor<4x1xf16> 3773} 3774 3775// ----- 3776 3777// CHECK-LABEL: func @mean_scalar_dim 3778func.func @mean_scalar_dim(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3779 // Verify that tf.Mean op with scalar attributes are lowered successfully. 3780 3781 // CHECK-NOT: tf.Mean 3782 %dimension = "tf.Const"() { value = dense<1> : tensor<i64> } : () -> tensor<i64> 3783 %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<i64>) -> tensor<4x1xf16> 3784 func.return %0 : tensor<4x1xf16> 3785} 3786 3787// ----- 3788 3789// CHECK-LABEL: func @mean_dynamic 3790func.func @mean_dynamic(%arg0: tensor<?x?xf16>) -> tensor<?x1xf16> { 3791 // CHECK: %[[CAST:.*]] = mhlo.convert(%arg0) : (tensor<?x?xf16>) -> tensor<?x?xf32> 3792 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32> 3793 // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32> 3794 // CHECK: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor<?x?xf16> -> tensor<2xindex> 3795 // CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index 3796 // CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index 3797 // CHECK: %[[REDUCED_DIM:.*]] = tensor.extract %[[SHAPE0]][%[[C1_2]]] : tensor<2xindex> 3798 // CHECK: %[[MUL:.*]] = arith.muli %[[C1_1]], %[[REDUCED_DIM]] : index 3799 // CHECK: %[[INDEX_CAST:.*]] = arith.index_cast %[[MUL]] : index to i64 3800 // CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[INDEX_CAST]] : tensor<i64> 3801 // CHECK: %[[CONVERT:.*]] = mhlo.convert(%[[TENSOR]]) : (tensor<i64>) -> tensor<f32> 3802 // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[CONVERT]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32> 3803 // CHECK: %[[MEAN_CONVERTED:.*]] = mhlo.convert(%[[MEAN]]) : (tensor<?xf32>) -> tensor<?xf16> 3804 // CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[MEAN_CONVERTED]] : tensor<?xf16> -> tensor<1xindex> 3805 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 3806 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 3807 // CHECK: %[[UNREDUCED_DIM:.*]] = tensor.extract %[[SHAPE1]][%[[C0]]] : tensor<1xindex> 3808 // CHECK: %[[RESULT_SHAPE:.*]] = tensor.from_elements %[[UNREDUCED_DIM]], %[[C1]] : tensor<2xindex> 3809 // CHECK: %[[RESULT:.*]] = mhlo.dynamic_reshape %[[MEAN_CONVERTED]], %[[RESULT_SHAPE]] : (tensor<?xf16>, tensor<2xindex>) -> tensor<?x1xf16> 3810 // CHECK: return %[[RESULT]] : tensor<?x1xf16> 3811 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3812 %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<?x?xf16>, tensor<1xi64>) -> tensor<?x1xf16> 3813 func.return %0 : tensor<?x1xf16> 3814} 3815 3816// ----- 3817 3818// CHECK-LABEL: func @sum 3819func.func @sum(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3820 // CHECK: %[[CAST:.*]] = mhlo.convert(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> 3821 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32> 3822 // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> 3823 // CHECK: %[[CAST_BACK:.*]] = mhlo.convert(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> 3824 // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> 3825 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3826 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3827 %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3828 func.return %0 : tensor<4x1xf16> 3829} 3830 3831// ----- 3832 3833// CHECK-LABEL: func @sum_dynamic 3834func.func @sum_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { 3835 // CHECK: %[[CAST:.*]] = mhlo.convert(%arg0) : (tensor<4x?xf16>) -> tensor<4x?xf32> 3836 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor<f32> 3837 // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x?xf32>, tensor<f32>) -> tensor<4xf32> 3838 // CHECK: %[[CAST_BACK:.*]] = mhlo.convert(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> 3839 // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> 3840 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3841 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3842 %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3843 func.return %0 : tensor<4x1xf16> 3844} 3845 3846// ----- 3847 3848// CHECK-LABEL: func @max 3849func.func @max(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3850 // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : tensor<4x8xf16> 3851 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor<f16> 3852 // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.maximum across dimensions = [1] : (tensor<4x8xf16>, tensor<f16>) -> tensor<4xf16> 3853 // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : tensor<4xf16> 3854 // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> 3855 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3856 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3857 %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3858 func.return %0 : tensor<4x1xf16> 3859} 3860 3861// ----- 3862 3863// CHECK-LABEL: func @max_qint 3864// Regression test to ensure we don't crash getting the initial value for 3865// tf.Max when using quantized integer types. 3866func.func @max_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> { 3867 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3868 %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8> 3869 func.return %0 : tensor<4x1x!tf_type.qint8> 3870} 3871 3872// ----- 3873 3874// CHECK-LABEL: func @max_dynamic 3875func.func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { 3876 // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : tensor<4x?xf16> 3877 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor<f16> 3878 // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.maximum across dimensions = [1] : (tensor<4x?xf16>, tensor<f16>) -> tensor<4xf16> 3879 // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : tensor<4xf16> 3880 // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> 3881 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3882 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3883 %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3884 func.return %0 : tensor<4x1xf16> 3885} 3886 3887// ----- 3888 3889// CHECK-LABEL: func @min 3890func.func @min(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3891 // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : tensor<4x8xf16> 3892 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0x7C00> : tensor<f16> 3893 // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.minimum across dimensions = [1] : (tensor<4x8xf16>, tensor<f16>) -> tensor<4xf16> 3894 // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : tensor<4xf16> 3895 // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> 3896 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3897 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3898 %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3899 func.return %0 : tensor<4x1xf16> 3900} 3901 3902// ----- 3903 3904// CHECK-LABEL: func @min_qint 3905// Regression test to ensure we don't crash getting the initial value for 3906// tf.Min when using quantized integer types. 3907func.func @min_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> { 3908 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3909 %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8> 3910 func.return %0 : tensor<4x1x!tf_type.qint8> 3911} 3912 3913// ----- 3914 3915// CHECK-LABEL: func @prod 3916func.func @prod(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3917 // CHECK: %[[CAST:.*]] = mhlo.convert(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> 3918 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 3919 // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.multiply across dimensions = [1] : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> 3920 // CHECK: %[[CAST_BACK:.*]] = mhlo.convert(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> 3921 // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> 3922 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3923 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3924 %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3925 func.return %0 : tensor<4x1xf16> 3926} 3927 3928// ----- 3929 3930// CHECK-LABEL: func @prod_qint 3931// Regression test to ensure we don't crash getting the initial value for 3932// tf.Prod when using quantized integer types. 3933func.func @prod_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> { 3934 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3935 %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8> 3936 func.return %0 : tensor<4x1x!tf_type.qint8> 3937} 3938 3939// ----- 3940 3941// CHECK-LABEL: @all 3942func.func @all(%input: tensor<4x8xi1>) -> tensor<4xi1> { 3943 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3944 // CHECK: %[[INIT:.*]] = mhlo.constant dense<true> : tensor<i1> 3945 // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%{{.*}} init: %[[INIT]]) applies mhlo.and across dimensions = [1] : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1> 3946 %0 = "tf.All"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> 3947 func.return %0 : tensor<4xi1> 3948} 3949 3950// ----- 3951 3952// CHECK-LABEL: @all_keep_dim 3953func.func @all_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { 3954 // CHECK: mhlo.reshape %{{.*}} : (tensor<4xi1>) -> tensor<4x1xi1> 3955 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3956 %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> 3957 func.return %0 : tensor<4x1xi1> 3958} 3959 3960// ----- 3961 3962// CHECK-LABEL: @all_dynamic 3963func.func @all_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { 3964 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3965 // CHECK: %[[ARG:.*]] = mhlo.convert %{{.*}} : tensor<4x?xi1> 3966 // CHECK: mhlo.reduce(%[[ARG]] 3967 %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> 3968 func.return %0 : tensor<4x1xi1> 3969} 3970 3971// ----- 3972 3973// CHECK-LABEL: @any 3974func.func @any(%input: tensor<4x8xi1>) -> tensor<4xi1> { 3975 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3976 // CHECK: %[[INIT:.*]] = mhlo.constant dense<false> : tensor<i1> 3977 // CHECK: mhlo.reduce(%{{.*}} init: %[[INIT]]) applies mhlo.or across dimensions = [1] : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1> 3978 %0 = "tf.Any"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> 3979 func.return %0 : tensor<4xi1> 3980} 3981 3982// ----- 3983 3984// CHECK-LABEL: @any_keep_dim 3985func.func @any_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { 3986 // CHECK: mhlo.reshape %{{.*}} : (tensor<4xi1>) -> tensor<4x1xi1> 3987 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3988 %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> 3989 func.return %0 : tensor<4x1xi1> 3990} 3991 3992// ----- 3993 3994// CHECK-LABEL: @any_dynamic 3995func.func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { 3996 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3997 // CHECK: %[[ARG:.*]] = mhlo.convert %{{.*}} : tensor<4x?xi1> 3998 // CHECK: mhlo.reduce(%[[ARG]] 3999 %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> 4000 func.return %0 : tensor<4x1xi1> 4001} 4002 4003//===----------------------------------------------------------------------===// 4004// Tile op legalizations. 4005//===----------------------------------------------------------------------===// 4006 4007// ----- 4008 4009// CHECK-LABEL: func @tile_by_reshape 4010func.func @tile_by_reshape(%arg0: tensor<4x8xf32>) -> tensor<28x24xf32> { 4011 // CHECK: %[[BROADCASTED:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<4x8xf32>) -> tensor<7x4x3x8xf32> 4012 // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[BROADCASTED]] : (tensor<7x4x3x8xf32>) -> tensor<28x24xf32> 4013 // CHECK: return %[[RESULT]] : tensor<28x24xf32> 4014 %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> 4015 %0 = "tf.Tile"(%arg0, %multiples) : (tensor<4x8xf32>, tensor<2xi64>) -> tensor<28x24xf32> 4016 func.return %0 : tensor<28x24xf32> 4017} 4018 4019// ----- 4020 4021// CHECK-LABEL: func @tile_just_broadcast 4022func.func @tile_just_broadcast(%arg0: tensor<1x1xf32>) -> tensor<7x3xf32> { 4023 // CHECK: %[[RESULT:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<7x3xf32> 4024 // CHECK: return %[[RESULT]] : tensor<7x3xf32> 4025 %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> 4026 %0 = "tf.Tile"(%arg0, %multiples) : (tensor<1x1xf32>, tensor<2xi64>) -> tensor<7x3xf32> 4027 func.return %0 : tensor<7x3xf32> 4028} 4029 4030// ----- 4031 4032// CHECK-LABEL: func @tile_dynamic_shape 4033func.func @tile_dynamic_shape(%arg0: tensor<?x8xf32>) -> tensor<?x24xf32> { 4034 %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi32> } : () -> tensor<2xi32> 4035 // CHECK: tensor.dim {{.*}} : tensor<?x8xf32> 4036 // CHECK: tensor.from_elements {{.*}} : tensor<4xindex> 4037 // CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<?x8xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> 4038 // CHECK: muli {{.*}} : index 4039 // CHECK: tensor.from_elements {{.*}} : tensor<2xindex> 4040 // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor<?x?x?x?xf32>, tensor<2xindex>) -> tensor<?x24xf32> 4041 %0 = "tf.Tile"(%arg0, %multiples) : (tensor<?x8xf32>, tensor<2xi32>) -> tensor<?x24xf32> 4042 func.return %0 : tensor<?x24xf32> 4043} 4044 4045//===----------------------------------------------------------------------===// 4046// ArgMax/ArgMin op legalizations. 4047//===----------------------------------------------------------------------===// 4048 4049// ----- 4050 4051// CHECK-LABEL: func @argmax_i64_input_i32_output_axis_0 4052func.func @argmax_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> { 4053 // CHECK: %[[INIT:.*]] = mhlo.constant dense<-9223372036854775808> : tensor<i64> 4054 // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32> 4055 // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi64> -> tensor<2xindex> 4056 // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<3x7xi32> 4057 // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) 4058 // CHECK: (%[[ARG1:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<i64>) (%[[ARG2:.*]]: tensor<i32>, %[[ARG4:.*]]: tensor<i32>) 4059 // CHECK: %[[COMPARE:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor<i64>, tensor<i64>) -> tensor<i1> 4060 // CHECK: %[[RESULT1:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG1]], %[[ARG3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> 4061 // CHECK: %[[COMPARE_EQ:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor<i64>, tensor<i64>) -> tensor<i1> 4062 // CHECK: %[[MIN:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] 4063 // CHECK: %[[RESULT2:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG2]], %[[ARG4]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32> 4064 // CHECK: %[[RESULT3:.*]] = "mhlo.select"(%[[COMPARE_EQ]], %[[MIN]], %[[RESULT2]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32> 4065 // CHECK: mhlo.return %[[RESULT1]], %[[RESULT3]] : tensor<i64>, tensor<i32> 4066 // CHECK: return %[[REDUCE]]#1 : tensor<7xi32> 4067 %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32> 4068 %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi64>, tensor<i32>) -> tensor<7xi32> 4069 func.return %0 : tensor<7xi32> 4070} 4071 4072// ----- 4073 4074// CHECK-LABEL: func @argmax_f32_input_i64_output_axis_1 4075func.func @argmax_f32_input_i64_output_axis_1(%arg0: tensor<3x7xf32>) -> tensor<3xi64> { 4076 // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32> 4077 // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i64> 4078 // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xf32> -> tensor<2xindex> 4079 // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<3x7xi64> 4080 // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) 4081 // CHECK: return %[[REDUCE]]#1 : tensor<3xi64> 4082 %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 4083 %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xf32>, tensor<i32>) -> tensor<3xi64> 4084 func.return %0 : tensor<3xi64> 4085} 4086 4087// ----- 4088 4089// CHECK-LABEL: func @argmax_i1_input_i64_output_axis_1 4090func.func @argmax_i1_input_i64_output_axis_1(%arg0: tensor<3x7xi1>) -> tensor<3xi64> { 4091 // CHECK-DAG: %[[INIT:.*]] = mhlo.constant dense<false> : tensor<i1> 4092 // CHECK-DAG: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i64> 4093 // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi1> -> tensor<2xindex> 4094 // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<3x7xi64> 4095 // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) 4096 // CHECK: return %[[REDUCE]]#1 : tensor<3xi64> 4097 %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 4098 %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi1>, tensor<i32>) -> tensor<3xi64> 4099 func.return %0 : tensor<3xi64> 4100} 4101 4102// ----- 4103 4104// CHECK-LABEL: func @argmax_dynamic_shape_input_output 4105func.func @argmax_dynamic_shape_input_output(%arg0: tensor<3x?xi32>) -> tensor<?xi32> { 4106 // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32> 4107 // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32> 4108 // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x?xi32> -> tensor<2xindex> 4109 // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<3x?xi32> 4110 // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) 4111 // CHECK: return %[[REDUCE]]#1 : tensor<?xi32> 4112 %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32> 4113 %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor<i32>) -> tensor<?xi32> 4114 func.return %0 : tensor<?xi32> 4115} 4116 4117// ----- 4118 4119// CHECK-LABEL: func @argmax_dynamic_shape_input 4120func.func @argmax_dynamic_shape_input(%arg0: tensor<3x?xi32>) -> tensor<3xi32> { 4121 // CHECK-DAG: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32> 4122 // CHECK-DAG: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32> 4123 // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x?xi32> -> tensor<2xindex> 4124 // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<3x?xi32> 4125 // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) 4126 // CHECK: return %[[REDUCE]]#1 : tensor<3xi32> 4127 %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 4128 %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor<i32>) -> tensor<3xi32> 4129 func.return %0 : tensor<3xi32> 4130} 4131 4132// ----- 4133 4134// CHECK-LABEL: func @argmin_i64_input_i32_output_axis_0 4135func.func @argmin_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> { 4136 // CHECK: %[[INIT:.*]] = mhlo.constant dense<9223372036854775807> : tensor<i64> 4137 // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32> 4138 // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi64> -> tensor<2xindex> 4139 // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<3x7xi32> 4140 // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) 4141 // CHECK: (%[[ARG1:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<i64>) (%[[ARG2:.*]]: tensor<i32>, %[[ARG4:.*]]: tensor<i32>) 4142 // CHECK: %[[COMPARE:.*]] = mhlo.compare LE, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor<i64>, tensor<i64>) -> tensor<i1> 4143 // CHECK: %[[RESULT1:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG1]], %[[ARG3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> 4144 // CHECK: %[[COMPARE_EQ:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor<i64>, tensor<i64>) -> tensor<i1> 4145 // CHECK: %[[MIN:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] 4146 // CHECK: %[[RESULT2:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG2]], %[[ARG4]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32> 4147 // CHECK: %[[RESULT3:.*]] = "mhlo.select"(%[[COMPARE_EQ]], %[[MIN]], %[[RESULT2]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32> 4148 // CHECK: mhlo.return %[[RESULT1]], %[[RESULT3]] : tensor<i64>, tensor<i32> 4149 // CHECK: return %[[REDUCE]]#1 : tensor<7xi32> 4150 %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32> 4151 %0 = "tf.ArgMin"(%arg0, %axis) : (tensor<3x7xi64>, tensor<i32>) -> tensor<7xi32> 4152 func.return %0 : tensor<7xi32> 4153} 4154 4155//===----------------------------------------------------------------------===// 4156// Random op legalizations. 4157//===----------------------------------------------------------------------===// 4158 4159// ----- 4160 4161// CHECK-LABEL: func @rng_uniform 4162func.func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { 4163 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4164 // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 4165 // CHECK: %[[CONV:.*]] = mhlo.convert(%arg0) : (tensor<3xi32>) -> tensor<3xi64> 4166 // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*UNIFORM.*}} -> tensor<12x?x64xf32> 4167 %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> 4168 // CHECK: return %[[F32]] 4169 func.return %0 : tensor<12x?x64xf32> 4170} 4171 4172// ----- 4173 4174// CHECK-LABEL: func @rng_std_normal 4175func.func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { 4176 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4177 // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 4178 // CHECK: %[[CONV:.*]] = mhlo.convert(%arg0) : (tensor<3xi32>) -> tensor<3xi64> 4179 // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*NORMAL.*}} -> tensor<12x?x64xf32> 4180 %0 = "tf.RandomStandardNormal"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> 4181 // CHECK: return %[[F32]] 4182 func.return %0 : tensor<12x?x64xf32> 4183} 4184 4185//===----------------------------------------------------------------------===// 4186// Range op legalizations. 4187//===----------------------------------------------------------------------===// 4188 4189// ----- 4190 4191// CHECK-LABEL: func @range 4192// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[DELTA:%.*]]: tensor<f32> 4193func.func @range(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<5xf32> { 4194 %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor<f32>} : () -> tensor<f32> 4195 // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota" 4196 // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<> : tensor<0xi64>} 4197 // CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>} 4198 %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32> 4199 func.return %3 : tensor<5xf32> 4200} 4201 4202// ----- 4203 4204// CHECK-LABEL: func @range_dynamic 4205// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[DELTA:%.*]]: tensor<f32> 4206func.func @range_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<?xf32> { 4207 // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0 4208 // CHECK-DAG: [[ABS1:%.+]] = mhlo.abs [[SUB]] 4209 // CHECK-DAG: [[CONVERT1:%.+]] = mhlo.convert [[ABS1]] 4210 // CHECK-DAG: [[CONVERT2:%.+]] = mhlo.convert %arg2 4211 // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT1]], [[CONVERT2]] 4212 // CHECK-DAG: [[CEIL:%.+]] = mhlo.ceil [[DIV]] 4213 // CHECK-DAG: [[CONVERT3:%.+]] = mhlo.convert([[CEIL]]) 4214 // CHECK-DAG: [[RESHAPE:%.+]] = mhlo.reshape [[CONVERT3]] 4215 // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64} 4216 // CHECK-DAG: [[CONVERT3:%.+]] = mhlo.convert %arg0 4217 // CHECK-DAG: [[CONVERT4:%.+]] = mhlo.convert %arg2 4218 // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<> : tensor<0xi64>} 4219 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<> : tensor<0xi64>} 4220 %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32> 4221 4222 // CHECK: return [[ADD]] 4223 func.return %2 : tensor<?xf32> 4224} 4225 4226// ----- 4227 4228// CHECK-LABEL: func @range_int_dynamic 4229// CHECK-SAME: [[START:%.*]]: tensor<i32>, [[DELTA:%.*]]: tensor<i32> 4230func.func @range_int_dynamic(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?xi32> { 4231 // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0 4232 // CHECK-DAG: [[ABS1:%.+]] = mhlo.abs [[SUB]] 4233 // CHECK-DAG: [[CONVERT1:%.+]] = mhlo.convert([[ABS1]]) 4234 // CHECK-DAG: [[CONVERT2:%.+]] = mhlo.convert(%arg2) 4235 // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT1]], [[CONVERT2]] 4236 // CHECK-DAG: [[CEIL:%.+]] = mhlo.ceil [[DIV]] 4237 // CHECK-DAG: [[CONVERT3:%.+]] = mhlo.convert([[CEIL]]) 4238 // CHECK-DAG: [[RESHAPE:%.+]] = mhlo.reshape [[CONVERT3]] 4239 // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64} 4240 // CHECK-DAG: [[CONVERT3:%.+]] = mhlo.convert %arg0 4241 // CHECK-DAG: [[CONVERT4:%.+]] = mhlo.convert %arg2 4242 // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<> : tensor<0xi64>} 4243 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<> : tensor<0xi64>} 4244 %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32> 4245 4246 // CHECK: return [[ADD]] 4247 func.return %2 : tensor<?xi32> 4248} 4249 4250// ----- 4251 4252// CHECK-LABEL: func @linspace_static 4253// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[STOP:%.*]]: tensor<f32> 4254func.func @linspace_static(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<4xf32> { 4255 // CHECK-DAG: [[NUM:%.*]] = mhlo.constant dense<4> 4256 // CHECK-DAG: [[NUM_F32:%.*]] = mhlo.convert([[NUM]]) 4257 // CHECK-DAG: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> 4258 // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = chlo.broadcast_subtract [[NUM_F32]], [[ONE]] 4259 // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]] 4260 // CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] 4261 // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} 4262 // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<> : tensor<0xi64>} 4263 // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>} 4264 // CHECK: return [[LINSPACE]] 4265 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor<i32>} : () -> tensor<i32> 4266 %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<4xf32> 4267 func.return %1 : tensor<4xf32> 4268} 4269 4270// ----- 4271 4272// CHECK-LABEL: func @linspace_dynamic 4273func.func @linspace_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>) -> tensor<?xf32> { 4274 // CHECK: "tf.LinSpace" 4275 %0 = "tf.LinSpace"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<?xf32> 4276 func.return %0 : tensor<?xf32> 4277} 4278 4279// ----- 4280 4281// CHECK-LABEL: func @linspace_invalid_num 4282func.func @linspace_invalid_num(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<?xf32> { 4283 // CHECK: mhlo.constant dense<> : tensor<0xi32> 4284 // CHECK: "tf.LinSpace" 4285 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> 4286 %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<0xi32>) -> tensor<?xf32> 4287 func.return %1 : tensor<?xf32> 4288} 4289 4290//===----------------------------------------------------------------------===// 4291// LegacyCall op legalizations. 4292//===----------------------------------------------------------------------===// 4293 4294// ----- 4295 4296func.func @identity_func(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> { 4297 func.return %arg0: tensor<10x2xf32> 4298} 4299 4300// CHECK-LABEL: testSimpleLegacyCallOp 4301func.func @testSimpleLegacyCallOp(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> { 4302 // CHECK: %[[RESULT:.*]] = call @identity_func(%arg0) : (tensor<10x2xf32>) -> tensor<10x2xf32> 4303 %0 = "tf.LegacyCall"(%arg0) {f = @identity_func} : (tensor<10x2xf32>) -> tensor<10x2xf32> 4304 // CHECK: return %[[RESULT]] 4305 func.return %0: tensor<10x2xf32> 4306} 4307 4308// ----- 4309 4310func.func @select_first(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> { 4311 func.return %arg0: tensor<10x2xf32> 4312} 4313 4314// CHECK-LABEL: testMultiInputLegacyCallOp 4315func.func @testMultiInputLegacyCallOp(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> { 4316 // CHECK: %[[RESULT:.*]] = call @select_first(%arg0, %arg1) : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32> 4317 %0 = "tf.LegacyCall"(%arg0, %arg1) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @select_first} : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32> 4318 // CHECK: return %[[RESULT]] 4319 func.return %0: tensor<10x2xf32> 4320} 4321 4322//===----------------------------------------------------------------------===// 4323// Conv op legalizations. 4324//===----------------------------------------------------------------------===// 4325 4326// ----- 4327 4328// CHECK-LABEL: conv_simple 4329func.func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> { 4330 4331 // CHECK: mhlo.convolution(%arg0, %arg1) 4332 // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] 4333 // CHECK-SAME{LITERAL}: window = {stride = [4, 5], pad = [[0, 1], [2, 3]], rhs_dilate = [2, 3]} 4334 // CHECK-SAME: batch_group_count = 1 4335 // CHECK-SAME: feature_group_count = 2 4336 4337 %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> 4338 func.return %0 : tensor<256x8x7x16xf32> 4339} 4340 4341// ----- 4342 4343// CHECK-LABEL: conv3d_simple 4344func.func @conv3d_simple(%arg0: tensor<256x32x32x32x6xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> { 4345 4346 // CHECK: mhlo.convolution(%arg0, %arg1) 4347 // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f] 4348 // CHECK-SAME{LITERAL}: window = {stride = [5, 6, 7], pad = [[1, 2], [2, 3], [2, 3]], rhs_dilate = [2, 3, 4]} 4349 // CHECK-SAME: batch_group_count = 1 4350 // CHECK-SAME: feature_group_count = 2 4351 4352 %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", dilations = [1, 2, 3, 4, 1], padding = "SAME", strides = [1, 5, 6, 7, 1]} : (tensor<256x32x32x32x6xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> 4353 func.return %0 : tensor<256x7x6x5x16xf32> 4354} 4355 4356// ----- 4357 4358// CHECK-LABEL: depthwiseconv_simple 4359func.func @depthwiseconv_simple(%arg0: tensor<?x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32>) -> tensor<?x3x4x9xf32> { 4360 // CHECK: %[[RESHAPED_FILTER:.*]] = mhlo.reshape %arg1 : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32> 4361 // CHECK: mhlo.convolution(%arg0, %[[RESHAPED_FILTER]]) 4362 // CHECK-SAME: feature_group_count = 3 4363 %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { 4364 data_format = "NHWC", 4365 device = "", 4366 dilations = [1, 1, 1, 1], 4367 explicit_paddings = [], 4368 padding = "VALID", 4369 strides = [1, 1, 1, 1] 4370 } : (tensor<?x4x5x3xf32>, tensor<2x2x3x3xf32>) -> tensor<?x3x4x9xf32> 4371 func.return %0 : tensor<?x3x4x9xf32> 4372} 4373 4374// ----- 4375 4376// CHECK-LABEL: conv_valid_padding 4377func.func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> { 4378 // CHECK: mhlo.convolution(%arg0, %arg1) 4379 4380 %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x4x5x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> 4381 func.return %0 : tensor<1x2x3x1xf32> 4382} 4383 4384// ----- 4385 4386// CHECK-LABEL: conv_explicit_paddings 4387func.func @conv_explicit_paddings(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> { 4388 4389 // CHECK: mhlo.convolution(%arg0, %arg1) 4390 // CHECK-SAME{LITERAL}: pad = [[6, 0], [3, 3]] 4391 4392 %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "EXPLICIT", explicit_paddings = [0, 0, 6, 0, 3, 3, 0, 0], strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> 4393 func.return %0 : tensor<256x9x7x16xf32> 4394} 4395 4396// ----- 4397 4398// CHECK-LABEL: @conv2d_backprop_input_dynamic 4399func.func @conv2d_backprop_input_dynamic(%filter: tensor<2x2x1x16xf32>, %out_backprop: tensor<?x256x256x16xf32>) -> tensor<?x512x512x1xf32> { 4400 // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} 4401 // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) 4402 // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f] 4403 // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} 4404 // CHECK-SAME: batch_group_count = 1 : i64 4405 // CHECK-SAME: feature_group_count = 1 : i64 4406 // CHECK: return %[[RESULT]] 4407 %cst_0_1d = "tf.Const"() {device = "", value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> 4408 %cst_1_0d = "tf.Const"() {device = "", value = dense<1> : tensor<i32>} : () -> tensor<i32> 4409 %cst_1_1d = "tf.Const"() {device = "", value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 4410 %cst_512_0d = "tf.Const"() {device = "", value = dense<512> : tensor<i32>} : () -> tensor<i32> 4411 %out_backprop_shape = "tf.Shape"(%out_backprop) {device = ""} : (tensor<?x256x256x16xf32>) -> tensor<4xi32> 4412 %batch_size = "tf.StridedSlice"(%out_backprop_shape, %cst_0_1d, %cst_1_1d, %cst_1_1d) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32> 4413 %input_shape = "tf.Pack"(%batch_size, %cst_512_0d, %cst_512_0d, %cst_1_0d) {axis = 0 : i64, device = ""} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<4xi32> 4414 %result = "tf.Conv2DBackpropInput"(%input_shape, %filter, %out_backprop) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<4xi32>, tensor<2x2x1x16xf32>, tensor<?x256x256x16xf32>) -> tensor<?x512x512x1xf32> 4415 return %result : tensor<?x512x512x1xf32> 4416} 4417 4418// ----- 4419 4420// CHECK-LABEL: @conv2d_backprop_input 4421func.func @conv2d_backprop_input( 4422 %filter: tensor<3x3x1x32xf32>, 4423 %out_backprop: tensor<100x26x26x32xf32> 4424 ) -> tensor<100x28x28x1xf32> { 4425 // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} 4426 // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) 4427 // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f] 4428 // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} 4429 // CHECK-SAME: batch_group_count = 1 : i64 4430 // CHECK-SAME: feature_group_count = 1 : i64 4431 // CHECK: return %[[RESULT]] 4432 %input_sizes = "tf.Const" () { value = dense<[100,28,28,1]> : tensor<4xi32> } : () -> tensor<4xi32> 4433 %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) { 4434 data_format = "NHWC", 4435 dilations = [1, 1, 1, 1], 4436 explicit_paddings = [], 4437 padding = "VALID", 4438 strides = [1, 1, 1, 1], 4439 use_cudnn_on_gpu = true 4440 } : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32> 4441 func.return %result : tensor<100x28x28x1xf32> 4442} 4443 4444// ----- 4445 4446// CHECK-LABEL: @conv2d_backprop_input_grouped 4447func.func @conv2d_backprop_input_grouped( 4448 %filter: tensor<2x2x5x21xf32>, 4449 %out_backprop: tensor<5x2x2x21xf32> 4450 ) -> tensor<5x3x3x15xf32> { 4451 %input_sizes = "tf.Const" () { value = dense<[5, 3, 3, 15]> : tensor<4xi32> } : () -> tensor<4xi32> 4452 4453 // Verify filter transformation for grouped convolution. 4454 4455 // CHECK: %[[RESHAPE:.*]] = mhlo.reshape %arg0 : (tensor<2x2x5x21xf32>) -> tensor<2x2x5x3x7xf32> 4456 // CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[RESHAPE]]) 4457 // CHECK-SAME: permutation = dense<[0, 1, 3, 2, 4]> 4458 // CHECK-SAME: (tensor<2x2x5x3x7xf32>) -> tensor<2x2x3x5x7xf32> 4459 // CHECK: mhlo.reshape %[[TRANSPOSE]] : (tensor<2x2x3x5x7xf32>) -> tensor<2x2x15x7xf32> 4460 4461 %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) { 4462 data_format = "NHWC", 4463 dilations = [1, 1, 1, 1], 4464 explicit_paddings = [], 4465 padding = "VALID", 4466 strides = [1, 1, 1, 1], 4467 use_cudnn_on_gpu = true 4468 } : (tensor<4xi32>, tensor<2x2x5x21xf32>, tensor<5x2x2x21xf32>) -> tensor<5x3x3x15xf32> 4469 func.return %result : tensor<5x3x3x15xf32> 4470} 4471 4472 4473// CHECK-LABEL: @conv3d_backprop_input 4474func.func @conv3d_backprop_input(%filter: tensor<3x3x3x1x6xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> { 4475 // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} 4476 // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) 4477 // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, o, i]->[b, 0, 1, 2, f] 4478 // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]} 4479 // CHECK-SAME: batch_group_count = 1 : i64, 4480 // CHECK-SAME: feature_group_count = 1 : i64 4481 4482 // CHECK: return %[[RESULT]] 4483 %input_sizes = "tf.Const" () {value = dense<[2, 8, 8, 8, 1]> : tensor<5xi32>} : () -> tensor<5xi32> 4484 %result = "tf.Conv3DBackpropInputV2"(%input_sizes, %filter, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<5xi32>, tensor<3x3x3x1x6xf32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> 4485 func.return %result : tensor<2x8x8x8x1xf32> 4486} 4487 4488// ----- 4489 4490// CHECK-LABEL: @conv2d_backprop_filter 4491func.func @conv2d_backprop_filter( 4492 %input: tensor<100x28x28x1xf32>, 4493 %out_backprop: tensor<100x26x26x32xf32> 4494 ) -> tensor<3x3x1x32xf32> { 4495 // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1) 4496 // CHECK-SAME: dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f] 4497 // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} 4498 // CHECK-SAME: batch_group_count = 1 : i64 4499 // CHECK-SAME: feature_group_count = 1 : i64 4500 // CHECK: return %[[RESULT]] 4501 %filter_sizes = "tf.Const" () { value = dense<[3,3,1,32]> : tensor<4xi32> } : () -> tensor<4xi32> 4502 %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) { 4503 data_format = "NHWC", 4504 dilations = [1, 1, 1, 1], 4505 explicit_paddings = [], 4506 padding = "VALID", 4507 strides = [1, 1, 1, 1], 4508 use_cudnn_on_gpu = true 4509 } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<3x3x1x32xf32> 4510 func.return %result : tensor<3x3x1x32xf32> 4511} 4512 4513// ----- 4514 4515// CHECK-LABEL: @conv2d_backprop_filter_grouped 4516func.func @conv2d_backprop_filter_grouped( 4517 %input: tensor<1x2x2x2xf32>, 4518 %out_backprop: tensor<1x1x1x2xf32> 4519 ) -> tensor<2x2x1x2xf32> { 4520 4521 // CHECK: mhlo.convolution(%arg0, %arg1) 4522 // CHECK-SAME: batch_group_count = 2 : i64 4523 // CHECK-SAME: feature_group_count = 1 : i64 4524 4525 %filter_sizes = "tf.Const" () { value = dense<[2, 2, 1, 2]> : tensor<4xi32> } : () -> tensor<4xi32> 4526 %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) { 4527 data_format = "NHWC", 4528 dilations = [1, 1, 1, 1], 4529 explicit_paddings = [], 4530 padding = "VALID", 4531 strides = [1, 1, 1, 1], 4532 use_cudnn_on_gpu = true 4533 } : (tensor<1x2x2x2xf32>, tensor<4xi32>, tensor<1x1x1x2xf32>) -> tensor<2x2x1x2xf32> 4534 func.return %result : tensor<2x2x1x2xf32> 4535} 4536 4537 4538// CHECK-LABEL: @conv3d_backprop_filter 4539func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> { 4540 // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1) 4541 // CHECK-SAME: dim_numbers = [f, 0, 1, 2, b]x[i, 0, 1, 2, o]->[0, 1, 2, b, f] 4542 // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]} 4543 // CHECK-SAME: batch_group_count = 1 : i64 4544 // CHECK-SAME: feature_group_count = 1 : i64 4545 // CHECK: return %[[RESULT]] 4546 %filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32> 4547 %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> 4548 func.return %result : tensor<3x3x3x1x6xf32> 4549} 4550 4551// ----- 4552 4553// CHECK-LABEL: @collective_permute 4554func.func @collective_permute(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { 4555 %source_target_pairs = "tf.Const" () { 4556 value = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi32> 4557 } : () -> tensor<3x2xi32> 4558 4559 // CHECK: "mhlo.collective_permute" 4560 // CHECK-SAME: source_target_pairs = dense<{{\[}}[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> 4561 %0 = "tf.CollectivePermute"(%arg0, %source_target_pairs) { 4562 } : (tensor<128x32xf32>, tensor<3x2xi32>) -> tensor<128x32xf32> 4563 4564 func.return %0 : tensor<128x32xf32> 4565} 4566 4567// ----- 4568 4569// CHECK-LABEL: @cross_replica_sum 4570func.func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { 4571 %replica_groups = "tf.Const" () { 4572 value = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32> 4573 } : () -> tensor<2x4xi32> 4574 4575 // CHECK: mhlo.cross-replica-sum 4576 // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> 4577 %result = "tf.CrossReplicaSum" (%input, %replica_groups) : (tensor<10xf32>, tensor<2x4xi32>) -> tensor<10xf32> 4578 func.return %result : tensor<10xf32> 4579} 4580 4581// ----- 4582 4583// CHECK-LABEL: conv_dynamic 4584func.func @conv_dynamic(%arg0: tensor<?x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<?x8x7x16xf32> { 4585 // CHECK: "mhlo.dynamic_conv" 4586 // CHECK-SAME: {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 2 : i64, rhs_dilation = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[4, 5]> : tensor<2xi64>} : (tensor<?x32x32x6xf32>, tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<?x8x7x16xf32> 4587 %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<?x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<?x8x7x16xf32> 4588 func.return %0 : tensor<?x8x7x16xf32> 4589} 4590 4591//===----------------------------------------------------------------------===// 4592// tf.Split legalization 4593//===----------------------------------------------------------------------===// 4594 4595// ----- 4596 4597// CHECK-LABEL: @split_not_match_non_const_split_dim 4598func.func @split_not_match_non_const_split_dim(%input: tensor<4x4xf32>, %split_dim: tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>) { 4599 // CHECK: tf.Split 4600 %0:2 = "tf.Split"(%split_dim, %input) : (tensor<i32>, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) 4601 func.return %0#0, %0#1 : tensor<*xf32>, tensor<*xf32> 4602} 4603 4604// ----- 4605 4606// CHECK-LABEL: @split_not_match_unknown_input_dim 4607func.func @split_not_match_unknown_input_dim(%input: tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) { 4608 %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 4609 // CHECK: tensor.dim {{.*}} : tensor<4x?x4xf32> 4610 // CHECK: arith.divsi {{.*}} : index 4611 // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> 4612 // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<4x?x4xf32> 4613 // CHECK: muli {{.*}} : index 4614 // CHECK: muli {{.*}} : index 4615 // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> 4616 // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> 4617 // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> 4618 // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<4x?x4xf32> 4619 %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) 4620 func.return %0#0, %0#1 : tensor<4x?x4xf32>, tensor<4x?x4xf32> 4621} 4622 4623// ----- 4624 4625// CHECK-LABEL: @split_match_and_split_into_two 4626func.func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) { 4627 %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> 4628 // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, 6]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32> 4629 // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32> 4630 %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) 4631 // CHECK: return %[[ONE]], %[[TWO]] 4632 func.return %0#0, %0#1 : tensor<2x6xf32>, tensor<2x6xf32> 4633} 4634 4635// ----- 4636 4637// CHECK-LABEL: @split_match_and_split_into_two_dynamic 4638func.func @split_match_and_split_into_two_dynamic(%input: tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) { 4639 %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> 4640 // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, -1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> 4641 // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, -1]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> 4642 %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) 4643 // CHECK: return %[[ONE]], %[[TWO]] 4644 func.return %0#0, %0#1 : tensor<2x?xf32>, tensor<2x?xf32> 4645} 4646 4647// ----- 4648 4649// CHECK-LABEL: @split_match_and_split_into_three 4650// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) 4651func.func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) { 4652 %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 4653 // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> 4654 // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> 4655 // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> 4656 %0:3 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) 4657 // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] 4658 func.return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32> 4659} 4660 4661//===----------------------------------------------------------------------===// 4662// tf.TopKV2 legalization 4663//===----------------------------------------------------------------------===// 4664 4665// ----- 4666 4667// CHECK-LABEL: topk_v2_non_const_k 4668func.func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) { 4669 // CHECK: tf.TopKV2 4670 %0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) 4671 func.return %0#0, %0#1: tensor<?xf32>, tensor<?xi32> 4672} 4673 4674// ----- 4675 4676// CHECK-LABEL: topk_v2_unknown_input_last_dim 4677func.func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) { 4678 %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32> 4679 // CHECK: tf.TopKV2 4680 %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor<i32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) 4681 func.return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32> 4682} 4683 4684// ----- 4685 4686// CHECK-LABEL: topk_v2 4687// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32> 4688func.func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { 4689 %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32> 4690 4691 // CHECK: chlo.top_k(%[[INPUT]], k = 8) 4692 %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) 4693 func.return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32> 4694} 4695 4696//===----------------------------------------------------------------------===// 4697// tf.SplitV legalization 4698//===----------------------------------------------------------------------===// 4699 4700// ----- 4701 4702// CHECK-LABEL: @splitv_match_and_split_into_three 4703// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) 4704func.func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { 4705 %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> 4706 %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 4707 // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x1xf32> 4708 // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> 4709 // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x3xf32> 4710 %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) 4711 // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] 4712 func.return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> 4713} 4714 4715// ----- 4716 4717// CHECK-LABEL: @splitv_match_and_split_into_three_dynamic 4718func.func @splitv_match_and_split_into_three_dynamic(%input: tensor<?x6xf32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>) { 4719 %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> 4720 %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 4721 // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x1xf32> 4722 // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x2xf32> 4723 // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x3xf32> 4724 %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<?x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>) 4725 func.return %0#0, %0#1, %0#2 : tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32> 4726} 4727 4728// ----- 4729 4730// CHECK-LABEL: @splitv_dynamic_dim_in_split_sizes 4731func.func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { 4732 %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} : () -> tensor<3xi32> 4733 %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 4734 // CHECK: limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64> 4735 // CHECK: limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64> 4736 // CHECK: limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64> 4737 %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) 4738 func.return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> 4739} 4740 4741//===----------------------------------------------------------------------===// 4742// tf.Assert legalization 4743//===----------------------------------------------------------------------===// 4744 4745// ----- 4746 4747// CHECK-LABEL: @assert 4748func.func @assert(%arg0: tensor<i1>, %arg1: tensor<*xf32>) { 4749 // CHECK-NOT: tf.Assert 4750 "tf.Assert"(%arg0, %arg1) {summarize = 1} : (tensor<i1>, tensor<*xf32>) -> () 4751 func.return 4752} 4753 4754//===----------------------------------------------------------------------===// 4755// tf.Unpack legalization 4756//===----------------------------------------------------------------------===// 4757 4758// ----- 4759 4760// CHECK-LABEL: @unpack 4761func.func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { 4762 // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> 4763 // CHECK: %[[RES1:.*]] = mhlo.reshape %[[SLICE1]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32> 4764 // CHECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> 4765 // CHECK: %[[RES2:.*]] = mhlo.reshape %[[SLICE2]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32> 4766 // CHECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> 4767 // CHECK: %[[RES3:.*]] = mhlo.reshape %[[SLICE3]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32> 4768 4769 %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) 4770 // return %[[RES1]], %[[RES2]], %[[RES3]] 4771 func.return %0#0, %0#1, %0#2 : tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32> 4772} 4773 4774// ----- 4775 4776// CHECK-LABEL: func @unpack_dynamic 4777func.func @unpack_dynamic(%arg0: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { 4778 // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<?x?x2xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x?x1xf32> 4779 // CHECK: tensor.from_elements {{.*}} : tensor<2xi32> 4780 // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor<?x?x1xf32>, tensor<2xi32>) -> tensor<?x?xf32> 4781 // CHECK: tensor.from_elements {{.*}} : tensor<3xi32> 4782 // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<?x?x2xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x?x1xf32> 4783 // CHECK: tensor.from_elements {{.*}} : tensor<2xi32> 4784 // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor<?x?x1xf32>, tensor<2xi32>) -> tensor<?x?xf32> 4785 // CHECK: return {{.*}} : tensor<?x?xf32>, tensor<?x?xf32> 4786 %0:2 = "tf.Unpack"(%arg0) {axis = -1 : i64} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) 4787 func.return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32> 4788} 4789 4790// ----- 4791 4792// CHECK-LABEL: @unpack_unranked 4793func.func @unpack_unranked(%input: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { 4794 4795 // CHECK: tf.Unpack 4796 %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) 4797 func.return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32> 4798} 4799 4800//===----------------------------------------------------------------------===// 4801// tf.UnsortedSegment{Max|Min|Prod|Sum} legalization 4802//===----------------------------------------------------------------------===// 4803 4804// ----- 4805 4806// CHECK-LABEL: @unsorted_segment_sum 4807// CHECK-SAME: [[DATA:%.*]]: tensor<8x16x64xf32> 4808// CHECK-SAME: [[SI:%.*]]: tensor<8x16xi32> 4809func.func @unsorted_segment_sum(%data: tensor<8x16x64xf32>, %segment_ids : tensor<8x16xi32>) -> (tensor<4x64xf32>) { 4810 %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32> 4811 // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4812 // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ZERO]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor<f32>) -> tensor<4x64xf32> 4813 // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ({ 4814 // CHECK: ^{{.*}}([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>): 4815 // CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] : tensor<f32> 4816 // CHECK: mhlo.return [[ADD]] 4817 // CHECK: indices_are_sorted = false, 4818 // CHECK-SAME: scatter_dimension_numbers = 4819 // CHECK-SAME: update_window_dims = [2] 4820 // CHECK-SAME: inserted_window_dims = [0] 4821 // CHECK-SAME: scatter_dims_to_operand_dims = [0] 4822 // CHECK-SAME: index_vector_dim = 2 4823 // CHECK-SAME: unique_indices = false 4824 // CHECK-SAME: (tensor<4x64xf32>, tensor<8x16xi32>, tensor<8x16x64xf32>) -> tensor<4x64xf32> 4825 // CHECK: return [[SCATTER]] 4826 %0 = "tf.UnsortedSegmentSum"(%data, %segment_ids, %num_segments) : (tensor<8x16x64xf32>, tensor<8x16xi32>, tensor<i32>) -> (tensor<4x64xf32>) 4827 func.return %0: tensor<4x64xf32> 4828} 4829 4830// ----- 4831 4832// CHECK-LABEL: @unsorted_segment_prod 4833// CHECK-SAME: [[DATA:%.*]]: tensor<8x?x64xf32> 4834// CHECK-SAME: [[SI:%.*]]: tensor<?x16xi32> 4835func.func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) { 4836 %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32> 4837 // CHECK: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 4838 // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ONE]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor<f32>) -> tensor<4x64xf32> 4839 // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ({ 4840 // CHECK: ^{{.*}}([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>): 4841 // CHECK: [[MUL:%.*]] = mhlo.multiply [[LHS]], [[RHS]] : tensor<f32> 4842 // CHECK: mhlo.return [[MUL]] 4843 // CHECK: indices_are_sorted = false 4844 // CHECK-SAME: scatter_dimension_numbers = 4845 // CHECK-SAME: update_window_dims = [2] 4846 // CHECK-SAME: inserted_window_dims = [0] 4847 // CHECK-SAME: scatter_dims_to_operand_dims = [0] 4848 // CHECK-SAME: index_vector_dim = 2 4849 // CHECK-SAME: unique_indices = false 4850 // CHECK-SAME: (tensor<4x64xf32>, tensor<?x16xi32>, tensor<8x?x64xf32>) -> tensor<4x?xf32> 4851 // CHECK: return [[SCATTER]] 4852 %0 = "tf.UnsortedSegmentProd"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>) 4853 func.return %0: tensor<4x?xf32> 4854} 4855 4856// ----- 4857 4858// CHECK-LABEL: @unsorted_segment_min 4859func.func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) { 4860 %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32> 4861 // CHECK: mhlo.constant dense<3.40282347E+38> : tensor<f32> 4862 // CHECK: mhlo.scatter 4863 // CHECK: mhlo.minimum 4864 %0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>) 4865 func.return %0: tensor<4x?xf32> 4866} 4867 4868// ----- 4869 4870// CHECK-LABEL: @unsorted_segment_max 4871func.func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) { 4872 %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32> 4873 // CHECK: mhlo.constant dense<-3.40282347E+38> : tensor<f32> 4874 // CHECK: mhlo.scatter 4875 // CHECK: mhlo.maximum 4876 %0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>) 4877 func.return %0: tensor<4x?xf32> 4878} 4879 4880//===----------------------------------------------------------------------===// 4881// tf.GatherNd legalization 4882//===----------------------------------------------------------------------===// 4883// CHECK-LABEL: func @gatherNd_dynamic 4884func.func @gatherNd_dynamic(%arg0: tensor<?x?x?xi32>, %arg1: tensor<?x6x2xi32>) -> tensor<?x6x?xi32> { 4885 // CHECK: tensor.dim 4886 // CHECK: index_cast 4887 // CHECK: tensor.from_elements 4888 // CHECK: mhlo.dynamic_gather 4889 // CHECK-SAME: dimension_numbers = 4890 // CHECK-SAME: offset_dims = [2] 4891 // CHECK-SAME: collapsed_slice_dims = [0, 1] 4892 // CHECK-SAME: start_index_map = [0, 1] 4893 // CHECK-SAME: index_vector_dim = 2 4894 // CHECK-SAME: indices_are_sorted = false 4895 %0 = "tf.GatherNd"(%arg0, %arg1) {Tindices = i32, Tparams = i32, device = ""} : (tensor<?x?x?xi32>, tensor<?x6x2xi32>) -> tensor<?x6x?xi32> 4896 func.return %0 : tensor<?x6x?xi32> 4897} 4898 4899// ----- 4900 4901// CHECK-LABEL: func @gatherNd_static 4902func.func @gatherNd_static(%arg0: tensor<2x4x128xf32>, %arg1: tensor<2x1xi32>) -> tensor<2x4x128xf32> { 4903 // CHECK: "mhlo.gather"({{.*}}) { 4904 // CHECK-SAME: dimension_numbers = 4905 // CHECK-SAME: offset_dims = [1, 2] 4906 // CHECK-SAME: collapsed_slice_dims = [0] 4907 // CHECK-SAME: start_index_map = [0] 4908 // CHECK-SAME: index_vector_dim = 1 4909 // CHECK-SAME: indices_are_sorted = false 4910 // CHECK-SAME: slice_sizes = dense<[1, 4, 128]> 4911 // CHECK-SAME: (tensor<2x4x128xf32>, tensor<2x1xi32>) -> tensor<2x4x128xf32> 4912 %0 = "tf.GatherNd"(%arg0, %arg1) {Tindices = i32, Tparams = i32, device = ""} : (tensor<2x4x128xf32>, tensor<2x1xi32>) -> tensor<2x4x128xf32> 4913 func.return %0 : tensor<2x4x128xf32> 4914} 4915 4916//===----------------------------------------------------------------------===// 4917// tf.GatherV2 legalization 4918//===----------------------------------------------------------------------===// 4919 4920// ----- 4921 4922// CHECK-LABEL: @gather_v2 4923// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] 4924// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] 4925func.func @gather_v2(%params: tensor<16x2x3xf32>, %indices: tensor<16x5xi32>) -> tensor<16x2x5xf32> { 4926 // CHECK: mhlo.torch_index_select 4927 // CHECK-SAME: %[[PARAMS]], %[[INDICES]] 4928 // CHECK-SAME: batch_dims = 1 4929 // CHECK-SAME: dim = 2 4930 %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> 4931 %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5xf32> 4932 func.return %1 : tensor<16x2x5xf32> 4933} 4934 4935// ----- 4936 4937// CHECK-LABEL: @gather_v2_dynamic 4938// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] 4939// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] 4940func.func @gather_v2_dynamic(%params: tensor<?x?x?xf32>, %indices: tensor<?x?xi32>) -> tensor<*xf32> { 4941 // CHECK: mhlo.torch_index_select 4942 // CHECK-SAME: %[[PARAMS]], %[[INDICES]] 4943 // CHECK-SAME: batch_dims = 1 4944 // CHECK-SAME: dim = 2 4945 %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> 4946 %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor<?x?x?xf32>, tensor<?x?xi32>, tensor<1xi32>) -> tensor<*xf32> 4947 func.return %1 : tensor<*xf32> 4948} 4949 4950// ----- 4951 4952// CHECK-LABEL: @gather_v2_dynamic_index_i64 4953// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] 4954// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] 4955func.func @gather_v2_dynamic_index_i64(%params: tensor<?x?x?xf32>, %indices: tensor<?x?xi64>) -> tensor<*xf32> { 4956 // CHECK: mhlo.torch_index_select 4957 // CHECK-SAME: %[[PARAMS]], %[[INDICES]] 4958 // CHECK-SAME: batch_dims = 1 4959 // CHECK-SAME: dim = 2 4960 %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> 4961 %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor<?x?x?xf32>, tensor<?x?xi64>, tensor<1xi32>) -> tensor<*xf32> 4962 func.return %1 : tensor<*xf32> 4963} 4964 4965// ----- 4966 4967// CHECK-LABEL: @gather_v2_unranked 4968func.func @gather_v2_unranked(%params: tensor<*xf32>, %indices: tensor<*xi32>) -> tensor<*xf32> { 4969 // CHECK: tf.GatherV2 4970 %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> 4971 %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor<*xf32>, tensor<*xi32>, tensor<1xi32>) -> tensor<*xf32> 4972 func.return %1 : tensor<*xf32> 4973} 4974 4975// ----- 4976 4977// CHECK-LABEL: @gather_v2_dynamic_shape 4978// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] 4979// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] 4980func.func @gather_v2_dynamic_shape(%params: tensor<?x2x3xf32>, %indices: tensor<?x5xi32>) -> tensor<?x2x5xf32> { 4981 // CHECK: mhlo.torch_index_select 4982 // CHECK-SAME: %[[PARAMS]], %[[INDICES]] 4983 // CHECK-SAME: batch_dims = 1 4984 // CHECK-SAME: dim = 2 4985 %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> 4986 %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor<?x2x3xf32>, tensor<?x5xi32>, tensor<1xi32>) -> tensor<?x2x5xf32> 4987 func.return %1 : tensor<?x2x5xf32> 4988} 4989 4990//===----------------------------------------------------------------------===// 4991// tf.StridedSliceGrad legalization 4992//===----------------------------------------------------------------------===// 4993 4994// ----- 4995 4996// CHECK-LABEL: strided_slice_grad 4997// CHECK-SAME: [[GRAD:%.*]]: tensor<4x16x1022xf32> 4998func.func @strided_slice_grad(%grad: tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> { 4999 5000 // For StridedSlice 5001 // Dim #: 0, 1, 2 5002 // Input shape: [4, 128, 1024] 5003 // Begin: 1, 4, -3 5004 // End: 8, 65, 42 5005 // Stride: 1, 4, -1 5006 // Begin mask: 1, 0, 0 (= 1) 5007 // End mask: 0, 0, 1 (= 4) 5008 5009 // So result shape: 5010 // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 5011 // Dim #1: 4 to 65 stride 4: so 16 5012 // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022 5013 // result shape: [4, 16, 1022] 5014 5015 // To pad back: 5016 // Dim #: 0, 1, 2 5017 // Pad low: 0, 4, 0 5018 // Pad interm: 0, 3, 0 5019 // Pad high: 0, 63, 2 5020 5021 %shape = "tf.Const"() {value = dense<[4, 128, 1024]> : tensor<3xi32>} : () -> (tensor<3xi32>) 5022 %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) 5023 %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) 5024 %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) 5025 5026 // CHECK: [[RESHAPE:%.*]] = mhlo.reshape %arg0 : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> 5027 // CHECK: [[REVERSE:%.*]] = "mhlo.reverse"([[RESHAPE]]) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> 5028 // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5029 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REVERSE]], [[ZERO]]) {edge_padding_high = dense<[0, 63, 2]> : tensor<3xi64>, edge_padding_low = dense<[0, 4, 0]> : tensor<3xi64>, interior_padding = dense<[0, 3, 0]> : tensor<3xi64>} : (tensor<4x16x1022xf32>, tensor<f32>) -> tensor<4x128x1024xf32> 5030 5031 %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 1, end_mask = 4} : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> 5032 // CHECK: return [[PAD]] 5033 func.return %0: tensor<4x128x1024xf32> 5034} 5035 5036// ----- 5037 5038// CHECK-LABEL: strided_slice_grad_shrink_axis_mask 5039// CHECK-SAME: [[GRAD:%.*]]: tensor<8xf32> 5040func.func @strided_slice_grad_shrink_axis_mask(%grad: tensor<8xf32>) -> tensor<4x8xf32> { 5041 // Input to StridedSlice was of shape 4x8xf32 5042 // Strided slice gets input[2:3, 0:8] 5043 // shrink_axis_mask is 1 denoting that dim#0 is shrunk. So the output is 8xf32 5044 // which is the shape of gradient. 5045 // StridedSliceGrad would reshape the gradient to 1x8xf32 and 5046 // then pad to match the shape of input 4x8xf32. 5047 5048 %shape = "tf.Const"() {value = dense<[4, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>) 5049 %begin = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> (tensor<2xi32>) 5050 %end = "tf.Const"() {value = dense<[3, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>) 5051 %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) 5052 5053 // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<8xf32>) -> tensor<1x8xf32> 5054 // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5055 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) 5056 // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0]> : tensor<2xi64> 5057 // CHECK-DAG-SAME: edge_padding_high = dense<[1, 0]> : tensor<2xi64> 5058 // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<2xi64> 5059 %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, shrink_axis_mask = 1} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<8xf32>) -> tensor<4x8xf32> 5060 5061 // CHECK: return [[PAD]] : tensor<4x8xf32> 5062 func.return %0 : tensor<4x8xf32> 5063} 5064 5065// ----- 5066 5067// CHECK-LABEL: strided_slice_grad_new_axis_mask 5068// CHECK-SAME: [[GRAD:%.*]]: tensor<1x2xf32> 5069func.func @strided_slice_grad_new_axis_mask(%grad: tensor<1x2xf32>) -> tensor<8xf32> { 5070 // Input to StridedSlice was of shape 8xf32 5071 // Strided slice gets input[tf.new_axis, 2:4] 5072 // new_axis_mask is 1 denoting new axis is inserted at dim#0. So the output is 5073 // 1x2xf32 which is the shape of gradient. 5074 // StridedSliceGrad would reshape the gradient to 2xf32 and 5075 // then pad to match the shape of input 4x8xf32. 5076 5077 %shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 5078 %begin = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>) 5079 %end = "tf.Const"() {value = dense<[0, 4]> : tensor<2xi32>} : () -> (tensor<2xi32>) 5080 %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) 5081 5082 // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<1x2xf32>) -> tensor<2xf32> 5083 // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5084 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) 5085 // CHECK-DAG-SAME: edge_padding_low = dense<2> : tensor<1xi64> 5086 // CHECK-DAG-SAME: edge_padding_high = dense<4> : tensor<1xi64> 5087 // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<1xi64> 5088 %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, new_axis_mask = 1} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1x2xf32>) -> tensor<8xf32> 5089 5090 // CHECK: return [[PAD]] : tensor<8xf32> 5091 func.return %0 : tensor<8xf32> 5092} 5093 5094// ----- 5095 5096// CHECK-LABEL: strided_slice_grad_ellipsis_mask 5097// CHECK-SAME: [[GRAD:%.*]]: tensor<2x4x8xf32> 5098func.func @strided_slice_grad_ellipsis_mask(%grad: tensor<2x4x8xf32>) -> tensor<4x4x8xf32> { 5099 // Input to StridedSlice was of shape 4x4x8xf32 5100 // Strided slice gets input[2:4, ...] 5101 // ellipsis_mask is 2 denoting that slice contains all elements in dim#1 and 5102 // dim#2, ignoring begin and end indices for these dimensions. So the output 5103 // is 2x4x8xf32 which is the shape of gradient. 5104 // StridedSliceGrad would pad the gradient to match the shape of 5105 // input 4x4x8xf32. 5106 5107 %shape = "tf.Const"() {value = dense<[4, 4, 8]> : tensor<3xi32>} : () -> (tensor<3xi32>) 5108 %begin = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 5109 %end = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>) 5110 %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) 5111 5112 // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<2x4x8xf32>) -> tensor<2x4x8xf32> 5113 // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5114 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) 5115 // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0, 0]> : tensor<3xi64> 5116 // CHECK-DAG-SAME: edge_padding_high = dense<0> : tensor<3xi64> 5117 // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<3xi64> 5118 %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, ellipsis_mask = 2} : (tensor<3xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2x4x8xf32>) -> tensor<4x4x8xf32> 5119 5120 // CHECK: return [[PAD]] : tensor<4x4x8xf32> 5121 func.return %0 : tensor<4x4x8xf32> 5122} 5123 5124 5125// CHECK-LABEL: strided_slice_grad_all_masks 5126// CHECK-SAME: [[GRAD:%.*]]: tensor<1x4x8x8x10x2x1xf32> 5127func.func @strided_slice_grad_all_masks(%grad: tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> { 5128 // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis] 5129 // New axis mask is at index 1 and 6 of sparse spec, so 5130 // new_axis_mask = 2^1 + 2^6 = 66 5131 // The ellipsis mask is applied to dim #1, #2 of input i.e, we get 5132 // canonicalized slice input[1, :, :, 8:, :10, 2:6:2] 5133 // The StridedSliceGrad op would propogate the gradient for the sliced tensor 5134 // to the original input tensor by padding with zeroes. 5135 5136 %shape = "tf.Const"() {value = dense<[2, 4, 8, 16, 32, 64]> : tensor<6xi32>} : () -> (tensor<6xi32>) 5137 %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) 5138 %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) 5139 %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>) 5140 5141 // Remove 2 new axes (at index 1 and 6) and 1 shrink axis (at index 0) 5142 // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<1x4x8x8x10x2x1xf32>) -> tensor<1x4x8x8x10x2xf32> 5143 // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5144 // The edge_padding_low, edge_padding_high and interior_padding attributes of 5145 // mhlo.pad would reflect the padding required to get the shape of the 5146 // input of StridedSlice op. 5147 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZERO]]) 5148 // CHECK-DAG-SAME: edge_padding_low = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> 5149 // CHECK-DAG-SAME: edge_padding_high = dense<[0, 0, 0, 0, 22, 59]> : tensor<6xi64> 5150 // CHECK-DAG-SAME: interior_padding = dense<[0, 0, 0, 0, 0, 1]> : tensor<6xi64> 5151 %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<6xi32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>, tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> 5152 5153 // CHECK: return [[PAD]] : tensor<2x4x8x16x32x64xf32> 5154 func.return %0 : tensor<2x4x8x16x32x64xf32> 5155} 5156 5157// ----- 5158 5159// CHECK-LABEL: @tensor_scatter_update 5160func.func @tensor_scatter_update(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> { 5161 // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({ 5162 // CHECK: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): 5163 // CHECK: mhlo.return %arg4 : tensor<f32> 5164 // CHECK: }) 5165 // CHECK-SAME: indices_are_sorted = false 5166 // CHECK-SAME: scatter_dimension_numbers 5167 // CHECK-SAME: update_window_dims = [1] 5168 // CHECK-SAME: inserted_window_dims = [0, 1] 5169 // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] 5170 // CHECK-SAME: index_vector_dim = 1 5171 // CHECK-SAME: unique_indices = false 5172 %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32> 5173 func.return %0 : tensor<?x?x?xf32> 5174} 5175 5176// ----- 5177 5178// CHECK-LABEL: @tensor_scatter_update_scalar_update 5179func.func @tensor_scatter_update_scalar_update(%tensor: tensor<4x3xi32>, %indices: tensor<2x1xi32>, %updates: tensor<i32>) -> tensor<4x3xi32> { 5180 // CHECK: mhlo.constant dense<[2, 3]> : tensor<2xi64> 5181 // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg2, %0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<2xi64>) -> tensor<2x3xi32> 5182 // CHECK: "mhlo.scatter" 5183 %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<4x3xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<4x3xi32> 5184 func.return %0 : tensor<4x3xi32> 5185} 5186 5187// ----- 5188 5189// CHECK-LABEL: @tensor_scatter_add 5190func.func @tensor_scatter_add(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> { 5191 // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({ 5192 // CHECK: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): 5193 // CHECK: %1 = mhlo.add %arg3, %arg4 : tensor<f32> 5194 // CHECK: mhlo.return %1 : tensor<f32> 5195 // CHECK: }) 5196 // CHECK-SAME: indices_are_sorted = false 5197 // CHECK-SAME: scatter_dimension_numbers 5198 // CHECK-SAME: update_window_dims = [1] 5199 // CHECK-SAME: inserted_window_dims = [0, 1] 5200 // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] 5201 // CHECK-SAME: index_vector_dim = 1 5202 // CHECK-SAME: unique_indices = false 5203 %0 = "tf.TensorScatterAdd"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32> 5204 func.return %0 : tensor<?x?x?xf32> 5205} 5206 5207// ----- 5208 5209// CHECK-LABEL: @tensor_scatter_add_scalar_update 5210func.func @tensor_scatter_add_scalar_update(%tensor: tensor<4x3xi32>, %indices: tensor<2x1xi32>, %updates: tensor<i32>) -> tensor<4x3xi32> { 5211 // CHECK: mhlo.constant dense<[2, 3]> : tensor<2xi64> 5212 // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg2, %0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<2xi64>) -> tensor<2x3xi32> 5213 // CHECK: "mhlo.scatter" 5214 %0 = "tf.TensorScatterAdd"(%tensor, %indices, %updates) : (tensor<4x3xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<4x3xi32> 5215 func.return %0 : tensor<4x3xi32> 5216} 5217 5218// ----- 5219 5220// CHECK-LABEL: @tensor_scatter_sub 5221func.func @tensor_scatter_sub(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> { 5222 // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({ 5223 // CHECK: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): 5224 // CHECK: %1 = mhlo.subtract %arg3, %arg4 : tensor<f32> 5225 // CHECK: mhlo.return %1 : tensor<f32> 5226 // CHECK: }) 5227 // CHECK-SAME: indices_are_sorted = false 5228 // CHECK-SAME: scatter_dimension_numbers 5229 // CHECK-SAME: update_window_dims = [1] 5230 // CHECK-SAME: inserted_window_dims = [0, 1] 5231 // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] 5232 // CHECK-SAME: index_vector_dim = 1 5233 // CHECK-SAME: unique_indices = false 5234 %0 = "tf.TensorScatterSub"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32> 5235 func.return %0 : tensor<?x?x?xf32> 5236} 5237 5238// ----- 5239 5240// CHECK-LABEL: @tensor_scatter_min 5241func.func @tensor_scatter_min(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> { 5242 // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({ 5243 // CHECK: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): 5244 // CHECK: %1 = mhlo.minimum %arg3, %arg4 : tensor<f32> 5245 // CHECK: mhlo.return %1 : tensor<f32> 5246 // CHECK: }) 5247 // CHECK-SAME: indices_are_sorted = false 5248 // CHECK-SAME: scatter_dimension_numbers 5249 // CHECK-SAME: update_window_dims = [1] 5250 // CHECK-SAME: inserted_window_dims = [0, 1] 5251 // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] 5252 // CHECK-SAME: index_vector_dim = 1 5253 // CHECK-SAME: unique_indices = false 5254 %0 = "tf.TensorScatterMin"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32> 5255 func.return %0 : tensor<?x?x?xf32> 5256} 5257 5258// ----- 5259 5260// CHECK-LABEL: @tensor_scatter_max 5261func.func @tensor_scatter_max(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> { 5262 // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({ 5263 // CHECK: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): 5264 // CHECK: %1 = mhlo.maximum %arg3, %arg4 : tensor<f32> 5265 // CHECK: mhlo.return %1 : tensor<f32> 5266 // CHECK: }) 5267 // CHECK-SAME: indices_are_sorted = false 5268 // CHECK-SAME: scatter_dimension_numbers 5269 // CHECK-SAME: update_window_dims = [1] 5270 // CHECK-SAME: inserted_window_dims = [0, 1] 5271 // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] 5272 // CHECK-SAME: index_vector_dim = 1 5273 // CHECK-SAME: unique_indices = false 5274 %0 = "tf.TensorScatterMax"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32> 5275 func.return %0 : tensor<?x?x?xf32> 5276} 5277 5278//===----------------------------------------------------------------------===// 5279// tf.RandomShuffle legalization 5280//===----------------------------------------------------------------------===// 5281 5282// ----- 5283 5284// CHECK-LABEL: @random_shuffle_first_dim_1 5285// CHECK-SAME: [[INPUT:%.*]]: tensor<1x?xf32> 5286func.func @random_shuffle_first_dim_1(%input: tensor<1x?xf32>) -> tensor<1x?xf32> { 5287 %0 = "tf.RandomShuffle"(%input) : (tensor<1x?xf32>) -> (tensor<1x?xf32>) 5288 // CHECK-NEXT: return [[INPUT]] 5289 func.return %0: tensor<1x?xf32> 5290} 5291 5292// ----- 5293 5294// CHECK-LABEL: @random_shuffle_1D_16 5295// CHECK-SAME: [[INPUT:%.*]]: tensor<16xf32> 5296func.func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { 5297 // CHECK-DAG: [[SHAPE:%.*]] = mhlo.constant dense<16> : tensor<1xi64> 5298 // CHECK-DAG: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor<i32> 5299 // CHECK-DAG: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor<i32> 5300 // CHECK: [[RNG:%.*]] = "mhlo.rng"([[LOWER]], [[UPPER]], [[SHAPE]]) {rng_distribution = #mhlo.rng_distribution<UNIFORM>} 5301 // CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) ({ 5302 // CHECK: ^{{.*}}([[ARG1:%.*]]: tensor<i32>, [[ARG2:%.*]]: tensor<i32>, {{.*}}: tensor<f32>, {{.*}}: tensor<f32>): 5303 // CHECK: mhlo.compare LT, [[ARG1]], [[ARG2]], TOTALORDER 5304 // CHECK: }) {dimension = -1 : i64, is_stable = {{.*}}} : (tensor<16xi32>, tensor<16xf32>) -> (tensor<16xi32>, tensor<16xf32>) 5305 // CHECK: return [[SORT]]#1 5306 %0 = "tf.RandomShuffle"(%input) : (tensor<16xf32>) -> (tensor<16xf32>) 5307 func.return %0: tensor<16xf32> 5308} 5309 5310// ----- 5311 5312// CHECK-LABEL: @random_shuffle_1D_10240 5313func.func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> { 5314 // CHECK: mhlo.rng{{.*UNIFORM.*}} 5315 // CHECK: mhlo.sort 5316 // CHECK: mhlo.rng{{.*UNIFORM.*}} 5317 // CHECK: mhlo.sort 5318 %0 = "tf.RandomShuffle"(%input) : (tensor<10240xf32>) -> (tensor<10240xf32>) 5319 func.return %0: tensor<10240xf32> 5320} 5321 5322// ----- 5323 5324// CHECK-LABEL: @random_shuffle_3D 5325// CHECK-SAME: [[INPUT:%.*]]: tensor<4x?x16xf32> 5326func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { 5327 // CHECK: [[INDICES:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> 5328 5329 // CHECK-DAG: [[RNG_SHAPE:%.*]] = mhlo.constant dense<4> : tensor<1xi64> 5330 // CHECK-DAG: [[RNG_LOWER:%.*]] = mhlo.constant dense<0> : tensor<i32> 5331 // CHECK-DAG: [[RNG_UPPER:%.*]] = mhlo.constant dense<4> : tensor<i32> 5332 // CHECK: [[SWAPS:%.*]] = "mhlo.rng"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]]) {rng_distribution = #mhlo.rng_distribution<UNIFORM>} 5333 5334 // CHECK: [[IV_INIT:%.*]] = mhlo.constant dense<0> : tensor<i32> 5335 5336 // CHECK: [[WHILE_OUT:%.*]]:3 = mhlo.while([[ITER_ARG0:.*]] = [[IV_INIT]], [[ITER_ARG1:.*]] = [[SWAPS]], [[ITER_ARG2:.*]] = [[INDICES]]) 5337 // CHECK: [[LIMIT:%.*]] = mhlo.constant dense<4> : tensor<i32> 5338 // CHECK: [[CMP:%.*]] = mhlo.compare LT, [[ITER_ARG0]], [[LIMIT]], NOTYPE 5339 // CHECK: mhlo.return [[CMP]] 5340 // CHECK: } do { 5341 // CHECK: [[SRC_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[ITER_ARG0]]) {slice_sizes = dense<1> : tensor<i64>} : (tensor<4xi32>, tensor<i32>) -> tensor<1xi32> 5342 // CHECK: [[SWP_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG1]], [[ITER_ARG0]]) {slice_sizes = dense<1> : tensor<i64>} : (tensor<4xi32>, tensor<i32>) -> tensor<1xi32> 5343 // CHECK: [[SWP:%.*]] = mhlo.reshape [[SWP_IDX]] : (tensor<1xi32>) -> tensor<i32> 5344 // CHECK: [[TGT_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[SWP]]) {slice_sizes = dense<1> : tensor<i64>} 5345 // CHECK: [[INDICES1:%.*]] = mhlo.dynamic_update_slice [[ITER_ARG2]], [[TGT_IDX]], [[ITER_ARG0]] : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32> 5346 // CHECK: [[INDICES2:%.*]] = mhlo.dynamic_update_slice [[INDICES1]], [[SRC_IDX]], [[SWP]] : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32> 5347 // CHECK: [[ONE:%.*]] = mhlo.constant dense<1> : tensor<i32> 5348 // CHECK: [[NEW_IV:%.*]] = chlo.broadcast_add [[ITER_ARG0]], [[ONE]] 5349 // CHECK: mhlo.return [[NEW_IV]], [[ITER_ARG1]], [[INDICES2]] 5350 // CHECK: } 5351 5352 // CHECK: [[GATHER:%.*]] = "mhlo.gather"([[INPUT]], [[WHILE_OUT]]#2) 5353 // CHECK-SAME: dimension_numbers = 5354 // CHECK-SAME: offset_dims = [1, 2] 5355 // CHECK-SAME: collapsed_slice_dims = [0] 5356 // CHECK-SAME: start_index_map = [0] 5357 // CHECK-SAME: index_vector_dim = 1 5358 // CHECK-SAME: indices_are_sorted = false 5359 // CHECK-SAME: slice_sizes = dense<[1, -1, 16]> 5360 // CHECK: (tensor<4x?x16xf32>, tensor<4xi32>) -> tensor<4x?x16xf32> 5361 5362 // CHECK: return [[GATHER]] 5363 5364 %0 = "tf.RandomShuffle"(%input) : (tensor<4x?x16xf32>) -> (tensor<4x?x16xf32>) 5365 func.return %0: tensor<4x?x16xf32> 5366} 5367 5368//===----------------------------------------------------------------------===// 5369// tf.AvgPool legalization 5370//===----------------------------------------------------------------------===// 5371 5372// ----- 5373 5374// CHECK-LABEL: @avgpool_valid_padding 5375// CHECK-SAME: [[ARG:%.+]]: tensor<2x12x21x7xf16> 5376// CHECK: [[CONV32:%.+]] = mhlo.convert(%arg0) : (tensor<2x12x21x7xf16>) -> tensor<2x12x21x7xf32> 5377// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5378// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ({ 5379// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>): 5380// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] 5381// CHECK: mhlo.return [[ADD]] 5382// CHECK: }) 5383// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> 5384// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> 5385// CHECK-SAME: -> tensor<2x3x5x7xf32> 5386// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 5387// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] 5388// CHECK-SAME: broadcast_dimensions = dense<> 5389// CHECK-SAME: -> tensor<2x3x5x7xf32> 5390// CHECK: [[CONV16:%.+]] = mhlo.convert([[DIV_RESULT]]) 5391// CHECK-SAME: -> tensor<2x3x5x7xf16> 5392// CHECK: return [[CONV16]] 5393func.func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> { 5394 %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> 5395 func.return %0 : tensor<2x3x5x7xf16> 5396} 5397 5398// ----- 5399 5400// CHECK-LABEL: @avgpool_3d_valid_padding 5401// CHECK-SAME: [[ARG:%.+]]: tensor<2x4x12x21x7xf16> 5402// CHECK: [[CONV32:%.+]] = mhlo.convert(%arg0) : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x12x21x7xf32> 5403// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5404// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ({ 5405// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>): 5406// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] 5407// CHECK: mhlo.return [[ADD]] 5408// CHECK: }) 5409// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> 5410// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> 5411// CHECK-SAME: -> tensor<2x4x3x5x7xf32> 5412// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 5413// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] 5414// CHECK-SAME: broadcast_dimensions = dense<> 5415// CHECK-SAME: -> tensor<2x4x3x5x7xf32> 5416// CHECK: [[CONV16:%.+]] = mhlo.convert([[DIV_RESULT]]) 5417// CHECK-SAME: -> tensor<2x4x3x5x7xf16> 5418// CHECK: return [[CONV16]] 5419func.func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> { 5420 %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> 5421 func.return %0 : tensor<2x4x3x5x7xf16> 5422} 5423 5424// ----- 5425 5426// CHECK-LABEL: @avgpool_nchw_format 5427// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x12x21xf16> 5428// CHECK: [[CONV32:%.+]] = mhlo.convert(%arg0) : (tensor<2x7x12x21xf16>) -> tensor<2x7x12x21xf32> 5429// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5430// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ({ 5431// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>): 5432// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] 5433// CHECK: mhlo.return [[ADD]] 5434// CHECK: }) 5435// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2]> 5436// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> 5437// CHECK-SAME: -> tensor<2x7x3x5xf32> 5438// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 5439// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] 5440// CHECK-SAME: broadcast_dimensions = dense<> 5441// CHECK-SAME: -> tensor<2x7x3x5xf32> 5442// CHECK: [[CONV16:%.+]] = mhlo.convert([[DIV_RESULT]]) 5443// CHECK-SAME: -> tensor<2x7x3x5xf16> 5444// CHECK: return [[CONV16]] 5445func.func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> { 5446 %0 = "tf.AvgPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 2, 2], padding = "VALID", strides = [1, 1, 4, 4]} : (tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> 5447 func.return %0 : tensor<2x7x3x5xf16> 5448} 5449 5450// ----- 5451 5452// CHECK-LABEL: @avgpool_3d_ncdhw_format 5453// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x4x12x21xf16> 5454// CHECK: [[CONV32:%.+]] = mhlo.convert(%arg0) : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x12x21xf32> 5455// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5456// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ({ 5457// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>): 5458// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] 5459// CHECK: mhlo.return [[ADD]] 5460// CHECK: }) 5461// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 2]> 5462// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> 5463// CHECK-SAME: -> tensor<2x7x4x3x5xf32> 5464// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 5465// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] 5466// CHECK-SAME: broadcast_dimensions = dense<> 5467// CHECK-SAME: -> tensor<2x7x4x3x5xf32> 5468// CHECK: [[CONV16:%.+]] = mhlo.convert([[DIV_RESULT]]) 5469// CHECK-SAME: -> tensor<2x7x4x3x5xf16> 5470// CHECK: return [[CONV16]] 5471func.func @avgpool_3d_ncdhw_format(%arg0: tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> { 5472 %0 = "tf.AvgPool3D"(%arg0) {data_format = "NCDHW", ksize = [1, 1, 1, 2, 2], padding = "VALID", strides = [1, 1, 1, 4, 4]} : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> 5473 func.return %0 : tensor<2x7x4x3x5xf16> 5474} 5475 5476// ----- 5477 5478// CHECK-LABEL: @avgpool_same_padding( 5479// CHECK-SAME: %[[ARG0:.*]]: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> 5480// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5481// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ({ 5482// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 5483// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 5484// CHECK: mhlo.return %[[SUM1]] : tensor<f32> 5485// CHECK: }) 5486// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> 5487// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> 5488// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> 5489// CHECK-SAME: -> tensor<2x4x6x7xf32> 5490// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x12x21x7xf32> 5491// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ({ 5492// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 5493// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 5494// CHECK: mhlo.return %[[SUM2]] : tensor<f32> 5495// CHECK: }) 5496// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> 5497// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> 5498// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> 5499// CHECK-SAME: -> tensor<2x4x6x7xf32> 5500// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] : tensor<2x4x6x7xf32> 5501// CHECK: return %[[RESULT]] : tensor<2x4x6x7xf32> 5502// CHECK: } 5503func.func @avgpool_same_padding(%arg0: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> { 5504 %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 5, 2, 1], padding = "SAME", strides = [1, 3, 4, 1]} : (tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> 5505 func.return %0 : tensor<2x4x6x7xf32> 5506} 5507 5508// ----- 5509 5510// CHECK-LABEL: @avgpool_3d_same_padding( 5511// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> 5512// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5513// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ({ 5514// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 5515// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 5516// CHECK: mhlo.return %[[SUM1]] : tensor<f32> 5517// CHECK: }) 5518// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> 5519// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> 5520// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> 5521// CHECK-SAME: -> tensor<2x4x4x6x7xf32> 5522// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x4x12x21x7xf32> 5523// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ({ 5524// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 5525// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 5526// CHECK: mhlo.return %[[SUM2]] : tensor<f32> 5527// CHECK: }) 5528// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> 5529// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> 5530// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> 5531// CHECK-SAME: -> tensor<2x4x4x6x7xf32> 5532// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] 5533// CHECK: return %[[RESULT]] : tensor<2x4x4x6x7xf32> 5534// CHECK: } 5535func.func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> { 5536 %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 5, 2, 1], padding = "SAME", strides = [1, 1, 3, 4, 1]} : (tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> 5537 func.return %0 : tensor<2x4x4x6x7xf32> 5538} 5539 5540//===----------------------------------------------------------------------===// 5541// AvgPoolGrad op legalizations. 5542//===----------------------------------------------------------------------===// 5543 5544// ----- 5545 5546// CHECK-LABEL: @avgpool_grad_valid_padding( 5547// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { 5548// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5549// CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 5550// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] 5551// CHECK-SAME: broadcast_dimensions = dense<> 5552// CHECK-SAME: -> tensor<10x12x16x64xf32> 5553// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 5554// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> 5555// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> 5556// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> 5557// CHECK-SAME: -> tensor<10x25x33x64xf32> 5558// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({ 5559// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 5560// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 5561// CHECK: mhlo.return %[[SUM]] : tensor<f32> 5562// CHECK: }) 5563// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> 5564// CHECK-SAME: window_strides = dense<1> 5565// CHECK-SAME: -> tensor<10x24x32x64xf32> 5566// CHECK: return %[[RESULT]] : tensor<10x24x32x64xf32> 5567func.func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { 5568 %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>) 5569 %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { 5570 data_format = "NHWC", 5571 ksize = [1, 2, 2, 1], 5572 padding = "VALID", 5573 strides = [1, 2, 2, 1] 5574 } : (tensor<4xi32>, tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> 5575 func.return %result : tensor<10x24x32x64xf32> 5576} 5577 5578// ----- 5579 5580// CHECK-LABEL: @avgpool_3d_grad_valid_padding( 5581// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { 5582// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5583// CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 5584// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor<f32>) -> tensor<10x8x12x16x64xf32> 5585// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 5586// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 1, 0]> 5587// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> 5588// CHECK-SAME: interior_padding = dense<[0, 0, 1, 1, 0]> 5589// CHECK-SAME: -> tensor<10x8x25x33x64xf32> 5590// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({ 5591// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 5592// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 5593// CHECK: mhlo.return %[[SUM]] : tensor<f32> 5594// CHECK: }) 5595// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> 5596// CHECK-SAME: window_strides = dense<1> 5597// CHECK-SAME: -> tensor<10x8x24x32x64xf32> 5598// CHECK: return %[[RESULT]] : tensor<10x8x24x32x64xf32> 5599func.func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { 5600 %orig_input_shape = "tf.Const"() {value = dense<[10, 8, 24, 32, 64]> : tensor<5xi32>} : () -> (tensor<5xi32>) 5601 %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { 5602 data_format = "NDHWC", 5603 ksize = [1, 1, 2, 2, 1], 5604 padding = "VALID", 5605 strides = [1, 1, 2, 2, 1]} : (tensor<5xi32>, tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> 5606 func.return %result : tensor<10x8x24x32x64xf32> 5607} 5608 5609// ----- 5610 5611// CHECK-LABEL: @avgpool_grad_same_padding( 5612// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { 5613// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5614// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32> 5615// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ({ 5616// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 5617// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 5618// CHECK: mhlo.return %[[SUM1]] : tensor<f32> 5619// CHECK: }) 5620// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> 5621// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> 5622// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> 5623// CHECK-SAME: -> tensor<2x4x7x9xf32> 5624// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x4x7x9xf32> 5625// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 5626// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 0]> 5627// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> 5628// CHECK-SAME: interior_padding = dense<[0, 3, 3, 0]> 5629// CHECK-SAME: -> tensor<2x14x27x9xf32> 5630// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({ 5631// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 5632// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 5633// CHECK: mhlo.return %[[SUM2]] : tensor<f32> 5634// CHECK: }) 5635// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> 5636// CHECK-SAME: window_strides = dense<1> 5637// CHECK-SAME: -> tensor<2x13x25x9xf32> 5638// CHECK: return %[[RESULT]] : tensor<2x13x25x9xf32> 5639func.func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { 5640 %orig_input_shape = "tf.Const"() {value = dense<[2, 13, 25, 9]> : tensor<4xi32>} : () -> (tensor<4xi32>) 5641 %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { 5642 data_format = "NHWC", 5643 ksize = [1, 2, 3, 1], 5644 padding = "SAME", 5645 strides = [1, 4, 4, 1] 5646 } : (tensor<4xi32>, tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> 5647 func.return %result : tensor<2x13x25x9xf32> 5648} 5649 5650// ----- 5651 5652// CHECK-LABEL: @avgpool_3d_grad_same_padding( 5653// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { 5654// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5655// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32> 5656// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ({ 5657// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 5658// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 5659// CHECK: mhlo.return %[[SUM1]] : tensor<f32> 5660// CHECK: }) 5661// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> 5662// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> 5663// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> 5664// CHECK-SAME: -> tensor<2x8x4x7x9xf32> 5665// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x8x4x7x9xf32> 5666// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 5667// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1, 0]> 5668// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> 5669// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3, 0]> 5670// CHECK-SAME: -> tensor<2x8x14x27x9xf32> 5671// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({ 5672// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 5673// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 5674// CHECK: mhlo.return %[[SUM2]] : tensor<f32> 5675// CHECK: }) 5676// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> 5677// CHECK-SAME: window_strides = dense<1> 5678// CHECK-SAME: -> tensor<2x8x13x25x9xf32> 5679// CHECK: return %[[RESULT]] : tensor<2x8x13x25x9xf32> 5680func.func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { 5681 %orig_input_shape = "tf.Const"() {value = dense<[2, 8, 13, 25, 9]> : tensor<5xi32>} : () -> (tensor<5xi32>) 5682 %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { 5683 data_format = "NDHWC", 5684 ksize = [1, 1, 2, 3, 1], 5685 padding = "SAME", 5686 strides = [1, 1, 4, 4, 1]} : (tensor<5xi32>, tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> 5687 func.return %result : tensor<2x8x13x25x9xf32> 5688} 5689 5690// ----- 5691 5692// CHECK-LABEL: @avgpool_grad_nchw_format( 5693// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { 5694// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5695// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32> 5696// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ({ 5697// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 5698// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 5699// CHECK: mhlo.return %[[SUM1]] : tensor<f32> 5700// CHECK: }) 5701// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1]]> 5702// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> 5703// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> 5704// CHECK-SAME: -> tensor<2x9x4x7xf32> 5705// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x4x7xf32> 5706// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 5707// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1]> 5708// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1]> 5709// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3]> 5710// CHECK-SAME: -> tensor<2x9x14x27xf32> 5711// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({ 5712// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 5713// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 5714// CHECK: mhlo.return %[[SUM2]] : tensor<f32> 5715// CHECK: }) 5716// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> 5717// CHECK-SAME: window_strides = dense<1> 5718// CHECK-SAME: -> tensor<2x9x13x25xf32> 5719// CHECK: return %[[RESULT]] : tensor<2x9x13x25xf32> 5720func.func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { 5721 %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 13, 25]> : tensor<4xi32>} : () -> (tensor<4xi32>) 5722 %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { 5723 data_format = "NCHW", 5724 ksize = [1, 1, 2, 3], 5725 padding = "SAME", 5726 strides = [1, 1, 4, 4] 5727 } : (tensor<4xi32>, tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> 5728 func.return %result : tensor<2x9x13x25xf32> 5729} 5730 5731// ----- 5732 5733// CHECK-LABEL: @avgpool_3d_grad_ncdwh_format( 5734// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { 5735// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5736// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32> 5737// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ({ 5738// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 5739// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 5740// CHECK: mhlo.return %[[SUM1]] : tensor<f32> 5741// CHECK: }) 5742// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 0], [0, 1], [1, 1]]> 5743// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> 5744// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> 5745// CHECK-SAME: -> tensor<2x9x8x4x7xf32> 5746// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x8x4x7xf32> 5747// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 5748// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 0, 1]> 5749// CHECK-SAME: edge_padding_low = dense<[0, 0, 0, 1, 1]> 5750// CHECK-SAME: interior_padding = dense<[0, 0, 0, 3, 3]> 5751// CHECK-SAME: -> tensor<2x9x8x14x27xf32> 5752// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ({ 5753// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 5754// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 5755// CHECK: mhlo.return %[[SUM2]] : tensor<f32> 5756// CHECK: }) 5757// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> 5758// CHECK-SAME: window_strides = dense<1> : tensor<5xi64> 5759// CHECK-SAME: -> tensor<2x9x8x13x25xf32> 5760// CHECK: return %[[RESULT]] : tensor<2x9x8x13x25xf32> 5761func.func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { 5762 %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 8, 13, 25]> : tensor<5xi32>} : () -> (tensor<5xi32>) 5763 %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { 5764 data_format = "NCDHW", 5765 ksize = [1, 1, 1, 2, 3], 5766 padding = "SAME", 5767 strides = [1, 1, 1, 4, 4]} : (tensor<5xi32>, tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> 5768 func.return %result : tensor<2x9x8x13x25xf32> 5769} 5770 5771// ----- 5772 5773// CHECK-LABEL: @avgpool_grad_bf16( 5774// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { 5775// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<bf16> 5776// CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<bf16> 5777// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] 5778// CHECK-SAME: broadcast_dimensions = dense<> 5779// CHECK-SAME: -> tensor<10x12x16x64xbf16> 5780// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 5781// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> 5782// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> 5783// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> 5784// CHECK-SAME: -> tensor<10x25x33x64xbf16> 5785// CHECK: %[[REDUCE_WINDOW_INPUT_CONVERTED:.*]] = mhlo.convert(%[[REDUCE_WINDOW_INPUT]]) : (tensor<10x25x33x64xbf16>) -> tensor<10x25x33x64xf32> 5786// CHECK: %[[ZERO_F32:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5787// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT_CONVERTED]], %[[ZERO_F32]]) ({ 5788// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 5789// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 5790// CHECK: mhlo.return %[[SUM]] : tensor<f32> 5791// CHECK: }) 5792// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> 5793// CHECK-SAME: window_strides = dense<1> 5794// CHECK-SAME: -> tensor<10x24x32x64xf32> 5795// CHECK: %[[RESULT_CONVERTED:.*]] = mhlo.convert(%[[RESULT]]) : (tensor<10x24x32x64xf32>) -> tensor<10x24x32x64xbf16> 5796// CHECK: return %[[RESULT_CONVERTED]] : tensor<10x24x32x64xbf16> 5797func.func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { 5798 %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>) 5799 %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { 5800 data_format = "NHWC", 5801 ksize = [1, 2, 2, 1], 5802 padding = "VALID", 5803 strides = [1, 2, 2, 1] 5804 } : (tensor<4xi32>, tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> 5805 func.return %result : tensor<10x24x32x64xbf16> 5806} 5807 5808// ----- 5809 5810// CHECK-LABEL: xla_sharding 5811func.func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { 5812 // CHECK-NEXT: "mhlo.custom_call"(%arg0) {call_target_name = "Sharding", mhlo.sharding = ""} 5813 %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "", sharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32> 5814 func.return %0 : tensor<4x16xf32> 5815} 5816 5817// ----- 5818 5819// CHECK-LABEL: inplace_update_one 5820func.func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> { 5821 // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> 5822 // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 5823 // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} 5824 // CHECK-DAG: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] 5825 // CHECK-DAG: [[UPDATE:%.+]] = mhlo.dynamic_update_slice %arg0, [[SLICE2]], [[RESHAPE1]], [[CST]] 5826 %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32> 5827 5828 // CHECK: return [[UPDATE]] 5829 func.return %0 : tensor<8x4xf32> 5830} 5831 5832// ----- 5833 5834// CHECK-LABEL: inplace_update_three 5835func.func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> { 5836 // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> 5837 // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 5838 // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 5839 // CHECK-DAG: [[SLICE3:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 5840 // CHECK-DAG: [[SLICE4:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} 5841 // CHECK-DAG: [[SLICE5:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} 5842 // CHECK-DAG: [[SLICE6:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} 5843 // CHECK-DAG: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] 5844 // CHECK-DAG: [[RESHAPE2:%.+]] = mhlo.reshape [[SLICE2]] 5845 // CHECK-DAG: [[RESHAPE3:%.+]] = mhlo.reshape [[SLICE3]] 5846 // CHECK-DAG: [[UPDATE1:%.+]] = mhlo.dynamic_update_slice %arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]] 5847 // CHECK-DAG: [[UPDATE2:%.+]] = mhlo.dynamic_update_slice [[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]] 5848 // CHECK-DAG: [[UPDATE3:%.+]] = mhlo.dynamic_update_slice [[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]] 5849 %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32> 5850 5851 // CHECK: return [[UPDATE3]] : tensor<8x8x4xf32> 5852 func.return %0 : tensor<8x8x4xf32> 5853} 5854 5855// ----- 5856 5857// CHECK-LABEL: xla_dynamic_update_slice 5858func.func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> { 5859 // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> 5860 // CHECK: [[RESHAPE0:%.+]] = mhlo.reshape [[SLICE0]] : (tensor<1xi32>) -> tensor<i32> 5861 // CHECK: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> 5862 // CHECK: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] : (tensor<1xi32>) -> tensor<i32> 5863 // CHECK: [[DUS:%.+]] = mhlo.dynamic_update_slice %arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]] : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<i32>, tensor<i32>) -> tensor<4x16xf32> 5864 // CHECK: return [[DUS]] 5865 %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<2xi32>) -> tensor<4x16xf32> 5866 func.return %0 : tensor<4x16xf32> 5867} 5868 5869// ----- 5870 5871// CHECK-LABEL: xla_dynamic_update_slice2 5872func.func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<1xi32>) -> tensor<4xf32> { 5873 // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> 5874 // CHECK: [[RESHAPE0:%.+]] = mhlo.reshape [[SLICE0]] : (tensor<1xi32>) -> tensor<i32> 5875 // CHECK: [[DUS:%.+]] = mhlo.dynamic_update_slice %arg0, %arg1, [[RESHAPE0]] : (tensor<4xf32>, tensor<2xf32>, tensor<i32>) -> tensor<4xf32> 5876 // CHECK: return [[DUS]] 5877 %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<1xi32>) -> tensor<4xf32> 5878 func.return %0 : tensor<4xf32> 5879} 5880 5881//===----------------------------------------------------------------------===// 5882// AllToAll op legalizations. 5883//===----------------------------------------------------------------------===// 5884 5885// ----- 5886 5887// CHECK-LABEL: func @alltoall_basic 5888// See https://www.tensorflow.org/api_docs/python/tf/raw_ops/AllToAll 5889func.func @alltoall_basic(%input: tensor<1x2xf32>) -> tensor<2x1xf32> { 5890 %group_assignment = "tf.Const" () { 5891 value = dense<[[0, 1]]> : tensor<1x2xi32> 5892 } : () -> tensor<1x2xi32> 5893 %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 0 : i64, split_count = 2 : i64, split_dimension = 1 : i64} : (tensor<1x2xf32>, tensor<1x2xi32>) -> tensor<2x1xf32> 5894 // CHECK: mhlo.all_to_all 5895 // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> 5896 func.return %result : tensor<2x1xf32> 5897} 5898 5899 5900//===----------------------------------------------------------------------===// 5901// Cumsum op legalizations. 5902//===----------------------------------------------------------------------===// 5903 5904// ----- 5905 5906// CHECK-LABEL: func @cumsum_static 5907// CHECK-SAME: [[X:%.*]]: tensor<4xf32> 5908func.func @cumsum_static(%arg0: tensor<4xf32>) -> tensor<4xf32> { 5909 // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32> 5910 // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[X]] : tensor<4xf32> 5911 // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5912 // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ({ 5913 // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>): 5914 // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32> 5915 // CHECK: mhlo.return [[SUM]] : tensor<f32> 5916 // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 5917 // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[REDUCE]] : tensor<4xf32> 5918 // CHECK: return [[CONVERT_REDUCE]] 5919 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> 5920 %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> 5921 func.return %1 : tensor<4xf32> 5922} 5923 5924// ----- 5925 5926// CHECK-LABEL: func @cumsum_exclusive 5927// CHECK-SAME: [[X:%.*]]: tensor<4xf32> 5928func.func @cumsum_exclusive(%arg0: tensor<4xf32>) -> tensor<4xf32> { 5929 // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32> 5930 // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[X]] : tensor<4xf32> 5931 // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5932 // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ({ 5933 // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>): 5934 // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32> 5935 // CHECK: mhlo.return [[SUM]] : tensor<f32> 5936 // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 5937 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 5938 // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[PAD]] : tensor<4xf32> 5939 // CHECK: return [[CONVERT_REDUCE]] 5940 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> 5941 %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> 5942 func.return %1 : tensor<4xf32> 5943} 5944 5945// ----- 5946 5947// CHECK-LABEL: func @cumsum_reverse 5948// CHECK-SAME: [[X:%.*]]: tensor<4xf32> 5949func.func @cumsum_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { 5950 // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32> 5951 // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> 5952 // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[REVERSE1]] : tensor<4xf32> 5953 // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5954 // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ({ 5955 // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>): 5956 // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32> 5957 // CHECK: mhlo.return [[SUM]] : tensor<f32> 5958 // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 5959 // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[REDUCE]] : tensor<4xf32> 5960 // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> 5961 // CHECK: return [[REVERSE_BACK]] 5962 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> 5963 %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = true} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> 5964 func.return %1 : tensor<4xf32> 5965} 5966 5967// ----- 5968 5969// CHECK-LABEL: func @cumsum_exclusive_reverse 5970// CHECK-SAME: [[X:%.*]]: tensor<4xf32> 5971func.func @cumsum_exclusive_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { 5972 // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32> 5973 // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> 5974 // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[REVERSE1]] : tensor<4xf32> 5975 // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5976 // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ({ 5977 // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>): 5978 // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32> 5979 // CHECK: mhlo.return [[SUM]] : tensor<f32> 5980 // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 5981 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 5982 // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[PAD]] : tensor<4xf32> 5983 // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> 5984 // CHECK: return [[REVERSE_BACK]] 5985 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> 5986 %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = true} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> 5987 func.return %1 : tensor<4xf32> 5988} 5989 5990// ----- 5991 5992// CHECK-LABEL: func @cumsum_empty 5993func.func @cumsum_empty(%arg0: tensor<0xf32>) -> tensor<0xf32> { 5994 %0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> 5995 5996 // CHECK: mhlo.constant dense<> : tensor<0xf32> 5997 %1 = "tf.Cumsum"(%arg0, %0) : (tensor<0xf32>, tensor<i32>) -> tensor<0xf32> 5998 func.return %1 : tensor<0xf32> 5999} 6000 6001// ----- 6002 6003// CHECK-LABEL: func @cumsum_dynamic 6004func.func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32> { 6005 // CHECK: "tf.Cumsum" 6006 %0 = "tf.Cumsum"(%arg0, %arg1) : (tensor<?xf32>, tensor<i32>) -> tensor<?xf32> 6007 func.return %0 : tensor<?xf32> 6008} 6009 6010//===----------------------------------------------------------------------===// 6011// Cumprod op legalizations. 6012//===----------------------------------------------------------------------===// 6013 6014// ----- 6015 6016// CHECK-LABEL: func @cumprod 6017func.func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> { 6018 // CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 6019 // CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ({ 6020 // CHECK: mhlo.mul 6021 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> 6022 %1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> 6023 func.return %1 : tensor<4xf32> 6024} 6025 6026//===----------------------------------------------------------------------===// 6027// Qr op legalization 6028//===----------------------------------------------------------------------===// 6029 6030// CHECK: func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) 6031func.func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) { 6032 // The tf.Qr lowering is a full algorithm that is not effective to verify with 6033 // FileCheck. Just verify that it converted. 6034 // TODO(laurenzo): Move this out of the mainline tf2xla conversion as it is 6035 // really only applicable to certain legacy uses. 6036 // CHECK-NOT: "tf.Qr" 6037 %0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) 6038 func.return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32> 6039} 6040 6041//===----------------------------------------------------------------------===// 6042// tf.Softplus legalization 6043//===----------------------------------------------------------------------===// 6044 6045// ----- 6046 6047// CHECK-LABEL: func @softplus_f16 6048// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf16>) 6049func.func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> { 6050 // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] 6051 // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.220700e-04> : tensor<f16> 6052 // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] 6053 // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f16> 6054 // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] 6055 // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] 6056 // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #mhlo<comparison_direction GT>} 6057 // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #mhlo<comparison_direction LT>} 6058 // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] 6059 // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) 6060 // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) 6061 %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf16>) -> tensor<8x16xf16> 6062 6063 // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf16> 6064 func.return %0 : tensor<8x16xf16> 6065} 6066 6067// ----- 6068 6069// CHECK-LABEL: func @softplus_bf16 6070// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xbf16>) 6071func.func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> { 6072 // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] 6073 // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<7.812500e-03> : tensor<bf16> 6074 // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] 6075 // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<bf16> 6076 // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] 6077 // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] 6078 // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #mhlo<comparison_direction GT>} 6079 // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #mhlo<comparison_direction LT>} 6080 // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] 6081 // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) 6082 // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) 6083 %0 = "tf.Softplus"(%arg0) : (tensor<8x16xbf16>) -> tensor<8x16xbf16> 6084 6085 // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xbf16> 6086 func.return %0 : tensor<8x16xbf16> 6087} 6088 6089// ----- 6090 6091// CHECK-LABEL: func @softplus_f32 6092// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf32>) 6093func.func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { 6094 // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] 6095 // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.1920929E-7> : tensor<f32> 6096 // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] 6097 // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32> 6098 // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] 6099 // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] 6100 // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #mhlo<comparison_direction GT>} 6101 // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #mhlo<comparison_direction LT>} 6102 // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] 6103 // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) 6104 // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) 6105 %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> 6106 6107 // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf32> 6108 func.return %0 : tensor<8x16xf32> 6109} 6110 6111// ----- 6112 6113// CHECK-LABEL: func @softplus_f64 6114// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf64>) 6115func.func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> { 6116 // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] 6117 // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<2.2204460492503131E-16> : tensor<f64> 6118 // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] 6119 // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f64> 6120 // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] 6121 // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] 6122 // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #mhlo<comparison_direction GT>} 6123 // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #mhlo<comparison_direction LT>} 6124 // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] 6125 // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) 6126 // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) 6127 %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf64>) -> tensor<8x16xf64> 6128 6129 // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf64> 6130 func.return %0 : tensor<8x16xf64> 6131} 6132 6133// ----- 6134 6135// CHECK-LABEL: @xla_gather 6136func.func @xla_gather(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<1x300x10xf32> { 6137 %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi64> } : () -> tensor<3xi64> 6138 6139 // CHECK: "mhlo.gather" 6140 // CHECK-SAME: dimension_numbers = 6141 // CHECK-SAME: offset_dims = [0, 1] 6142 // CHECK-SAME: collapsed_slice_dims = [0] 6143 // CHECK-SAME: start_index_map = [0, 1] 6144 // CHECK-SAME: index_vector_dim = 1 6145 // CHECK-SAME: indices_are_sorted = true 6146 // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> 6147 6148 %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01\20\01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xf32> 6149 func.return %0 : tensor<1x300x10xf32> 6150} 6151 6152// ----- 6153 6154// CHECK-LABEL: @xla_gather_i32 6155func.func @xla_gather_i32(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<1x300x10xf32> { 6156 %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi32> } : () -> tensor<3xi32> 6157 6158 // CHECK: "mhlo.gather" 6159 // CHECK-SAME: dimension_numbers = 6160 // CHECK-SAME: offset_dims = [0, 1] 6161 // CHECK-SAME: collapsed_slice_dims = [0] 6162 // CHECK-SAME: start_index_map = [0, 1] 6163 // CHECK-SAME: index_vector_dim = 1 6164 // CHECK-SAME: indices_are_sorted = true 6165 // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> 6166 6167 %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01\20\01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi32>) -> tensor<1x300x10xf32> 6168 func.return %0 : tensor<1x300x10xf32> 6169} 6170 6171 6172// CHECK: func @stridedslice_with_i32 6173func.func @stridedslice_with_i32(%arg0: tensor<i32>) -> tensor<4xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "const_0_arg", outputs = "identity_0_retval_RetVal"}} { 6174// CHECK-NOT: tf.StridedSlice 6175// CHECK: [[DYNSLICE:%.*]] = "mhlo.dynamic_slice 6176// CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[DYNSLICE]] 6177// CHECK: return [[RESHAPE]] 6178 %0 = "tf.Const"() {value = dense<[[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32> 6179 %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 6180 %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 6181 %3 = "tf.AddV2"(%arg0, %1) {_xla_inferred_shapes = [#tf_type.shape<>], device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32> 6182 %4 = "tf.Pack"(%3) {_xla_inferred_shapes = [#tf_type.shape<1>], axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32> 6183 %5 = "tf.Pack"(%arg0) {_xla_inferred_shapes = [#tf_type.shape<1>], axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32> 6184 %6 = "tf.StridedSlice"(%0, %5, %4, %2) {_xla_inferred_shapes = [#tf_type.shape<4>], begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2x4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf32> 6185 func.return %6 : tensor<4xf32> 6186} 6187 6188func.func @replica_id() -> tensor<i32> { 6189 // CHECK: %[[ID:.*]] = mhlo.replica_id : tensor<ui32> 6190 // CHECK: %[[RESULT:.*]] = mhlo.convert(%0) : (tensor<ui32>) -> tensor<i32> 6191 %0 = "tf.XlaReplicaId"() : () -> tensor<i32> 6192 func.return %0 : tensor<i32> 6193} 6194 6195// CHECK: func @angle_c64 6196// CHECK-SAME: ([[ARG0:%.*]]: tensor<complex<f32>>) 6197func.func @angle_c64(%arg0: tensor<complex<f32>>) -> tensor<f32> { 6198// CHECK: [[IMAG:%.*]] = mhlo.imag([[ARG0]]) 6199// CHECK: [[REAL:%.*]] = mhlo.real([[ARG0]]) 6200// CHECK: [[ATAN2:%.*]] = mhlo.atan2 [[IMAG]], [[REAL]] 6201 %0 = "tf.Angle"(%arg0): (tensor<complex<f32>>) -> tensor<f32> 6202 func.return %0 : tensor<f32> 6203} 6204 6205//===----------------------------------------------------------------------===// 6206// tf.ApproximateEqual legalization 6207//===----------------------------------------------------------------------===// 6208 6209// CHECK-LABEL: func @approximateequal_f64 6210func.func @approximateequal_f64(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) -> tensor<?xi1> { 6211 // CHECK: %[[SUB:.*]] = mhlo.subtract %arg0, %arg1 : tensor<?xf64> 6212 // CHECK: %[[ABS:.*]] = mhlo.abs %[[SUB]] : tensor<?xf64> 6213 // CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32> 6214 // CHECK: %[[CONVERT:.*]] = mhlo.convert(%[[CST]]) : (tensor<f32>) -> tensor<f64> 6215 // CHECK: %[[LE:.*]] = chlo.broadcast_compare %[[ABS]], %[[CONVERT]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<?xf64>, tensor<f64>) -> tensor<?xi1> 6216 // CHECK: return %[[LE]] : tensor<?xi1> 6217 %equal = "tf.ApproximateEqual"(%arg0, %arg1) { tolerance = 2. : f32 } : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xi1> 6218 func.return %equal : tensor<?xi1> 6219} 6220 6221// CHECK-LABEL: func @approximateequal_i32 6222func.func @approximateequal_i32(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi1> { 6223 // CHECK: %[[SUB:.*]] = mhlo.subtract %arg0, %arg1 : tensor<?xi32> 6224 // CHECK: %[[ABS:.*]] = mhlo.abs %[[SUB]] : tensor<?xi32> 6225 // CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32> 6226 // CHECK: %[[CONVERT:.*]] = mhlo.convert(%[[CST]]) : (tensor<f32>) -> tensor<i32> 6227 // CHECK: %[[LE:.*]] = chlo.broadcast_compare %[[ABS]], %[[CONVERT]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi1> 6228 // CHECK: return %[[LE]] : tensor<?xi1> 6229 %equal = "tf.ApproximateEqual"(%arg0, %arg1) { tolerance = 2. : f32 } : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi1> 6230 func.return %equal : tensor<?xi1> 6231} 6232 6233// CHECK-LABEL: func @approximateequal_complex64 6234func.func @approximateequal_complex64(%arg0: tensor<?xcomplex<f32>>, %arg1: tensor<?xcomplex<f32>>) -> tensor<?xi1> { 6235 // CHECK: %[[SUB:.*]] = mhlo.subtract %arg0, %arg1 : tensor<?xcomplex<f32>> 6236 // CHECK: %[[ABS:.*]] = mhlo.abs(%[[SUB]]) : (tensor<?xcomplex<f32>>) -> tensor<?xf32> 6237 // CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32> 6238 // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[CST]] : tensor<f32> 6239 // CHECK: %[[LE:.*]] = chlo.broadcast_compare %[[ABS]], %[[CONVERT]] {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<?xf32>, tensor<f32>) -> tensor<?xi1> 6240 // CHECK: return %[[LE]] : tensor<?xi1> 6241 %equal = "tf.ApproximateEqual"(%arg0, %arg1) { tolerance = 2. : f32 } : (tensor<?xcomplex<f32>>, tensor<?xcomplex<f32>>) -> tensor<?xi1> 6242 func.return %equal : tensor<?xi1> 6243} 6244 6245//===----------------------------------------------------------------------===// 6246// tf.XlaConvV2 legalization 6247//===----------------------------------------------------------------------===// 6248 6249// ----- 6250 6251// CHECK-LABEL: xla_conv_v2 6252func.func @xla_conv_v2(%lhs: tensor<8x4x16x16x16xf32>, %rhs: tensor<4x3x3x16x16xf32>) -> (tensor<4x4x14x14x16xf32>) { 6253 %feature_group_count = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 6254 %lhs_dilation = "tf.Const"() {value = dense<[4, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32> 6255 %rhs_dilation = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> 6256 %padding = "tf.Const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32> 6257 %strides = "tf.Const"() {value = dense<[3, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32> 6258 // CHECK: mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], window = {stride = [3, 1, 1], pad = {{\[\[}}0, 0], {{\[}}0, 0], {{\[}}0, 0]], lhs_dilate = [4, 1, 1], rhs_dilate = [1, 1, 1]} {batch_group_count = 2 : i64, feature_group_count = 1 : i64, precision_config = []} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>) -> tensor<4x4x14x14x16xf32> 6259 %0 = "tf.XlaConvV2"(%lhs, %rhs, %strides, %padding, %lhs_dilation, %rhs_dilation, %feature_group_count) {batch_group_count = 2 : i64, dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<4x4x14x14x16xf32> 6260 func.return %0 : tensor<4x4x14x14x16xf32> 6261} 6262 6263//===----------------------------------------------------------------------===// 6264// tf.XlaDot legalization 6265//===----------------------------------------------------------------------===// 6266 6267// ----- 6268 6269// CHECK-LABEL: @xladot_matmul( 6270// CHECK-SAME: %[[LHS:.*]]: tensor<64x32xi8>, %[[RHS:.*]]: tensor<32x16xi8>) -> tensor<64x16xi32> 6271func.func @xladot_matmul(%lhs : tensor<64x32xi8>, %rhs : tensor<32x16xi8>) -> tensor<64x16xi32> { 6272 // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) { 6273 // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< 6274 // CHECK-NOT: lhs_batching_dimensions = 6275 // CHECK-NOT: rhs_batching_dimensions = 6276 // CHECK-SAME: lhs_contracting_dimensions = [1] 6277 // CHECK-SAME: rhs_contracting_dimensions = [0] 6278 // CHECK-SAME: precision_config = [] 6279 %res = "tf.XlaDot"(%lhs, %rhs) {dimension_numbers = "\0A\01\01\12\01\00", precision_config = ""} : (tensor<64x32xi8>, tensor<32x16xi8>) -> tensor<64x16xi32> 6280 func.return %res : tensor<64x16xi32> 6281} 6282 6283//===----------------------------------------------------------------------===// 6284// tf.XlaDotV2 legalization 6285//===----------------------------------------------------------------------===// 6286 6287// ----- 6288 6289// CHECK-LABEL: @xladotv2_matmul( 6290// CHECK-SAME: %[[LHS:.*]]: tensor<64x32xi8>, %[[RHS:.*]]: tensor<32x16xi8>) -> tensor<64x16xi32> 6291func.func @xladotv2_matmul(%lhs : tensor<64x32xi8>, %rhs : tensor<32x16xi8>) -> tensor<64x16xi32> { 6292 // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) { 6293 // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< 6294 // CHECK-NOT: lhs_batching_dimensions = 6295 // CHECK-NOT: rhs_batching_dimensions = 6296 // CHECK-SAME: lhs_contracting_dimensions = [1] 6297 // CHECK-SAME: rhs_contracting_dimensions = [0] 6298 // CHECK-SAME: precision_config = [] 6299 %res = "tf.XlaDotV2"(%lhs, %rhs) {dimension_numbers = "\0A\01\01\12\01\00", precision_config = ""} : (tensor<64x32xi8>, tensor<32x16xi8>) -> tensor<64x16xi32> 6300 func.return %res : tensor<64x16xi32> 6301} 6302 6303//===----------------------------------------------------------------------===// 6304// tf.XlaDynamicSlice legalization 6305//===----------------------------------------------------------------------===// 6306// ----- 6307 6308// CHECK-LABEL: xla_dynamic_slice_constant_start 6309func.func @xla_dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { 6310 // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<i64> 6311 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 6312 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 6313 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : 6314 // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64> 6315 // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64> 6316 // CHECK-NEXT: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]]) 6317 // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} : 6318 // CHECK-DAG-SAME: (tensor<4xi32>, tensor<i64>) -> tensor<2xi32> 6319 // CHECK-NEXT: return %[[RESULT]] : tensor<2xi32> 6320 %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) 6321 %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>) 6322 %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32> 6323 func.return %0 : tensor<2xi32> 6324} 6325 6326// ----- 6327 6328// CHECK-LABEL: xla_dynamic_slice_i32_consts 6329func.func @xla_dynamic_slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { 6330 // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<i32> 6331 // CHECK: "mhlo.dynamic_slice"(%arg0, %[[START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i32>) -> tensor<2xi32> 6332 %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) 6333 %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) 6334 %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> 6335 func.return %0 : tensor<2xi32> 6336} 6337 6338// ----- 6339 6340// CHECK-LABEL: xla_dynamic_slice_constant_start_dynamic_shape 6341func.func @xla_dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { 6342 // CHECK-DAG: %[[START1:.*]] = mhlo.constant dense<1> : tensor<i64> 6343 // CHECK-DAG: %[[START2:.*]] = mhlo.constant dense<0> : tensor<i64> 6344 // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice" 6345 // CHECK-DAG-SAME: (%arg0, %[[START1]], %[[START2]]) 6346 // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : 6347 // CHECK-DAG-SAME: (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32> 6348 // CHECK: return %[[RESULT]] : tensor<1x4xi32> 6349 %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) 6350 %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) 6351 %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor<?x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> 6352 func.return %0 : tensor<1x4xi32> 6353} 6354 6355// ----- 6356 6357// CHECK-LABEL: xla_dynamic_slice_variable_start 6358func.func @xla_dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { 6359 // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%arg1) 6360 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 6361 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 6362 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> 6363 // CHECK: %[[RESHAPED_START1:.*]] = mhlo.reshape %[[SLICED_START1]] : (tensor<1xi64>) -> tensor<i64> 6364 // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%arg1) 6365 // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>, 6366 // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, 6367 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> 6368 // CHECK: %[[RESHAPED_START2:.*]] = mhlo.reshape %[[SLICED_START2]] : (tensor<1xi64>) -> tensor<i64> 6369 // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32> 6370 // CHECK: return %[[RESULT]] : tensor<1x4xi32> 6371 %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) 6372 %0 = "tf.XlaDynamicSlice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> 6373 func.return %0 : tensor<1x4xi32> 6374} 6375 6376// ----- 6377 6378// CHECK-LABEL: xla_dynamic_slice_mhlo_sizes 6379func.func @xla_dynamic_slice_mhlo_sizes(%arg0: tensor<1x1024x4xf32>, %arg1: tensor<3xi32>) -> tensor<1x512x4xf32> { 6380 // CHECK-NOT: "tf.XlaDynamicSlice" 6381 %0 = "mhlo.constant"() {value = dense<[1, 512, 4]> : tensor<3xi32>} : () -> tensor<3xi32> 6382 %1 = "tf.XlaDynamicSlice"(%arg0, %arg1, %0) : (tensor<1x1024x4xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x4xf32> 6383 func.return %1 : tensor<1x512x4xf32> 6384} 6385 6386//===----------------------------------------------------------------------===// 6387// tf.XlaEinsum legalization 6388//===----------------------------------------------------------------------===// 6389 6390// ----- 6391 6392// CHECK-LABEL: func @xlaeinsum 6393func.func @xlaeinsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { 6394 // CHECK-NEXT: mhlo.einsum 6395 %0 = "tf.XlaEinsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32> 6396 func.return %0: tensor<2x4xf32> 6397} 6398 6399 6400//===----------------------------------------------------------------------===// 6401// tf.XlaReduceWindow legalization 6402//===----------------------------------------------------------------------===// 6403// ----- 6404// CHECK-LABEL: @test_xla_reduce_window 6405func.func @test_xla_reduce_window(%arg0: tensor<7xf32>, %arg1: tensor<f32>) -> tensor<10xf32> { 6406 %cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32> 6407 %cst_0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 6408 %cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> 6409 %cst_2 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> 6410 %cst_3 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32> 6411 // CHECK: %[[REDUCE:.*]] = "mhlo.reduce_window"(%arg0, %arg1) ({ 6412 // CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) 6413 // CHECK-NEXT: %[[SUM:.*]] = func.call @sum_reducer3(%[[ARG0]], %[[ARG1]]){{.*}} 6414 // CHECK-NEXT: mhlo.return %[[SUM]] : tensor<*xf32> 6415 // CHECK-NEXT: }) {base_dilations = dense<3> : tensor<1xi64>, padding = dense<0> : tensor<1x2xi64>, window_dilations = dense<4> : tensor<1xi64>, window_dimensions = dense<1> : tensor<1xi64>, window_strides = dense<2> : tensor<1xi64>} : (tensor<7xf32>, tensor<f32>) -> tensor<10xf32> 6416 // CHECK-NEXT: return %[[REDUCE]] 6417 %0 = "tf.XlaReduceWindow"(%arg0, %arg1, %cst_0, %cst_1, %cst_2, %cst_3, %cst) {computation = @sum_reducer3} : (tensor<7xf32>, tensor<f32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<10xf32> 6418 func.return %0 : tensor<10xf32> 6419} 6420 6421func.func private @sum_reducer3(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { 6422 %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> 6423 func.return %0 : tensor<*xf32> 6424} 6425 6426//===----------------------------------------------------------------------===// 6427// tf.XlaSort legalization 6428//===----------------------------------------------------------------------===// 6429 6430// ----- 6431 6432// CHECK-LABEL: @xlasort_int 6433// CHECK-SAME: %[[INPUT:.*]]: tensor<16xi32> 6434func.func @xlasort_int(%input: tensor<16xi32>) -> (tensor<16xi32>) { 6435 // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) ({ 6436 // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<i32>, %[[RHS:.*]]: tensor<i32>) 6437 // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare LT, %[[LHS]], %[[RHS]], NOTYPE 6438 // CHECK-NEXT: mhlo.return %[[CMP]] 6439 // CHECK-NEXT: }) {dimension = -1 : i64, is_stable = false} : (tensor<16xi32>) -> tensor<16xi32> 6440 // CHECK-NEXT: return %[[SORT]] 6441 %output = "tf.XlaSort"(%input) : (tensor<16xi32>) -> (tensor<16xi32>) 6442 func.return %output : tensor<16xi32> 6443} 6444 6445// ----- 6446 6447// CHECK-LABEL: @xlasort_float 6448// CHECK-SAME: %[[INPUT:.*]]: tensor<8xf64> 6449func.func @xlasort_float(%input: tensor<8xf64>) -> (tensor<8xf64>) { 6450 // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) ({ 6451 // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<f64>, %[[RHS:.*]]: tensor<f64>) 6452 // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare LT, %[[LHS]], %[[RHS]], TOTALORDER 6453 // CHECK-NEXT: mhlo.return %[[CMP]] 6454 // CHECK-NEXT: }) {dimension = -1 : i64, is_stable = false} : (tensor<8xf64>) -> tensor<8xf64> 6455 // CHECK-NEXT: return %[[SORT]] 6456 %output = "tf.XlaSort"(%input) : (tensor<8xf64>) -> (tensor<8xf64>) 6457 func.return %output : tensor<8xf64> 6458} 6459 6460// ----- 6461 6462// CHECK-LABEL: @xlasort_const 6463func.func @xlasort_const() -> (tensor<2x3xi64>) { 6464 // CHECK: [2, 4, 3], [6, 5, 1] 6465 %input = "tf.Const"() {value = dense<[[2, 4, 3], [6, 5, 1]]> : tensor<2x3xi64>} : () -> (tensor<2x3xi64>) 6466 // CHECK-NEXT: [2, 3, 4], [1, 5, 6] 6467 %output = "tf.XlaSort"(%input): (tensor<2x3xi64>) -> (tensor<2x3xi64>) 6468 func.return %output : tensor<2x3xi64> 6469} 6470 6471//===----------------------------------------------------------------------===// 6472// tf.XlaRngBitGenerator legalization 6473//===----------------------------------------------------------------------===// 6474 6475// CHECK-LABEL: @xla_rng_bit_generator 6476// CHECK-SAME: %[[STATE:.*]]: tensor<2xui64> 6477func.func @xla_rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1,_arg2", outputs = "_retval0,_retval1"}} { 6478 // CHECK-NEXT: %0 = mhlo.constant dense<[10, 12]> : tensor<2xi32> 6479 %cst = "tf.Const"() {value = dense<[10, 12]> : tensor<2xi32>} : () -> tensor<2xi32> 6480 // CHECK-NEXT: %1 = mhlo.constant dense<3> : tensor<i32> 6481 %cst_0 = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32> 6482 // CHECK-NEXT: %[[OUTPUT_STATE:.*]], %[[OUTPUT:.*]] = "mhlo.rng_bit_generator"(%[[STATE]]) {rng_algorithm = #mhlo.rng_algorithm<DEFAULT>} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) 6483 // CHECK-NEXT: return %[[OUTPUT_STATE]], %[[OUTPUT]] : tensor<2xui64>, tensor<10x12xui32> 6484 %output_key, %output = "tf.XlaRngBitGenerator"(%cst_0, %arg0, %cst) : (tensor<i32>, tensor<2xui64>, tensor<2xi32>) -> (tensor<2xui64>, tensor<10x12xui32>) 6485 func.return %output_key, %output : tensor<2xui64>, tensor<10x12xui32> 6486} 6487 6488//===----------------------------------------------------------------------===// 6489// tf.XlaVariadicV2 legalization 6490//===----------------------------------------------------------------------===// 6491 6492// ----- 6493// CHECK-LABEL: @xla_variadic_reduce_v2 6494func.func @xla_variadic_reduce_v2(%arg0: tensor<2x3xcomplex<f64>>, %arg1: tensor<complex<f64>>) -> tensor<3xcomplex<f64>> attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} { 6495 // CHECK: %[[REDUCE:.*]] = mhlo.reduce(%arg0 init: %arg1) 6496 // CHECK-SAME: dimensions = [0] 6497 // CHECK-NEXT: (%[[ARG0:.*]]: tensor<complex<f64>>, %[[ARG1:.*]]: tensor<complex<f64>>) 6498 // CHECK-NEXT: %[[SUM:.*]] = func.call @sum_reducer(%[[ARG0]], %[[ARG1]]){{.*}} 6499 // CHECK-NEXT: mhlo.return %[[SUM]] : tensor<complex<f64>> 6500 // CHECK: return %[[REDUCE]] 6501 %0 = "tf.XlaVariadicReduceV2"(%arg0, %arg1) {_XlaHasReferenceVars = false, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", dimensions_to_reduce = [0], operand_segment_sizes = array<i32: 1, 1>, reducer = @sum_reducer} : (tensor<2x3xcomplex<f64>>, tensor<complex<f64>>) -> tensor<3xcomplex<f64>> 6502 func.return %0 : tensor<3xcomplex<f64>> 6503} 6504 6505func.func private @sum_reducer(%arg0: tensor<complex<f64>>, %arg1: tensor<complex<f64>>) -> tensor<complex<f64>> { 6506 %0 = "tf.AddV2"(%arg1, %arg0) : (tensor<complex<f64>>, tensor<complex<f64>>) -> tensor<complex<f64>> 6507 func.return %0 : tensor<complex<f64>> 6508} 6509 6510// ----- 6511 6512// CHECK-LABEL: @xla_variadic_reduce_v2_dynamic 6513func.func @xla_variadic_reduce_v2_dynamic(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} { 6514 // CHECK: %[[REDUCE:.*]] = mhlo.reduce(%arg0 init: %arg1) 6515 // CHECK-SAME: dimensions = [0] 6516 // CHECK-NEXT: (%[[ARG0:.*]]: tensor<i32>, %[[ARG1:.*]]: tensor<i32>) 6517 // CHECK-NEXT: %[[SUM:.*]] = func.call @sum_reducer2(%[[ARG0]], %[[ARG1]]){{.*}} 6518 // CHECK-NEXT: mhlo.return %[[SUM]] : tensor<i32> 6519 // CHECK: return %[[REDUCE]] 6520 %0 = "tf.XlaVariadicReduceV2"(%arg0, %arg1) {_XlaHasReferenceVars = false, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", dimensions_to_reduce = [0], operand_segment_sizes = array<i32: 1, 1>, reducer = @sum_reducer2} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> 6521 func.return %0 : tensor<*xi32> 6522} 6523 6524func.func private @sum_reducer2(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> { 6525 %0 = "tf.AddV2"(%arg1, %arg0) : (tensor<i32>, tensor<i32>) -> tensor<i32> 6526 func.return %0 : tensor<i32> 6527} 6528 6529//===----------------------------------------------------------------------===// 6530// tf.XlaVariadicSort legalization 6531//===----------------------------------------------------------------------===// 6532 6533// CHECK-LABEL: @xla_variadic_sort 6534// CHECK-SAME: %[[INPUT:.*]]: tensor<2x3x4xui8> 6535func.func @xla_variadic_sort(%arg0: tensor<2x3x4xui8>) -> tensor<2x3x4xui8> attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} { 6536 // CHECK-NEXT: {{.*}} = mhlo.constant dense<0> : tensor<i32> 6537 %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> 6538 // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) ({ 6539 // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<ui8>, %[[RHS:.*]]: tensor<ui8>) 6540 // CHECK-NEXT: %[[CMP:.*]] = func.call @compare_lt(%[[LHS]], %[[RHS]]) : (tensor<ui8>, tensor<ui8>) -> tensor<i1> 6541 // CHECK-NEXT: mhlo.return %[[CMP]] 6542 // CHECK-NEXT: }) {dimension = 0 : i64, is_stable = false} : (tensor<2x3x4xui8>) -> tensor<2x3x4xui8> 6543 // CHECK-NEXT: return %[[SORT]] 6544 %0 = "tf.XlaVariadicSort"(%arg0, %cst) {_XlaHasReferenceVars = false, comparator = @compare_lt, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", is_stable = false} : (tensor<2x3x4xui8>, tensor<i32>) -> tensor<2x3x4xui8> 6545 func.return %0 : tensor<2x3x4xui8> 6546} 6547 6548func.func private @compare_lt(%arg0: tensor<ui8>, %arg1: tensor<ui8>) -> tensor<i1> attributes {tf._disable_call_shape_inference = true} { 6549 %0 = "tf.Less"(%arg0, %arg1) {device = ""} : (tensor<ui8>, tensor<ui8>) -> tensor<i1> 6550 func.return %0 : tensor<i1> 6551} 6552 6553//===----------------------------------------------------------------------===// 6554// tf.NextAfter legalization 6555//===----------------------------------------------------------------------===// 6556// CHECK-LABEL: func @nextafter 6557func.func @nextafter(%arg0: tensor<2xf32>, %arg1 : tensor<2xf32>) -> tensor<2xf32> { 6558 // CHECK-NEXT: %0 = chlo.broadcast_next_after %arg0, %arg1 : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> 6559 // CHECK-NEXT: return %0 : tensor<2xf32> 6560 %0 = "tf.NextAfter"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> 6561 func.return %0: tensor<2xf32> 6562} 6563 6564//===----------------------------------------------------------------------===// 6565// tf.XlaReduceScatter legalization 6566//===----------------------------------------------------------------------===// 6567// CHECK-LABEL: func @xla_reduce_scatter 6568func.func @xla_reduce_scatter(%arg0: tensor<128x128xf32>) -> tensor<64x128xf32> { 6569 %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> 6570 %cst_0 = "tf.Const"() {value = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32> 6571 // CHECK: "mhlo.reduce_scatter"(%arg0) 6572 // CHECK{LITERAL}: replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> 6573 // CHECK-SAME: scatter_dimension = 0 6574 // 6575 %1 = "tf.XlaReduceScatter"(%arg0, %cst_0, %cst) {reduce_op = "Add"} : (tensor<128x128xf32>, tensor<4x2xi32>, tensor<i32>) -> tensor<64x128xf32> 6576 func.return %1 : tensor<64x128xf32> 6577} 6578 6579 6580//===----------------------------------------------------------------------===// 6581// tf.XlaSelectAndScatter legalization 6582//===----------------------------------------------------------------------===// 6583func.func @test_xla_select_and_scatter(%arg0: tensor<4x5x1x1xbf16>, %arg1: tensor<2x2x1x1xbf16>, %arg2: tensor<bf16>) -> tensor<?x?x?x?xbf16> { 6584 %cst = "tf.Const"() {value = dense<0> : tensor<4x2xi32>} : () -> tensor<4x2xi32> 6585 %cst_0 = "tf.Const"() {value = dense<[2, 2, 1, 1]> : tensor<4xi32>} : () -> tensor<4xi32> 6586 %cst_1 = "tf.Const"() {value = dense<[2, 3, 1, 1]> : tensor<4xi32>} : () -> tensor<4xi32> 6587 // CHECK: %[[SELECT_AND_SCATTER:.*]] = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ 6588 // CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: tensor<*xbf16>, %[[ARG1:.*]]: tensor<*xbf16>) 6589 // CHECK-NEXT: %[[RES:.*]] = func.call @ge_select(%[[ARG0]], %[[ARG1]]){{.*}} 6590 // CHECK-NEXT: mhlo.return %[[RES]] : tensor<*xi1> 6591 // CHECK-NEXT: }, { 6592 // CHECK-NEXT: ^{{.*}}(%[[ARG2:.*]]: tensor<*xbf16>, %[[ARG3:.*]]: tensor<*xbf16>) 6593 // CHECK-NEXT: %[[RES:.*]] = func.call @add_scatter(%[[ARG2]], %[[ARG3]]){{.*}} 6594 // CHECK-NEXT: mhlo.return %[[RES]] : tensor<*xbf16> 6595 // CHECK-NEXT: }) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[2, 3, 1, 1]> : tensor<4xi64>, window_strides = dense<[2, 2, 1, 1]> : tensor<4xi64>} : (tensor<4x5x1x1xbf16>, tensor<2x2x1x1xbf16>, tensor<bf16>) -> tensor<?x?x?x?xbf16> 6596 // CHECK-NEXT: return %[[SELECT_AND_SCATTER]] 6597 %0 = "tf.XlaSelectAndScatter"(%arg0, %cst_1, %cst_0, %cst, %arg1, %arg2) {scatter = @add_scatter, select = @ge_select} : (tensor<4x5x1x1xbf16>, tensor<4xi32>, tensor<4xi32>, tensor<4x2xi32>, tensor<2x2x1x1xbf16>, tensor<bf16>) -> tensor<?x?x?x?xbf16> 6598 func.return %0 : tensor<?x?x?x?xbf16> 6599} 6600 6601func.func private @add_scatter(%arg0: tensor<*xbf16>, %arg1: tensor<*xbf16>) -> tensor<*xbf16> { 6602 %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor<*xbf16>, tensor<*xbf16>) -> tensor<*xbf16> 6603 func.return %0 : tensor<*xbf16> 6604} 6605 6606func.func private @ge_select(%arg0: tensor<*xbf16>, %arg1: tensor<*xbf16>) -> tensor<*xi1> { 6607 %0 = "tf.GreaterEqual"(%arg0, %arg1) {device = ""} : (tensor<*xbf16>, tensor<*xbf16>) -> tensor<*xi1> 6608 func.return %0 : tensor<*xi1> 6609} 6610 6611//===----------------------------------------------------------------------===// 6612// tf.XlaOptimizationBarrier legalization 6613//===----------------------------------------------------------------------===// 6614 6615func.func @test_xla_optimization_barrier(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xi32>) -> (tensor<4x4xf32>, tensor<3x4xi32>) { 6616 // CHECK: %[[OPT_BARRIER:.*]]:2 = mhlo.optimization_barrier %arg0, %arg1 6617 // CHECK-NEXT: return %[[OPT_BARRIER]]#0, %[[OPT_BARRIER]]#1 6618 %0, %1 = "tf.XlaOptimizationBarrier"(%arg0, %arg1) : (tensor<4x4xf32>, tensor<3x4xi32>) -> (tensor<4x4xf32>, tensor<3x4xi32>) 6619 func.return %0, %1 : tensor<4x4xf32>, tensor<3x4xi32> 6620} 6621