1// RUN: lhlo-tfrt-opt %s \ 2// RUN: -lmhlo-gpu-async-conversion \ 3// RUN: -gpu-async-region \ 4// RUN: -async-gpu-tfrt-conversion \ 5// RUN: | FileCheck %s 6 7// CHECK: func @gemm( 8// CHECK-SAME: %arg0: !tfrt.chain, 9// CHECK-SAME: %arg1: !tfrt_gpu.stream, 10// CHECK-SAME: %arg2: !tfrt_gpu.buffer, 11// CHECK-SAME: %arg3: !tfrt_gpu.buffer, 12// CHECK-SAME: %arg4: !tfrt_gpu.buffer 13// CHECK-SAME: ) -> !tfrt.chain 14func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<100xi8>) { 15 // CHECK-NOT: cast 16 // CHECK-NOT: async.execute 17 // CHECK-NOT: memref.view 18 19 %c0 = constant 0 : index 20 %view = memref.view %output[%c0][] : memref<100xi8> to memref<5x5xf32> 21 22 // CHECK: [[M:%[0-9]+]] = tfrt.constant.i32 5 23 // CHECK: [[N:%[0-9]+]] = tfrt.constant.i32 5 24 // CHECK: [[K:%[0-9]+]] = tfrt.constant.i32 4 25 // CHECK: [[ALPHA:%[0-9]+]] = tfrt.constant.f32 5.000000e-01 26 // CHECK: [[LDA:%[0-9]+]] = tfrt.constant.i32 5 27 // CHECK: [[LDB:%[0-9]+]] = tfrt.constant.i32 4 28 // CHECK: [[BETA:%[0-9]+]] = tfrt.constant.f32 0.000000e+00 29 // CHECK: [[LDC:%[0-9]+]] = tfrt.constant.i32 5 30 // CHECK: [[ALGO:%[0-9]+]] = tfrt_gpu.blas.gemm.algo CUBLAS_GEMM_DEFAULT 31 // CHECK: [[HANDLE:%[0-9]+]] = tfrt_gpu.blas.create %arg1 32 33 // CHECK: [[CHAIN:%[0-9]+]] = tfrt_gpu.blas.gemm [[HANDLE]], 34 // CHECK-SAME: CUBLAS_OP_N, CUBLAS_OP_N, [[M]], [[N]], [[K]], [[ALPHA]], 35 // CHECK-SAME: %arg3, CUDA_R_32F, [[LDA]], 36 // CHECK-SAME: %arg2, CUDA_R_32F, [[LDB]], [[BETA]], 37 // CHECK-SAME: %arg4, CUDA_R_32F, [[LDC]], 38 // CHECK-SAME: CUDA_R_32F, [[ALGO]], %arg0 39 40 "lmhlo_gpu.gemm"(%lhs, %rhs, %view) { dot_dimension_numbers = { 41 lhs_batching_dimensions = dense<[]> : tensor<0xi64>, 42 rhs_batching_dimensions = dense<[]> : tensor<0xi64>, 43 lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>, 44 rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>}, 45 alpha_real = 0.5, 46 alpha_imag = 0.0, 47 batch_size = 1, 48 lhs_stride = 20, 49 rhs_stride = 20} 50 : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> () 51 52 // CHECK-NOT: cast 53 // CHECK: tfrt.return [[CHAIN]] : !tfrt.chain 54 "lmhlo.terminator"() : () -> () 55} 56 57// CHECK: func @gemm_batch( 58// CHECK-SAME: %arg0: !tfrt.chain, 59// CHECK-SAME: %arg1: !tfrt_gpu.stream, 60// CHECK-SAME: %arg2: !tfrt_gpu.buffer, 61// CHECK-SAME: %arg3: !tfrt_gpu.buffer, 62// CHECK-SAME: %arg4: !tfrt_gpu.buffer 63// CHECK-SAME: ) -> !tfrt.chain 64func @gemm_batch(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>) { 65 // CHECK-NOT: cast 66 // CHECK-NOT: async.execute 67 68 // CHECK: [[M:%[0-9]+]] = tfrt.constant.i32 5 69 // CHECK: [[N:%[0-9]+]] = tfrt.constant.i32 5 70 // CHECK: [[K:%[0-9]+]] = tfrt.constant.i32 4 71 // CHECK: [[ALPHA:%[0-9]+]] = tfrt.constant.f32 5.000000e-01 72 // CHECK: [[LDA:%[0-9]+]] = tfrt.constant.i32 5 73 // CHECK: [[LDB:%[0-9]+]] = tfrt.constant.i32 4 74 // CHECK: [[BETA:%[0-9]+]] = tfrt.constant.f32 0.000000e+00 75 // CHECK: [[LDC:%[0-9]+]] = tfrt.constant.i32 5 76 // CHECK: [[ALGO:%[0-9]+]] = tfrt_gpu.blas.gemm.algo CUBLAS_GEMM_DEFAULT 77 // CHECK: [[HANDLE:%[0-9]+]] = tfrt_gpu.blas.create %arg1 78 // CHECK: [[STRIDEA:%[0-9]+]] = tfrt.constant.i64 20 79 // CHECK: [[STRIDEB:%[0-9]+]] = tfrt.constant.i64 20 80 // CHECK: [[STRIDEC:%[0-9]+]] = tfrt.constant.i64 25 81 // CHECK: [[BATCH:%[0-9]+]] = tfrt.constant.i32 42 82 83 // CHECK: [[CHAIN:%[0-9]+]] = tfrt_gpu.blas.gemm.batch [[HANDLE]], 84 // CHECK-SAME: CUBLAS_OP_N, CUBLAS_OP_N, [[M]], [[N]], [[K]], [[ALPHA]], 85 // CHECK-SAME: %arg3, CUDA_R_32F, [[LDA]], [[STRIDEA]], 86 // CHECK-SAME: %arg2, CUDA_R_32F, [[LDB]], [[STRIDEB]], [[BETA]], 87 // CHECK-SAME: %arg4, CUDA_R_32F, [[LDC]], [[STRIDEC]], [[BATCH]], 88 // CHECK-SAME: CUDA_R_32F, [[ALGO]], %arg0 89 90 "lmhlo_gpu.gemm"(%lhs, %rhs, %output) { dot_dimension_numbers = { 91 lhs_batching_dimensions = dense<[]> : tensor<0xi64>, 92 rhs_batching_dimensions = dense<[]> : tensor<0xi64>, 93 lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>, 94 rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>}, 95 alpha_real = 0.5, 96 alpha_imag = 0.0, 97 batch_size = 42, 98 lhs_stride = 20, 99 rhs_stride = 20} 100 : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> () 101 102 // CHECK-NOT: cast 103 // CHECK: tfrt.return [[CHAIN]] : !tfrt.chain 104 "lmhlo.terminator"() : () -> () 105} 106 107// CHECK: func @gemm_bias( 108// CHECK-SAME: %arg0: !tfrt.chain, 109// CHECK-SAME: %arg1: !tfrt_gpu.stream, 110// CHECK-SAME: %arg2: !tfrt_gpu.buffer, 111// CHECK-SAME: %arg3: !tfrt_gpu.buffer, 112// CHECK-SAME: %arg4: !tfrt_gpu.buffer, 113// CHECK-SAME: %arg5: !tfrt_gpu.buffer 114// CHECK-SAME: ) -> !tfrt.chain 115func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, 116 %bias: memref<5x5xf32>, %output:memref<5x5xf32>) { 117 // CHECK-NOT: cast 118 // CHECK-NOT: async.execute 119 120 // CHECK: [[M:%[0-9]+]] = tfrt.constant.i32 5 121 // CHECK: [[N:%[0-9]+]] = tfrt.constant.i32 5 122 // CHECK: [[K:%[0-9]+]] = tfrt.constant.i32 4 123 // CHECK: [[ALPHA:%[0-9]+]] = tfrt.constant.f32 5.000000e-01 124 // CHECK: [[LDA:%[0-9]+]] = tfrt.constant.i32 5 125 // CHECK: [[LDB:%[0-9]+]] = tfrt.constant.i32 4 126 // CHECK: [[BETA:%[0-9]+]] = tfrt.constant.f32 1.000000e+00 127 // CHECK: [[LDC:%[0-9]+]] = tfrt.constant.i32 5 128 // CHECK: [[ALGO:%[0-9]+]] = tfrt_gpu.blas.gemm.algo CUBLAS_GEMM_DEFAULT 129 // CHECK: [[HANDLE:%[0-9]+]] = tfrt_gpu.blas.create %arg1 130 131 // CHECK: [[CHAIN:%[0-9]+]] = tfrt_gpu.blas.gemm [[HANDLE]], 132 // CHECK-SAME: CUBLAS_OP_N, CUBLAS_OP_N, [[M]], [[N]], [[K]], [[ALPHA]], 133 // CHECK-SAME: %arg3, CUDA_R_32F, [[LDA]], 134 // CHECK-SAME: %arg2, CUDA_R_32F, [[LDB]], [[BETA]], 135 // CHECK-SAME: %arg5, CUDA_R_32F, [[LDC]], 136 // CHECK-SAME: CUDA_R_32F, [[ALGO]], %arg0 137 138 "lmhlo_gpu.gemm_bias"(%lhs, %rhs, %bias, %output) { dot_dimension_numbers = { 139 lhs_batching_dimensions = dense<[]> : tensor<0xi64>, 140 rhs_batching_dimensions = dense<[]> : tensor<0xi64>, 141 lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>, 142 rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>}, 143 alpha_real = 0.5, 144 alpha_imag = 0.0, 145 beta = 1.0, 146 batch_size = 1, 147 lhs_stride = 20, 148 rhs_stride = 20} 149 : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>, memref<5x5xf32>) -> () 150 151 // CHECK-NOT: cast 152 // CHECK: tfrt.return [[CHAIN]] : !tfrt.chain 153 "lmhlo.terminator"() : () -> () 154} 155 156// CHECK: func @all_reduce( 157// CHECK-SAME: %arg0: !tfrt.chain, 158// CHECK-SAME: %arg1: !tfrt_gpu.stream, 159// CHECK-SAME: %arg2: !tfrt_gpu.buffer, 160// CHECK-SAME: %arg3: !tfrt_gpu.buffer, 161// CHECK-SAME: %arg4: !tfrt_gpu.buffer, 162// CHECK-SAME: %arg5: !tfrt_gpu.buffer 163// CHECK-SAME: ) -> !tfrt.chain 164func @all_reduce(%operand0: memref<2x2xf32>, %operand1: memref<2x2xf32>, %result0: memref<2x2xf32>, %result1: memref<2x2xf32>) { 165 // CHECK-NOT: cast 166 // CHECK-NOT: async.execute 167 168 // CHECK: [[CONTEXT:%[0-9]+]] = tfrt_gpu.stream.get_context %arg1 169 // CHECK: [[HANDLE:%[0-9]+]] = xlir.ccl.create [[CONTEXT]] 170 // CHECK: [[CHAIN1:%[0-9]+]] = tfrt_gpu.ccl.all_reduce [[HANDLE]], 171 // CHECK-SAME: %arg2, %arg4, 7, 0, %arg0 172 // CHECK: [[CHAIN2:%[0-9]+]] = tfrt_gpu.ccl.all_reduce [[HANDLE]], 173 // CHECK-SAME: %arg3, %arg5, 7, 0, [[CHAIN1]] 174 // CHECK: [[CHAIN3:%[0-9]+]] = tfrt_gpu.ccl.execute %arg1, [[HANDLE]], 175 // CHECK-SAME: [[CHAIN2]] 176 177 "lmhlo.all_reduce"(%operand0, %operand1, %result0, %result1) ( { 178 ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): 179 %0 = mhlo.add %lhs, %rhs : tensor<f32> 180 "mhlo.return"(%0) : (tensor<f32>) -> () 181 }) { 182 replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, 183 channel_id = { handle = 5 : i64, type = 2 : i64 }, 184 constrain_layout = true, 185 use_global_device_ids = true 186 } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () 187 188 // CHECK-NOT: cast 189 // CHECK: tfrt.return [[CHAIN3]] : !tfrt.chain 190 "lmhlo.terminator"() : () -> () 191} 192 193// CHECK: func @two_ops( 194// CHECK-SAME: %arg0: !tfrt.chain, 195// CHECK-SAME: %arg1: !tfrt_gpu.stream, 196// CHECK-SAME: %arg2: !tfrt_gpu.buffer 197// CHECK-SAME: ) -> !tfrt.chain 198func @two_ops(%memref: memref<4x4xf32>) { 199 // CHECK-NOT: cast 200 // CHECK-NOT: async.execute 201 202 // CHECK: tfrt.constant.f32 3.14159274 203 // CHECK: tfrt_gpu.blas.gemm 204 "lmhlo_gpu.gemm"(%memref, %memref, %memref) { dot_dimension_numbers = { 205 lhs_batching_dimensions = dense<[]> : tensor<0xi64>, 206 rhs_batching_dimensions = dense<[]> : tensor<0xi64>, 207 lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>, 208 rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>}, 209 alpha_real = 3.14159274, 210 alpha_imag = 0.0, 211 batch_size = 1, 212 lhs_stride = 16, 213 rhs_stride = 16} 214 : (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> () 215 216 // CHECK: tfrt.constant.f32 2.71828175 217 // CHECK: tfrt_gpu.blas.gemm 218 "lmhlo_gpu.gemm"(%memref, %memref, %memref) { dot_dimension_numbers = { 219 lhs_batching_dimensions = dense<[]> : tensor<0xi64>, 220 rhs_batching_dimensions = dense<[]> : tensor<0xi64>, 221 lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>, 222 rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>}, 223 alpha_real = 2.71828175, 224 alpha_imag = 0.0, 225 batch_size = 1, 226 lhs_stride = 16, 227 rhs_stride = 16} 228 : (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> () 229 230 // CHECK-NOT: cast 231 // CHECK: tfrt.return {{.*}} : !tfrt.chain 232 "lmhlo.terminator"() : () -> () 233} 234 235// CHECK: func @async( 236// CHECK-SAME: %arg0: !tfrt.chain, 237// CHECK-SAME: %arg1: !tfrt_gpu.stream, 238// CHECK-SAME: %arg2: !tfrt_gpu.buffer 239// CHECK-SAME: ) -> !tfrt.chain 240func @async(%memref: memref<4x4xf32>) { 241 // CHECK-NOT: cast 242 243 // CHECK: %[[a0:.*]], %[[t0:.*]] = async.execute 244 // CHECK-SAME: -> !async.value<!tfrt_gpu.event> 245 %a0 = async.execute () { 246 // CHECK: %[[e0:.*]] = tfrt_gpu.event.create 247 // CHECK: %[[ch0:.*]] = tfrt_gpu.event.record %[[e0]], %arg1 248 // CHECK: %[[s0:.*]] = tfrt_gpu.stream.create 249 // CHECK: %[[ch1:.*]] = tfrt_gpu.stream.wait %[[s0]], %[[e0]] 250 // CHECK: %[[h0:.*]] = tfrt_gpu.blas.create %[[s0]] 251 // CHECK: %[[ch2:.*]] = tfrt_gpu.blas.gemm %[[h0]] 252 // CHECK-SAME: %[[ch1]] 253 "lmhlo_gpu.gemm"(%memref, %memref, %memref) { dot_dimension_numbers = { 254 lhs_batching_dimensions = dense<[]> : tensor<0xi64>, 255 rhs_batching_dimensions = dense<[]> : tensor<0xi64>, 256 lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>, 257 rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>}, 258 alpha_real = 0.5, 259 alpha_imag = 0.0, 260 batch_size = 1, 261 lhs_stride = 16, 262 rhs_stride = 16} 263 : (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> () 264 // CHECK: %[[e1:.*]] = tfrt_gpu.event.create 265 // CHECK: %[[ch3:.*]] = tfrt_gpu.event.record %[[e1]], %[[s0]], %[[ch2]] 266 // CHECK: async.yield %[[e1]] : !tfrt_gpu.event 267 async.yield 268 } 269 270 // CHECK: %[[a1:.*]], %[[t1:.*]] = async.execute [%[[a0]]] 271 // CHECK-SAME: (%[[t0]] as %[[e2:.*]]: 272 // CHECK-SAME: !async.value<!tfrt_gpu.event>) -> !async.value<!tfrt_gpu.event> 273 %a1 = async.execute [%a0] () { 274 // CHECK: %[[s1:.*]] = tfrt_gpu.stream.create 275 // CHECK: %[[ch4:.*]] = tfrt_gpu.stream.wait %[[s1]], %[[e2]] 276 // CHECK: %[[h1:.*]] = tfrt_gpu.blas.create %[[s1]] 277 // CHECK: %[[ch5:.*]] = tfrt_gpu.blas.gemm %[[h1]] 278 // CHECK-SAME: %[[ch4]] 279 "lmhlo_gpu.gemm"(%memref, %memref, %memref) { dot_dimension_numbers = { 280 lhs_batching_dimensions = dense<[]> : tensor<0xi64>, 281 rhs_batching_dimensions = dense<[]> : tensor<0xi64>, 282 lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>, 283 rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>}, 284 alpha_real = 0.5, 285 alpha_imag = 0.0, 286 batch_size = 1, 287 lhs_stride = 16, 288 rhs_stride = 16} 289 : (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> () 290 // CHECK: %[[e3:.*]] = tfrt_gpu.event.create 291 // CHECK: %[[ch6:.*]] = tfrt_gpu.event.record %[[e3]], %[[s1]], %[[ch5]] 292 // CHECK: async.yield %[[e3]] : !tfrt_gpu.event 293 async.yield 294 } 295 296 // CHECK: async.await %[[a1]] : !async.token 297 // CHECK: %[[e4:.*]] = async.await %[[t1]] : !async.value<!tfrt_gpu.event> 298 // CHECK: %[[ch7:.*]] = tfrt_gpu.stream.wait %arg1, %[[e4]] 299 async.await %a1 : !async.token 300 301 // CHECK: %[[h2:.*]] = tfrt_gpu.blas.create %arg1 302 // CHECK: %[[ch8:.*]] = tfrt_gpu.blas.gemm %[[h2]] 303 // CHECK-SAME: %[[ch7]] 304 "lmhlo_gpu.gemm"(%memref, %memref, %memref) { dot_dimension_numbers = { 305 lhs_batching_dimensions = dense<[]> : tensor<0xi64>, 306 rhs_batching_dimensions = dense<[]> : tensor<0xi64>, 307 lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>, 308 rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>}, 309 alpha_real = 0.5, 310 alpha_imag = 0.0, 311 batch_size = 1, 312 lhs_stride = 16, 313 rhs_stride = 16} 314 : (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> () 315 316 // CHECK: tfrt.return %[[ch8]] : !tfrt.chain 317 "lmhlo.terminator"() : () -> () 318} 319