1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
17
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
20 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
21 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
22 #include "tensorflow/compiler/xla/service/hlo_dce.h"
23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
24 #include "tensorflow/compiler/xla/service/hlo_parser.h"
25 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
26 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
27 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/filecheck.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33
34 namespace xla {
35 namespace gpu {
36 namespace {
37
38 namespace op = xla::testing::opcode_matchers;
39
40 class HorizontalLoopFusionTest : public HloTestBase {};
41
TEST_F(HorizontalLoopFusionTest,BasicTest)42 TEST_F(HorizontalLoopFusionTest, BasicTest) {
43 auto module = ParseAndReturnVerifiedModule(R"(
44 HloModule BasicTest
45
46 fused_computation.1 {
47 arg.1 = f16[1024]{0} parameter(0)
48 arg.2 = f16[1024]{0} parameter(1)
49 ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
50 }
51
52 fused_computation.2 {
53 arg.1 = f16[123]{0} parameter(0)
54 arg.2 = f16[123]{0} parameter(1)
55 ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
56 }
57
58 ENTRY entry_computation {
59 arg.1 = f16[1024]{0} parameter(0)
60 arg.2 = f16[1024]{0} parameter(1)
61 arg.3 = f16[123]{0} parameter(2)
62 arg.4 = f16[123]{0} parameter(3)
63 fusion.1 = f16[1024]{0}
64 fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
65 fusion.2 = f16[123]{0}
66 fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
67 ROOT tuple.1 = (f16[1024]{0}, f16[123]{0})
68 tuple(fusion.1, fusion.2)
69 }
70 )")
71 .ValueOrDie();
72
73 EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
74 EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
75
76 const HloInstruction* entry_root =
77 module->entry_computation()->root_instruction();
78 EXPECT_THAT(entry_root,
79 op::Tuple(op::Bitcast(op::GetTupleElement(op::Fusion())),
80 op::Bitcast(op::GetTupleElement(op::Fusion()))));
81
82 const HloInstruction* fusion = entry_root->operand(0)->operand(0)->operand(0);
83 ASSERT_TRUE(fusion->IsMultiOutputFusion());
84 EXPECT_THAT(
85 fusion->fused_expression_root(),
86 op::Tuple(op::Slice(op::Concatenate(op::Reshape(), op::Reshape())),
87 op::Slice(op::Concatenate(op::Reshape(), op::Reshape()))));
88 }
89
90 // Horizontal fusion should not be triggered as fusion will create cycles.
TEST_F(HorizontalLoopFusionTest,NegativeTestForCycle)91 TEST_F(HorizontalLoopFusionTest, NegativeTestForCycle) {
92 auto module = ParseAndReturnVerifiedModule(R"(
93 HloModule NegativeTestForCycle
94
95 fused_computation.1 {
96 arg.1 = f16[123]{0} parameter(0)
97 arg.2 = f16[123]{0} parameter(1)
98 ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
99 }
100
101 fused_computation.2 {
102 arg.1 = f16[123]{0} parameter(0)
103 arg.2 = f16[123]{0} parameter(1)
104 ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
105 }
106
107 ENTRY entry_computation {
108 arg.1 = f16[123]{0} parameter(0)
109 arg.2 = f16[123]{0} parameter(1)
110 arg.3 = f16[123]{0} parameter(2)
111 arg.4 = f16[123]{0} parameter(3)
112 // fusion.1 and fusion.2 will not be horizontally fused as it will create
113 // a cycle through fusion.1 -> add.2 -> fusion.2
114 fusion.1 = f16[123]{0}
115 fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
116 add.2 = f16[123]{0} add(fusion.1, arg.4)
117 fusion.2 = f16[123]{0}
118 fusion(add.2, arg.3), kind=kLoop, calls=fused_computation.2
119 ROOT tuple.1 = (f16[123]{0}, f16[123]{0}, f16[123]{0})
120 tuple(fusion.1, fusion.2, add.2)
121 }
122 )")
123 .ValueOrDie();
124
125 EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
126 }
127
TEST_F(HorizontalLoopFusionTest,NegativeTestForIncompatibleTypes)128 TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) {
129 auto module = ParseAndReturnVerifiedModule(R"(
130 HloModule NegativeTestForIncompatibleTypes
131
132 fused_computation.1 {
133 arg.1 = f16[1024]{0} parameter(0)
134 arg.2 = f16[1024]{0} parameter(1)
135 ROOT mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
136 }
137
138 fused_computation.2 {
139 arg.1 = s32[123]{0} parameter(0)
140 arg.2 = s32[123]{0} parameter(1)
141 ROOT add.1 = s32[123]{0} add(arg.1, arg.2)
142 }
143
144 ENTRY entry_computation {
145 arg.1 = f16[1024]{0} parameter(0)
146 arg.2 = f16[1024]{0} parameter(1)
147 arg.3 = s32[123]{0} parameter(2)
148 arg.4 = s32[123]{0} parameter(3)
149 // fusion.1 and fusion.2 will not be horizontally fused because their output
150 // types are different.
151 fusion.1 = f16[1024]{0}
152 fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
153 fusion.2 = s32[123]{0}
154 fusion(arg.3, arg.4), kind=kLoop, calls=fused_computation.2
155 ROOT tuple.1 = (f16[1024]{0}, s32[123]{0})
156 tuple(fusion.1, fusion.2)
157 }
158 )")
159 .ValueOrDie();
160
161 EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
162 }
163
TEST_F(HorizontalLoopFusionTest,HorizontalLoopFusionAfterVerticalFusion)164 TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) {
165 auto module = ParseAndReturnVerifiedModule(R"(
166 HloModule MergeSharedFusionInstruction
167
168 ENTRY MergeSharedFusionInstruction.Computation0 {
169 param.1.1 = f32[4,1024]{1,0} parameter(0)
170 param.1.2 = f32[4,1024]{1,0} parameter(1)
171 param.1.3 = f32[4,1024]{1,0} parameter(2)
172 param.2.1 = f32[321,5]{1,0} parameter(3)
173 param.2.2 = f32[321,5]{1,0} parameter(4)
174 param.2.3 = f32[321,5]{1,0} parameter(5)
175 const.1 = f32[] constant(3)
176 const.2 = f32[] constant(3)
177 broadcast.1 = f32[4,1024]{1,0} broadcast(const.1), dimensions={}
178 broadcast.2 = f32[321,5]{1,0} broadcast(const.2), dimensions={}
179 mul.1.1 = f32[4,1024]{1,0} multiply(param.1.1, param.1.2)
180 mul.1.2 = f32[4,1024]{1,0} multiply(param.1.3, broadcast.1)
181 add.1 = f32[4,1024]{1,0} add(mul.1.1, mul.1.2)
182 mul.2.1 = f32[321,5]{1,0} multiply(param.2.1, param.2.2)
183 mul.2.2 = f32[321,5]{1,0} multiply(param.2.3, broadcast.2)
184 add.2 = f32[321,5]{1,0} add(mul.2.1, mul.2.2)
185 ROOT tuple = (f32[4,1024]{1,0}, f32[321,5]{1,0}) tuple(add.1, add.2)
186 })")
187 .ValueOrDie();
188
189 HloPassPipeline fusion("fusion");
190 fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/false);
191 fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/true);
192 EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
193 EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
194
195 VLOG(2) << "Dump after horizontal fusion:";
196 VLOG(2) << module->ToString();
197
198 EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
199 }
200
TEST_F(HorizontalLoopFusionTest,GradientDescentOptimizerLike)201 TEST_F(HorizontalLoopFusionTest, GradientDescentOptimizerLike) {
202 HloComputation::Builder builder(TestName());
203
204 std::vector<HloInstruction*> var_outs;
205 for (int64_t i = 0; i < 128; ++i) {
206 // For shapes {1, 1024}, {2, 1024}, ..., {128, 1024}
207 Shape shape = ShapeUtil::MakeShape(F32, {i + 1, 1024});
208 HloInstruction* param_var_in = builder.AddInstruction(
209 HloInstruction::CreateParameter(i * 3 + 0, shape, "var.in"));
210 HloInstruction* param_alpha =
211 builder.AddInstruction(HloInstruction::CreateParameter(
212 i * 3 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
213 HloInstruction* param_delta = builder.AddInstruction(
214 HloInstruction::CreateParameter(i * 3 + 2, shape, "delta"));
215 HloInstruction* alpha_broadcasted = builder.AddInstruction(
216 HloInstruction::CreateBroadcast(shape, param_alpha, {}));
217 HloInstruction* alpha_delta =
218 builder.AddInstruction(HloInstruction::CreateBinary(
219 shape, HloOpcode::kMultiply, alpha_broadcasted, param_delta));
220 HloInstruction* var_out =
221 builder.AddInstruction(HloInstruction::CreateBinary(
222 shape, HloOpcode::kSubtract, param_var_in, alpha_delta));
223 var_outs.push_back(var_out);
224 }
225 builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
226
227 auto module = CreateNewVerifiedModule();
228 module->AddEntryComputation(builder.Build());
229
230 // Testing with the entire gpu optimization pipeline.
231 EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0, 0}));
232 }
233
TEST_F(HorizontalLoopFusionTest,FusingDifferentOutputs)234 TEST_F(HorizontalLoopFusionTest, FusingDifferentOutputs) {
235 auto module = ParseAndReturnVerifiedModule(R"(
236 HloModule HeterogeneousMultiOutputFusions
237
238 fused_computation.1 {
239 arg.1 = f16[1024]{0} parameter(0)
240 arg.2 = f16[1024]{0} parameter(1)
241 arg.3 = f16[1024]{0} parameter(2)
242 arg.4 = f16[1024]{0} parameter(3)
243 mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
244 mul.2 = f16[1024]{0} multiply(arg.3, arg.4)
245 add.1 = f16[1024]{0} add(mul.1, mul.2)
246 ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}) tuple(add.1, mul.1)
247 }
248
249 fused_computation.2 {
250 arg.1 = f16[123]{0} parameter(0)
251 arg.2 = f16[123]{0} parameter(1)
252 arg.3 = f16[123]{0} parameter(2)
253 arg.4 = f16[123]{0} parameter(3)
254 add.1 = f16[123]{0} add(arg.1, arg.2)
255 add.2 = f16[123]{0} add(arg.3, arg.4)
256 mul.1 = f16[123]{0} multiply(add.1, add.2)
257 ROOT tuple.1 = (f16[123]{0}, f16[123]{0}) tuple(mul.1, add.1)
258 }
259
260 ENTRY entry_computation {
261 arg.1 = f16[1024]{0} parameter(0)
262 arg.2 = f16[1024]{0} parameter(1)
263 arg.3 = f16[1024]{0} parameter(2)
264 arg.4 = f16[1024]{0} parameter(3)
265 arg.5 = f16[123]{0} parameter(4)
266 arg.6 = f16[123]{0} parameter(5)
267 arg.7 = f16[123]{0} parameter(6)
268 arg.8 = f16[123]{0} parameter(7)
269 fusion.1 = (f16[1024]{0}, f16[1024]{0})
270 fusion(arg.1, arg.2, arg.3, arg.4),
271 kind=kLoop, calls=fused_computation.1
272 fusion.2 = (f16[123]{0}, f16[123]{0})
273 fusion(arg.5, arg.6, arg.7, arg.8),
274 kind=kLoop, calls=fused_computation.2
275 gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=0
276 gte.2 = f16[1024]{0} get-tuple-element(fusion.1), index=1
277 gte.3 = f16[123]{0} get-tuple-element(fusion.2), index=0
278 gte.4 = f16[123]{0} get-tuple-element(fusion.2), index=1
279 ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}, f16[123]{0}, f16[123]{0})
280 tuple(gte.1, gte.2, gte.3, gte.4)
281 }
282 )")
283 .ValueOrDie();
284
285 EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
286 EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
287
288 VLOG(2) << "Dump after horizontal fusion:";
289 VLOG(2) << module->ToString();
290
291 EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
292 }
293
TEST_F(HorizontalLoopFusionTest,RMSPropLike)294 TEST_F(HorizontalLoopFusionTest, RMSPropLike) {
295 HloComputation::Builder builder(TestName());
296
297 std::vector<HloInstruction*> all_outputs;
298 for (int64_t i = 0; i < 48; ++i) {
299 Shape shape = ShapeUtil::MakeShape(F32, {2, 1024 + i});
300 // ms <- grad**2 (1 - rho) + ms * rho
301 HloInstruction* grad = builder.AddInstruction(
302 HloInstruction::CreateParameter(i * 9 + 0, shape, "grad"));
303 HloInstruction* ms = builder.AddInstruction(
304 HloInstruction::CreateParameter(i * 9 + 1, shape, "ms"));
305 HloInstruction* rho =
306 builder.AddInstruction(HloInstruction::CreateParameter(
307 i * 9 + 2, ShapeUtil::MakeShape(F32, {}), "rho"));
308 HloInstruction* one_minus_rho =
309 builder.AddInstruction(HloInstruction::CreateParameter(
310 i * 9 + 3, ShapeUtil::MakeShape(F32, {}), "one_minus_rho"));
311 HloInstruction* rho_broadcasted =
312 builder.AddInstruction(HloInstruction::CreateBroadcast(shape, rho, {}));
313 HloInstruction* one_mins_rho_broadcasted = builder.AddInstruction(
314 HloInstruction::CreateBroadcast(shape, one_minus_rho, {}));
315 HloInstruction* grad_squared = builder.AddInstruction(
316 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, grad, grad));
317 HloInstruction* ms_1st_term = builder.AddInstruction(
318 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, grad_squared,
319 one_mins_rho_broadcasted));
320 HloInstruction* ms_2nd_term =
321 builder.AddInstruction(HloInstruction::CreateBinary(
322 shape, HloOpcode::kMultiply, ms, rho_broadcasted));
323 HloInstruction* ms_out =
324 builder.AddInstruction(HloInstruction::CreateBinary(
325 shape, HloOpcode::kAdd, ms_1st_term, ms_2nd_term));
326
327 // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
328 HloInstruction* momentum = builder.AddInstruction(
329 HloInstruction::CreateParameter(i * 9 + 4, shape, "momemtum"));
330 HloInstruction* mom = builder.AddInstruction(
331 HloInstruction::CreateParameter(i * 9 + 5, shape, "mom"));
332 HloInstruction* lr = builder.AddInstruction(HloInstruction::CreateParameter(
333 i * 9 + 6, ShapeUtil::MakeShape(F32, {}), "lr"));
334 HloInstruction* epsilon =
335 builder.AddInstruction(HloInstruction::CreateParameter(
336 i * 9 + 7, ShapeUtil::MakeShape(F32, {}), "epsilon"));
337 HloInstruction* lr_broadcasted =
338 builder.AddInstruction(HloInstruction::CreateBroadcast(shape, lr, {}));
339 HloInstruction* epsilon_broadcasted = builder.AddInstruction(
340 HloInstruction::CreateBroadcast(shape, epsilon, {}));
341 HloInstruction* mom_1st_term =
342 builder.AddInstruction(HloInstruction::CreateBinary(
343 shape, HloOpcode::kMultiply, momentum, mom));
344 HloInstruction* ms_eps =
345 builder.AddInstruction(HloInstruction::CreateBinary(
346 shape, HloOpcode::kAdd, ms_out, epsilon_broadcasted));
347 HloInstruction* ms_eps_rsq = builder.AddInstruction(
348 HloInstruction::CreateUnary(shape, HloOpcode::kRsqrt, ms_eps));
349 HloInstruction* grad_ms_eps_rsq =
350 builder.AddInstruction(HloInstruction::CreateBinary(
351 shape, HloOpcode::kMultiply, grad, ms_eps_rsq));
352 HloInstruction* mom_2nd_term =
353 builder.AddInstruction(HloInstruction::CreateBinary(
354 shape, HloOpcode::kMultiply, lr_broadcasted, grad_ms_eps_rsq));
355 HloInstruction* mom_out =
356 builder.AddInstruction(HloInstruction::CreateBinary(
357 shape, HloOpcode::kAdd, mom_1st_term, mom_2nd_term));
358
359 // var <- var - mom
360 HloInstruction* var = builder.AddInstruction(
361 HloInstruction::CreateParameter(i * 9 + 8, shape, "var"));
362 HloInstruction* var_out =
363 builder.AddInstruction(HloInstruction::CreateBinary(
364 shape, HloOpcode::kSubtract, var, mom_out));
365
366 all_outputs.push_back(ms_out);
367 all_outputs.push_back(mom_out);
368 all_outputs.push_back(var_out);
369 }
370 builder.AddInstruction(HloInstruction::CreateTuple(all_outputs));
371
372 auto module = CreateNewVerifiedModule();
373 module->AddEntryComputation(builder.Build());
374
375 EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5}));
376 }
377
TEST_F(HorizontalLoopFusionTest,DynamicUpdateSlice)378 TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) {
379 auto module = ParseAndReturnVerifiedModule(R"(
380 HloModule NegativeTestForDynamicUpdateSlice
381
382 fusion.1 {
383 p.0 = f16[5,9,10]{2,1,0} parameter(0)
384 p.1 = s32[] parameter(1)
385 p.2 = f16[1,9,10]{2,1,0} parameter(2)
386 c.0 = s32[] constant(0)
387 ROOT %dynamic-update-slice =
388 f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
389 }
390
391 fusion.2 {
392 p.0 = f16[5,9,10]{2,1,0} parameter(0)
393 p.1 = s32[] parameter(1)
394 p.2 = f16[1,9,10]{2,1,0} parameter(2)
395 c.0 = s32[] constant(0)
396 ROOT %dynamic-update-slice =
397 f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
398 }
399
400 ENTRY entry {
401 p.00 = f16[5,9,10]{2,1,0} parameter(0)
402 p.01 = f16[5,9,10]{2,1,0} parameter(1)
403 p.10 = s32[] parameter(2)
404 p.11 = s32[] parameter(3)
405 p.20 = f16[1,9,10]{2,1,0} parameter(4)
406 p.21 = f16[1,9,10]{2,1,0} parameter(5)
407
408 f1 = f16[5,9,10] fusion(p.00, p.10, p.20), kind=kLoop, calls=fusion.1
409 f2 = f16[5,9,10] fusion(p.01, p.11, p.21), kind=kLoop, calls=fusion.2
410 ROOT tuple = (f16[5,9,10],f16[5,9,10]) tuple(f1, f2)
411 })")
412 .ValueOrDie();
413
414 EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
415 EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
416
417 VLOG(2) << "Dump after horizontal fusion:";
418 VLOG(2) << module->ToString();
419
420 EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
421 }
422
TEST_F(HorizontalLoopFusionTest,NegativeTestForSharedParam)423 TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) {
424 auto module = ParseAndReturnVerifiedModule(R"(
425 HloModule BasicTest
426
427 fused_computation.1 {
428 arg.1 = f16[123]{0} parameter(0)
429 arg.2 = f16[123]{0} parameter(1)
430 ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
431 }
432
433 fused_computation.2 {
434 arg.1 = f16[123]{0} parameter(0)
435 arg.2 = f16[123]{0} parameter(1)
436 ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
437 }
438
439 ENTRY entry_computation {
440 arg.1 = f16[123]{0} parameter(0)
441 // arg.2 is shared by fusion.1 and fusion.2
442 arg.2 = f16[123]{0} parameter(1)
443 arg.3 = f16[123]{0} parameter(2)
444 fusion.1 = f16[123]{0}
445 fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
446 fusion.2 = f16[123]{0}
447 fusion(arg.3, arg.2), kind=kLoop, calls=fused_computation.2
448 ROOT tuple.1 = (f16[123]{0}, f16[123]{0})
449 tuple(fusion.1, fusion.2)
450 }
451 )")
452 .ValueOrDie();
453
454 EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
455 }
456
TEST_F(HorizontalLoopFusionTest,IterativeHorizontalFusion)457 TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) {
458 auto module = ParseAndReturnVerifiedModule(R"(
459 HloModule NonfusionInstrs
460
461 fused_computation.0 {
462 arg.0 = f16[] parameter(0)
463 arg.1 = f16[123]{0} parameter(1)
464 broadcast.0 = f16[123]{0} broadcast(arg.0), dimensions={}
465 ROOT mul.1 = f16[123]{0} multiply(broadcast.0, arg.1)
466 }
467
468 fused_computation.1 {
469 arg.0 = f16[] parameter(0)
470 arg.1 = f16[456]{0} parameter(1)
471 broadcast.0 = f16[456]{0} broadcast(arg.0), dimensions={}
472 ROOT add.1 = f16[456]{0} add(broadcast.0, arg.1)
473 }
474
475 ENTRY entry_computation {
476 arg.0 = f16[] parameter(0)
477 arg.1 = f16[] parameter(1)
478 arg.2 = f16[123]{0} parameter(2)
479 arg.3 = f16[456]{0} parameter(3)
480 // Test fusion of non-fusion instructions. sqrt.0 and sqrt.1 are to be
481 // fused.
482 sqrt.0 = f16[] sqrt(arg.0)
483 sqrt.1 = f16[] sqrt(arg.1)
484 // fusion.0 and fusion.1 are to be fused.
485 fusion.0 = f16[123]{0}
486 fusion(sqrt.0, arg.2), kind=kLoop, calls=fused_computation.0
487 fusion.1 = f16[456]{0}
488 fusion(sqrt.1, arg.3), kind=kLoop, calls=fused_computation.1
489 ROOT tuple.1 = (f16[123]{0}, f16[456]{0}) tuple(fusion.0, fusion.1)
490 }
491 )")
492 .ValueOrDie();
493
494 HloPassFix<HloPassPipeline> iterative_h_fusion("iterative_h_fusion");
495 iterative_h_fusion.AddPass<GpuHorizontalLoopFusion>();
496 iterative_h_fusion.AddPass<HloDCE>();
497 EXPECT_TRUE(iterative_h_fusion.Run(module.get()).ValueOrDie());
498
499 // Verify that fusion.0 and fusion.1 are fused.
500 const HloInstruction* entry_root =
501 module->entry_computation()->root_instruction();
502 EXPECT_THAT(entry_root,
503 op::Tuple(op::Bitcast(op::GetTupleElement(op::Fusion())),
504 op::Bitcast(op::GetTupleElement(op::Fusion()))));
505 const HloInstruction* fusion = entry_root->operand(0)->operand(0)->operand(0);
506 EXPECT_TRUE(fusion->IsMultiOutputFusion());
507
508 // Verify that the total number of fusion instructions is 2 so that we
509 // know sqrt.0 and sqrt.1 are fused.
510 size_t total_fusion_instrs = 0;
511 for (const HloInstruction* instr :
512 module->entry_computation()->instructions()) {
513 if (instr->opcode() == HloOpcode::kFusion) {
514 ++total_fusion_instrs;
515 }
516 }
517 EXPECT_EQ(total_fusion_instrs, 2);
518 }
519
TEST_F(HorizontalLoopFusionTest,TraversalOrder)520 TEST_F(HorizontalLoopFusionTest, TraversalOrder) {
521 auto module = ParseAndReturnVerifiedModule(R"(
522 HloModule cluster
523
524 %fused_computation (param_0: f32[256,256], param_1: f32[], param_2: f32[])
525 -> f32[256,256] {
526 %param_0 = f32[256,256]{1,0} parameter(0)
527 %param_1 = f32[] parameter(1)
528 %param_2 = f32[] parameter(2)
529 %multiply.0 = f32[] multiply(f32[] %param_1, f32[] %param_2)
530 %broadcast.0 = f32[256,256]{1,0} broadcast(f32[] %multiply.0), dimensions={}
531 ROOT %multiply.1 = f32[256,256]{1,0}
532 multiply(f32[256,256]{1,0} %param_0, f32[256,256]{1,0} %broadcast.0)
533 }
534
535 %fused_computation.1 (param_0: f32[256,256], param_1: f32[], param_2: f32[])
536 -> f32[256,256] {
537 %param_0 = f32[256,256]{1,0} parameter(0)
538 %param_1 = f32[] parameter(1)
539 %param_2 = f32[] parameter(2)
540 %multiply.0 = f32[] multiply(f32[] %param_1, f32[] %param_2)
541 %broadcast.0 = f32[256,256]{1,0} broadcast(f32[] %multiply.0), dimensions={}
542 ROOT %multiply.1 = f32[256,256]{1,0}
543 multiply(f32[256,256]{1,0} %param_0, f32[256,256]{1,0} %broadcast.0)
544 }
545
546 ENTRY %entry_computation (arg0: f32[256,256], arg1: f32[256,256], arg2: f32[],
547 arg3: f32[], arg4: f32[], arg5: f32[])
548 -> (f32[256,256], f32[256,256]) {
549 %arg0 = f32[256,256]{1,0} parameter(0), parameter_replication={false}
550 %arg1 = f32[256,256]{1,0} parameter(1), parameter_replication={false}
551 %arg2 = f32[] parameter(2), parameter_replication={false}
552 %arg3 = f32[] parameter(3), parameter_replication={false}
553 %arg4 = f32[] parameter(4), parameter_replication={false}
554 %arg5 = f32[] parameter(5), parameter_replication={false}
555 %sqrt = f32[] sqrt(f32[] %arg2)
556 %sqrt.1 = f32[] sqrt(f32[] %arg3)
557 %fusion = f32[256,256]{1,0}
558 fusion(f32[256,256]{1,0} %arg0, f32[] %sqrt, f32[] %sqrt.1),
559 kind=kLoop, calls=%fused_computation
560 %sqrt.2 = f32[] sqrt(f32[] %arg4)
561 %sqrt.3 = f32[] sqrt(f32[] %arg5)
562 %fusion.1 = f32[256,256]{1,0}
563 fusion(f32[256,256]{1,0} %arg1, f32[] %sqrt.2, f32[] %sqrt.3),
564 kind=kLoop, calls=%fused_computation.1
565 ROOT %tuple.163 = (f32[256,256]{1,0}, f32[256,256]{1,0})
566 tuple(f32[256,256]{1,0} %fusion.1, f32[256,256]{1,0} %fusion)
567 }
568 )")
569 .ValueOrDie();
570
571 HloPassFix<HloPassPipeline> iterative_h_fusion("iterative_h_fusion");
572 iterative_h_fusion.AddPass<GpuHorizontalLoopFusion>();
573 EXPECT_TRUE(iterative_h_fusion.Run(module.get()).ValueOrDie());
574
575 // Verify that the total number of fusion instructions is 2 so that we
576 // know all the sqrt instructions are fused into a kernel. Note that if we
577 // traverse from def-to-use (i.e., top-to-down) instead of use-to-def, we
578 // will end up having 3 fusions instead of 2.
579 size_t total_fusion_instrs = 0;
580 for (const HloInstruction* instr :
581 module->entry_computation()->instructions()) {
582 if (instr->opcode() == HloOpcode::kFusion) {
583 ++total_fusion_instrs;
584 }
585 }
586 EXPECT_EQ(total_fusion_instrs, 2);
587 }
588
589 } // namespace
590 } // namespace gpu
591 } // namespace xla
592