1// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s 2 3// CHECK-LABEL: @dot_general_is_dot 4func.func @dot_general_is_dot(%arg0: tensor<5x6xf32>, %arg1: tensor<6x?xf32>) -> tensor<5x?xf32> { 5 // CHECK: %[[DOT:.+]] = "mhlo.dot"(%arg0, %arg1) 6 // CHECK-SAME: precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>] 7 %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<5x6xf32>, tensor<6x?xf32>) -> tensor<5x?xf32> 8 // CHECK: %[[DOT]] 9 return %0 : tensor<5x?xf32> 10} 11 12// ----- 13 14// CHECK-LABEL: @convolution_is_dot_general 15func.func @convolution_is_dot_general(%arg0: tensor<5x6xf32>, %arg1: tensor<?x6xf32>) -> tensor<5x?xf32> { 16 // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%arg0, %arg1) 17 // CHECK-SAME: dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>, 18 // CHECK-SAME: precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>] 19 %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f]x[o, i]->[b, f], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<5x6xf32>, tensor<?x6xf32>) -> tensor<5x?xf32> 20 // CHECK: %[[DOT]] 21 return %0 : tensor<5x?xf32> 22} 23 24// ----- 25 26// CHECK-LABEL: @convolution_is_dot_general_swap 27func.func @convolution_is_dot_general_swap(%arg0: tensor<5x6xf32>, %arg1: tensor<?x6xf32>) -> tensor<5x?xf32> { 28 // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%arg0, %arg1) 29 // CHECK-SAME: dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>, 30 // CHECK-SAME: precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>] 31 %0 = mhlo.convolution(%arg1, %arg0) dim_numbers = [b, f]x[o, i]->[f, b], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<?x6xf32>, tensor<5x6xf32>) -> tensor<5x?xf32> 32 // CHECK: %[[DOT]] 33 return %0 : tensor<5x?xf32> 34} 35 36// ----- 37 38// CHECK-LABEL: @conv_grouped_is_dot 39func.func @conv_grouped_is_dot(%arg0: tensor<5x12xf32>, %arg1: tensor<2x6xf32>) -> tensor<5x6xf32> { 40 // CHECK: %[[RES0:.+]] = mhlo.reshape %arg0 : (tensor<5x12xf32>) -> tensor<5x6x2xf32> 41 // CHECK: %[[RES1:.+]] = mhlo.reshape %arg1 : (tensor<2x6xf32>) -> tensor<6x1x2xf32> 42 // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[RES0]], %[[RES1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [1], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [2]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} 43 // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%2) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} 44 // CHECK: %[[OUT:.+]] = mhlo.reshape %3 : (tensor<5x6x1xf32>) -> tensor<5x6xf32> 45 %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f]x[i, o]->[b, f], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 6 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<5x12xf32>, tensor<2x6xf32>) -> tensor<5x6xf32> 46 // CHECK: return %[[OUT]] 47 return %0 : tensor<5x6xf32> 48} 49 50// ----- 51 52// CHECK-LABEL: conv_grouped_is_dot_multi 53func.func @conv_grouped_is_dot_multi(%arg0: tensor<5x4xf32>, %arg1: tensor<2x6xf32>) -> tensor<5x6xf32> { 54 // CHECK: %[[LHS:.+]] = mhlo.reshape %arg0 : (tensor<5x4xf32>) -> tensor<5x2x2xf32> 55 // CHECK: %[[RHS:.+]] = mhlo.reshape %arg1 : (tensor<2x6xf32>) -> tensor<2x3x2xf32> 56 // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [1], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [2]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} 57 // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} 58 // CHECK: %[[OUT:.+]] = mhlo.reshape %[[TRANSPOSE]] : (tensor<5x2x3xf32>) -> tensor<5x6xf32> 59 %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f]x[i, o]->[b, f], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<5x4xf32>, tensor<2x6xf32>) -> tensor<5x6xf32> 60 // CHECK: return %[[OUT]] 61 return %0 : tensor<5x6xf32> 62} 63 64// ----- 65 66// CHECK-LABEL: conv_grouped_is_dot_transpose_rhs 67func.func @conv_grouped_is_dot_transpose_rhs(%arg0: tensor<5x4xf32>, %arg1: tensor<6x2xf32>) -> tensor<5x6xf32> { 68 // CHECK: %[[LHS:.+]] = mhlo.reshape %arg0 : (tensor<5x4xf32>) -> tensor<5x2x2xf32> 69 // CHECK: %[[RHS:.+]] = mhlo.reshape %arg1 : (tensor<6x2xf32>) -> tensor<2x2x3xf32> 70 // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [1], rhs_batching_dimensions = [1], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} 71 // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} 72 // CHECK: %[[OUT:.+]] = mhlo.reshape %[[TRANSPOSE]] : (tensor<5x2x3xf32>) -> tensor<5x6xf32> 73 %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f]x[o, i]->[b, f], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<5x4xf32>, tensor<6x2xf32>) -> tensor<5x6xf32> 74 // CHECK: return %[[OUT]] 75 return %0 : tensor<5x6xf32> 76} 77 78// ----- 79 80// CHECK-LABEL: conv_grouped_is_dot_transpose_ins 81func.func @conv_grouped_is_dot_transpose_ins(%arg0: tensor<4x5xf32>, %arg1: tensor<6x2xf32>) -> tensor<5x6xf32> { 82 // CHECK: %[[LHS:.+]] = mhlo.reshape %arg0 : (tensor<4x5xf32>) -> tensor<2x2x5xf32> 83 // CHECK: %[[RHS:.+]] = mhlo.reshape %arg1 : (tensor<6x2xf32>) -> tensor<2x2x3xf32> 84 // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [1], lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} 85 // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} 86 // CHECK: %[[OUT:.+]] = mhlo.reshape %[[TRANSPOSE]] : (tensor<5x2x3xf32>) -> tensor<5x6xf32> 87 %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [f, b]x[o, i]->[b, f], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<4x5xf32>, tensor<6x2xf32>) -> tensor<5x6xf32> 88 // CHECK: return %[[OUT]] 89 return %0 : tensor<5x6xf32> 90} 91 92// ----- 93 94// CHECK-LABEL: conv_grouped_is_dot_transpose_out 95func.func @conv_grouped_is_dot_transpose_out(%arg0: tensor<5x4xf32>, %arg1: tensor<2x6xf32>) -> tensor<6x5xf32> { 96 // CHECK: %[[LHS:.+]] = mhlo.reshape %arg0 : (tensor<5x4xf32>) -> tensor<5x2x2xf32> 97 // CHECK: %[[RHS:.+]] = mhlo.reshape %arg1 : (tensor<2x6xf32>) -> tensor<2x3x2xf32> 98 // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [1], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [2]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} 99 // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} 100 // CHECK: %[[OUT:.+]] = mhlo.reshape %[[TRANSPOSE]] 101 %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f]x[i, o]->[f, b], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<5x4xf32>, tensor<2x6xf32>) -> tensor<6x5xf32> 102 // CHECK: return %[[OUT]] 103 return %0 : tensor<6x5xf32> 104}