xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_dce_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_dce.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/compiler/xla/layout_util.h"
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/tests/test_utils.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 
36 namespace xla {
37 namespace {
38 
39 class HloDceTest : public HloTestBase {
40  protected:
HloDceTest()41   HloDceTest() {}
42 
43   // Returns whether the given instruction exists in the given computation.
HasInstruction(const HloComputation & computation,const HloInstruction * instruction)44   bool HasInstruction(const HloComputation& computation,
45                       const HloInstruction* instruction) {
46     return absl::c_linear_search(computation.instructions(), instruction);
47   }
48 };
49 
TEST_F(HloDceTest,NoDeadCode)50 TEST_F(HloDceTest, NoDeadCode) {
51   // Verify that no dead code is removed from a computation with no dead code.
52   auto builder = HloComputation::Builder(TestName());
53   auto constant1 = builder.AddInstruction(
54       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
55   auto constant2 = builder.AddInstruction(
56       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
57   builder.AddInstruction(HloInstruction::CreateBinary(
58       constant1->shape(), HloOpcode::kAdd, constant1, constant2));
59 
60   auto module = CreateNewVerifiedModule();
61   auto computation = module->AddEntryComputation(builder.Build());
62 
63   EXPECT_EQ(3, computation->instruction_count());
64 
65   HloDCE dce;
66   EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
67 
68   EXPECT_EQ(3, computation->instruction_count());
69 }
70 
TEST_F(HloDceTest,InstructionsWithSideEffect)71 TEST_F(HloDceTest, InstructionsWithSideEffect) {
72   // Verify that side-effect instructions (Send in this test) are not removed.
73   auto builder = HloComputation::Builder(TestName());
74   auto constant = builder.AddInstruction(
75       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
76   auto token = builder.AddInstruction(HloInstruction::CreateToken());
77   auto send = builder.AddInstruction(
78       HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
79   builder.AddInstruction(HloInstruction::CreateSendDone(send));
80   builder.AddInstruction(HloInstruction::CreateTuple({}));
81 
82   auto module = CreateNewVerifiedModule();
83   auto computation = module->AddEntryComputation(builder.Build());
84 
85   EXPECT_EQ(5, computation->instruction_count());
86 
87   HloDCE dce;
88   EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
89 
90   EXPECT_EQ(5, computation->instruction_count());
91 }
92 
TEST_F(HloDceTest,CustomCallInstructionsWithSideEffect)93 TEST_F(HloDceTest, CustomCallInstructionsWithSideEffect) {
94   // Verify that custom call instruction with side-effect is not removed.
95   auto builder = HloComputation::Builder(TestName());
96   auto instr = Cast<HloCustomCallInstruction>(builder.AddInstruction(
97       HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
98                                        /*operands=*/{},
99                                        /*custom_call_target=*/"foo")));
100   instr->set_custom_call_has_side_effect(true);
101   builder.AddInstruction(HloInstruction::CreateTuple({}));
102 
103   auto module = CreateNewVerifiedModule();
104   module->AddEntryComputation(builder.Build());
105 
106   HloDCE dce;
107   TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&dce, module.get()));
108   EXPECT_FALSE(result);
109 }
110 
TEST_F(HloDceTest,CustomCallInstructionsWithoutSideEffect)111 TEST_F(HloDceTest, CustomCallInstructionsWithoutSideEffect) {
112   // Verify that custom call instruction without side-effect is removed.
113   auto builder = HloComputation::Builder(TestName());
114   builder.AddInstruction(
115       HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
116                                        /*operands=*/{},
117                                        /*custom_call_target=*/"foo"));
118   builder.AddInstruction(HloInstruction::CreateTuple({}));
119 
120   auto module = CreateNewVerifiedModule();
121   module->AddEntryComputation(builder.Build());
122 
123   HloDCE dce;
124   TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&dce, module.get()));
125   EXPECT_TRUE(result);
126 }
127 
TEST_F(HloDceTest,DeadParameters)128 TEST_F(HloDceTest, DeadParameters) {
129   // Verify that dead parameters are not removed, but use of the dead parameters
130   // are.
131   auto builder = HloComputation::Builder(TestName());
132   auto live_param = builder.AddInstruction(HloInstruction::CreateParameter(
133       0, ShapeUtil::MakeShape(F32, {}), "live_param"));
134   auto dead_param1 = builder.AddInstruction(HloInstruction::CreateParameter(
135       1, ShapeUtil::MakeShape(F32, {}), "dead_param1"));
136   builder.AddInstruction(HloInstruction::CreateParameter(
137       2, ShapeUtil::MakeShape(F32, {}), "dead_param2"));
138 
139   // This is a dead negate instruction.
140   builder.AddInstruction(HloInstruction::CreateUnary(
141       dead_param1->shape(), HloOpcode::kNegate, dead_param1));
142 
143   // This negate is not dead because it is the root.
144   builder.AddInstruction(HloInstruction::CreateUnary(
145       live_param->shape(), HloOpcode::kNegate, live_param));
146 
147   auto module = CreateNewVerifiedModule();
148   auto computation = module->AddEntryComputation(builder.Build());
149 
150   EXPECT_EQ(5, computation->instruction_count());
151   EXPECT_EQ(1, dead_param1->user_count());
152 
153   HloDCE dce;
154   EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
155 
156   EXPECT_EQ(4, computation->instruction_count());
157   EXPECT_EQ(0, dead_param1->user_count());
158 }
159 
TEST_F(HloDceTest,ControlDependencies)160 TEST_F(HloDceTest, ControlDependencies) {
161   // Verify that instructions with control dependencies are not removed.
162   auto builder = HloComputation::Builder(TestName());
163   auto constant1 = builder.AddInstruction(
164       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
165   auto constant2 = builder.AddInstruction(
166       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
167 
168   // Create two dead instructions: a negate and an add.
169   auto dead_negate = builder.AddInstruction(HloInstruction::CreateUnary(
170       constant1->shape(), HloOpcode::kNegate, constant1));
171   auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary(
172       constant1->shape(), HloOpcode::kAdd, constant1, constant2));
173 
174   // Create the same two instructions again, but these will have a control
175   // dependency added.
176   auto dead_negate_with_control_dep =
177       builder.AddInstruction(HloInstruction::CreateUnary(
178           constant1->shape(), HloOpcode::kNegate, constant1));
179   auto dead_add_with_control_dep =
180       builder.AddInstruction(HloInstruction::CreateBinary(
181           constant1->shape(), HloOpcode::kAdd, constant1, constant2));
182 
183   // Create a root so the previously added instruction is dead.
184   builder.AddInstruction(HloInstruction::CreateBinary(
185       constant1->shape(), HloOpcode::kAdd, constant1, constant2));
186 
187   auto module = CreateNewVerifiedModule();
188   auto computation = module->AddEntryComputation(builder.Build());
189 
190   // Add a control dependency between two instructions.
191   TF_ASSERT_OK(dead_negate_with_control_dep->AddControlDependencyTo(
192       dead_add_with_control_dep));
193 
194   EXPECT_EQ(7, computation->instruction_count());
195   EXPECT_TRUE(HasInstruction(*computation, dead_negate));
196   EXPECT_TRUE(HasInstruction(*computation, dead_add));
197   EXPECT_TRUE(HasInstruction(*computation, dead_negate_with_control_dep));
198   EXPECT_TRUE(HasInstruction(*computation, dead_add_with_control_dep));
199 
200   HloDCE dce;
201   EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
202 
203   EXPECT_EQ(5, computation->instruction_count());
204   EXPECT_FALSE(HasInstruction(*computation, dead_negate));
205   EXPECT_FALSE(HasInstruction(*computation, dead_add));
206   EXPECT_TRUE(HasInstruction(*computation, dead_negate_with_control_dep));
207   EXPECT_TRUE(HasInstruction(*computation, dead_add_with_control_dep));
208 }
209 
210 // Tests that a dead call instruction is removed.
TEST_F(HloDceTest,DeadInstructionWithCalledComputation)211 TEST_F(HloDceTest, DeadInstructionWithCalledComputation) {
212   auto module = CreateNewVerifiedModule();
213   Shape shape = ShapeUtil::MakeShape(F32, {});
214 
215   // Called computation for the call instruction.
216   auto callee_builder = HloComputation::Builder(TestName() + "-callee");
217   {
218     auto param = callee_builder.AddInstruction(
219         HloInstruction::CreateParameter(0, shape, "param"));
220     callee_builder.AddInstruction(
221         HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
222   }
223   auto called_computation =
224       module->AddEmbeddedComputation(callee_builder.Build());
225 
226   // Entry computation with a call instruction.
227   auto builder = HloComputation::Builder(TestName());
228   auto param = builder.AddInstruction(
229       HloInstruction::CreateParameter(0, shape, "param"));
230   auto dead_call = builder.AddInstruction(
231       HloInstruction::CreateCall(shape, {param}, called_computation));
232   builder.AddInstruction(
233       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
234   auto computation = module->AddEntryComputation(builder.Build());
235 
236   EXPECT_EQ(3, computation->instruction_count());
237   EXPECT_EQ(2, param->user_count());
238   EXPECT_EQ(0, dead_call->user_count());
239   EXPECT_TRUE(HasInstruction(*computation, dead_call));
240 
241   HloDCE dce;
242   EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
243 
244   EXPECT_EQ(2, computation->instruction_count());
245   EXPECT_EQ(1, param->user_count());
246   EXPECT_FALSE(HasInstruction(*computation, dead_call));
247 }
248 
249 // Tests that a while instruction with an infeed (effectul instruction) in its
250 // body is not removed, even its user count is 0.
TEST_F(HloDceTest,CalledComputationWithSideEffect)251 TEST_F(HloDceTest, CalledComputationWithSideEffect) {
252   auto module = CreateNewVerifiedModule();
253   Shape shape = ShapeUtil::MakeShape(F32, {});
254 
255   // Condition computation of a while instruction.
256   auto cond_builder = HloComputation::Builder(TestName() + "-cond");
257   {
258     auto param = cond_builder.AddInstruction(
259         HloInstruction::CreateParameter(0, shape, "cond_param"));
260     auto constant = cond_builder.AddInstruction(
261         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
262     cond_builder.AddInstruction(
263         HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
264                                       constant, ComparisonDirection::kLt));
265   }
266   auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
267 
268   // Body computation of a while instruction.
269   auto body_builder = HloComputation::Builder(TestName() + "-body");
270   {
271     auto param = body_builder.AddInstruction(
272         HloInstruction::CreateParameter(0, shape, "param"));
273     auto token = body_builder.AddInstruction(HloInstruction::CreateToken());
274     auto infeed = body_builder.AddInstruction(
275         HloInstruction::CreateInfeed(shape, token, ""));
276     auto infeed_data = body_builder.AddInstruction(
277         HloInstruction::CreateGetTupleElement(shape, infeed, 0));
278     body_builder.AddInstruction(HloInstruction::CreateBinary(
279         shape, HloOpcode::kAdd, param, infeed_data));
280   }
281   auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
282 
283   // Entry computation with a while instruction and a negate on the parameter.
284   auto builder = HloComputation::Builder(TestName());
285   auto param = builder.AddInstruction(
286       HloInstruction::CreateParameter(0, shape, "param"));
287   auto live_while = builder.AddInstruction(HloInstruction::CreateWhile(
288       shape, cond_computation, body_computation, param));
289   builder.AddInstruction(
290       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
291   auto computation = module->AddEntryComputation(builder.Build());
292 
293   // Check the while instruction is not removed even if its user count is 0.
294   EXPECT_EQ(3, computation->instruction_count());
295   EXPECT_EQ(2, param->user_count());
296   EXPECT_EQ(0, live_while->user_count());
297   EXPECT_TRUE(HasInstruction(*computation, live_while));
298 
299   HloDCE dce;
300   EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
301 
302   EXPECT_EQ(3, computation->instruction_count());
303   EXPECT_EQ(2, param->user_count());
304   EXPECT_EQ(0, live_while->user_count());
305   EXPECT_TRUE(HasInstruction(*computation, live_while));
306 }
307 
308 // Tests that a nested call instruction with a side effect is not removed.
TEST_F(HloDceTest,CalledComputationWithNestedSideEffect)309 TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) {
310   auto module = CreateNewVerifiedModule();
311   Shape shape = ShapeUtil::MakeShape(F32, {});
312 
313   // Nested called computation with a side effect.
314   auto nested_callee_builder =
315       HloComputation::Builder(TestName() + "-nested_callee");
316   {
317     auto param = nested_callee_builder.AddInstruction(
318         HloInstruction::CreateParameter(0, shape, "param"));
319     auto token =
320         nested_callee_builder.AddInstruction(HloInstruction::CreateToken());
321     nested_callee_builder.AddInstruction(
322         HloInstruction::CreateOutfeed(shape, param, token, ""));
323   }
324   auto nested_called_computation =
325       module->AddEmbeddedComputation(nested_callee_builder.Build());
326 
327   // Outer called computation that calls the nested computation.
328   auto callee_builder = HloComputation::Builder(TestName() + "-callee");
329   {
330     auto param = callee_builder.AddInstruction(
331         HloInstruction::CreateParameter(0, shape, "param"));
332     callee_builder.AddInstruction(HloInstruction::CreateCall(
333         ShapeUtil::MakeTokenShape(), {param}, nested_called_computation));
334   }
335   auto called_computation =
336       module->AddEmbeddedComputation(callee_builder.Build());
337 
338   // Entry computation with a call instruction.
339   auto builder = HloComputation::Builder(TestName());
340   auto param = builder.AddInstruction(
341       HloInstruction::CreateParameter(0, shape, "param"));
342   auto live_call = builder.AddInstruction(HloInstruction::CreateCall(
343       ShapeUtil::MakeTokenShape(), {param}, called_computation));
344   auto computation = module->AddEntryComputation(builder.Build());
345 
346   EXPECT_EQ(2, computation->instruction_count());
347   EXPECT_EQ(1, param->user_count());
348   EXPECT_EQ(0, live_call->user_count());
349   EXPECT_TRUE(HasInstruction(*computation, live_call));
350 
351   HloDCE dce;
352   EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
353 
354   EXPECT_EQ(2, computation->instruction_count());
355   EXPECT_EQ(1, param->user_count());
356   EXPECT_EQ(0, live_call->user_count());
357   EXPECT_TRUE(HasInstruction(*computation, live_call));
358 }
359 
TEST_F(HloDceTest,RemoveDeadSubcomputation)360 TEST_F(HloDceTest, RemoveDeadSubcomputation) {
361   auto module = CreateNewVerifiedModule();
362   HloComputation::Builder builder(TestName());
363 
364   HloComputation::Builder subcomp_builder("reduction_subcomp");
365   {
366     auto* param0 =
367         subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
368             /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "param0"));
369     auto* param1 =
370         subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
371             /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "param1"));
372     subcomp_builder.AddInstruction(HloInstruction::CreateBinary(
373         ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, param0, param1));
374   }
375   auto reduce_subcomp = module->AddEmbeddedComputation(subcomp_builder.Build());
376 
377   // Create a dead reduce instruction.
378   builder.AddInstruction(HloInstruction::CreateReduce(
379       ShapeUtil::MakeShape(F32, {1}),
380       builder.AddInstruction(HloInstruction::CreateParameter(
381           /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")),
382       builder.AddInstruction(
383           HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
384       /*dimensions_to_reduce=*/{0}, reduce_subcomp));
385 
386   // Add another instruction as the root of the computation.
387   builder.AddInstruction(
388       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
389 
390   module->AddEntryComputation(builder.Build());
391   EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
392 
393   HloDCE dce;
394   EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
395 
396   // We should have DCE'ed the reduction computation along with the reduction
397   // instruction.
398   EXPECT_EQ(module->MakeComputationPostOrder().size(), 1);
399 }
400 
TEST_F(HloDceTest,KeepUsedSubcomputation)401 TEST_F(HloDceTest, KeepUsedSubcomputation) {
402   auto module = CreateNewVerifiedModule();
403   HloComputation::Builder builder(TestName());
404 
405   HloComputation::Builder subcomp_builder("reduction_subcomp");
406   {
407     auto* param0 =
408         subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
409             /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "param0"));
410     auto* param1 =
411         subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
412             /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "param1"));
413     subcomp_builder.AddInstruction(HloInstruction::CreateBinary(
414         ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, param0, param1));
415   }
416   auto reduce_subcomp = module->AddEmbeddedComputation(subcomp_builder.Build());
417 
418   // Create a dead reduce instruction.
419   builder.AddInstruction(HloInstruction::CreateReduce(
420       ShapeUtil::MakeShape(F32, {}),
421       builder.AddInstruction(HloInstruction::CreateParameter(
422           /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")),
423       builder.AddInstruction(
424           HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
425       /*dimensions_to_reduce=*/{0}, reduce_subcomp));
426 
427   // Add another instruction as the root of the computation that also uses
428   // reduce_subcomp.
429   builder.AddInstruction(HloInstruction::CreateReduce(
430       ShapeUtil::MakeShape(F32, {}),
431       builder.AddInstruction(HloInstruction::CreateParameter(
432           /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {100}), "param1")),
433       builder.AddInstruction(
434           HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
435       /*dimensions_to_reduce=*/{0}, reduce_subcomp));
436 
437   module->AddEntryComputation(builder.Build());
438   EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
439 
440   HloDCE dce;
441   EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
442 
443   // We shouldn't have DCE'ed reduce_subcomp, even though we removed one of
444   // its users.
445   EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
446 }
447 
TEST_F(HloDceTest,RemovedNestedDeadComputations)448 TEST_F(HloDceTest, RemovedNestedDeadComputations) {
449   auto module = CreateNewVerifiedModule();
450   Shape shape = ShapeUtil::MakeShape(F32, {});
451 
452   HloComputation::Builder called_subcomp_builder("called_dead_add");
453   {
454     auto* param0 =
455         called_subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
456             /*parameter_number=*/0, shape, "param0"));
457     auto* param1 =
458         called_subcomp_builder.AddInstruction(HloInstruction::CreateParameter(
459             /*parameter_number=*/1, shape, "param1"));
460     called_subcomp_builder.AddInstruction(HloInstruction::CreateBinary(
461         ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, param0, param1));
462   }
463   auto called_subcomp =
464       module->AddEmbeddedComputation(called_subcomp_builder.Build());
465 
466   // Creates a module with unflattened control flow with two dead computations
467   // that both call the same subcomputation, which becomes dead after the two
468   // callers are removed.
469   {
470     HloComputation::Builder dead_subcomp_builder("dead_caller0");
471     auto* param0 = dead_subcomp_builder.AddInstruction(
472         HloInstruction::CreateParameter(0, shape, "param0"));
473     auto* param1 = dead_subcomp_builder.AddInstruction(
474         HloInstruction::CreateParameter(1, shape, "param1"));
475     dead_subcomp_builder.AddInstruction(
476         HloInstruction::CreateCall(shape, {param0, param1}, called_subcomp));
477     module->AddEmbeddedComputation(dead_subcomp_builder.Build());
478   }
479 
480   {
481     HloComputation::Builder dead_subcomp_builder("dead_caller1");
482     auto* param0 = dead_subcomp_builder.AddInstruction(
483         HloInstruction::CreateParameter(0, shape, "param0"));
484     auto* param1 = dead_subcomp_builder.AddInstruction(
485         HloInstruction::CreateParameter(1, shape, "param1"));
486     dead_subcomp_builder.AddInstruction(
487         HloInstruction::CreateCall(shape, {param0, param1}, called_subcomp));
488     module->AddEmbeddedComputation(dead_subcomp_builder.Build());
489   }
490 
491   HloComputation::Builder builder(TestName());
492 
493   // Adds a constant instruction as the root of the computation.
494   builder.AddInstruction(
495       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
496 
497   module->AddEntryComputation(builder.Build());
498   EXPECT_EQ(module->MakeComputationPostOrder().size(), 4);
499 
500   HloDCE dce;
501   auto changed = dce.Run(module.get());
502   ASSERT_TRUE(changed.ok());
503   EXPECT_TRUE(*changed);
504 
505   // Only the entry computation should be left after eliminating the dead caller
506   // and callee subcomputations.
507   EXPECT_EQ(module->MakeComputationPostOrder().size(), 1);
508 }
509 
510 }  // namespace
511 }  // namespace xla
512