1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/deadness_analysis.h"
17
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/sendrecv_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/compiler/jit/deadness_analysis_internal.h"
25 #include "tensorflow/compiler/jit/defs.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/graph/algorithm.h"
32 #include "tensorflow/core/graph/graph_def_builder.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35
36 namespace tensorflow {
37 namespace {
38
HasInputsWithMismatchingDeadness(const DeadnessAnalysis & deadness_analysis,const Node & n)39 se::port::StatusOr<bool> HasInputsWithMismatchingDeadness(
40 const DeadnessAnalysis& deadness_analysis, const Node& n) {
41 std::optional<DeadnessAnalysis::DeadnessPredicate> pred;
42 for (const Edge* edge : n.in_edges()) {
43 TF_ASSIGN_OR_RETURN(
44 DeadnessAnalysis::DeadnessPredicate this_pred,
45 deadness_analysis.GetPredicateFor(edge->src(), edge->src_output()));
46 if (pred && *pred != this_pred) {
47 return true;
48 }
49 pred = this_pred;
50 }
51
52 return false;
53 }
54
55 using deadness_analysis_internal::ComputePredicates;
56 using deadness_analysis_internal::PredicateMapTy;
57
AnalyzeDeadness(Graph * graph,std::unique_ptr<DeadnessAnalysis> * result)58 Status AnalyzeDeadness(Graph* graph,
59 std::unique_ptr<DeadnessAnalysis>* result) {
60 FixupSourceAndSinkEdges(graph);
61 return DeadnessAnalysis::Run(*graph, result);
62 }
63
CreateSwitch(const Scope & root,const string & prefix)64 ops::Switch CreateSwitch(const Scope& root, const string& prefix) {
65 Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT);
66 Output predicate =
67 ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL);
68 return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate);
69 }
70
ControlOutputFor(const Output & o)71 TensorId ControlOutputFor(const Output& o) {
72 return {o.node()->name(), Graph::kControlSlot};
73 }
74
VLogGraphIfAsked(const Graph & graph)75 void VLogGraphIfAsked(const Graph& graph) {
76 if (VLOG_IS_ON(3)) {
77 GraphDef graph_def;
78 graph.ToGraphDef(&graph_def);
79 string serialized;
80 ::tensorflow::protobuf::TextFormat::PrintToString(graph_def, &serialized);
81 LOG(INFO) << serialized;
82 }
83 }
84
85 struct InductionVarInfo {
86 Output induction_var;
87 Output loop_cond;
88 };
89
90 // Creates an induction variable with the following structure (simplified for
91 // brevity):
92 //
93 // +---------------+
94 // | initial_value |
95 // +---------------+
96 // |
97 // |
98 // v
99 // +---------------+
100 // | Enter |
101 // +---------------+
102 // |
103 // |
104 // v
105 // +---------------+
106 // +> | Merge | -+
107 // | +---------------+ |
108 // | | |
109 // | | |
110 // | v |
111 // | +---------------+ |
112 // | | LessThan10 | |
113 // | +---------------+ |
114 // | | |
115 // | | |
116 // | v |
117 // | +---------------+ |
118 // +----+- | Switch | <+
119 // | | +---------------+
120 // | | |
121 // | | |
122 // | | v
123 // | | +---------------+
124 // | +- | AddOne |
125 // | +---------------+
126 // | +---------------+
127 // +-----> | Exit |
128 // +---------------+
CreateInductionVariable(const Scope & root,const string & prefix,const string & frame_name,const Output & initial_value)129 InductionVarInfo CreateInductionVariable(const Scope& root,
130 const string& prefix,
131 const string& frame_name,
132 const Output& initial_value) {
133 Output enter_initial_value = ops::internal::Enter(
134 root.WithOpName(prefix + "/enter"), initial_value, frame_name);
135
136 ops::Merge iv(root.WithOpName(prefix + "/iv"),
137 {enter_initial_value, enter_initial_value});
138 Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1);
139 Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10);
140 Output loop_cond_expr =
141 ops::Less(root.WithOpName(prefix + "/cond"), iv.output, final_value);
142 ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output,
143 loop_cond_expr);
144 ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
145 latch.output_false);
146 Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"),
147 latch.output_true, increment_by);
148 Output next_iteration =
149 ops::NextIteration(root.WithOpName(prefix + "/next_iteration"), iv_next);
150
151 CHECK(root.graph()
152 ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)
153 .ok());
154 root.graph()->AddControlEdge(iv.output.node(), increment_by.node());
155 root.graph()->AddControlEdge(iv.output.node(), final_value.node());
156
157 return {iv.output, loop_cond_expr};
158 }
159
CreateInductionVariable(const Scope & root,const string & prefix,const string & frame_name,int32_t init)160 InductionVarInfo CreateInductionVariable(const Scope& root,
161 const string& prefix,
162 const string& frame_name,
163 int32_t init) {
164 return CreateInductionVariable(
165 root, prefix, frame_name,
166 ops::Const(root.WithOpName(prefix + "/init"), init));
167 }
168
169 // Creates an induction variable with the following structure:
170 //
171 // +---------------+
172 // | initial_value |
173 // +---------------+
174 // |
175 // |
176 // v
177 // +---------------+
178 // | Enter |
179 // +---------------+
180 // |
181 // |
182 // v
183 // +---------------+
184 // | Merge | <+
185 // +---------------+ |
186 // | |
187 // | |
188 // v |
189 // +-----------+ +---------------+ |
190 // | loop_cond | --> | Switch | -+
191 // +-----------+ +---------------+
192 // |
193 // |
194 // v
195 // +---------------+
196 // | Exit |
197 // +---------------+
198 struct DependentInductionVar {
199 Output induction_var;
200 ops::Switch latch;
201 };
202
CreateDependentLoopInvariantValue(const Scope & root,const string & prefix,const string & frame_name,const Output & loop_cond,const Output & value)203 DependentInductionVar CreateDependentLoopInvariantValue(
204 const Scope& root, const string& prefix, const string& frame_name,
205 const Output& loop_cond, const Output& value) {
206 Output enter_value = ops::internal::Enter(root.WithOpName(prefix + "/enter"),
207 value, frame_name);
208 ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value});
209 ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
210 ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
211 latch.output_false);
212 Output next_iteration = ops::NextIteration(
213 root.WithOpName(prefix + "/next_iteration"), latch.output_true);
214 CHECK(root.graph()
215 ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)
216 .ok());
217 return {iv.output, latch};
218 }
219
CreateDependentLoopInvariantValue(const Scope & root,const string & prefix,const string & frame_name,const Output & loop_cond,int32_t value)220 DependentInductionVar CreateDependentLoopInvariantValue(
221 const Scope& root, const string& prefix, const string& frame_name,
222 const Output& loop_cond, int32_t value) {
223 return CreateDependentLoopInvariantValue(
224 root, prefix, frame_name, loop_cond,
225 ops::Const(root.WithOpName(prefix + "/init"), value));
226 }
227
TEST(DeadnessAnalysisTest,BasicPositive)228 TEST(DeadnessAnalysisTest, BasicPositive) {
229 Scope root = Scope::NewRootScope().ExitOnError();
230
231 ops::Switch sw = CreateSwitch(root, "0");
232 Output add =
233 ops::Add(root.WithOpName("add"), sw.output_true, sw.output_false);
234
235 std::unique_ptr<DeadnessAnalysis> result;
236 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
237
238 TF_ASSERT_OK_AND_ASSIGN(
239 bool has_inputs_with_mismatching_deadness,
240 HasInputsWithMismatchingDeadness(*result, *add.node()));
241 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
242 }
243
TEST(DeadnessAnalysisTest,BasicNegative)244 TEST(DeadnessAnalysisTest, BasicNegative) {
245 Scope root = Scope::NewRootScope().ExitOnError();
246
247 Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
248 Output b = ops::Placeholder(root.WithOpName("b"), DT_FLOAT);
249 Output add = ops::Add(root.WithOpName("add"), a, b);
250
251 std::unique_ptr<DeadnessAnalysis> result;
252 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
253
254 TF_ASSERT_OK_AND_ASSIGN(
255 bool has_inputs_with_mismatching_deadness,
256 HasInputsWithMismatchingDeadness(*result, *add.node()));
257 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
258 }
259
TEST(DeadnessAnalysisTest,AndIsCommutative)260 TEST(DeadnessAnalysisTest, AndIsCommutative) {
261 Scope root = Scope::NewRootScope().ExitOnError();
262
263 ops::Switch sw_0 = CreateSwitch(root, "0");
264 ops::Switch sw_1 = CreateSwitch(root, "1");
265
266 Output a0 =
267 ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
268 Output a1 =
269 ops::Add(root.WithOpName("a1"), sw_1.output_false, sw_0.output_false);
270
271 Output b0 =
272 ops::Add(root.WithOpName("b0"), sw_0.output_false, sw_1.output_true);
273 Output b1 =
274 ops::Add(root.WithOpName("b1"), sw_1.output_true, sw_0.output_false);
275
276 Output live0 = ops::Add(root.WithOpName("live0"), a0, a1);
277 Output live1 = ops::Add(root.WithOpName("live1"), b0, b1);
278
279 Output halfdead0 = ops::Add(root.WithOpName("halfdead0"), a0, b0);
280 Output halfdead1 = ops::Add(root.WithOpName("halfdead1"), a1, b1);
281
282 std::unique_ptr<DeadnessAnalysis> result;
283 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
284
285 bool has_inputs_with_mismatching_deadness;
286
287 TF_ASSERT_OK_AND_ASSIGN(
288 has_inputs_with_mismatching_deadness,
289 HasInputsWithMismatchingDeadness(*result, *live0.node()));
290 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
291
292 TF_ASSERT_OK_AND_ASSIGN(
293 has_inputs_with_mismatching_deadness,
294 HasInputsWithMismatchingDeadness(*result, *live1.node()));
295 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
296
297 TF_ASSERT_OK_AND_ASSIGN(
298 has_inputs_with_mismatching_deadness,
299 HasInputsWithMismatchingDeadness(*result, *halfdead0.node()));
300 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
301
302 TF_ASSERT_OK_AND_ASSIGN(
303 has_inputs_with_mismatching_deadness,
304 HasInputsWithMismatchingDeadness(*result, *halfdead1.node()));
305 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
306 }
307
TEST(DeadnessAnalysisTest,AndIsAssociative)308 TEST(DeadnessAnalysisTest, AndIsAssociative) {
309 Scope root = Scope::NewRootScope().ExitOnError();
310
311 ops::Switch sw_0 = CreateSwitch(root, "0");
312 ops::Switch sw_1 = CreateSwitch(root, "1");
313 ops::Switch sw_2 = CreateSwitch(root, "2");
314
315 Output a0 =
316 ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
317 Output a1 = ops::Add(root.WithOpName("a1"), a0, sw_2.output_false);
318
319 Output b0 =
320 ops::Add(root.WithOpName("b0"), sw_1.output_false, sw_2.output_false);
321 Output b1 = ops::Add(root.WithOpName("b1"), sw_0.output_false, b0);
322
323 Output add = ops::Add(root.WithOpName("add"), a1, b1);
324
325 std::unique_ptr<DeadnessAnalysis> result;
326 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
327
328 TF_ASSERT_OK_AND_ASSIGN(
329 bool has_inputs_with_mismatching_deadness,
330 HasInputsWithMismatchingDeadness(*result, *add.node()));
331 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
332 }
333
TEST(DeadnessAnalysisTest,OrIsCommutative)334 TEST(DeadnessAnalysisTest, OrIsCommutative) {
335 Scope root = Scope::NewRootScope().ExitOnError();
336
337 ops::Switch sw_0 = CreateSwitch(root, "0");
338 ops::Switch sw_1 = CreateSwitch(root, "1");
339
340 ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
341 ops::Merge m1(root.WithOpName("m1"), {sw_1.output_false, sw_0.output_false});
342 ops::Merge m2(root.WithOpName("m2"), {sw_0.output_false, sw_1.output_true});
343 ops::Merge m3(root.WithOpName("m3"), {sw_1.output_true, sw_0.output_false});
344
345 Output live0 = ops::Add(root.WithOpName("live0"), m0.output, m1.output);
346 Output live1 = ops::Add(root.WithOpName("live1"), m2.output, m3.output);
347
348 Output halfdead0 =
349 ops::Add(root.WithOpName("halfdead0"), m0.output, m2.output);
350 Output halfdead1 =
351 ops::Add(root.WithOpName("halfdead1"), m1.output, m3.output);
352
353 std::unique_ptr<DeadnessAnalysis> result;
354 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
355
356 bool has_inputs_with_mismatching_deadness;
357
358 TF_ASSERT_OK_AND_ASSIGN(
359 has_inputs_with_mismatching_deadness,
360 HasInputsWithMismatchingDeadness(*result, *live0.node()));
361 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
362
363 TF_ASSERT_OK_AND_ASSIGN(
364 has_inputs_with_mismatching_deadness,
365 HasInputsWithMismatchingDeadness(*result, *live1.node()));
366 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
367
368 TF_ASSERT_OK_AND_ASSIGN(
369 has_inputs_with_mismatching_deadness,
370 HasInputsWithMismatchingDeadness(*result, *halfdead0.node()));
371 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
372
373 TF_ASSERT_OK_AND_ASSIGN(
374 has_inputs_with_mismatching_deadness,
375 HasInputsWithMismatchingDeadness(*result, *halfdead1.node()));
376 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
377 }
378
TEST(DeadnessAnalysisTest,OrIsAssociative)379 TEST(DeadnessAnalysisTest, OrIsAssociative) {
380 Scope root = Scope::NewRootScope().ExitOnError();
381
382 ops::Switch sw_0 = CreateSwitch(root, "0");
383 ops::Switch sw_1 = CreateSwitch(root, "1");
384 ops::Switch sw_2 = CreateSwitch(root, "2");
385
386 ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
387 ops::Merge m1(root.WithOpName("m1"), {m0.output, sw_2.output_false});
388 ops::Merge m2(root.WithOpName("m2"), {sw_1.output_false, sw_2.output_false});
389 ops::Merge m3(root.WithOpName("m3"), {sw_0.output_false, m2.output});
390
391 Output add = ops::Add(root.WithOpName("add"), m1.output, m3.output);
392
393 std::unique_ptr<DeadnessAnalysis> result;
394 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
395
396 TF_ASSERT_OK_AND_ASSIGN(
397 bool has_inputs_with_mismatching_deadness,
398 HasInputsWithMismatchingDeadness(*result, *add.node()));
399 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
400 }
401
TEST(DeadnessAnalysisTest,AndOfOr)402 TEST(DeadnessAnalysisTest, AndOfOr) {
403 Scope root = Scope::NewRootScope().ExitOnError();
404
405 ops::Switch sw_0 = CreateSwitch(root, "0");
406 ops::Switch sw_1 = CreateSwitch(root, "1");
407 ops::Switch sw_2 = CreateSwitch(root, "2");
408 ops::Switch sw_3 = CreateSwitch(root, "3");
409
410 ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
411 ops::Merge m1(root.WithOpName("m1"), {sw_2.output_false, sw_3.output_false});
412
413 Output add0 = ops::Add(root.WithOpName("add0"), m0.output, m1.output);
414 Output add1 = ops::Add(root.WithOpName("add1"), m0.output, m1.output);
415
416 Output add2 = ops::Add(root.WithOpName("add2"), add0, add1);
417
418 std::unique_ptr<DeadnessAnalysis> result;
419 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
420
421 TF_ASSERT_OK_AND_ASSIGN(
422 bool has_inputs_with_mismatching_deadness,
423 HasInputsWithMismatchingDeadness(*result, *add2.node()));
424 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
425 }
426
TEST(DeadnessAnalysisTest,OrOfAnd)427 TEST(DeadnessAnalysisTest, OrOfAnd) {
428 Scope root = Scope::NewRootScope().ExitOnError();
429
430 ops::Switch sw_0 = CreateSwitch(root, "0");
431 ops::Switch sw_1 = CreateSwitch(root, "1");
432 ops::Switch sw_2 = CreateSwitch(root, "2");
433 ops::Switch sw_3 = CreateSwitch(root, "3");
434
435 Output add0 =
436 ops::Add(root.WithOpName("add0"), sw_0.output_false, sw_1.output_false);
437 Output add1 =
438 ops::Add(root.WithOpName("add1"), sw_2.output_false, sw_3.output_false);
439
440 ops::Merge m0(root.WithOpName("m0"), {add0, add1});
441 ops::Merge m1(root.WithOpName("m1"), {add0, add1});
442
443 Output add2 = ops::Add(root.WithOpName("add2"), m0.output, m1.output);
444
445 std::unique_ptr<DeadnessAnalysis> result;
446 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
447
448 TF_ASSERT_OK_AND_ASSIGN(
449 bool has_inputs_with_mismatching_deadness,
450 HasInputsWithMismatchingDeadness(*result, *add2.node()));
451 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
452 }
453
TEST(DeadnessAnalysisTest,AndOrDistributiveSimplified)454 TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) {
455 // (*A | (~*A & ((~*B & ~*A) | (~*A & *B)))) == #true
456 Scope root = Scope::NewRootScope().ExitOnError();
457
458 ops::Switch sw_0 = CreateSwitch(root, "A");
459 ops::Switch sw_1 = CreateSwitch(root, "B");
460 Output add0 =
461 ops::Add(root.WithOpName("and0"), sw_0.output_false, sw_1.output_true);
462 Output add1 =
463 ops::Add(root.WithOpName("and1"), sw_0.output_false, sw_1.output_false);
464 ops::Merge or2(root.WithOpName("or2"), {add0, add1});
465 Output add3 =
466 ops::Add(root.WithOpName("and3"), or2.output, sw_0.output_false);
467 ops::Merge or4(root.WithOpName("or4"), {add3, sw_0.output_true});
468
469 std::unique_ptr<DeadnessAnalysis> result;
470 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
471
472 PredicateMapTy predicate_map;
473 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
474 EXPECT_EQ(predicate_map[ControlOutputFor(or4.output)], "#true");
475 }
476
TEST(DeadnessAnalysisTest,AndOrDistributive)477 TEST(DeadnessAnalysisTest, AndOrDistributive) {
478 // (A|B)&C == (A&C)|(B&C)
479 Scope root = Scope::NewRootScope().ExitOnError();
480
481 ops::Switch sw_0 = CreateSwitch(root, "0");
482 ops::Switch sw_1 = CreateSwitch(root, "1");
483 ops::Switch sw_2 = CreateSwitch(root, "2");
484
485 ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
486 Output add0 = ops::Add(root.WithOpName("add0"), m0.output, sw_2.output_false);
487
488 Output add1 =
489 ops::Add(root.WithOpName("add1"), sw_0.output_false, sw_2.output_false);
490 Output add2 =
491 ops::Add(root.WithOpName("add2"), sw_1.output_false, sw_2.output_false);
492 ops::Merge m1(root.WithOpName("m1"), {add1, add2});
493
494 Output add3 = ops::Add(root.WithOpName("add3"), add0, m1.output);
495
496 std::unique_ptr<DeadnessAnalysis> result;
497 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
498
499 TF_ASSERT_OK_AND_ASSIGN(
500 bool has_inputs_with_mismatching_deadness,
501 HasInputsWithMismatchingDeadness(*result, *add3.node()));
502 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
503 }
504
TEST(DeadnessAnalysisTest,Ternary)505 TEST(DeadnessAnalysisTest, Ternary) {
506 Scope root = Scope::NewRootScope().ExitOnError();
507
508 Output predicate = ops::Placeholder(root.WithOpName("predicate"), DT_BOOL);
509 Output true_value = ops::Placeholder(root.WithOpName("true_value"), DT_FLOAT);
510 Output false_value =
511 ops::Placeholder(root.WithOpName("false_value"), DT_FLOAT);
512
513 ops::Switch predicated_true(root.WithOpName("predicated_true"), true_value,
514 predicate);
515
516 ops::Switch predicated_false(root.WithOpName("predicated_false"), true_value,
517 predicate);
518 ops::Merge merge(root.WithOpName("ternary"), {predicated_true.output_true,
519 predicated_false.output_false});
520 Output addend = ops::Placeholder(root.WithOpName("addend"), DT_FLOAT);
521 Output add = ops::Add(root.WithOpName("add"), merge.output, addend);
522
523 std::unique_ptr<DeadnessAnalysis> result;
524 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
525
526 TF_ASSERT_OK_AND_ASSIGN(
527 bool has_inputs_with_mismatching_deadness,
528 HasInputsWithMismatchingDeadness(*result, *add.node()));
529 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
530 }
531
TEST(DeadnessAnalysisTest,Recv)532 TEST(DeadnessAnalysisTest, Recv) {
533 Scope root = Scope::NewRootScope().ExitOnError();
534
535 Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_FLOAT, "tensor_a",
536 "sender", 0, "receiver");
537 Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_FLOAT, "tensor_b",
538 "sender", 0, "receiver");
539 Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
540
541 std::unique_ptr<DeadnessAnalysis> result;
542 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
543
544 TF_ASSERT_OK_AND_ASSIGN(
545 bool has_inputs_with_mismatching_deadness,
546 HasInputsWithMismatchingDeadness(*result, *add.node()));
547 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
548 }
549
TEST(DeadnessAnalysisTest,HostRecv)550 TEST(DeadnessAnalysisTest, HostRecv) {
551 Scope root = Scope::NewRootScope().ExitOnError();
552
553 Output recv_a = ops::_HostRecv(root.WithOpName("recv_a"), DT_FLOAT,
554 "tensor_a", "sender", 0, "receiver");
555 Output recv_b = ops::_HostRecv(root.WithOpName("recv_b"), DT_FLOAT,
556 "tensor_b", "sender", 0, "receiver");
557 Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
558
559 std::unique_ptr<DeadnessAnalysis> result;
560 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
561
562 TF_ASSERT_OK_AND_ASSIGN(
563 bool has_inputs_with_mismatching_deadness,
564 HasInputsWithMismatchingDeadness(*result, *add.node()));
565 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
566 }
567
TEST(DeadnessAnalysisTest,Loop)568 TEST(DeadnessAnalysisTest, Loop) {
569 Scope root = Scope::NewRootScope().ExitOnError();
570 Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0).induction_var;
571 Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0).induction_var;
572 Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1).induction_var;
573 Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1);
574 Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2);
575
576 // NB! iv0 and iv1 are equivalent and a smarter deadness analysis would have
577 // noticed that. Today we are pessimistic here because we assign an
578 // uninterpreted symbol to merges with backedges.
579
580 VLogGraphIfAsked(*root.graph());
581
582 {
583 std::unique_ptr<DeadnessAnalysis> result;
584 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
585
586 bool has_inputs_with_mismatching_deadness;
587
588 TF_ASSERT_OK_AND_ASSIGN(
589 has_inputs_with_mismatching_deadness,
590 HasInputsWithMismatchingDeadness(*result, *add0.node()));
591 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
592
593 TF_ASSERT_OK_AND_ASSIGN(
594 has_inputs_with_mismatching_deadness,
595 HasInputsWithMismatchingDeadness(*result, *add1.node()));
596 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
597 }
598 {
599 PredicateMapTy predicate_map;
600 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
601
602 // In theory we should be able to tell that iv0/cond:0 and iv1/cond:0
603 // produce the same deadness. But we're not that smart today.
604 EXPECT_EQ(predicate_map[ControlOutputFor(iv0)],
605 "{#true,&,*iv0/cond:0}<fr0>");
606 EXPECT_EQ(predicate_map[ControlOutputFor(iv1)],
607 "{#true,&,*iv1/cond:0}<fr0>");
608 EXPECT_EQ(predicate_map[ControlOutputFor(iv2)],
609 "{#true,&,*iv2/cond:0}<fr0>");
610 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
611 "({#true,&,*iv0/cond:0}<fr0> & {#true,&,*iv1/cond:0}<fr0>)");
612 EXPECT_EQ(predicate_map[ControlOutputFor(add1)],
613 "({#true,&,*iv1/cond:0}<fr0> & {#true,&,*iv2/cond:0}<fr0>)");
614 }
615 }
616
TEST(DeadnessAnalysisTest,ControlEquivalentLoopBodies)617 TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
618 Scope root = Scope::NewRootScope().ExitOnError();
619 InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
620 Output dependent_iv0 =
621 CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0)
622 .induction_var;
623 Output dependent_iv1 =
624 CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0)
625 .induction_var;
626 Output add0 = ops::Add(root.WithOpName("add0"), dependent_iv0, dependent_iv1);
627
628 VLogGraphIfAsked(*root.graph());
629
630 {
631 std::unique_ptr<DeadnessAnalysis> result;
632 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
633
634 TF_ASSERT_OK_AND_ASSIGN(
635 bool has_inputs_with_mismatching_deadness,
636 HasInputsWithMismatchingDeadness(*result, *add0.node()));
637 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
638 }
639 {
640 PredicateMapTy predicate_map;
641 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
642 /*enable_optimistic=*/true));
643
644 EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
645 "{#true,&,*iv0/cond:0}<loop>");
646 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
647 predicate_map[ControlOutputFor(iv.induction_var)]);
648 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
649 predicate_map[ControlOutputFor(iv.induction_var)]);
650 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
651 predicate_map[ControlOutputFor(iv.induction_var)]);
652 }
653 {
654 PredicateMapTy predicate_map;
655 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
656 /*enable_optimistic=*/false));
657
658 EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
659 "{#true,&,*iv0/cond:0}<loop>");
660 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
661 "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
662 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
663 "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
664 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
665 "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
666 }
667 }
668
TEST(DeadnessAnalysisTest,LoopInvariantPredicateOnBackedge)669 TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
670 // Create a merge that "looks like" a loop but isn't really. It has a value
671 // that does not depend on the merge on its backedge.
672 Scope root = Scope::NewRootScope().ExitOnError();
673 InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0);
674 DependentInductionVar dependent_iv =
675 CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0);
676 FixupSourceAndSinkEdges(root.graph());
677
678 TF_ASSERT_OK(root.graph()->UpdateEdge(
679 iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0));
680
681 VLogGraphIfAsked(*root.graph());
682
683 {
684 PredicateMapTy predicate_map;
685 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
686 /*enable_optimistic=*/true));
687
688 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
689 "{#true,&,*iv0/cond:0}<frame>");
690 }
691 {
692 PredicateMapTy predicate_map;
693 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
694 /*enable_optimistic=*/false));
695
696 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
697 "div0/iv:0");
698 }
699 }
700
TEST(DeadnessAnalysisTest,ControlEquivalentNestedLoopBodies)701 TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
702 Scope root = Scope::NewRootScope().ExitOnError();
703 InductionVarInfo iv_outer =
704 CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
705 Output enter_constant_outer_loop = ops::internal::Enter(
706 root.WithOpName("constant_enter_outer_loop"),
707 ops::Const(root.WithOpName("constant"), 5), "outer_loop",
708 ops::internal::Enter::Attrs().IsConstant(true));
709 ops::Switch inner_value(root.WithOpName("outer_is_live"),
710 enter_constant_outer_loop, iv_outer.loop_cond);
711 InductionVarInfo iv_inner = CreateInductionVariable(
712 root, "iv_inner", "inner_loop", inner_value.output_true);
713
714 Output dependent_outer_iv0 =
715 CreateDependentLoopInvariantValue(root, "dependent_outer_iv0",
716 "outer_loop", iv_outer.loop_cond, 0)
717 .induction_var;
718 Output dependent_outer_iv1 =
719 CreateDependentLoopInvariantValue(root, "dependent_outer_iv1",
720 "outer_loop", iv_outer.loop_cond, 0)
721 .induction_var;
722
723 Output dependent_inner_iv0 = CreateDependentLoopInvariantValue(
724 root, "dependent_inner_iv0", "inner_loop",
725 iv_inner.loop_cond, dependent_outer_iv0)
726 .induction_var;
727 Output dependent_inner_iv1 = CreateDependentLoopInvariantValue(
728 root, "dependent_inner_iv1", "inner_loop",
729 iv_inner.loop_cond, dependent_outer_iv1)
730 .induction_var;
731
732 Output add0 = ops::Add(root.WithOpName("add0"), dependent_inner_iv0,
733 dependent_inner_iv1);
734
735 VLogGraphIfAsked(*root.graph());
736
737 {
738 std::unique_ptr<DeadnessAnalysis> result;
739 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
740
741 TF_ASSERT_OK_AND_ASSIGN(
742 bool has_inputs_with_mismatching_deadness,
743 HasInputsWithMismatchingDeadness(*result, *add0.node()));
744 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
745 }
746 {
747 PredicateMapTy predicate_map;
748 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
749 /*enable_optimistic=*/true));
750
751 EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
752 "{#true,&,*iv_outer/cond:0}<outer_loop>");
753 EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
754 "{(*iv_outer/cond:0 & "
755 "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
756 "cond:0}<inner_loop;outer_loop>");
757
758 // enable_optimistic = true or not should produce the same results because
759 // of fallback. However, note that the order of iv_inner/cond:0 and
760 // iv_inner/iv:0 is different because the optimistic approach does not
761 // create predicates for all merges and it can change the predicate id and
762 // hence the symbol order.
763 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
764 "{{#true,&,(iv_outer/iv:0 & "
765 "*iv_outer/cond:0)}<outer_loop>,&,(*iv_inner/cond:0 & "
766 "iv_inner/iv:0)}<inner_loop;outer_loop>");
767 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
768 predicate_map[ControlOutputFor(dependent_inner_iv0)]);
769 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
770 predicate_map[ControlOutputFor(dependent_inner_iv0)]);
771 }
772 {
773 PredicateMapTy predicate_map;
774 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
775 /*enable_optimistic=*/false));
776
777 EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
778 "{#true,&,*iv_outer/cond:0}<outer_loop>");
779 EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
780 "{(*iv_outer/cond:0 & "
781 "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
782 "cond:0}<inner_loop;outer_loop>");
783
784 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
785 "{{#true,&,(iv_outer/iv:0 & "
786 "*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
787 "*iv_inner/cond:0)}<inner_loop;outer_loop>");
788 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
789 predicate_map[ControlOutputFor(dependent_inner_iv0)]);
790 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
791 predicate_map[ControlOutputFor(dependent_inner_iv0)]);
792 }
793 }
794
TEST(DeadnessAnalysisTest,ControlNonEquivalentNestedLoopBodies)795 TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
796 Scope root = Scope::NewRootScope().ExitOnError();
797
798 std::array<Output, 2> outer_iv;
799 std::array<Output, 2> inner_iv;
800
801 for (int i : {0, 1}) {
802 InductionVarInfo iv_outer =
803 CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
804 Output enter_constant_outer_loop = ops::internal::Enter(
805 root.WithOpName("constant_enter_outer_loop"),
806 ops::Const(root.WithOpName("constant"), 5), "outer_loop",
807 ops::internal::Enter::Attrs().IsConstant(true));
808 ops::Switch inner_value(root.WithOpName("outer_is_live"),
809 enter_constant_outer_loop, iv_outer.loop_cond);
810 InductionVarInfo iv_inner = CreateInductionVariable(
811 root, "iv_inner", "inner_loop", inner_value.output_true);
812
813 outer_iv[i] = iv_outer.induction_var;
814 inner_iv[i] = iv_inner.induction_var;
815 }
816
817 Output add0 = ops::Add(root.WithOpName("add0"), inner_iv[0], inner_iv[1]);
818
819 VLogGraphIfAsked(*root.graph());
820
821 {
822 std::unique_ptr<DeadnessAnalysis> result;
823 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
824
825 TF_ASSERT_OK_AND_ASSIGN(
826 bool has_inputs_with_mismatching_deadness,
827 HasInputsWithMismatchingDeadness(*result, *add0.node()));
828 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
829 }
830
831 {
832 PredicateMapTy predicate_map;
833 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
834
835 EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[0])],
836 "{#true,&,*iv_outer/cond:0}<outer_loop>");
837 EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[0])],
838 "{(*iv_outer/cond:0 & "
839 "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
840 "cond:0}<inner_loop;outer_loop>");
841 EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[1])],
842 "{#true,&,*iv_outer/cond_1:0}<outer_loop>");
843 EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[1])],
844 "{(*iv_outer/cond_1:0 & "
845 "{#true,&,*iv_outer/cond_1:0}<outer_loop>),&,*iv_inner/"
846 "cond_1:0}<inner_loop;outer_loop>");
847 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
848 "({(*iv_outer/cond:0 & "
849 "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
850 "cond:0}<inner_loop;outer_loop> & {(*iv_outer/cond_1:0 & "
851 "{#true,&,*iv_outer/cond_1:0}<outer_loop>),&,*iv_inner/"
852 "cond_1:0}<inner_loop;outer_loop>)");
853 }
854 }
855
TEST(DeadnessAnalysisTest,NestedLoopBodiesWithACapture)856 TEST(DeadnessAnalysisTest, NestedLoopBodiesWithACapture) {
857 Scope root = Scope::NewRootScope().ExitOnError();
858 InductionVarInfo iv_outer =
859 CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
860 Output enter_constant_outer_loop = ops::internal::Enter(
861 root.WithOpName("constant_enter_outer_loop"),
862 ops::Const(root.WithOpName("constant"), 5), "outer_loop",
863 ops::internal::Enter::Attrs().IsConstant(true));
864 ops::Switch inner_value(root.WithOpName("outer_is_live"),
865 enter_constant_outer_loop, iv_outer.loop_cond);
866 InductionVarInfo iv_inner = CreateInductionVariable(
867 root, "iv_inner", "inner_loop", inner_value.output_true);
868
869 DependentInductionVar div0_outer = CreateDependentLoopInvariantValue(
870 root, "div0_outer", "outer_loop", iv_outer.loop_cond, 0);
871 DependentInductionVar div1_outer = CreateDependentLoopInvariantValue(
872 root, "div1_outer", "outer_loop", iv_outer.loop_cond, 0);
873
874 DependentInductionVar div0_inner = CreateDependentLoopInvariantValue(
875 root, "div0_inner", "inner_loop", iv_inner.loop_cond,
876 div0_outer.induction_var);
877 DependentInductionVar div1_inner = CreateDependentLoopInvariantValue(
878 root, "div1_inner", "inner_loop", iv_inner.loop_cond,
879 div1_outer.induction_var);
880
881 Output captured = ops::_Recv(root.WithOpName("captured"), DT_INT32,
882 "tensor_a", "sender", 0, "receiver");
883 Output capture_enter_outer = ops::internal::Enter(
884 root.WithOpName("capture_enter_outer"), captured, "outer_loop",
885 ops::internal::Enter::Attrs().IsConstant(true));
886 Output capture_enter_inner = ops::internal::Enter(
887 root.WithOpName("capture_enter_inner"), capture_enter_outer, "inner_loop",
888 ops::internal::Enter::Attrs().IsConstant(true));
889 Output mul0 = ops::Mul(root.WithOpName("mul0"), div1_inner.induction_var,
890 capture_enter_inner);
891 TF_ASSERT_OK(root.graph()->UpdateEdge(
892 mul0.node(), 0, div1_inner.latch.output_true.node(), 0));
893
894 Output add0 = ops::Add(root.WithOpName("add0"), div0_inner.induction_var,
895 div1_inner.induction_var);
896
897 VLogGraphIfAsked(*root.graph());
898
899 {
900 std::unique_ptr<DeadnessAnalysis> result;
901 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
902
903 TF_ASSERT_OK_AND_ASSIGN(
904 bool has_inputs_with_mismatching_deadness,
905 HasInputsWithMismatchingDeadness(*result, *add0.node()));
906 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
907 }
908 }
909
TEST(DeadnessAnalysisTest,CyclicRecurrence)910 TEST(DeadnessAnalysisTest, CyclicRecurrence) {
911 Scope root = Scope::NewRootScope().ExitOnError();
912 InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
913 DependentInductionVar div0 =
914 CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0);
915 DependentInductionVar div1 =
916 CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0);
917 FixupSourceAndSinkEdges(root.graph());
918 TF_ASSERT_OK(root.graph()->UpdateEdge(div1.induction_var.node(), 0,
919 div0.latch.output_true.node(), 0));
920 TF_ASSERT_OK(root.graph()->UpdateEdge(div0.induction_var.node(), 0,
921 div1.latch.output_true.node(), 0));
922
923 VLogGraphIfAsked(*root.graph());
924
925 {
926 PredicateMapTy predicate_map;
927 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
928 /*enable_optimistic=*/true));
929
930 EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
931 "{#true,&,*iv0/cond:0}<loop>");
932 EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)],
933 "{#true,&,*iv0/cond:0}<loop>");
934 EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)],
935 "{#true,&,*iv0/cond:0}<loop>");
936
937 // This tests the rule {S,&,X} & ~X => S.
938 TensorId switch_false_out = {div1.latch.output_false.node()->name(),
939 div1.latch.output_false.index()};
940 EXPECT_EQ(predicate_map[switch_false_out], "(#true)");
941 }
942 {
943 PredicateMapTy predicate_map;
944 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
945 /*enable_optimistic=*/false));
946
947 EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
948 "{#true,&,*iv0/cond:0}<loop>");
949 EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)], "div0/iv:0");
950 EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)], "div1/iv:0");
951 }
952 }
953
TEST(DeadnessAnalysisTest,AndRecurrenceNeedsFrameName)954 TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) {
955 Scope root = Scope::NewRootScope().ExitOnError();
956 InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10);
957 InductionVarInfo iv_1 = CreateInductionVariable(root, "iv_1", "frame_1", 9);
958
959 Output init = CreateSwitch(root, "init").output_true;
960 Output step = CreateSwitch(root, "step").output_true;
961
962 std::array<Output, 2> exits;
963 std::array<Output, 2> next_iterations;
964
965 for (int i : {0, 1}) {
966 Output init_enter = ops::internal::Enter(
967 root.WithOpName(absl::StrCat("init_enter_frame_", i)), init,
968 absl::StrCat("frame_", i),
969 ops::internal::Enter::Attrs().IsConstant(true));
970 Output step_enter = ops::internal::Enter(
971 root.WithOpName(absl::StrCat("step_enter_frame_", i)), step,
972 absl::StrCat("frame_", i),
973 ops::internal::Enter::Attrs().IsConstant(true));
974
975 ops::Merge iv(root.WithOpName(absl::StrCat("expr_", i)),
976 {init_enter, init_enter});
977 Output add = ops::Add(root.WithOpName(absl::StrCat("add_", i)), iv.output,
978 step_enter);
979 next_iterations[i] = ops::NextIteration(
980 root.WithOpName(absl::StrCat("expr_", i, "_next_iteration")), add);
981 EXPECT_TRUE(
982 root.graph()
983 ->UpdateEdge(next_iterations[i].node(), 0, iv.output.node(), 1)
984 .ok());
985 exits[i] = ops::internal::Exit(root.WithOpName(absl::StrCat("exit_", i)),
986 iv.output);
987 }
988
989 FixupSourceAndSinkEdges(root.graph());
990
991 {
992 PredicateMapTy predicate_map;
993 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
994
995 EXPECT_NE(predicate_map[ControlOutputFor(exits[0])],
996 predicate_map[ControlOutputFor(exits[1])]);
997 EXPECT_NE(predicate_map[ControlOutputFor(exits[0])], "");
998 EXPECT_NE(predicate_map[ControlOutputFor(exits[1])], "");
999
1000 EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])],
1001 predicate_map[ControlOutputFor(next_iterations[1])]);
1002 EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])], "");
1003 EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[1])], "");
1004 }
1005 }
1006
TEST(DeadnessAnalysisTest,ControlInputs)1007 TEST(DeadnessAnalysisTest, ControlInputs) {
1008 Scope root = Scope::NewRootScope().ExitOnError();
1009 ops::Switch sw = CreateSwitch(root, "0");
1010
1011 Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
1012 Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
1013
1014 Output const0 = ops::Const(root.WithOpName("const0"), 1);
1015 Output const1 = ops::Const(root.WithOpName("const1"), 2);
1016
1017 Output add = ops::Add(root.WithOpName("add"), const0, const1);
1018
1019 root.graph()->AddControlEdge(id0.node(), const0.node());
1020 root.graph()->AddControlEdge(id1.node(), const1.node());
1021
1022 std::unique_ptr<DeadnessAnalysis> result;
1023 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1024
1025 TF_ASSERT_OK_AND_ASSIGN(
1026 bool has_inputs_with_mismatching_deadness,
1027 HasInputsWithMismatchingDeadness(*result, *add.node()));
1028 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
1029 }
1030
TEST(DeadnessAnalysisTest,ControlTrigger)1031 TEST(DeadnessAnalysisTest, ControlTrigger) {
1032 Scope root = Scope::NewRootScope().ExitOnError();
1033 ops::Switch sw = CreateSwitch(root, "0");
1034
1035 Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
1036 Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
1037
1038 ops::ControlTrigger ctrl_trigger0(root.WithOpName("ctrl_trigger0"));
1039 ops::ControlTrigger ctrl_trigger1(root.WithOpName("ctrl_trigger1"));
1040
1041 Output const0 = ops::Const(root.WithOpName("const0"), 1);
1042 Output const1 = ops::Const(root.WithOpName("const1"), 2);
1043
1044 Output add = ops::Add(root.WithOpName("add"), const0, const1);
1045
1046 root.graph()->AddControlEdge(id0.node(), ctrl_trigger0.operation.node());
1047 root.graph()->AddControlEdge(ctrl_trigger0.operation.node(), const0.node());
1048
1049 root.graph()->AddControlEdge(id1.node(), ctrl_trigger1.operation.node());
1050 root.graph()->AddControlEdge(ctrl_trigger1.operation.node(), const1.node());
1051
1052 std::unique_ptr<DeadnessAnalysis> result;
1053 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1054
1055 TF_ASSERT_OK_AND_ASSIGN(
1056 bool has_inputs_with_mismatching_deadness,
1057 HasInputsWithMismatchingDeadness(*result, *add.node()));
1058 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
1059 }
1060
TEST(DeadnessAnalysisTest,ControlInputsToMerge)1061 TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
1062 Scope root = Scope::NewRootScope().ExitOnError();
1063 ops::Switch sw = CreateSwitch(root, "0");
1064
1065 Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
1066 Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
1067
1068 Output constant = ops::Const(root.WithOpName("constant"), 5);
1069 ops::Merge m0(root.WithOpName("m0"), {constant});
1070 ops::Merge m1(root.WithOpName("m0"), {constant});
1071 Output add = ops::Add(root.WithOpName("add"), m0.output, m1.output);
1072
1073 root.graph()->AddControlEdge(id0.node(), m0.output.node());
1074 root.graph()->AddControlEdge(id1.node(), m1.output.node());
1075
1076 std::unique_ptr<DeadnessAnalysis> result;
1077 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1078
1079 TF_ASSERT_OK_AND_ASSIGN(
1080 bool has_inputs_with_mismatching_deadness,
1081 HasInputsWithMismatchingDeadness(*result, *add.node()));
1082 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
1083 }
1084
TEST(DeadnessAnalysisTest,RecvVsSwitch)1085 TEST(DeadnessAnalysisTest, RecvVsSwitch) {
1086 // Demonstrates why we need the must_be_true bit on SymbolP.
1087 Scope root = Scope::NewRootScope().ExitOnError();
1088
1089 Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
1090 0, "receiver");
1091 Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
1092 ops::Switch sw(root.WithOpName("switch"), value, recv);
1093 Output logical_and =
1094 ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
1095
1096 std::unique_ptr<DeadnessAnalysis> result;
1097 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1098
1099 TF_ASSERT_OK_AND_ASSIGN(
1100 bool has_inputs_with_mismatching_deadness,
1101 HasInputsWithMismatchingDeadness(*result, *logical_and.node()));
1102 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
1103 }
1104
TEST(DeadnessAnalysisTest,RecvVsSwitchText)1105 TEST(DeadnessAnalysisTest, RecvVsSwitchText) {
1106 // Demonstrates why we need the must_be_true bit on SymbolP.
1107 Scope root = Scope::NewRootScope().ExitOnError();
1108
1109 Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
1110 0, "receiver");
1111 Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
1112 ops::Switch sw(root.WithOpName("switch"), value, recv);
1113 Output logical_and =
1114 ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
1115
1116 std::unique_ptr<DeadnessAnalysis> result;
1117 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1118
1119 PredicateMapTy predicate_map;
1120 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1121
1122 TensorId logical_and_output_0 = {logical_and.node()->name(),
1123 Graph::kControlSlot};
1124 EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)");
1125 }
1126
TEST(DeadnessAnalysisTest,DeMorgan)1127 TEST(DeadnessAnalysisTest, DeMorgan) {
1128 Scope root = Scope::NewRootScope().ExitOnError();
1129
1130 Output cond_0 = ops::Placeholder(root.WithOpName("cond_0"), DT_BOOL);
1131 Output cond_1 = ops::Placeholder(root.WithOpName("cond_1"), DT_BOOL);
1132 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1133
1134 ops::Switch sw_0(root.WithOpName("switch_0"), value, cond_0);
1135 ops::Switch sw_1(root.WithOpName("switch_1"), value, cond_1);
1136
1137 Output and_0_1 =
1138 ops::Add(root.WithOpName("and_0_1"), sw_0.output_true, sw_1.output_true);
1139
1140 Output or_not0_not1 = ops::Merge(root.WithOpName("or_not0_not1"),
1141 {sw_0.output_false, sw_1.output_false})
1142 .output;
1143
1144 // Predicate(should_always_be_dead) =
1145 // (A & B) & (~A | ~B) = (A & B) & ~(A & B) = False
1146 Output should_always_be_dead =
1147 ops::Add(root.WithOpName("should_always_be_dead"), and_0_1, or_not0_not1);
1148
1149 // Predicate(should_always_be_dead) =
1150 // (A & B) | (~A | ~B) = (A & B) | ~(A & B) = True
1151 Output should_always_be_alive =
1152 ops::Merge(root.WithOpName("should_always_be_alive"),
1153 {and_0_1, or_not0_not1})
1154 .output;
1155
1156 std::unique_ptr<DeadnessAnalysis> result;
1157 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1158
1159 PredicateMapTy predicate_map;
1160 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1161
1162 EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_dead)], "#false");
1163 EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_alive)], "#true");
1164 }
1165
TEST(DeadnessAnalysisTest,ConstantTrueSwitchCondition)1166 TEST(DeadnessAnalysisTest, ConstantTrueSwitchCondition) {
1167 Scope root = Scope::NewRootScope().ExitOnError();
1168
1169 Output constant_true = ops::Const(root.WithOpName("const_true"), true);
1170 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1171 ops::Switch sw(root.WithOpName("switch"), value, constant_true);
1172
1173 Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
1174 Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
1175
1176 FixupSourceAndSinkEdges(root.graph());
1177
1178 PredicateMapTy predicate_map;
1179 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1180
1181 EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#false");
1182 EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#true");
1183 }
1184
TEST(DeadnessAnalysisTest,ConstantFalseSwitchCondition)1185 TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) {
1186 Scope root = Scope::NewRootScope().ExitOnError();
1187
1188 Output constant_false = ops::Const(root.WithOpName("const_false"), false);
1189 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1190 ops::Switch sw(root.WithOpName("switch"), value, constant_false);
1191
1192 Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
1193 Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
1194
1195 FixupSourceAndSinkEdges(root.graph());
1196
1197 PredicateMapTy predicate_map;
1198 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1199
1200 EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#true");
1201 EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false");
1202 }
1203
TEST(DeadnessAnalysisTest,RefBoolSwitchCondition)1204 TEST(DeadnessAnalysisTest, RefBoolSwitchCondition) {
1205 Scope root = Scope::NewRootScope().ExitOnError();
1206
1207 Output condition_ref_var =
1208 ops::Variable(root.WithOpName("cond_ref"), TensorShape({}), DT_BOOL);
1209 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1210 ops::Switch sw(root.WithOpName("switch"), value, condition_ref_var);
1211
1212 Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
1213 Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
1214
1215 FixupSourceAndSinkEdges(root.graph());
1216
1217 PredicateMapTy predicate_map;
1218 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1219
1220 EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "~*cond_ref:0");
1221 EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "*cond_ref:0");
1222 }
1223
CreateSwitchN(const Scope & scope,Input data,Input output_index,int64_t num_outs,OutputList * outputs)1224 void CreateSwitchN(const Scope& scope, Input data, Input output_index,
1225 int64_t num_outs, OutputList* outputs) {
1226 if (!scope.ok()) return;
1227 auto _data = ops::AsNodeOut(scope, data);
1228 if (!scope.ok()) return;
1229 auto _output_index = ops::AsNodeOut(scope, output_index);
1230 if (!scope.ok()) return;
1231 Node* ret;
1232 const auto unique_name = scope.GetUniqueNameForOp("_SwitchN");
1233 auto builder = NodeBuilder(unique_name, "_SwitchN")
1234 .Input(_data)
1235 .Input(_output_index)
1236 .Attr("num_outs", num_outs);
1237 scope.UpdateBuilder(&builder);
1238 scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
1239 if (!scope.ok()) return;
1240 scope.UpdateStatus(scope.DoShapeInference(ret));
1241 for (int32_t i = 0; i < ret->num_outputs(); ++i) {
1242 outputs->push_back(Output(ret, i));
1243 }
1244 }
1245
TEST(DeadnessAnalysisTest,Constant1_SwitchN_2Branches_DoesNotFail)1246 TEST(DeadnessAnalysisTest, Constant1_SwitchN_2Branches_DoesNotFail) {
1247 Scope root = Scope::NewRootScope().ExitOnError();
1248
1249 Output constant_1 = ops::Const(root.WithOpName("const_1"), 1);
1250 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1251 OutputList outputs;
1252 CreateSwitchN(root.WithOpName("switchn"), value, constant_1, 2, &outputs);
1253
1254 Output id_0 = ops::Identity(root.WithOpName("id_0"), outputs[0]);
1255 Output id_1 = ops::Identity(root.WithOpName("id_1"), outputs[1]);
1256
1257 FixupSourceAndSinkEdges(root.graph());
1258
1259 PredicateMapTy predicate_map;
1260 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1261
1262 EXPECT_EQ(predicate_map[ControlOutputFor(id_0)], "#false");
1263 EXPECT_EQ(predicate_map[ControlOutputFor(id_1)], "#true");
1264 }
1265
TEST(DeadnessAnalysisTest,Constant7_SwitchN_3Branches)1266 TEST(DeadnessAnalysisTest, Constant7_SwitchN_3Branches) {
1267 Scope root = Scope::NewRootScope().ExitOnError();
1268
1269 Output constant_7 = ops::Const(root.WithOpName("const_7"), 7);
1270 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1271 OutputList outputs;
1272 CreateSwitchN(root.WithOpName("switchn"), value, constant_7, 3, &outputs);
1273
1274 Output id_0 = ops::Identity(root.WithOpName("id_0"), outputs[0]);
1275 Output id_1 = ops::Identity(root.WithOpName("id_1"), outputs[1]);
1276 Output id_2 = ops::Identity(root.WithOpName("id_2"), outputs[2]);
1277
1278 FixupSourceAndSinkEdges(root.graph());
1279
1280 PredicateMapTy predicate_map;
1281 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1282
1283 EXPECT_EQ(predicate_map[ControlOutputFor(id_0)], "#false");
1284 EXPECT_EQ(predicate_map[ControlOutputFor(id_1)], "#false");
1285 EXPECT_EQ(predicate_map[ControlOutputFor(id_2)], "#true");
1286 }
1287
TEST(DeadnessAnalysisTest,RefInt_SwitchN_3Branches)1288 TEST(DeadnessAnalysisTest, RefInt_SwitchN_3Branches) {
1289 Scope root = Scope::NewRootScope().ExitOnError();
1290
1291 Output condition_ref_var =
1292 ops::Variable(root.WithOpName("bidx"), TensorShape({}), DT_INT32);
1293 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1294 OutputList outputs;
1295 CreateSwitchN(root.WithOpName("switchn"), value, condition_ref_var, 3,
1296 &outputs);
1297
1298 Output id_0 = ops::Identity(root.WithOpName("id_0"), outputs[0]);
1299 Output id_1 = ops::Identity(root.WithOpName("id_1"), outputs[1]);
1300 Output id_2 = ops::Identity(root.WithOpName("id_2"), outputs[2]);
1301
1302 FixupSourceAndSinkEdges(root.graph());
1303
1304 PredicateMapTy predicate_map;
1305 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1306
1307 EXPECT_EQ(predicate_map[ControlOutputFor(id_0)], "bidx:0=0");
1308 EXPECT_EQ(predicate_map[ControlOutputFor(id_1)], "(~bidx:0=0 & bidx:0=1)");
1309 EXPECT_EQ(predicate_map[ControlOutputFor(id_2)], "(~bidx:0=0 & ~bidx:0=1)");
1310 }
1311
1312 } // namespace
1313 } // namespace tensorflow
1314