1// RUN: lhlo-tfrt-opt %s              \
2// RUN:   -lmhlo-gpu-async-conversion \
3// RUN:   -gpu-async-region           \
4// RUN:   -async-gpu-tfrt-conversion  \
5// RUN: | FileCheck %s
6
7// CHECK:      func @gemm(
8// CHECK-SAME:   %arg0: !tfrt.chain,
9// CHECK-SAME:   %arg1: !tfrt_gpu.stream,
10// CHECK-SAME:   %arg2: !tfrt_gpu.buffer,
11// CHECK-SAME:   %arg3: !tfrt_gpu.buffer,
12// CHECK-SAME:   %arg4: !tfrt_gpu.buffer
13// CHECK-SAME: ) -> !tfrt.chain
14func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<100xi8>) {
15  // CHECK-NOT: cast
16  // CHECK-NOT: async.execute
17  // CHECK-NOT: memref.view
18
19  %c0 = constant 0 : index
20  %view = memref.view %output[%c0][] : memref<100xi8> to memref<5x5xf32>
21
22  // CHECK: [[M:%[0-9]+]] = tfrt.constant.i32 5
23  // CHECK: [[N:%[0-9]+]] = tfrt.constant.i32 5
24  // CHECK: [[K:%[0-9]+]] = tfrt.constant.i32 4
25  // CHECK: [[ALPHA:%[0-9]+]] = tfrt.constant.f32 5.000000e-01
26  // CHECK: [[LDA:%[0-9]+]] = tfrt.constant.i32 5
27  // CHECK: [[LDB:%[0-9]+]] = tfrt.constant.i32 4
28  // CHECK: [[BETA:%[0-9]+]] = tfrt.constant.f32 0.000000e+00
29  // CHECK: [[LDC:%[0-9]+]] = tfrt.constant.i32 5
30  // CHECK: [[ALGO:%[0-9]+]] = tfrt_gpu.blas.gemm.algo CUBLAS_GEMM_DEFAULT
31  // CHECK: [[HANDLE:%[0-9]+]] = tfrt_gpu.blas.create %arg1
32
33  // CHECK: [[CHAIN:%[0-9]+]] = tfrt_gpu.blas.gemm [[HANDLE]],
34  // CHECK-SAME: CUBLAS_OP_N, CUBLAS_OP_N, [[M]], [[N]], [[K]], [[ALPHA]],
35  // CHECK-SAME: %arg3, CUDA_R_32F, [[LDA]],
36  // CHECK-SAME: %arg2, CUDA_R_32F, [[LDB]], [[BETA]],
37  // CHECK-SAME: %arg4, CUDA_R_32F, [[LDC]],
38  // CHECK-SAME: CUDA_R_32F, [[ALGO]], %arg0
39
40  "lmhlo_gpu.gemm"(%lhs, %rhs, %view) { dot_dimension_numbers = {
41       lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
42       rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
43       lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>,
44       rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>},
45       alpha_real = 0.5,
46       alpha_imag = 0.0,
47       batch_size = 1,
48       lhs_stride = 20,
49       rhs_stride = 20}
50    : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> ()
51
52  // CHECK-NOT: cast
53  // CHECK: tfrt.return [[CHAIN]] : !tfrt.chain
54  "lmhlo.terminator"() : () -> ()
55}
56
57// CHECK:      func @gemm_batch(
58// CHECK-SAME:   %arg0: !tfrt.chain,
59// CHECK-SAME:   %arg1: !tfrt_gpu.stream,
60// CHECK-SAME:   %arg2: !tfrt_gpu.buffer,
61// CHECK-SAME:   %arg3: !tfrt_gpu.buffer,
62// CHECK-SAME:   %arg4: !tfrt_gpu.buffer
63// CHECK-SAME: ) -> !tfrt.chain
64func @gemm_batch(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>) {
65  // CHECK-NOT: cast
66  // CHECK-NOT: async.execute
67
68  // CHECK: [[M:%[0-9]+]] = tfrt.constant.i32 5
69  // CHECK: [[N:%[0-9]+]] = tfrt.constant.i32 5
70  // CHECK: [[K:%[0-9]+]] = tfrt.constant.i32 4
71  // CHECK: [[ALPHA:%[0-9]+]] = tfrt.constant.f32 5.000000e-01
72  // CHECK: [[LDA:%[0-9]+]] = tfrt.constant.i32 5
73  // CHECK: [[LDB:%[0-9]+]] = tfrt.constant.i32 4
74  // CHECK: [[BETA:%[0-9]+]] = tfrt.constant.f32 0.000000e+00
75  // CHECK: [[LDC:%[0-9]+]] = tfrt.constant.i32 5
76  // CHECK: [[ALGO:%[0-9]+]] = tfrt_gpu.blas.gemm.algo CUBLAS_GEMM_DEFAULT
77  // CHECK: [[HANDLE:%[0-9]+]] = tfrt_gpu.blas.create %arg1
78  // CHECK: [[STRIDEA:%[0-9]+]] = tfrt.constant.i64 20
79  // CHECK: [[STRIDEB:%[0-9]+]] = tfrt.constant.i64 20
80  // CHECK: [[STRIDEC:%[0-9]+]] = tfrt.constant.i64 25
81  // CHECK: [[BATCH:%[0-9]+]] = tfrt.constant.i32 42
82
83  // CHECK: [[CHAIN:%[0-9]+]] = tfrt_gpu.blas.gemm.batch [[HANDLE]],
84  // CHECK-SAME: CUBLAS_OP_N, CUBLAS_OP_N, [[M]], [[N]], [[K]], [[ALPHA]],
85  // CHECK-SAME: %arg3, CUDA_R_32F, [[LDA]], [[STRIDEA]],
86  // CHECK-SAME: %arg2, CUDA_R_32F, [[LDB]], [[STRIDEB]], [[BETA]],
87  // CHECK-SAME: %arg4, CUDA_R_32F, [[LDC]], [[STRIDEC]], [[BATCH]],
88  // CHECK-SAME: CUDA_R_32F, [[ALGO]], %arg0
89
90  "lmhlo_gpu.gemm"(%lhs, %rhs, %output) { dot_dimension_numbers = {
91       lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
92       rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
93       lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>,
94       rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>},
95       alpha_real = 0.5,
96       alpha_imag = 0.0,
97       batch_size = 42,
98       lhs_stride = 20,
99       rhs_stride = 20}
100    : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> ()
101
102  // CHECK-NOT: cast
103  // CHECK: tfrt.return [[CHAIN]] : !tfrt.chain
104  "lmhlo.terminator"() : () -> ()
105}
106
107// CHECK:      func @gemm_bias(
108// CHECK-SAME:   %arg0: !tfrt.chain,
109// CHECK-SAME:   %arg1: !tfrt_gpu.stream,
110// CHECK-SAME:   %arg2: !tfrt_gpu.buffer,
111// CHECK-SAME:   %arg3: !tfrt_gpu.buffer,
112// CHECK-SAME:   %arg4: !tfrt_gpu.buffer,
113// CHECK-SAME:   %arg5: !tfrt_gpu.buffer
114// CHECK-SAME: ) -> !tfrt.chain
115func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
116                %bias: memref<5x5xf32>, %output:memref<5x5xf32>) {
117  // CHECK-NOT: cast
118  // CHECK-NOT: async.execute
119
120  // CHECK: [[M:%[0-9]+]] = tfrt.constant.i32 5
121  // CHECK: [[N:%[0-9]+]] = tfrt.constant.i32 5
122  // CHECK: [[K:%[0-9]+]] = tfrt.constant.i32 4
123  // CHECK: [[ALPHA:%[0-9]+]] = tfrt.constant.f32 5.000000e-01
124  // CHECK: [[LDA:%[0-9]+]] = tfrt.constant.i32 5
125  // CHECK: [[LDB:%[0-9]+]] = tfrt.constant.i32 4
126  // CHECK: [[BETA:%[0-9]+]] = tfrt.constant.f32 1.000000e+00
127  // CHECK: [[LDC:%[0-9]+]] = tfrt.constant.i32 5
128  // CHECK: [[ALGO:%[0-9]+]] = tfrt_gpu.blas.gemm.algo CUBLAS_GEMM_DEFAULT
129  // CHECK: [[HANDLE:%[0-9]+]] = tfrt_gpu.blas.create %arg1
130
131  // CHECK: [[CHAIN:%[0-9]+]] = tfrt_gpu.blas.gemm [[HANDLE]],
132  // CHECK-SAME: CUBLAS_OP_N, CUBLAS_OP_N, [[M]], [[N]], [[K]], [[ALPHA]],
133  // CHECK-SAME: %arg3, CUDA_R_32F, [[LDA]],
134  // CHECK-SAME: %arg2, CUDA_R_32F, [[LDB]], [[BETA]],
135  // CHECK-SAME: %arg5, CUDA_R_32F, [[LDC]],
136  // CHECK-SAME: CUDA_R_32F, [[ALGO]], %arg0
137
138  "lmhlo_gpu.gemm_bias"(%lhs, %rhs, %bias, %output) { dot_dimension_numbers = {
139       lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
140       rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
141       lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>,
142       rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>},
143       alpha_real = 0.5,
144       alpha_imag = 0.0,
145       beta = 1.0,
146       batch_size = 1,
147       lhs_stride = 20,
148       rhs_stride = 20}
149    : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>, memref<5x5xf32>) -> ()
150
151  // CHECK-NOT: cast
152  // CHECK: tfrt.return [[CHAIN]] : !tfrt.chain
153  "lmhlo.terminator"() : () -> ()
154}
155
156// CHECK:      func @all_reduce(
157// CHECK-SAME:   %arg0: !tfrt.chain,
158// CHECK-SAME:   %arg1: !tfrt_gpu.stream,
159// CHECK-SAME:   %arg2: !tfrt_gpu.buffer,
160// CHECK-SAME:   %arg3: !tfrt_gpu.buffer,
161// CHECK-SAME:   %arg4: !tfrt_gpu.buffer,
162// CHECK-SAME:   %arg5: !tfrt_gpu.buffer
163// CHECK-SAME: ) -> !tfrt.chain
164func @all_reduce(%operand0: memref<2x2xf32>, %operand1: memref<2x2xf32>, %result0: memref<2x2xf32>, %result1: memref<2x2xf32>) {
165  // CHECK-NOT: cast
166  // CHECK-NOT: async.execute
167
168  // CHECK: [[CONTEXT:%[0-9]+]] = tfrt_gpu.stream.get_context %arg1
169  // CHECK: [[HANDLE:%[0-9]+]] = xlir.ccl.create [[CONTEXT]]
170  // CHECK: [[CHAIN1:%[0-9]+]] = tfrt_gpu.ccl.all_reduce [[HANDLE]],
171  // CHECK-SAME: %arg2, %arg4, 7, 0, %arg0
172  // CHECK: [[CHAIN2:%[0-9]+]] = tfrt_gpu.ccl.all_reduce [[HANDLE]],
173  // CHECK-SAME: %arg3, %arg5, 7, 0, [[CHAIN1]]
174  // CHECK: [[CHAIN3:%[0-9]+]] = tfrt_gpu.ccl.execute %arg1, [[HANDLE]],
175  // CHECK-SAME: [[CHAIN2]]
176
177  "lmhlo.all_reduce"(%operand0, %operand1, %result0, %result1) ( {
178      ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
179          %0 = mhlo.add %lhs, %rhs : tensor<f32>
180          "mhlo.return"(%0) : (tensor<f32>) -> ()
181      }) {
182          replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
183          channel_id = { handle = 5 : i64, type = 2 : i64 },
184          constrain_layout = true,
185          use_global_device_ids = true
186      } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
187
188  // CHECK-NOT: cast
189  // CHECK: tfrt.return [[CHAIN3]] : !tfrt.chain
190  "lmhlo.terminator"() : () -> ()
191}
192
193// CHECK:      func @two_ops(
194// CHECK-SAME:   %arg0: !tfrt.chain,
195// CHECK-SAME:   %arg1: !tfrt_gpu.stream,
196// CHECK-SAME:   %arg2: !tfrt_gpu.buffer
197// CHECK-SAME: ) -> !tfrt.chain
198func @two_ops(%memref: memref<4x4xf32>) {
199  // CHECK-NOT: cast
200  // CHECK-NOT: async.execute
201
202  // CHECK: tfrt.constant.f32 3.14159274
203  // CHECK: tfrt_gpu.blas.gemm
204  "lmhlo_gpu.gemm"(%memref, %memref, %memref) { dot_dimension_numbers = {
205       lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
206       rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
207       lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>,
208       rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>},
209       alpha_real = 3.14159274,
210       alpha_imag = 0.0,
211       batch_size = 1,
212       lhs_stride = 16,
213       rhs_stride = 16}
214    : (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> ()
215
216  // CHECK: tfrt.constant.f32 2.71828175
217  // CHECK: tfrt_gpu.blas.gemm
218  "lmhlo_gpu.gemm"(%memref, %memref, %memref) { dot_dimension_numbers = {
219       lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
220       rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
221       lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>,
222       rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>},
223       alpha_real = 2.71828175,
224       alpha_imag = 0.0,
225       batch_size = 1,
226       lhs_stride = 16,
227       rhs_stride = 16}
228    : (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> ()
229
230  // CHECK-NOT: cast
231  // CHECK: tfrt.return {{.*}} : !tfrt.chain
232  "lmhlo.terminator"() : () -> ()
233}
234
235// CHECK:      func @async(
236// CHECK-SAME:   %arg0: !tfrt.chain,
237// CHECK-SAME:   %arg1: !tfrt_gpu.stream,
238// CHECK-SAME:   %arg2: !tfrt_gpu.buffer
239// CHECK-SAME: ) -> !tfrt.chain
240func @async(%memref: memref<4x4xf32>) {
241  // CHECK-NOT: cast
242
243  // CHECK: %[[a0:.*]], %[[t0:.*]] = async.execute
244  // CHECK-SAME: -> !async.value<!tfrt_gpu.event>
245  %a0 = async.execute () {
246    // CHECK: %[[e0:.*]] = tfrt_gpu.event.create
247    // CHECK: %[[ch0:.*]] = tfrt_gpu.event.record %[[e0]], %arg1
248    // CHECK: %[[s0:.*]] = tfrt_gpu.stream.create
249    // CHECK: %[[ch1:.*]] = tfrt_gpu.stream.wait %[[s0]], %[[e0]]
250    // CHECK: %[[h0:.*]] = tfrt_gpu.blas.create %[[s0]]
251    // CHECK: %[[ch2:.*]] = tfrt_gpu.blas.gemm %[[h0]]
252    // CHECK-SAME: %[[ch1]]
253    "lmhlo_gpu.gemm"(%memref, %memref, %memref) { dot_dimension_numbers = {
254         lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
255         rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
256         lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>,
257         rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>},
258         alpha_real = 0.5,
259         alpha_imag = 0.0,
260         batch_size = 1,
261         lhs_stride = 16,
262         rhs_stride = 16}
263      : (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> ()
264    // CHECK: %[[e1:.*]] = tfrt_gpu.event.create
265    // CHECK: %[[ch3:.*]] = tfrt_gpu.event.record %[[e1]], %[[s0]], %[[ch2]]
266    // CHECK: async.yield %[[e1]] : !tfrt_gpu.event
267    async.yield
268  }
269
270  // CHECK: %[[a1:.*]], %[[t1:.*]] = async.execute [%[[a0]]]
271  // CHECK-SAME: (%[[t0]] as %[[e2:.*]]:
272  // CHECK-SAME: !async.value<!tfrt_gpu.event>) -> !async.value<!tfrt_gpu.event>
273  %a1 = async.execute [%a0] () {
274    // CHECK: %[[s1:.*]] = tfrt_gpu.stream.create
275    // CHECK: %[[ch4:.*]] = tfrt_gpu.stream.wait %[[s1]], %[[e2]]
276    // CHECK: %[[h1:.*]] = tfrt_gpu.blas.create %[[s1]]
277    // CHECK: %[[ch5:.*]] = tfrt_gpu.blas.gemm %[[h1]]
278    // CHECK-SAME: %[[ch4]]
279    "lmhlo_gpu.gemm"(%memref, %memref, %memref) { dot_dimension_numbers = {
280         lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
281         rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
282         lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>,
283         rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>},
284         alpha_real = 0.5,
285         alpha_imag = 0.0,
286         batch_size = 1,
287         lhs_stride = 16,
288         rhs_stride = 16}
289      : (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> ()
290    // CHECK: %[[e3:.*]] = tfrt_gpu.event.create
291    // CHECK: %[[ch6:.*]] = tfrt_gpu.event.record %[[e3]], %[[s1]], %[[ch5]]
292    // CHECK: async.yield %[[e3]] : !tfrt_gpu.event
293    async.yield
294  }
295
296  // CHECK: async.await %[[a1]] : !async.token
297  // CHECK: %[[e4:.*]] = async.await %[[t1]] : !async.value<!tfrt_gpu.event>
298  // CHECK: %[[ch7:.*]] = tfrt_gpu.stream.wait %arg1, %[[e4]]
299  async.await %a1 : !async.token
300
301  // CHECK: %[[h2:.*]] = tfrt_gpu.blas.create %arg1
302  // CHECK: %[[ch8:.*]] = tfrt_gpu.blas.gemm %[[h2]]
303  // CHECK-SAME: %[[ch7]]
304  "lmhlo_gpu.gemm"(%memref, %memref, %memref) { dot_dimension_numbers = {
305       lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
306       rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
307       lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>,
308       rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>},
309       alpha_real = 0.5,
310       alpha_imag = 0.0,
311       batch_size = 1,
312       lhs_stride = 16,
313       rhs_stride = 16}
314    : (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> ()
315
316  // CHECK: tfrt.return %[[ch8]] : !tfrt.chain
317  "lmhlo.terminator"() : () -> ()
318}
319