1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/call_graph.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29
30 namespace xla {
31 namespace {
32
33 using ::testing::UnorderedElementsAre;
34
35 class CallGraphTest : public HloTestBase {
36 protected:
37 // Build and return a trivial computation taking and returning a scalar.
MakeScalarComputation(HloOpcode opcode=HloOpcode::kNegate)38 std::unique_ptr<HloComputation> MakeScalarComputation(
39 HloOpcode opcode = HloOpcode::kNegate) {
40 HloComputation::Builder builder(TestName() + ".ScalarComputation");
41 HloInstruction* param0 = builder.AddInstruction(
42 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
43 builder.AddInstruction(
44 HloInstruction::CreateUnary(kScalarShape, opcode, param0));
45 return builder.Build();
46 }
47
48 // Build and return a computation which takes a scalar and maps (kMap) the
49 // given computation to the value 'callsites' number of times.
MakeMappingComputation(HloComputation * map_computation,int64_t callsites)50 std::unique_ptr<HloComputation> MakeMappingComputation(
51 HloComputation* map_computation, int64_t callsites) {
52 HloComputation::Builder builder(TestName() + ".MappingComputation");
53 HloInstruction* param0 = builder.AddInstruction(
54 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
55 HloInstruction* last_value = param0;
56 for (int64_t i = 0; i < callsites; ++i) {
57 last_value = builder.AddInstruction(HloInstruction::CreateMap(
58 kScalarShape, {last_value}, map_computation));
59 }
60 return builder.Build();
61 }
62
63 // Build and return a computation which takes a scalar and calls (kCall) the
64 // given computation with value 'callsites' number of times.
MakeCallingComputation(HloComputation * callee_computation,int64_t callsites,const std::string & suffix=".CallingComputation")65 std::unique_ptr<HloComputation> MakeCallingComputation(
66 HloComputation* callee_computation, int64_t callsites,
67 const std::string& suffix = ".CallingComputation") {
68 HloComputation::Builder builder(TestName() + suffix);
69 HloInstruction* param0 = builder.AddInstruction(
70 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
71 HloInstruction* last_value = param0;
72 for (int64_t i = 0; i < callsites; ++i) {
73 last_value = builder.AddInstruction(HloInstruction::CreateCall(
74 kScalarShape, {last_value}, callee_computation));
75 }
76 return builder.Build();
77 }
78
79 // Build and return a computation which takes a scalar and returns a PRED
80 // value.
MakeConditionComputation()81 std::unique_ptr<HloComputation> MakeConditionComputation() {
82 HloComputation::Builder builder(TestName() + ".ConditionComputation");
83 HloInstruction* param0 = builder.AddInstruction(
84 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
85 HloInstruction* zero = builder.AddInstruction(
86 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
87 builder.AddInstruction(
88 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
89 zero, ComparisonDirection::kGt));
90 return builder.Build();
91 }
92
93 const Shape kScalarShape = ShapeUtil::MakeShape(F32, {});
94 };
95
TEST_F(CallGraphTest,SingletonComputation)96 TEST_F(CallGraphTest, SingletonComputation) {
97 // Test the call graph of a module with a single computation.
98 auto module = CreateNewVerifiedModule();
99 HloComputation* computation =
100 module->AddEntryComputation(MakeScalarComputation());
101 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
102 EXPECT_EQ(1, call_graph->nodes().size());
103 EXPECT_TRUE(call_graph->IsFlattened());
104
105 const CallGraphNode& node = call_graph->GetNode(computation);
106 EXPECT_EQ(computation, node.computation());
107 EXPECT_EQ(node.depth(), 0);
108 EXPECT_TRUE(node.callsites().empty());
109 EXPECT_TRUE(node.callees().empty());
110 EXPECT_TRUE(node.caller_callsites().empty());
111 EXPECT_TRUE(node.callers().empty());
112 EXPECT_EQ(CallContext::kControlFlow, node.context());
113 }
114
TEST_F(CallGraphTest,UnreachableComputation)115 TEST_F(CallGraphTest, UnreachableComputation) {
116 // Test the call graph of a module with an entry computation and an
117 // unreachable computation.
118 auto module = CreateNewVerifiedModule();
119 HloComputation* entry_computation =
120 module->AddEntryComputation(MakeScalarComputation());
121 HloComputation* unreachable_computation =
122 module->AddEmbeddedComputation(MakeScalarComputation());
123
124 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
125 EXPECT_EQ(2, call_graph->nodes().size());
126
127 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
128 EXPECT_EQ(entry_node.depth(), 0);
129 EXPECT_EQ(entry_computation, entry_node.computation());
130 EXPECT_EQ(CallContext::kControlFlow, entry_node.context());
131
132 const CallGraphNode& unreachable_node =
133 call_graph->GetNode(unreachable_computation);
134 EXPECT_EQ(unreachable_node.depth(), 0);
135 EXPECT_EQ(unreachable_computation, unreachable_node.computation());
136 EXPECT_EQ(CallContext::kControlFlow, unreachable_node.context());
137 }
138
TEST_F(CallGraphTest,ParallelComputation)139 TEST_F(CallGraphTest, ParallelComputation) {
140 // Test a call graph of a module with an entry computation which calls another
141 // computation in a parallel context via kMap.
142 auto module = CreateNewVerifiedModule();
143 HloComputation* map_computation =
144 module->AddEmbeddedComputation(MakeScalarComputation());
145 HloComputation* entry_computation = module->AddEntryComputation(
146 MakeMappingComputation(map_computation, /*callsites=*/5));
147
148 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
149 EXPECT_EQ(2, call_graph->nodes().size());
150
151 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
152 EXPECT_EQ(entry_computation, entry_node.computation());
153 EXPECT_EQ(entry_node.depth(), 0);
154 EXPECT_EQ(CallContext::kControlFlow, entry_node.context());
155 EXPECT_EQ(5, entry_node.callsites().size());
156 EXPECT_EQ(1, entry_node.callees().size());
157 EXPECT_TRUE(entry_node.caller_callsites().empty());
158 EXPECT_TRUE(call_graph->GetComputationCallers(entry_computation).empty());
159 EXPECT_TRUE(entry_node.callers().empty());
160
161 const CallGraphNode& map_node = call_graph->GetNode(map_computation);
162 EXPECT_EQ(map_computation, map_node.computation());
163 EXPECT_EQ(map_node.depth(), 1);
164 EXPECT_EQ(CallContext::kEmbedded, map_node.context());
165 EXPECT_TRUE(map_node.callsites().empty());
166 EXPECT_TRUE(map_node.callees().empty());
167 EXPECT_EQ(5, map_node.caller_callsites().size());
168 EXPECT_EQ(5, call_graph->GetComputationCallers(map_computation).size());
169 EXPECT_EQ(1, map_node.callers().size());
170 }
171
TEST_F(CallGraphTest,SequentialComputations)172 TEST_F(CallGraphTest, SequentialComputations) {
173 // Test a call graph of a module with an entry computation which calls another
174 // computation in a sequential context via kCall.
175 auto module = CreateNewVerifiedModule();
176 HloComputation* called_computation =
177 module->AddEmbeddedComputation(MakeScalarComputation());
178 HloComputation* entry_computation = module->AddEntryComputation(
179 MakeCallingComputation(called_computation, /*callsites=*/3));
180
181 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
182 EXPECT_EQ(2, call_graph->nodes().size());
183
184 // The called computation is only called from one other computation, but there
185 // are multiple callsites.
186 EXPECT_FALSE(call_graph->IsFlattened());
187
188 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
189 EXPECT_EQ(entry_computation, entry_node.computation());
190 EXPECT_EQ(CallContext::kControlFlow, entry_node.context());
191 EXPECT_EQ(3, entry_node.callsites().size());
192 EXPECT_EQ(1, entry_node.callees().size());
193 EXPECT_TRUE(entry_node.caller_callsites().empty());
194 EXPECT_TRUE(call_graph->GetComputationCallers(entry_computation).empty());
195 EXPECT_TRUE(entry_node.callers().empty());
196
197 const CallGraphNode& called_node = call_graph->GetNode(called_computation);
198 EXPECT_EQ(called_computation, called_node.computation());
199 EXPECT_EQ(CallContext::kControlFlow, called_node.context());
200 EXPECT_TRUE(called_node.callsites().empty());
201 EXPECT_TRUE(called_node.callees().empty());
202 EXPECT_EQ(3, called_node.caller_callsites().size());
203 EXPECT_EQ(3, call_graph->GetComputationCallers(called_computation).size());
204 EXPECT_EQ(1, called_node.callers().size());
205 }
206
TEST_F(CallGraphTest,ContextBothComputations)207 TEST_F(CallGraphTest, ContextBothComputations) {
208 // Test a call graph of a module with an entry computation which calls another
209 // computation in both a parallel and sequential context.
210 auto module = CreateNewVerifiedModule();
211 HloComputation* subcomputation =
212 module->AddEmbeddedComputation(MakeScalarComputation());
213
214 HloComputation::Builder builder(TestName());
215 HloInstruction* param0 = builder.AddInstruction(
216 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
217 HloInstruction* call = builder.AddInstruction(
218 HloInstruction::CreateCall(kScalarShape, {param0}, subcomputation));
219 HloInstruction* map = builder.AddInstruction(
220 HloInstruction::CreateMap(kScalarShape, {call}, subcomputation));
221 HloComputation* entry_computation =
222 module->AddEntryComputation(builder.Build());
223
224 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
225 EXPECT_EQ(2, call_graph->nodes().size());
226
227 EXPECT_FALSE(call_graph->IsFlattened());
228
229 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
230 EXPECT_EQ(entry_computation, entry_node.computation());
231 EXPECT_EQ(2, entry_node.callsites().size());
232
233 const CallSite& call_callsite = entry_node.callsites()[0];
234 EXPECT_EQ(call, call_callsite.instruction());
235 EXPECT_THAT(call_callsite.called_computations(),
236 UnorderedElementsAre(subcomputation));
237 EXPECT_EQ(CallContext::kControlFlow, call_callsite.context());
238 EXPECT_EQ(entry_node.GetCallSite(call), &call_callsite);
239
240 const CallSite& map_callsite = entry_node.callsites()[1];
241 EXPECT_EQ(map, map_callsite.instruction());
242 EXPECT_THAT(map_callsite.called_computations(),
243 UnorderedElementsAre(subcomputation));
244 EXPECT_EQ(CallContext::kEmbedded, map_callsite.context());
245 EXPECT_EQ(entry_node.GetCallSite(map), &map_callsite);
246
247 const CallGraphNode& sub_node = call_graph->GetNode(subcomputation);
248 EXPECT_EQ(sub_node.depth(), 1);
249 EXPECT_EQ(CallContext::kBoth, sub_node.context());
250 }
251
TEST_F(CallGraphTest,ComputationWithConditional)252 TEST_F(CallGraphTest, ComputationWithConditional) {
253 // Test a call graph of a module with a conditional.
254 auto module = CreateNewVerifiedModule();
255 HloComputation* true_computation =
256 module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil));
257 HloComputation* false_computation =
258 module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kFloor));
259
260 HloComputation::Builder builder(TestName());
261 HloInstruction* pred = builder.AddInstruction(
262 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
263 HloInstruction* const1 = builder.AddInstruction(
264 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
265 HloInstruction* const2 = builder.AddInstruction(
266 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.6f)));
267 HloInstruction* conditional =
268 builder.AddInstruction(HloInstruction::CreateConditional(
269 kScalarShape, pred, const1, true_computation, const2,
270 false_computation));
271 HloComputation* entry_computation =
272 module->AddEntryComputation(builder.Build());
273
274 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
275
276 EXPECT_EQ(3, call_graph->nodes().size());
277
278 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
279 EXPECT_EQ(entry_node.depth(), 0);
280 EXPECT_EQ(entry_computation, entry_node.computation());
281 EXPECT_EQ(1, entry_node.callsites().size());
282
283 const CallSite& conditional_callsite = entry_node.callsites()[0];
284 EXPECT_EQ(conditional, conditional_callsite.instruction());
285 EXPECT_THAT(conditional_callsite.called_computations(),
286 UnorderedElementsAre(true_computation, false_computation));
287 EXPECT_EQ(CallContext::kControlFlow, conditional_callsite.context());
288 EXPECT_EQ(entry_node.GetCallSite(conditional), &conditional_callsite);
289
290 const CallGraphNode& true_node = call_graph->GetNode(true_computation);
291 EXPECT_EQ(true_node.depth(), 1);
292 EXPECT_TRUE(true_node.callees().empty());
293 EXPECT_EQ(1, true_node.callers().size());
294 EXPECT_EQ(entry_computation, true_node.callers()[0]);
295
296 const CallGraphNode& false_node = call_graph->GetNode(false_computation);
297 EXPECT_EQ(false_node.depth(), 1);
298 EXPECT_TRUE(false_node.callees().empty());
299 EXPECT_EQ(1, false_node.callers().size());
300 EXPECT_EQ(entry_computation, false_node.callers()[0]);
301 }
302
TEST_F(CallGraphTest,ComplexGraph)303 TEST_F(CallGraphTest, ComplexGraph) {
304 // Test a call graph of a module with several computation called in various
305 // contexts. The call graph looks like:
306 //
307 // entry
308 // / |
309 // a |
310 // / | \ |
311 // b | cond
312 // \ |
313 // c
314 //
315 // Calls are made via kCall, kWhile, and kMap instructions.
316 auto module = CreateNewVerifiedModule();
317 HloComputation* cond_computation =
318 module->AddEmbeddedComputation(MakeConditionComputation());
319 HloComputation* c_computation =
320 module->AddEmbeddedComputation(MakeScalarComputation());
321 HloComputation* b_computation = module->AddEmbeddedComputation(
322 MakeMappingComputation(c_computation, /*callsites=*/1));
323
324 HloComputation* a_computation;
325 {
326 HloComputation::Builder builder(TestName() + ".a");
327 HloInstruction* param0 = builder.AddInstruction(
328 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
329 HloInstruction* call = builder.AddInstruction(
330 HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
331 builder.AddInstruction(HloInstruction::CreateWhile(
332 kScalarShape, cond_computation, b_computation, call));
333 a_computation = module->AddEmbeddedComputation(builder.Build());
334 }
335
336 HloComputation* entry_computation;
337 {
338 HloComputation::Builder builder(TestName() + ".entry");
339 HloInstruction* param0 = builder.AddInstruction(
340 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
341 builder.AddInstruction(HloInstruction::CreateWhile(
342 kScalarShape, cond_computation, a_computation, param0));
343 entry_computation = module->AddEntryComputation(builder.Build());
344 }
345
346 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
347 EXPECT_EQ(5, call_graph->nodes().size());
348 EXPECT_FALSE(call_graph->IsFlattened());
349
350 const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
351 const CallGraphNode& a_node = call_graph->GetNode(a_computation);
352 const CallGraphNode& b_node = call_graph->GetNode(b_computation);
353 const CallGraphNode& c_node = call_graph->GetNode(c_computation);
354 const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
355
356 // Verify depths.
357 EXPECT_EQ(entry_node.depth(), 0);
358 EXPECT_EQ(a_node.depth(), 1);
359 EXPECT_EQ(b_node.depth(), 2);
360 EXPECT_EQ(c_node.depth(), 3);
361 EXPECT_EQ(cond_node.depth(), 2);
362
363 // Entry computation has one while instruction calling two computations
364 // (cond_computation and a_computation).
365 ASSERT_EQ(1, entry_node.callsites().size());
366 auto called_computations = entry_node.callsites()[0].called_computations();
367 EXPECT_THAT(called_computations,
368 UnorderedElementsAre(cond_computation, a_computation));
369 EXPECT_EQ(CallContext::kControlFlow, entry_node.context());
370
371 EXPECT_TRUE(c_node.callsites().empty());
372 EXPECT_THAT(c_node.callers(),
373 UnorderedElementsAre(a_computation, b_computation));
374 EXPECT_EQ(CallContext::kBoth, c_node.context());
375
376 // Visit the graph and verify nodes were visited in callee-before-caller
377 // order.
378 std::vector<const HloComputation*> visited;
379 TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
380 visited.push_back(node.computation());
381 return OkStatus();
382 }));
383 EXPECT_EQ(visited.size(), 5);
384 // All values in visited should be unique.
385 EXPECT_EQ(
386 absl::flat_hash_set<const HloComputation*>(visited.begin(), visited.end())
387 .size(),
388 5);
389
390 // Verify visitation order of some computations in the graph.
391 auto index_of = [&visited](const HloComputation* comp) {
392 auto it = absl::c_find(visited, comp);
393 EXPECT_NE(it, visited.end());
394 return std::distance(visited.begin(), it);
395 };
396 EXPECT_EQ(4, index_of(entry_computation));
397 EXPECT_LT(index_of(cond_computation), index_of(a_computation));
398 EXPECT_LT(index_of(c_computation), index_of(b_computation));
399 EXPECT_LT(index_of(b_computation), index_of(a_computation));
400
401 // Verify dominance relations between computation in the graph.
402
403 // Entry dominates everybody, and is dominated by no one except itself.
404 EXPECT_TRUE(call_graph->Dominates(entry_computation, entry_computation));
405 EXPECT_TRUE(call_graph->Dominates(entry_computation, a_computation));
406 EXPECT_TRUE(call_graph->Dominates(entry_computation, b_computation));
407 EXPECT_TRUE(call_graph->Dominates(entry_computation, c_computation));
408 EXPECT_TRUE(call_graph->Dominates(entry_computation, cond_computation));
409 EXPECT_FALSE(call_graph->Dominates(a_computation, entry_computation));
410 EXPECT_FALSE(call_graph->Dominates(b_computation, entry_computation));
411 EXPECT_FALSE(call_graph->Dominates(c_computation, entry_computation));
412 EXPECT_FALSE(call_graph->Dominates(cond_computation, entry_computation));
413
414 // 'a' only dominates 'b' and 'c'.
415 EXPECT_TRUE(call_graph->Dominates(a_computation, a_computation));
416 EXPECT_TRUE(call_graph->Dominates(a_computation, b_computation));
417 EXPECT_TRUE(call_graph->Dominates(a_computation, c_computation));
418 EXPECT_FALSE(call_graph->Dominates(b_computation, a_computation));
419 EXPECT_FALSE(call_graph->Dominates(c_computation, a_computation));
420 EXPECT_FALSE(call_graph->Dominates(a_computation, cond_computation));
421
422 EXPECT_TRUE(call_graph->Dominates(b_computation, b_computation));
423 EXPECT_FALSE(call_graph->Dominates(b_computation, c_computation));
424 EXPECT_FALSE(call_graph->Dominates(b_computation, cond_computation));
425
426 EXPECT_TRUE(call_graph->Dominates(c_computation, c_computation));
427 EXPECT_FALSE(call_graph->Dominates(c_computation, cond_computation));
428 EXPECT_FALSE(call_graph->Dominates(cond_computation, c_computation));
429
430 EXPECT_TRUE(call_graph->Dominates(cond_computation, cond_computation));
431 }
432
TEST_F(CallGraphTest,ComplexGraphNearestAncestors)433 TEST_F(CallGraphTest, ComplexGraphNearestAncestors) {
434 // Test NearestAncestorsInSameComputation on a call graph of a module with
435 // several computation called in various contexts. The call graph looks like:
436 //
437 // entry
438 // / |
439 // a |
440 // / | \ |
441 // b | cond
442 // \ |
443 // c
444 //
445 // Calls are made via kCall, kWhile, and kMap instructions.
446 auto module = CreateNewVerifiedModule();
447 HloComputation* cond_computation =
448 module->AddEmbeddedComputation(MakeConditionComputation());
449 HloComputation* c_computation =
450 module->AddEmbeddedComputation(MakeScalarComputation());
451 HloComputation* b_computation = module->AddEmbeddedComputation(
452 MakeMappingComputation(c_computation, /*callsites=*/1));
453 HloInstruction* b_map = b_computation->root_instruction();
454
455 HloComputation* a_computation;
456 HloInstruction* a_call;
457 HloInstruction* a_while;
458 {
459 HloComputation::Builder builder(TestName() + ".a");
460 HloInstruction* param0 = builder.AddInstruction(
461 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
462 a_call = builder.AddInstruction(
463 HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
464 a_while = builder.AddInstruction(HloInstruction::CreateWhile(
465 kScalarShape, cond_computation, b_computation, a_call));
466 a_computation = module->AddEmbeddedComputation(builder.Build());
467 }
468
469 HloComputation* entry_computation;
470 HloInstruction* entry_while;
471 {
472 HloComputation::Builder builder(TestName() + ".entry");
473 HloInstruction* param0 = builder.AddInstruction(
474 HloInstruction::CreateParameter(0, kScalarShape, "param0"));
475 entry_while = builder.AddInstruction(HloInstruction::CreateWhile(
476 kScalarShape, cond_computation, a_computation, param0));
477 entry_computation = module->AddEntryComputation(builder.Build());
478 }
479
480 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
481 EXPECT_EQ(5, call_graph->nodes().size());
482
483 // Verify NearestAncestorsInSameComputation for various instructions in the
484 // module.
485 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_call, a_call),
486 std::make_pair(a_call, a_call));
487
488 // c_computation is called from more than one site, so
489 // NearestAncestorsInSameComputation bails and returns nullptrs.
490 std::pair<HloInstruction*, HloInstruction*> null_pair = {nullptr, nullptr};
491 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(
492 b_map, c_computation->root_instruction()),
493 null_pair);
494
495 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, entry_while),
496 std::make_pair(entry_while, entry_while));
497 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, a_call),
498 std::make_pair(a_while, a_call));
499 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, a_call),
500 std::make_pair(a_while, a_call));
501 EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, b_map),
502 std::make_pair(a_while, a_while));
503 }
504
TEST_F(CallGraphTest,VisitSingletonComputation)505 TEST_F(CallGraphTest, VisitSingletonComputation) {
506 // Test the call graph visitor with a call graph with a single node.
507 auto module = CreateNewVerifiedModule();
508 HloComputation* computation =
509 module->AddEntryComputation(MakeScalarComputation());
510 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
511
512 std::vector<HloComputation*> visited;
513 TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
514 visited.push_back(node.computation());
515 return OkStatus();
516 }));
517 EXPECT_THAT(visited, UnorderedElementsAre(computation));
518 }
519
TEST_F(CallGraphTest,VisitUnreachableComputation)520 TEST_F(CallGraphTest, VisitUnreachableComputation) {
521 // Test the call graph visitor with a call graph with an unreachable node.
522 auto module = CreateNewVerifiedModule();
523 HloComputation* entry_computation =
524 module->AddEntryComputation(MakeScalarComputation());
525 HloComputation* unreachable_computation =
526 module->AddEmbeddedComputation(MakeScalarComputation());
527 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
528
529 // Test visitation of only reachable nodes.
530 {
531 std::vector<const HloComputation*> visited;
532 TF_ASSERT_OK(call_graph->VisitNodes(
533 [&visited](const CallGraphNode& node) {
534 visited.push_back(node.computation());
535 return OkStatus();
536 },
537 /*visit_unreachable_nodes=*/false));
538 EXPECT_EQ(visited.size(), 1);
539 EXPECT_EQ(visited[0], entry_computation);
540 }
541
542 // Test visitation of all nodes (reachable and unreachable).
543 {
544 std::vector<HloComputation*> visited;
545 TF_ASSERT_OK(call_graph->VisitNodes(
546 [&visited](const CallGraphNode& node) {
547 visited.push_back(node.computation());
548 return OkStatus();
549 },
550 /*visit_unreachable_nodes=*/true));
551 EXPECT_EQ(visited.size(), 2);
552 EXPECT_THAT(visited, UnorderedElementsAre(entry_computation,
553 unreachable_computation));
554 }
555 }
556
TEST_F(CallGraphTest,VisitWithError)557 TEST_F(CallGraphTest, VisitWithError) {
558 // Test that the call graph visitor properly propagates errors.
559 auto module = CreateNewVerifiedModule();
560 module->AddEntryComputation(MakeScalarComputation());
561 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
562
563 Status status = call_graph->VisitNodes(
564 [](const CallGraphNode&) { return InternalError("Visitation failed"); });
565
566 ASSERT_FALSE(status.ok());
567 ASSERT_EQ(status.code(), tensorflow::error::INTERNAL);
568 ASSERT_THAT(status.error_message(),
569 ::testing::HasSubstr("Visitation failed"));
570 }
571
572 } // namespace
573 } // namespace xla
574