1// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s
2
3// CHECK-LABEL: func @transpose_splat_constant
4func.func @transpose_splat_constant() -> tensor<5x10xf32> {
5  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<5x10xf32>
6  %cst = mhlo.constant dense<1.000000e+00> : tensor<10x5xf32>
7  %0 = "mhlo.transpose"(%cst) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<10x5xf32>) -> tensor<5x10xf32>
8  // CHECK-NEXT: return [[CST]]
9  func.return %0 : tensor<5x10xf32>
10}
11
12// -----
13
14// CHECK-LABEL: func @remove_noop
15// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
16func.func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> {
17  %0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32>
18  // CHECK-NEXT: return [[ARG]]
19  func.return %0 : tensor<2x3x9x5xi32>
20}
21
22// -----
23
24// CHECK-LABEL: func @keep_real_transpose
25// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
26func.func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
27  // CHECK-NEXT: "mhlo.transpose"([[ARG]])
28  %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
29  func.return %0 : tensor<3x2x5x9xi32>
30}
31
32// -----
33
34// CHECK-LABEL: func @keep_same_shape_real_transpose
35// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
36func.func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> {
37  // CHECK-NEXT: "mhlo.transpose"([[ARG]])
38  %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32>
39  func.return %0 : tensor<4x4xi32>
40}
41
42// -----
43
44// CHECK-LABEL: @eliminate_redundant_transpose
45// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
46func.func @eliminate_redundant_transpose(%arg : tensor<3x4x16x2xf32>) -> tensor<3x2x16x4xf32> {
47  %0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}: (tensor<3x4x16x2xf32>) -> tensor<3x2x4x16xf32>
48  %1 = "mhlo.transpose"(%0) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>}: (tensor<3x2x4x16xf32>) -> tensor<3x2x16x4xf32>
49  // CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.transpose"([[ARG]])
50  // CHECK-SAME: dense<[0, 3, 2, 1]
51  // CHECK-NEXT: return [[RET]]
52  func.return %1 : tensor<3x2x16x4xf32>
53}
54
55// -----
56
57// CHECK-LABEL: @simplify_transpose_case1
58// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
59func.func @simplify_transpose_case1(%arg : tensor<10x1x512xf32>) -> tensor<1x10x512xf32> {
60  %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 2]> : tensor<3xi64>}: (tensor<10x1x512xf32>) -> tensor<1x10x512xf32>
61  // CHECK-NEXT: mhlo.reshape [[ARG]]
62  func.return %0 : tensor<1x10x512xf32>
63}
64
65// -----
66
67// CHECK-LABEL: @simplify_transpose_case2
68// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
69func.func @simplify_transpose_case2(%arg : tensor<10x1x512x1xf32>) -> tensor<1x1x10x512xf32> {
70  %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 3, 0, 2]> : tensor<4xi64>}: (tensor<10x1x512x1xf32>) -> tensor<1x1x10x512xf32>
71  // CHECK-NEXT: mhlo.reshape [[ARG]]
72  func.return %0 : tensor<1x1x10x512xf32>
73}
74
75// -----
76
77// CHECK-LABEL: @not_simplify_transpose_dynamic_shape
78// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
79func.func @not_simplify_transpose_dynamic_shape(%arg : tensor<10x?x512xf32>) -> tensor<?x10x512xf32> {
80  %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 2]> : tensor<3xi64>}: (tensor<10x?x512xf32>) -> tensor<?x10x512xf32>
81  // CHECK-NEXT: "mhlo.transpose"([[ARG]])
82  func.return %0 : tensor<?x10x512xf32>
83}
84
85// -----
86
87// CHECK-LABEL: func @broadcast_transpose
88// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
89func.func @broadcast_transpose(%arg0 : tensor<64xf32>) -> tensor<5x64x31x95xf32> {
90    %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<5x31x95x64xf32>
91    %1 = "mhlo.transpose"(%0) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>} : (tensor<5x31x95x64xf32>) -> tensor<5x64x31x95xf32>
92    // CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.broadcast_in_dim"([[ARG]])
93    // CHECK-SAME: dense<1>
94    // CHECK-NEXT: return [[RET]]
95    func.return %1 : tensor<5x64x31x95xf32>
96}
97
98// -----
99
100// CHECK-LABEL: func @broadcast_transpose_non_dim
101// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
102func.func @broadcast_transpose_non_dim(%arg0 : tensor<f32>) -> tensor<5x64x31x95xf32> {
103    %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x31x95x64xf32>
104    %1 = "mhlo.transpose"(%0) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>} : (tensor<5x31x95x64xf32>) -> tensor<5x64x31x95xf32>
105    // CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.broadcast_in_dim"([[ARG]])
106    // CHECK-SAME: dense<>
107    // CHECK-NEXT: return [[RET]]
108    func.return %1 : tensor<5x64x31x95xf32>
109}
110
111// -----
112
113// CHECK-LABEL: func @broadcast_transpose_multi_dim
114// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
115func.func @broadcast_transpose_multi_dim(%arg0 : tensor<95x64xf32>) -> tensor<5x64x31x95xf32> {
116    %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<95x64xf32>) -> tensor<5x31x95x64xf32>
117    %1 = "mhlo.transpose"(%0) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>} : (tensor<5x31x95x64xf32>) -> tensor<5x64x31x95xf32>
118    // CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.broadcast_in_dim"([[ARG]])
119    // CHECK-SAME: dense<[3, 1]>
120    // CHECK-NEXT: return [[RET]]
121    func.return %1 : tensor<5x64x31x95xf32>
122}
123