xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/call_graph_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/call_graph.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 
30 namespace xla {
31 namespace {
32 
33 using ::testing::UnorderedElementsAre;
34 
35 class CallGraphTest : public HloTestBase {
36  protected:
37   // Build and return a trivial computation taking and returning a scalar.
MakeScalarComputation(HloOpcode opcode=HloOpcode::kNegate)38   std::unique_ptr<HloComputation> MakeScalarComputation(
39       HloOpcode opcode = HloOpcode::kNegate) {
40     HloComputation::Builder builder(TestName() + ".ScalarComputation");
41     HloInstruction* param0 = builder.AddInstruction(
42         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
43     builder.AddInstruction(
44         HloInstruction::CreateUnary(kScalarShape, opcode, param0));
45     return builder.Build();
46   }
47 
48   // Build and return a computation which takes a scalar and maps (kMap) the
49   // given computation to the value 'callsites' number of times.
MakeMappingComputation(HloComputation * map_computation,int64_t callsites)50   std::unique_ptr<HloComputation> MakeMappingComputation(
51       HloComputation* map_computation, int64_t callsites) {
52     HloComputation::Builder builder(TestName() + ".MappingComputation");
53     HloInstruction* param0 = builder.AddInstruction(
54         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
55     HloInstruction* last_value = param0;
56     for (int64_t i = 0; i < callsites; ++i) {
57       last_value = builder.AddInstruction(HloInstruction::CreateMap(
58           kScalarShape, {last_value}, map_computation));
59     }
60     return builder.Build();
61   }
62 
63   // Build and return a computation which takes a scalar and calls (kCall) the
64   // given computation with value 'callsites' number of times.
MakeCallingComputation(HloComputation * callee_computation,int64_t callsites,const std::string & suffix=".CallingComputation")65   std::unique_ptr<HloComputation> MakeCallingComputation(
66       HloComputation* callee_computation, int64_t callsites,
67       const std::string& suffix = ".CallingComputation") {
68     HloComputation::Builder builder(TestName() + suffix);
69     HloInstruction* param0 = builder.AddInstruction(
70         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
71     HloInstruction* last_value = param0;
72     for (int64_t i = 0; i < callsites; ++i) {
73       last_value = builder.AddInstruction(HloInstruction::CreateCall(
74           kScalarShape, {last_value}, callee_computation));
75     }
76     return builder.Build();
77   }
78 
79   // Build and return a computation which takes a scalar and returns a PRED
80   // value.
MakeConditionComputation()81   std::unique_ptr<HloComputation> MakeConditionComputation() {
82     HloComputation::Builder builder(TestName() + ".ConditionComputation");
83     HloInstruction* param0 = builder.AddInstruction(
84         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
85     HloInstruction* zero = builder.AddInstruction(
86         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
87     builder.AddInstruction(
88         HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
89                                       zero, ComparisonDirection::kGt));
90     return builder.Build();
91   }
92 
93   const Shape kScalarShape = ShapeUtil::MakeShape(F32, {});
94 };
95 
TEST_F(CallGraphTest,SingletonComputation)96 TEST_F(CallGraphTest, SingletonComputation) {
97   // Test the call graph of a module with a single computation.
98   auto module = CreateNewVerifiedModule();
99   HloComputation* computation =
100       module->AddEntryComputation(MakeScalarComputation());
101   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
102   EXPECT_EQ(1, call_graph->nodes().size());
103   EXPECT_TRUE(call_graph->IsFlattened());
104 
105   const CallGraphNode& node = call_graph->GetNode(computation);
106   EXPECT_EQ(computation, node.computation());
107   EXPECT_EQ(node.depth(), 0);
108   EXPECT_TRUE(node.callsites().empty());
109   EXPECT_TRUE(node.callees().empty());
110   EXPECT_TRUE(node.caller_callsites().empty());
111   EXPECT_TRUE(node.callers().empty());
112   EXPECT_EQ(CallContext::kControlFlow, node.context());
113 }
114 
TEST_F(CallGraphTest,UnreachableComputation)115 TEST_F(CallGraphTest, UnreachableComputation) {
116   // Test the call graph of a module with an entry computation and an
117   // unreachable computation.
118   auto module = CreateNewVerifiedModule();
119   HloComputation* entry_computation =
120       module->AddEntryComputation(MakeScalarComputation());
121   HloComputation* unreachable_computation =
122       module->AddEmbeddedComputation(MakeScalarComputation());
123 
124   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
125   EXPECT_EQ(2, call_graph->nodes().size());
126 
127   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
128   EXPECT_EQ(entry_node.depth(), 0);
129   EXPECT_EQ(entry_computation, entry_node.computation());
130   EXPECT_EQ(CallContext::kControlFlow, entry_node.context());
131 
132   const CallGraphNode& unreachable_node =
133       call_graph->GetNode(unreachable_computation);
134   EXPECT_EQ(unreachable_node.depth(), 0);
135   EXPECT_EQ(unreachable_computation, unreachable_node.computation());
136   EXPECT_EQ(CallContext::kControlFlow, unreachable_node.context());
137 }
138 
TEST_F(CallGraphTest,ParallelComputation)139 TEST_F(CallGraphTest, ParallelComputation) {
140   // Test a call graph of a module with an entry computation which calls another
141   // computation in a parallel context via kMap.
142   auto module = CreateNewVerifiedModule();
143   HloComputation* map_computation =
144       module->AddEmbeddedComputation(MakeScalarComputation());
145   HloComputation* entry_computation = module->AddEntryComputation(
146       MakeMappingComputation(map_computation, /*callsites=*/5));
147 
148   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
149   EXPECT_EQ(2, call_graph->nodes().size());
150 
151   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
152   EXPECT_EQ(entry_computation, entry_node.computation());
153   EXPECT_EQ(entry_node.depth(), 0);
154   EXPECT_EQ(CallContext::kControlFlow, entry_node.context());
155   EXPECT_EQ(5, entry_node.callsites().size());
156   EXPECT_EQ(1, entry_node.callees().size());
157   EXPECT_TRUE(entry_node.caller_callsites().empty());
158   EXPECT_TRUE(call_graph->GetComputationCallers(entry_computation).empty());
159   EXPECT_TRUE(entry_node.callers().empty());
160 
161   const CallGraphNode& map_node = call_graph->GetNode(map_computation);
162   EXPECT_EQ(map_computation, map_node.computation());
163   EXPECT_EQ(map_node.depth(), 1);
164   EXPECT_EQ(CallContext::kEmbedded, map_node.context());
165   EXPECT_TRUE(map_node.callsites().empty());
166   EXPECT_TRUE(map_node.callees().empty());
167   EXPECT_EQ(5, map_node.caller_callsites().size());
168   EXPECT_EQ(5, call_graph->GetComputationCallers(map_computation).size());
169   EXPECT_EQ(1, map_node.callers().size());
170 }
171 
TEST_F(CallGraphTest,SequentialComputations)172 TEST_F(CallGraphTest, SequentialComputations) {
173   // Test a call graph of a module with an entry computation which calls another
174   // computation in a sequential context via kCall.
175   auto module = CreateNewVerifiedModule();
176   HloComputation* called_computation =
177       module->AddEmbeddedComputation(MakeScalarComputation());
178   HloComputation* entry_computation = module->AddEntryComputation(
179       MakeCallingComputation(called_computation, /*callsites=*/3));
180 
181   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
182   EXPECT_EQ(2, call_graph->nodes().size());
183 
184   // The called computation is only called from one other computation, but there
185   // are multiple callsites.
186   EXPECT_FALSE(call_graph->IsFlattened());
187 
188   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
189   EXPECT_EQ(entry_computation, entry_node.computation());
190   EXPECT_EQ(CallContext::kControlFlow, entry_node.context());
191   EXPECT_EQ(3, entry_node.callsites().size());
192   EXPECT_EQ(1, entry_node.callees().size());
193   EXPECT_TRUE(entry_node.caller_callsites().empty());
194   EXPECT_TRUE(call_graph->GetComputationCallers(entry_computation).empty());
195   EXPECT_TRUE(entry_node.callers().empty());
196 
197   const CallGraphNode& called_node = call_graph->GetNode(called_computation);
198   EXPECT_EQ(called_computation, called_node.computation());
199   EXPECT_EQ(CallContext::kControlFlow, called_node.context());
200   EXPECT_TRUE(called_node.callsites().empty());
201   EXPECT_TRUE(called_node.callees().empty());
202   EXPECT_EQ(3, called_node.caller_callsites().size());
203   EXPECT_EQ(3, call_graph->GetComputationCallers(called_computation).size());
204   EXPECT_EQ(1, called_node.callers().size());
205 }
206 
TEST_F(CallGraphTest,ContextBothComputations)207 TEST_F(CallGraphTest, ContextBothComputations) {
208   // Test a call graph of a module with an entry computation which calls another
209   // computation in both a parallel and sequential context.
210   auto module = CreateNewVerifiedModule();
211   HloComputation* subcomputation =
212       module->AddEmbeddedComputation(MakeScalarComputation());
213 
214   HloComputation::Builder builder(TestName());
215   HloInstruction* param0 = builder.AddInstruction(
216       HloInstruction::CreateParameter(0, kScalarShape, "param0"));
217   HloInstruction* call = builder.AddInstruction(
218       HloInstruction::CreateCall(kScalarShape, {param0}, subcomputation));
219   HloInstruction* map = builder.AddInstruction(
220       HloInstruction::CreateMap(kScalarShape, {call}, subcomputation));
221   HloComputation* entry_computation =
222       module->AddEntryComputation(builder.Build());
223 
224   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
225   EXPECT_EQ(2, call_graph->nodes().size());
226 
227   EXPECT_FALSE(call_graph->IsFlattened());
228 
229   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
230   EXPECT_EQ(entry_computation, entry_node.computation());
231   EXPECT_EQ(2, entry_node.callsites().size());
232 
233   const CallSite& call_callsite = entry_node.callsites()[0];
234   EXPECT_EQ(call, call_callsite.instruction());
235   EXPECT_THAT(call_callsite.called_computations(),
236               UnorderedElementsAre(subcomputation));
237   EXPECT_EQ(CallContext::kControlFlow, call_callsite.context());
238   EXPECT_EQ(entry_node.GetCallSite(call), &call_callsite);
239 
240   const CallSite& map_callsite = entry_node.callsites()[1];
241   EXPECT_EQ(map, map_callsite.instruction());
242   EXPECT_THAT(map_callsite.called_computations(),
243               UnorderedElementsAre(subcomputation));
244   EXPECT_EQ(CallContext::kEmbedded, map_callsite.context());
245   EXPECT_EQ(entry_node.GetCallSite(map), &map_callsite);
246 
247   const CallGraphNode& sub_node = call_graph->GetNode(subcomputation);
248   EXPECT_EQ(sub_node.depth(), 1);
249   EXPECT_EQ(CallContext::kBoth, sub_node.context());
250 }
251 
TEST_F(CallGraphTest,ComputationWithConditional)252 TEST_F(CallGraphTest, ComputationWithConditional) {
253   // Test a call graph of a module with a conditional.
254   auto module = CreateNewVerifiedModule();
255   HloComputation* true_computation =
256       module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil));
257   HloComputation* false_computation =
258       module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kFloor));
259 
260   HloComputation::Builder builder(TestName());
261   HloInstruction* pred = builder.AddInstruction(
262       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
263   HloInstruction* const1 = builder.AddInstruction(
264       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
265   HloInstruction* const2 = builder.AddInstruction(
266       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.6f)));
267   HloInstruction* conditional =
268       builder.AddInstruction(HloInstruction::CreateConditional(
269           kScalarShape, pred, const1, true_computation, const2,
270           false_computation));
271   HloComputation* entry_computation =
272       module->AddEntryComputation(builder.Build());
273 
274   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
275 
276   EXPECT_EQ(3, call_graph->nodes().size());
277 
278   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
279   EXPECT_EQ(entry_node.depth(), 0);
280   EXPECT_EQ(entry_computation, entry_node.computation());
281   EXPECT_EQ(1, entry_node.callsites().size());
282 
283   const CallSite& conditional_callsite = entry_node.callsites()[0];
284   EXPECT_EQ(conditional, conditional_callsite.instruction());
285   EXPECT_THAT(conditional_callsite.called_computations(),
286               UnorderedElementsAre(true_computation, false_computation));
287   EXPECT_EQ(CallContext::kControlFlow, conditional_callsite.context());
288   EXPECT_EQ(entry_node.GetCallSite(conditional), &conditional_callsite);
289 
290   const CallGraphNode& true_node = call_graph->GetNode(true_computation);
291   EXPECT_EQ(true_node.depth(), 1);
292   EXPECT_TRUE(true_node.callees().empty());
293   EXPECT_EQ(1, true_node.callers().size());
294   EXPECT_EQ(entry_computation, true_node.callers()[0]);
295 
296   const CallGraphNode& false_node = call_graph->GetNode(false_computation);
297   EXPECT_EQ(false_node.depth(), 1);
298   EXPECT_TRUE(false_node.callees().empty());
299   EXPECT_EQ(1, false_node.callers().size());
300   EXPECT_EQ(entry_computation, false_node.callers()[0]);
301 }
302 
TEST_F(CallGraphTest,ComplexGraph)303 TEST_F(CallGraphTest, ComplexGraph) {
304   // Test a call graph of a module with several computation called in various
305   // contexts. The call graph looks like:
306   //
307   //      entry
308   //      /  |
309   //     a   |
310   //   / | \ |
311   //  b  |  cond
312   //   \ |
313   //    c
314   //
315   // Calls are made via kCall, kWhile, and kMap instructions.
316   auto module = CreateNewVerifiedModule();
317   HloComputation* cond_computation =
318       module->AddEmbeddedComputation(MakeConditionComputation());
319   HloComputation* c_computation =
320       module->AddEmbeddedComputation(MakeScalarComputation());
321   HloComputation* b_computation = module->AddEmbeddedComputation(
322       MakeMappingComputation(c_computation, /*callsites=*/1));
323 
324   HloComputation* a_computation;
325   {
326     HloComputation::Builder builder(TestName() + ".a");
327     HloInstruction* param0 = builder.AddInstruction(
328         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
329     HloInstruction* call = builder.AddInstruction(
330         HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
331     builder.AddInstruction(HloInstruction::CreateWhile(
332         kScalarShape, cond_computation, b_computation, call));
333     a_computation = module->AddEmbeddedComputation(builder.Build());
334   }
335 
336   HloComputation* entry_computation;
337   {
338     HloComputation::Builder builder(TestName() + ".entry");
339     HloInstruction* param0 = builder.AddInstruction(
340         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
341     builder.AddInstruction(HloInstruction::CreateWhile(
342         kScalarShape, cond_computation, a_computation, param0));
343     entry_computation = module->AddEntryComputation(builder.Build());
344   }
345 
346   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
347   EXPECT_EQ(5, call_graph->nodes().size());
348   EXPECT_FALSE(call_graph->IsFlattened());
349 
350   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
351   const CallGraphNode& a_node = call_graph->GetNode(a_computation);
352   const CallGraphNode& b_node = call_graph->GetNode(b_computation);
353   const CallGraphNode& c_node = call_graph->GetNode(c_computation);
354   const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
355 
356   // Verify depths.
357   EXPECT_EQ(entry_node.depth(), 0);
358   EXPECT_EQ(a_node.depth(), 1);
359   EXPECT_EQ(b_node.depth(), 2);
360   EXPECT_EQ(c_node.depth(), 3);
361   EXPECT_EQ(cond_node.depth(), 2);
362 
363   // Entry computation has one while instruction calling two computations
364   // (cond_computation and a_computation).
365   ASSERT_EQ(1, entry_node.callsites().size());
366   auto called_computations = entry_node.callsites()[0].called_computations();
367   EXPECT_THAT(called_computations,
368               UnorderedElementsAre(cond_computation, a_computation));
369   EXPECT_EQ(CallContext::kControlFlow, entry_node.context());
370 
371   EXPECT_TRUE(c_node.callsites().empty());
372   EXPECT_THAT(c_node.callers(),
373               UnorderedElementsAre(a_computation, b_computation));
374   EXPECT_EQ(CallContext::kBoth, c_node.context());
375 
376   // Visit the graph and verify nodes were visited in callee-before-caller
377   // order.
378   std::vector<const HloComputation*> visited;
379   TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
380     visited.push_back(node.computation());
381     return OkStatus();
382   }));
383   EXPECT_EQ(visited.size(), 5);
384   // All values in visited should be unique.
385   EXPECT_EQ(
386       absl::flat_hash_set<const HloComputation*>(visited.begin(), visited.end())
387           .size(),
388       5);
389 
390   // Verify visitation order of some computations in the graph.
391   auto index_of = [&visited](const HloComputation* comp) {
392     auto it = absl::c_find(visited, comp);
393     EXPECT_NE(it, visited.end());
394     return std::distance(visited.begin(), it);
395   };
396   EXPECT_EQ(4, index_of(entry_computation));
397   EXPECT_LT(index_of(cond_computation), index_of(a_computation));
398   EXPECT_LT(index_of(c_computation), index_of(b_computation));
399   EXPECT_LT(index_of(b_computation), index_of(a_computation));
400 
401   // Verify dominance relations between computation in the graph.
402 
403   // Entry dominates everybody, and is dominated by no one except itself.
404   EXPECT_TRUE(call_graph->Dominates(entry_computation, entry_computation));
405   EXPECT_TRUE(call_graph->Dominates(entry_computation, a_computation));
406   EXPECT_TRUE(call_graph->Dominates(entry_computation, b_computation));
407   EXPECT_TRUE(call_graph->Dominates(entry_computation, c_computation));
408   EXPECT_TRUE(call_graph->Dominates(entry_computation, cond_computation));
409   EXPECT_FALSE(call_graph->Dominates(a_computation, entry_computation));
410   EXPECT_FALSE(call_graph->Dominates(b_computation, entry_computation));
411   EXPECT_FALSE(call_graph->Dominates(c_computation, entry_computation));
412   EXPECT_FALSE(call_graph->Dominates(cond_computation, entry_computation));
413 
414   // 'a' only dominates 'b' and 'c'.
415   EXPECT_TRUE(call_graph->Dominates(a_computation, a_computation));
416   EXPECT_TRUE(call_graph->Dominates(a_computation, b_computation));
417   EXPECT_TRUE(call_graph->Dominates(a_computation, c_computation));
418   EXPECT_FALSE(call_graph->Dominates(b_computation, a_computation));
419   EXPECT_FALSE(call_graph->Dominates(c_computation, a_computation));
420   EXPECT_FALSE(call_graph->Dominates(a_computation, cond_computation));
421 
422   EXPECT_TRUE(call_graph->Dominates(b_computation, b_computation));
423   EXPECT_FALSE(call_graph->Dominates(b_computation, c_computation));
424   EXPECT_FALSE(call_graph->Dominates(b_computation, cond_computation));
425 
426   EXPECT_TRUE(call_graph->Dominates(c_computation, c_computation));
427   EXPECT_FALSE(call_graph->Dominates(c_computation, cond_computation));
428   EXPECT_FALSE(call_graph->Dominates(cond_computation, c_computation));
429 
430   EXPECT_TRUE(call_graph->Dominates(cond_computation, cond_computation));
431 }
432 
TEST_F(CallGraphTest,ComplexGraphNearestAncestors)433 TEST_F(CallGraphTest, ComplexGraphNearestAncestors) {
434   // Test NearestAncestorsInSameComputation on a call graph of a module with
435   // several computation called in various contexts. The call graph looks like:
436   //
437   //      entry
438   //      /  |
439   //     a   |
440   //   / | \ |
441   //  b  |  cond
442   //   \ |
443   //    c
444   //
445   // Calls are made via kCall, kWhile, and kMap instructions.
446   auto module = CreateNewVerifiedModule();
447   HloComputation* cond_computation =
448       module->AddEmbeddedComputation(MakeConditionComputation());
449   HloComputation* c_computation =
450       module->AddEmbeddedComputation(MakeScalarComputation());
451   HloComputation* b_computation = module->AddEmbeddedComputation(
452       MakeMappingComputation(c_computation, /*callsites=*/1));
453   HloInstruction* b_map = b_computation->root_instruction();
454 
455   HloComputation* a_computation;
456   HloInstruction* a_call;
457   HloInstruction* a_while;
458   {
459     HloComputation::Builder builder(TestName() + ".a");
460     HloInstruction* param0 = builder.AddInstruction(
461         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
462     a_call = builder.AddInstruction(
463         HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
464     a_while = builder.AddInstruction(HloInstruction::CreateWhile(
465         kScalarShape, cond_computation, b_computation, a_call));
466     a_computation = module->AddEmbeddedComputation(builder.Build());
467   }
468 
469   HloComputation* entry_computation;
470   HloInstruction* entry_while;
471   {
472     HloComputation::Builder builder(TestName() + ".entry");
473     HloInstruction* param0 = builder.AddInstruction(
474         HloInstruction::CreateParameter(0, kScalarShape, "param0"));
475     entry_while = builder.AddInstruction(HloInstruction::CreateWhile(
476         kScalarShape, cond_computation, a_computation, param0));
477     entry_computation = module->AddEntryComputation(builder.Build());
478   }
479 
480   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
481   EXPECT_EQ(5, call_graph->nodes().size());
482 
483   // Verify NearestAncestorsInSameComputation for various instructions in the
484   // module.
485   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_call, a_call),
486             std::make_pair(a_call, a_call));
487 
488   // c_computation is called from more than one site, so
489   // NearestAncestorsInSameComputation bails and returns nullptrs.
490   std::pair<HloInstruction*, HloInstruction*> null_pair = {nullptr, nullptr};
491   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(
492                 b_map, c_computation->root_instruction()),
493             null_pair);
494 
495   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, entry_while),
496             std::make_pair(entry_while, entry_while));
497   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, a_call),
498             std::make_pair(a_while, a_call));
499   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, a_call),
500             std::make_pair(a_while, a_call));
501   EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, b_map),
502             std::make_pair(a_while, a_while));
503 }
504 
TEST_F(CallGraphTest,VisitSingletonComputation)505 TEST_F(CallGraphTest, VisitSingletonComputation) {
506   // Test the call graph visitor with a call graph with a single node.
507   auto module = CreateNewVerifiedModule();
508   HloComputation* computation =
509       module->AddEntryComputation(MakeScalarComputation());
510   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
511 
512   std::vector<HloComputation*> visited;
513   TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
514     visited.push_back(node.computation());
515     return OkStatus();
516   }));
517   EXPECT_THAT(visited, UnorderedElementsAre(computation));
518 }
519 
TEST_F(CallGraphTest,VisitUnreachableComputation)520 TEST_F(CallGraphTest, VisitUnreachableComputation) {
521   // Test the call graph visitor with a call graph with an unreachable node.
522   auto module = CreateNewVerifiedModule();
523   HloComputation* entry_computation =
524       module->AddEntryComputation(MakeScalarComputation());
525   HloComputation* unreachable_computation =
526       module->AddEmbeddedComputation(MakeScalarComputation());
527   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
528 
529   // Test visitation of only reachable nodes.
530   {
531     std::vector<const HloComputation*> visited;
532     TF_ASSERT_OK(call_graph->VisitNodes(
533         [&visited](const CallGraphNode& node) {
534           visited.push_back(node.computation());
535           return OkStatus();
536         },
537         /*visit_unreachable_nodes=*/false));
538     EXPECT_EQ(visited.size(), 1);
539     EXPECT_EQ(visited[0], entry_computation);
540   }
541 
542   // Test visitation of all nodes (reachable and unreachable).
543   {
544     std::vector<HloComputation*> visited;
545     TF_ASSERT_OK(call_graph->VisitNodes(
546         [&visited](const CallGraphNode& node) {
547           visited.push_back(node.computation());
548           return OkStatus();
549         },
550         /*visit_unreachable_nodes=*/true));
551     EXPECT_EQ(visited.size(), 2);
552     EXPECT_THAT(visited, UnorderedElementsAre(entry_computation,
553                                               unreachable_computation));
554   }
555 }
556 
TEST_F(CallGraphTest,VisitWithError)557 TEST_F(CallGraphTest, VisitWithError) {
558   // Test that the call graph visitor properly propagates errors.
559   auto module = CreateNewVerifiedModule();
560   module->AddEntryComputation(MakeScalarComputation());
561   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
562 
563   Status status = call_graph->VisitNodes(
564       [](const CallGraphNode&) { return InternalError("Visitation failed"); });
565 
566   ASSERT_FALSE(status.ok());
567   ASSERT_EQ(status.code(), tensorflow::error::INTERNAL);
568   ASSERT_THAT(status.error_message(),
569               ::testing::HasSubstr("Visitation failed"));
570 }
571 
572 }  // namespace
573 }  // namespace xla
574