xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <utility>
18 
19 #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
20 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
21 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
22 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
25 #include "tensorflow/compiler/xla/service/hlo_parser.h"
26 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/tests/filecheck.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/xla.pb.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/stream_executor/lib/statusor.h"
35 
36 namespace xla {
37 namespace gpu {
38 
39 namespace {
40 
41 namespace m = ::xla::match;
42 
43 class GemmRewriteTest : public GpuCodegenTest {
44  public:
CheckNumberOfAllocations(const std::string & hlo,int expected_number_of_allocations)45   void CheckNumberOfAllocations(const std::string& hlo,
46                                 int expected_number_of_allocations) {
47     TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
48                             GetOptimizedModule(hlo));
49     TF_ASSERT_OK_AND_ASSIGN(
50         std::unique_ptr<Executable> executable,
51         backend().compiler()->RunBackend(
52             std::move(optimized_module), backend().default_stream_executor(),
53             backend().default_stream_executor()->GetAllocator()));
54     GpuExecutable* gpu_executable =
55         static_cast<GpuExecutable*>(executable.get());
56     absl::Span<const BufferAllocation> allocations =
57         gpu_executable->GetAllocations();
58     CHECK_EQ(allocations.size(), expected_number_of_allocations);
59   }
60 
GetCudaComputeCapability()61   se::CudaComputeCapability GetCudaComputeCapability() {
62     return backend()
63         .default_stream_executor()
64         ->GetDeviceDescription()
65         .cuda_compute_capability();
66   }
67 };
68 
TEST_F(GemmRewriteTest,SimpleRewrite)69 TEST_F(GemmRewriteTest, SimpleRewrite) {
70   const char* hlo_text = R"(
71 HloModule SimpleGemm
72 
73 ENTRY AddDotsFunc {
74   x = f32[2,3] parameter(0)
75   y = f32[3,4] parameter(1)
76   ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
77 }
78 
79 )";
80 
81   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
82   MatchOptimizedHlo(hlo_text,
83                     R"(
84 
85 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,3], y: f32[3,4]) -> f32[2,4] {
86 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
87 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
88 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]),
89 ; CHECK:           custom_call_target="__cublas$gemm",
90 ; CHECK:           backend_config="{
91 ; CHECK-DAG:       \"alpha_real\":1
92 ; CHECK-DAG:       \"alpha_imag\":0
93 ; CHECK-DAG:       \"beta\":0
94 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
95 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
96 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
97 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
98       )");
99 }
100 
101 TEST_F(GemmRewriteTest, TestBatchedAutotuning) {
102   const char* hlo_text = R"(
103 HloModule ComplexDotMultipleNonContracting
104 
105 ENTRY %test {
106   %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0)
107   %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1)
108   ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3}
109 }
110 
111 )";
112 
113   MatchOptimizedHlo(hlo_text,
114                     R"(
115 ; CHECK: selected_algorithm
116       )");
117 }
118 
119 TEST_F(GemmRewriteTest, SimpleRewriteDeterministic) {
120   const char* hlo_text = R"(
121 HloModule SimpleGemm
122 
123 ENTRY AddDotsFunc {
124   x = f32[128,128] parameter(0)
125   y = f32[128,128] parameter(1)
126   ROOT dot_a = f32[128,128] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
127 }
128 )";
129 
130   auto get_module = [&]() -> StatusOr<std::unique_ptr<HloModule>> {
131     HloModuleConfig config;
132     DebugOptions debug_options = GetDebugOptionsForTest();
133     debug_options.set_xla_gpu_deterministic_ops(true);
134     config.set_debug_options(debug_options);
135     return ParseAndReturnVerifiedModule(hlo_text, config);
136   };
137 
138   TF_ASSERT_OK_AND_ASSIGN(
139       std::unique_ptr<HloModule> optimized_module,
140       backend().compiler()->RunHloPasses(
141           *get_module(), backend().default_stream_executor(),
142           backend().default_stream_executor()->GetAllocator()));
143   DebugOptions debug_options = GetDebugOptionsForTest();
144   if (!debug_options.xla_gpu_enable_cublaslt()) {
145     StatusOr<bool> filecheck_result_cublas =
146         RunFileCheck(optimized_module->ToString(),
147                      R"(
148 ; CHECK:    \"selected_algorithm\":\"-1\"
149       )");
150     TF_ASSERT_OK(filecheck_result_cublas.status());
151     EXPECT_TRUE(filecheck_result_cublas.ValueOrDie());
152     EXPECT_TRUE(RunAndCompare(*get_module(), ErrorSpec{1e-5, 1e-5}));
153   } else {
154     // With cublaslt enabled, selected_algorithm is se::blas::kNoAlgorithm
155     StatusOr<bool> filecheck_result_cublas =
156         RunFileCheck(optimized_module->ToString(),
157                      R"(
158 ; CHECK:    \"selected_algorithm\":\"-4\"
159       )");
160     TF_ASSERT_OK(filecheck_result_cublas.status());
161     EXPECT_TRUE(filecheck_result_cublas.ValueOrDie());
162     EXPECT_TRUE(RunAndCompare(*get_module(), ErrorSpec{1e-3, 1e-5}));
163   }
164 }
165 
166 TEST_F(GemmRewriteTest, MultipleContractingDims) {
167   const char* hlo_text = R"(
168 HloModule MultipleContractingCheckGemm
169 
170 ENTRY AddDotsFunc {
171   x = f32[3,4,2] parameter(0)
172   y = f32[3,4,5] parameter(1)
173   ROOT dot_a = f32[2,5] dot(x, y), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}
174 }
175 
176 )";
177 
178   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
179   MatchOptimizedHlo(hlo_text,
180                     R"(
181 ; CHECK-NOT:     copy
182 ;
183 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,4,2], y: f32[3,4,5]) -> f32[2,5] {
184 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,4,2]{2,1,0} parameter(0)
185 ; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4,5]{2,1,0} parameter(1)
186 ; CHECK-DAG:     [[BITCAST0:%[^ ]+]] = f32[2,12]{0,1} bitcast([[P0]])
187 ; CHECK-DAG:     [[BITCAST1:%[^ ]+]] = f32[12,5]{1,0} bitcast([[P1]])
188 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,5]{1,0} custom-call([[BITCAST0]], [[BITCAST1]]),
189 ; CHECK:           custom_call_target="__cublas$gemm",
190 ; CHECK:           backend_config="{
191 ; CHECK-DAG:       \"alpha_real\":1
192 ; CHECK-DAG:       \"alpha_imag\":0
193 ; CHECK-DAG:       \"beta\":0
194 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
195 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
196 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
197 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
198       )");
199 }
200 
201 TEST_F(GemmRewriteTest, ArgTransposeFoldCheck) {
202   const char* hlo_text = R"(
203 HloModule ArgTransposeFoldGemm
204 
205 ENTRY AddDotsFunc {
206   x = f32[3,2] parameter(0)
207   y = f32[3,4] parameter(1)
208   x_transposed = f32[2,3] transpose(x), dimensions={1, 0}
209   ROOT dot_a = f32[2,4] dot(x_transposed, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
210 }
211 
212 )";
213 
214   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
215   MatchOptimizedHlo(hlo_text,
216                     R"(
217 
218 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2], y: f32[3,4]) -> f32[2,4] {
219 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,2]{1,0} parameter(0)
220 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
221 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]),
222 ; CHECK:           custom_call_target="__cublas$gemm",
223 ; CHECK:           backend_config="{
224 ; CHECK-DAG:       \"alpha_real\":1
225 ; CHECK-DAG:       \"alpha_imag\":0
226 ; CHECK-DAG:       \"beta\":0
227 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
228 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
229 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
230 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
231       )");
232 }
233 
234 TEST_F(GemmRewriteTest, BatchedArgRowColTransposeFoldCheck) {
235   const char* hlo_text = R"(
236 HloModule BatchedArgRowColTransposeFoldGemm
237 
238 ENTRY AddDotsFunc {
239   x = f32[5,3,2] parameter(0)
240   y = f32[5,3,4] parameter(1)
241   x_transposed = f32[5,2,3] transpose(x), dimensions={0, 2, 1}
242   ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
243 }
244 
245 )";
246 
247   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
248   MatchOptimizedHlo(hlo_text,
249                     R"(
250 
251 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,3,2], y: f32[5,3,4]) -> f32[5,2,4] {
252 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[5,3,2]{2,1,0} parameter(0)
253 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
254 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]),
255 ; CHECK:           custom_call_target="__cublas$gemm",
256 ; CHECK:           backend_config="{
257 ; CHECK-DAG:       \"alpha_real\":1
258 ; CHECK-DAG:       \"alpha_imag\":0
259 ; CHECK-DAG:       \"beta\":0
260 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]}
261 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
262 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
263 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
264       )");
265 }
266 
267 TEST_F(GemmRewriteTest, BatchRowTransposeFoldCheck) {
268   const char* hlo_text = R"(
269 HloModule BatchRowTransposeFoldCheck
270 
271 ENTRY AddDotsFunc {
272   x = f32[2,5,3] parameter(0)
273   y = f32[5,3,4] parameter(1)
274   x_transposed = f32[5,2,3] transpose(x), dimensions={1, 0, 2}
275   ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
276 }
277 
278 )";
279 
280   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
281   MatchOptimizedHlo(hlo_text,
282                     R"(
283 
284 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,5,3], y: f32[5,3,4]) -> f32[5,2,4] {
285 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,5,3]{2,1,0} parameter(0)
286 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
287 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]),
288 ; CHECK:           custom_call_target="__cublas$gemm",
289 ; CHECK:           backend_config="{
290 ; CHECK-DAG:       \"alpha_real\":1
291 ; CHECK-DAG:       \"alpha_imag\":0
292 ; CHECK-DAG:       \"beta\":0
293 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"1\"],\"rhs_batch_dimensions\":[\"0\"]}
294 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
295 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
296 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
297       )");
298 }
299 
300 TEST_F(GemmRewriteTest, BatchFromMinorDimTransposeIsNotFolded) {
301   const char* hlo_text = R"(
302 HloModule BatchFromMinorDimTransposeDoesntFold
303 
304 ENTRY AddDotsFunc {
305   x = f32[3,2,5] parameter(0)
306   y = f32[5,3,4] parameter(1)
307   x_transposed = f32[5,2,3] transpose(x), dimensions={2, 1, 0}
308   ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
309 }
310 
311 )";
312 
313   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
314   MatchOptimizedHlo(hlo_text,
315                     R"(
316 
317 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2,5], y: f32[5,3,4]) -> f32[5,2,4] {
318 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,2,5]{2,1,0} parameter(0)
319 ; CHECK-DAG:     [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
320 ; CHECK-DAG:     [[FUSION:%[^ ]+]] = f32[5,2,3]{2,1,0} fusion([[P0]])
321 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[FUSION]], [[P1]]),
322 ; CHECK:           custom_call_target="__cublas$gemm",
323 ; CHECK:           backend_config="{
324 ; CHECK-DAG:       \"alpha_real\":1
325 ; CHECK-DAG:       \"alpha_imag\":0
326 ; CHECK-DAG:       \"beta\":0
327 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]}
328 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
329 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
330       )");
331 }
332 
333 TEST_F(GemmRewriteTest, InstrTransposeFoldCheck) {
334   const char* hlo_text = R"(
335 HloModule InstrTransposeFoldGemm
336 
337 ENTRY AddDotsFunc {
338   x = f32[2,3] parameter(0)
339   y = f32[3,4] parameter(1)
340   dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
341   ROOT out = f32[4,2] transpose(dot_a), dimensions={1, 0}
342 }
343 
344 )";
345 
346   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
347   MatchOptimizedHlo(hlo_text,
348                     R"(
349 
350 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,3], y: f32[3,4]) -> f32[4,2] {
351 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
352 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
353 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} custom-call([[P1]], [[P0]]),
354 ; CHECK:           custom_call_target="__cublas$gemm",
355 ; CHECK:           backend_config="{
356 ; CHECK-DAG:       \"alpha_real\":1
357 ; CHECK-DAG:       \"alpha_imag\":0
358 ; CHECK-DAG:       \"beta\":0
359 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
360 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
361 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
362 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
363       )");
364 }
365 
366 TEST_F(GemmRewriteTest, BatchedInstrLayoutTransposed) {
367   const char* hlo_text = R"(
368 HloModule BatchedInstrLayoutCheck
369 
370 ENTRY AddDotsFunc {
371   x = f32[5,2,3] parameter(0)
372   y = f32[5,3,4] parameter(1)
373   dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
374   ROOT out = f32[2,5,4] transpose(dot_a), dimensions={1, 0, 2}
375 }
376 
377 )";
378 
379   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
380   MatchOptimizedHlo(hlo_text,
381                     R"(
382 
383 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,5,4] {
384 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0)
385 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
386 ; CHECK-NEXT:    [[GEMM:%[^ ]+]] = f32[5,2,4]{2,0,1} custom-call([[P0]], [[P1]]),
387 ; CHECK:           custom_call_target="__cublas$gemm",
388 ; CHECK:           backend_config="{
389 ; CHECK-DAG:       \"alpha_real\":1
390 ; CHECK-DAG:       \"alpha_imag\":0
391 ; CHECK-DAG:       \"beta\":0
392 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]}
393 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
394 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
395 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
396 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,5,4]{2,1,0} bitcast([[GEMM]])
397       )");
398 }
399 
400 TEST_F(GemmRewriteTest, BatchedInstrLayoutBatchNotInMinorDim) {
401   const char* hlo_text = R"(
402 HloModule BatchedInstrLayoutBatchNotInMinorDim
403 
404 ENTRY AddDotsFunc {
405   x = f32[5,2,3] parameter(0)
406   y = f32[5,3,4] parameter(1)
407   dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
408   ROOT out = f32[2,4,5] transpose(dot_a), dimensions={1, 2, 0}
409 }
410 
411 )";
412 
413   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
414   MatchOptimizedHlo(hlo_text,
415                     R"(
416 
417 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,4,5] {
418 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0)
419 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
420 ; CHECK-NEXT:    [[GEMM:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]),
421 ; CHECK:           custom_call_target="__cublas$gemm",
422 ; CHECK:           backend_config="{
423 ; CHECK-DAG:       \"alpha_real\":1
424 ; CHECK-DAG:       \"alpha_imag\":0
425 ; CHECK-DAG:       \"beta\":0
426 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]}
427 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
428 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
429 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
430 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,4,5]{2,1,0} [[OP:[^ ]+]]([[GEMM]])
431       )");
432 }
433 
434 TEST_F(GemmRewriteTest, AlphaSimpleRewrite) {
435   const char* hlo_text = R"(
436 HloModule AlphaSimpleRewrite
437 
438 ENTRY AddDotsFunc {
439   x = f32[2,2] parameter(0)
440   y = f32[2,2] parameter(1)
441   k = f32[] constant(3.0)
442   k_broadcast = f32[2, 2] broadcast(k), dimensions={}
443   dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
444   ROOT dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
445 }
446 
447 )";
448 
449   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
450   MatchOptimizedHlo(hlo_text,
451                     R"(
452 
453 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
454 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
455 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
456 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]),
457 ; CHECK:           custom_call_target="__cublas$gemm",
458 ; CHECK:           backend_config="{
459 ; CHECK-DAG:       \"alpha_real\":3
460 ; CHECK-DAG:       \"alpha_imag\":0
461 ; CHECK-DAG:       \"beta\":0
462 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
463 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
464 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
465 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
466       )");
467 }
468 
469 TEST_F(GemmRewriteTest, ComplexAlphaSimpleRewrite) {
470   const char* hlo_text = R"(
471 HloModule ComplexAlphaSimpleRewrite
472 
473 ENTRY AddDotsFunc {
474   x = c64[2,2] parameter(0)
475   y = c64[2,2] parameter(1)
476   k = c64[] constant((3.0, 3.0))
477   k_broadcast = c64[2, 2] broadcast(k), dimensions={}
478   dot_a = c64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
479   ROOT dot_a_multiplied = c64[2, 2] multiply(dot_a, k_broadcast)
480 }
481 
482 )";
483 
484   DebugOptions debug_options = GetDebugOptionsForTest();
485   if (!debug_options.xla_gpu_enable_cublaslt()) {
486     EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
487   } else {
488     EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-2}));
489   }
490   MatchOptimizedHlo(hlo_text,
491                     R"(
492 
493 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: c64[2,2], y: c64[2,2]) -> c64[2,2] {
494 ; CHECK-NEXT:    [[P0:%[^ ]+]] = c64[2,2]{1,0} parameter(0)
495 ; CHECK-NEXT:    [[P1:%[^ ]+]] = c64[2,2]{1,0} parameter(1)
496 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = c64[2,2]{1,0} custom-call([[P0]], [[P1]]),
497 ; CHECK:           custom_call_target="__cublas$gemm",
498 ; CHECK:           backend_config="{
499 ; CHECK-DAG:       \"alpha_real\":3
500 ; CHECK-DAG:       \"alpha_imag\":3
501 ; CHECK-DAG:       \"beta\":0
502 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
503 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
504 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
505 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
506       )");
507 }
508 
509 TEST_F(GemmRewriteTest, AlphaMultipleUsersNoRewrite) {
510   const char* hlo_text = R"(
511 HloModule AlphaMultipleUsersNoRewrite
512 
513 ENTRY AddDotsFunc {
514   x = f32[2,2] parameter(0)
515   y = f32[2,2] parameter(1)
516   k = f32[] constant(3.0)
517   k_broadcast = f32[2, 2] broadcast(k), dimensions={}
518   dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
519   dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
520   ROOT out = f32[2,2] add(dot_a_multiplied, dot_a)
521 }
522 
523 )";
524 
525   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
526   MatchOptimizedHlo(hlo_text,
527                     R"(
528 ; CHECK:    [[C:%[^ ]+]] = f32[2,2]{1,0} custom-call([[A:%[^ ]+]], [[B:%[^ ]+]]),
529 ; CHECK:           custom_call_target="__cublas$gemm",
530 ; CHECK:           backend_config="{
531 ; CHECK-DAG:       \"alpha_real\":1
532 ; CHECK-DAG:       \"alpha_imag\":0
533 ; CHECK-DAG:       \"beta\":0
534 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
535 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
536 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
537 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
538       )");
539 }
540 
541 TEST_F(GemmRewriteTest, AlphaVectorNoRewrite) {
542   const char* hlo_text = R"(
543 HloModule AlphaVectorNoRewrite
544 
545 ENTRY AddDotsFunc {
546   x = f32[2,2] parameter(0)
547   y = f32[2,2] parameter(1)
548   alpha = f32[2] constant({1, 2})
549   alpha_broadcast = f32[2,2] broadcast(alpha), dimensions={1}
550   dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
551   ROOT dot_a_multiplied = f32[2, 2] multiply(dot, alpha_broadcast)
552 }
553 )";
554 
555   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
556   MatchOptimizedHlo(hlo_text,
557                     R"(
558 
559 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
560 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
561 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
562 ; CHECK-NEXT:    [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]),
563 ; CHECK:           custom_call_target="__cublas$gemm",
564 ; CHECK:           backend_config="{
565 ; CHECK-DAG:       \"alpha_real\":1
566 ; CHECK-DAG:       \"alpha_imag\":0
567 ; CHECK-DAG:       \"beta\":0
568 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
569 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
570 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
571 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
572       )");
573 }
574 
575 TEST_F(GemmRewriteTest, AlphaBetaRewrite) {
576   const char* hlo_text = R"(
577 HloModule NonZeroAlphaBeta
578 
579 ENTRY AddDotsFunc {
580   x = f32[2,2] parameter(0)
581   y = f32[2,2] parameter(1)
582   bias = f32[2,2] parameter(2)
583   k = f32[] constant(3.0)
584   k_broadcast = f32[2, 2] broadcast(k), dimensions={}
585   dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
586   dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
587   ROOT out = f32[2,2] add(dot_a_multiplied, bias)
588 }
589 
590 )";
591 
592   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
593   MatchOptimizedHlo(hlo_text,
594                     R"(
595 
596 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] {
597 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
598 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
599 ; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
600 ; CHECK-NEXT:    [[P2_COPY:%[^ ]+]] = f32[2,2]{1,0} copy([[P2]])
601 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]], [[P2_COPY]]),
602 ; CHECK:           custom_call_target="__cublas$gemm",
603 ; CHECK:           output_to_operand_aliasing={{{{}: \(2, {}\)}}},
604 ; CHECK:           backend_config="{
605 ; CHECK-DAG:       \"alpha_real\":3
606 ; CHECK-DAG:       \"alpha_imag\":0
607 ; CHECK-DAG:       \"beta\":1
608 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
609 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
610 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
611 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
612       )");
613 }
614 
615 TEST_F(GemmRewriteTest, BiasMultipleUsersNoRewrite) {
616   const char* hlo_text = R"(
617 HloModule BiasMultipleUsersNoRewrite
618 
619 ENTRY AddDotsFunc {
620   x = f32[2,2] parameter(0)
621   y = f32[2,2] parameter(1)
622   bias = f32[2,2] parameter(2)
623   k = f32[] constant(3.0)
624   k_broadcast = f32[2, 2] broadcast(k), dimensions={}
625   dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
626   dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast)
627   biased_out = f32[2,2] add(dot_a_multiplied, bias)
628   ROOT out = f32[2,2] add(biased_out, bias)
629 }
630 
631 )";
632 
633   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
634   MatchOptimizedHlo(hlo_text,
635                     R"(
636 
637 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] {
638 ; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
639 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
640 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
641 ; CHECK-NEXT:    [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]),
642 ; CHECK:           custom_call_target="__cublas$gemm",
643 ; CHECK:           backend_config="{
644 ; CHECK-DAG:       \"alpha_real\":3
645 ; CHECK-DAG:       \"alpha_imag\":0
646 ; CHECK-DAG:       \"beta\":0
647 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
648 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
649 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
650 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
651       )");
652 }
653 
654 TEST_F(GemmRewriteTest, SharedBufferAssignment) {
655   const char* hlo_text = R"(
656 HloModule SharedBufferAssignment
657 
658 ENTRY AddDotsFunc {
659   x = f32[2,2] parameter(0)
660   y = f32[2,2] parameter(1)
661   bias = f32[2,2] add(x, y)
662   dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
663   ROOT out = f32[2,2] add(dot, bias)
664 }
665 
666 )";
667 
668   // Bias should be fused into the multiplication.
669   CheckNumberOfAllocations(hlo_text, 3);
670   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
671 }
672 
673 TEST_F(GemmRewriteTest, BF16Gemm) {
674   const char* hlo_text = R"(
675 HloModule bf16gemm
676 
677 ENTRY bf16gemm {
678   %parameter.1 = bf16[12,4]{1,0} parameter(0)
679   %parameter.2 = bf16[4,8]{1,0} parameter(1)
680   ROOT %dot.8 = bf16[12,8] dot(bf16[12,4] %parameter.1, bf16[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
681 }
682   )";
683   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
684 
685   if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) {
686     MatchOptimizedHlo(hlo_text,
687                       R"(
688 ; CHECK: bf16[16,8]{1,0} custom-call(bf16[16,8]{1,0} {{.*}}, bf16[8,8]{1,0} {{.*}}), custom_call_target="__cublas$gemm"
689   )",
690                       /*print_operand_shape=*/true);
691   } else {
692     MatchOptimizedHlo(hlo_text,
693                       R"(
694 ; CHECK: bf16[12,8]{1,0} custom-call(bf16[12,4]{1,0} [[P0:%[^ ]+]], bf16[4,8]{1,0} [[P1:%[^ ]+]]), custom_call_target="__cublas$gemm"
695 
696   )",
697                       /*print_operand_shape=*/true);
698   }
699 }
700 
701 TEST_F(GemmRewriteTest, BF16GemmStrided) {
702   const char* hlo_text = R"(
703 HloModule bf16gemm
704 
705 ENTRY bf16gemm {
706   %parameter.1 = bf16[3,3,4] parameter(0)
707   %parameter.2 = bf16[3,3,2] parameter(1)
708   ROOT %dot.3 = bf16[3,4,2]{2,1,0} dot(bf16[3,3,4]{2,1,0} %parameter.1, bf16[3,3,2]{2,1,0} %parameter.2), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,highest}
709 }
710 
711   )";
712   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
713 
714   if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) {
715     MatchOptimizedHlo(hlo_text,
716                       R"(
717     ; CHECK: bf16[3,8,8]{2,1,0} custom-call(bf16[3,8,8]{2,1,0} {{.*}}, bf16[3,8,8]{2,1,0} {{.*}}), custom_call_target="__cublas$gemm"
718     )",
719                       /*print_operand_shape=*/true);
720   } else {
721     MatchOptimizedHlo(hlo_text,
722                       R"(
723     ; CHECK: ROOT [[OUT:%[^ ]+]] = bf16[3,4,2]{2,1,0} custom-call(bf16[3,3,4]{2,1,0} [[A:%[^ ]+]], bf16[3,3,2]{2,1,0} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
724     )",
725                       /*print_operand_shape=*/true);
726   }
727 }
728 
729 TEST_F(GemmRewriteTest, BF16GemmCodeGen) {
730   const char* hlo_text = R"(
731 HloModule bf16codegendgemm
732 
733 ENTRY bf16gemm {
734   %parameter.1 = bf16[3]{0} parameter(0)
735   %parameter.2 = bf16[3]{0} parameter(1)
736   ROOT %dot.3 = bf16[] dot(bf16[3]{0} %parameter.1, bf16[3]{0} %parameter.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}, operand_precision={highest,highest}
737 }
738   )";
739 
740   MatchOptimizedHlo(hlo_text, R"(
741 ; CHECK:  [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
742 ; CHECK:  [[INSTR_1:%[^ ]+]] = f32[3]{0} convert([[P1]])
743 ; CHECK:  [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
744 ; CHECK:  [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[P0]])
745 ; CHECK:  [[INSTR_4:%[^ ]+]] = f32[3]{0} multiply([[INSTR_1]], [[INSTR_3]])
746 ; CHECK:  [[INSTR_5:%[^ ]+]] = f32[] constant(0)
747 ; CHECK:  [[INSTR_6:%[^ ]+]] = f32[] reduce([[INSTR_4]], [[INSTR_5]]), dimensions={0}, to_apply=[[INSTR_7:%[^ ]+]]
748 ; CHECK:  ROOT [[INSTR_8:%[^ ]+]] = bf16[] convert([[INSTR_6]])
749   )");
750 
751   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
752 }
753 
754 TEST_F(GemmRewriteTest, BF16Transpose) {
755   const char* hlo_text = R"(
756 HloModule broadcast
757 
758 ENTRY broadcast {
759   p = bf16[9] parameter(0)
760   ROOT out = bf16[1,9] broadcast(p), dimensions={1}
761 }
762 )";
763 
764   MatchOptimizedHlo(hlo_text, R"(
765 ; CHECK: bf16[1,9]{1,0} bitcast
766 ; CHECK: bf16[1,9]{1,0} copy
767 )");
768 
769   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
770 }
771 
772 TEST_F(GemmRewriteTest, Int8Gemm) {
773   const char* hlo_text = R"(
774 HloModule int8gemm
775 
776 ENTRY int8gemm {
777   %parameter.1 = s8[12,4]{1,0} parameter(0)
778   %parameter.2 = s8[4,8]{1,0} parameter(1)
779   ROOT %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
780 }
781   )";
782   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
783 
784   if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
785     MatchOptimizedHlo(hlo_text,
786                       R"(
787 ; CHECK: s32[12,8]{1,0} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
788   )",
789                       /*print_operand_shape=*/true);
790   } else {
791     MatchOptimizedHlo(hlo_text,
792                       R"(
793 ; CHECK: s32[12,8]{1,0} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
794 
795   )",
796                       /*print_operand_shape=*/true);
797   }
798 }
799 
800 TEST_F(GemmRewriteTest, Int8GemmNoAlphaRewrite) {
801   const char* hlo_text = R"(
802 HloModule int8gemm
803 
804 ENTRY int8gemm {
805   %parameter.1 = s8[12,4]{1,0} parameter(0)
806   %parameter.2 = s8[4,8]{1,0} parameter(1)
807   k = s32[] constant(2)
808   k_broadcast = s32[12,8] broadcast(k), dimensions={}
809   %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
810   ROOT dot_multiplied = s32[12,8] multiply(%dot.8, k_broadcast)
811 }
812   )";
813   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
814 
815   if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
816     MatchOptimizedHlo(hlo_text,
817                       R"(
818 ; CHECK: s32[12,8]{1,0} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
819 ; CHECK:           custom_call_target="__cublas$gemm",
820 ; CHECK:           backend_config="{
821 ; CHECK-DAG:       \"alpha_real\":1
822 ; CHECK-DAG:       \"alpha_imag\":0
823   )",
824                       /*print_operand_shape=*/true);
825   } else {
826     MatchOptimizedHlo(hlo_text,
827                       R"(
828 ; CHECK: s32[12,8]{1,0} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
829 
830   )",
831                       /*print_operand_shape=*/true);
832   }
833 }
834 
835 TEST_F(GemmRewriteTest, Int8GemmNoBetaRewrite) {
836   const char* hlo_text = R"(
837 HloModule int8gemm
838 
839 ENTRY int8gemm {
840   %parameter.1 = s8[12,4]{1,0} parameter(0)
841   %parameter.2 = s8[4,8]{1,0} parameter(1)
842   bias = s32[12,8] parameter(2)
843   %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
844   ROOT out = s32[12,8] add(%dot.8, bias)
845 }
846   )";
847   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
848 
849   if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
850     MatchOptimizedHlo(hlo_text,
851                       R"(
852 ; CHECK: s32[12,8]{1,0} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
853 ; CHECK:           custom_call_target="__cublas$gemm",
854 ; CHECK:           backend_config="{
855 ; CHECK-DAG:       \"alpha_real\":1
856 ; CHECK-DAG:       \"alpha_imag\":0
857 ; CHECK-DAG:       \"beta\":0
858   )",
859                       /*print_operand_shape=*/true);
860   } else {
861     MatchOptimizedHlo(hlo_text,
862                       R"(
863 ; CHECK: s32[12,8]{1,0} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
864 
865   )",
866                       /*print_operand_shape=*/true);
867   }
868 }
869 
870 TEST_F(GemmRewriteTest, Int8GemmNotMultipleOfFour) {
871   const char* hlo_text = R"(
872 HloModule int8gemm
873 
874 ENTRY int8gemm {
875   %parameter.1 = s8[13,4]{1,0} parameter(0)
876   %parameter.2 = s8[4,9]{1,0} parameter(1)
877   ROOT %dot.9 = s32[13,9] dot(s8[13,4] %parameter.1, s8[4,9] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
878 }
879   )";
880   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
881 
882   if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
883     MatchOptimizedHlo(hlo_text,
884                       R"(
885 ; CHECK: s32[16,12]{1,0} custom-call(s8[16,4]{1,0} [[A:%[^ ]+]], s8[4,12]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
886   )",
887                       /*print_operand_shape=*/true);
888   } else {
889     MatchOptimizedHlo(hlo_text,
890                       R"(
891 ; CHECK: s32[13,9]{1,0} dot(s32[13,4]{1,0} [[A:%[^ ]+]], s32[4,9]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
892 
893   )",
894                       /*print_operand_shape=*/true);
895   }
896 }
897 
898 TEST_F(GemmRewriteTest, BF16GemmWithBias) {
899   const char* hlo_text = R"(
900 HloModule BF16GemmWithBias
901 
902 ENTRY BF16GemmWithBias {
903   x = bf16[8,8]{1,0} parameter(0)
904   y = bf16[8,8]{1,0} parameter(1)
905   dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
906   bias = bf16[8,8]{1,0} parameter(2)
907   ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias)
908 }
909   )";
910 
911   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
912   MatchOptimizedHlo(hlo_text,
913                     R"(
914 ; CHECK-LABEL: ENTRY %BF16GemmWithBias (x: bf16[8,8], y: bf16[8,8], bias: bf16[8,8]) -> bf16[8,8] {
915 ; CHECK-NEXT:    [[P0:%[^ ]+]] = bf16[8,8]{1,0} parameter(0)
916 ; CHECK-NEXT:    [[P1:%[^ ]+]] = bf16[8,8]{1,0} parameter(1)
917 ; CHECK-NEXT:    [[P2:%[^ ]+]] = bf16[8,8]{1,0} parameter(2)
918 ; CHECK-NEXT:    [[P2_COPY:%[^ ]+]] = bf16[8,8]{1,0} copy([[P2]])
919 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2_COPY]]), custom_call_target="__cublas$gemm", output_to_operand_aliasing={{{{}: \(2, {}\)}}}, backend_config="{
920 ; CHECK-DAG:       \"alpha_real\":1
921 ; CHECK-DAG:       \"alpha_imag\":0
922 ; CHECK-DAG:       \"beta\":1
923 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
924 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
925 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
926 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
927       )");
928 }
929 
930 #if GOOGLE_CUDA
931 
932 class CublasLtMatmulRewriteTest : public GpuCodegenTest {
933   DebugOptions GetDebugOptionsForTest() override {
934     DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
935     debug_options.set_xla_gpu_enable_cublaslt(true);
936     return debug_options;
937   }
938 };
939 
940 TEST_F(CublasLtMatmulRewriteTest, Simple) {
941   const char* hlo_text = R"(
942 HloModule test
943 
944 ENTRY test {
945   x = f32[2,3] parameter(0)
946   y = f32[3,4] parameter(1)
947   ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
948 }
949 
950 )";
951 
952   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
953   MatchOptimizedHlo(hlo_text,
954                     R"(
955 
956 ; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> f32[2,4] {
957 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
958 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
959 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), custom_call_target="__cublas$lt$matmul", backend_config="{
960 ; CHECK-DAG:       \"alpha_real\":1
961 ; CHECK-DAG:       \"alpha_imag\":0
962 ; CHECK-DAG:       \"beta\":0
963 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
964 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
965 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
966 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
967       )");
968 }
969 
970 TEST_F(CublasLtMatmulRewriteTest, MatrixBias) {
971   const char* hlo_text = R"(
972 HloModule test
973 
974 ENTRY test {
975   x = f32[2,3] parameter(0)
976   y = f32[3,4] parameter(1)
977   z = f32[2,4] parameter(2)
978   dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
979   ROOT out = f32[2,4] add(dot_a, z)
980 }
981 
982 )";
983 
984   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
985   MatchOptimizedHlo(hlo_text,
986                     R"(
987 
988 ; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2,4]) -> f32[2,4] {
989 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
990 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
991 ; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2)
992 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), custom_call_target="__cublas$lt$matmul", backend_config="{
993 ; CHECK-DAG:       \"alpha_real\":1
994 ; CHECK-DAG:       \"alpha_imag\":0
995 ; CHECK-DAG:       \"beta\":1
996 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
997 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
998 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
999 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
1000       )");
1001 }
1002 
1003 TEST_F(CublasLtMatmulRewriteTest, VectorBias) {
1004   const char* hlo_text = R"(
1005 HloModule test
1006 
1007 ENTRY test {
1008   x = f32[2,3] parameter(0)
1009   y = f32[3,4] parameter(1)
1010   z = f32[4] parameter(2)
1011   dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1012   z_bcast = f32[2,4] broadcast(z), dimensions={1}
1013   ROOT out = f32[2,4] add(dot_a, z_bcast)
1014 }
1015 
1016 )";
1017 
1018   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
1019   MatchOptimizedHlo(hlo_text,
1020                     R"(
1021 
1022 ; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> f32[2,4] {
1023 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
1024 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
1025 ; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
1026 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), custom_call_target="__cublas$lt$matmul", backend_config="{
1027 ; CHECK-DAG:       \"alpha_real\":1
1028 ; CHECK-DAG:       \"alpha_imag\":0
1029 ; CHECK-DAG:       \"beta\":0
1030 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
1031 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
1032 ; CHECK-DAG:       \"epilogue\":\"BIAS\"
1033 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
1034       )");
1035 }
1036 
1037 TEST_F(CublasLtMatmulRewriteTest, VectorBiasIncorrectAxisFusedAsMatrix) {
1038   const char* hlo_text = R"(
1039 HloModule test
1040 
1041 ENTRY test {
1042   x = f32[2,3] parameter(0)
1043   y = f32[3,4] parameter(1)
1044   z = f32[2] parameter(2)
1045   dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1046   z_bcast = f32[2,4] broadcast(z), dimensions={0}
1047   ROOT out = f32[2,4] add(dot_a, z_bcast)
1048 }
1049 
1050 )";
1051 
1052   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
1053   MatchOptimizedHlo(hlo_text,
1054                     R"(
1055 
1056 ; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2]) -> f32[2,4] {
1057 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
1058 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
1059 ; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
1060 ; CHECK-NEXT:    [[P2_BCAST:%[^ ]+]] = f32[2,4]{1,0} broadcast([[P2]]), dimensions={0}
1061 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2_BCAST]]), custom_call_target="__cublas$lt$matmul", backend_config="{
1062 ; CHECK-DAG:       \"alpha_real\":1
1063 ; CHECK-DAG:       \"alpha_imag\":0
1064 ; CHECK-DAG:       \"beta\":1
1065 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
1066 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
1067 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
1068 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
1069       )");
1070 }
1071 
1072 TEST_F(CublasLtMatmulRewriteTest, VectorBiasTransposed) {
1073   const char* hlo_text = R"(
1074 HloModule test
1075 
1076 ENTRY test {
1077   x = f32[2,3] parameter(0)
1078   y = f32[3,4] parameter(1)
1079   z = f32[2] parameter(2)
1080   dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1081   z_bcast = f32[2,4] broadcast(z), dimensions={0}
1082   add = f32[2,4] add(dot_a, z_bcast)
1083   ROOT out = f32[4,2] transpose(add), dimensions={1,0}
1084 }
1085 
1086 )";
1087 
1088   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
1089   MatchOptimizedHlo(hlo_text,
1090                     R"(
1091 
1092 ; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2]) -> f32[4,2] {
1093 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
1094 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
1095 ; CHECK-NEXT:    [[P2:%[^ ]+]] = f32[2]{0} parameter(2)
1096 ; CHECK-NEXT:    [[MATMUL:%[^ ]+]] = f32[2,4]{0,1} custom-call([[P0]], [[P1]], [[P2]]), custom_call_target="__cublas$lt$matmul", backend_config="{
1097 ; CHECK-DAG:       \"alpha_real\":1
1098 ; CHECK-DAG:       \"alpha_imag\":0
1099 ; CHECK-DAG:       \"beta\":0
1100 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
1101 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
1102 ; CHECK-DAG:       \"epilogue\":\"BIAS\"
1103 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
1104 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} bitcast([[MATMUL]])
1105       )");
1106 }
1107 
1108 TEST_F(CublasLtMatmulRewriteTest, VectorBiasThenMatrixBias) {
1109   const char* hlo_text = R"(
1110 HloModule test
1111 
1112 ENTRY test {
1113   x = f32[2,3] parameter(0)
1114   y = f32[3,4] parameter(1)
1115   z = f32[4] parameter(2)
1116   z2 = f32[2,4] parameter(3)
1117   dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1118   z_bcast = f32[2,4] broadcast(z), dimensions={1}
1119   add0 = f32[2,4] add(dot_a, z_bcast)
1120   ROOT add1 = f32[2,4] add(add0, z2)
1121 }
1122 
1123 )";
1124 
1125   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
1126   MatchOptimizedHlo(hlo_text,
1127                     R"(
1128 
1129 ; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4], z2: f32[2,4]) -> f32[2,4] {
1130 ; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
1131 ; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
1132 ; CHECK-DAG:     [[P2:%[^ ]+]] = f32[4]{0} parameter(2)
1133 ; CHECK-DAG:     [[P3:%[^ ]+]] = f32[2,4]{1,0} parameter(3)
1134 ; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P3]], [[P2]]), custom_call_target="__cublas$lt$matmul", backend_config="{
1135 ; CHECK-DAG:       \"alpha_real\":1
1136 ; CHECK-DAG:       \"alpha_imag\":0
1137 ; CHECK-DAG:       \"beta\":1
1138 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]}
1139 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
1140 ; CHECK-DAG:       \"epilogue\":\"BIAS\"
1141 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
1142       )");
1143 }
1144 
1145 #endif  // GOOGLE_CUDA
1146 
1147 using GemmRewriteHloTest = HloTestBase;
1148 
1149 TEST_F(GemmRewriteHloTest, MergeBitcastAndAdd) {
1150   const char* hlo_text = R"(
1151 HloModule test
1152 ENTRY test {
1153   x = f32[2,2] parameter(0)
1154   y = f32[2,2] parameter(1)
1155   bias = f32[4] parameter(2)
1156   dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1157   ROOT out = f32[4] add(f32[4] bitcast(dot), bias)
1158 }
1159 )";
1160 
1161   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1162                           ParseAndReturnVerifiedModule(hlo_text));
1163   GemmRewriter pass;
1164   TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
1165   EXPECT_TRUE(changed);
1166 
1167   EXPECT_THAT(
1168       module->entry_computation()->root_instruction(),
1169       GmockMatch(
1170           m::Bitcast(
1171               m::CustomCall("__cublas$gemm", m::Parameter(0), m::Parameter(1),
1172                             m::Bitcast(m::Parameter(2)).WithShape(F32, {2, 2})))
1173               .WithShape(F32, {4})));
1174 }
1175 
1176 TEST_F(GemmRewriteHloTest, FoldConstantBias) {
1177   const char* hlo_text = R"(
1178 HloModule test
1179 ENTRY test {
1180   x = f32[2,2] parameter(0)
1181   y = f32[2,2] parameter(1)
1182   bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
1183 
1184   dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1185   bias1 = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
1186   sum1 = add(dot1, bias1)
1187 
1188   dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1189   sum2 = add(dot2, f32[2,2] reshape(bias))
1190 
1191   dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1192   bias3 = f32[2,2] transpose(bias), dimensions={1,0}
1193   sum3 = add(dot3, bias3)
1194 
1195   dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1196   sum4 = add(dot4, f32[2,2] bitcast(bias))
1197 
1198   ROOT root = tuple(sum1, sum2, sum3, sum4)
1199 }
1200 )";
1201 
1202   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1203                           ParseAndReturnVerifiedModule(hlo_text));
1204   GemmRewriter pass;
1205   TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
1206   SCOPED_TRACE(module->ToString());
1207   EXPECT_TRUE(changed);
1208 
1209   EXPECT_THAT(
1210       module->entry_computation()->root_instruction(),
1211       GmockMatch(m::Tuple(
1212           m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
1213           m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
1214           m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
1215           m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()))));
1216 }
1217 
1218 }  // namespace
1219 }  // namespace gpu
1220 }  // namespace xla
1221