1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17 #include <utility>
18
19 #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
20 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
21 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
22 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
25 #include "tensorflow/compiler/xla/service/hlo_parser.h"
26 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/tests/filecheck.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/xla.pb.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/stream_executor/lib/statusor.h"
35
36 namespace xla {
37 namespace gpu {
38
39 namespace {
40
41 namespace m = ::xla::match;
42
43 class GemmRewriteTest : public GpuCodegenTest {
44 public:
CheckNumberOfAllocations(const std::string & hlo,int expected_number_of_allocations)45 void CheckNumberOfAllocations(const std::string& hlo,
46 int expected_number_of_allocations) {
47 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
48 GetOptimizedModule(hlo));
49 TF_ASSERT_OK_AND_ASSIGN(
50 std::unique_ptr<Executable> executable,
51 backend().compiler()->RunBackend(
52 std::move(optimized_module), backend().default_stream_executor(),
53 backend().default_stream_executor()->GetAllocator()));
54 GpuExecutable* gpu_executable =
55 static_cast<GpuExecutable*>(executable.get());
56 absl::Span<const BufferAllocation> allocations =
57 gpu_executable->GetAllocations();
58 CHECK_EQ(allocations.size(), expected_number_of_allocations);
59 }
60
GetCudaComputeCapability()61 se::CudaComputeCapability GetCudaComputeCapability() {
62 return backend()
63 .default_stream_executor()
64 ->GetDeviceDescription()
65 .cuda_compute_capability();
66 }
67 };
68
TEST_F(GemmRewriteTest,SimpleRewrite)69 TEST_F(GemmRewriteTest, SimpleRewrite) {
70 const char* hlo_text = R"(
71 HloModule SimpleGemm
72
73 ENTRY AddDotsFunc {
74 x = f32[2,3] parameter(0)
75 y = f32[3,4] parameter(1)
76 ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
77 }
78
79 )";
80
81 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
82 MatchOptimizedHlo(hlo_text,
83 R"(
84
85 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,3], y: f32[3,4]) -> f32[2,4] {
86 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
87 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
88 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]),
89 ; CHECK: custom_call_target="__cublas$gemm",
90 ; CHECK: backend_config="{
91 ; CHECK-DAG: \"alpha_real\":1
92 ; CHECK-DAG: \"alpha_imag\":0
93 ; CHECK-DAG: \"beta\":0
94 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
95 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
96 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
97 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
98 )");
99 }
100
101 TEST_F(GemmRewriteTest, TestBatchedAutotuning) {
102 const char* hlo_text = R"(
103 HloModule ComplexDotMultipleNonContracting
104
105 ENTRY %test {
106 %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0)
107 %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1)
108 ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3}
109 }
110
111 )";
112
113 MatchOptimizedHlo(hlo_text,
114 R"(
115 ; CHECK: selected_algorithm
116 )");
117 }
118
119 TEST_F(GemmRewriteTest, SimpleRewriteDeterministic) {
120 const char* hlo_text = R"(
121 HloModule SimpleGemm
122
123 ENTRY AddDotsFunc {
124 x = f32[128,128] parameter(0)
125 y = f32[128,128] parameter(1)
126 ROOT dot_a = f32[128,128] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
127 }
128 )";
129
130 auto get_module = [&]() -> StatusOr<std::unique_ptr<HloModule>> {
131 HloModuleConfig config;
132 DebugOptions debug_options = GetDebugOptionsForTest();
133 debug_options.set_xla_gpu_deterministic_ops(true);
134 config.set_debug_options(debug_options);
135 return ParseAndReturnVerifiedModule(hlo_text, config);
136 };
137
138 TF_ASSERT_OK_AND_ASSIGN(
139 std::unique_ptr<HloModule> optimized_module,
140 backend().compiler()->RunHloPasses(
141 *get_module(), backend().default_stream_executor(),
142 backend().default_stream_executor()->GetAllocator()));
143 DebugOptions debug_options = GetDebugOptionsForTest();
144 if (!debug_options.xla_gpu_enable_cublaslt()) {
145 StatusOr<bool> filecheck_result_cublas =
146 RunFileCheck(optimized_module->ToString(),
147 R"(
148 ; CHECK: \"selected_algorithm\":\"-1\"
149 )");
150 TF_ASSERT_OK(filecheck_result_cublas.status());
151 EXPECT_TRUE(filecheck_result_cublas.ValueOrDie());
152 EXPECT_TRUE(RunAndCompare(*get_module(), ErrorSpec{1e-5, 1e-5}));
153 } else {
154 // With cublaslt enabled, selected_algorithm is se::blas::kNoAlgorithm
155 StatusOr<bool> filecheck_result_cublas =
156 RunFileCheck(optimized_module->ToString(),
157 R"(
158 ; CHECK: \"selected_algorithm\":\"-4\"
159 )");
160 TF_ASSERT_OK(filecheck_result_cublas.status());
161 EXPECT_TRUE(filecheck_result_cublas.ValueOrDie());
162 EXPECT_TRUE(RunAndCompare(*get_module(), ErrorSpec{1e-3, 1e-5}));
163 }
164 }
165
166 TEST_F(GemmRewriteTest, MultipleContractingDims) {
167 const char* hlo_text = R"(
168 HloModule MultipleContractingCheckGemm
169
170 ENTRY AddDotsFunc {
171 x = f32[3,4,2] parameter(0)
172 y = f32[3,4,5] parameter(1)
173 ROOT dot_a = f32[2,5] dot(x, y), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}
174 }
175
176 )";
177
178 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
179 MatchOptimizedHlo(hlo_text,
180 R"(
181 ; CHECK-NOT: copy
182 ;
183 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,4,2], y: f32[3,4,5]) -> f32[2,5] {
184 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,4,2]{2,1,0} parameter(0)
185 ; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4,5]{2,1,0} parameter(1)
186 ; CHECK-DAG: [[BITCAST0:%[^ ]+]] = f32[2,12]{0,1} bitcast([[P0]])
187 ; CHECK-DAG: [[BITCAST1:%[^ ]+]] = f32[12,5]{1,0} bitcast([[P1]])
188 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,5]{1,0} custom-call([[BITCAST0]], [[BITCAST1]]),
189 ; CHECK: custom_call_target="__cublas$gemm",
190 ; CHECK: backend_config="{
191 ; CHECK-DAG: \"alpha_real\":1
192 ; CHECK-DAG: \"alpha_imag\":0
193 ; CHECK-DAG: \"beta\":0
194 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
195 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
196 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
197 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
198 )");
199 }
200
201 TEST_F(GemmRewriteTest, ArgTransposeFoldCheck) {
202 const char* hlo_text = R"(
203 HloModule ArgTransposeFoldGemm
204
205 ENTRY AddDotsFunc {
206 x = f32[3,2] parameter(0)
207 y = f32[3,4] parameter(1)
208 x_transposed = f32[2,3] transpose(x), dimensions={1, 0}
209 ROOT dot_a = f32[2,4] dot(x_transposed, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
210 }
211
212 )";
213
214 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
215 MatchOptimizedHlo(hlo_text,
216 R"(
217
218 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2], y: f32[3,4]) -> f32[2,4] {
219 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,2]{1,0} parameter(0)
220 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
221 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]),
222 ; CHECK: custom_call_target="__cublas$gemm",
223 ; CHECK: backend_config="{
224 ; CHECK-DAG: \"alpha_real\":1
225 ; CHECK-DAG: \"alpha_imag\":0
226 ; CHECK-DAG: \"beta\":0
227 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
228 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
229 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
230 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
231 )");
232 }
233
234 TEST_F(GemmRewriteTest, BatchedArgRowColTransposeFoldCheck) {
235 const char* hlo_text = R"(
236 HloModule BatchedArgRowColTransposeFoldGemm
237
238 ENTRY AddDotsFunc {
239 x = f32[5,3,2] parameter(0)
240 y = f32[5,3,4] parameter(1)
241 x_transposed = f32[5,2,3] transpose(x), dimensions={0, 2, 1}
242 ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
243 }
244
245 )";
246
247 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
248 MatchOptimizedHlo(hlo_text,
249 R"(
250
251 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,3,2], y: f32[5,3,4]) -> f32[5,2,4] {
252 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,3,2]{2,1,0} parameter(0)
253 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
254 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]),
255 ; CHECK: custom_call_target="__cublas$gemm",
256 ; CHECK: backend_config="{
257 ; CHECK-DAG: \"alpha_real\":1
258 ; CHECK-DAG: \"alpha_imag\":0
259 ; CHECK-DAG: \"beta\":0
260 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]}
261 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
262 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
263 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
264 )");
265 }
266
267 TEST_F(GemmRewriteTest, BatchRowTransposeFoldCheck) {
268 const char* hlo_text = R"(
269 HloModule BatchRowTransposeFoldCheck
270
271 ENTRY AddDotsFunc {
272 x = f32[2,5,3] parameter(0)
273 y = f32[5,3,4] parameter(1)
274 x_transposed = f32[5,2,3] transpose(x), dimensions={1, 0, 2}
275 ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
276 }
277
278 )";
279
280 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
281 MatchOptimizedHlo(hlo_text,
282 R"(
283
284 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,5,3], y: f32[5,3,4]) -> f32[5,2,4] {
285 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,5,3]{2,1,0} parameter(0)
286 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
287 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]),
288 ; CHECK: custom_call_target="__cublas$gemm",
289 ; CHECK: backend_config="{
290 ; CHECK-DAG: \"alpha_real\":1
291 ; CHECK-DAG: \"alpha_imag\":0
292 ; CHECK-DAG: \"beta\":0
293 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"1\"],\"rhs_batch_dimensions\":[\"0\"]}
294 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
295 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
296 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
297 )");
298 }
299
300 TEST_F(GemmRewriteTest, BatchFromMinorDimTransposeIsNotFolded) {
301 const char* hlo_text = R"(
302 HloModule BatchFromMinorDimTransposeDoesntFold
303
304 ENTRY AddDotsFunc {
305 x = f32[3,2,5] parameter(0)
306 y = f32[5,3,4] parameter(1)
307 x_transposed = f32[5,2,3] transpose(x), dimensions={2, 1, 0}
308 ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
309 }
310
311 )";
312
313 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
314 MatchOptimizedHlo(hlo_text,
315 R"(
316
317 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2,5], y: f32[5,3,4]) -> f32[5,2,4] {
318 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,2,5]{2,1,0} parameter(0)
319 ; CHECK-DAG: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
320 ; CHECK-DAG: [[FUSION:%[^ ]+]] = f32[5,2,3]{2,1,0} fusion([[P0]])
321 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[FUSION]], [[P1]]),
322 ; CHECK: custom_call_target="__cublas$gemm",
323 ; CHECK: backend_config="{
324 ; CHECK-DAG: \"alpha_real\":1
325 ; CHECK-DAG: \"alpha_imag\":0
326 ; CHECK-DAG: \"beta\":0
327 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]}
328 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
329 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
330 )");
331 }
332
333 TEST_F(GemmRewriteTest, InstrTransposeFoldCheck) {
334 const char* hlo_text = R"(
335 HloModule InstrTransposeFoldGemm
336
337 ENTRY AddDotsFunc {
338 x = f32[2,3] parameter(0)
339 y = f32[3,4] parameter(1)
340 dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
341 ROOT out = f32[4,2] transpose(dot_a), dimensions={1, 0}
342 }
343
344 )";
345
346 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
347 MatchOptimizedHlo(hlo_text,
348 R"(
349
350 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,3], y: f32[3,4]) -> f32[4,2] {
351 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
352 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
353 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} custom-call([[P1]], [[P0]]),
354 ; CHECK: custom_call_target="__cublas$gemm",
355 ; CHECK: backend_config="{
356 ; CHECK-DAG: \"alpha_real\":1
357 ; CHECK-DAG: \"alpha_imag\":0
358 ; CHECK-DAG: \"beta\":0
359 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
360 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
361 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
362 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
363 )");
364 }
365
366 TEST_F(GemmRewriteTest, BatchedInstrLayoutTransposed) {
367 const char* hlo_text = R"(
368 HloModule BatchedInstrLayoutCheck
369
370 ENTRY AddDotsFunc {
371 x = f32[5,2,3] parameter(0)
372 y = f32[5,3,4] parameter(1)
373 dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
374 ROOT out = f32[2,5,4] transpose(dot_a), dimensions={1, 0, 2}
375 }
376
377 )";
378
379 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
380 MatchOptimizedHlo(hlo_text,
381 R"(
382
383 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,5,4] {
384 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0)
385 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
386 ; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[5,2,4]{2,0,1} custom-call([[P0]], [[P1]]),
387 ; CHECK: custom_call_target="__cublas$gemm",
388 ; CHECK: backend_config="{
389 ; CHECK-DAG: \"alpha_real\":1
390 ; CHECK-DAG: \"alpha_imag\":0
391 ; CHECK-DAG: \"beta\":0
392 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]}
393 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
394 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
395 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
396 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,5,4]{2,1,0} bitcast([[GEMM]])
397 )");
398 }
399
400 TEST_F(GemmRewriteTest, BatchedInstrLayoutBatchNotInMinorDim) {
401 const char* hlo_text = R"(
402 HloModule BatchedInstrLayoutBatchNotInMinorDim
403
404 ENTRY AddDotsFunc {
405 x = f32[5,2,3] parameter(0)
406 y = f32[5,3,4] parameter(1)
407 dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
408 ROOT out = f32[2,4,5] transpose(dot_a), dimensions={1, 2, 0}
409 }
410
411 )";
412
413 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
414 MatchOptimizedHlo(hlo_text,
415 R"(
416
417 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,4,5] {
418 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0)
419 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
420 ; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]),
421 ; CHECK: custom_call_target="__cublas$gemm",
422 ; CHECK: backend_config="{
423 ; CHECK-DAG: \"alpha_real\":1
424 ; CHECK-DAG: \"alpha_imag\":0
425 ; CHECK-DAG: \"beta\":0
426 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]}
427 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
428 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
429 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
430 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4,5]{2,1,0} [[OP:[^ ]+]]([[GEMM]])
431 )");
432 }
433
434 TEST_F(GemmRewriteTest, AlphaSimpleRewrite) {
435 const char* hlo_text = R"(
436 HloModule AlphaSimpleRewrite
437
438 ENTRY AddDotsFunc {
439 x = f32[2,2] parameter(0)
440 y = f32[2,2] parameter(1)
441 k = f32[] constant(3.0)
442 k_broadcast = f32[2, 2] broadcast(k), dimensions={}
443 dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
444 ROOT dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
445 }
446
447 )";
448
449 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
450 MatchOptimizedHlo(hlo_text,
451 R"(
452
453 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
454 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
455 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
456 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]),
457 ; CHECK: custom_call_target="__cublas$gemm",
458 ; CHECK: backend_config="{
459 ; CHECK-DAG: \"alpha_real\":3
460 ; CHECK-DAG: \"alpha_imag\":0
461 ; CHECK-DAG: \"beta\":0
462 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
463 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
464 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
465 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
466 )");
467 }
468
469 TEST_F(GemmRewriteTest, ComplexAlphaSimpleRewrite) {
470 const char* hlo_text = R"(
471 HloModule ComplexAlphaSimpleRewrite
472
473 ENTRY AddDotsFunc {
474 x = c64[2,2] parameter(0)
475 y = c64[2,2] parameter(1)
476 k = c64[] constant((3.0, 3.0))
477 k_broadcast = c64[2, 2] broadcast(k), dimensions={}
478 dot_a = c64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
479 ROOT dot_a_multiplied = c64[2, 2] multiply(dot_a, k_broadcast)
480 }
481
482 )";
483
484 DebugOptions debug_options = GetDebugOptionsForTest();
485 if (!debug_options.xla_gpu_enable_cublaslt()) {
486 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
487 } else {
488 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-2}));
489 }
490 MatchOptimizedHlo(hlo_text,
491 R"(
492
493 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: c64[2,2], y: c64[2,2]) -> c64[2,2] {
494 ; CHECK-NEXT: [[P0:%[^ ]+]] = c64[2,2]{1,0} parameter(0)
495 ; CHECK-NEXT: [[P1:%[^ ]+]] = c64[2,2]{1,0} parameter(1)
496 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = c64[2,2]{1,0} custom-call([[P0]], [[P1]]),
497 ; CHECK: custom_call_target="__cublas$gemm",
498 ; CHECK: backend_config="{
499 ; CHECK-DAG: \"alpha_real\":3
500 ; CHECK-DAG: \"alpha_imag\":3
501 ; CHECK-DAG: \"beta\":0
502 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
503 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
504 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
505 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
506 )");
507 }
508
509 TEST_F(GemmRewriteTest, AlphaMultipleUsersNoRewrite) {
510 const char* hlo_text = R"(
511 HloModule AlphaMultipleUsersNoRewrite
512
513 ENTRY AddDotsFunc {
514 x = f32[2,2] parameter(0)
515 y = f32[2,2] parameter(1)
516 k = f32[] constant(3.0)
517 k_broadcast = f32[2, 2] broadcast(k), dimensions={}
518 dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
519 dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
520 ROOT out = f32[2,2] add(dot_a_multiplied, dot_a)
521 }
522
523 )";
524
525 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
526 MatchOptimizedHlo(hlo_text,
527 R"(
528 ; CHECK: [[C:%[^ ]+]] = f32[2,2]{1,0} custom-call([[A:%[^ ]+]], [[B:%[^ ]+]]),
529 ; CHECK: custom_call_target="__cublas$gemm",
530 ; CHECK: backend_config="{
531 ; CHECK-DAG: \"alpha_real\":1
532 ; CHECK-DAG: \"alpha_imag\":0
533 ; CHECK-DAG: \"beta\":0
534 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
535 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
536 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
537 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
538 )");
539 }
540
541 TEST_F(GemmRewriteTest, AlphaVectorNoRewrite) {
542 const char* hlo_text = R"(
543 HloModule AlphaVectorNoRewrite
544
545 ENTRY AddDotsFunc {
546 x = f32[2,2] parameter(0)
547 y = f32[2,2] parameter(1)
548 alpha = f32[2] constant({1, 2})
549 alpha_broadcast = f32[2,2] broadcast(alpha), dimensions={1}
550 dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
551 ROOT dot_a_multiplied = f32[2, 2] multiply(dot, alpha_broadcast)
552 }
553 )";
554
555 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
556 MatchOptimizedHlo(hlo_text,
557 R"(
558
559 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
560 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
561 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
562 ; CHECK-NEXT: [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]),
563 ; CHECK: custom_call_target="__cublas$gemm",
564 ; CHECK: backend_config="{
565 ; CHECK-DAG: \"alpha_real\":1
566 ; CHECK-DAG: \"alpha_imag\":0
567 ; CHECK-DAG: \"beta\":0
568 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
569 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
570 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
571 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
572 )");
573 }
574
575 TEST_F(GemmRewriteTest, AlphaBetaRewrite) {
576 const char* hlo_text = R"(
577 HloModule NonZeroAlphaBeta
578
579 ENTRY AddDotsFunc {
580 x = f32[2,2] parameter(0)
581 y = f32[2,2] parameter(1)
582 bias = f32[2,2] parameter(2)
583 k = f32[] constant(3.0)
584 k_broadcast = f32[2, 2] broadcast(k), dimensions={}
585 dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
586 dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
587 ROOT out = f32[2,2] add(dot_a_multiplied, bias)
588 }
589
590 )";
591
592 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
593 MatchOptimizedHlo(hlo_text,
594 R"(
595
596 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] {
597 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
598 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
599 ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
600 ; CHECK-NEXT: [[P2_COPY:%[^ ]+]] = f32[2,2]{1,0} copy([[P2]])
601 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]], [[P2_COPY]]),
602 ; CHECK: custom_call_target="__cublas$gemm",
603 ; CHECK: output_to_operand_aliasing={{{{}: \(2, {}\)}}},
604 ; CHECK: backend_config="{
605 ; CHECK-DAG: \"alpha_real\":3
606 ; CHECK-DAG: \"alpha_imag\":0
607 ; CHECK-DAG: \"beta\":1
608 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
609 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
610 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
611 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
612 )");
613 }
614
615 TEST_F(GemmRewriteTest, BiasMultipleUsersNoRewrite) {
616 const char* hlo_text = R"(
617 HloModule BiasMultipleUsersNoRewrite
618
619 ENTRY AddDotsFunc {
620 x = f32[2,2] parameter(0)
621 y = f32[2,2] parameter(1)
622 bias = f32[2,2] parameter(2)
623 k = f32[] constant(3.0)
624 k_broadcast = f32[2, 2] broadcast(k), dimensions={}
625 dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
626 dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
627 biased_out = f32[2,2] add(dot_a_multiplied, bias)
628 ROOT out = f32[2,2] add(biased_out, bias)
629 }
630
631 )";
632
633 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
634 MatchOptimizedHlo(hlo_text,
635 R"(
636
637 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] {
638 ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
639 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
640 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
641 ; CHECK-NEXT: [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]),
642 ; CHECK: custom_call_target="__cublas$gemm",
643 ; CHECK: backend_config="{
644 ; CHECK-DAG: \"alpha_real\":3
645 ; CHECK-DAG: \"alpha_imag\":0
646 ; CHECK-DAG: \"beta\":0
647 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
648 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
649 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
650 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
651 )");
652 }
653
654 TEST_F(GemmRewriteTest, SharedBufferAssignment) {
655 const char* hlo_text = R"(
656 HloModule SharedBufferAssignment
657
658 ENTRY AddDotsFunc {
659 x = f32[2,2] parameter(0)
660 y = f32[2,2] parameter(1)
661 bias = f32[2,2] add(x, y)
662 dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
663 ROOT out = f32[2,2] add(dot, bias)
664 }
665
666 )";
667
668 // Bias should be fused into the multiplication.
669 CheckNumberOfAllocations(hlo_text, 3);
670 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
671 }
672
673 TEST_F(GemmRewriteTest, BF16Gemm) {
674 const char* hlo_text = R"(
675 HloModule bf16gemm
676
677 ENTRY bf16gemm {
678 %parameter.1 = bf16[12,4]{1,0} parameter(0)
679 %parameter.2 = bf16[4,8]{1,0} parameter(1)
680 ROOT %dot.8 = bf16[12,8] dot(bf16[12,4] %parameter.1, bf16[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
681 }
682 )";
683 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
684
685 if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) {
686 MatchOptimizedHlo(hlo_text,
687 R"(
688 ; CHECK: bf16[16,8]{1,0} custom-call(bf16[16,8]{1,0} {{.*}}, bf16[8,8]{1,0} {{.*}}), custom_call_target="__cublas$gemm"
689 )",
690 /*print_operand_shape=*/true);
691 } else {
692 MatchOptimizedHlo(hlo_text,
693 R"(
694 ; CHECK: bf16[12,8]{1,0} custom-call(bf16[12,4]{1,0} [[P0:%[^ ]+]], bf16[4,8]{1,0} [[P1:%[^ ]+]]), custom_call_target="__cublas$gemm"
695
696 )",
697 /*print_operand_shape=*/true);
698 }
699 }
700
701 TEST_F(GemmRewriteTest, BF16GemmStrided) {
702 const char* hlo_text = R"(
703 HloModule bf16gemm
704
705 ENTRY bf16gemm {
706 %parameter.1 = bf16[3,3,4] parameter(0)
707 %parameter.2 = bf16[3,3,2] parameter(1)
708 ROOT %dot.3 = bf16[3,4,2]{2,1,0} dot(bf16[3,3,4]{2,1,0} %parameter.1, bf16[3,3,2]{2,1,0} %parameter.2), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,highest}
709 }
710
711 )";
712 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
713
714 if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) {
715 MatchOptimizedHlo(hlo_text,
716 R"(
717 ; CHECK: bf16[3,8,8]{2,1,0} custom-call(bf16[3,8,8]{2,1,0} {{.*}}, bf16[3,8,8]{2,1,0} {{.*}}), custom_call_target="__cublas$gemm"
718 )",
719 /*print_operand_shape=*/true);
720 } else {
721 MatchOptimizedHlo(hlo_text,
722 R"(
723 ; CHECK: ROOT [[OUT:%[^ ]+]] = bf16[3,4,2]{2,1,0} custom-call(bf16[3,3,4]{2,1,0} [[A:%[^ ]+]], bf16[3,3,2]{2,1,0} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
724 )",
725 /*print_operand_shape=*/true);
726 }
727 }
728
729 TEST_F(GemmRewriteTest, BF16GemmCodeGen) {
730 const char* hlo_text = R"(
731 HloModule bf16codegendgemm
732
733 ENTRY bf16gemm {
734 %parameter.1 = bf16[3]{0} parameter(0)
735 %parameter.2 = bf16[3]{0} parameter(1)
736 ROOT %dot.3 = bf16[] dot(bf16[3]{0} %parameter.1, bf16[3]{0} %parameter.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}, operand_precision={highest,highest}
737 }
738 )";
739
740 MatchOptimizedHlo(hlo_text, R"(
741 ; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
742 ; CHECK: [[INSTR_1:%[^ ]+]] = f32[3]{0} convert([[P1]])
743 ; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
744 ; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[P0]])
745 ; CHECK: [[INSTR_4:%[^ ]+]] = f32[3]{0} multiply([[INSTR_1]], [[INSTR_3]])
746 ; CHECK: [[INSTR_5:%[^ ]+]] = f32[] constant(0)
747 ; CHECK: [[INSTR_6:%[^ ]+]] = f32[] reduce([[INSTR_4]], [[INSTR_5]]), dimensions={0}, to_apply=[[INSTR_7:%[^ ]+]]
748 ; CHECK: ROOT [[INSTR_8:%[^ ]+]] = bf16[] convert([[INSTR_6]])
749 )");
750
751 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
752 }
753
754 TEST_F(GemmRewriteTest, BF16Transpose) {
755 const char* hlo_text = R"(
756 HloModule broadcast
757
758 ENTRY broadcast {
759 p = bf16[9] parameter(0)
760 ROOT out = bf16[1,9] broadcast(p), dimensions={1}
761 }
762 )";
763
764 MatchOptimizedHlo(hlo_text, R"(
765 ; CHECK: bf16[1,9]{1,0} bitcast
766 ; CHECK: bf16[1,9]{1,0} copy
767 )");
768
769 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
770 }
771
772 TEST_F(GemmRewriteTest, Int8Gemm) {
773 const char* hlo_text = R"(
774 HloModule int8gemm
775
776 ENTRY int8gemm {
777 %parameter.1 = s8[12,4]{1,0} parameter(0)
778 %parameter.2 = s8[4,8]{1,0} parameter(1)
779 ROOT %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
780 }
781 )";
782 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
783
784 if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
785 MatchOptimizedHlo(hlo_text,
786 R"(
787 ; CHECK: s32[12,8]{1,0} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
788 )",
789 /*print_operand_shape=*/true);
790 } else {
791 MatchOptimizedHlo(hlo_text,
792 R"(
793 ; CHECK: s32[12,8]{1,0} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
794
795 )",
796 /*print_operand_shape=*/true);
797 }
798 }
799
800 TEST_F(GemmRewriteTest, Int8GemmNoAlphaRewrite) {
801 const char* hlo_text = R"(
802 HloModule int8gemm
803
804 ENTRY int8gemm {
805 %parameter.1 = s8[12,4]{1,0} parameter(0)
806 %parameter.2 = s8[4,8]{1,0} parameter(1)
807 k = s32[] constant(2)
808 k_broadcast = s32[12,8] broadcast(k), dimensions={}
809 %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
810 ROOT dot_multiplied = s32[12,8] multiply(%dot.8, k_broadcast)
811 }
812 )";
813 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
814
815 if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
816 MatchOptimizedHlo(hlo_text,
817 R"(
818 ; CHECK: s32[12,8]{1,0} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
819 ; CHECK: custom_call_target="__cublas$gemm",
820 ; CHECK: backend_config="{
821 ; CHECK-DAG: \"alpha_real\":1
822 ; CHECK-DAG: \"alpha_imag\":0
823 )",
824 /*print_operand_shape=*/true);
825 } else {
826 MatchOptimizedHlo(hlo_text,
827 R"(
828 ; CHECK: s32[12,8]{1,0} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
829
830 )",
831 /*print_operand_shape=*/true);
832 }
833 }
834
835 TEST_F(GemmRewriteTest, Int8GemmNoBetaRewrite) {
836 const char* hlo_text = R"(
837 HloModule int8gemm
838
839 ENTRY int8gemm {
840 %parameter.1 = s8[12,4]{1,0} parameter(0)
841 %parameter.2 = s8[4,8]{1,0} parameter(1)
842 bias = s32[12,8] parameter(2)
843 %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
844 ROOT out = s32[12,8] add(%dot.8, bias)
845 }
846 )";
847 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
848
849 if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
850 MatchOptimizedHlo(hlo_text,
851 R"(
852 ; CHECK: s32[12,8]{1,0} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
853 ; CHECK: custom_call_target="__cublas$gemm",
854 ; CHECK: backend_config="{
855 ; CHECK-DAG: \"alpha_real\":1
856 ; CHECK-DAG: \"alpha_imag\":0
857 ; CHECK-DAG: \"beta\":0
858 )",
859 /*print_operand_shape=*/true);
860 } else {
861 MatchOptimizedHlo(hlo_text,
862 R"(
863 ; CHECK: s32[12,8]{1,0} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
864
865 )",
866 /*print_operand_shape=*/true);
867 }
868 }
869
870 TEST_F(GemmRewriteTest, Int8GemmNotMultipleOfFour) {
871 const char* hlo_text = R"(
872 HloModule int8gemm
873
874 ENTRY int8gemm {
875 %parameter.1 = s8[13,4]{1,0} parameter(0)
876 %parameter.2 = s8[4,9]{1,0} parameter(1)
877 ROOT %dot.9 = s32[13,9] dot(s8[13,4] %parameter.1, s8[4,9] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
878 }
879 )";
880 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
881
882 if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
883 MatchOptimizedHlo(hlo_text,
884 R"(
885 ; CHECK: s32[16,12]{1,0} custom-call(s8[16,4]{1,0} [[A:%[^ ]+]], s8[4,12]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
886 )",
887 /*print_operand_shape=*/true);
888 } else {
889 MatchOptimizedHlo(hlo_text,
890 R"(
891 ; CHECK: s32[13,9]{1,0} dot(s32[13,4]{1,0} [[A:%[^ ]+]], s32[4,9]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
892
893 )",
894 /*print_operand_shape=*/true);
895 }
896 }
897
898 TEST_F(GemmRewriteTest, BF16GemmWithBias) {
899 const char* hlo_text = R"(
900 HloModule BF16GemmWithBias
901
902 ENTRY BF16GemmWithBias {
903 x = bf16[8,8]{1,0} parameter(0)
904 y = bf16[8,8]{1,0} parameter(1)
905 dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
906 bias = bf16[8,8]{1,0} parameter(2)
907 ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias)
908 }
909 )";
910
911 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
912 MatchOptimizedHlo(hlo_text,
913 R"(
914 ; CHECK-LABEL: ENTRY %BF16GemmWithBias (x: bf16[8,8], y: bf16[8,8], bias: bf16[8,8]) -> bf16[8,8] {
915 ; CHECK-NEXT: [[P0:%[^ ]+]] = bf16[8,8]{1,0} parameter(0)
916 ; CHECK-NEXT: [[P1:%[^ ]+]] = bf16[8,8]{1,0} parameter(1)
917 ; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[8,8]{1,0} parameter(2)
918 ; CHECK-NEXT: [[P2_COPY:%[^ ]+]] = bf16[8,8]{1,0} copy([[P2]])
919 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2_COPY]]), custom_call_target="__cublas$gemm", output_to_operand_aliasing={{{{}: \(2, {}\)}}}, backend_config="{
920 ; CHECK-DAG: \"alpha_real\":1
921 ; CHECK-DAG: \"alpha_imag\":0
922 ; CHECK-DAG: \"beta\":1
923 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
924 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
925 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
926 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
927 )");
928 }
929
930 #if GOOGLE_CUDA
931
932 class CublasLtMatmulRewriteTest : public GpuCodegenTest {
933 DebugOptions GetDebugOptionsForTest() override {
934 DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
935 debug_options.set_xla_gpu_enable_cublaslt(true);
936 return debug_options;
937 }
938 };
939
940 TEST_F(CublasLtMatmulRewriteTest, Simple) {
941 const char* hlo_text = R"(
942 HloModule test
943
944 ENTRY test {
945 x = f32[2,3] parameter(0)
946 y = f32[3,4] parameter(1)
947 ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
948 }
949
950 )";
951
952 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
953 MatchOptimizedHlo(hlo_text,
954 R"(
955
956 ; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> f32[2,4] {
957 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
958 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
959 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), custom_call_target="__cublas$lt$matmul", backend_config="{
960 ; CHECK-DAG: \"alpha_real\":1
961 ; CHECK-DAG: \"alpha_imag\":0
962 ; CHECK-DAG: \"beta\":0
963 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
964 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
965 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
966 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
967 )");
968 }
969
970 TEST_F(CublasLtMatmulRewriteTest, MatrixBias) {
971 const char* hlo_text = R"(
972 HloModule test
973
974 ENTRY test {
975 x = f32[2,3] parameter(0)
976 y = f32[3,4] parameter(1)
977 z = f32[2,4] parameter(2)
978 dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
979 ROOT out = f32[2,4] add(dot_a, z)
980 }
981
982 )";
983
984 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
985 MatchOptimizedHlo(hlo_text,
986 R"(
987
988 ; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2,4]) -> f32[2,4] {
989 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
990 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
991 ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2)
992 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), custom_call_target="__cublas$lt$matmul", backend_config="{
993 ; CHECK-DAG: \"alpha_real\":1
994 ; CHECK-DAG: \"alpha_imag\":0
995 ; CHECK-DAG: \"beta\":1
996 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
997 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
998 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
999 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
1000 )");
1001 }
1002
1003 TEST_F(CublasLtMatmulRewriteTest, VectorBias) {
1004 const char* hlo_text = R"(
1005 HloModule test
1006
1007 ENTRY test {
1008 x = f32[2,3] parameter(0)
1009 y = f32[3,4] parameter(1)
1010 z = f32[4] parameter(2)
1011 dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1012 z_bcast = f32[2,4] broadcast(z), dimensions={1}
1013 ROOT out = f32[2,4] add(dot_a, z_bcast)
1014 }
1015
1016 )";
1017
1018 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
1019 MatchOptimizedHlo(hlo_text,
1020 R"(
1021
1022 ; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> f32[2,4] {
1023 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
1024 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
1025 ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
1026 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), custom_call_target="__cublas$lt$matmul", backend_config="{
1027 ; CHECK-DAG: \"alpha_real\":1
1028 ; CHECK-DAG: \"alpha_imag\":0
1029 ; CHECK-DAG: \"beta\":0
1030 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
1031 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
1032 ; CHECK-DAG: \"epilogue\":\"BIAS\"
1033 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
1034 )");
1035 }
1036
1037 TEST_F(CublasLtMatmulRewriteTest, VectorBiasIncorrectAxisFusedAsMatrix) {
1038 const char* hlo_text = R"(
1039 HloModule test
1040
1041 ENTRY test {
1042 x = f32[2,3] parameter(0)
1043 y = f32[3,4] parameter(1)
1044 z = f32[2] parameter(2)
1045 dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1046 z_bcast = f32[2,4] broadcast(z), dimensions={0}
1047 ROOT out = f32[2,4] add(dot_a, z_bcast)
1048 }
1049
1050 )";
1051
1052 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
1053 MatchOptimizedHlo(hlo_text,
1054 R"(
1055
1056 ; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2]) -> f32[2,4] {
1057 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
1058 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
1059 ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
1060 ; CHECK-NEXT: [[P2_BCAST:%[^ ]+]] = f32[2,4]{1,0} broadcast([[P2]]), dimensions={0}
1061 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2_BCAST]]), custom_call_target="__cublas$lt$matmul", backend_config="{
1062 ; CHECK-DAG: \"alpha_real\":1
1063 ; CHECK-DAG: \"alpha_imag\":0
1064 ; CHECK-DAG: \"beta\":1
1065 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
1066 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
1067 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
1068 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
1069 )");
1070 }
1071
1072 TEST_F(CublasLtMatmulRewriteTest, VectorBiasTransposed) {
1073 const char* hlo_text = R"(
1074 HloModule test
1075
1076 ENTRY test {
1077 x = f32[2,3] parameter(0)
1078 y = f32[3,4] parameter(1)
1079 z = f32[2] parameter(2)
1080 dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1081 z_bcast = f32[2,4] broadcast(z), dimensions={0}
1082 add = f32[2,4] add(dot_a, z_bcast)
1083 ROOT out = f32[4,2] transpose(add), dimensions={1,0}
1084 }
1085
1086 )";
1087
1088 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
1089 MatchOptimizedHlo(hlo_text,
1090 R"(
1091
1092 ; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2]) -> f32[4,2] {
1093 ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
1094 ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
1095 ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
1096 ; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[2,4]{0,1} custom-call([[P0]], [[P1]], [[P2]]), custom_call_target="__cublas$lt$matmul", backend_config="{
1097 ; CHECK-DAG: \"alpha_real\":1
1098 ; CHECK-DAG: \"alpha_imag\":0
1099 ; CHECK-DAG: \"beta\":0
1100 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
1101 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
1102 ; CHECK-DAG: \"epilogue\":\"BIAS\"
1103 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
1104 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} bitcast([[MATMUL]])
1105 )");
1106 }
1107
1108 TEST_F(CublasLtMatmulRewriteTest, VectorBiasThenMatrixBias) {
1109 const char* hlo_text = R"(
1110 HloModule test
1111
1112 ENTRY test {
1113 x = f32[2,3] parameter(0)
1114 y = f32[3,4] parameter(1)
1115 z = f32[4] parameter(2)
1116 z2 = f32[2,4] parameter(3)
1117 dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1118 z_bcast = f32[2,4] broadcast(z), dimensions={1}
1119 add0 = f32[2,4] add(dot_a, z_bcast)
1120 ROOT add1 = f32[2,4] add(add0, z2)
1121 }
1122
1123 )";
1124
1125 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
1126 MatchOptimizedHlo(hlo_text,
1127 R"(
1128
1129 ; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4], z2: f32[2,4]) -> f32[2,4] {
1130 ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
1131 ; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
1132 ; CHECK-DAG: [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
1133 ; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
1134 ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P3]], [[P2]]), custom_call_target="__cublas$lt$matmul", backend_config="{
1135 ; CHECK-DAG: \"alpha_real\":1
1136 ; CHECK-DAG: \"alpha_imag\":0
1137 ; CHECK-DAG: \"beta\":1
1138 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
1139 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
1140 ; CHECK-DAG: \"epilogue\":\"BIAS\"
1141 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
1142 )");
1143 }
1144
1145 #endif // GOOGLE_CUDA
1146
1147 using GemmRewriteHloTest = HloTestBase;
1148
1149 TEST_F(GemmRewriteHloTest, MergeBitcastAndAdd) {
1150 const char* hlo_text = R"(
1151 HloModule test
1152 ENTRY test {
1153 x = f32[2,2] parameter(0)
1154 y = f32[2,2] parameter(1)
1155 bias = f32[4] parameter(2)
1156 dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1157 ROOT out = f32[4] add(f32[4] bitcast(dot), bias)
1158 }
1159 )";
1160
1161 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1162 ParseAndReturnVerifiedModule(hlo_text));
1163 GemmRewriter pass;
1164 TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
1165 EXPECT_TRUE(changed);
1166
1167 EXPECT_THAT(
1168 module->entry_computation()->root_instruction(),
1169 GmockMatch(
1170 m::Bitcast(
1171 m::CustomCall("__cublas$gemm", m::Parameter(0), m::Parameter(1),
1172 m::Bitcast(m::Parameter(2)).WithShape(F32, {2, 2})))
1173 .WithShape(F32, {4})));
1174 }
1175
1176 TEST_F(GemmRewriteHloTest, FoldConstantBias) {
1177 const char* hlo_text = R"(
1178 HloModule test
1179 ENTRY test {
1180 x = f32[2,2] parameter(0)
1181 y = f32[2,2] parameter(1)
1182 bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
1183
1184 dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1185 bias1 = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
1186 sum1 = add(dot1, bias1)
1187
1188 dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1189 sum2 = add(dot2, f32[2,2] reshape(bias))
1190
1191 dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1192 bias3 = f32[2,2] transpose(bias), dimensions={1,0}
1193 sum3 = add(dot3, bias3)
1194
1195 dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1196 sum4 = add(dot4, f32[2,2] bitcast(bias))
1197
1198 ROOT root = tuple(sum1, sum2, sum3, sum4)
1199 }
1200 )";
1201
1202 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1203 ParseAndReturnVerifiedModule(hlo_text));
1204 GemmRewriter pass;
1205 TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
1206 SCOPED_TRACE(module->ToString());
1207 EXPECT_TRUE(changed);
1208
1209 EXPECT_THAT(
1210 module->entry_computation()->root_instruction(),
1211 GmockMatch(m::Tuple(
1212 m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
1213 m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
1214 m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
1215 m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()))));
1216 }
1217
1218 } // namespace
1219 } // namespace gpu
1220 } // namespace xla
1221