xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/deadness_analysis_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/deadness_analysis.h"
17 
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/sendrecv_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/compiler/jit/deadness_analysis_internal.h"
25 #include "tensorflow/compiler/jit/defs.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/graph/algorithm.h"
32 #include "tensorflow/core/graph/graph_def_builder.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35 
36 namespace tensorflow {
37 namespace {
38 
HasInputsWithMismatchingDeadness(const DeadnessAnalysis & deadness_analysis,const Node & n)39 se::port::StatusOr<bool> HasInputsWithMismatchingDeadness(
40     const DeadnessAnalysis& deadness_analysis, const Node& n) {
41   std::optional<DeadnessAnalysis::DeadnessPredicate> pred;
42   for (const Edge* edge : n.in_edges()) {
43     TF_ASSIGN_OR_RETURN(
44         DeadnessAnalysis::DeadnessPredicate this_pred,
45         deadness_analysis.GetPredicateFor(edge->src(), edge->src_output()));
46     if (pred && *pred != this_pred) {
47       return true;
48     }
49     pred = this_pred;
50   }
51 
52   return false;
53 }
54 
55 using deadness_analysis_internal::ComputePredicates;
56 using deadness_analysis_internal::PredicateMapTy;
57 
AnalyzeDeadness(Graph * graph,std::unique_ptr<DeadnessAnalysis> * result)58 Status AnalyzeDeadness(Graph* graph,
59                        std::unique_ptr<DeadnessAnalysis>* result) {
60   FixupSourceAndSinkEdges(graph);
61   return DeadnessAnalysis::Run(*graph, result);
62 }
63 
CreateSwitch(const Scope & root,const string & prefix)64 ops::Switch CreateSwitch(const Scope& root, const string& prefix) {
65   Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT);
66   Output predicate =
67       ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL);
68   return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate);
69 }
70 
ControlOutputFor(const Output & o)71 TensorId ControlOutputFor(const Output& o) {
72   return {o.node()->name(), Graph::kControlSlot};
73 }
74 
VLogGraphIfAsked(const Graph & graph)75 void VLogGraphIfAsked(const Graph& graph) {
76   if (VLOG_IS_ON(3)) {
77     GraphDef graph_def;
78     graph.ToGraphDef(&graph_def);
79     string serialized;
80     ::tensorflow::protobuf::TextFormat::PrintToString(graph_def, &serialized);
81     LOG(INFO) << serialized;
82   }
83 }
84 
85 struct InductionVarInfo {
86   Output induction_var;
87   Output loop_cond;
88 };
89 
90 // Creates an induction variable with the following structure (simplified for
91 // brevity):
92 //
93 //            +---------------+
94 //            | initial_value |
95 //            +---------------+
96 //              |
97 //              |
98 //              v
99 //            +---------------+
100 //            |     Enter     |
101 //            +---------------+
102 //              |
103 //              |
104 //              v
105 //            +---------------+
106 //         +> |     Merge     | -+
107 //         |  +---------------+  |
108 //         |    |                |
109 //         |    |                |
110 //         |    v                |
111 //         |  +---------------+  |
112 //         |  |  LessThan10   |  |
113 //         |  +---------------+  |
114 //         |    |                |
115 //         |    |                |
116 //         |    v                |
117 //         |  +---------------+  |
118 //    +----+- |    Switch     | <+
119 //    |    |  +---------------+
120 //    |    |    |
121 //    |    |    |
122 //    |    |    v
123 //    |    |  +---------------+
124 //    |    +- |    AddOne     |
125 //    |       +---------------+
126 //    |       +---------------+
127 //    +-----> |     Exit      |
128 //            +---------------+
CreateInductionVariable(const Scope & root,const string & prefix,const string & frame_name,const Output & initial_value)129 InductionVarInfo CreateInductionVariable(const Scope& root,
130                                          const string& prefix,
131                                          const string& frame_name,
132                                          const Output& initial_value) {
133   Output enter_initial_value = ops::internal::Enter(
134       root.WithOpName(prefix + "/enter"), initial_value, frame_name);
135 
136   ops::Merge iv(root.WithOpName(prefix + "/iv"),
137                 {enter_initial_value, enter_initial_value});
138   Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1);
139   Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10);
140   Output loop_cond_expr =
141       ops::Less(root.WithOpName(prefix + "/cond"), iv.output, final_value);
142   ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output,
143                     loop_cond_expr);
144   ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
145                            latch.output_false);
146   Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"),
147                             latch.output_true, increment_by);
148   Output next_iteration =
149       ops::NextIteration(root.WithOpName(prefix + "/next_iteration"), iv_next);
150 
151   CHECK(root.graph()
152             ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)
153             .ok());
154   root.graph()->AddControlEdge(iv.output.node(), increment_by.node());
155   root.graph()->AddControlEdge(iv.output.node(), final_value.node());
156 
157   return {iv.output, loop_cond_expr};
158 }
159 
CreateInductionVariable(const Scope & root,const string & prefix,const string & frame_name,int32_t init)160 InductionVarInfo CreateInductionVariable(const Scope& root,
161                                          const string& prefix,
162                                          const string& frame_name,
163                                          int32_t init) {
164   return CreateInductionVariable(
165       root, prefix, frame_name,
166       ops::Const(root.WithOpName(prefix + "/init"), init));
167 }
168 
169 // Creates an induction variable with the following structure:
170 //
171 //                           +---------------+
172 //                           | initial_value |
173 //                           +---------------+
174 //                             |
175 //                             |
176 //                             v
177 //                           +---------------+
178 //                           |     Enter     |
179 //                           +---------------+
180 //                             |
181 //                             |
182 //                             v
183 //                           +---------------+
184 //                           |     Merge     | <+
185 //                           +---------------+  |
186 //                             |                |
187 //                             |                |
188 //                             v                |
189 //         +-----------+     +---------------+  |
190 //         | loop_cond | --> |    Switch     | -+
191 //         +-----------+     +---------------+
192 //                             |
193 //                             |
194 //                             v
195 //                           +---------------+
196 //                           |     Exit      |
197 //                           +---------------+
198 struct DependentInductionVar {
199   Output induction_var;
200   ops::Switch latch;
201 };
202 
CreateDependentLoopInvariantValue(const Scope & root,const string & prefix,const string & frame_name,const Output & loop_cond,const Output & value)203 DependentInductionVar CreateDependentLoopInvariantValue(
204     const Scope& root, const string& prefix, const string& frame_name,
205     const Output& loop_cond, const Output& value) {
206   Output enter_value = ops::internal::Enter(root.WithOpName(prefix + "/enter"),
207                                             value, frame_name);
208   ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value});
209   ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
210   ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
211                            latch.output_false);
212   Output next_iteration = ops::NextIteration(
213       root.WithOpName(prefix + "/next_iteration"), latch.output_true);
214   CHECK(root.graph()
215             ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)
216             .ok());
217   return {iv.output, latch};
218 }
219 
CreateDependentLoopInvariantValue(const Scope & root,const string & prefix,const string & frame_name,const Output & loop_cond,int32_t value)220 DependentInductionVar CreateDependentLoopInvariantValue(
221     const Scope& root, const string& prefix, const string& frame_name,
222     const Output& loop_cond, int32_t value) {
223   return CreateDependentLoopInvariantValue(
224       root, prefix, frame_name, loop_cond,
225       ops::Const(root.WithOpName(prefix + "/init"), value));
226 }
227 
TEST(DeadnessAnalysisTest,BasicPositive)228 TEST(DeadnessAnalysisTest, BasicPositive) {
229   Scope root = Scope::NewRootScope().ExitOnError();
230 
231   ops::Switch sw = CreateSwitch(root, "0");
232   Output add =
233       ops::Add(root.WithOpName("add"), sw.output_true, sw.output_false);
234 
235   std::unique_ptr<DeadnessAnalysis> result;
236   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
237 
238   TF_ASSERT_OK_AND_ASSIGN(
239       bool has_inputs_with_mismatching_deadness,
240       HasInputsWithMismatchingDeadness(*result, *add.node()));
241   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
242 }
243 
TEST(DeadnessAnalysisTest,BasicNegative)244 TEST(DeadnessAnalysisTest, BasicNegative) {
245   Scope root = Scope::NewRootScope().ExitOnError();
246 
247   Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
248   Output b = ops::Placeholder(root.WithOpName("b"), DT_FLOAT);
249   Output add = ops::Add(root.WithOpName("add"), a, b);
250 
251   std::unique_ptr<DeadnessAnalysis> result;
252   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
253 
254   TF_ASSERT_OK_AND_ASSIGN(
255       bool has_inputs_with_mismatching_deadness,
256       HasInputsWithMismatchingDeadness(*result, *add.node()));
257   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
258 }
259 
TEST(DeadnessAnalysisTest,AndIsCommutative)260 TEST(DeadnessAnalysisTest, AndIsCommutative) {
261   Scope root = Scope::NewRootScope().ExitOnError();
262 
263   ops::Switch sw_0 = CreateSwitch(root, "0");
264   ops::Switch sw_1 = CreateSwitch(root, "1");
265 
266   Output a0 =
267       ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
268   Output a1 =
269       ops::Add(root.WithOpName("a1"), sw_1.output_false, sw_0.output_false);
270 
271   Output b0 =
272       ops::Add(root.WithOpName("b0"), sw_0.output_false, sw_1.output_true);
273   Output b1 =
274       ops::Add(root.WithOpName("b1"), sw_1.output_true, sw_0.output_false);
275 
276   Output live0 = ops::Add(root.WithOpName("live0"), a0, a1);
277   Output live1 = ops::Add(root.WithOpName("live1"), b0, b1);
278 
279   Output halfdead0 = ops::Add(root.WithOpName("halfdead0"), a0, b0);
280   Output halfdead1 = ops::Add(root.WithOpName("halfdead1"), a1, b1);
281 
282   std::unique_ptr<DeadnessAnalysis> result;
283   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
284 
285   bool has_inputs_with_mismatching_deadness;
286 
287   TF_ASSERT_OK_AND_ASSIGN(
288       has_inputs_with_mismatching_deadness,
289       HasInputsWithMismatchingDeadness(*result, *live0.node()));
290   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
291 
292   TF_ASSERT_OK_AND_ASSIGN(
293       has_inputs_with_mismatching_deadness,
294       HasInputsWithMismatchingDeadness(*result, *live1.node()));
295   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
296 
297   TF_ASSERT_OK_AND_ASSIGN(
298       has_inputs_with_mismatching_deadness,
299       HasInputsWithMismatchingDeadness(*result, *halfdead0.node()));
300   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
301 
302   TF_ASSERT_OK_AND_ASSIGN(
303       has_inputs_with_mismatching_deadness,
304       HasInputsWithMismatchingDeadness(*result, *halfdead1.node()));
305   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
306 }
307 
TEST(DeadnessAnalysisTest,AndIsAssociative)308 TEST(DeadnessAnalysisTest, AndIsAssociative) {
309   Scope root = Scope::NewRootScope().ExitOnError();
310 
311   ops::Switch sw_0 = CreateSwitch(root, "0");
312   ops::Switch sw_1 = CreateSwitch(root, "1");
313   ops::Switch sw_2 = CreateSwitch(root, "2");
314 
315   Output a0 =
316       ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
317   Output a1 = ops::Add(root.WithOpName("a1"), a0, sw_2.output_false);
318 
319   Output b0 =
320       ops::Add(root.WithOpName("b0"), sw_1.output_false, sw_2.output_false);
321   Output b1 = ops::Add(root.WithOpName("b1"), sw_0.output_false, b0);
322 
323   Output add = ops::Add(root.WithOpName("add"), a1, b1);
324 
325   std::unique_ptr<DeadnessAnalysis> result;
326   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
327 
328   TF_ASSERT_OK_AND_ASSIGN(
329       bool has_inputs_with_mismatching_deadness,
330       HasInputsWithMismatchingDeadness(*result, *add.node()));
331   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
332 }
333 
TEST(DeadnessAnalysisTest,OrIsCommutative)334 TEST(DeadnessAnalysisTest, OrIsCommutative) {
335   Scope root = Scope::NewRootScope().ExitOnError();
336 
337   ops::Switch sw_0 = CreateSwitch(root, "0");
338   ops::Switch sw_1 = CreateSwitch(root, "1");
339 
340   ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
341   ops::Merge m1(root.WithOpName("m1"), {sw_1.output_false, sw_0.output_false});
342   ops::Merge m2(root.WithOpName("m2"), {sw_0.output_false, sw_1.output_true});
343   ops::Merge m3(root.WithOpName("m3"), {sw_1.output_true, sw_0.output_false});
344 
345   Output live0 = ops::Add(root.WithOpName("live0"), m0.output, m1.output);
346   Output live1 = ops::Add(root.WithOpName("live1"), m2.output, m3.output);
347 
348   Output halfdead0 =
349       ops::Add(root.WithOpName("halfdead0"), m0.output, m2.output);
350   Output halfdead1 =
351       ops::Add(root.WithOpName("halfdead1"), m1.output, m3.output);
352 
353   std::unique_ptr<DeadnessAnalysis> result;
354   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
355 
356   bool has_inputs_with_mismatching_deadness;
357 
358   TF_ASSERT_OK_AND_ASSIGN(
359       has_inputs_with_mismatching_deadness,
360       HasInputsWithMismatchingDeadness(*result, *live0.node()));
361   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
362 
363   TF_ASSERT_OK_AND_ASSIGN(
364       has_inputs_with_mismatching_deadness,
365       HasInputsWithMismatchingDeadness(*result, *live1.node()));
366   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
367 
368   TF_ASSERT_OK_AND_ASSIGN(
369       has_inputs_with_mismatching_deadness,
370       HasInputsWithMismatchingDeadness(*result, *halfdead0.node()));
371   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
372 
373   TF_ASSERT_OK_AND_ASSIGN(
374       has_inputs_with_mismatching_deadness,
375       HasInputsWithMismatchingDeadness(*result, *halfdead1.node()));
376   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
377 }
378 
TEST(DeadnessAnalysisTest,OrIsAssociative)379 TEST(DeadnessAnalysisTest, OrIsAssociative) {
380   Scope root = Scope::NewRootScope().ExitOnError();
381 
382   ops::Switch sw_0 = CreateSwitch(root, "0");
383   ops::Switch sw_1 = CreateSwitch(root, "1");
384   ops::Switch sw_2 = CreateSwitch(root, "2");
385 
386   ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
387   ops::Merge m1(root.WithOpName("m1"), {m0.output, sw_2.output_false});
388   ops::Merge m2(root.WithOpName("m2"), {sw_1.output_false, sw_2.output_false});
389   ops::Merge m3(root.WithOpName("m3"), {sw_0.output_false, m2.output});
390 
391   Output add = ops::Add(root.WithOpName("add"), m1.output, m3.output);
392 
393   std::unique_ptr<DeadnessAnalysis> result;
394   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
395 
396   TF_ASSERT_OK_AND_ASSIGN(
397       bool has_inputs_with_mismatching_deadness,
398       HasInputsWithMismatchingDeadness(*result, *add.node()));
399   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
400 }
401 
TEST(DeadnessAnalysisTest,AndOfOr)402 TEST(DeadnessAnalysisTest, AndOfOr) {
403   Scope root = Scope::NewRootScope().ExitOnError();
404 
405   ops::Switch sw_0 = CreateSwitch(root, "0");
406   ops::Switch sw_1 = CreateSwitch(root, "1");
407   ops::Switch sw_2 = CreateSwitch(root, "2");
408   ops::Switch sw_3 = CreateSwitch(root, "3");
409 
410   ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
411   ops::Merge m1(root.WithOpName("m1"), {sw_2.output_false, sw_3.output_false});
412 
413   Output add0 = ops::Add(root.WithOpName("add0"), m0.output, m1.output);
414   Output add1 = ops::Add(root.WithOpName("add1"), m0.output, m1.output);
415 
416   Output add2 = ops::Add(root.WithOpName("add2"), add0, add1);
417 
418   std::unique_ptr<DeadnessAnalysis> result;
419   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
420 
421   TF_ASSERT_OK_AND_ASSIGN(
422       bool has_inputs_with_mismatching_deadness,
423       HasInputsWithMismatchingDeadness(*result, *add2.node()));
424   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
425 }
426 
TEST(DeadnessAnalysisTest,OrOfAnd)427 TEST(DeadnessAnalysisTest, OrOfAnd) {
428   Scope root = Scope::NewRootScope().ExitOnError();
429 
430   ops::Switch sw_0 = CreateSwitch(root, "0");
431   ops::Switch sw_1 = CreateSwitch(root, "1");
432   ops::Switch sw_2 = CreateSwitch(root, "2");
433   ops::Switch sw_3 = CreateSwitch(root, "3");
434 
435   Output add0 =
436       ops::Add(root.WithOpName("add0"), sw_0.output_false, sw_1.output_false);
437   Output add1 =
438       ops::Add(root.WithOpName("add1"), sw_2.output_false, sw_3.output_false);
439 
440   ops::Merge m0(root.WithOpName("m0"), {add0, add1});
441   ops::Merge m1(root.WithOpName("m1"), {add0, add1});
442 
443   Output add2 = ops::Add(root.WithOpName("add2"), m0.output, m1.output);
444 
445   std::unique_ptr<DeadnessAnalysis> result;
446   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
447 
448   TF_ASSERT_OK_AND_ASSIGN(
449       bool has_inputs_with_mismatching_deadness,
450       HasInputsWithMismatchingDeadness(*result, *add2.node()));
451   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
452 }
453 
TEST(DeadnessAnalysisTest,AndOrDistributiveSimplified)454 TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) {
455   // (*A | (~*A & ((~*B & ~*A) | (~*A & *B)))) == #true
456   Scope root = Scope::NewRootScope().ExitOnError();
457 
458   ops::Switch sw_0 = CreateSwitch(root, "A");
459   ops::Switch sw_1 = CreateSwitch(root, "B");
460   Output add0 =
461       ops::Add(root.WithOpName("and0"), sw_0.output_false, sw_1.output_true);
462   Output add1 =
463       ops::Add(root.WithOpName("and1"), sw_0.output_false, sw_1.output_false);
464   ops::Merge or2(root.WithOpName("or2"), {add0, add1});
465   Output add3 =
466       ops::Add(root.WithOpName("and3"), or2.output, sw_0.output_false);
467   ops::Merge or4(root.WithOpName("or4"), {add3, sw_0.output_true});
468 
469   std::unique_ptr<DeadnessAnalysis> result;
470   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
471 
472   PredicateMapTy predicate_map;
473   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
474   EXPECT_EQ(predicate_map[ControlOutputFor(or4.output)], "#true");
475 }
476 
TEST(DeadnessAnalysisTest,AndOrDistributive)477 TEST(DeadnessAnalysisTest, AndOrDistributive) {
478   // (A|B)&C == (A&C)|(B&C)
479   Scope root = Scope::NewRootScope().ExitOnError();
480 
481   ops::Switch sw_0 = CreateSwitch(root, "0");
482   ops::Switch sw_1 = CreateSwitch(root, "1");
483   ops::Switch sw_2 = CreateSwitch(root, "2");
484 
485   ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
486   Output add0 = ops::Add(root.WithOpName("add0"), m0.output, sw_2.output_false);
487 
488   Output add1 =
489       ops::Add(root.WithOpName("add1"), sw_0.output_false, sw_2.output_false);
490   Output add2 =
491       ops::Add(root.WithOpName("add2"), sw_1.output_false, sw_2.output_false);
492   ops::Merge m1(root.WithOpName("m1"), {add1, add2});
493 
494   Output add3 = ops::Add(root.WithOpName("add3"), add0, m1.output);
495 
496   std::unique_ptr<DeadnessAnalysis> result;
497   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
498 
499   TF_ASSERT_OK_AND_ASSIGN(
500       bool has_inputs_with_mismatching_deadness,
501       HasInputsWithMismatchingDeadness(*result, *add3.node()));
502   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
503 }
504 
TEST(DeadnessAnalysisTest,Ternary)505 TEST(DeadnessAnalysisTest, Ternary) {
506   Scope root = Scope::NewRootScope().ExitOnError();
507 
508   Output predicate = ops::Placeholder(root.WithOpName("predicate"), DT_BOOL);
509   Output true_value = ops::Placeholder(root.WithOpName("true_value"), DT_FLOAT);
510   Output false_value =
511       ops::Placeholder(root.WithOpName("false_value"), DT_FLOAT);
512 
513   ops::Switch predicated_true(root.WithOpName("predicated_true"), true_value,
514                               predicate);
515 
516   ops::Switch predicated_false(root.WithOpName("predicated_false"), true_value,
517                                predicate);
518   ops::Merge merge(root.WithOpName("ternary"), {predicated_true.output_true,
519                                                 predicated_false.output_false});
520   Output addend = ops::Placeholder(root.WithOpName("addend"), DT_FLOAT);
521   Output add = ops::Add(root.WithOpName("add"), merge.output, addend);
522 
523   std::unique_ptr<DeadnessAnalysis> result;
524   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
525 
526   TF_ASSERT_OK_AND_ASSIGN(
527       bool has_inputs_with_mismatching_deadness,
528       HasInputsWithMismatchingDeadness(*result, *add.node()));
529   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
530 }
531 
TEST(DeadnessAnalysisTest,Recv)532 TEST(DeadnessAnalysisTest, Recv) {
533   Scope root = Scope::NewRootScope().ExitOnError();
534 
535   Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_FLOAT, "tensor_a",
536                              "sender", 0, "receiver");
537   Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_FLOAT, "tensor_b",
538                              "sender", 0, "receiver");
539   Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
540 
541   std::unique_ptr<DeadnessAnalysis> result;
542   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
543 
544   TF_ASSERT_OK_AND_ASSIGN(
545       bool has_inputs_with_mismatching_deadness,
546       HasInputsWithMismatchingDeadness(*result, *add.node()));
547   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
548 }
549 
TEST(DeadnessAnalysisTest,HostRecv)550 TEST(DeadnessAnalysisTest, HostRecv) {
551   Scope root = Scope::NewRootScope().ExitOnError();
552 
553   Output recv_a = ops::_HostRecv(root.WithOpName("recv_a"), DT_FLOAT,
554                                  "tensor_a", "sender", 0, "receiver");
555   Output recv_b = ops::_HostRecv(root.WithOpName("recv_b"), DT_FLOAT,
556                                  "tensor_b", "sender", 0, "receiver");
557   Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
558 
559   std::unique_ptr<DeadnessAnalysis> result;
560   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
561 
562   TF_ASSERT_OK_AND_ASSIGN(
563       bool has_inputs_with_mismatching_deadness,
564       HasInputsWithMismatchingDeadness(*result, *add.node()));
565   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
566 }
567 
TEST(DeadnessAnalysisTest,Loop)568 TEST(DeadnessAnalysisTest, Loop) {
569   Scope root = Scope::NewRootScope().ExitOnError();
570   Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0).induction_var;
571   Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0).induction_var;
572   Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1).induction_var;
573   Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1);
574   Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2);
575 
576   // NB!  iv0 and iv1 are equivalent and a smarter deadness analysis would have
577   // noticed that.  Today we are pessimistic here because we assign an
578   // uninterpreted symbol to merges with backedges.
579 
580   VLogGraphIfAsked(*root.graph());
581 
582   {
583     std::unique_ptr<DeadnessAnalysis> result;
584     TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
585 
586     bool has_inputs_with_mismatching_deadness;
587 
588     TF_ASSERT_OK_AND_ASSIGN(
589         has_inputs_with_mismatching_deadness,
590         HasInputsWithMismatchingDeadness(*result, *add0.node()));
591     EXPECT_TRUE(has_inputs_with_mismatching_deadness);
592 
593     TF_ASSERT_OK_AND_ASSIGN(
594         has_inputs_with_mismatching_deadness,
595         HasInputsWithMismatchingDeadness(*result, *add1.node()));
596     EXPECT_TRUE(has_inputs_with_mismatching_deadness);
597   }
598   {
599     PredicateMapTy predicate_map;
600     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
601 
602     // In theory we should be able to tell that iv0/cond:0 and iv1/cond:0
603     // produce the same deadness.  But we're not that smart today.
604     EXPECT_EQ(predicate_map[ControlOutputFor(iv0)],
605               "{#true,&,*iv0/cond:0}<fr0>");
606     EXPECT_EQ(predicate_map[ControlOutputFor(iv1)],
607               "{#true,&,*iv1/cond:0}<fr0>");
608     EXPECT_EQ(predicate_map[ControlOutputFor(iv2)],
609               "{#true,&,*iv2/cond:0}<fr0>");
610     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
611               "({#true,&,*iv0/cond:0}<fr0> & {#true,&,*iv1/cond:0}<fr0>)");
612     EXPECT_EQ(predicate_map[ControlOutputFor(add1)],
613               "({#true,&,*iv1/cond:0}<fr0> & {#true,&,*iv2/cond:0}<fr0>)");
614   }
615 }
616 
TEST(DeadnessAnalysisTest,ControlEquivalentLoopBodies)617 TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
618   Scope root = Scope::NewRootScope().ExitOnError();
619   InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
620   Output dependent_iv0 =
621       CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0)
622           .induction_var;
623   Output dependent_iv1 =
624       CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0)
625           .induction_var;
626   Output add0 = ops::Add(root.WithOpName("add0"), dependent_iv0, dependent_iv1);
627 
628   VLogGraphIfAsked(*root.graph());
629 
630   {
631     std::unique_ptr<DeadnessAnalysis> result;
632     TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
633 
634     TF_ASSERT_OK_AND_ASSIGN(
635         bool has_inputs_with_mismatching_deadness,
636         HasInputsWithMismatchingDeadness(*result, *add0.node()));
637     EXPECT_FALSE(has_inputs_with_mismatching_deadness);
638   }
639   {
640     PredicateMapTy predicate_map;
641     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
642                                    /*enable_optimistic=*/true));
643 
644     EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
645               "{#true,&,*iv0/cond:0}<loop>");
646     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
647               predicate_map[ControlOutputFor(iv.induction_var)]);
648     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
649               predicate_map[ControlOutputFor(iv.induction_var)]);
650     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
651               predicate_map[ControlOutputFor(iv.induction_var)]);
652   }
653   {
654     PredicateMapTy predicate_map;
655     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
656                                    /*enable_optimistic=*/false));
657 
658     EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
659               "{#true,&,*iv0/cond:0}<loop>");
660     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
661               "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
662     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
663               "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
664     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
665               "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
666   }
667 }
668 
TEST(DeadnessAnalysisTest,LoopInvariantPredicateOnBackedge)669 TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
670   // Create a merge that "looks like" a loop but isn't really.  It has a value
671   // that does not depend on the merge on its backedge.
672   Scope root = Scope::NewRootScope().ExitOnError();
673   InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0);
674   DependentInductionVar dependent_iv =
675       CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0);
676   FixupSourceAndSinkEdges(root.graph());
677 
678   TF_ASSERT_OK(root.graph()->UpdateEdge(
679       iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0));
680 
681   VLogGraphIfAsked(*root.graph());
682 
683   {
684     PredicateMapTy predicate_map;
685     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
686                                    /*enable_optimistic=*/true));
687 
688     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
689               "{#true,&,*iv0/cond:0}<frame>");
690   }
691   {
692     PredicateMapTy predicate_map;
693     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
694                                    /*enable_optimistic=*/false));
695 
696     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
697               "div0/iv:0");
698   }
699 }
700 
TEST(DeadnessAnalysisTest,ControlEquivalentNestedLoopBodies)701 TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
702   Scope root = Scope::NewRootScope().ExitOnError();
703   InductionVarInfo iv_outer =
704       CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
705   Output enter_constant_outer_loop = ops::internal::Enter(
706       root.WithOpName("constant_enter_outer_loop"),
707       ops::Const(root.WithOpName("constant"), 5), "outer_loop",
708       ops::internal::Enter::Attrs().IsConstant(true));
709   ops::Switch inner_value(root.WithOpName("outer_is_live"),
710                           enter_constant_outer_loop, iv_outer.loop_cond);
711   InductionVarInfo iv_inner = CreateInductionVariable(
712       root, "iv_inner", "inner_loop", inner_value.output_true);
713 
714   Output dependent_outer_iv0 =
715       CreateDependentLoopInvariantValue(root, "dependent_outer_iv0",
716                                         "outer_loop", iv_outer.loop_cond, 0)
717           .induction_var;
718   Output dependent_outer_iv1 =
719       CreateDependentLoopInvariantValue(root, "dependent_outer_iv1",
720                                         "outer_loop", iv_outer.loop_cond, 0)
721           .induction_var;
722 
723   Output dependent_inner_iv0 = CreateDependentLoopInvariantValue(
724                                    root, "dependent_inner_iv0", "inner_loop",
725                                    iv_inner.loop_cond, dependent_outer_iv0)
726                                    .induction_var;
727   Output dependent_inner_iv1 = CreateDependentLoopInvariantValue(
728                                    root, "dependent_inner_iv1", "inner_loop",
729                                    iv_inner.loop_cond, dependent_outer_iv1)
730                                    .induction_var;
731 
732   Output add0 = ops::Add(root.WithOpName("add0"), dependent_inner_iv0,
733                          dependent_inner_iv1);
734 
735   VLogGraphIfAsked(*root.graph());
736 
737   {
738     std::unique_ptr<DeadnessAnalysis> result;
739     TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
740 
741     TF_ASSERT_OK_AND_ASSIGN(
742         bool has_inputs_with_mismatching_deadness,
743         HasInputsWithMismatchingDeadness(*result, *add0.node()));
744     EXPECT_FALSE(has_inputs_with_mismatching_deadness);
745   }
746   {
747     PredicateMapTy predicate_map;
748     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
749                                    /*enable_optimistic=*/true));
750 
751     EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
752               "{#true,&,*iv_outer/cond:0}<outer_loop>");
753     EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
754               "{(*iv_outer/cond:0 & "
755               "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
756               "cond:0}<inner_loop;outer_loop>");
757 
758     // enable_optimistic = true or not should produce the same results because
759     // of fallback.  However, note that the order of iv_inner/cond:0 and
760     // iv_inner/iv:0 is different because the optimistic approach does not
761     // create predicates for all merges and it can change the predicate id and
762     // hence the symbol order.
763     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
764               "{{#true,&,(iv_outer/iv:0 & "
765               "*iv_outer/cond:0)}<outer_loop>,&,(*iv_inner/cond:0 & "
766               "iv_inner/iv:0)}<inner_loop;outer_loop>");
767     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
768               predicate_map[ControlOutputFor(dependent_inner_iv0)]);
769     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
770               predicate_map[ControlOutputFor(dependent_inner_iv0)]);
771   }
772   {
773     PredicateMapTy predicate_map;
774     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
775                                    /*enable_optimistic=*/false));
776 
777     EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
778               "{#true,&,*iv_outer/cond:0}<outer_loop>");
779     EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
780               "{(*iv_outer/cond:0 & "
781               "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
782               "cond:0}<inner_loop;outer_loop>");
783 
784     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
785               "{{#true,&,(iv_outer/iv:0 & "
786               "*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
787               "*iv_inner/cond:0)}<inner_loop;outer_loop>");
788     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
789               predicate_map[ControlOutputFor(dependent_inner_iv0)]);
790     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
791               predicate_map[ControlOutputFor(dependent_inner_iv0)]);
792   }
793 }
794 
TEST(DeadnessAnalysisTest,ControlNonEquivalentNestedLoopBodies)795 TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
796   Scope root = Scope::NewRootScope().ExitOnError();
797 
798   std::array<Output, 2> outer_iv;
799   std::array<Output, 2> inner_iv;
800 
801   for (int i : {0, 1}) {
802     InductionVarInfo iv_outer =
803         CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
804     Output enter_constant_outer_loop = ops::internal::Enter(
805         root.WithOpName("constant_enter_outer_loop"),
806         ops::Const(root.WithOpName("constant"), 5), "outer_loop",
807         ops::internal::Enter::Attrs().IsConstant(true));
808     ops::Switch inner_value(root.WithOpName("outer_is_live"),
809                             enter_constant_outer_loop, iv_outer.loop_cond);
810     InductionVarInfo iv_inner = CreateInductionVariable(
811         root, "iv_inner", "inner_loop", inner_value.output_true);
812 
813     outer_iv[i] = iv_outer.induction_var;
814     inner_iv[i] = iv_inner.induction_var;
815   }
816 
817   Output add0 = ops::Add(root.WithOpName("add0"), inner_iv[0], inner_iv[1]);
818 
819   VLogGraphIfAsked(*root.graph());
820 
821   {
822     std::unique_ptr<DeadnessAnalysis> result;
823     TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
824 
825     TF_ASSERT_OK_AND_ASSIGN(
826         bool has_inputs_with_mismatching_deadness,
827         HasInputsWithMismatchingDeadness(*result, *add0.node()));
828     EXPECT_TRUE(has_inputs_with_mismatching_deadness);
829   }
830 
831   {
832     PredicateMapTy predicate_map;
833     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
834 
835     EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[0])],
836               "{#true,&,*iv_outer/cond:0}<outer_loop>");
837     EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[0])],
838               "{(*iv_outer/cond:0 & "
839               "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
840               "cond:0}<inner_loop;outer_loop>");
841     EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[1])],
842               "{#true,&,*iv_outer/cond_1:0}<outer_loop>");
843     EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[1])],
844               "{(*iv_outer/cond_1:0 & "
845               "{#true,&,*iv_outer/cond_1:0}<outer_loop>),&,*iv_inner/"
846               "cond_1:0}<inner_loop;outer_loop>");
847     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
848               "({(*iv_outer/cond:0 & "
849               "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
850               "cond:0}<inner_loop;outer_loop> & {(*iv_outer/cond_1:0 & "
851               "{#true,&,*iv_outer/cond_1:0}<outer_loop>),&,*iv_inner/"
852               "cond_1:0}<inner_loop;outer_loop>)");
853   }
854 }
855 
TEST(DeadnessAnalysisTest,NestedLoopBodiesWithACapture)856 TEST(DeadnessAnalysisTest, NestedLoopBodiesWithACapture) {
857   Scope root = Scope::NewRootScope().ExitOnError();
858   InductionVarInfo iv_outer =
859       CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
860   Output enter_constant_outer_loop = ops::internal::Enter(
861       root.WithOpName("constant_enter_outer_loop"),
862       ops::Const(root.WithOpName("constant"), 5), "outer_loop",
863       ops::internal::Enter::Attrs().IsConstant(true));
864   ops::Switch inner_value(root.WithOpName("outer_is_live"),
865                           enter_constant_outer_loop, iv_outer.loop_cond);
866   InductionVarInfo iv_inner = CreateInductionVariable(
867       root, "iv_inner", "inner_loop", inner_value.output_true);
868 
869   DependentInductionVar div0_outer = CreateDependentLoopInvariantValue(
870       root, "div0_outer", "outer_loop", iv_outer.loop_cond, 0);
871   DependentInductionVar div1_outer = CreateDependentLoopInvariantValue(
872       root, "div1_outer", "outer_loop", iv_outer.loop_cond, 0);
873 
874   DependentInductionVar div0_inner = CreateDependentLoopInvariantValue(
875       root, "div0_inner", "inner_loop", iv_inner.loop_cond,
876       div0_outer.induction_var);
877   DependentInductionVar div1_inner = CreateDependentLoopInvariantValue(
878       root, "div1_inner", "inner_loop", iv_inner.loop_cond,
879       div1_outer.induction_var);
880 
881   Output captured = ops::_Recv(root.WithOpName("captured"), DT_INT32,
882                                "tensor_a", "sender", 0, "receiver");
883   Output capture_enter_outer = ops::internal::Enter(
884       root.WithOpName("capture_enter_outer"), captured, "outer_loop",
885       ops::internal::Enter::Attrs().IsConstant(true));
886   Output capture_enter_inner = ops::internal::Enter(
887       root.WithOpName("capture_enter_inner"), capture_enter_outer, "inner_loop",
888       ops::internal::Enter::Attrs().IsConstant(true));
889   Output mul0 = ops::Mul(root.WithOpName("mul0"), div1_inner.induction_var,
890                          capture_enter_inner);
891   TF_ASSERT_OK(root.graph()->UpdateEdge(
892       mul0.node(), 0, div1_inner.latch.output_true.node(), 0));
893 
894   Output add0 = ops::Add(root.WithOpName("add0"), div0_inner.induction_var,
895                          div1_inner.induction_var);
896 
897   VLogGraphIfAsked(*root.graph());
898 
899   {
900     std::unique_ptr<DeadnessAnalysis> result;
901     TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
902 
903     TF_ASSERT_OK_AND_ASSIGN(
904         bool has_inputs_with_mismatching_deadness,
905         HasInputsWithMismatchingDeadness(*result, *add0.node()));
906     EXPECT_TRUE(has_inputs_with_mismatching_deadness);
907   }
908 }
909 
TEST(DeadnessAnalysisTest,CyclicRecurrence)910 TEST(DeadnessAnalysisTest, CyclicRecurrence) {
911   Scope root = Scope::NewRootScope().ExitOnError();
912   InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
913   DependentInductionVar div0 =
914       CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0);
915   DependentInductionVar div1 =
916       CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0);
917   FixupSourceAndSinkEdges(root.graph());
918   TF_ASSERT_OK(root.graph()->UpdateEdge(div1.induction_var.node(), 0,
919                                         div0.latch.output_true.node(), 0));
920   TF_ASSERT_OK(root.graph()->UpdateEdge(div0.induction_var.node(), 0,
921                                         div1.latch.output_true.node(), 0));
922 
923   VLogGraphIfAsked(*root.graph());
924 
925   {
926     PredicateMapTy predicate_map;
927     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
928                                    /*enable_optimistic=*/true));
929 
930     EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
931               "{#true,&,*iv0/cond:0}<loop>");
932     EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)],
933               "{#true,&,*iv0/cond:0}<loop>");
934     EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)],
935               "{#true,&,*iv0/cond:0}<loop>");
936 
937     // This tests the rule {S,&,X} & ~X => S.
938     TensorId switch_false_out = {div1.latch.output_false.node()->name(),
939                                  div1.latch.output_false.index()};
940     EXPECT_EQ(predicate_map[switch_false_out], "(#true)");
941   }
942   {
943     PredicateMapTy predicate_map;
944     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
945                                    /*enable_optimistic=*/false));
946 
947     EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
948               "{#true,&,*iv0/cond:0}<loop>");
949     EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)], "div0/iv:0");
950     EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)], "div1/iv:0");
951   }
952 }
953 
TEST(DeadnessAnalysisTest,AndRecurrenceNeedsFrameName)954 TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) {
955   Scope root = Scope::NewRootScope().ExitOnError();
956   InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10);
957   InductionVarInfo iv_1 = CreateInductionVariable(root, "iv_1", "frame_1", 9);
958 
959   Output init = CreateSwitch(root, "init").output_true;
960   Output step = CreateSwitch(root, "step").output_true;
961 
962   std::array<Output, 2> exits;
963   std::array<Output, 2> next_iterations;
964 
965   for (int i : {0, 1}) {
966     Output init_enter = ops::internal::Enter(
967         root.WithOpName(absl::StrCat("init_enter_frame_", i)), init,
968         absl::StrCat("frame_", i),
969         ops::internal::Enter::Attrs().IsConstant(true));
970     Output step_enter = ops::internal::Enter(
971         root.WithOpName(absl::StrCat("step_enter_frame_", i)), step,
972         absl::StrCat("frame_", i),
973         ops::internal::Enter::Attrs().IsConstant(true));
974 
975     ops::Merge iv(root.WithOpName(absl::StrCat("expr_", i)),
976                   {init_enter, init_enter});
977     Output add = ops::Add(root.WithOpName(absl::StrCat("add_", i)), iv.output,
978                           step_enter);
979     next_iterations[i] = ops::NextIteration(
980         root.WithOpName(absl::StrCat("expr_", i, "_next_iteration")), add);
981     EXPECT_TRUE(
982         root.graph()
983             ->UpdateEdge(next_iterations[i].node(), 0, iv.output.node(), 1)
984             .ok());
985     exits[i] = ops::internal::Exit(root.WithOpName(absl::StrCat("exit_", i)),
986                                    iv.output);
987   }
988 
989   FixupSourceAndSinkEdges(root.graph());
990 
991   {
992     PredicateMapTy predicate_map;
993     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
994 
995     EXPECT_NE(predicate_map[ControlOutputFor(exits[0])],
996               predicate_map[ControlOutputFor(exits[1])]);
997     EXPECT_NE(predicate_map[ControlOutputFor(exits[0])], "");
998     EXPECT_NE(predicate_map[ControlOutputFor(exits[1])], "");
999 
1000     EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])],
1001               predicate_map[ControlOutputFor(next_iterations[1])]);
1002     EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])], "");
1003     EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[1])], "");
1004   }
1005 }
1006 
TEST(DeadnessAnalysisTest,ControlInputs)1007 TEST(DeadnessAnalysisTest, ControlInputs) {
1008   Scope root = Scope::NewRootScope().ExitOnError();
1009   ops::Switch sw = CreateSwitch(root, "0");
1010 
1011   Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
1012   Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
1013 
1014   Output const0 = ops::Const(root.WithOpName("const0"), 1);
1015   Output const1 = ops::Const(root.WithOpName("const1"), 2);
1016 
1017   Output add = ops::Add(root.WithOpName("add"), const0, const1);
1018 
1019   root.graph()->AddControlEdge(id0.node(), const0.node());
1020   root.graph()->AddControlEdge(id1.node(), const1.node());
1021 
1022   std::unique_ptr<DeadnessAnalysis> result;
1023   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1024 
1025   TF_ASSERT_OK_AND_ASSIGN(
1026       bool has_inputs_with_mismatching_deadness,
1027       HasInputsWithMismatchingDeadness(*result, *add.node()));
1028   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
1029 }
1030 
TEST(DeadnessAnalysisTest,ControlTrigger)1031 TEST(DeadnessAnalysisTest, ControlTrigger) {
1032   Scope root = Scope::NewRootScope().ExitOnError();
1033   ops::Switch sw = CreateSwitch(root, "0");
1034 
1035   Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
1036   Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
1037 
1038   ops::ControlTrigger ctrl_trigger0(root.WithOpName("ctrl_trigger0"));
1039   ops::ControlTrigger ctrl_trigger1(root.WithOpName("ctrl_trigger1"));
1040 
1041   Output const0 = ops::Const(root.WithOpName("const0"), 1);
1042   Output const1 = ops::Const(root.WithOpName("const1"), 2);
1043 
1044   Output add = ops::Add(root.WithOpName("add"), const0, const1);
1045 
1046   root.graph()->AddControlEdge(id0.node(), ctrl_trigger0.operation.node());
1047   root.graph()->AddControlEdge(ctrl_trigger0.operation.node(), const0.node());
1048 
1049   root.graph()->AddControlEdge(id1.node(), ctrl_trigger1.operation.node());
1050   root.graph()->AddControlEdge(ctrl_trigger1.operation.node(), const1.node());
1051 
1052   std::unique_ptr<DeadnessAnalysis> result;
1053   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1054 
1055   TF_ASSERT_OK_AND_ASSIGN(
1056       bool has_inputs_with_mismatching_deadness,
1057       HasInputsWithMismatchingDeadness(*result, *add.node()));
1058   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
1059 }
1060 
TEST(DeadnessAnalysisTest,ControlInputsToMerge)1061 TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
1062   Scope root = Scope::NewRootScope().ExitOnError();
1063   ops::Switch sw = CreateSwitch(root, "0");
1064 
1065   Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
1066   Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
1067 
1068   Output constant = ops::Const(root.WithOpName("constant"), 5);
1069   ops::Merge m0(root.WithOpName("m0"), {constant});
1070   ops::Merge m1(root.WithOpName("m0"), {constant});
1071   Output add = ops::Add(root.WithOpName("add"), m0.output, m1.output);
1072 
1073   root.graph()->AddControlEdge(id0.node(), m0.output.node());
1074   root.graph()->AddControlEdge(id1.node(), m1.output.node());
1075 
1076   std::unique_ptr<DeadnessAnalysis> result;
1077   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1078 
1079   TF_ASSERT_OK_AND_ASSIGN(
1080       bool has_inputs_with_mismatching_deadness,
1081       HasInputsWithMismatchingDeadness(*result, *add.node()));
1082   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
1083 }
1084 
TEST(DeadnessAnalysisTest,RecvVsSwitch)1085 TEST(DeadnessAnalysisTest, RecvVsSwitch) {
1086   // Demonstrates why we need the must_be_true bit on SymbolP.
1087   Scope root = Scope::NewRootScope().ExitOnError();
1088 
1089   Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
1090                            0, "receiver");
1091   Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
1092   ops::Switch sw(root.WithOpName("switch"), value, recv);
1093   Output logical_and =
1094       ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
1095 
1096   std::unique_ptr<DeadnessAnalysis> result;
1097   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1098 
1099   TF_ASSERT_OK_AND_ASSIGN(
1100       bool has_inputs_with_mismatching_deadness,
1101       HasInputsWithMismatchingDeadness(*result, *logical_and.node()));
1102   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
1103 }
1104 
TEST(DeadnessAnalysisTest,RecvVsSwitchText)1105 TEST(DeadnessAnalysisTest, RecvVsSwitchText) {
1106   // Demonstrates why we need the must_be_true bit on SymbolP.
1107   Scope root = Scope::NewRootScope().ExitOnError();
1108 
1109   Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
1110                            0, "receiver");
1111   Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
1112   ops::Switch sw(root.WithOpName("switch"), value, recv);
1113   Output logical_and =
1114       ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
1115 
1116   std::unique_ptr<DeadnessAnalysis> result;
1117   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1118 
1119   PredicateMapTy predicate_map;
1120   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1121 
1122   TensorId logical_and_output_0 = {logical_and.node()->name(),
1123                                    Graph::kControlSlot};
1124   EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)");
1125 }
1126 
TEST(DeadnessAnalysisTest,DeMorgan)1127 TEST(DeadnessAnalysisTest, DeMorgan) {
1128   Scope root = Scope::NewRootScope().ExitOnError();
1129 
1130   Output cond_0 = ops::Placeholder(root.WithOpName("cond_0"), DT_BOOL);
1131   Output cond_1 = ops::Placeholder(root.WithOpName("cond_1"), DT_BOOL);
1132   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1133 
1134   ops::Switch sw_0(root.WithOpName("switch_0"), value, cond_0);
1135   ops::Switch sw_1(root.WithOpName("switch_1"), value, cond_1);
1136 
1137   Output and_0_1 =
1138       ops::Add(root.WithOpName("and_0_1"), sw_0.output_true, sw_1.output_true);
1139 
1140   Output or_not0_not1 = ops::Merge(root.WithOpName("or_not0_not1"),
1141                                    {sw_0.output_false, sw_1.output_false})
1142                             .output;
1143 
1144   // Predicate(should_always_be_dead) =
1145   // (A & B) & (~A | ~B) = (A & B) & ~(A & B) = False
1146   Output should_always_be_dead =
1147       ops::Add(root.WithOpName("should_always_be_dead"), and_0_1, or_not0_not1);
1148 
1149   // Predicate(should_always_be_dead) =
1150   // (A & B) | (~A | ~B) = (A & B) | ~(A & B) = True
1151   Output should_always_be_alive =
1152       ops::Merge(root.WithOpName("should_always_be_alive"),
1153                  {and_0_1, or_not0_not1})
1154           .output;
1155 
1156   std::unique_ptr<DeadnessAnalysis> result;
1157   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1158 
1159   PredicateMapTy predicate_map;
1160   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1161 
1162   EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_dead)], "#false");
1163   EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_alive)], "#true");
1164 }
1165 
TEST(DeadnessAnalysisTest,ConstantTrueSwitchCondition)1166 TEST(DeadnessAnalysisTest, ConstantTrueSwitchCondition) {
1167   Scope root = Scope::NewRootScope().ExitOnError();
1168 
1169   Output constant_true = ops::Const(root.WithOpName("const_true"), true);
1170   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1171   ops::Switch sw(root.WithOpName("switch"), value, constant_true);
1172 
1173   Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
1174   Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
1175 
1176   FixupSourceAndSinkEdges(root.graph());
1177 
1178   PredicateMapTy predicate_map;
1179   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1180 
1181   EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#false");
1182   EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#true");
1183 }
1184 
TEST(DeadnessAnalysisTest,ConstantFalseSwitchCondition)1185 TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) {
1186   Scope root = Scope::NewRootScope().ExitOnError();
1187 
1188   Output constant_false = ops::Const(root.WithOpName("const_false"), false);
1189   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1190   ops::Switch sw(root.WithOpName("switch"), value, constant_false);
1191 
1192   Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
1193   Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
1194 
1195   FixupSourceAndSinkEdges(root.graph());
1196 
1197   PredicateMapTy predicate_map;
1198   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1199 
1200   EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#true");
1201   EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false");
1202 }
1203 
TEST(DeadnessAnalysisTest,RefBoolSwitchCondition)1204 TEST(DeadnessAnalysisTest, RefBoolSwitchCondition) {
1205   Scope root = Scope::NewRootScope().ExitOnError();
1206 
1207   Output condition_ref_var =
1208       ops::Variable(root.WithOpName("cond_ref"), TensorShape({}), DT_BOOL);
1209   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1210   ops::Switch sw(root.WithOpName("switch"), value, condition_ref_var);
1211 
1212   Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
1213   Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
1214 
1215   FixupSourceAndSinkEdges(root.graph());
1216 
1217   PredicateMapTy predicate_map;
1218   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1219 
1220   EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "~*cond_ref:0");
1221   EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "*cond_ref:0");
1222 }
1223 
CreateSwitchN(const Scope & scope,Input data,Input output_index,int64_t num_outs,OutputList * outputs)1224 void CreateSwitchN(const Scope& scope, Input data, Input output_index,
1225                    int64_t num_outs, OutputList* outputs) {
1226   if (!scope.ok()) return;
1227   auto _data = ops::AsNodeOut(scope, data);
1228   if (!scope.ok()) return;
1229   auto _output_index = ops::AsNodeOut(scope, output_index);
1230   if (!scope.ok()) return;
1231   Node* ret;
1232   const auto unique_name = scope.GetUniqueNameForOp("_SwitchN");
1233   auto builder = NodeBuilder(unique_name, "_SwitchN")
1234                      .Input(_data)
1235                      .Input(_output_index)
1236                      .Attr("num_outs", num_outs);
1237   scope.UpdateBuilder(&builder);
1238   scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
1239   if (!scope.ok()) return;
1240   scope.UpdateStatus(scope.DoShapeInference(ret));
1241   for (int32_t i = 0; i < ret->num_outputs(); ++i) {
1242     outputs->push_back(Output(ret, i));
1243   }
1244 }
1245 
TEST(DeadnessAnalysisTest,Constant1_SwitchN_2Branches_DoesNotFail)1246 TEST(DeadnessAnalysisTest, Constant1_SwitchN_2Branches_DoesNotFail) {
1247   Scope root = Scope::NewRootScope().ExitOnError();
1248 
1249   Output constant_1 = ops::Const(root.WithOpName("const_1"), 1);
1250   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1251   OutputList outputs;
1252   CreateSwitchN(root.WithOpName("switchn"), value, constant_1, 2, &outputs);
1253 
1254   Output id_0 = ops::Identity(root.WithOpName("id_0"), outputs[0]);
1255   Output id_1 = ops::Identity(root.WithOpName("id_1"), outputs[1]);
1256 
1257   FixupSourceAndSinkEdges(root.graph());
1258 
1259   PredicateMapTy predicate_map;
1260   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1261 
1262   EXPECT_EQ(predicate_map[ControlOutputFor(id_0)], "#false");
1263   EXPECT_EQ(predicate_map[ControlOutputFor(id_1)], "#true");
1264 }
1265 
TEST(DeadnessAnalysisTest,Constant7_SwitchN_3Branches)1266 TEST(DeadnessAnalysisTest, Constant7_SwitchN_3Branches) {
1267   Scope root = Scope::NewRootScope().ExitOnError();
1268 
1269   Output constant_7 = ops::Const(root.WithOpName("const_7"), 7);
1270   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1271   OutputList outputs;
1272   CreateSwitchN(root.WithOpName("switchn"), value, constant_7, 3, &outputs);
1273 
1274   Output id_0 = ops::Identity(root.WithOpName("id_0"), outputs[0]);
1275   Output id_1 = ops::Identity(root.WithOpName("id_1"), outputs[1]);
1276   Output id_2 = ops::Identity(root.WithOpName("id_2"), outputs[2]);
1277 
1278   FixupSourceAndSinkEdges(root.graph());
1279 
1280   PredicateMapTy predicate_map;
1281   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1282 
1283   EXPECT_EQ(predicate_map[ControlOutputFor(id_0)], "#false");
1284   EXPECT_EQ(predicate_map[ControlOutputFor(id_1)], "#false");
1285   EXPECT_EQ(predicate_map[ControlOutputFor(id_2)], "#true");
1286 }
1287 
TEST(DeadnessAnalysisTest,RefInt_SwitchN_3Branches)1288 TEST(DeadnessAnalysisTest, RefInt_SwitchN_3Branches) {
1289   Scope root = Scope::NewRootScope().ExitOnError();
1290 
1291   Output condition_ref_var =
1292       ops::Variable(root.WithOpName("bidx"), TensorShape({}), DT_INT32);
1293   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1294   OutputList outputs;
1295   CreateSwitchN(root.WithOpName("switchn"), value, condition_ref_var, 3,
1296                 &outputs);
1297 
1298   Output id_0 = ops::Identity(root.WithOpName("id_0"), outputs[0]);
1299   Output id_1 = ops::Identity(root.WithOpName("id_1"), outputs[1]);
1300   Output id_2 = ops::Identity(root.WithOpName("id_2"), outputs[2]);
1301 
1302   FixupSourceAndSinkEdges(root.graph());
1303 
1304   PredicateMapTy predicate_map;
1305   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1306 
1307   EXPECT_EQ(predicate_map[ControlOutputFor(id_0)], "bidx:0=0");
1308   EXPECT_EQ(predicate_map[ControlOutputFor(id_1)], "(~bidx:0=0 & bidx:0=1)");
1309   EXPECT_EQ(predicate_map[ControlOutputFor(id_2)], "(~bidx:0=0 & ~bidx:0=1)");
1310 }
1311 
1312 }  // namespace
1313 }  // namespace tensorflow
1314