1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
17
18 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/test.h"
25 #include "tensorflow/compiler/xla/test_helpers.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29
30 namespace xla {
31
32 // A class specifying the BF16 support used to test the propagation pass. It
33 // specifies that BF16 and mixed precision are supported in all HloInstructions,
34 // and that kDot reduces its operands precision to BF16.
35 class TestBFloat16Support : public BFloat16Support {
36 public:
TestBFloat16Support()37 TestBFloat16Support() {}
~TestBFloat16Support()38 ~TestBFloat16Support() override {}
39
SupportsBF16Operand(const HloInstruction & hlo,int64_t operand_index) const40 bool SupportsBF16Operand(const HloInstruction& hlo,
41 int64_t operand_index) const override {
42 return true;
43 }
44
SupportsBF16Output(const HloInstruction & hlo) const45 bool SupportsBF16Output(const HloInstruction& hlo) const override {
46 return true;
47 }
48
SupportsMixedPrecisions(const HloInstruction & hlo) const49 bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
50 return true;
51 }
52
EffectiveOperandPrecisionIsBF16(const HloInstruction & hlo,int64_t operand_index) const53 bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo,
54 int64_t operand_index) const override {
55 return hlo.opcode() == HloOpcode::kDot;
56 }
57 };
58
59 class BFloat16PropagationTest : public HloTestBase {
60 protected:
BFloat16PropagationTest()61 BFloat16PropagationTest()
62 : HloTestBase(/*verifier_layout_sensitive=*/false,
63 /*allow_mixed_precision_in_hlo_verifier=*/true) {}
64
65 // Runs the propagation pass on the given module, and returns whether the
66 // module is changed after this pass.
PropagatePrecision(HloModule * module)67 bool PropagatePrecision(HloModule* module) {
68 TestBFloat16Support bfloat16_support;
69 BFloat16Propagation propagation(&bfloat16_support);
70 StatusOr<bool> result = propagation.Run(module);
71 EXPECT_IS_OK(result.status());
72 return result.ValueOrDie();
73 }
74
75 // Returns whether the given HloInstruction's output element type is BF16 or
76 // the only use of it is converting to BF16.
OutputsBF16(const HloInstruction * inst)77 bool OutputsBF16(const HloInstruction* inst) {
78 if (inst->shape().element_type() == BF16) {
79 return true;
80 }
81 return inst->user_count() == 1 &&
82 inst->users()[0]->opcode() == HloOpcode::kConvert &&
83 inst->users()[0]->shape().element_type() == BF16;
84 }
85
CreateDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs)86 std::unique_ptr<HloInstruction> CreateDot(const Shape& shape,
87 HloInstruction* lhs,
88 HloInstruction* rhs) {
89 DotDimensionNumbers dot_dnums;
90 dot_dnums.add_lhs_contracting_dimensions(1);
91 dot_dnums.add_rhs_contracting_dimensions(0);
92 return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
93 DefaultPrecisionConfig(2));
94 }
95 };
96
97 // Tests that BF16 can propagate through select over non-tuple buffers, but not
98 // through add where reducing operand precision can affect the result.
TEST_F(BFloat16PropagationTest,PropagateThroughSelectButNotAdd)99 TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) {
100 auto builder = HloComputation::Builder(TestName());
101 Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
102
103 HloInstruction* a =
104 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
105 HloInstruction* b =
106 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
107 HloInstruction* c =
108 builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c"));
109 HloInstruction* add0 = builder.AddInstruction(
110 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
111 HloInstruction* add1 = builder.AddInstruction(
112 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b));
113 HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateCompare(
114 ShapeUtil::MakeShape(PRED, {2, 4}), a, b, ComparisonDirection::kEq));
115 HloInstruction* sel = builder.AddInstruction(
116 HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1));
117 HloInstruction* xpose =
118 builder.AddInstruction(HloInstruction::CreateTranspose(
119 ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0}));
120 HloInstruction* dot = builder.AddInstruction(
121 CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, a));
122 HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
123 ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot));
124
125 auto module = CreateNewVerifiedModule();
126 auto computation = module->AddEntryComputation(builder.Build());
127
128 EXPECT_TRUE(PropagatePrecision(module.get()));
129
130 EXPECT_EQ(computation->root_instruction(), root);
131 EXPECT_TRUE(OutputsBF16(xpose));
132 EXPECT_TRUE(OutputsBF16(sel));
133 EXPECT_TRUE(OutputsBF16(add1));
134 EXPECT_FALSE(OutputsBF16(add0));
135 EXPECT_FALSE(OutputsBF16(a));
136 EXPECT_FALSE(OutputsBF16(b));
137 EXPECT_FALSE(OutputsBF16(c));
138 }
139
TEST_F(BFloat16PropagationTest,PropagateThroughMaxPoolReduceWindow)140 TEST_F(BFloat16PropagationTest, PropagateThroughMaxPoolReduceWindow) {
141 auto module = CreateNewVerifiedModule();
142
143 auto sub_builder = HloComputation::Builder("max");
144 HloInstruction* p0 = sub_builder.AddInstruction(
145 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "a"));
146 HloInstruction* p1 = sub_builder.AddInstruction(
147 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "b"));
148 sub_builder.AddInstruction(HloInstruction::CreateBinary(
149 ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, p0, p1));
150 auto max_computation = module->AddEmbeddedComputation(sub_builder.Build());
151
152 auto builder = HloComputation::Builder(TestName());
153 Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
154
155 HloInstruction* a =
156 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
157 HloInstruction* b =
158 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
159 HloInstruction* c =
160 builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c"));
161 HloInstruction* add = builder.AddInstruction(
162 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
163 Window window;
164 WindowDimension dim;
165 dim.set_size(2);
166 dim.set_stride(1);
167 dim.set_padding_high(1);
168 dim.set_window_dilation(1);
169 dim.set_base_dilation(1);
170 *window.add_dimensions() = dim;
171 *window.add_dimensions() = dim;
172 HloInstruction* rw =
173 builder.AddInstruction(HloInstruction::CreateReduceWindow(
174 shape, add,
175 builder.AddInstruction(
176 HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
177 window, max_computation));
178 HloInstruction* xpose =
179 builder.AddInstruction(HloInstruction::CreateTranspose(
180 ShapeUtil::MakeShape(F32, {4, 2}), c, {1, 0}));
181 HloInstruction* dot = builder.AddInstruction(
182 CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, rw));
183 HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
184 ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot));
185
186 auto computation = module->AddEntryComputation(builder.Build());
187
188 EXPECT_TRUE(PropagatePrecision(module.get()));
189
190 EXPECT_EQ(computation->root_instruction(), root);
191 EXPECT_TRUE(OutputsBF16(add));
192 EXPECT_TRUE(OutputsBF16(xpose));
193 EXPECT_TRUE(OutputsBF16(rw));
194 }
195
196 // Tests that side-effecting all-reduce should not be changed.
TEST_F(BFloat16PropagationTest,DoNotChangeAllReduce)197 TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) {
198 auto module = CreateNewVerifiedModule();
199
200 auto builder = HloComputation::Builder(TestName());
201 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
202 HloInstruction* a =
203 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
204 HloInstruction* b =
205 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
206 auto rb = HloComputation::Builder(TestName());
207 rb.AddInstruction(HloInstruction::CreateBinary(
208 shape, HloOpcode::kAdd,
209 rb.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")),
210 rb.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"))));
211 auto reduction = module->AddEmbeddedComputation(rb.Build());
212 HloInstruction* all_reduce =
213 builder.AddInstruction(HloInstruction::CreateAllReduce(
214 ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction,
215 /*replica_groups=*/{}, /*constrain_layout=*/false,
216 /*channel_id=*/1, /*use_global_device_ids=*/false));
217 HloInstruction* gte0 = builder.AddInstruction(
218 HloInstruction::CreateGetTupleElement(shape, all_reduce, 0));
219 HloInstruction* gte1 = builder.AddInstruction(
220 HloInstruction::CreateGetTupleElement(shape, all_reduce, 1));
221 HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1));
222 HloInstruction* root = builder.AddInstruction(
223 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
224
225 auto computation = module->AddEntryComputation(builder.Build());
226
227 EXPECT_FALSE(PropagatePrecision(module.get()));
228 EXPECT_EQ(computation->root_instruction(), root);
229 }
230
231 // Tests that if a constant is converted to BF16 then its literal must also be
232 // converted.
TEST_F(BFloat16PropagationTest,ConvertConstantLiteral)233 TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
234 auto builder = HloComputation::Builder(TestName());
235 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
236 Array2D<float> array_a(4, 4);
237 array_a.FillUnique(1.0f);
238 Array2D<float> array_b(4, 4);
239 array_b.FillUnique(10.0f);
240
241 HloInstruction* a = builder.AddInstruction(
242 HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a)));
243 HloInstruction* b = builder.AddInstruction(
244 HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b)));
245 HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b));
246
247 auto module = CreateNewVerifiedModule();
248 auto computation = module->AddEntryComputation(builder.Build());
249
250 EXPECT_TRUE(PropagatePrecision(module.get()));
251
252 EXPECT_EQ(computation->root_instruction(), dot);
253 EXPECT_TRUE(OutputsBF16(dot->operand(0)));
254 EXPECT_TRUE(OutputsBF16(dot->operand(1)));
255 EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
256 EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
257 EXPECT_TRUE(LiteralTestUtil::Equal(
258 LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)),
259 dot->operand(0)->literal()));
260 EXPECT_TRUE(LiteralTestUtil::Equal(
261 LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)),
262 dot->operand(1)->literal()));
263 }
264
265 // Tests that BF16 can be propagated through nested tuples.
TEST_F(BFloat16PropagationTest,PropagateThroughTuples)266 TEST_F(BFloat16PropagationTest, PropagateThroughTuples) {
267 auto builder = HloComputation::Builder(TestName());
268 Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
269
270 HloInstruction* a =
271 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
272 HloInstruction* b =
273 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
274 HloInstruction* add0 = builder.AddInstruction(
275 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
276 HloInstruction* add1 = builder.AddInstruction(
277 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a));
278 HloInstruction* add2 = builder.AddInstruction(
279 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, b));
280 HloInstruction* xpose =
281 builder.AddInstruction(HloInstruction::CreateTranspose(
282 ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0}));
283
284 HloInstruction* tuple0 =
285 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1, add2}));
286 HloInstruction* tuple1 =
287 builder.AddInstruction(HloInstruction::CreateTuple({tuple0, xpose}));
288
289 HloInstruction* lhs = builder.AddInstruction(
290 HloInstruction::CreateGetTupleElement(xpose->shape(), tuple1, 1));
291 HloInstruction* rhs =
292 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
293 add0->shape(),
294 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
295 tuple0->shape(), tuple1, 0)),
296 0));
297 HloInstruction* dot = builder.AddInstruction(
298 CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
299
300 HloInstruction* output_tuple =
301 builder.AddInstruction(HloInstruction::CreateTuple({dot, add2}));
302
303 auto module = CreateNewVerifiedModule();
304 auto computation = module->AddEntryComputation(builder.Build());
305
306 EXPECT_TRUE(PropagatePrecision(module.get()));
307
308 EXPECT_EQ(computation->root_instruction(), output_tuple);
309 EXPECT_TRUE(OutputsBF16(xpose));
310 EXPECT_TRUE(OutputsBF16(add0));
311 EXPECT_TRUE(OutputsBF16(add1));
312 EXPECT_FALSE(OutputsBF16(add2));
313 }
314
315 // Tests that even if an instruction does not define a buffer in its output, its
316 // shape must match the defining instruction.
TEST_F(BFloat16PropagationTest,SameValueReferencedTwice)317 TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) {
318 auto builder = HloComputation::Builder(TestName());
319 Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
320
321 HloInstruction* a =
322 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
323 HloInstruction* b =
324 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
325 HloInstruction* add0 = builder.AddInstruction(
326 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
327 HloInstruction* add1 = builder.AddInstruction(
328 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a));
329
330 HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateTranspose(
331 ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0}));
332
333 HloInstruction* tuple =
334 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
335 HloInstruction* rhs = builder.AddInstruction(
336 HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1));
337
338 // lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1.
339 HloInstruction* dot = builder.AddInstruction(
340 CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
341
342 auto module = CreateNewVerifiedModule();
343 auto computation = module->AddEntryComputation(builder.Build());
344
345 EXPECT_TRUE(PropagatePrecision(module.get()));
346
347 EXPECT_EQ(computation->root_instruction(), dot);
348 EXPECT_TRUE(OutputsBF16(add1));
349 EXPECT_TRUE(OutputsBF16(lhs));
350
351 // add0 and rhs have been eliminated by simplification and DCE.
352 }
353
354 // Tests that a non-fusion computation's root should not be changed.
TEST_F(BFloat16PropagationTest,DoNotChangeComputationRoot)355 TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
356 auto builder = HloComputation::Builder(TestName());
357 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
358
359 HloInstruction* a =
360 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
361 HloInstruction* b =
362 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
363 HloInstruction* add = builder.AddInstruction(
364 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
365
366 HloInstruction* dot = builder.AddInstruction(CreateDot(shape, add, add));
367
368 HloInstruction* tuple =
369 builder.AddInstruction(HloInstruction::CreateTuple({add, dot}));
370
371 auto module = CreateNewVerifiedModule();
372 auto computation = module->AddEntryComputation(builder.Build());
373
374 EXPECT_FALSE(PropagatePrecision(module.get()));
375
376 EXPECT_EQ(computation->root_instruction(), tuple);
377 EXPECT_FALSE(OutputsBF16(add));
378 }
379
380 // Tests that BF16 is propagated properly through fused computations.
TEST_F(BFloat16PropagationTest,PropagateThroughFusion)381 TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
382 auto module = CreateNewVerifiedModule();
383 auto builder = HloComputation::Builder(TestName());
384 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
385
386 HloInstruction* param = builder.AddInstruction(
387 HloInstruction::CreateParameter(0, shape, "param"));
388 HloInstruction* add = builder.AddInstruction(
389 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
390
391 auto builder_f0 = HloComputation::Builder("fusion0");
392 HloInstruction* a_f0 =
393 builder_f0.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
394 HloInstruction* b_f0 =
395 builder_f0.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
396 HloInstruction* tuple_f0 =
397 builder_f0.AddInstruction(HloInstruction::CreateTuple({a_f0, b_f0}));
398 auto comp_f0 = module->AddEmbeddedComputation(builder_f0.Build());
399 auto fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
400 tuple_f0->shape(), HloInstruction::FusionKind::kCustom, {add, add},
401 comp_f0));
402
403 auto builder_f1 = HloComputation::Builder("fusion1");
404 HloInstruction* p_f1 = builder_f1.AddInstruction(
405 HloInstruction::CreateParameter(0, tuple_f0->shape(), "param"));
406 HloInstruction* a_f1 = builder_f1.AddInstruction(
407 HloInstruction::CreateGetTupleElement(shape, p_f1, 0));
408 HloInstruction* b_f1 = builder_f1.AddInstruction(
409 HloInstruction::CreateGetTupleElement(shape, p_f1, 1));
410 HloInstruction* dot = builder_f1.AddInstruction(CreateDot(shape, a_f1, b_f1));
411 auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build());
412 auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion(
413 dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1));
414
415 auto computation = module->AddEntryComputation(builder.Build());
416
417 EXPECT_TRUE(PropagatePrecision(module.get()));
418
419 EXPECT_EQ(computation->root_instruction(), fusion1);
420 EXPECT_TRUE(OutputsBF16(add));
421 EXPECT_TRUE(OutputsBF16(a_f0));
422 EXPECT_TRUE(OutputsBF16(b_f0));
423 EXPECT_TRUE(OutputsBF16(a_f1));
424 EXPECT_TRUE(OutputsBF16(b_f1));
425 }
426
427 // Tests that a fusion with a bitcast-convert as its root is changed via adding
428 // extra convert, instead of changing the type in-place.
TEST_F(BFloat16PropagationTest,FusionWithBitcastConvertRoot)429 TEST_F(BFloat16PropagationTest, FusionWithBitcastConvertRoot) {
430 auto module = CreateNewVerifiedModule();
431 auto builder = HloComputation::Builder(TestName());
432 Shape u32_shape = ShapeUtil::MakeShape(U32, {4, 4});
433 Shape f32_shape = ShapeUtil::MakeShape(F32, {4, 4});
434
435 HloInstruction* param = builder.AddInstruction(
436 HloInstruction::CreateParameter(0, u32_shape, "param"));
437
438 auto builder_f = HloComputation::Builder("fusion");
439 HloInstruction* a_f = builder_f.AddInstruction(
440 HloInstruction::CreateParameter(0, u32_shape, "a"));
441 HloInstruction* bc_f = builder_f.AddInstruction(
442 HloInstruction::CreateBitcastConvert(f32_shape, a_f));
443 auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
444 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
445 f32_shape, HloInstruction::FusionKind::kLoop, {param}, comp_f));
446 auto dot = builder.AddInstruction(CreateDot(f32_shape, fusion, fusion));
447
448 auto computation = module->AddEntryComputation(builder.Build());
449 EXPECT_TRUE(PropagatePrecision(module.get()));
450
451 EXPECT_EQ(computation->root_instruction(), dot);
452 EXPECT_EQ(bc_f->shape(), f32_shape);
453 EXPECT_TRUE(OutputsBF16(bc_f));
454 }
455
456 // Tests that changes to BF16 that cannot be propagated outside a fusion are
457 // discarded.
TEST_F(BFloat16PropagationTest,DiscardFusionInternalBF16Changes)458 TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) {
459 auto module = CreateNewVerifiedModule();
460 auto builder = HloComputation::Builder(TestName());
461 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
462
463 HloInstruction* param = builder.AddInstruction(
464 HloInstruction::CreateParameter(0, shape, "param"));
465 HloInstruction* add = builder.AddInstruction(
466 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
467
468 auto builder_f = HloComputation::Builder("fusion");
469 HloInstruction* a_f =
470 builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
471 HloInstruction* b_f =
472 builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
473 HloInstruction* add_f = builder_f.AddInstruction(
474 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f));
475 HloInstruction* dot_f =
476 builder_f.AddInstruction(CreateDot(shape, add_f, add_f));
477 auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
478 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
479 dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f));
480
481 auto computation = module->AddEntryComputation(builder.Build());
482
483 EXPECT_FALSE(PropagatePrecision(module.get()));
484 EXPECT_EQ(computation->root_instruction(), fusion);
485 }
486
487 // Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion
488 // outputs are only used by a dot, and 3) one element of the tuple is used by
489 // an add in the fusion computation, then the propagation pass should create a
490 // convert in the fusion computation to keep the add's operand in F32 but change
491 // the fusion output to BF16. E.g., the following fusion computation
492 // (F32, F32) fusion_computation(F32 a, F32 b)
493 // = tuple(F32 a, F32 add(F32 a, F32 b))
494 // will be changed to
495 // (BF16, BF16) fusion_computation(F32 a, F32 b)
496 // = tuple(BF16 convert(a), BF16 add(F32 a, F32 b))
TEST_F(BFloat16PropagationTest,ConvertTupleFusionElementIfUsedByAdd)497 TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) {
498 auto module = CreateNewVerifiedModule();
499 auto builder = HloComputation::Builder(TestName());
500 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
501
502 HloInstruction* param = builder.AddInstruction(
503 HloInstruction::CreateParameter(0, shape, "param"));
504 HloInstruction* add = builder.AddInstruction(
505 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
506
507 auto builder_f = HloComputation::Builder("fusion0");
508 HloInstruction* a_f =
509 builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
510 HloInstruction* b_f =
511 builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
512 HloInstruction* add_f = builder_f.AddInstruction(
513 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f));
514 HloInstruction* tuple_f =
515 builder_f.AddInstruction(HloInstruction::CreateTuple({a_f, add_f}));
516 auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
517 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
518 tuple_f->shape(), HloInstruction::FusionKind::kCustom, {add, add},
519 comp_f));
520
521 HloInstruction* gte0 = builder.AddInstruction(
522 HloInstruction::CreateGetTupleElement(shape, fusion, 0));
523 HloInstruction* gte1 = builder.AddInstruction(
524 HloInstruction::CreateGetTupleElement(shape, fusion, 1));
525 HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1));
526
527 auto computation = module->AddEntryComputation(builder.Build());
528
529 EXPECT_TRUE(PropagatePrecision(module.get()));
530
531 EXPECT_EQ(computation->root_instruction(), dot);
532 EXPECT_TRUE(OutputsBF16(gte0));
533 EXPECT_TRUE(OutputsBF16(gte1));
534 EXPECT_FALSE(OutputsBF16(a_f));
535 EXPECT_FALSE(OutputsBF16(b_f));
536 EXPECT_TRUE(OutputsBF16(add_f));
537 auto new_fusion_root = comp_f->root_instruction();
538 EXPECT_EQ(new_fusion_root->opcode(), HloOpcode::kTuple);
539 EXPECT_EQ(new_fusion_root->operand(1), add_f);
540 EXPECT_EQ(new_fusion_root->operand(0)->opcode(), HloOpcode::kConvert);
541 EXPECT_TRUE(OutputsBF16(new_fusion_root->operand(0)));
542 }
543
544
545 // Tests that BF16 is propagated properly through a while computation with
546 // non-tuple input/output.
TEST_F(BFloat16PropagationTest,PropagateThroughSimpleWhile)547 TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) {
548 auto module = CreateNewVerifiedModule();
549 auto builder = HloComputation::Builder(TestName());
550 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
551
552 HloInstruction* param0 = builder.AddInstruction(
553 HloInstruction::CreateParameter(0, shape, "param0"));
554 HloInstruction* param1 = builder.AddInstruction(
555 HloInstruction::CreateParameter(1, shape, "param1"));
556 HloInstruction* add = builder.AddInstruction(
557 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
558
559 auto builder_cond = HloComputation::Builder("cond");
560 auto cond_param = builder_cond.AddInstruction(
561 HloInstruction::CreateParameter(0, shape, "cond_param"));
562 auto cond_dot =
563 builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param));
564 auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateCompare(
565 ShapeUtil::MakeShape(PRED, {}),
566 builder_cond.AddInstruction(HloInstruction::CreateReshape(
567 ShapeUtil::MakeShape(F32, {}),
568 builder_cond.AddInstruction(
569 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
570 cond_dot, {0, 0}, {1, 1}, {1, 1})))),
571 builder_cond.AddInstruction(HloInstruction::CreateReshape(
572 ShapeUtil::MakeShape(F32, {}),
573 builder_cond.AddInstruction(
574 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
575 cond_dot, {1, 1}, {2, 2}, {1, 1})))),
576 ComparisonDirection::kGt));
577 auto cond = module->AddEmbeddedComputation(builder_cond.Build());
578
579 auto builder_body = HloComputation::Builder("body");
580 auto body_param = builder_body.AddInstruction(
581 HloInstruction::CreateParameter(0, shape, "body_param"));
582 auto body_dot =
583 builder_body.AddInstruction(CreateDot(shape, body_param, body_param));
584 auto body = module->AddEmbeddedComputation(builder_body.Build());
585
586 auto while_hlo = builder.AddInstruction(
587 HloInstruction::CreateWhile(shape, cond, body, add));
588
589 auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
590 auto computation = module->AddEntryComputation(builder.Build());
591
592 EXPECT_TRUE(PropagatePrecision(module.get()));
593
594 EXPECT_EQ(computation->root_instruction(), dot);
595 EXPECT_TRUE(
596 ShapeUtil::Equal(cond_root->shape(), ShapeUtil::MakeShape(PRED, {})));
597 EXPECT_TRUE(OutputsBF16(add));
598 EXPECT_TRUE(OutputsBF16(body_dot));
599 EXPECT_TRUE(OutputsBF16(body_param));
600 EXPECT_TRUE(OutputsBF16(cond_param));
601 EXPECT_FALSE(OutputsBF16(dot));
602 }
603
604 // Tests that if the while condition prevents using BF16, no changes should be
605 // made to the while body and thus the fusion node inside it.
TEST_F(BFloat16PropagationTest,ConditionPreventsPropagationForFusionInsideWhile)606 TEST_F(BFloat16PropagationTest,
607 ConditionPreventsPropagationForFusionInsideWhile) {
608 auto module = CreateNewVerifiedModule();
609 auto builder = HloComputation::Builder(TestName());
610 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
611
612 HloInstruction* param0 = builder.AddInstruction(
613 HloInstruction::CreateParameter(0, shape, "param0"));
614 HloInstruction* param1 = builder.AddInstruction(
615 HloInstruction::CreateParameter(1, shape, "param1"));
616 HloInstruction* add = builder.AddInstruction(
617 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
618
619 auto builder_cond = HloComputation::Builder("cond");
620 auto cond_param = builder_cond.AddInstruction(
621 HloInstruction::CreateParameter(0, shape, "cond_param"));
622 builder_cond.AddInstruction(HloInstruction::CreateCompare(
623 ShapeUtil::MakeShape(PRED, {}),
624 builder_cond.AddInstruction(HloInstruction::CreateReshape(
625 ShapeUtil::MakeShape(F32, {}),
626 builder_cond.AddInstruction(HloInstruction::CreateSlice(
627 ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {0, 0}, {1, 1},
628 {1, 1})))),
629 builder_cond.AddInstruction(HloInstruction::CreateReshape(
630 ShapeUtil::MakeShape(F32, {}),
631 builder_cond.AddInstruction(HloInstruction::CreateSlice(
632 ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2},
633 {1, 1})))),
634 ComparisonDirection::kGt));
635 auto cond = module->AddEmbeddedComputation(builder_cond.Build());
636
637 auto builder_body = HloComputation::Builder("body");
638 auto body_param = builder_body.AddInstruction(
639 HloInstruction::CreateParameter(0, shape, "body_param"));
640 auto body_transpose = builder_body.AddInstruction(
641 HloInstruction::CreateTranspose(shape, body_param, {0, 1}));
642
643 auto builder_f = HloComputation::Builder("fusion");
644 HloInstruction* a_f =
645 builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
646 builder_f.AddInstruction(HloInstruction::CreateTranspose(shape, a_f, {0, 1}));
647 auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
648 auto body_fusion = builder_body.AddInstruction(HloInstruction::CreateFusion(
649 shape, HloInstruction::FusionKind::kCustom, {body_transpose}, comp_f));
650 auto body = module->AddEmbeddedComputation(builder_body.Build());
651
652 auto while_hlo = builder.AddInstruction(
653 HloInstruction::CreateWhile(shape, cond, body, add));
654
655 auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
656 auto computation = module->AddEntryComputation(builder.Build());
657
658 EXPECT_FALSE(PropagatePrecision(module.get()));
659 EXPECT_EQ(computation->root_instruction(), dot);
660 EXPECT_FALSE(OutputsBF16(add));
661 EXPECT_FALSE(OutputsBF16(body_fusion));
662 EXPECT_FALSE(OutputsBF16(body_param));
663 EXPECT_FALSE(OutputsBF16(body_transpose));
664 EXPECT_FALSE(OutputsBF16(a_f));
665 }
666
667 // Tests that BF16 is propagated properly through while computations with
668 // tuple-shaped input/output.
TEST_F(BFloat16PropagationTest,PropagateThroughTupleWhile)669 TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
670 auto module = CreateNewVerifiedModule();
671 auto builder = HloComputation::Builder(TestName());
672 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
673
674 HloInstruction* param0 = builder.AddInstruction(
675 HloInstruction::CreateParameter(0, shape, "param0"));
676 HloInstruction* param1 = builder.AddInstruction(
677 HloInstruction::CreateParameter(1, shape, "param1"));
678 HloInstruction* add0 = builder.AddInstruction(
679 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
680 HloInstruction* add1 = builder.AddInstruction(
681 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
682 HloInstruction* tuple =
683 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
684
685 auto builder_cond = HloComputation::Builder("cond");
686 auto cond_param = builder_cond.AddInstruction(
687 HloInstruction::CreateParameter(0, tuple->shape(), "cond_param"));
688 auto cond_lhs = builder_cond.AddInstruction(
689 HloInstruction::CreateGetTupleElement(shape, cond_param, 0));
690 auto cond_rhs = builder_cond.AddInstruction(
691 HloInstruction::CreateGetTupleElement(shape, cond_param, 1));
692 // This add should prevent RHS from using BF16
693 auto cond_add_rhs = builder_cond.AddInstruction(
694 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs));
695 auto cond_dot =
696 builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs));
697 builder_cond.AddInstruction(HloInstruction::CreateCompare(
698 ShapeUtil::MakeShape(PRED, {}),
699 builder_cond.AddInstruction(HloInstruction::CreateReshape(
700 ShapeUtil::MakeShape(F32, {}),
701 builder_cond.AddInstruction(
702 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
703 cond_dot, {0, 0}, {1, 1}, {1, 1})))),
704 builder_cond.AddInstruction(HloInstruction::CreateReshape(
705 ShapeUtil::MakeShape(F32, {}),
706 builder_cond.AddInstruction(
707 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
708 cond_dot, {1, 1}, {2, 2}, {1, 1})))),
709 ComparisonDirection::kGt));
710 auto cond = module->AddEmbeddedComputation(builder_cond.Build());
711
712 auto builder_body = HloComputation::Builder("body");
713 auto body_param = builder_body.AddInstruction(
714 HloInstruction::CreateParameter(0, tuple->shape(), "body_param"));
715 auto body_lhs = builder_body.AddInstruction(
716 HloInstruction::CreateGetTupleElement(shape, body_param, 0));
717 auto body_rhs = builder_body.AddInstruction(
718 HloInstruction::CreateGetTupleElement(shape, body_param, 1));
719 auto body_dot1 =
720 builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
721 auto body_dot2 =
722 builder_body.AddInstruction(CreateDot(shape, body_rhs, body_lhs));
723 auto body_transpose = builder_body.AddInstruction(
724 HloInstruction::CreateTranspose(shape, body_dot2, {0, 1}));
725 builder_body.AddInstruction(
726 HloInstruction::CreateTuple({body_dot1, body_transpose}));
727 auto body = module->AddEmbeddedComputation(builder_body.Build());
728
729 auto while_hlo = builder.AddInstruction(
730 HloInstruction::CreateWhile(tuple->shape(), cond, body, tuple));
731
732 auto lhs = builder.AddInstruction(
733 HloInstruction::CreateGetTupleElement(shape, while_hlo, 0));
734 auto rhs = builder.AddInstruction(
735 HloInstruction::CreateGetTupleElement(shape, while_hlo, 1));
736 auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
737 auto computation = module->AddEntryComputation(builder.Build());
738
739 EXPECT_TRUE(PropagatePrecision(module.get()));
740
741 EXPECT_EQ(computation->root_instruction(), dot);
742 EXPECT_TRUE(OutputsBF16(lhs));
743 EXPECT_FALSE(OutputsBF16(rhs));
744 EXPECT_TRUE(OutputsBF16(body_dot1));
745 EXPECT_TRUE(OutputsBF16(body_lhs));
746 EXPECT_FALSE(OutputsBF16(body_rhs));
747 EXPECT_FALSE(OutputsBF16(body_dot2));
748 EXPECT_FALSE(OutputsBF16(body_transpose));
749 EXPECT_TRUE(OutputsBF16(cond_lhs));
750 EXPECT_FALSE(OutputsBF16(cond_rhs));
751 EXPECT_TRUE(OutputsBF16(add0));
752 EXPECT_FALSE(OutputsBF16(add1));
753 }
754
755 // Tests that BF16 is not propagated through multiple whiles that invoke the
756 // same computation as long as one while prevents the propagation.
TEST_F(BFloat16PropagationTest,DoNotPropagateWhilesCallingSameComputation)757 TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
758 auto module = CreateNewVerifiedModule();
759 auto builder = HloComputation::Builder(TestName());
760 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
761
762 HloInstruction* param0 = builder.AddInstruction(
763 HloInstruction::CreateParameter(0, shape, "param0"));
764 HloInstruction* param1 = builder.AddInstruction(
765 HloInstruction::CreateParameter(1, shape, "param1"));
766 HloInstruction* add0 = builder.AddInstruction(
767 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
768 HloInstruction* add1 = builder.AddInstruction(
769 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
770 HloInstruction* add2 = builder.AddInstruction(
771 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
772 HloInstruction* add3 = builder.AddInstruction(
773 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
774 HloInstruction* tuple0 =
775 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
776 HloInstruction* tuple1 =
777 builder.AddInstruction(HloInstruction::CreateTuple({add2, add3}));
778
779 // Condition computation for the first while.
780 auto builder_cond0 = HloComputation::Builder("cond0");
781 auto cond0_param = builder_cond0.AddInstruction(
782 HloInstruction::CreateParameter(0, tuple0->shape(), "cond0_param"));
783 auto cond0_lhs = builder_cond0.AddInstruction(
784 HloInstruction::CreateGetTupleElement(shape, cond0_param, 0));
785 auto cond0_rhs = builder_cond0.AddInstruction(
786 HloInstruction::CreateGetTupleElement(shape, cond0_param, 1));
787 // This add should prevent RHS from using BF16
788 auto cond0_add_rhs =
789 builder_cond0.AddInstruction(HloInstruction::CreateBinary(
790 shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs));
791 auto cond0_dot =
792 builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs));
793 builder_cond0.AddInstruction(HloInstruction::CreateCompare(
794 ShapeUtil::MakeShape(PRED, {}),
795 builder_cond0.AddInstruction(HloInstruction::CreateReshape(
796 ShapeUtil::MakeShape(F32, {}),
797 builder_cond0.AddInstruction(
798 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
799 cond0_dot, {0, 0}, {1, 1}, {1, 1})))),
800 builder_cond0.AddInstruction(HloInstruction::CreateReshape(
801 ShapeUtil::MakeShape(F32, {}),
802 builder_cond0.AddInstruction(
803 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
804 cond0_dot, {1, 1}, {2, 2}, {1, 1})))),
805 ComparisonDirection::kGt));
806 auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build());
807
808 // Condition computation for the second while.
809 auto builder_cond1 = HloComputation::Builder("cond1");
810 auto cond1_param = builder_cond1.AddInstruction(
811 HloInstruction::CreateParameter(0, tuple1->shape(), "cond1_param"));
812 auto cond1_lhs = builder_cond1.AddInstruction(
813 HloInstruction::CreateGetTupleElement(shape, cond1_param, 0));
814 auto cond1_rhs = builder_cond1.AddInstruction(
815 HloInstruction::CreateGetTupleElement(shape, cond1_param, 1));
816 // This add should prevent LHS from using BF16
817 auto cond1_add_lhs =
818 builder_cond1.AddInstruction(HloInstruction::CreateBinary(
819 shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs));
820 auto cond1_dot =
821 builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs));
822 builder_cond1.AddInstruction(HloInstruction::CreateCompare(
823 ShapeUtil::MakeShape(PRED, {}),
824 builder_cond1.AddInstruction(HloInstruction::CreateReshape(
825 ShapeUtil::MakeShape(F32, {}),
826 builder_cond1.AddInstruction(
827 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
828 cond1_dot, {0, 0}, {1, 1}, {1, 1})))),
829 builder_cond1.AddInstruction(HloInstruction::CreateReshape(
830 ShapeUtil::MakeShape(F32, {}),
831 builder_cond1.AddInstruction(
832 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
833 cond1_dot, {1, 1}, {2, 2}, {1, 1})))),
834 ComparisonDirection::kGt));
835 auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build());
836
837 // Body computation shared by both whiles.
838 auto builder_body = HloComputation::Builder("body");
839 auto body_param = builder_body.AddInstruction(
840 HloInstruction::CreateParameter(0, tuple0->shape(), "body_param"));
841 auto body_lhs = builder_body.AddInstruction(
842 HloInstruction::CreateGetTupleElement(shape, body_param, 0));
843 auto body_rhs = builder_body.AddInstruction(
844 HloInstruction::CreateGetTupleElement(shape, body_param, 1));
845 auto body_dot =
846 builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
847 builder_body.AddInstruction(
848 HloInstruction::CreateTuple({body_dot, body_rhs}));
849 auto body = module->AddEmbeddedComputation(builder_body.Build());
850
851 auto while0 = builder.AddInstruction(
852 HloInstruction::CreateWhile(tuple0->shape(), cond0, body, tuple0));
853 auto while1 = builder.AddInstruction(
854 HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1));
855
856 auto lhs = builder.AddInstruction(
857 CreateDot(shape,
858 builder.AddInstruction(
859 HloInstruction::CreateGetTupleElement(shape, while0, 0)),
860 builder.AddInstruction(
861 HloInstruction::CreateGetTupleElement(shape, while0, 1))));
862 auto rhs = builder.AddInstruction(
863 CreateDot(shape,
864 builder.AddInstruction(
865 HloInstruction::CreateGetTupleElement(shape, while1, 0)),
866 builder.AddInstruction(
867 HloInstruction::CreateGetTupleElement(shape, while1, 1))));
868 auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
869 auto computation = module->AddEntryComputation(builder.Build());
870
871 EXPECT_TRUE(PropagatePrecision(module.get()));
872 EXPECT_FALSE(OutputsBF16(body_dot));
873 EXPECT_FALSE(OutputsBF16(body_rhs));
874 EXPECT_FALSE(OutputsBF16(body_lhs));
875 EXPECT_FALSE(OutputsBF16(cond0_lhs));
876 EXPECT_FALSE(OutputsBF16(cond0_rhs));
877 EXPECT_FALSE(OutputsBF16(cond1_lhs));
878 EXPECT_FALSE(OutputsBF16(cond1_rhs));
879 EXPECT_TRUE(OutputsBF16(cond0_add_rhs));
880 EXPECT_TRUE(OutputsBF16(cond1_add_lhs));
881 EXPECT_EQ(computation->root_instruction(), dot);
882 }
883
884 // Tests that if this pass turns an F32 -> BF16 conversion into a no-op (BF16 ->
885 // BF16 conversion), then it will remove that conversion.
TEST_F(BFloat16PropagationTest,NoopConversionRemoved)886 TEST_F(BFloat16PropagationTest, NoopConversionRemoved) {
887 auto builder = HloComputation::Builder(TestName());
888 Shape f32_shape = ShapeUtil::MakeShape(F32, {4, 4});
889 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {4, 4});
890
891 HloInstruction* param = builder.AddInstruction(
892 HloInstruction::CreateParameter(0, f32_shape, "param"));
893 HloInstruction* add0 = builder.AddInstruction(
894 HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param));
895 HloInstruction* add1 = builder.AddInstruction(
896 HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param));
897 HloInstruction* tuple =
898 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
899 HloInstruction* gte0 = builder.AddInstruction(
900 HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0));
901 HloInstruction* gte1 = builder.AddInstruction(
902 HloInstruction::CreateGetTupleElement(f32_shape, tuple, 1));
903 HloInstruction* convert0 =
904 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte0));
905 HloInstruction* convert1 =
906 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte1));
907 HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary(
908 bf16_shape, HloOpcode::kAdd, convert0, convert1));
909
910 auto module = CreateNewVerifiedModule();
911 auto computation = module->AddEntryComputation(builder.Build());
912
913 EXPECT_TRUE(PropagatePrecision(module.get()));
914
915 EXPECT_EQ(computation->root_instruction(), add2);
916 EXPECT_EQ(add2->operand(0), add0);
917 EXPECT_EQ(add2->operand(1), add1);
918 EXPECT_EQ(add0->shape().element_type(), BF16);
919 EXPECT_EQ(add1->shape().element_type(), BF16);
920 }
921
TEST_F(BFloat16PropagationTest,TupleDomain)922 TEST_F(BFloat16PropagationTest, TupleDomain) {
923 auto builder = HloComputation::Builder(TestName());
924 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
925
926 HloInstruction* a =
927 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
928 HloInstruction* b =
929 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
930 HloInstruction* a_trans =
931 builder.AddInstruction(HloInstruction::CreateTranspose(shape, a, {0, 1}));
932 HloInstruction* b_trans =
933 builder.AddInstruction(HloInstruction::CreateTranspose(shape, b, {0, 1}));
934 HloInstruction* tuple =
935 builder.AddInstruction(HloInstruction::CreateTuple({a_trans, b_trans}));
936 HloInstruction* domain = builder.AddInstruction(
937 HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr));
938 HloInstruction* a_gte = builder.AddInstruction(
939 HloInstruction::CreateGetTupleElement(shape, domain, 0));
940 HloInstruction* b_gte = builder.AddInstruction(
941 HloInstruction::CreateGetTupleElement(shape, domain, 1));
942 HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a_gte, b_gte));
943 HloInstruction* root = builder.AddInstruction(
944 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
945
946 auto module = CreateNewVerifiedModule();
947 auto computation = module->AddEntryComputation(builder.Build());
948
949 EXPECT_TRUE(PropagatePrecision(module.get()));
950 EXPECT_EQ(computation->root_instruction(), root);
951
952 // test BF16 propagated through domain
953 EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 0).element_type(),
954 BF16);
955 EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 1).element_type(),
956 BF16);
957
958 EXPECT_TRUE(OutputsBF16(a_trans));
959 EXPECT_TRUE(OutputsBF16(b_trans));
960 EXPECT_TRUE(OutputsBF16(a_gte));
961 EXPECT_TRUE(OutputsBF16(b_gte));
962 EXPECT_FALSE(OutputsBF16(a));
963 EXPECT_FALSE(OutputsBF16(b));
964 }
965
966 // Tests that bf16 is not propagated through a domain in case its input cannot
967 // be propagated. In the case below the input of the domain is the parameter
968 // tuple which cannot be propagated, so the domain instruction is not propagated
969 // either.
TEST_F(BFloat16PropagationTest,TupleDomainNoPropagation)970 TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) {
971 auto builder = HloComputation::Builder(TestName());
972 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
973 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
974
975 HloInstruction* param = builder.AddInstruction(
976 HloInstruction::CreateParameter(0, tuple_shape, "param"));
977 HloInstruction* domain = builder.AddInstruction(
978 HloInstruction::CreateDomain(param->shape(), param, nullptr, nullptr));
979 HloInstruction* a_gte = builder.AddInstruction(
980 HloInstruction::CreateGetTupleElement(shape, domain, 0));
981 HloInstruction* b_gte = builder.AddInstruction(
982 HloInstruction::CreateGetTupleElement(shape, domain, 1));
983 HloInstruction* a_trans = builder.AddInstruction(
984 HloInstruction::CreateTranspose(shape, a_gte, {0, 1}));
985 HloInstruction* b_trans = builder.AddInstruction(
986 HloInstruction::CreateTranspose(shape, b_gte, {0, 1}));
987 HloInstruction* dot =
988 builder.AddInstruction(CreateDot(shape, a_trans, b_trans));
989 HloInstruction* root = builder.AddInstruction(
990 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
991
992 auto module = CreateNewVerifiedModule();
993 auto computation = module->AddEntryComputation(builder.Build());
994
995 EXPECT_TRUE(PropagatePrecision(module.get()));
996
997 EXPECT_EQ(computation->root_instruction(), root);
998 EXPECT_TRUE(OutputsBF16(a_trans));
999 EXPECT_TRUE(OutputsBF16(b_trans));
1000 EXPECT_FALSE(OutputsBF16(a_gte));
1001 EXPECT_FALSE(OutputsBF16(b_gte));
1002 EXPECT_FALSE(OutputsBF16(domain));
1003 EXPECT_FALSE(OutputsBF16(param));
1004 }
1005
TEST_F(BFloat16PropagationTest,ConditionalSeparateBranchOperands)1006 TEST_F(BFloat16PropagationTest, ConditionalSeparateBranchOperands) {
1007 const std::string module_str = R"(
1008 HloModule module
1009
1010 true_branch {
1011 true_param = f32[4096,4096] parameter(0)
1012 ROOT max = f32[4096,4096] maximum(true_param, true_param)
1013 }
1014
1015 false_branch {
1016 false_param = f32[4096,4096] parameter(0)
1017 ROOT add = f32[4096,4096] add(false_param, false_param)
1018 }
1019
1020 ENTRY entry {
1021 param0 = f32[4096,4096] parameter(0)
1022 param1 = f32[4096,4096] parameter(1)
1023 copy0 = f32[4096,4096] copy(param0)
1024 copy1 = f32[4096,4096] copy(param1)
1025 param2 = pred[] parameter(2)
1026 conditional = f32[4096,4096] conditional(param2, copy0, copy1),
1027 true_computation=true_branch, false_computation=false_branch
1028 ROOT dot = f32[4096,4096] dot(conditional, conditional),
1029 lhs_contracting_dims={1}, rhs_contracting_dims={0}
1030 }
1031 )";
1032
1033 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1034 ParseAndReturnVerifiedModule(module_str));
1035 EXPECT_TRUE(PropagatePrecision(module.get()));
1036
1037 auto cond = FindInstruction(module.get(), "conditional");
1038 auto copy0 = FindInstruction(module.get(), "copy0");
1039 auto copy1 = FindInstruction(module.get(), "copy1");
1040 EXPECT_TRUE(OutputsBF16(cond));
1041 EXPECT_TRUE(OutputsBF16(copy0));
1042 EXPECT_FALSE(OutputsBF16(copy1));
1043 }
1044
TEST_F(BFloat16PropagationTest,ConditionalSharedBranchOperands)1045 TEST_F(BFloat16PropagationTest, ConditionalSharedBranchOperands) {
1046 const std::string module_str = R"(
1047 HloModule module
1048
1049 true_branch {
1050 true_param = f32[4096,4096] parameter(0)
1051 ROOT max = f32[4096,4096] maximum(true_param, true_param)
1052 }
1053
1054 false_branch {
1055 false_param = f32[4096,4096] parameter(0)
1056 ROOT add = f32[4096,4096] add(false_param, false_param)
1057 }
1058
1059 ENTRY entry {
1060 param0 = f32[4096,4096] parameter(0)
1061 copy0 = f32[4096,4096] copy(param0)
1062 param1 = pred[] parameter(1)
1063 conditional = f32[4096,4096] conditional(param1, copy0, copy0),
1064 true_computation=true_branch, false_computation=false_branch
1065 ROOT dot = f32[4096,4096] dot(conditional, conditional),
1066 lhs_contracting_dims={1}, rhs_contracting_dims={0}
1067 }
1068 )";
1069
1070 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1071 ParseAndReturnVerifiedModule(module_str));
1072 EXPECT_TRUE(PropagatePrecision(module.get()));
1073
1074 auto cond = FindInstruction(module.get(), "conditional");
1075 auto copy0 = FindInstruction(module.get(), "copy0");
1076 EXPECT_TRUE(OutputsBF16(cond));
1077 EXPECT_FALSE(OutputsBF16(copy0));
1078 }
1079
TEST_F(BFloat16PropagationTest,ConditionalAliasingOutputs)1080 TEST_F(BFloat16PropagationTest, ConditionalAliasingOutputs) {
1081 const std::string module_str = R"(
1082 HloModule module
1083
1084 true_branch {
1085 true_param = f32[4096,4096] parameter(0)
1086 max = f32[4096,4096] maximum(true_param, true_param)
1087 ROOT true_tuple = (f32[4096,4096], f32[4096,4096]) tuple(max, max)
1088 }
1089
1090 false_branch {
1091 false_param = f32[4096,4096] parameter(0)
1092 min = f32[4096,4096] minimum(false_param, false_param)
1093 max2 = f32[4096,4096] maximum(false_param, false_param)
1094 ROOT false_tuple = (f32[4096,4096], f32[4096,4096]) tuple(min, max2)
1095 }
1096
1097 ENTRY entry {
1098 param0 = f32[4096,4096] parameter(0)
1099 copy0 = f32[4096,4096] copy(param0)
1100 param1 = pred[] parameter(1)
1101 conditional = (f32[4096,4096], f32[4096,4096]) conditional(param1, copy0, copy0),
1102 true_computation=true_branch, false_computation=false_branch
1103 gte0 = f32[4096,4096] get-tuple-element(conditional), index=0
1104 gte1 = f32[4096,4096] get-tuple-element(conditional), index=1
1105 dot = f32[4096,4096] dot(gte0, gte1),
1106 lhs_contracting_dims={1}, rhs_contracting_dims={0}
1107 ROOT tuple = (f32[4096,4096], f32[4096,4096]) tuple(dot, gte1)
1108 }
1109 )";
1110
1111 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1112 ParseAndReturnVerifiedModule(module_str));
1113 EXPECT_FALSE(PropagatePrecision(module.get()));
1114 }
1115
TEST_F(BFloat16PropagationTest,DynamicUpdateSlice)1116 TEST_F(BFloat16PropagationTest, DynamicUpdateSlice) {
1117 // This test is crafted so that the DUS has an f32 input (due to parameter)
1118 // and bf16 output (due to dot). But we should enforce DUS operand 0 and
1119 // output to get the same precision since it's an in-place operation.
1120 const std::string module_str = R"(
1121 HloModule Module
1122
1123 ENTRY main {
1124 param = f32[128,128] parameter(0)
1125 constant.1 = f32[] constant(0)
1126 broadcast.6 = f32[128,1] broadcast(constant.1), dimensions={}
1127 constant.3 = s32[] constant(0)
1128 dynamic-update-slice = f32[128,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3)
1129 ROOT dot = f32[128,128] dot(dynamic-update-slice, dynamic-update-slice), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1130 }
1131 )";
1132
1133 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1134 ParseAndReturnVerifiedModule(module_str));
1135 EXPECT_FALSE(PropagatePrecision(module.get()));
1136
1137 HloInstruction* dus = module->entry_computation()->GetInstructionWithName(
1138 "dynamic-update-slice");
1139 EXPECT_FALSE(OutputsBF16(dus));
1140 }
1141
1142 // This test demonstrates the need for invoking the ResolveAliasingBuffer
1143 // multiple times via a fixed-point algorithm. The key was the aliasing of the
1144 // two output buffers of the conditional, at subshape 0 (first element). This
1145 // aliasing is not resolved until after the gte0 variale is already processed,
1146 // triggering incorrect type for gte0 if not repeating the aliasing analysis.
TEST_F(BFloat16PropagationTest,ConditionalGTEWithFusion)1147 TEST_F(BFloat16PropagationTest, ConditionalGTEWithFusion) {
1148 const std::string module_str = R"(
1149 HloModule module
1150
1151 %add.0 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
1152 x.1 = f32[4096,4096] parameter(0)
1153 y.1 = f32[4096,4096] parameter(1)
1154 ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
1155 lhs_contracting_dims={1}, rhs_contracting_dims={0}
1156 }
1157
1158 %add.1 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
1159 x.1 = f32[4096,4096] parameter(0)
1160 y.1 = f32[4096,4096] parameter(1)
1161 ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
1162 lhs_contracting_dims={1}, rhs_contracting_dims={0}
1163 }
1164
1165 %add.2 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
1166 x.1 = f32[4096,4096] parameter(0)
1167 y.1 = f32[4096,4096] parameter(1)
1168 ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
1169 lhs_contracting_dims={1}, rhs_contracting_dims={0}
1170 }
1171
1172 %add.3 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
1173 x.1 = f32[4096,4096] parameter(0)
1174 y.1 = f32[4096,4096] parameter(1)
1175 ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
1176 lhs_contracting_dims={1}, rhs_contracting_dims={0}
1177 }
1178
1179 true_branch {
1180 true_param = f32[4096,4096] parameter(0)
1181 constant.1 = f32[4096,4096] constant(0)
1182 add0 = f32[4096,4096] fusion(true_param,true_param), kind=kLoop, calls=add.0
1183 constant.2 = f32[4096,4096] constant(0)
1184 ROOT tuple.2 = (f32[4096,4096], f32[4096,4096], f32[4096,4096]) tuple(true_param,add0,constant.2)
1185 }
1186
1187 false_branch {
1188 false_param = f32[4096,4096] parameter(0)
1189 add3 = f32[4096,4096] fusion(false_param,false_param), kind=kLoop, calls=add.1
1190 constant.1 = f32[4096,4096] constant(0)
1191 ROOT tuple.2 = (f32[4096,4096], f32[4096,4096], f32[4096,4096]) tuple(add3, add3,constant.1)
1192 }
1193
1194 ENTRY entry {
1195 param0 = f32[4096,4096] parameter(0)
1196 copy0 = f32[4096,4096] copy(param0)
1197 param1 = pred[] parameter(1)
1198 conditional = (f32[4096,4096], f32[4096,4096], f32[4096,4096]) conditional(param1, param0, copy0),
1199 true_computation=true_branch, false_computation=false_branch
1200 gte = f32[4096,4096] get-tuple-element(conditional), index=0
1201 gte1 = f32[4096,4096] get-tuple-element(conditional), index=1
1202 gte2 = f32[4096,4096] get-tuple-element(conditional), index=2
1203 add2 = f32[4096,4096] fusion(gte, gte1), kind=kLoop, calls=add.2
1204 ROOT add3 = f32[4096,4096] fusion(add2, gte2), kind=kLoop, calls=add.3
1205 }
1206 )";
1207
1208 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1209 ParseAndReturnVerifiedModule(module_str));
1210 EXPECT_TRUE(PropagatePrecision(module.get()));
1211 VLOG(2) << module->ToString() << "\n";
1212 EXPECT_TRUE(HloVerifier(/*layout_sensitive=*/false,
1213 /*allow_mixed_precision=*/true)
1214 .Run(module.get())
1215 .status()
1216 .ok());
1217 auto gte = FindInstruction(module.get(), "gte");
1218 auto gte1 = FindInstruction(module.get(), "gte1");
1219 auto gte2 = FindInstruction(module.get(), "gte2");
1220 EXPECT_FALSE(OutputsBF16(gte));
1221 EXPECT_FALSE(OutputsBF16(gte1));
1222 EXPECT_TRUE(OutputsBF16(gte2));
1223 }
1224
1225 } // namespace xla
1226