1// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s 2 3// CHECK-LABEL: func @transpose_splat_constant 4func.func @transpose_splat_constant() -> tensor<5x10xf32> { 5 // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<5x10xf32> 6 %cst = mhlo.constant dense<1.000000e+00> : tensor<10x5xf32> 7 %0 = "mhlo.transpose"(%cst) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<10x5xf32>) -> tensor<5x10xf32> 8 // CHECK-NEXT: return [[CST]] 9 func.return %0 : tensor<5x10xf32> 10} 11 12// ----- 13 14// CHECK-LABEL: func @remove_noop 15// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] 16func.func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> { 17 %0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> 18 // CHECK-NEXT: return [[ARG]] 19 func.return %0 : tensor<2x3x9x5xi32> 20} 21 22// ----- 23 24// CHECK-LABEL: func @keep_real_transpose 25// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] 26func.func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { 27 // CHECK-NEXT: "mhlo.transpose"([[ARG]]) 28 %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> 29 func.return %0 : tensor<3x2x5x9xi32> 30} 31 32// ----- 33 34// CHECK-LABEL: func @keep_same_shape_real_transpose 35// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] 36func.func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> { 37 // CHECK-NEXT: "mhlo.transpose"([[ARG]]) 38 %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32> 39 func.return %0 : tensor<4x4xi32> 40} 41 42// ----- 43 44// CHECK-LABEL: @eliminate_redundant_transpose 45// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] 46func.func @eliminate_redundant_transpose(%arg : tensor<3x4x16x2xf32>) -> tensor<3x2x16x4xf32> { 47 %0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}: (tensor<3x4x16x2xf32>) -> tensor<3x2x4x16xf32> 48 %1 = "mhlo.transpose"(%0) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>}: (tensor<3x2x4x16xf32>) -> tensor<3x2x16x4xf32> 49 // CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.transpose"([[ARG]]) 50 // CHECK-SAME: dense<[0, 3, 2, 1] 51 // CHECK-NEXT: return [[RET]] 52 func.return %1 : tensor<3x2x16x4xf32> 53} 54 55// ----- 56 57// CHECK-LABEL: @simplify_transpose_case1 58// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] 59func.func @simplify_transpose_case1(%arg : tensor<10x1x512xf32>) -> tensor<1x10x512xf32> { 60 %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 2]> : tensor<3xi64>}: (tensor<10x1x512xf32>) -> tensor<1x10x512xf32> 61 // CHECK-NEXT: mhlo.reshape [[ARG]] 62 func.return %0 : tensor<1x10x512xf32> 63} 64 65// ----- 66 67// CHECK-LABEL: @simplify_transpose_case2 68// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] 69func.func @simplify_transpose_case2(%arg : tensor<10x1x512x1xf32>) -> tensor<1x1x10x512xf32> { 70 %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 3, 0, 2]> : tensor<4xi64>}: (tensor<10x1x512x1xf32>) -> tensor<1x1x10x512xf32> 71 // CHECK-NEXT: mhlo.reshape [[ARG]] 72 func.return %0 : tensor<1x1x10x512xf32> 73} 74 75// ----- 76 77// CHECK-LABEL: @not_simplify_transpose_dynamic_shape 78// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] 79func.func @not_simplify_transpose_dynamic_shape(%arg : tensor<10x?x512xf32>) -> tensor<?x10x512xf32> { 80 %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 2]> : tensor<3xi64>}: (tensor<10x?x512xf32>) -> tensor<?x10x512xf32> 81 // CHECK-NEXT: "mhlo.transpose"([[ARG]]) 82 func.return %0 : tensor<?x10x512xf32> 83} 84 85// ----- 86 87// CHECK-LABEL: func @broadcast_transpose 88// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] 89func.func @broadcast_transpose(%arg0 : tensor<64xf32>) -> tensor<5x64x31x95xf32> { 90 %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<5x31x95x64xf32> 91 %1 = "mhlo.transpose"(%0) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>} : (tensor<5x31x95x64xf32>) -> tensor<5x64x31x95xf32> 92 // CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.broadcast_in_dim"([[ARG]]) 93 // CHECK-SAME: dense<1> 94 // CHECK-NEXT: return [[RET]] 95 func.return %1 : tensor<5x64x31x95xf32> 96} 97 98// ----- 99 100// CHECK-LABEL: func @broadcast_transpose_non_dim 101// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] 102func.func @broadcast_transpose_non_dim(%arg0 : tensor<f32>) -> tensor<5x64x31x95xf32> { 103 %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x31x95x64xf32> 104 %1 = "mhlo.transpose"(%0) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>} : (tensor<5x31x95x64xf32>) -> tensor<5x64x31x95xf32> 105 // CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.broadcast_in_dim"([[ARG]]) 106 // CHECK-SAME: dense<> 107 // CHECK-NEXT: return [[RET]] 108 func.return %1 : tensor<5x64x31x95xf32> 109} 110 111// ----- 112 113// CHECK-LABEL: func @broadcast_transpose_multi_dim 114// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] 115func.func @broadcast_transpose_multi_dim(%arg0 : tensor<95x64xf32>) -> tensor<5x64x31x95xf32> { 116 %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<95x64xf32>) -> tensor<5x31x95x64xf32> 117 %1 = "mhlo.transpose"(%0) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>} : (tensor<5x31x95x64xf32>) -> tensor<5x64x31x95xf32> 118 // CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.broadcast_in_dim"([[ARG]]) 119 // CHECK-SAME: dense<[3, 1]> 120 // CHECK-NEXT: return [[RET]] 121 func.return %1 : tensor<5x64x31x95xf32> 122} 123