xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
17 
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
20 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
21 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
22 #include "tensorflow/compiler/xla/service/hlo_dce.h"
23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
24 #include "tensorflow/compiler/xla/service/hlo_parser.h"
25 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
26 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
27 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/filecheck.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33 
34 namespace xla {
35 namespace gpu {
36 namespace {
37 
38 namespace op = xla::testing::opcode_matchers;
39 
40 class HorizontalLoopFusionTest : public HloTestBase {};
41 
TEST_F(HorizontalLoopFusionTest,BasicTest)42 TEST_F(HorizontalLoopFusionTest, BasicTest) {
43   auto module = ParseAndReturnVerifiedModule(R"(
44  HloModule BasicTest
45 
46  fused_computation.1 {
47    arg.1 = f16[1024]{0} parameter(0)
48    arg.2 = f16[1024]{0} parameter(1)
49    ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
50  }
51 
52  fused_computation.2 {
53    arg.1 = f16[123]{0} parameter(0)
54    arg.2 = f16[123]{0} parameter(1)
55    ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
56  }
57 
58  ENTRY entry_computation {
59    arg.1 = f16[1024]{0} parameter(0)
60    arg.2 = f16[1024]{0} parameter(1)
61    arg.3 = f16[123]{0} parameter(2)
62    arg.4 = f16[123]{0} parameter(3)
63    fusion.1 = f16[1024]{0}
64        fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
65    fusion.2 = f16[123]{0}
66        fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
67    ROOT tuple.1 = (f16[1024]{0}, f16[123]{0})
68        tuple(fusion.1, fusion.2)
69  }
70 )")
71                     .ValueOrDie();
72 
73   EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
74   EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
75 
76   const HloInstruction* entry_root =
77       module->entry_computation()->root_instruction();
78   EXPECT_THAT(entry_root,
79               op::Tuple(op::Bitcast(op::GetTupleElement(op::Fusion())),
80                         op::Bitcast(op::GetTupleElement(op::Fusion()))));
81 
82   const HloInstruction* fusion = entry_root->operand(0)->operand(0)->operand(0);
83   ASSERT_TRUE(fusion->IsMultiOutputFusion());
84   EXPECT_THAT(
85       fusion->fused_expression_root(),
86       op::Tuple(op::Slice(op::Concatenate(op::Reshape(), op::Reshape())),
87                 op::Slice(op::Concatenate(op::Reshape(), op::Reshape()))));
88 }
89 
90 // Horizontal fusion should not be triggered as fusion will create cycles.
TEST_F(HorizontalLoopFusionTest,NegativeTestForCycle)91 TEST_F(HorizontalLoopFusionTest, NegativeTestForCycle) {
92   auto module = ParseAndReturnVerifiedModule(R"(
93  HloModule NegativeTestForCycle
94 
95  fused_computation.1 {
96    arg.1 = f16[123]{0} parameter(0)
97    arg.2 = f16[123]{0} parameter(1)
98    ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
99  }
100 
101  fused_computation.2 {
102    arg.1 = f16[123]{0} parameter(0)
103    arg.2 = f16[123]{0} parameter(1)
104    ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
105  }
106 
107  ENTRY entry_computation {
108    arg.1 = f16[123]{0} parameter(0)
109    arg.2 = f16[123]{0} parameter(1)
110    arg.3 = f16[123]{0} parameter(2)
111    arg.4 = f16[123]{0} parameter(3)
112    // fusion.1 and fusion.2 will not be horizontally fused as it will create
113    // a cycle through fusion.1 -> add.2 -> fusion.2
114    fusion.1 = f16[123]{0}
115        fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
116    add.2 = f16[123]{0} add(fusion.1, arg.4)
117    fusion.2 = f16[123]{0}
118        fusion(add.2, arg.3), kind=kLoop, calls=fused_computation.2
119    ROOT tuple.1 = (f16[123]{0}, f16[123]{0}, f16[123]{0})
120        tuple(fusion.1, fusion.2, add.2)
121  }
122 )")
123                     .ValueOrDie();
124 
125   EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
126 }
127 
TEST_F(HorizontalLoopFusionTest,NegativeTestForIncompatibleTypes)128 TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) {
129   auto module = ParseAndReturnVerifiedModule(R"(
130  HloModule NegativeTestForIncompatibleTypes
131 
132  fused_computation.1 {
133    arg.1 = f16[1024]{0} parameter(0)
134    arg.2 = f16[1024]{0} parameter(1)
135    ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
136  }
137 
138  fused_computation.2 {
139    arg.1 = s32[123]{0} parameter(0)
140    arg.2 = s32[123]{0} parameter(1)
141    ROOT add.1 = s32[123]{0} add(arg.1, arg.2)
142  }
143 
144  ENTRY entry_computation {
145    arg.1 = f16[1024]{0} parameter(0)
146    arg.2 = f16[1024]{0} parameter(1)
147    arg.3 = s32[123]{0} parameter(2)
148    arg.4 = s32[123]{0} parameter(3)
149    // fusion.1 and fusion.2 will not be horizontally fused because their output
150    // types are different.
151    fusion.1 = f16[1024]{0}
152        fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
153    fusion.2 = s32[123]{0}
154        fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
155    ROOT tuple.1 = (f16[1024]{0}, s32[123]{0})
156        tuple(fusion.1, fusion.2)
157  }
158 )")
159                     .ValueOrDie();
160 
161   EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
162 }
163 
TEST_F(HorizontalLoopFusionTest,HorizontalLoopFusionAfterVerticalFusion)164 TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) {
165   auto module = ParseAndReturnVerifiedModule(R"(
166  HloModule MergeSharedFusionInstruction
167 
168  ENTRY MergeSharedFusionInstruction.Computation0 {
169   param.1.1   = f32[4,1024]{1,0} parameter(0)
170   param.1.2   = f32[4,1024]{1,0} parameter(1)
171   param.1.3   = f32[4,1024]{1,0} parameter(2)
172   param.2.1   = f32[321,5]{1,0} parameter(3)
173   param.2.2   = f32[321,5]{1,0} parameter(4)
174   param.2.3   = f32[321,5]{1,0} parameter(5)
175   const.1     = f32[] constant(3)
176   const.2     = f32[] constant(3)
177   broadcast.1 = f32[4,1024]{1,0} broadcast(const.1), dimensions={}
178   broadcast.2 = f32[321,5]{1,0} broadcast(const.2), dimensions={}
179   mul.1.1     = f32[4,1024]{1,0} multiply(param.1.1, param.1.2)
180   mul.1.2     = f32[4,1024]{1,0} multiply(param.1.3, broadcast.1)
181   add.1       = f32[4,1024]{1,0} add(mul.1.1, mul.1.2)
182   mul.2.1     = f32[321,5]{1,0} multiply(param.2.1, param.2.2)
183   mul.2.2     = f32[321,5]{1,0} multiply(param.2.3, broadcast.2)
184   add.2       = f32[321,5]{1,0} add(mul.2.1, mul.2.2)
185   ROOT tuple = (f32[4,1024]{1,0}, f32[321,5]{1,0}) tuple(add.1, add.2)
186 })")
187                     .ValueOrDie();
188 
189   HloPassPipeline fusion("fusion");
190   fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/false);
191   fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/true);
192   EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
193   EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
194 
195   VLOG(2) << "Dump after horizontal fusion:";
196   VLOG(2) << module->ToString();
197 
198   EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
199 }
200 
TEST_F(HorizontalLoopFusionTest,GradientDescentOptimizerLike)201 TEST_F(HorizontalLoopFusionTest, GradientDescentOptimizerLike) {
202   HloComputation::Builder builder(TestName());
203 
204   std::vector<HloInstruction*> var_outs;
205   for (int64_t i = 0; i < 128; ++i) {
206     // For shapes {1, 1024}, {2, 1024}, ..., {128, 1024}
207     Shape shape = ShapeUtil::MakeShape(F32, {i + 1, 1024});
208     HloInstruction* param_var_in = builder.AddInstruction(
209         HloInstruction::CreateParameter(i * 3 + 0, shape, "var.in"));
210     HloInstruction* param_alpha =
211         builder.AddInstruction(HloInstruction::CreateParameter(
212             i * 3 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
213     HloInstruction* param_delta = builder.AddInstruction(
214         HloInstruction::CreateParameter(i * 3 + 2, shape, "delta"));
215     HloInstruction* alpha_broadcasted = builder.AddInstruction(
216         HloInstruction::CreateBroadcast(shape, param_alpha, {}));
217     HloInstruction* alpha_delta =
218         builder.AddInstruction(HloInstruction::CreateBinary(
219             shape, HloOpcode::kMultiply, alpha_broadcasted, param_delta));
220     HloInstruction* var_out =
221         builder.AddInstruction(HloInstruction::CreateBinary(
222             shape, HloOpcode::kSubtract, param_var_in, alpha_delta));
223     var_outs.push_back(var_out);
224   }
225   builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
226 
227   auto module = CreateNewVerifiedModule();
228   module->AddEntryComputation(builder.Build());
229 
230   // Testing with the entire gpu optimization pipeline.
231   EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0, 0}));
232 }
233 
TEST_F(HorizontalLoopFusionTest,FusingDifferentOutputs)234 TEST_F(HorizontalLoopFusionTest, FusingDifferentOutputs) {
235   auto module = ParseAndReturnVerifiedModule(R"(
236  HloModule HeterogeneousMultiOutputFusions
237 
238  fused_computation.1 {
239    arg.1 = f16[1024]{0} parameter(0)
240    arg.2 = f16[1024]{0} parameter(1)
241    arg.3 = f16[1024]{0} parameter(2)
242    arg.4 = f16[1024]{0} parameter(3)
243    mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
244    mul.2 = f16[1024]{0} multiply(arg.3, arg.4)
245    add.1 = f16[1024]{0} add(mul.1, mul.2)
246    ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}) tuple(add.1, mul.1)
247  }
248 
249  fused_computation.2 {
250    arg.1 = f16[123]{0} parameter(0)
251    arg.2 = f16[123]{0} parameter(1)
252    arg.3 = f16[123]{0} parameter(2)
253    arg.4 = f16[123]{0} parameter(3)
254    add.1 = f16[123]{0} add(arg.1, arg.2)
255    add.2 = f16[123]{0} add(arg.3, arg.4)
256    mul.1 = f16[123]{0} multiply(add.1, add.2)
257    ROOT tuple.1 = (f16[123]{0}, f16[123]{0}) tuple(mul.1, add.1)
258  }
259 
260  ENTRY entry_computation {
261    arg.1 = f16[1024]{0} parameter(0)
262    arg.2 = f16[1024]{0} parameter(1)
263    arg.3 = f16[1024]{0} parameter(2)
264    arg.4 = f16[1024]{0} parameter(3)
265    arg.5 = f16[123]{0} parameter(4)
266    arg.6 = f16[123]{0} parameter(5)
267    arg.7 = f16[123]{0} parameter(6)
268    arg.8 = f16[123]{0} parameter(7)
269    fusion.1 = (f16[1024]{0}, f16[1024]{0})
270        fusion(arg.1, arg.2, arg.3, arg.4),
271            kind=kLoop, calls=fused_computation.1
272    fusion.2 = (f16[123]{0}, f16[123]{0})
273        fusion(arg.5, arg.6, arg.7, arg.8),
274            kind=kLoop, calls=fused_computation.2
275    gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=0
276    gte.2 = f16[1024]{0} get-tuple-element(fusion.1), index=1
277    gte.3 = f16[123]{0} get-tuple-element(fusion.2), index=0
278    gte.4 = f16[123]{0} get-tuple-element(fusion.2), index=1
279    ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}, f16[123]{0}, f16[123]{0})
280        tuple(gte.1, gte.2, gte.3, gte.4)
281  }
282 )")
283                     .ValueOrDie();
284 
285   EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
286   EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
287 
288   VLOG(2) << "Dump after horizontal fusion:";
289   VLOG(2) << module->ToString();
290 
291   EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
292 }
293 
TEST_F(HorizontalLoopFusionTest,RMSPropLike)294 TEST_F(HorizontalLoopFusionTest, RMSPropLike) {
295   HloComputation::Builder builder(TestName());
296 
297   std::vector<HloInstruction*> all_outputs;
298   for (int64_t i = 0; i < 48; ++i) {
299     Shape shape = ShapeUtil::MakeShape(F32, {2, 1024 + i});
300     // ms <- grad**2 (1 - rho) + ms * rho
301     HloInstruction* grad = builder.AddInstruction(
302         HloInstruction::CreateParameter(i * 9 + 0, shape, "grad"));
303     HloInstruction* ms = builder.AddInstruction(
304         HloInstruction::CreateParameter(i * 9 + 1, shape, "ms"));
305     HloInstruction* rho =
306         builder.AddInstruction(HloInstruction::CreateParameter(
307             i * 9 + 2, ShapeUtil::MakeShape(F32, {}), "rho"));
308     HloInstruction* one_minus_rho =
309         builder.AddInstruction(HloInstruction::CreateParameter(
310             i * 9 + 3, ShapeUtil::MakeShape(F32, {}), "one_minus_rho"));
311     HloInstruction* rho_broadcasted =
312         builder.AddInstruction(HloInstruction::CreateBroadcast(shape, rho, {}));
313     HloInstruction* one_mins_rho_broadcasted = builder.AddInstruction(
314         HloInstruction::CreateBroadcast(shape, one_minus_rho, {}));
315     HloInstruction* grad_squared = builder.AddInstruction(
316         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, grad, grad));
317     HloInstruction* ms_1st_term = builder.AddInstruction(
318         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, grad_squared,
319                                      one_mins_rho_broadcasted));
320     HloInstruction* ms_2nd_term =
321         builder.AddInstruction(HloInstruction::CreateBinary(
322             shape, HloOpcode::kMultiply, ms, rho_broadcasted));
323     HloInstruction* ms_out =
324         builder.AddInstruction(HloInstruction::CreateBinary(
325             shape, HloOpcode::kAdd, ms_1st_term, ms_2nd_term));
326 
327     // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
328     HloInstruction* momentum = builder.AddInstruction(
329         HloInstruction::CreateParameter(i * 9 + 4, shape, "momemtum"));
330     HloInstruction* mom = builder.AddInstruction(
331         HloInstruction::CreateParameter(i * 9 + 5, shape, "mom"));
332     HloInstruction* lr = builder.AddInstruction(HloInstruction::CreateParameter(
333         i * 9 + 6, ShapeUtil::MakeShape(F32, {}), "lr"));
334     HloInstruction* epsilon =
335         builder.AddInstruction(HloInstruction::CreateParameter(
336             i * 9 + 7, ShapeUtil::MakeShape(F32, {}), "epsilon"));
337     HloInstruction* lr_broadcasted =
338         builder.AddInstruction(HloInstruction::CreateBroadcast(shape, lr, {}));
339     HloInstruction* epsilon_broadcasted = builder.AddInstruction(
340         HloInstruction::CreateBroadcast(shape, epsilon, {}));
341     HloInstruction* mom_1st_term =
342         builder.AddInstruction(HloInstruction::CreateBinary(
343             shape, HloOpcode::kMultiply, momentum, mom));
344     HloInstruction* ms_eps =
345         builder.AddInstruction(HloInstruction::CreateBinary(
346             shape, HloOpcode::kAdd, ms_out, epsilon_broadcasted));
347     HloInstruction* ms_eps_rsq = builder.AddInstruction(
348         HloInstruction::CreateUnary(shape, HloOpcode::kRsqrt, ms_eps));
349     HloInstruction* grad_ms_eps_rsq =
350         builder.AddInstruction(HloInstruction::CreateBinary(
351             shape, HloOpcode::kMultiply, grad, ms_eps_rsq));
352     HloInstruction* mom_2nd_term =
353         builder.AddInstruction(HloInstruction::CreateBinary(
354             shape, HloOpcode::kMultiply, lr_broadcasted, grad_ms_eps_rsq));
355     HloInstruction* mom_out =
356         builder.AddInstruction(HloInstruction::CreateBinary(
357             shape, HloOpcode::kAdd, mom_1st_term, mom_2nd_term));
358 
359     // var <- var - mom
360     HloInstruction* var = builder.AddInstruction(
361         HloInstruction::CreateParameter(i * 9 + 8, shape, "var"));
362     HloInstruction* var_out =
363         builder.AddInstruction(HloInstruction::CreateBinary(
364             shape, HloOpcode::kSubtract, var, mom_out));
365 
366     all_outputs.push_back(ms_out);
367     all_outputs.push_back(mom_out);
368     all_outputs.push_back(var_out);
369   }
370   builder.AddInstruction(HloInstruction::CreateTuple(all_outputs));
371 
372   auto module = CreateNewVerifiedModule();
373   module->AddEntryComputation(builder.Build());
374 
375   EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5}));
376 }
377 
TEST_F(HorizontalLoopFusionTest,DynamicUpdateSlice)378 TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) {
379   auto module = ParseAndReturnVerifiedModule(R"(
380   HloModule NegativeTestForDynamicUpdateSlice
381 
382   fusion.1 {
383     p.0 = f16[5,9,10]{2,1,0} parameter(0)
384     p.1 = s32[] parameter(1)
385     p.2 = f16[1,9,10]{2,1,0} parameter(2)
386     c.0 = s32[] constant(0)
387     ROOT %dynamic-update-slice =
388         f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
389   }
390 
391   fusion.2 {
392     p.0 = f16[5,9,10]{2,1,0} parameter(0)
393     p.1 = s32[] parameter(1)
394     p.2 = f16[1,9,10]{2,1,0} parameter(2)
395     c.0 = s32[] constant(0)
396     ROOT %dynamic-update-slice =
397         f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
398   }
399 
400   ENTRY entry {
401     p.00 = f16[5,9,10]{2,1,0} parameter(0)
402     p.01 = f16[5,9,10]{2,1,0} parameter(1)
403     p.10 = s32[] parameter(2)
404     p.11 = s32[] parameter(3)
405     p.20 = f16[1,9,10]{2,1,0} parameter(4)
406     p.21 = f16[1,9,10]{2,1,0} parameter(5)
407 
408     f1 = f16[5,9,10] fusion(p.00, p.10, p.20), kind=kLoop, calls=fusion.1
409     f2 = f16[5,9,10] fusion(p.01, p.11, p.21), kind=kLoop, calls=fusion.2
410     ROOT tuple = (f16[5,9,10],f16[5,9,10]) tuple(f1, f2)
411   })")
412                     .ValueOrDie();
413 
414   EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
415   EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
416 
417   VLOG(2) << "Dump after horizontal fusion:";
418   VLOG(2) << module->ToString();
419 
420   EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
421 }
422 
TEST_F(HorizontalLoopFusionTest,NegativeTestForSharedParam)423 TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) {
424   auto module = ParseAndReturnVerifiedModule(R"(
425  HloModule BasicTest
426 
427  fused_computation.1 {
428    arg.1 = f16[123]{0} parameter(0)
429    arg.2 = f16[123]{0} parameter(1)
430    ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
431  }
432 
433  fused_computation.2 {
434    arg.1 = f16[123]{0} parameter(0)
435    arg.2 = f16[123]{0} parameter(1)
436    ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
437  }
438 
439  ENTRY entry_computation {
440    arg.1 = f16[123]{0} parameter(0)
441    // arg.2 is shared by fusion.1 and fusion.2
442    arg.2 = f16[123]{0} parameter(1)
443    arg.3 = f16[123]{0} parameter(2)
444    fusion.1 = f16[123]{0}
445        fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
446    fusion.2 = f16[123]{0}
447        fusion(arg.3, arg.2), kind=kLoop, calls=fused_computation.2
448    ROOT tuple.1 = (f16[123]{0}, f16[123]{0})
449        tuple(fusion.1, fusion.2)
450  }
451 )")
452                     .ValueOrDie();
453 
454   EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
455 }
456 
TEST_F(HorizontalLoopFusionTest,IterativeHorizontalFusion)457 TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) {
458   auto module = ParseAndReturnVerifiedModule(R"(
459  HloModule NonfusionInstrs
460 
461  fused_computation.0 {
462    arg.0 = f16[] parameter(0)
463    arg.1 = f16[123]{0} parameter(1)
464    broadcast.0 = f16[123]{0} broadcast(arg.0), dimensions={}
465    ROOT mul.1 = f16[123]{0} multiply(broadcast.0, arg.1)
466  }
467 
468  fused_computation.1 {
469    arg.0 = f16[] parameter(0)
470    arg.1 = f16[456]{0} parameter(1)
471    broadcast.0 = f16[456]{0} broadcast(arg.0), dimensions={}
472    ROOT add.1 = f16[456]{0} add(broadcast.0, arg.1)
473  }
474 
475  ENTRY entry_computation {
476    arg.0 = f16[] parameter(0)
477    arg.1 = f16[] parameter(1)
478    arg.2 = f16[123]{0} parameter(2)
479    arg.3 = f16[456]{0} parameter(3)
480    // Test fusion of non-fusion instructions. sqrt.0 and sqrt.1 are to be
481    // fused.
482    sqrt.0 = f16[] sqrt(arg.0)
483    sqrt.1 = f16[] sqrt(arg.1)
484    // fusion.0 and fusion.1 are to be fused.
485    fusion.0 = f16[123]{0}
486        fusion(sqrt.0, arg.2), kind=kLoop, calls=fused_computation.0
487    fusion.1 = f16[456]{0}
488        fusion(sqrt.1, arg.3), kind=kLoop, calls=fused_computation.1
489    ROOT tuple.1 = (f16[123]{0}, f16[456]{0}) tuple(fusion.0, fusion.1)
490  }
491 )")
492                     .ValueOrDie();
493 
494   HloPassFix<HloPassPipeline> iterative_h_fusion("iterative_h_fusion");
495   iterative_h_fusion.AddPass<GpuHorizontalLoopFusion>();
496   iterative_h_fusion.AddPass<HloDCE>();
497   EXPECT_TRUE(iterative_h_fusion.Run(module.get()).ValueOrDie());
498 
499   // Verify that fusion.0 and fusion.1 are fused.
500   const HloInstruction* entry_root =
501       module->entry_computation()->root_instruction();
502   EXPECT_THAT(entry_root,
503               op::Tuple(op::Bitcast(op::GetTupleElement(op::Fusion())),
504                         op::Bitcast(op::GetTupleElement(op::Fusion()))));
505   const HloInstruction* fusion = entry_root->operand(0)->operand(0)->operand(0);
506   EXPECT_TRUE(fusion->IsMultiOutputFusion());
507 
508   // Verify that the total number of fusion instructions is 2 so that we
509   // know sqrt.0 and sqrt.1 are fused.
510   size_t total_fusion_instrs = 0;
511   for (const HloInstruction* instr :
512        module->entry_computation()->instructions()) {
513     if (instr->opcode() == HloOpcode::kFusion) {
514       ++total_fusion_instrs;
515     }
516   }
517   EXPECT_EQ(total_fusion_instrs, 2);
518 }
519 
TEST_F(HorizontalLoopFusionTest,TraversalOrder)520 TEST_F(HorizontalLoopFusionTest, TraversalOrder) {
521   auto module = ParseAndReturnVerifiedModule(R"(
522  HloModule cluster
523 
524  %fused_computation (param_0: f32[256,256], param_1: f32[], param_2: f32[])
525      -> f32[256,256] {
526    %param_0 = f32[256,256]{1,0} parameter(0)
527    %param_1 = f32[] parameter(1)
528    %param_2 = f32[] parameter(2)
529    %multiply.0 = f32[] multiply(f32[] %param_1, f32[] %param_2)
530    %broadcast.0 = f32[256,256]{1,0} broadcast(f32[] %multiply.0), dimensions={}
531    ROOT %multiply.1 = f32[256,256]{1,0}
532        multiply(f32[256,256]{1,0} %param_0, f32[256,256]{1,0} %broadcast.0)
533  }
534 
535  %fused_computation.1 (param_0: f32[256,256], param_1: f32[], param_2: f32[])
536      -> f32[256,256] {
537    %param_0 = f32[256,256]{1,0} parameter(0)
538    %param_1 = f32[] parameter(1)
539    %param_2 = f32[] parameter(2)
540    %multiply.0 = f32[] multiply(f32[] %param_1, f32[] %param_2)
541    %broadcast.0 = f32[256,256]{1,0} broadcast(f32[] %multiply.0), dimensions={}
542    ROOT %multiply.1 = f32[256,256]{1,0}
543        multiply(f32[256,256]{1,0} %param_0, f32[256,256]{1,0} %broadcast.0)
544  }
545 
546  ENTRY %entry_computation (arg0: f32[256,256], arg1: f32[256,256], arg2: f32[],
547                            arg3: f32[], arg4: f32[], arg5: f32[])
548                                -> (f32[256,256], f32[256,256]) {
549    %arg0 = f32[256,256]{1,0} parameter(0), parameter_replication={false}
550    %arg1 = f32[256,256]{1,0} parameter(1), parameter_replication={false}
551    %arg2 = f32[] parameter(2), parameter_replication={false}
552    %arg3 = f32[] parameter(3), parameter_replication={false}
553    %arg4 = f32[] parameter(4), parameter_replication={false}
554    %arg5 = f32[] parameter(5), parameter_replication={false}
555    %sqrt = f32[] sqrt(f32[] %arg2)
556    %sqrt.1 = f32[] sqrt(f32[] %arg3)
557    %fusion = f32[256,256]{1,0}
558        fusion(f32[256,256]{1,0} %arg0, f32[] %sqrt, f32[] %sqrt.1),
559        kind=kLoop, calls=%fused_computation
560    %sqrt.2 = f32[] sqrt(f32[] %arg4)
561    %sqrt.3 = f32[] sqrt(f32[] %arg5)
562    %fusion.1 = f32[256,256]{1,0}
563        fusion(f32[256,256]{1,0} %arg1, f32[] %sqrt.2, f32[] %sqrt.3),
564        kind=kLoop, calls=%fused_computation.1
565    ROOT %tuple.163 = (f32[256,256]{1,0}, f32[256,256]{1,0})
566        tuple(f32[256,256]{1,0} %fusion.1, f32[256,256]{1,0} %fusion)
567  }
568 )")
569                     .ValueOrDie();
570 
571   HloPassFix<HloPassPipeline> iterative_h_fusion("iterative_h_fusion");
572   iterative_h_fusion.AddPass<GpuHorizontalLoopFusion>();
573   EXPECT_TRUE(iterative_h_fusion.Run(module.get()).ValueOrDie());
574 
575   // Verify that the total number of fusion instructions is 2 so that we
576   // know all the sqrt instructions are fused into a kernel. Note that if we
577   // traverse from def-to-use (i.e., top-to-down) instead of use-to-def, we
578   // will end up having 3 fusions instead of 2.
579   size_t total_fusion_instrs = 0;
580   for (const HloInstruction* instr :
581        module->entry_computation()->instructions()) {
582     if (instr->opcode() == HloOpcode::kFusion) {
583       ++total_fusion_instrs;
584     }
585   }
586   EXPECT_EQ(total_fusion_instrs, 2);
587 }
588 
589 }  // namespace
590 }  // namespace gpu
591 }  // namespace xla
592