1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/stream_executor/lib/statusor.h"
34
35 namespace xla {
36 namespace {
37
38 namespace op = xla::testing::opcode_matchers;
39
40 class ConditionalSimplifierTest : public HloTestBase {
41 public:
42 // Makes a computation that contains a conditional with constant predicate.
43 HloComputation* MakeConditional(HloModule* module, bool is_constant = true);
44 };
45
MakeConditional(HloModule * module,bool is_constant)46 HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module,
47 bool is_constant) {
48 HloComputation::Builder builder(TestName());
49
50 // true_computation returns param+1.
51 HloComputation* true_computation;
52 {
53 HloComputation::Builder true_computation_builder(TestName() +
54 ".true_computation");
55 auto param =
56 true_computation_builder.AddInstruction(HloInstruction::CreateParameter(
57 0, ShapeUtil::MakeShape(S32, {}), "param"));
58 auto one = true_computation_builder.AddInstruction(
59 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
60
61 true_computation_builder.AddInstruction(HloInstruction::CreateBinary(
62 ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, one));
63
64 true_computation =
65 module->AddEmbeddedComputation(true_computation_builder.Build());
66 }
67
68 // false_computation returns param+42.
69 HloComputation* false_computation;
70 {
71 HloComputation::Builder false_computation_builder(TestName() +
72 ".false_computation");
73 auto param = false_computation_builder.AddInstruction(
74 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}),
75 "param"));
76 auto forty_two = false_computation_builder.AddInstruction(
77 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(42)));
78
79 false_computation_builder.AddInstruction(HloInstruction::CreateBinary(
80 ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, forty_two));
81 false_computation =
82 module->AddEmbeddedComputation(false_computation_builder.Build());
83 }
84
85 auto false_instrn = builder.AddInstruction(
86 is_constant
87 ? HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))
88 : HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(PRED, {}),
89 "cond"));
90 auto false_param = builder.AddInstruction(HloInstruction::CreateParameter(
91 0, ShapeUtil::MakeShape(S32, {}), "false_param"));
92 auto one = builder.AddInstruction(
93 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
94
95 builder.AddInstruction(HloInstruction::CreateConditional(
96 ShapeUtil::MakeShape(S32, {}), false_instrn, one, true_computation,
97 false_param, false_computation));
98
99 return module->AddEntryComputation(builder.Build());
100 }
101
TEST_F(ConditionalSimplifierTest,ConditionalGetsInlined)102 TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) {
103 auto m = CreateNewVerifiedModule();
104 HloComputation* computation = MakeConditional(m.get());
105 ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
106 EXPECT_THAT(computation->root_instruction(),
107 op::Add(op::Parameter(), op::Constant()));
108 }
109
TEST_F(ConditionalSimplifierTest,BranchGetsInlined)110 TEST_F(ConditionalSimplifierTest, BranchGetsInlined) {
111 auto m = CreateNewVerifiedModule();
112 HloComputation* computation = MakeConditional(m.get(), /*is_constant=*/false);
113 ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
114 EXPECT_THAT(
115 computation->root_instruction(),
116 op::Select(op::Parameter(1), op::Add(op::Constant(), op::Constant()),
117 op::Add(op::Parameter(0), op::Constant())));
118 }
119
TEST_F(ConditionalSimplifierTest,ConditionalWithControlDependency)120 TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) {
121 auto m = CreateNewVerifiedModule();
122 HloComputation* computation = MakeConditional(m.get());
123
124 auto* true_op = computation->AddInstruction(
125 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
126 TF_ASSERT_OK(
127 true_op->AddControlDependencyTo(computation->root_instruction()));
128
129 EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
130 }
131
TEST_F(ConditionalSimplifierTest,NotRemovedIfContainsSend)132 TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) {
133 auto m = CreateNewVerifiedModule();
134 HloComputation* computation = MakeConditional(m.get());
135 auto* conditional = computation->root_instruction();
136 ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
137
138 auto* true_computation = conditional->true_computation();
139 auto* token = true_computation->AddInstruction(HloInstruction::CreateToken());
140 auto* send = true_computation->AddInstruction(HloInstruction::CreateSend(
141 true_computation->AddInstruction(
142 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))),
143 token, /*channel_id=*/0));
144 true_computation->AddInstruction(HloInstruction::CreateSendDone(send));
145 EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
146 }
147
TEST_F(ConditionalSimplifierTest,NotRemovedIfContainsRecv)148 TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) {
149 auto m = CreateNewVerifiedModule();
150 HloComputation* computation = MakeConditional(m.get());
151 auto* conditional = computation->root_instruction();
152 ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
153
154 auto* true_computation = conditional->true_computation();
155 auto* token = true_computation->AddInstruction(HloInstruction::CreateToken());
156 auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv(
157 ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0));
158 true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv));
159 EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
160 }
161
TEST_F(ConditionalSimplifierTest,NotRemovedIfContainsNonRemovableInstruction)162 TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) {
163 auto m = CreateNewVerifiedModule();
164 HloComputation* computation = MakeConditional(m.get());
165 auto* conditional = computation->root_instruction();
166 ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
167 auto* false_computation = conditional->false_computation();
168 auto token = false_computation->AddInstruction(HloInstruction::CreateToken());
169 false_computation->AddInstruction(HloInstruction::CreateInfeed(
170 ShapeUtil::MakeShape(F32, {1}), token, "config"));
171 EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
172 }
173
TEST_F(ConditionalSimplifierTest,TrivalOperandsRemoved)174 TEST_F(ConditionalSimplifierTest, TrivalOperandsRemoved) {
175 absl::string_view hlo_string =
176 R"(
177 HloModule UnusedTupleOperands
178 on_false {
179 t = (f32[20,40], f32[40,40], f32[20,40], f32[40,40]) parameter(0)
180 lhs = f32[20,40] get-tuple-element(t), index=0
181 rhs = f32[40,40] get-tuple-element(t), index=1
182 dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
183 ROOT result = (f32[20,40]) tuple(dot)
184 }
185
186 on_true {
187 t = (f32[20,40], f32[40,40], f32[20,40], f32[40,40]) parameter(0)
188 lhs = f32[20,40] get-tuple-element(t), index=2
189 rhs = f32[40,40] get-tuple-element(t), index=3
190 dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
191 ROOT result = (f32[20,40]) tuple(dot)
192 }
193
194 ENTRY main {
195 c0_0 = f32[20,40] parameter(0)
196 c0_1 = f32[40,40] parameter(1)
197 c1_0 = f32[20,40] parameter(2)
198 c1_1 = f32[40,40] parameter(3)
199 p = pred[] parameter(4)
200 t = (f32[20,40], f32[40,40], f32[20,40], f32[40,40]) tuple(c0_0, c0_1, c1_0, c1_1)
201 call = (f32[20,40]) call(t), to_apply=on_true
202 ROOT result = (f32[20,40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true
203 }
204 )";
205 auto status = ParseAndReturnVerifiedModule(hlo_string);
206 TF_ASSERT_OK(status.status());
207 std::unique_ptr<HloModule> module = std::move(status).value();
208 HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
209 TF_ASSERT_OK(v.Run(module.get()).status());
210 EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie());
211 TF_ASSERT_OK(v.Run(module.get()).status());
212 HloInstruction* conditional = module->entry_computation()->root_instruction();
213 EXPECT_TRUE(conditional != nullptr);
214 EXPECT_EQ(conditional->operand(1)->shape().tuple_shapes().size(), 2);
215 EXPECT_EQ(conditional->operand(2)->shape().tuple_shapes().size(), 2);
216 // For the call operation, nothing should have changed.
217 HloInstruction* call = FindInstruction(module.get(), "call");
218 EXPECT_EQ(
219 call->to_apply()->parameter_instruction(0)->shape().tuple_shapes().size(),
220 4);
221 }
222
TEST_F(ConditionalSimplifierTest,TwoConditionalsCreatedInReversedLexicalOrder)223 TEST_F(ConditionalSimplifierTest,
224 TwoConditionalsCreatedInReversedLexicalOrder) {
225 absl::string_view hlo_string = R"(
226 HloModule DeadConditional
227 computation.1 {
228 param.1 = s64[] parameter(0)
229 constant.1 = s64[] constant(1)
230 ROOT add.1 = s64[] add(param.1, constant.1)
231 }
232
233 computation.2 {
234 param.2 = s64[] parameter(0)
235 constant.2 = s64[] constant(2)
236 ROOT add.2 = s64[] add(param.2, constant.2)
237 }
238
239 computation.3 {
240 param.3 = s64[] parameter(0)
241 constant.3 = s64[] constant(3)
242 ROOT add.3 = s64[] add(param.3, constant.3)
243 }
244
245 computation.4 {
246 param.4 = s64[] parameter(0)
247 constant.4 = s64[] constant(4)
248 ROOT add.4 = s64[] add(param.4, constant.4)
249 }
250
251 ENTRY KernelEntry {
252 param.1 = s64[] parameter(0)
253 param.2 = s64[] parameter(1)
254 param.3 = s64[] parameter(2)
255 param.4 = pred[] parameter(3)
256
257 conditional_1 = s64[] conditional(param.4, param.3, param.2),
258 true_computation=computation.3, false_computation=computation.4
259 constant.1 = pred[] constant(false)
260 ROOT conditional_2 = s64[] conditional(constant.1, conditional_1,
261 param.1), true_computation=computation.1,
262 false_computation=computation.2
263 })";
264 auto status = ParseAndReturnVerifiedModule(hlo_string);
265 TF_ASSERT_OK(status.status());
266 std::unique_ptr<HloModule> module = std::move(status).value();
267 HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
268 TF_ASSERT_OK(v.Run(module.get()).status());
269
270 // Replace conditional_1 with a clone that is created after conditional_2.
271 HloInstruction* conditional_1 =
272 FindInstruction(module.get(), "conditional_1");
273 HloInstruction* conditional_1_clone =
274 conditional_1->parent()->AddInstruction(conditional_1->Clone());
275 TF_ASSERT_OK(conditional_1->ReplaceAllUsesWith(conditional_1_clone));
276 TF_ASSERT_OK(conditional_1->parent()->RemoveInstruction(conditional_1));
277
278 EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie());
279 }
280
TEST_F(ConditionalSimplifierTest,RemoveDeadRoots)281 TEST_F(ConditionalSimplifierTest, RemoveDeadRoots) {
282 absl::string_view hlo_string =
283 R"(
284 HloModule RemoveDeadRoots
285 on_false {
286 t = (f32[20,40], f32[40,40]) parameter(0)
287 lhs = f32[20,40] get-tuple-element(t), index=0
288 rhs = f32[40,40] get-tuple-element(t), index=1
289 dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
290 after-all = token[] after-all()
291 outfeed = token[] outfeed(dot, after-all)
292 ROOT result = (f32[20,40]) tuple(dot)
293 }
294
295 on_true {
296 t = (f32[20,40], f32[40,40]) parameter(0)
297 lhs = f32[20,40] get-tuple-element(t), index=0
298 add = f32[20,40] add(lhs, lhs)
299 ROOT result = (f32[20,40]) tuple(add)
300 }
301
302 ENTRY main {
303 c0_0 = f32[20,40] parameter(0)
304 c0_1 = f32[40,40] parameter(1)
305 p = pred[] parameter(2)
306 t = (f32[20,40], f32[40,40]) tuple(c0_0, c0_1)
307 conditional = (f32[20, 40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true
308 ROOT result = () tuple()
309 }
310 )";
311 auto status = ParseAndReturnVerifiedModule(hlo_string);
312 TF_ASSERT_OK(status.status());
313 HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
314 TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
315 EXPECT_TRUE(
316 ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie());
317 TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
318 HloInstruction* conditional =
319 FindInstruction(status.ValueOrDie().get(), "conditional");
320 // The conditional root should be replaced with an empty tuple.
321 EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 0);
322 }
323
TEST_F(ConditionalSimplifierTest,SecondTupleElementUnusedAndRemoved)324 TEST_F(ConditionalSimplifierTest, SecondTupleElementUnusedAndRemoved) {
325 absl::string_view hlo_string =
326 R"(
327 HloModule SecondTupleElementUnusedAndRemoved
328
329 on_true {
330 arg_tuple.7 = (f32[10,10]{1,0}) parameter(0)
331 get-tuple-element.9 = f32[10,10]{1,0} get-tuple-element(arg_tuple.7), index=0
332 copy = f32[10,10]{1,0} copy(get-tuple-element.9)
333 ROOT tuple.6 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(copy, get-tuple-element.9)
334 }
335
336 on_false {
337 constant.17 = f32[] constant(0)
338 constant.18 = f32[] constant(1)
339 rng.19 = f32[10,10]{1,0} rng(constant.17, constant.18), distribution=rng_uniform
340 arg_tuple.14 = (f32[10,10]{1,0}) parameter(0)
341 get-tuple-element.16 = f32[10,10]{1,0} get-tuple-element(arg_tuple.14), index=0
342 ROOT tuple.7 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(rng.19, get-tuple-element.16)
343 }
344
345 ENTRY main {
346 constant.38 = pred[] constant(true)
347 arg_tuple.30 = (s32[], f32[10,10]{1,0}) parameter(0)
348 get-tuple-element.21 = f32[10,10]{1,0} get-tuple-element(arg_tuple.30), index=1
349 tuple.1 = (f32[10,10]{1,0}) tuple(get-tuple-element.21)
350 conditional = (f32[10,10]{1,0}, f32[10,10]{1,0}) conditional(constant.38, tuple.1, tuple.1), true_computation=on_true, false_computation=on_false
351 get-first-index = f32[10,10]{1,0} get-tuple-element(conditional), index=0
352 ROOT result = (f32[10,10]{1,0}) tuple(get-first-index)
353 }
354 )";
355 auto status = ParseAndReturnVerifiedModule(hlo_string);
356 TF_ASSERT_OK(status.status());
357 HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
358 TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
359 EXPECT_TRUE(
360 ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie());
361 TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
362 const HloInstruction* conditional =
363 FindInstruction(status.ValueOrDie().get(), "conditional");
364 // The second element of "conditional" result tuple (f32[10,10], f32[10,10])
365 // should be removed since it is not referenced by any GTE instructions
366 // (see "get-first-index" instruction in hlo_string).
367 EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 1);
368 }
369
TEST_F(ConditionalSimplifierTest,FirstTupleElementUnusedAndRemoved)370 TEST_F(ConditionalSimplifierTest, FirstTupleElementUnusedAndRemoved) {
371 absl::string_view hlo_string =
372 R"(
373 HloModule FirstTupleElementUnusedAndRemoved
374
375 on_true {
376 arg_tuple.7 = (f32[10,10]{1,0}) parameter(0)
377 get-tuple-element.9 = f32[10,10]{1,0} get-tuple-element(arg_tuple.7), index=0
378 copy = f32[10,10]{1,0} copy(get-tuple-element.9)
379 ROOT tuple.6 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(copy, get-tuple-element.9)
380 }
381
382 on_false {
383 constant.17 = f32[] constant(0)
384 constant.18 = f32[] constant(1)
385 rng.19 = f32[10,10]{1,0} rng(constant.17, constant.18), distribution=rng_uniform
386 arg_tuple.14 = (f32[10,10]{1,0}) parameter(0)
387 get-tuple-element.16 = f32[10,10]{1,0} get-tuple-element(arg_tuple.14), index=0
388 ROOT tuple.7 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(rng.19, get-tuple-element.16)
389 }
390
391 ENTRY main {
392 constant.38 = pred[] constant(true)
393 arg_tuple.30 = (s32[], f32[10,10]{1,0}) parameter(0)
394 get-tuple-element.21 = f32[10,10]{1,0} get-tuple-element(arg_tuple.30), index=1
395 tuple.1 = (f32[10,10]{1,0}) tuple(get-tuple-element.21)
396 conditional = (f32[10,10]{1,0}, f32[10,10]{1,0}) conditional(constant.38, tuple.1, tuple.1), true_computation=on_true, false_computation=on_false
397 get-second-index = f32[10,10]{1,0} get-tuple-element(conditional), index=1
398 ROOT result = (f32[10,10]{1,0}) tuple(get-second-index)
399 }
400 )";
401 auto status = ParseAndReturnVerifiedModule(hlo_string);
402 TF_ASSERT_OK(status.status());
403 HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
404 TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
405 EXPECT_TRUE(
406 ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie());
407 TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
408 const HloInstruction* conditional =
409 FindInstruction(status.ValueOrDie().get(), "conditional");
410 // The first element of "conditional" result tuple (f32[10,10], f32[10,10])
411 // should be removed since it is not referenced by any GTE instructions (see
412 // "get-second-index" instruction in hlo_string).
413 EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 1);
414 }
415
416 // Before:
417 // gte rng
418 // / \ / \
419 // | | | |
420 // on_true on_false
421 // (f32, f32) (f32, f32)
422 // | |
423 // \ /
424 // conditional
425 // (f32, f32)
426 //
427 // After:
428 // gte rng
429 // | |
430 // on_true on_false
431 // (f32) (f32)
432 // | |
433 // \ /
434 // conditional
435 // (f32)
436 //
437 // The 'rng' in on_false is to add side-effect so that conditional is not being
438 // simplified and replaced with 'select' instruction by TryRemoveConditional.
TEST_F(ConditionalSimplifierTest,MergeDuplicateTupleElements)439 TEST_F(ConditionalSimplifierTest, MergeDuplicateTupleElements) {
440 absl::string_view hlo_string =
441 R"(
442 HloModule MergeDuplicateTupleElements
443
444 on_true {
445 param-true = (f32[]) parameter(0)
446 gte-true = f32[] get-tuple-element(param-true), index=0
447 ROOT tuple-true = (f32[], f32[]) tuple(gte-true, gte-true)
448 }
449
450 on_false {
451 param-false = (f32[]) parameter(0)
452 constant.0 = f32[] constant(0)
453 constant.1 = f32[] constant(1)
454 rng = f32[] rng(constant.0, constant.1), distribution=rng_uniform
455 ROOT tuple-false = (f32[], f32[]) tuple(rng, rng)
456 }
457
458 ENTRY main {
459 comp = pred[] parameter(0)
460 arg = (f32[]) parameter(1)
461 conditional = (f32[], f32[]) conditional(comp, arg, arg), true_computation=on_true, false_computation=on_false
462 gte.0 = f32[] get-tuple-element(conditional), index=0
463 gte.1 = f32[] get-tuple-element(conditional), index=1
464 ROOT add = f32[] add(gte.0, gte.1)
465 }
466 )";
467 auto status = ParseAndReturnVerifiedModule(hlo_string);
468 TF_ASSERT_OK(status.status());
469 HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
470 TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
471 EXPECT_TRUE(
472 ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie());
473 TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
474 const HloInstruction* conditional =
475 FindInstruction(status.ValueOrDie().get(), "conditional");
476 EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 1);
477 const HloInstruction* gte_0 =
478 FindInstruction(status.ValueOrDie().get(), "gte.0");
479 const HloInstruction* gte_1 =
480 FindInstruction(status.ValueOrDie().get(), "gte.1");
481 EXPECT_EQ(gte_0->tuple_index(), 0);
482 EXPECT_EQ(gte_1->tuple_index(), 0);
483 }
484
485 // Since select can only be used on arrays, use after-all for token types.
TEST_F(ConditionalSimplifierTest,SimplifyConditionalWithTokens)486 TEST_F(ConditionalSimplifierTest, SimplifyConditionalWithTokens) {
487 absl::string_view hlo_string =
488 R"(
489 HloModule SimplifyConditionalWithTokens
490
491 true_comp {
492 ROOT parameter.13 = (token[]) parameter(0)
493 }
494
495 false_comp {
496 ROOT parameter.21 = (token[]) parameter(0)
497 }
498
499 ENTRY entry {
500 parameter.29 = pred[] parameter(0)
501 token.1 = token[] after-all()
502 token.2 = token[] after-all()
503 tuple.3 = (token[]) tuple(token.1)
504 tuple.4 = (token[]) tuple(token.2)
505 ROOT conditional.5 = (token[]) conditional(parameter.29, tuple.3, tuple.4), true_computation=true_comp, false_computation=false_comp
506 }
507 )";
508 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
509 ParseAndReturnVerifiedModule(hlo_string));
510 HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false);
511 TF_ASSERT_OK(v.Run(module.get()).status());
512 EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie());
513 EXPECT_THAT(module->entry_computation()->root_instruction(),
514 op::Tuple(op::AfterAll(
515 op::GetTupleElement(op::Tuple(op::AfterAll()), 0),
516 op::GetTupleElement(op::Tuple(op::AfterAll()), 0))));
517 }
518
519 } // namespace
520
521 } // namespace xla
522