1// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s
2
3// CHECK-LABEL: @dot_general_is_dot
4func.func @dot_general_is_dot(%arg0: tensor<5x6xf32>, %arg1: tensor<6x?xf32>) -> tensor<5x?xf32> {
5  // CHECK: %[[DOT:.+]] = "mhlo.dot"(%arg0, %arg1)
6  // CHECK-SAME: precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
7  %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<5x6xf32>, tensor<6x?xf32>) -> tensor<5x?xf32>
8  // CHECK: %[[DOT]]
9  return %0 : tensor<5x?xf32>
10}
11
12// -----
13
14// CHECK-LABEL: @convolution_is_dot_general
15func.func @convolution_is_dot_general(%arg0: tensor<5x6xf32>, %arg1: tensor<?x6xf32>) -> tensor<5x?xf32> {
16  // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%arg0, %arg1)
17  // CHECK-SAME: dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>,
18  // CHECK-SAME: precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
19  %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f]x[o, i]->[b, f], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<5x6xf32>, tensor<?x6xf32>) -> tensor<5x?xf32>
20  // CHECK: %[[DOT]]
21  return %0 : tensor<5x?xf32>
22}
23
24// -----
25
26// CHECK-LABEL: @convolution_is_dot_general_swap
27func.func @convolution_is_dot_general_swap(%arg0: tensor<5x6xf32>, %arg1: tensor<?x6xf32>) -> tensor<5x?xf32> {
28  // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%arg0, %arg1)
29  // CHECK-SAME: dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>,
30  // CHECK-SAME: precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
31  %0 = mhlo.convolution(%arg1, %arg0) dim_numbers = [b, f]x[o, i]->[f, b], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<?x6xf32>, tensor<5x6xf32>) -> tensor<5x?xf32>
32  // CHECK: %[[DOT]]
33  return %0 : tensor<5x?xf32>
34}
35
36// -----
37
38// CHECK-LABEL: @conv_grouped_is_dot
39func.func @conv_grouped_is_dot(%arg0: tensor<5x12xf32>, %arg1: tensor<2x6xf32>) -> tensor<5x6xf32> {
40  // CHECK: %[[RES0:.+]] = mhlo.reshape %arg0 : (tensor<5x12xf32>) -> tensor<5x6x2xf32>
41  // CHECK: %[[RES1:.+]] = mhlo.reshape %arg1 : (tensor<2x6xf32>) -> tensor<6x1x2xf32>
42  // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[RES0]], %[[RES1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [1], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [2]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]}
43  // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%2) {permutation = dense<[1, 0, 2]> : tensor<3xi64>}
44  // CHECK: %[[OUT:.+]] = mhlo.reshape %3 : (tensor<5x6x1xf32>) -> tensor<5x6xf32>
45  %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f]x[i, o]->[b, f], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 6 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<5x12xf32>, tensor<2x6xf32>) -> tensor<5x6xf32>
46  // CHECK: return %[[OUT]]
47  return %0 : tensor<5x6xf32>
48}
49
50// -----
51
52// CHECK-LABEL: conv_grouped_is_dot_multi
53func.func @conv_grouped_is_dot_multi(%arg0: tensor<5x4xf32>, %arg1: tensor<2x6xf32>) -> tensor<5x6xf32> {
54  // CHECK: %[[LHS:.+]] = mhlo.reshape %arg0 : (tensor<5x4xf32>) -> tensor<5x2x2xf32>
55  // CHECK: %[[RHS:.+]] = mhlo.reshape %arg1 : (tensor<2x6xf32>) -> tensor<2x3x2xf32>
56  // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [1], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [2]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]}
57  // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) {permutation = dense<[1, 0, 2]> : tensor<3xi64>}
58  // CHECK: %[[OUT:.+]] = mhlo.reshape %[[TRANSPOSE]] : (tensor<5x2x3xf32>) -> tensor<5x6xf32>
59  %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f]x[i, o]->[b, f], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<5x4xf32>, tensor<2x6xf32>) -> tensor<5x6xf32>
60  // CHECK: return %[[OUT]]
61  return %0 : tensor<5x6xf32>
62}
63
64// -----
65
66// CHECK-LABEL: conv_grouped_is_dot_transpose_rhs
67func.func @conv_grouped_is_dot_transpose_rhs(%arg0: tensor<5x4xf32>, %arg1: tensor<6x2xf32>) -> tensor<5x6xf32> {
68  // CHECK: %[[LHS:.+]] = mhlo.reshape %arg0 : (tensor<5x4xf32>) -> tensor<5x2x2xf32>
69  // CHECK: %[[RHS:.+]] = mhlo.reshape %arg1 : (tensor<6x2xf32>) -> tensor<2x2x3xf32>
70  // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [1], rhs_batching_dimensions = [1], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]}
71  // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) {permutation = dense<[1, 0, 2]> : tensor<3xi64>}
72  // CHECK: %[[OUT:.+]] = mhlo.reshape %[[TRANSPOSE]] : (tensor<5x2x3xf32>) -> tensor<5x6xf32>
73  %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f]x[o, i]->[b, f], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<5x4xf32>, tensor<6x2xf32>) -> tensor<5x6xf32>
74  // CHECK: return %[[OUT]]
75  return %0 : tensor<5x6xf32>
76}
77
78// -----
79
80// CHECK-LABEL: conv_grouped_is_dot_transpose_ins
81func.func @conv_grouped_is_dot_transpose_ins(%arg0: tensor<4x5xf32>, %arg1: tensor<6x2xf32>) -> tensor<5x6xf32> {
82  // CHECK: %[[LHS:.+]] = mhlo.reshape %arg0 : (tensor<4x5xf32>) -> tensor<2x2x5xf32>
83  // CHECK: %[[RHS:.+]] = mhlo.reshape %arg1 : (tensor<6x2xf32>) -> tensor<2x2x3xf32>
84  // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [1], lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]}
85  // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) {permutation = dense<[1, 0, 2]> : tensor<3xi64>}
86  // CHECK: %[[OUT:.+]] = mhlo.reshape %[[TRANSPOSE]] : (tensor<5x2x3xf32>) -> tensor<5x6xf32>
87  %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [f, b]x[o, i]->[b, f], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<4x5xf32>, tensor<6x2xf32>) -> tensor<5x6xf32>
88  // CHECK: return %[[OUT]]
89  return %0 : tensor<5x6xf32>
90}
91
92// -----
93
94// CHECK-LABEL: conv_grouped_is_dot_transpose_out
95func.func @conv_grouped_is_dot_transpose_out(%arg0: tensor<5x4xf32>, %arg1: tensor<2x6xf32>) -> tensor<6x5xf32> {
96  // CHECK: %[[LHS:.+]] = mhlo.reshape %arg0 : (tensor<5x4xf32>) -> tensor<5x2x2xf32>
97  // CHECK: %[[RHS:.+]] = mhlo.reshape %arg1 : (tensor<2x6xf32>) -> tensor<2x3x2xf32>
98  // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [1], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [2]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]}
99  // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>}
100  // CHECK: %[[OUT:.+]] = mhlo.reshape %[[TRANSPOSE]]
101  %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f]x[i, o]->[f, b], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<5x4xf32>, tensor<2x6xf32>) -> tensor<6x5xf32>
102  // CHECK: return %[[OUT]]
103  return %0 : tensor<6x5xf32>
104}