xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
17 
18 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/test.h"
25 #include "tensorflow/compiler/xla/test_helpers.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 
30 namespace xla {
31 
32 // A class specifying the BF16 support used to test the propagation pass. It
33 // specifies that BF16 and mixed precision are supported in all HloInstructions,
34 // and that kDot reduces its operands precision to BF16.
35 class TestBFloat16Support : public BFloat16Support {
36  public:
TestBFloat16Support()37   TestBFloat16Support() {}
~TestBFloat16Support()38   ~TestBFloat16Support() override {}
39 
SupportsBF16Operand(const HloInstruction & hlo,int64_t operand_index) const40   bool SupportsBF16Operand(const HloInstruction& hlo,
41                            int64_t operand_index) const override {
42     return true;
43   }
44 
SupportsBF16Output(const HloInstruction & hlo) const45   bool SupportsBF16Output(const HloInstruction& hlo) const override {
46     return true;
47   }
48 
SupportsMixedPrecisions(const HloInstruction & hlo) const49   bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
50     return true;
51   }
52 
EffectiveOperandPrecisionIsBF16(const HloInstruction & hlo,int64_t operand_index) const53   bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo,
54                                        int64_t operand_index) const override {
55     return hlo.opcode() == HloOpcode::kDot;
56   }
57 };
58 
59 class BFloat16PropagationTest : public HloTestBase {
60  protected:
BFloat16PropagationTest()61   BFloat16PropagationTest()
62       : HloTestBase(/*verifier_layout_sensitive=*/false,
63                     /*allow_mixed_precision_in_hlo_verifier=*/true) {}
64 
65   // Runs the propagation pass on the given module, and returns whether the
66   // module is changed after this pass.
PropagatePrecision(HloModule * module)67   bool PropagatePrecision(HloModule* module) {
68     TestBFloat16Support bfloat16_support;
69     BFloat16Propagation propagation(&bfloat16_support);
70     StatusOr<bool> result = propagation.Run(module);
71     EXPECT_IS_OK(result.status());
72     return result.ValueOrDie();
73   }
74 
75   // Returns whether the given HloInstruction's output element type is BF16 or
76   // the only use of it is converting to BF16.
OutputsBF16(const HloInstruction * inst)77   bool OutputsBF16(const HloInstruction* inst) {
78     if (inst->shape().element_type() == BF16) {
79       return true;
80     }
81     return inst->user_count() == 1 &&
82            inst->users()[0]->opcode() == HloOpcode::kConvert &&
83            inst->users()[0]->shape().element_type() == BF16;
84   }
85 
CreateDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs)86   std::unique_ptr<HloInstruction> CreateDot(const Shape& shape,
87                                             HloInstruction* lhs,
88                                             HloInstruction* rhs) {
89     DotDimensionNumbers dot_dnums;
90     dot_dnums.add_lhs_contracting_dimensions(1);
91     dot_dnums.add_rhs_contracting_dimensions(0);
92     return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
93                                      DefaultPrecisionConfig(2));
94   }
95 };
96 
97 // Tests that BF16 can propagate through select over non-tuple buffers, but not
98 // through add where reducing operand precision can affect the result.
TEST_F(BFloat16PropagationTest,PropagateThroughSelectButNotAdd)99 TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) {
100   auto builder = HloComputation::Builder(TestName());
101   Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
102 
103   HloInstruction* a =
104       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
105   HloInstruction* b =
106       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
107   HloInstruction* c =
108       builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c"));
109   HloInstruction* add0 = builder.AddInstruction(
110       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
111   HloInstruction* add1 = builder.AddInstruction(
112       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b));
113   HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateCompare(
114       ShapeUtil::MakeShape(PRED, {2, 4}), a, b, ComparisonDirection::kEq));
115   HloInstruction* sel = builder.AddInstruction(
116       HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1));
117   HloInstruction* xpose =
118       builder.AddInstruction(HloInstruction::CreateTranspose(
119           ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0}));
120   HloInstruction* dot = builder.AddInstruction(
121       CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, a));
122   HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
123       ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot));
124 
125   auto module = CreateNewVerifiedModule();
126   auto computation = module->AddEntryComputation(builder.Build());
127 
128   EXPECT_TRUE(PropagatePrecision(module.get()));
129 
130   EXPECT_EQ(computation->root_instruction(), root);
131   EXPECT_TRUE(OutputsBF16(xpose));
132   EXPECT_TRUE(OutputsBF16(sel));
133   EXPECT_TRUE(OutputsBF16(add1));
134   EXPECT_FALSE(OutputsBF16(add0));
135   EXPECT_FALSE(OutputsBF16(a));
136   EXPECT_FALSE(OutputsBF16(b));
137   EXPECT_FALSE(OutputsBF16(c));
138 }
139 
TEST_F(BFloat16PropagationTest,PropagateThroughMaxPoolReduceWindow)140 TEST_F(BFloat16PropagationTest, PropagateThroughMaxPoolReduceWindow) {
141   auto module = CreateNewVerifiedModule();
142 
143   auto sub_builder = HloComputation::Builder("max");
144   HloInstruction* p0 = sub_builder.AddInstruction(
145       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "a"));
146   HloInstruction* p1 = sub_builder.AddInstruction(
147       HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "b"));
148   sub_builder.AddInstruction(HloInstruction::CreateBinary(
149       ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, p0, p1));
150   auto max_computation = module->AddEmbeddedComputation(sub_builder.Build());
151 
152   auto builder = HloComputation::Builder(TestName());
153   Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
154 
155   HloInstruction* a =
156       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
157   HloInstruction* b =
158       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
159   HloInstruction* c =
160       builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c"));
161   HloInstruction* add = builder.AddInstruction(
162       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
163   Window window;
164   WindowDimension dim;
165   dim.set_size(2);
166   dim.set_stride(1);
167   dim.set_padding_high(1);
168   dim.set_window_dilation(1);
169   dim.set_base_dilation(1);
170   *window.add_dimensions() = dim;
171   *window.add_dimensions() = dim;
172   HloInstruction* rw =
173       builder.AddInstruction(HloInstruction::CreateReduceWindow(
174           shape, add,
175           builder.AddInstruction(
176               HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
177           window, max_computation));
178   HloInstruction* xpose =
179       builder.AddInstruction(HloInstruction::CreateTranspose(
180           ShapeUtil::MakeShape(F32, {4, 2}), c, {1, 0}));
181   HloInstruction* dot = builder.AddInstruction(
182       CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, rw));
183   HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
184       ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot));
185 
186   auto computation = module->AddEntryComputation(builder.Build());
187 
188   EXPECT_TRUE(PropagatePrecision(module.get()));
189 
190   EXPECT_EQ(computation->root_instruction(), root);
191   EXPECT_TRUE(OutputsBF16(add));
192   EXPECT_TRUE(OutputsBF16(xpose));
193   EXPECT_TRUE(OutputsBF16(rw));
194 }
195 
196 // Tests that side-effecting all-reduce should not be changed.
TEST_F(BFloat16PropagationTest,DoNotChangeAllReduce)197 TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) {
198   auto module = CreateNewVerifiedModule();
199 
200   auto builder = HloComputation::Builder(TestName());
201   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
202   HloInstruction* a =
203       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
204   HloInstruction* b =
205       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
206   auto rb = HloComputation::Builder(TestName());
207   rb.AddInstruction(HloInstruction::CreateBinary(
208       shape, HloOpcode::kAdd,
209       rb.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")),
210       rb.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"))));
211   auto reduction = module->AddEmbeddedComputation(rb.Build());
212   HloInstruction* all_reduce =
213       builder.AddInstruction(HloInstruction::CreateAllReduce(
214           ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction,
215           /*replica_groups=*/{}, /*constrain_layout=*/false,
216           /*channel_id=*/1, /*use_global_device_ids=*/false));
217   HloInstruction* gte0 = builder.AddInstruction(
218       HloInstruction::CreateGetTupleElement(shape, all_reduce, 0));
219   HloInstruction* gte1 = builder.AddInstruction(
220       HloInstruction::CreateGetTupleElement(shape, all_reduce, 1));
221   HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1));
222   HloInstruction* root = builder.AddInstruction(
223       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
224 
225   auto computation = module->AddEntryComputation(builder.Build());
226 
227   EXPECT_FALSE(PropagatePrecision(module.get()));
228   EXPECT_EQ(computation->root_instruction(), root);
229 }
230 
231 // Tests that if a constant is converted to BF16 then its literal must also be
232 // converted.
TEST_F(BFloat16PropagationTest,ConvertConstantLiteral)233 TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
234   auto builder = HloComputation::Builder(TestName());
235   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
236   Array2D<float> array_a(4, 4);
237   array_a.FillUnique(1.0f);
238   Array2D<float> array_b(4, 4);
239   array_b.FillUnique(10.0f);
240 
241   HloInstruction* a = builder.AddInstruction(
242       HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a)));
243   HloInstruction* b = builder.AddInstruction(
244       HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b)));
245   HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b));
246 
247   auto module = CreateNewVerifiedModule();
248   auto computation = module->AddEntryComputation(builder.Build());
249 
250   EXPECT_TRUE(PropagatePrecision(module.get()));
251 
252   EXPECT_EQ(computation->root_instruction(), dot);
253   EXPECT_TRUE(OutputsBF16(dot->operand(0)));
254   EXPECT_TRUE(OutputsBF16(dot->operand(1)));
255   EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
256   EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
257   EXPECT_TRUE(LiteralTestUtil::Equal(
258       LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)),
259       dot->operand(0)->literal()));
260   EXPECT_TRUE(LiteralTestUtil::Equal(
261       LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)),
262       dot->operand(1)->literal()));
263 }
264 
265 // Tests that BF16 can be propagated through nested tuples.
TEST_F(BFloat16PropagationTest,PropagateThroughTuples)266 TEST_F(BFloat16PropagationTest, PropagateThroughTuples) {
267   auto builder = HloComputation::Builder(TestName());
268   Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
269 
270   HloInstruction* a =
271       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
272   HloInstruction* b =
273       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
274   HloInstruction* add0 = builder.AddInstruction(
275       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
276   HloInstruction* add1 = builder.AddInstruction(
277       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a));
278   HloInstruction* add2 = builder.AddInstruction(
279       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, b));
280   HloInstruction* xpose =
281       builder.AddInstruction(HloInstruction::CreateTranspose(
282           ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0}));
283 
284   HloInstruction* tuple0 =
285       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1, add2}));
286   HloInstruction* tuple1 =
287       builder.AddInstruction(HloInstruction::CreateTuple({tuple0, xpose}));
288 
289   HloInstruction* lhs = builder.AddInstruction(
290       HloInstruction::CreateGetTupleElement(xpose->shape(), tuple1, 1));
291   HloInstruction* rhs =
292       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
293           add0->shape(),
294           builder.AddInstruction(HloInstruction::CreateGetTupleElement(
295               tuple0->shape(), tuple1, 0)),
296           0));
297   HloInstruction* dot = builder.AddInstruction(
298       CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
299 
300   HloInstruction* output_tuple =
301       builder.AddInstruction(HloInstruction::CreateTuple({dot, add2}));
302 
303   auto module = CreateNewVerifiedModule();
304   auto computation = module->AddEntryComputation(builder.Build());
305 
306   EXPECT_TRUE(PropagatePrecision(module.get()));
307 
308   EXPECT_EQ(computation->root_instruction(), output_tuple);
309   EXPECT_TRUE(OutputsBF16(xpose));
310   EXPECT_TRUE(OutputsBF16(add0));
311   EXPECT_TRUE(OutputsBF16(add1));
312   EXPECT_FALSE(OutputsBF16(add2));
313 }
314 
315 // Tests that even if an instruction does not define a buffer in its output, its
316 // shape must match the defining instruction.
TEST_F(BFloat16PropagationTest,SameValueReferencedTwice)317 TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) {
318   auto builder = HloComputation::Builder(TestName());
319   Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
320 
321   HloInstruction* a =
322       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
323   HloInstruction* b =
324       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
325   HloInstruction* add0 = builder.AddInstruction(
326       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
327   HloInstruction* add1 = builder.AddInstruction(
328       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a));
329 
330   HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateTranspose(
331       ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0}));
332 
333   HloInstruction* tuple =
334       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
335   HloInstruction* rhs = builder.AddInstruction(
336       HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1));
337 
338   // lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1.
339   HloInstruction* dot = builder.AddInstruction(
340       CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
341 
342   auto module = CreateNewVerifiedModule();
343   auto computation = module->AddEntryComputation(builder.Build());
344 
345   EXPECT_TRUE(PropagatePrecision(module.get()));
346 
347   EXPECT_EQ(computation->root_instruction(), dot);
348   EXPECT_TRUE(OutputsBF16(add1));
349   EXPECT_TRUE(OutputsBF16(lhs));
350 
351   // add0 and rhs have been eliminated by simplification and DCE.
352 }
353 
354 // Tests that a non-fusion computation's root should not be changed.
TEST_F(BFloat16PropagationTest,DoNotChangeComputationRoot)355 TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
356   auto builder = HloComputation::Builder(TestName());
357   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
358 
359   HloInstruction* a =
360       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
361   HloInstruction* b =
362       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
363   HloInstruction* add = builder.AddInstruction(
364       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
365 
366   HloInstruction* dot = builder.AddInstruction(CreateDot(shape, add, add));
367 
368   HloInstruction* tuple =
369       builder.AddInstruction(HloInstruction::CreateTuple({add, dot}));
370 
371   auto module = CreateNewVerifiedModule();
372   auto computation = module->AddEntryComputation(builder.Build());
373 
374   EXPECT_FALSE(PropagatePrecision(module.get()));
375 
376   EXPECT_EQ(computation->root_instruction(), tuple);
377   EXPECT_FALSE(OutputsBF16(add));
378 }
379 
380 // Tests that BF16 is propagated properly through fused computations.
TEST_F(BFloat16PropagationTest,PropagateThroughFusion)381 TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
382   auto module = CreateNewVerifiedModule();
383   auto builder = HloComputation::Builder(TestName());
384   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
385 
386   HloInstruction* param = builder.AddInstruction(
387       HloInstruction::CreateParameter(0, shape, "param"));
388   HloInstruction* add = builder.AddInstruction(
389       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
390 
391   auto builder_f0 = HloComputation::Builder("fusion0");
392   HloInstruction* a_f0 =
393       builder_f0.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
394   HloInstruction* b_f0 =
395       builder_f0.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
396   HloInstruction* tuple_f0 =
397       builder_f0.AddInstruction(HloInstruction::CreateTuple({a_f0, b_f0}));
398   auto comp_f0 = module->AddEmbeddedComputation(builder_f0.Build());
399   auto fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
400       tuple_f0->shape(), HloInstruction::FusionKind::kCustom, {add, add},
401       comp_f0));
402 
403   auto builder_f1 = HloComputation::Builder("fusion1");
404   HloInstruction* p_f1 = builder_f1.AddInstruction(
405       HloInstruction::CreateParameter(0, tuple_f0->shape(), "param"));
406   HloInstruction* a_f1 = builder_f1.AddInstruction(
407       HloInstruction::CreateGetTupleElement(shape, p_f1, 0));
408   HloInstruction* b_f1 = builder_f1.AddInstruction(
409       HloInstruction::CreateGetTupleElement(shape, p_f1, 1));
410   HloInstruction* dot = builder_f1.AddInstruction(CreateDot(shape, a_f1, b_f1));
411   auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build());
412   auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion(
413       dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1));
414 
415   auto computation = module->AddEntryComputation(builder.Build());
416 
417   EXPECT_TRUE(PropagatePrecision(module.get()));
418 
419   EXPECT_EQ(computation->root_instruction(), fusion1);
420   EXPECT_TRUE(OutputsBF16(add));
421   EXPECT_TRUE(OutputsBF16(a_f0));
422   EXPECT_TRUE(OutputsBF16(b_f0));
423   EXPECT_TRUE(OutputsBF16(a_f1));
424   EXPECT_TRUE(OutputsBF16(b_f1));
425 }
426 
427 // Tests that a fusion with a bitcast-convert as its root is changed via adding
428 // extra convert, instead of changing the type in-place.
TEST_F(BFloat16PropagationTest,FusionWithBitcastConvertRoot)429 TEST_F(BFloat16PropagationTest, FusionWithBitcastConvertRoot) {
430   auto module = CreateNewVerifiedModule();
431   auto builder = HloComputation::Builder(TestName());
432   Shape u32_shape = ShapeUtil::MakeShape(U32, {4, 4});
433   Shape f32_shape = ShapeUtil::MakeShape(F32, {4, 4});
434 
435   HloInstruction* param = builder.AddInstruction(
436       HloInstruction::CreateParameter(0, u32_shape, "param"));
437 
438   auto builder_f = HloComputation::Builder("fusion");
439   HloInstruction* a_f = builder_f.AddInstruction(
440       HloInstruction::CreateParameter(0, u32_shape, "a"));
441   HloInstruction* bc_f = builder_f.AddInstruction(
442       HloInstruction::CreateBitcastConvert(f32_shape, a_f));
443   auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
444   auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
445       f32_shape, HloInstruction::FusionKind::kLoop, {param}, comp_f));
446   auto dot = builder.AddInstruction(CreateDot(f32_shape, fusion, fusion));
447 
448   auto computation = module->AddEntryComputation(builder.Build());
449   EXPECT_TRUE(PropagatePrecision(module.get()));
450 
451   EXPECT_EQ(computation->root_instruction(), dot);
452   EXPECT_EQ(bc_f->shape(), f32_shape);
453   EXPECT_TRUE(OutputsBF16(bc_f));
454 }
455 
456 // Tests that changes to BF16 that cannot be propagated outside a fusion are
457 // discarded.
TEST_F(BFloat16PropagationTest,DiscardFusionInternalBF16Changes)458 TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) {
459   auto module = CreateNewVerifiedModule();
460   auto builder = HloComputation::Builder(TestName());
461   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
462 
463   HloInstruction* param = builder.AddInstruction(
464       HloInstruction::CreateParameter(0, shape, "param"));
465   HloInstruction* add = builder.AddInstruction(
466       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
467 
468   auto builder_f = HloComputation::Builder("fusion");
469   HloInstruction* a_f =
470       builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
471   HloInstruction* b_f =
472       builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
473   HloInstruction* add_f = builder_f.AddInstruction(
474       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f));
475   HloInstruction* dot_f =
476       builder_f.AddInstruction(CreateDot(shape, add_f, add_f));
477   auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
478   auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
479       dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f));
480 
481   auto computation = module->AddEntryComputation(builder.Build());
482 
483   EXPECT_FALSE(PropagatePrecision(module.get()));
484   EXPECT_EQ(computation->root_instruction(), fusion);
485 }
486 
487 // Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion
488 // outputs are only used by a dot, and 3) one element of the tuple is used by
489 // an add in the fusion computation, then the propagation pass should create a
490 // convert in the fusion computation to keep the add's operand in F32 but change
491 // the fusion output to BF16. E.g., the following fusion computation
492 //   (F32, F32) fusion_computation(F32 a, F32 b)
493 //     = tuple(F32 a, F32 add(F32 a, F32 b))
494 // will be changed to
495 //   (BF16, BF16) fusion_computation(F32 a, F32 b)
496 //     = tuple(BF16 convert(a), BF16 add(F32 a, F32 b))
TEST_F(BFloat16PropagationTest,ConvertTupleFusionElementIfUsedByAdd)497 TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) {
498   auto module = CreateNewVerifiedModule();
499   auto builder = HloComputation::Builder(TestName());
500   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
501 
502   HloInstruction* param = builder.AddInstruction(
503       HloInstruction::CreateParameter(0, shape, "param"));
504   HloInstruction* add = builder.AddInstruction(
505       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
506 
507   auto builder_f = HloComputation::Builder("fusion0");
508   HloInstruction* a_f =
509       builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
510   HloInstruction* b_f =
511       builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
512   HloInstruction* add_f = builder_f.AddInstruction(
513       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f));
514   HloInstruction* tuple_f =
515       builder_f.AddInstruction(HloInstruction::CreateTuple({a_f, add_f}));
516   auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
517   auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
518       tuple_f->shape(), HloInstruction::FusionKind::kCustom, {add, add},
519       comp_f));
520 
521   HloInstruction* gte0 = builder.AddInstruction(
522       HloInstruction::CreateGetTupleElement(shape, fusion, 0));
523   HloInstruction* gte1 = builder.AddInstruction(
524       HloInstruction::CreateGetTupleElement(shape, fusion, 1));
525   HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1));
526 
527   auto computation = module->AddEntryComputation(builder.Build());
528 
529   EXPECT_TRUE(PropagatePrecision(module.get()));
530 
531   EXPECT_EQ(computation->root_instruction(), dot);
532   EXPECT_TRUE(OutputsBF16(gte0));
533   EXPECT_TRUE(OutputsBF16(gte1));
534   EXPECT_FALSE(OutputsBF16(a_f));
535   EXPECT_FALSE(OutputsBF16(b_f));
536   EXPECT_TRUE(OutputsBF16(add_f));
537   auto new_fusion_root = comp_f->root_instruction();
538   EXPECT_EQ(new_fusion_root->opcode(), HloOpcode::kTuple);
539   EXPECT_EQ(new_fusion_root->operand(1), add_f);
540   EXPECT_EQ(new_fusion_root->operand(0)->opcode(), HloOpcode::kConvert);
541   EXPECT_TRUE(OutputsBF16(new_fusion_root->operand(0)));
542 }
543 
544 
545 // Tests that BF16 is propagated properly through a while computation with
546 // non-tuple input/output.
TEST_F(BFloat16PropagationTest,PropagateThroughSimpleWhile)547 TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) {
548   auto module = CreateNewVerifiedModule();
549   auto builder = HloComputation::Builder(TestName());
550   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
551 
552   HloInstruction* param0 = builder.AddInstruction(
553       HloInstruction::CreateParameter(0, shape, "param0"));
554   HloInstruction* param1 = builder.AddInstruction(
555       HloInstruction::CreateParameter(1, shape, "param1"));
556   HloInstruction* add = builder.AddInstruction(
557       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
558 
559   auto builder_cond = HloComputation::Builder("cond");
560   auto cond_param = builder_cond.AddInstruction(
561       HloInstruction::CreateParameter(0, shape, "cond_param"));
562   auto cond_dot =
563       builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param));
564   auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateCompare(
565       ShapeUtil::MakeShape(PRED, {}),
566       builder_cond.AddInstruction(HloInstruction::CreateReshape(
567           ShapeUtil::MakeShape(F32, {}),
568           builder_cond.AddInstruction(
569               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
570                                           cond_dot, {0, 0}, {1, 1}, {1, 1})))),
571       builder_cond.AddInstruction(HloInstruction::CreateReshape(
572           ShapeUtil::MakeShape(F32, {}),
573           builder_cond.AddInstruction(
574               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
575                                           cond_dot, {1, 1}, {2, 2}, {1, 1})))),
576       ComparisonDirection::kGt));
577   auto cond = module->AddEmbeddedComputation(builder_cond.Build());
578 
579   auto builder_body = HloComputation::Builder("body");
580   auto body_param = builder_body.AddInstruction(
581       HloInstruction::CreateParameter(0, shape, "body_param"));
582   auto body_dot =
583       builder_body.AddInstruction(CreateDot(shape, body_param, body_param));
584   auto body = module->AddEmbeddedComputation(builder_body.Build());
585 
586   auto while_hlo = builder.AddInstruction(
587       HloInstruction::CreateWhile(shape, cond, body, add));
588 
589   auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
590   auto computation = module->AddEntryComputation(builder.Build());
591 
592   EXPECT_TRUE(PropagatePrecision(module.get()));
593 
594   EXPECT_EQ(computation->root_instruction(), dot);
595   EXPECT_TRUE(
596       ShapeUtil::Equal(cond_root->shape(), ShapeUtil::MakeShape(PRED, {})));
597   EXPECT_TRUE(OutputsBF16(add));
598   EXPECT_TRUE(OutputsBF16(body_dot));
599   EXPECT_TRUE(OutputsBF16(body_param));
600   EXPECT_TRUE(OutputsBF16(cond_param));
601   EXPECT_FALSE(OutputsBF16(dot));
602 }
603 
604 // Tests that if the while condition prevents using BF16, no changes should be
605 // made to the while body and thus the fusion node inside it.
TEST_F(BFloat16PropagationTest,ConditionPreventsPropagationForFusionInsideWhile)606 TEST_F(BFloat16PropagationTest,
607        ConditionPreventsPropagationForFusionInsideWhile) {
608   auto module = CreateNewVerifiedModule();
609   auto builder = HloComputation::Builder(TestName());
610   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
611 
612   HloInstruction* param0 = builder.AddInstruction(
613       HloInstruction::CreateParameter(0, shape, "param0"));
614   HloInstruction* param1 = builder.AddInstruction(
615       HloInstruction::CreateParameter(1, shape, "param1"));
616   HloInstruction* add = builder.AddInstruction(
617       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
618 
619   auto builder_cond = HloComputation::Builder("cond");
620   auto cond_param = builder_cond.AddInstruction(
621       HloInstruction::CreateParameter(0, shape, "cond_param"));
622   builder_cond.AddInstruction(HloInstruction::CreateCompare(
623       ShapeUtil::MakeShape(PRED, {}),
624       builder_cond.AddInstruction(HloInstruction::CreateReshape(
625           ShapeUtil::MakeShape(F32, {}),
626           builder_cond.AddInstruction(HloInstruction::CreateSlice(
627               ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {0, 0}, {1, 1},
628               {1, 1})))),
629       builder_cond.AddInstruction(HloInstruction::CreateReshape(
630           ShapeUtil::MakeShape(F32, {}),
631           builder_cond.AddInstruction(HloInstruction::CreateSlice(
632               ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2},
633               {1, 1})))),
634       ComparisonDirection::kGt));
635   auto cond = module->AddEmbeddedComputation(builder_cond.Build());
636 
637   auto builder_body = HloComputation::Builder("body");
638   auto body_param = builder_body.AddInstruction(
639       HloInstruction::CreateParameter(0, shape, "body_param"));
640   auto body_transpose = builder_body.AddInstruction(
641       HloInstruction::CreateTranspose(shape, body_param, {0, 1}));
642 
643   auto builder_f = HloComputation::Builder("fusion");
644   HloInstruction* a_f =
645       builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
646   builder_f.AddInstruction(HloInstruction::CreateTranspose(shape, a_f, {0, 1}));
647   auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
648   auto body_fusion = builder_body.AddInstruction(HloInstruction::CreateFusion(
649       shape, HloInstruction::FusionKind::kCustom, {body_transpose}, comp_f));
650   auto body = module->AddEmbeddedComputation(builder_body.Build());
651 
652   auto while_hlo = builder.AddInstruction(
653       HloInstruction::CreateWhile(shape, cond, body, add));
654 
655   auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
656   auto computation = module->AddEntryComputation(builder.Build());
657 
658   EXPECT_FALSE(PropagatePrecision(module.get()));
659   EXPECT_EQ(computation->root_instruction(), dot);
660   EXPECT_FALSE(OutputsBF16(add));
661   EXPECT_FALSE(OutputsBF16(body_fusion));
662   EXPECT_FALSE(OutputsBF16(body_param));
663   EXPECT_FALSE(OutputsBF16(body_transpose));
664   EXPECT_FALSE(OutputsBF16(a_f));
665 }
666 
667 // Tests that BF16 is propagated properly through while computations with
668 // tuple-shaped input/output.
TEST_F(BFloat16PropagationTest,PropagateThroughTupleWhile)669 TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
670   auto module = CreateNewVerifiedModule();
671   auto builder = HloComputation::Builder(TestName());
672   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
673 
674   HloInstruction* param0 = builder.AddInstruction(
675       HloInstruction::CreateParameter(0, shape, "param0"));
676   HloInstruction* param1 = builder.AddInstruction(
677       HloInstruction::CreateParameter(1, shape, "param1"));
678   HloInstruction* add0 = builder.AddInstruction(
679       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
680   HloInstruction* add1 = builder.AddInstruction(
681       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
682   HloInstruction* tuple =
683       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
684 
685   auto builder_cond = HloComputation::Builder("cond");
686   auto cond_param = builder_cond.AddInstruction(
687       HloInstruction::CreateParameter(0, tuple->shape(), "cond_param"));
688   auto cond_lhs = builder_cond.AddInstruction(
689       HloInstruction::CreateGetTupleElement(shape, cond_param, 0));
690   auto cond_rhs = builder_cond.AddInstruction(
691       HloInstruction::CreateGetTupleElement(shape, cond_param, 1));
692   // This add should prevent RHS from using BF16
693   auto cond_add_rhs = builder_cond.AddInstruction(
694       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs));
695   auto cond_dot =
696       builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs));
697   builder_cond.AddInstruction(HloInstruction::CreateCompare(
698       ShapeUtil::MakeShape(PRED, {}),
699       builder_cond.AddInstruction(HloInstruction::CreateReshape(
700           ShapeUtil::MakeShape(F32, {}),
701           builder_cond.AddInstruction(
702               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
703                                           cond_dot, {0, 0}, {1, 1}, {1, 1})))),
704       builder_cond.AddInstruction(HloInstruction::CreateReshape(
705           ShapeUtil::MakeShape(F32, {}),
706           builder_cond.AddInstruction(
707               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
708                                           cond_dot, {1, 1}, {2, 2}, {1, 1})))),
709       ComparisonDirection::kGt));
710   auto cond = module->AddEmbeddedComputation(builder_cond.Build());
711 
712   auto builder_body = HloComputation::Builder("body");
713   auto body_param = builder_body.AddInstruction(
714       HloInstruction::CreateParameter(0, tuple->shape(), "body_param"));
715   auto body_lhs = builder_body.AddInstruction(
716       HloInstruction::CreateGetTupleElement(shape, body_param, 0));
717   auto body_rhs = builder_body.AddInstruction(
718       HloInstruction::CreateGetTupleElement(shape, body_param, 1));
719   auto body_dot1 =
720       builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
721   auto body_dot2 =
722       builder_body.AddInstruction(CreateDot(shape, body_rhs, body_lhs));
723   auto body_transpose = builder_body.AddInstruction(
724       HloInstruction::CreateTranspose(shape, body_dot2, {0, 1}));
725   builder_body.AddInstruction(
726       HloInstruction::CreateTuple({body_dot1, body_transpose}));
727   auto body = module->AddEmbeddedComputation(builder_body.Build());
728 
729   auto while_hlo = builder.AddInstruction(
730       HloInstruction::CreateWhile(tuple->shape(), cond, body, tuple));
731 
732   auto lhs = builder.AddInstruction(
733       HloInstruction::CreateGetTupleElement(shape, while_hlo, 0));
734   auto rhs = builder.AddInstruction(
735       HloInstruction::CreateGetTupleElement(shape, while_hlo, 1));
736   auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
737   auto computation = module->AddEntryComputation(builder.Build());
738 
739   EXPECT_TRUE(PropagatePrecision(module.get()));
740 
741   EXPECT_EQ(computation->root_instruction(), dot);
742   EXPECT_TRUE(OutputsBF16(lhs));
743   EXPECT_FALSE(OutputsBF16(rhs));
744   EXPECT_TRUE(OutputsBF16(body_dot1));
745   EXPECT_TRUE(OutputsBF16(body_lhs));
746   EXPECT_FALSE(OutputsBF16(body_rhs));
747   EXPECT_FALSE(OutputsBF16(body_dot2));
748   EXPECT_FALSE(OutputsBF16(body_transpose));
749   EXPECT_TRUE(OutputsBF16(cond_lhs));
750   EXPECT_FALSE(OutputsBF16(cond_rhs));
751   EXPECT_TRUE(OutputsBF16(add0));
752   EXPECT_FALSE(OutputsBF16(add1));
753 }
754 
755 // Tests that BF16 is not propagated through multiple whiles that invoke the
756 // same computation as long as one while prevents the propagation.
TEST_F(BFloat16PropagationTest,DoNotPropagateWhilesCallingSameComputation)757 TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
758   auto module = CreateNewVerifiedModule();
759   auto builder = HloComputation::Builder(TestName());
760   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
761 
762   HloInstruction* param0 = builder.AddInstruction(
763       HloInstruction::CreateParameter(0, shape, "param0"));
764   HloInstruction* param1 = builder.AddInstruction(
765       HloInstruction::CreateParameter(1, shape, "param1"));
766   HloInstruction* add0 = builder.AddInstruction(
767       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
768   HloInstruction* add1 = builder.AddInstruction(
769       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
770   HloInstruction* add2 = builder.AddInstruction(
771       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
772   HloInstruction* add3 = builder.AddInstruction(
773       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
774   HloInstruction* tuple0 =
775       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
776   HloInstruction* tuple1 =
777       builder.AddInstruction(HloInstruction::CreateTuple({add2, add3}));
778 
779   // Condition computation for the first while.
780   auto builder_cond0 = HloComputation::Builder("cond0");
781   auto cond0_param = builder_cond0.AddInstruction(
782       HloInstruction::CreateParameter(0, tuple0->shape(), "cond0_param"));
783   auto cond0_lhs = builder_cond0.AddInstruction(
784       HloInstruction::CreateGetTupleElement(shape, cond0_param, 0));
785   auto cond0_rhs = builder_cond0.AddInstruction(
786       HloInstruction::CreateGetTupleElement(shape, cond0_param, 1));
787   // This add should prevent RHS from using BF16
788   auto cond0_add_rhs =
789       builder_cond0.AddInstruction(HloInstruction::CreateBinary(
790           shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs));
791   auto cond0_dot =
792       builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs));
793   builder_cond0.AddInstruction(HloInstruction::CreateCompare(
794       ShapeUtil::MakeShape(PRED, {}),
795       builder_cond0.AddInstruction(HloInstruction::CreateReshape(
796           ShapeUtil::MakeShape(F32, {}),
797           builder_cond0.AddInstruction(
798               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
799                                           cond0_dot, {0, 0}, {1, 1}, {1, 1})))),
800       builder_cond0.AddInstruction(HloInstruction::CreateReshape(
801           ShapeUtil::MakeShape(F32, {}),
802           builder_cond0.AddInstruction(
803               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
804                                           cond0_dot, {1, 1}, {2, 2}, {1, 1})))),
805       ComparisonDirection::kGt));
806   auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build());
807 
808   // Condition computation for the second while.
809   auto builder_cond1 = HloComputation::Builder("cond1");
810   auto cond1_param = builder_cond1.AddInstruction(
811       HloInstruction::CreateParameter(0, tuple1->shape(), "cond1_param"));
812   auto cond1_lhs = builder_cond1.AddInstruction(
813       HloInstruction::CreateGetTupleElement(shape, cond1_param, 0));
814   auto cond1_rhs = builder_cond1.AddInstruction(
815       HloInstruction::CreateGetTupleElement(shape, cond1_param, 1));
816   // This add should prevent LHS from using BF16
817   auto cond1_add_lhs =
818       builder_cond1.AddInstruction(HloInstruction::CreateBinary(
819           shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs));
820   auto cond1_dot =
821       builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs));
822   builder_cond1.AddInstruction(HloInstruction::CreateCompare(
823       ShapeUtil::MakeShape(PRED, {}),
824       builder_cond1.AddInstruction(HloInstruction::CreateReshape(
825           ShapeUtil::MakeShape(F32, {}),
826           builder_cond1.AddInstruction(
827               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
828                                           cond1_dot, {0, 0}, {1, 1}, {1, 1})))),
829       builder_cond1.AddInstruction(HloInstruction::CreateReshape(
830           ShapeUtil::MakeShape(F32, {}),
831           builder_cond1.AddInstruction(
832               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
833                                           cond1_dot, {1, 1}, {2, 2}, {1, 1})))),
834       ComparisonDirection::kGt));
835   auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build());
836 
837   // Body computation shared by both whiles.
838   auto builder_body = HloComputation::Builder("body");
839   auto body_param = builder_body.AddInstruction(
840       HloInstruction::CreateParameter(0, tuple0->shape(), "body_param"));
841   auto body_lhs = builder_body.AddInstruction(
842       HloInstruction::CreateGetTupleElement(shape, body_param, 0));
843   auto body_rhs = builder_body.AddInstruction(
844       HloInstruction::CreateGetTupleElement(shape, body_param, 1));
845   auto body_dot =
846       builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
847   builder_body.AddInstruction(
848       HloInstruction::CreateTuple({body_dot, body_rhs}));
849   auto body = module->AddEmbeddedComputation(builder_body.Build());
850 
851   auto while0 = builder.AddInstruction(
852       HloInstruction::CreateWhile(tuple0->shape(), cond0, body, tuple0));
853   auto while1 = builder.AddInstruction(
854       HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1));
855 
856   auto lhs = builder.AddInstruction(
857       CreateDot(shape,
858                 builder.AddInstruction(
859                     HloInstruction::CreateGetTupleElement(shape, while0, 0)),
860                 builder.AddInstruction(
861                     HloInstruction::CreateGetTupleElement(shape, while0, 1))));
862   auto rhs = builder.AddInstruction(
863       CreateDot(shape,
864                 builder.AddInstruction(
865                     HloInstruction::CreateGetTupleElement(shape, while1, 0)),
866                 builder.AddInstruction(
867                     HloInstruction::CreateGetTupleElement(shape, while1, 1))));
868   auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
869   auto computation = module->AddEntryComputation(builder.Build());
870 
871   EXPECT_TRUE(PropagatePrecision(module.get()));
872   EXPECT_FALSE(OutputsBF16(body_dot));
873   EXPECT_FALSE(OutputsBF16(body_rhs));
874   EXPECT_FALSE(OutputsBF16(body_lhs));
875   EXPECT_FALSE(OutputsBF16(cond0_lhs));
876   EXPECT_FALSE(OutputsBF16(cond0_rhs));
877   EXPECT_FALSE(OutputsBF16(cond1_lhs));
878   EXPECT_FALSE(OutputsBF16(cond1_rhs));
879   EXPECT_TRUE(OutputsBF16(cond0_add_rhs));
880   EXPECT_TRUE(OutputsBF16(cond1_add_lhs));
881   EXPECT_EQ(computation->root_instruction(), dot);
882 }
883 
884 // Tests that if this pass turns an F32 -> BF16 conversion into a no-op (BF16 ->
885 // BF16 conversion), then it will remove that conversion.
TEST_F(BFloat16PropagationTest,NoopConversionRemoved)886 TEST_F(BFloat16PropagationTest, NoopConversionRemoved) {
887   auto builder = HloComputation::Builder(TestName());
888   Shape f32_shape = ShapeUtil::MakeShape(F32, {4, 4});
889   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {4, 4});
890 
891   HloInstruction* param = builder.AddInstruction(
892       HloInstruction::CreateParameter(0, f32_shape, "param"));
893   HloInstruction* add0 = builder.AddInstruction(
894       HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param));
895   HloInstruction* add1 = builder.AddInstruction(
896       HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param));
897   HloInstruction* tuple =
898       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
899   HloInstruction* gte0 = builder.AddInstruction(
900       HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0));
901   HloInstruction* gte1 = builder.AddInstruction(
902       HloInstruction::CreateGetTupleElement(f32_shape, tuple, 1));
903   HloInstruction* convert0 =
904       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte0));
905   HloInstruction* convert1 =
906       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte1));
907   HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary(
908       bf16_shape, HloOpcode::kAdd, convert0, convert1));
909 
910   auto module = CreateNewVerifiedModule();
911   auto computation = module->AddEntryComputation(builder.Build());
912 
913   EXPECT_TRUE(PropagatePrecision(module.get()));
914 
915   EXPECT_EQ(computation->root_instruction(), add2);
916   EXPECT_EQ(add2->operand(0), add0);
917   EXPECT_EQ(add2->operand(1), add1);
918   EXPECT_EQ(add0->shape().element_type(), BF16);
919   EXPECT_EQ(add1->shape().element_type(), BF16);
920 }
921 
TEST_F(BFloat16PropagationTest,TupleDomain)922 TEST_F(BFloat16PropagationTest, TupleDomain) {
923   auto builder = HloComputation::Builder(TestName());
924   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
925 
926   HloInstruction* a =
927       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
928   HloInstruction* b =
929       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
930   HloInstruction* a_trans =
931       builder.AddInstruction(HloInstruction::CreateTranspose(shape, a, {0, 1}));
932   HloInstruction* b_trans =
933       builder.AddInstruction(HloInstruction::CreateTranspose(shape, b, {0, 1}));
934   HloInstruction* tuple =
935       builder.AddInstruction(HloInstruction::CreateTuple({a_trans, b_trans}));
936   HloInstruction* domain = builder.AddInstruction(
937       HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr));
938   HloInstruction* a_gte = builder.AddInstruction(
939       HloInstruction::CreateGetTupleElement(shape, domain, 0));
940   HloInstruction* b_gte = builder.AddInstruction(
941       HloInstruction::CreateGetTupleElement(shape, domain, 1));
942   HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a_gte, b_gte));
943   HloInstruction* root = builder.AddInstruction(
944       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
945 
946   auto module = CreateNewVerifiedModule();
947   auto computation = module->AddEntryComputation(builder.Build());
948 
949   EXPECT_TRUE(PropagatePrecision(module.get()));
950   EXPECT_EQ(computation->root_instruction(), root);
951 
952   // test BF16 propagated through domain
953   EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 0).element_type(),
954             BF16);
955   EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 1).element_type(),
956             BF16);
957 
958   EXPECT_TRUE(OutputsBF16(a_trans));
959   EXPECT_TRUE(OutputsBF16(b_trans));
960   EXPECT_TRUE(OutputsBF16(a_gte));
961   EXPECT_TRUE(OutputsBF16(b_gte));
962   EXPECT_FALSE(OutputsBF16(a));
963   EXPECT_FALSE(OutputsBF16(b));
964 }
965 
966 // Tests that bf16 is not propagated through a domain in case its input cannot
967 // be propagated. In the case below the input of the domain is the parameter
968 // tuple which cannot be propagated, so the domain instruction is not propagated
969 // either.
TEST_F(BFloat16PropagationTest,TupleDomainNoPropagation)970 TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) {
971   auto builder = HloComputation::Builder(TestName());
972   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
973   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
974 
975   HloInstruction* param = builder.AddInstruction(
976       HloInstruction::CreateParameter(0, tuple_shape, "param"));
977   HloInstruction* domain = builder.AddInstruction(
978       HloInstruction::CreateDomain(param->shape(), param, nullptr, nullptr));
979   HloInstruction* a_gte = builder.AddInstruction(
980       HloInstruction::CreateGetTupleElement(shape, domain, 0));
981   HloInstruction* b_gte = builder.AddInstruction(
982       HloInstruction::CreateGetTupleElement(shape, domain, 1));
983   HloInstruction* a_trans = builder.AddInstruction(
984       HloInstruction::CreateTranspose(shape, a_gte, {0, 1}));
985   HloInstruction* b_trans = builder.AddInstruction(
986       HloInstruction::CreateTranspose(shape, b_gte, {0, 1}));
987   HloInstruction* dot =
988       builder.AddInstruction(CreateDot(shape, a_trans, b_trans));
989   HloInstruction* root = builder.AddInstruction(
990       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
991 
992   auto module = CreateNewVerifiedModule();
993   auto computation = module->AddEntryComputation(builder.Build());
994 
995   EXPECT_TRUE(PropagatePrecision(module.get()));
996 
997   EXPECT_EQ(computation->root_instruction(), root);
998   EXPECT_TRUE(OutputsBF16(a_trans));
999   EXPECT_TRUE(OutputsBF16(b_trans));
1000   EXPECT_FALSE(OutputsBF16(a_gte));
1001   EXPECT_FALSE(OutputsBF16(b_gte));
1002   EXPECT_FALSE(OutputsBF16(domain));
1003   EXPECT_FALSE(OutputsBF16(param));
1004 }
1005 
TEST_F(BFloat16PropagationTest,ConditionalSeparateBranchOperands)1006 TEST_F(BFloat16PropagationTest, ConditionalSeparateBranchOperands) {
1007   const std::string module_str = R"(
1008 HloModule module
1009 
1010 true_branch {
1011   true_param = f32[4096,4096] parameter(0)
1012   ROOT max = f32[4096,4096] maximum(true_param, true_param)
1013 }
1014 
1015 false_branch {
1016   false_param = f32[4096,4096] parameter(0)
1017   ROOT add = f32[4096,4096] add(false_param, false_param)
1018 }
1019 
1020 ENTRY entry {
1021   param0 = f32[4096,4096] parameter(0)
1022   param1 = f32[4096,4096] parameter(1)
1023   copy0 = f32[4096,4096] copy(param0)
1024   copy1 = f32[4096,4096] copy(param1)
1025   param2 = pred[] parameter(2)
1026   conditional = f32[4096,4096] conditional(param2, copy0, copy1),
1027     true_computation=true_branch, false_computation=false_branch
1028   ROOT dot = f32[4096,4096] dot(conditional, conditional),
1029     lhs_contracting_dims={1}, rhs_contracting_dims={0}
1030 }
1031 )";
1032 
1033   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1034                           ParseAndReturnVerifiedModule(module_str));
1035   EXPECT_TRUE(PropagatePrecision(module.get()));
1036 
1037   auto cond = FindInstruction(module.get(), "conditional");
1038   auto copy0 = FindInstruction(module.get(), "copy0");
1039   auto copy1 = FindInstruction(module.get(), "copy1");
1040   EXPECT_TRUE(OutputsBF16(cond));
1041   EXPECT_TRUE(OutputsBF16(copy0));
1042   EXPECT_FALSE(OutputsBF16(copy1));
1043 }
1044 
TEST_F(BFloat16PropagationTest,ConditionalSharedBranchOperands)1045 TEST_F(BFloat16PropagationTest, ConditionalSharedBranchOperands) {
1046   const std::string module_str = R"(
1047 HloModule module
1048 
1049 true_branch {
1050   true_param = f32[4096,4096] parameter(0)
1051   ROOT max = f32[4096,4096] maximum(true_param, true_param)
1052 }
1053 
1054 false_branch {
1055   false_param = f32[4096,4096] parameter(0)
1056   ROOT add = f32[4096,4096] add(false_param, false_param)
1057 }
1058 
1059 ENTRY entry {
1060   param0 = f32[4096,4096] parameter(0)
1061   copy0 = f32[4096,4096] copy(param0)
1062   param1 = pred[] parameter(1)
1063   conditional = f32[4096,4096] conditional(param1, copy0, copy0),
1064     true_computation=true_branch, false_computation=false_branch
1065   ROOT dot = f32[4096,4096] dot(conditional, conditional),
1066     lhs_contracting_dims={1}, rhs_contracting_dims={0}
1067 }
1068 )";
1069 
1070   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1071                           ParseAndReturnVerifiedModule(module_str));
1072   EXPECT_TRUE(PropagatePrecision(module.get()));
1073 
1074   auto cond = FindInstruction(module.get(), "conditional");
1075   auto copy0 = FindInstruction(module.get(), "copy0");
1076   EXPECT_TRUE(OutputsBF16(cond));
1077   EXPECT_FALSE(OutputsBF16(copy0));
1078 }
1079 
TEST_F(BFloat16PropagationTest,ConditionalAliasingOutputs)1080 TEST_F(BFloat16PropagationTest, ConditionalAliasingOutputs) {
1081   const std::string module_str = R"(
1082 HloModule module
1083 
1084 true_branch {
1085   true_param = f32[4096,4096] parameter(0)
1086   max = f32[4096,4096] maximum(true_param, true_param)
1087   ROOT true_tuple = (f32[4096,4096], f32[4096,4096]) tuple(max, max)
1088 }
1089 
1090 false_branch {
1091   false_param = f32[4096,4096] parameter(0)
1092   min = f32[4096,4096] minimum(false_param, false_param)
1093   max2 = f32[4096,4096] maximum(false_param, false_param)
1094   ROOT false_tuple = (f32[4096,4096], f32[4096,4096]) tuple(min, max2)
1095 }
1096 
1097 ENTRY entry {
1098   param0 = f32[4096,4096] parameter(0)
1099   copy0 = f32[4096,4096] copy(param0)
1100   param1 = pred[] parameter(1)
1101   conditional = (f32[4096,4096], f32[4096,4096]) conditional(param1, copy0, copy0),
1102     true_computation=true_branch, false_computation=false_branch
1103   gte0 = f32[4096,4096] get-tuple-element(conditional), index=0
1104   gte1 = f32[4096,4096] get-tuple-element(conditional), index=1
1105   dot = f32[4096,4096] dot(gte0, gte1),
1106     lhs_contracting_dims={1}, rhs_contracting_dims={0}
1107   ROOT tuple = (f32[4096,4096], f32[4096,4096]) tuple(dot, gte1)
1108 }
1109 )";
1110 
1111   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1112                           ParseAndReturnVerifiedModule(module_str));
1113   EXPECT_FALSE(PropagatePrecision(module.get()));
1114 }
1115 
TEST_F(BFloat16PropagationTest,DynamicUpdateSlice)1116 TEST_F(BFloat16PropagationTest, DynamicUpdateSlice) {
1117   // This test is crafted so that the DUS has an f32 input (due to parameter)
1118   // and bf16 output (due to dot). But we should enforce DUS operand 0 and
1119   // output to get the same precision since it's an in-place operation.
1120   const std::string module_str = R"(
1121 HloModule Module
1122 
1123 ENTRY main {
1124   param = f32[128,128] parameter(0)
1125   constant.1 = f32[] constant(0)
1126   broadcast.6 = f32[128,1] broadcast(constant.1), dimensions={}
1127   constant.3 = s32[] constant(0)
1128   dynamic-update-slice = f32[128,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3)
1129   ROOT dot = f32[128,128] dot(dynamic-update-slice, dynamic-update-slice), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1130 }
1131 )";
1132 
1133   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1134                           ParseAndReturnVerifiedModule(module_str));
1135   EXPECT_FALSE(PropagatePrecision(module.get()));
1136 
1137   HloInstruction* dus = module->entry_computation()->GetInstructionWithName(
1138       "dynamic-update-slice");
1139   EXPECT_FALSE(OutputsBF16(dus));
1140 }
1141 
1142 // This test demonstrates the need for invoking the ResolveAliasingBuffer
1143 // multiple times via a fixed-point algorithm. The key was the aliasing of the
1144 // two output buffers of the conditional, at subshape 0 (first element). This
1145 // aliasing is not resolved until after the gte0 variale is already processed,
1146 // triggering incorrect type for gte0 if not repeating the aliasing analysis.
TEST_F(BFloat16PropagationTest,ConditionalGTEWithFusion)1147 TEST_F(BFloat16PropagationTest, ConditionalGTEWithFusion) {
1148   const std::string module_str = R"(
1149 HloModule module
1150 
1151 %add.0 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
1152   x.1 = f32[4096,4096] parameter(0)
1153   y.1 = f32[4096,4096] parameter(1)
1154   ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
1155     lhs_contracting_dims={1}, rhs_contracting_dims={0}
1156 }
1157 
1158 %add.1 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
1159   x.1 = f32[4096,4096] parameter(0)
1160   y.1 = f32[4096,4096] parameter(1)
1161   ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
1162     lhs_contracting_dims={1}, rhs_contracting_dims={0}
1163 }
1164 
1165 %add.2 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
1166   x.1 = f32[4096,4096] parameter(0)
1167   y.1 = f32[4096,4096] parameter(1)
1168   ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
1169     lhs_contracting_dims={1}, rhs_contracting_dims={0}
1170 }
1171 
1172 %add.3 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
1173   x.1 = f32[4096,4096] parameter(0)
1174   y.1 = f32[4096,4096] parameter(1)
1175   ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
1176     lhs_contracting_dims={1}, rhs_contracting_dims={0}
1177 }
1178 
1179 true_branch {
1180   true_param = f32[4096,4096] parameter(0)
1181   constant.1 = f32[4096,4096] constant(0)
1182   add0 = f32[4096,4096] fusion(true_param,true_param), kind=kLoop, calls=add.0
1183   constant.2 = f32[4096,4096] constant(0)
1184   ROOT tuple.2 = (f32[4096,4096], f32[4096,4096], f32[4096,4096]) tuple(true_param,add0,constant.2)
1185 }
1186 
1187 false_branch {
1188   false_param = f32[4096,4096] parameter(0)
1189   add3 = f32[4096,4096] fusion(false_param,false_param), kind=kLoop, calls=add.1
1190   constant.1 = f32[4096,4096] constant(0)
1191   ROOT tuple.2 = (f32[4096,4096], f32[4096,4096], f32[4096,4096]) tuple(add3, add3,constant.1)
1192 }
1193 
1194 ENTRY entry {
1195   param0 = f32[4096,4096] parameter(0)
1196   copy0 = f32[4096,4096] copy(param0)
1197   param1 = pred[] parameter(1)
1198   conditional = (f32[4096,4096], f32[4096,4096], f32[4096,4096]) conditional(param1, param0, copy0),
1199     true_computation=true_branch, false_computation=false_branch
1200   gte = f32[4096,4096] get-tuple-element(conditional), index=0
1201   gte1 = f32[4096,4096] get-tuple-element(conditional), index=1
1202   gte2 = f32[4096,4096] get-tuple-element(conditional), index=2
1203   add2 = f32[4096,4096] fusion(gte, gte1), kind=kLoop, calls=add.2
1204   ROOT add3 = f32[4096,4096] fusion(add2, gte2), kind=kLoop, calls=add.3
1205   }
1206 )";
1207 
1208   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1209                           ParseAndReturnVerifiedModule(module_str));
1210   EXPECT_TRUE(PropagatePrecision(module.get()));
1211   VLOG(2) << module->ToString() << "\n";
1212   EXPECT_TRUE(HloVerifier(/*layout_sensitive=*/false,
1213                           /*allow_mixed_precision=*/true)
1214                   .Run(module.get())
1215                   .status()
1216                   .ok());
1217   auto gte = FindInstruction(module.get(), "gte");
1218   auto gte1 = FindInstruction(module.get(), "gte1");
1219   auto gte2 = FindInstruction(module.get(), "gte2");
1220   EXPECT_FALSE(OutputsBF16(gte));
1221   EXPECT_FALSE(OutputsBF16(gte1));
1222   EXPECT_TRUE(OutputsBF16(gte2));
1223 }
1224 
1225 }  // namespace xla
1226