xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/conditional_simplifier_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/stream_executor/lib/statusor.h"
34 
35 namespace xla {
36 namespace {
37 
38 namespace op = xla::testing::opcode_matchers;
39 
40 class ConditionalSimplifierTest : public HloTestBase {
41  public:
42   // Makes a computation that contains a conditional with constant predicate.
43   HloComputation* MakeConditional(HloModule* module, bool is_constant = true);
44 };
45 
MakeConditional(HloModule * module,bool is_constant)46 HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module,
47                                                            bool is_constant) {
48   HloComputation::Builder builder(TestName());
49 
50   // true_computation returns param+1.
51   HloComputation* true_computation;
52   {
53     HloComputation::Builder true_computation_builder(TestName() +
54                                                      ".true_computation");
55     auto param =
56         true_computation_builder.AddInstruction(HloInstruction::CreateParameter(
57             0, ShapeUtil::MakeShape(S32, {}), "param"));
58     auto one = true_computation_builder.AddInstruction(
59         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
60 
61     true_computation_builder.AddInstruction(HloInstruction::CreateBinary(
62         ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, one));
63 
64     true_computation =
65         module->AddEmbeddedComputation(true_computation_builder.Build());
66   }
67 
68   // false_computation returns param+42.
69   HloComputation* false_computation;
70   {
71     HloComputation::Builder false_computation_builder(TestName() +
72                                                       ".false_computation");
73     auto param = false_computation_builder.AddInstruction(
74         HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}),
75                                         "param"));
76     auto forty_two = false_computation_builder.AddInstruction(
77         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(42)));
78 
79     false_computation_builder.AddInstruction(HloInstruction::CreateBinary(
80         ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, forty_two));
81     false_computation =
82         module->AddEmbeddedComputation(false_computation_builder.Build());
83   }
84 
85   auto false_instrn = builder.AddInstruction(
86       is_constant
87           ? HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))
88           : HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(PRED, {}),
89                                             "cond"));
90   auto false_param = builder.AddInstruction(HloInstruction::CreateParameter(
91       0, ShapeUtil::MakeShape(S32, {}), "false_param"));
92   auto one = builder.AddInstruction(
93       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
94 
95   builder.AddInstruction(HloInstruction::CreateConditional(
96       ShapeUtil::MakeShape(S32, {}), false_instrn, one, true_computation,
97       false_param, false_computation));
98 
99   return module->AddEntryComputation(builder.Build());
100 }
101 
TEST_F(ConditionalSimplifierTest,ConditionalGetsInlined)102 TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) {
103   auto m = CreateNewVerifiedModule();
104   HloComputation* computation = MakeConditional(m.get());
105   ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
106   EXPECT_THAT(computation->root_instruction(),
107               op::Add(op::Parameter(), op::Constant()));
108 }
109 
TEST_F(ConditionalSimplifierTest,BranchGetsInlined)110 TEST_F(ConditionalSimplifierTest, BranchGetsInlined) {
111   auto m = CreateNewVerifiedModule();
112   HloComputation* computation = MakeConditional(m.get(), /*is_constant=*/false);
113   ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
114   EXPECT_THAT(
115       computation->root_instruction(),
116       op::Select(op::Parameter(1), op::Add(op::Constant(), op::Constant()),
117                  op::Add(op::Parameter(0), op::Constant())));
118 }
119 
TEST_F(ConditionalSimplifierTest,ConditionalWithControlDependency)120 TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) {
121   auto m = CreateNewVerifiedModule();
122   HloComputation* computation = MakeConditional(m.get());
123 
124   auto* true_op = computation->AddInstruction(
125       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
126   TF_ASSERT_OK(
127       true_op->AddControlDependencyTo(computation->root_instruction()));
128 
129   EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
130 }
131 
TEST_F(ConditionalSimplifierTest,NotRemovedIfContainsSend)132 TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) {
133   auto m = CreateNewVerifiedModule();
134   HloComputation* computation = MakeConditional(m.get());
135   auto* conditional = computation->root_instruction();
136   ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
137 
138   auto* true_computation = conditional->true_computation();
139   auto* token = true_computation->AddInstruction(HloInstruction::CreateToken());
140   auto* send = true_computation->AddInstruction(HloInstruction::CreateSend(
141       true_computation->AddInstruction(
142           HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))),
143       token, /*channel_id=*/0));
144   true_computation->AddInstruction(HloInstruction::CreateSendDone(send));
145   EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
146 }
147 
TEST_F(ConditionalSimplifierTest,NotRemovedIfContainsRecv)148 TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) {
149   auto m = CreateNewVerifiedModule();
150   HloComputation* computation = MakeConditional(m.get());
151   auto* conditional = computation->root_instruction();
152   ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
153 
154   auto* true_computation = conditional->true_computation();
155   auto* token = true_computation->AddInstruction(HloInstruction::CreateToken());
156   auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv(
157       ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0));
158   true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv));
159   EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
160 }
161 
TEST_F(ConditionalSimplifierTest,NotRemovedIfContainsNonRemovableInstruction)162 TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) {
163   auto m = CreateNewVerifiedModule();
164   HloComputation* computation = MakeConditional(m.get());
165   auto* conditional = computation->root_instruction();
166   ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
167   auto* false_computation = conditional->false_computation();
168   auto token = false_computation->AddInstruction(HloInstruction::CreateToken());
169   false_computation->AddInstruction(HloInstruction::CreateInfeed(
170       ShapeUtil::MakeShape(F32, {1}), token, "config"));
171   EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
172 }
173 
TEST_F(ConditionalSimplifierTest,TrivalOperandsRemoved)174 TEST_F(ConditionalSimplifierTest, TrivalOperandsRemoved) {
175   absl::string_view hlo_string =
176       R"(
177 HloModule UnusedTupleOperands
178 on_false {
179   t = (f32[20,40], f32[40,40], f32[20,40], f32[40,40]) parameter(0)
180   lhs = f32[20,40] get-tuple-element(t), index=0
181   rhs = f32[40,40] get-tuple-element(t), index=1
182   dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
183   ROOT result = (f32[20,40]) tuple(dot)
184 }
185 
186 on_true {
187   t = (f32[20,40], f32[40,40], f32[20,40], f32[40,40]) parameter(0)
188   lhs = f32[20,40] get-tuple-element(t), index=2
189   rhs = f32[40,40] get-tuple-element(t), index=3
190   dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
191   ROOT result = (f32[20,40]) tuple(dot)
192 }
193 
194 ENTRY main {
195   c0_0 = f32[20,40] parameter(0)
196   c0_1 = f32[40,40] parameter(1)
197   c1_0 = f32[20,40] parameter(2)
198   c1_1 = f32[40,40] parameter(3)
199   p = pred[] parameter(4)
200   t = (f32[20,40], f32[40,40], f32[20,40], f32[40,40]) tuple(c0_0, c0_1, c1_0, c1_1)
201   call = (f32[20,40]) call(t), to_apply=on_true
202   ROOT result = (f32[20,40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true
203 }
204 )";
205   auto status = ParseAndReturnVerifiedModule(hlo_string);
206   TF_ASSERT_OK(status.status());
207   std::unique_ptr<HloModule> module = std::move(status).value();
208   HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
209   TF_ASSERT_OK(v.Run(module.get()).status());
210   EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie());
211   TF_ASSERT_OK(v.Run(module.get()).status());
212   HloInstruction* conditional = module->entry_computation()->root_instruction();
213   EXPECT_TRUE(conditional != nullptr);
214   EXPECT_EQ(conditional->operand(1)->shape().tuple_shapes().size(), 2);
215   EXPECT_EQ(conditional->operand(2)->shape().tuple_shapes().size(), 2);
216   // For the call operation, nothing should have changed.
217   HloInstruction* call = FindInstruction(module.get(), "call");
218   EXPECT_EQ(
219       call->to_apply()->parameter_instruction(0)->shape().tuple_shapes().size(),
220       4);
221 }
222 
TEST_F(ConditionalSimplifierTest,TwoConditionalsCreatedInReversedLexicalOrder)223 TEST_F(ConditionalSimplifierTest,
224        TwoConditionalsCreatedInReversedLexicalOrder) {
225   absl::string_view hlo_string = R"(
226   HloModule DeadConditional
227     computation.1 {
228       param.1 = s64[] parameter(0)
229       constant.1 = s64[] constant(1)
230       ROOT add.1 = s64[] add(param.1, constant.1)
231     }
232 
233     computation.2 {
234       param.2 = s64[] parameter(0)
235       constant.2 = s64[] constant(2)
236       ROOT add.2 = s64[] add(param.2, constant.2)
237    }
238 
239     computation.3 {
240       param.3 = s64[] parameter(0)
241       constant.3 = s64[] constant(3)
242       ROOT add.3 = s64[] add(param.3, constant.3)
243     }
244 
245     computation.4 {
246       param.4 = s64[] parameter(0)
247       constant.4 = s64[] constant(4)
248       ROOT add.4 = s64[] add(param.4, constant.4)
249     }
250 
251     ENTRY KernelEntry {
252       param.1 = s64[] parameter(0)
253       param.2 = s64[] parameter(1)
254       param.3 = s64[] parameter(2)
255       param.4 = pred[] parameter(3)
256 
257       conditional_1 = s64[] conditional(param.4, param.3, param.2),
258         true_computation=computation.3, false_computation=computation.4
259       constant.1 = pred[] constant(false)
260       ROOT conditional_2 = s64[] conditional(constant.1, conditional_1,
261         param.1), true_computation=computation.1,
262         false_computation=computation.2
263     })";
264   auto status = ParseAndReturnVerifiedModule(hlo_string);
265   TF_ASSERT_OK(status.status());
266   std::unique_ptr<HloModule> module = std::move(status).value();
267   HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
268   TF_ASSERT_OK(v.Run(module.get()).status());
269 
270   // Replace conditional_1 with a clone that is created after conditional_2.
271   HloInstruction* conditional_1 =
272       FindInstruction(module.get(), "conditional_1");
273   HloInstruction* conditional_1_clone =
274       conditional_1->parent()->AddInstruction(conditional_1->Clone());
275   TF_ASSERT_OK(conditional_1->ReplaceAllUsesWith(conditional_1_clone));
276   TF_ASSERT_OK(conditional_1->parent()->RemoveInstruction(conditional_1));
277 
278   EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie());
279 }
280 
TEST_F(ConditionalSimplifierTest,RemoveDeadRoots)281 TEST_F(ConditionalSimplifierTest, RemoveDeadRoots) {
282   absl::string_view hlo_string =
283       R"(
284 HloModule RemoveDeadRoots
285 on_false {
286   t = (f32[20,40], f32[40,40]) parameter(0)
287   lhs = f32[20,40] get-tuple-element(t), index=0
288   rhs = f32[40,40] get-tuple-element(t), index=1
289   dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
290   after-all = token[] after-all()
291   outfeed = token[] outfeed(dot, after-all)
292   ROOT result = (f32[20,40]) tuple(dot)
293 }
294 
295 on_true {
296   t = (f32[20,40], f32[40,40]) parameter(0)
297   lhs = f32[20,40] get-tuple-element(t), index=0
298   add = f32[20,40] add(lhs, lhs)
299   ROOT result = (f32[20,40]) tuple(add)
300 }
301 
302 ENTRY main {
303   c0_0 = f32[20,40] parameter(0)
304   c0_1 = f32[40,40] parameter(1)
305   p = pred[] parameter(2)
306   t = (f32[20,40], f32[40,40]) tuple(c0_0, c0_1)
307   conditional = (f32[20, 40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true
308   ROOT result = () tuple()
309 }
310 )";
311   auto status = ParseAndReturnVerifiedModule(hlo_string);
312   TF_ASSERT_OK(status.status());
313   HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
314   TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
315   EXPECT_TRUE(
316       ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie());
317   TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
318   HloInstruction* conditional =
319       FindInstruction(status.ValueOrDie().get(), "conditional");
320   // The conditional root should be replaced with an empty tuple.
321   EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 0);
322 }
323 
TEST_F(ConditionalSimplifierTest,SecondTupleElementUnusedAndRemoved)324 TEST_F(ConditionalSimplifierTest, SecondTupleElementUnusedAndRemoved) {
325   absl::string_view hlo_string =
326       R"(
327 HloModule SecondTupleElementUnusedAndRemoved
328 
329 on_true {
330   arg_tuple.7 = (f32[10,10]{1,0}) parameter(0)
331   get-tuple-element.9 = f32[10,10]{1,0} get-tuple-element(arg_tuple.7), index=0
332   copy = f32[10,10]{1,0} copy(get-tuple-element.9)
333   ROOT tuple.6 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(copy, get-tuple-element.9)
334 }
335 
336 on_false {
337   constant.17 = f32[] constant(0)
338   constant.18 = f32[] constant(1)
339   rng.19 = f32[10,10]{1,0} rng(constant.17, constant.18), distribution=rng_uniform
340   arg_tuple.14 = (f32[10,10]{1,0}) parameter(0)
341   get-tuple-element.16 = f32[10,10]{1,0} get-tuple-element(arg_tuple.14), index=0
342   ROOT tuple.7 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(rng.19, get-tuple-element.16)
343 }
344 
345 ENTRY main {
346   constant.38 = pred[] constant(true)
347   arg_tuple.30 = (s32[], f32[10,10]{1,0}) parameter(0)
348   get-tuple-element.21 = f32[10,10]{1,0} get-tuple-element(arg_tuple.30), index=1
349   tuple.1 = (f32[10,10]{1,0}) tuple(get-tuple-element.21)
350   conditional = (f32[10,10]{1,0}, f32[10,10]{1,0}) conditional(constant.38, tuple.1, tuple.1), true_computation=on_true, false_computation=on_false
351   get-first-index = f32[10,10]{1,0} get-tuple-element(conditional), index=0
352   ROOT result = (f32[10,10]{1,0}) tuple(get-first-index)
353 }
354 )";
355   auto status = ParseAndReturnVerifiedModule(hlo_string);
356   TF_ASSERT_OK(status.status());
357   HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
358   TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
359   EXPECT_TRUE(
360       ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie());
361   TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
362   const HloInstruction* conditional =
363       FindInstruction(status.ValueOrDie().get(), "conditional");
364   // The second element of "conditional" result tuple (f32[10,10], f32[10,10])
365   // should be removed since it is not referenced by any GTE instructions
366   // (see "get-first-index" instruction in hlo_string).
367   EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 1);
368 }
369 
TEST_F(ConditionalSimplifierTest,FirstTupleElementUnusedAndRemoved)370 TEST_F(ConditionalSimplifierTest, FirstTupleElementUnusedAndRemoved) {
371   absl::string_view hlo_string =
372       R"(
373 HloModule FirstTupleElementUnusedAndRemoved
374 
375 on_true {
376   arg_tuple.7 = (f32[10,10]{1,0}) parameter(0)
377   get-tuple-element.9 = f32[10,10]{1,0} get-tuple-element(arg_tuple.7), index=0
378   copy = f32[10,10]{1,0} copy(get-tuple-element.9)
379   ROOT tuple.6 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(copy, get-tuple-element.9)
380 }
381 
382 on_false {
383   constant.17 = f32[] constant(0)
384   constant.18 = f32[] constant(1)
385   rng.19 = f32[10,10]{1,0} rng(constant.17, constant.18), distribution=rng_uniform
386   arg_tuple.14 = (f32[10,10]{1,0}) parameter(0)
387   get-tuple-element.16 = f32[10,10]{1,0} get-tuple-element(arg_tuple.14), index=0
388   ROOT tuple.7 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(rng.19, get-tuple-element.16)
389 }
390 
391 ENTRY main {
392   constant.38 = pred[] constant(true)
393   arg_tuple.30 = (s32[], f32[10,10]{1,0}) parameter(0)
394   get-tuple-element.21 = f32[10,10]{1,0} get-tuple-element(arg_tuple.30), index=1
395   tuple.1 = (f32[10,10]{1,0}) tuple(get-tuple-element.21)
396   conditional = (f32[10,10]{1,0}, f32[10,10]{1,0}) conditional(constant.38, tuple.1, tuple.1), true_computation=on_true, false_computation=on_false
397   get-second-index = f32[10,10]{1,0} get-tuple-element(conditional), index=1
398   ROOT result = (f32[10,10]{1,0}) tuple(get-second-index)
399 }
400 )";
401   auto status = ParseAndReturnVerifiedModule(hlo_string);
402   TF_ASSERT_OK(status.status());
403   HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
404   TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
405   EXPECT_TRUE(
406       ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie());
407   TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
408   const HloInstruction* conditional =
409       FindInstruction(status.ValueOrDie().get(), "conditional");
410   // The first element of "conditional" result tuple (f32[10,10], f32[10,10])
411   // should be removed since it is not referenced by any GTE instructions (see
412   // "get-second-index" instruction in hlo_string).
413   EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 1);
414 }
415 
416 // Before:
417 //       gte          rng
418 //      /   \        /   \
419 //      |   |        |   |
420 //     on_true      on_false
421 //    (f32, f32)   (f32, f32)
422 //         |           |
423 //          \         /
424 //          conditional
425 //          (f32, f32)
426 //
427 // After:
428 //       gte          rng
429 //        |            |
430 //     on_true      on_false
431 //      (f32)        (f32)
432 //         |           |
433 //          \         /
434 //          conditional
435 //             (f32)
436 //
437 // The 'rng' in on_false is to add side-effect so that conditional is not being
438 // simplified and replaced with 'select' instruction by TryRemoveConditional.
TEST_F(ConditionalSimplifierTest,MergeDuplicateTupleElements)439 TEST_F(ConditionalSimplifierTest, MergeDuplicateTupleElements) {
440   absl::string_view hlo_string =
441       R"(
442 HloModule MergeDuplicateTupleElements
443 
444 on_true {
445   param-true = (f32[]) parameter(0)
446   gte-true = f32[] get-tuple-element(param-true), index=0
447   ROOT tuple-true = (f32[], f32[]) tuple(gte-true, gte-true)
448 }
449 
450 on_false {
451   param-false = (f32[]) parameter(0)
452   constant.0 = f32[] constant(0)
453   constant.1 = f32[] constant(1)
454   rng = f32[] rng(constant.0, constant.1), distribution=rng_uniform
455   ROOT tuple-false = (f32[], f32[]) tuple(rng, rng)
456 }
457 
458 ENTRY main {
459   comp = pred[] parameter(0)
460   arg = (f32[]) parameter(1)
461   conditional = (f32[], f32[]) conditional(comp, arg, arg), true_computation=on_true, false_computation=on_false
462   gte.0 = f32[] get-tuple-element(conditional), index=0
463   gte.1 = f32[] get-tuple-element(conditional), index=1
464   ROOT add = f32[] add(gte.0, gte.1)
465 }
466 )";
467   auto status = ParseAndReturnVerifiedModule(hlo_string);
468   TF_ASSERT_OK(status.status());
469   HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
470   TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
471   EXPECT_TRUE(
472       ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie());
473   TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
474   const HloInstruction* conditional =
475       FindInstruction(status.ValueOrDie().get(), "conditional");
476   EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 1);
477   const HloInstruction* gte_0 =
478       FindInstruction(status.ValueOrDie().get(), "gte.0");
479   const HloInstruction* gte_1 =
480       FindInstruction(status.ValueOrDie().get(), "gte.1");
481   EXPECT_EQ(gte_0->tuple_index(), 0);
482   EXPECT_EQ(gte_1->tuple_index(), 0);
483 }
484 
485 // Since select can only be used on arrays, use after-all for token types.
TEST_F(ConditionalSimplifierTest,SimplifyConditionalWithTokens)486 TEST_F(ConditionalSimplifierTest, SimplifyConditionalWithTokens) {
487   absl::string_view hlo_string =
488       R"(
489 HloModule SimplifyConditionalWithTokens
490 
491 true_comp {
492   ROOT parameter.13 = (token[]) parameter(0)
493 }
494 
495 false_comp {
496   ROOT parameter.21 = (token[]) parameter(0)
497 }
498 
499 ENTRY entry {
500   parameter.29 = pred[] parameter(0)
501   token.1 = token[] after-all()
502   token.2 = token[] after-all()
503   tuple.3 = (token[]) tuple(token.1)
504   tuple.4 = (token[]) tuple(token.2)
505   ROOT conditional.5 = (token[]) conditional(parameter.29, tuple.3, tuple.4), true_computation=true_comp, false_computation=false_comp
506 }
507 )";
508   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
509                           ParseAndReturnVerifiedModule(hlo_string));
510   HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
511   TF_ASSERT_OK(v.Run(module.get()).status());
512   EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie());
513   EXPECT_THAT(module->entry_computation()->root_instruction(),
514               op::Tuple(op::AfterAll(
515                   op::GetTupleElement(op::Tuple(op::AfterAll()), 0),
516                   op::GetTupleElement(op::Tuple(op::AfterAll()), 0))));
517 }
518 
519 }  // namespace
520 
521 }  // namespace xla
522