xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
17 
18 #include "absl/strings/match.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
24 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
25 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
26 #include "tensorflow/core/grappler/utils.h"
27 #include "tensorflow/core/grappler/utils/grappler_test.h"
28 #include "tensorflow/core/grappler/utils/topological_sort.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/platform/test.h"
31 
32 namespace tensorflow {
33 namespace grappler {
34 namespace {
35 
36 class DependencyOptimizerTest : public GrapplerTest {};
37 
VerifyGraphsEqual(const GraphDef & original_graph,const GraphDef & optimized_graph,const string & func)38 void VerifyGraphsEqual(const GraphDef& original_graph,
39                        const GraphDef& optimized_graph, const string& func) {
40   EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
41   for (int i = 0; i < original_graph.node_size(); ++i) {
42     const NodeDef& original = original_graph.node(i);
43     const NodeDef& optimized = optimized_graph.node(i);
44     EXPECT_EQ(original.name(), optimized.name()) << func;
45     EXPECT_EQ(original.op(), optimized.op()) << func;
46     EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
47     for (int j = 0; j < original.input_size(); ++j) {
48       EXPECT_EQ(original.input(j), optimized.input(j)) << func;
49     }
50   }
51 }
52 
TEST_F(DependencyOptimizerTest,NoOp)53 TEST_F(DependencyOptimizerTest, NoOp) {
54   // This trivial graph is so basic there's nothing to optimize.
55   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
56   GrapplerItem item;
57   CHECK(fake_input.NextItem(&item));
58 
59   DependencyOptimizer optimizer;
60   GraphDef output;
61   Status status = optimizer.Optimize(nullptr, item, &output);
62   TF_EXPECT_OK(status);
63 
64   VerifyGraphsEqual(item.graph, output, __FUNCTION__);
65 }
66 
TEST_F(DependencyOptimizerTest,DependenciesDrivenByConstants)67 TEST_F(DependencyOptimizerTest, DependenciesDrivenByConstants) {
68   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
69   Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
70   Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2});
71   Output z = ops::Const(s.WithOpName("z"), {1.0f, 2.0f}, {1, 2});
72   Output add = ops::Add(s.WithOpName("add"), x, y);
73   Output id1 =
74       ops::Identity(s.WithOpName("id1").WithControlDependencies(x), add);
75   Output id2 = ops::Identity(
76       s.WithOpName("id2").WithControlDependencies(y).WithControlDependencies(z),
77       add);
78 
79   GrapplerItem item;
80   TF_CHECK_OK(s.ToGraphDef(&item.graph));
81   item.fetch.push_back("id1");
82   item.fetch.push_back("id2");
83 
84   DependencyOptimizer optimizer;
85   GraphDef output;
86   Status status = optimizer.Optimize(nullptr, item, &output);
87   TF_EXPECT_OK(status);
88   // Run the optimizer twice to make sure the rewrite is idempotent.
89   item.graph.Swap(&output);
90   status = optimizer.Optimize(nullptr, item, &output);
91   TF_EXPECT_OK(status);
92 
93   // The 'z' node should have been optimized away leaving only 5 nodes.
94   EXPECT_EQ(5, output.node_size());
95 
96   for (const NodeDef& node : item.graph.node()) {
97     if (node.name() == "id1" || node.name() == "id2") {
98       EXPECT_EQ(1, node.input_size());
99       EXPECT_EQ("add", node.input(0));
100     }
101   }
102 }
103 
TEST_F(DependencyOptimizerTest,ChangeToNoop)104 TEST_F(DependencyOptimizerTest, ChangeToNoop) {
105   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
106   Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
107   Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
108   Output add = ops::Add(s.WithOpName("add"), x, y);
109   Output id1 =
110       ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x);
111   Output id2 =
112       ops::Identity(s.WithOpName("id2").WithControlDependencies(add), y);
113 
114   GrapplerItem item;
115   TF_CHECK_OK(s.ToGraphDef(&item.graph));
116   item.fetch.push_back("id1");
117   item.fetch.push_back("id2");
118 
119   DependencyOptimizer optimizer;
120   GraphDef output;
121   Status status = optimizer.Optimize(nullptr, item, &output);
122   TF_EXPECT_OK(status);
123   // Run the optimizer twice to make sure the rewrite is idempotent.
124   item.graph.Swap(&output);
125   status = optimizer.Optimize(nullptr, item, &output);
126   TF_EXPECT_OK(status);
127 
128   EXPECT_EQ(item.graph.node_size(), output.node_size());
129   int found = 0;
130   for (int i = 0; i < item.graph.node_size(); ++i) {
131     const NodeDef& node = item.graph.node(i);
132     // "add" should get turned into a NoOp and removed.
133     EXPECT_NE("add", node.name());
134     if (node.name() == "id1") {
135       EXPECT_EQ("Identity", node.op());
136       EXPECT_EQ(2, node.input_size());
137       EXPECT_EQ("x", node.input(0));
138       EXPECT_EQ("^y", node.input(1));
139       ++found;
140     } else if (node.name() == "id2") {
141       EXPECT_EQ("Identity", node.op());
142       EXPECT_EQ(2, node.input_size());
143       EXPECT_EQ("y", node.input(0));
144       EXPECT_EQ("^x", node.input(1));
145       ++found;
146     }
147   }
148   EXPECT_EQ(2, found);
149 }
150 
TEST_F(DependencyOptimizerTest,ChangeToNoop_RepeatedInput)151 TEST_F(DependencyOptimizerTest, ChangeToNoop_RepeatedInput) {
152   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
153   Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
154   Output add = ops::Add(s.WithOpName("add"), x, x);
155   Output id1 =
156       ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x);
157   GrapplerItem item;
158   TF_CHECK_OK(s.ToGraphDef(&item.graph));
159   item.fetch = {"id1"};
160 
161   DependencyOptimizer optimizer;
162   GraphDef output;
163   Status status = optimizer.Optimize(nullptr, item, &output);
164   TF_EXPECT_OK(status);
165   // Run the optimizer twice to make sure the rewrite is idempotent.
166   item.graph.Swap(&output);
167   status = optimizer.Optimize(nullptr, item, &output);
168   TF_EXPECT_OK(status);
169 
170   EXPECT_EQ(item.graph.node_size(), output.node_size());
171   int found = 0;
172   for (int i = 0; i < item.graph.node_size(); ++i) {
173     const NodeDef& node = item.graph.node(i);
174     // "add" should get turned into a NoOp and removed.
175     EXPECT_NE("add", node.name());
176     if (node.name() == "id1") {
177       EXPECT_EQ("Identity", node.op());
178       EXPECT_EQ(1, node.input_size());
179       EXPECT_EQ("x", node.input(0));
180       ++found;
181     }
182   }
183   EXPECT_EQ(1, found);
184 }
185 
TEST_F(DependencyOptimizerTest,ChangeToNoop_SwitchIdentity)186 TEST_F(DependencyOptimizerTest, ChangeToNoop_SwitchIdentity) {
187   // This tests that we don't try to repeatedly add Identity nodes
188   // with names like "ConstantFoldingCtrl/foo/bar/switch_$port" when
189   // multiple nodes reading the same output of a Switch node get
190   // optimized (e.g. constant folded or turned into NoOps).
191   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
192   ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
193   ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
194   ops::Switch s(scope.WithOpName("switch"), v_in, v_ctrl);
195   // "neg" should be turned into a NoOp with a control dependency from
196   // the existing Identity node "ConstantFoldingCtrl/switch_1" and
197   // subsequently eliminated completely from the graph.
198   Output neg = ops::Neg(scope.WithOpName("neg"), s.output_true);
199   // c1 could be a result of constant folding some node fed by neg.
200   Output c1 = ops::Const(scope.WithOpName("c1").WithControlDependencies(neg),
201                          {1.0f, 2.0f}, {1, 2});
202   Output ctrl_dep_id = ops::Identity(
203       scope.WithOpName("ConstantFoldingCtrl/switch_1"), s.output_true);
204   // c2 could be a result of constant folding a node fed by s, which also
205   // added the ctrl_dep_id node.
206   Output c2 =
207       ops::Const(scope.WithOpName("c2").WithControlDependencies(ctrl_dep_id),
208                  {1.0f, 2.0f}, {1, 2});
209   Output neg1 = ops::Neg(scope.WithOpName("neg1"), s.output_false);
210   Output neg2 = ops::Neg(scope.WithOpName("neg2"), ctrl_dep_id);
211 
212   GrapplerItem item;
213   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
214   item.fetch.push_back("c1");
215   item.fetch.push_back("c2");
216   item.fetch.push_back("neg1");
217   item.fetch.push_back("neg2");
218 
219   DependencyOptimizer optimizer;
220   GraphDef output;
221   Status status = optimizer.Optimize(nullptr, item, &output);
222   TF_EXPECT_OK(status);
223 
224   EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
225   for (int i = 0; i < output.node_size(); ++i) {
226     const NodeDef& node = output.node(i);
227     // "neg" should be eliminated.
228     EXPECT_NE("neg", node.name());
229     // A control dep from "^ConstantFoldingCtrl/switch_1"
230     // should be attached to "c1".
231     if (node.name() == "c1") {
232       EXPECT_EQ("Const", node.op());
233       EXPECT_EQ(1, node.input_size());
234       EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
235     }
236   }
237 }
238 
239 // TODO(rmlarsen): Add test to make sure we skip Switch and Merge.
TEST_F(DependencyOptimizerTest,ChangeToNoop_NoFetch)240 TEST_F(DependencyOptimizerTest, ChangeToNoop_NoFetch) {
241   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
242   Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
243   Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
244   Output add = ops::Add(s.WithOpName("add"), x, y);
245   Output id1 =
246       ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x);
247   Output id2 =
248       ops::Identity(s.WithOpName("id2").WithControlDependencies(add), y);
249 
250   GrapplerItem item;
251   TF_CHECK_OK(s.ToGraphDef(&item.graph));
252 
253   DependencyOptimizer optimizer;
254   GraphDef output;
255   Status status = optimizer.Optimize(nullptr, item, &output);
256   TF_EXPECT_OK(status);
257 
258   TF_CHECK_OK(TopologicalSort(&item.graph));
259   VerifyGraphsEqual(item.graph, output, __FUNCTION__);
260 }
261 
TEST_F(DependencyOptimizerTest,RemoveNoOps_EmptyInputOrOutput)262 TEST_F(DependencyOptimizerTest, RemoveNoOps_EmptyInputOrOutput) {
263   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
264   Output x = ops::RandomUniform(s, {1, 2}, DT_FLOAT);
265   auto noop1 = ops::NoOp(s);
266   auto noop2 = ops::NoOp(s.WithControlDependencies(x));
267   Output id = ops::Identity(s.WithControlDependencies({noop1.operation}), x);
268 
269   GrapplerItem item;
270   TF_CHECK_OK(s.ToGraphDef(&item.graph));
271   item.fetch.push_back("Identity");
272 
273   DependencyOptimizer optimizer;
274   GraphDef output;
275   Status status = optimizer.Optimize(nullptr, item, &output);
276   TF_EXPECT_OK(status);
277   // Run the optimizer twice to make sure the rewrite is idempotent.
278   item.graph.Swap(&output);
279   status = optimizer.Optimize(nullptr, item, &output);
280   TF_EXPECT_OK(status);
281 
282   EXPECT_EQ(item.graph.node_size(), output.node_size());
283   for (const NodeDef& node : output.node()) {
284     if (node.name() == "NoOp" || node.name() == "NoOp_1") {
285       EXPECT_EQ(0, node.input_size());
286     } else if (node.name() == "Identity") {
287       EXPECT_EQ(1, node.input_size());
288       EXPECT_EQ("RandomUniform", node.input(0));
289     }
290   }
291 }
292 
TEST_F(DependencyOptimizerTest,RemoveNoOps_DeviceBoundaries)293 TEST_F(DependencyOptimizerTest, RemoveNoOps_DeviceBoundaries) {
294   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
295   Output x = ops::RandomUniform(s.WithOpName("x").WithDevice("/CPU:0"), {1, 2},
296                                 DT_FLOAT);
297   Output y = ops::RandomUniform(s.WithOpName("y").WithDevice("/CPU:0"), {1, 2},
298                                 DT_FLOAT);
299   // NoOp with a single input- and two output dependencies.
300   auto noop = ops::NoOp(s.WithControlDependencies(x).WithDevice("/CPU:1"));
301   // NoOp with a two input- and a single output dependency.
302   auto noop_1 = ops::NoOp(
303       s.WithControlDependencies(x).WithControlDependencies(y).WithDevice(
304           "/CPU:0"));
305   Output id = ops::Identity(
306       s.WithControlDependencies({noop.operation}).WithDevice("/CPU:1"), x);
307   Output id_1 = ops::Identity(
308       s.WithControlDependencies({noop.operation, noop_1.operation})
309           .WithDevice("/CPU:1"),
310       y);
311 
312   GrapplerItem item;
313   TF_CHECK_OK(s.ToGraphDef(&item.graph));
314   item.fetch.push_back("Identity");
315   item.fetch.push_back("Identity_1");
316 
317   DependencyOptimizer optimizer;
318   GraphDef output;
319   Status status = optimizer.Optimize(nullptr, item, &output);
320   TF_EXPECT_OK(status);
321 
322   // The optimization should be disabled to prevent increasing the number of
323   // nodes crossing device boundaries.
324   TF_CHECK_OK(TopologicalSort(&item.graph));
325   VerifyGraphsEqual(item.graph, output, __FUNCTION__);
326 }
327 
TEST_F(DependencyOptimizerTest,RemoveIdentityOps_DeviceBoundaries)328 TEST_F(DependencyOptimizerTest, RemoveIdentityOps_DeviceBoundaries) {
329   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
330   Output x = ops::RandomUniform(s.WithOpName("x").WithDevice("/CPU:0"), {1, 2},
331                                 DT_FLOAT);
332   Output y = ops::RandomUniform(s.WithOpName("y").WithDevice("/CPU:0"), {1, 2},
333                                 DT_FLOAT);
334   // Identity with a single input- and two output dependencies.
335   auto id_a = ops::Identity(s.WithOpName("id_a").WithDevice("/CPU:1"), x);
336   // Identity with a two input- and a single output dependency.
337   auto id_b = ops::Identity(
338       s.WithOpName("id_b").WithControlDependencies(y).WithDevice("/CPU:0"), x);
339 
340   Output id =
341       ops::Identity(s.WithControlDependencies(id_a).WithDevice("/CPU:1"), id_b);
342   Output id_1 = ops::Identity(s.WithDevice("/CPU:1"), id_a);
343 
344   GrapplerItem item;
345   TF_CHECK_OK(s.ToGraphDef(&item.graph));
346   item.fetch.push_back("Identity");
347   item.fetch.push_back("Identity_1");
348 
349   DependencyOptimizer optimizer;
350   GraphDef output;
351   Status status = optimizer.Optimize(nullptr, item, &output);
352   TF_EXPECT_OK(status);
353 
354   // The optimization should be disabled to prevent increasing the number of
355   // nodes crossing device boundaries.
356   TF_CHECK_OK(TopologicalSort(&item.graph));
357   VerifyGraphsEqual(item.graph, output, __FUNCTION__);
358 }
359 
TEST_F(DependencyOptimizerTest,RemoveIdentityOps_IdenticalDevices)360 TEST_F(DependencyOptimizerTest, RemoveIdentityOps_IdenticalDevices) {
361   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
362   Output x = ops::RandomUniform(s.WithOpName("x").WithDevice("/CPU:0"), {1, 2},
363                                 DT_FLOAT);
364   auto id_a = ops::Identity(s.WithOpName("id_a").WithDevice("/CPU:1"), x);
365   Output id =
366       ops::Identity(s.WithControlDependencies(id_a).WithDevice("/CPU:0"), id_a);
367 
368   GrapplerItem item;
369   TF_CHECK_OK(s.ToGraphDef(&item.graph));
370   item.fetch.push_back("Identity");
371 
372   DependencyOptimizer optimizer;
373   GraphDef output;
374   Status status = optimizer.Optimize(nullptr, item, &output);
375   TF_EXPECT_OK(status);
376 
377   EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
378   for (const NodeDef& node : output.node()) {
379     EXPECT_NE(node.name(), "id_a");
380     if (node.name() == "Identity") {
381       EXPECT_EQ(node.input(0), "x");
382     }
383   }
384 }
385 
TEST_F(DependencyOptimizerTest,RemoveNoOps_SingleInputOrOutput)386 TEST_F(DependencyOptimizerTest, RemoveNoOps_SingleInputOrOutput) {
387   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
388   Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
389   Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
390   // NoOp with a single input- and two output dependencies.
391   auto noop = ops::NoOp(s.WithControlDependencies(x));
392   // NoOp with a two input- and a single output dependency.
393   auto noop_1 =
394       ops::NoOp(s.WithControlDependencies(x).WithControlDependencies(y));
395   Output id = ops::Identity(s.WithControlDependencies({noop.operation}), x);
396   Output id_1 = ops::Identity(
397       s.WithControlDependencies({noop.operation, noop_1.operation}), y);
398 
399   GrapplerItem item;
400   TF_CHECK_OK(s.ToGraphDef(&item.graph));
401   item.fetch.push_back("Identity");
402   item.fetch.push_back("Identity_1");
403 
404   DependencyOptimizer optimizer;
405   GraphDef output;
406   Status status = optimizer.Optimize(nullptr, item, &output);
407   TF_EXPECT_OK(status);
408   // Run the optimizer twice to make sure the rewrite is idempotent.
409   item.graph.Swap(&output);
410   status = optimizer.Optimize(nullptr, item, &output);
411   TF_EXPECT_OK(status);
412 
413   EXPECT_EQ(item.graph.node_size(), output.node_size());
414   for (const NodeDef& node : output.node()) {
415     if (node.name() == "NoOp" || node.name() == "NoOp_1") {
416       EXPECT_EQ(0, node.input_size());
417     } else if (node.name() == "Identity") {
418       EXPECT_EQ("x", node.input(0));
419     } else if (node.name() == "Identity_1") {
420       EXPECT_EQ("y", node.input(0));
421       EXPECT_EQ("^x", node.input(1));
422     }
423   }
424 }
425 
TEST_F(DependencyOptimizerTest,RemoveIdentity)426 TEST_F(DependencyOptimizerTest, RemoveIdentity) {
427   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
428   Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
429   Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
430   Output z = ops::RandomUniform(s.WithOpName("z"), {1, 2}, DT_FLOAT);
431 
432   // Identity nodes to be removed.
433   // Case a) with a single input- and multiple outputs.
434   auto id_a = ops::Identity(s.WithOpName("id_a"), x);
435   // Case b) with multiple inputs and a single output.
436   auto id_b = ops::Identity(
437       s.WithOpName("id_b").WithControlDependencies(y).WithControlDependencies(
438           z),
439       x);
440   // Case c) with two inputs and two outputs.
441   auto id_c = ops::Identity(s.WithOpName("id_c").WithControlDependencies(y), x);
442 
443   // Output for Case a.
444   Output a_a = ops::Identity(s.WithOpName("a_a"), id_a);
445   Output a_b = ops::Identity(s.WithOpName("a_b"), id_a);
446   Output a_c =
447       ops::Identity(s.WithOpName("a_c").WithControlDependencies(id_a), z);
448   Output a_d =
449       ops::Identity(s.WithOpName("a_d").WithControlDependencies(id_a), z);
450   // Output for Case b.
451   Output b_a = ops::Identity(s.WithOpName("b_a"), id_b);
452   // Output for Case c.
453   Output c_a = ops::Identity(s.WithOpName("c_a"), id_c);
454   Output c_b =
455       ops::Identity(s.WithOpName("c_b").WithControlDependencies(id_c), z);
456 
457   GrapplerItem item;
458   TF_CHECK_OK(s.ToGraphDef(&item.graph));
459   item.fetch = {"a_a", "a_b", "a_c", "a_d", "b_a", "c_a", "c_b"};
460 
461   DependencyOptimizer optimizer;
462   GraphDef output;
463   Status status = optimizer.Optimize(nullptr, item, &output);
464   TF_EXPECT_OK(status);
465 
466   EXPECT_EQ(item.graph.node_size() - 3, output.node_size());
467   int found = 0;
468   for (const NodeDef& node : output.node()) {
469     EXPECT_NE("id_a", node.name());
470     EXPECT_NE("id_b", node.name());
471     EXPECT_NE("id_c", node.name());
472     if (node.name() == "a_a" || node.name() == "a_b") {
473       ASSERT_EQ(1, node.input_size());
474       EXPECT_EQ("x", node.input(0));
475       ++found;
476     }
477     if (node.name() == "a_c" || node.name() == "a_d") {
478       ASSERT_EQ(2, node.input_size());
479       EXPECT_EQ("z", node.input(0));
480       EXPECT_EQ("^x", node.input(1));
481       ++found;
482     }
483     if (node.name() == "b_a") {
484       ASSERT_EQ(3, node.input_size());
485       EXPECT_EQ("x", node.input(0));
486       EXPECT_EQ("^y", node.input(1));
487       EXPECT_EQ("^z", node.input(2));
488       ++found;
489     }
490     if (node.name() == "c_a") {
491       ASSERT_EQ(2, node.input_size());
492       EXPECT_EQ("x", node.input(0));
493       EXPECT_EQ("^y", node.input(1));
494       ++found;
495     }
496     if (node.name() == "c_b") {
497       ASSERT_EQ(3, node.input_size());
498       EXPECT_EQ("z", node.input(0));
499       EXPECT_EQ("^x", node.input(1));
500       EXPECT_EQ("^y", node.input(2));
501       ++found;
502     }
503   }
504   EXPECT_EQ(found, 7);
505 }
506 
TEST_F(DependencyOptimizerTest,RemoveIdentity_RepeatedInputs)507 TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) {
508   // Corner cases with repeated inputs.
509   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
510   ops::Variable x(scope.WithOpName("x"), {}, DT_BOOL);
511   ops::Variable y(scope.WithOpName("y"), {}, DT_BOOL);
512   ops::Switch sw(scope.WithOpName("switch"), x, x);
513   // id0 should be removed.
514   Output id0 = ops::Identity(scope.WithOpName("id0"), sw.output_true);
515   // id1 should not be removed, since it would anchor a control dependency
516   // on the switch.
517   Output id1 = ops::Identity(scope.WithOpName("id1"), sw.output_false);
518   Output or0 = ops::LogicalOr(scope.WithOpName("or0"), id0, id0);
519   Output or1 = ops::LogicalOr(scope.WithOpName("or1"), id0, y);
520   Output or2 = ops::LogicalOr(
521       scope.WithOpName("or2").WithControlDependencies(id1), y, y);
522 
523   GrapplerItem item;
524   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
525   item.fetch.push_back("or0");
526   item.fetch.push_back("or1");
527   item.fetch.push_back("or2");
528   DependencyOptimizer optimizer;
529   GraphDef output;
530   Status status = optimizer.Optimize(nullptr, item, &output);
531   TF_EXPECT_OK(status);
532 
533   EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
534   int found = 0;
535   for (const NodeDef& node : output.node()) {
536     EXPECT_NE("id0", node.name());
537     if (node.name() == "or0") {
538       EXPECT_EQ(2, node.input_size());
539       EXPECT_EQ("switch:1", node.input(0));
540       EXPECT_EQ("switch:1", node.input(1));
541       ++found;
542     }
543     if (node.name() == "or1") {
544       EXPECT_EQ(2, node.input_size());
545       EXPECT_EQ("switch:1", node.input(0));
546       EXPECT_EQ("y", node.input(1));
547       ++found;
548     }
549     if (node.name() == "or2") {
550       // or1 should be unchanged.
551       EXPECT_EQ(3, node.input_size());
552       EXPECT_EQ("y", node.input(0));
553       EXPECT_EQ("y", node.input(1));
554       EXPECT_EQ("^id1", node.input(2));
555       ++found;
556     }
557   }
558   EXPECT_EQ(found, 3);
559 }
560 
TEST_F(DependencyOptimizerTest,Transitive_Reduction_Simple)561 TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) {
562   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
563   Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
564   Output x = ops::Square(s.WithOpName("x"), c);
565   Output neg1 = ops::Neg(s.WithOpName("neg1"), x);
566   Output neg2 =
567       ops::Neg(s.WithOpName("neg2").WithControlDependencies({x}), neg1);
568 
569   GrapplerItem item;
570   TF_CHECK_OK(s.ToGraphDef(&item.graph));
571   item.fetch.push_back("neg2");
572   DependencyOptimizer optimizer;
573   GraphDef output;
574   Status status = optimizer.Optimize(nullptr, item, &output);
575   TF_EXPECT_OK(status);
576   EXPECT_EQ(4, output.node_size());
577   EXPECT_EQ("neg2", output.node(3).name());
578   EXPECT_EQ(1, output.node(3).input_size());
579   EXPECT_EQ("neg1", output.node(3).input(0));
580 }
581 
TEST_F(DependencyOptimizerTest,ChangeToNoop_Identity)582 TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) {
583   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
584   ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
585   Output id_after_var = ops::Identity(scope.WithOpName("id_after_var"), v_in);
586   ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
587   ops::Switch s(
588       scope.WithOpName("switch").WithControlDependencies(id_after_var), v_in,
589       v_ctrl);
590   Output id0 = ops::Identity(scope.WithOpName("id0"), s.output_true);
591   Output grappler_added_id = ops::Identity(
592       scope.WithOpName("ConstantFoldingCtrl/switch_1"), s.output_true);
593   Output c1 = ops::Const(scope.WithOpName("c1")
594                              .WithControlDependencies(id_after_var)
595                              .WithControlDependencies(grappler_added_id),
596                          {1.0f, 2.0f}, {1, 2});
597   Output id1 = ops::Identity(scope.WithOpName("id1"), c1);
598   Output id2 = ops::Identity(scope.WithOpName("id2"), id0);
599   Output fetch =
600       ops::Identity(scope.WithOpName("fetch").WithControlDependencies(id1), c1);
601 
602   GrapplerItem item;
603   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
604   item.fetch.push_back("c1");
605   item.fetch.push_back("id2");
606   item.fetch.push_back("fetch");
607 
608   DependencyOptimizer optimizer;
609   GraphDef output;
610   Status status = optimizer.Optimize(nullptr, item, &output);
611   TF_EXPECT_OK(status);
612 
613   EXPECT_EQ(item.graph.node_size() - 2, output.node_size());
614   bool found = false;
615   for (int i = 0; i < output.node_size(); ++i) {
616     const NodeDef& node = output.node(i);
617     // "id0" and "id1" but neither "ConstantFoldingCtrl/switch_1",
618     // "id_after_var, nor "id2"" should be eliminated.
619     EXPECT_NE("id0", node.name());
620     EXPECT_NE("id1", node.name());
621     if (node.name() == "c1") {
622       EXPECT_EQ("Const", node.op());
623       EXPECT_EQ(1, node.input_size());
624       EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
625       found = true;
626     }
627   }
628   EXPECT_TRUE(found);
629 }
630 
TEST_F(DependencyOptimizerTest,IdentityInputs)631 TEST_F(DependencyOptimizerTest, IdentityInputs) {
632   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
633   Output b = ops::Placeholder(scope.WithOpName("b"), DT_BOOL);
634   Output x = ops::RandomUniform(scope.WithOpName("x"), {1, 2}, DT_FLOAT);
635   auto s = ops::Switch(scope.WithOpName("s"), x, b);
636 
637   // Identity nodes to be removed.
638   auto id_f = ops::Identity(scope.WithOpName("id_f"), s.output_false);
639   auto id_t = ops::Identity(scope.WithOpName("id_t"), s.output_true);
640 
641   // Output
642   Output out1 = ops::Identity(scope.WithOpName("out1"), id_f);
643   Output out2 = ops::Identity(scope.WithOpName("out2"), id_t);
644 
645   GrapplerItem item;
646   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
647   item.fetch = {"out1", "out2"};
648 
649   DependencyOptimizer optimizer;
650   GraphDef output;
651   Status status = optimizer.Optimize(nullptr, item, &output);
652   TF_EXPECT_OK(status);
653 
654   EXPECT_EQ(6, output.node_size());
655   EXPECT_EQ("out1", output.node(4).name());
656   EXPECT_EQ(1, output.node(4).input_size());
657   EXPECT_EQ("s", output.node(4).input(0));
658 
659   EXPECT_EQ("out2", output.node(5).name());
660   EXPECT_EQ(1, output.node(5).input_size());
661   EXPECT_EQ("s:1", output.node(5).input(0));
662 }
663 
TEST_F(DependencyOptimizerTest,RemoveIdentityN_SwitchInput)664 TEST_F(DependencyOptimizerTest, RemoveIdentityN_SwitchInput) {
665   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
666   Output b = ops::Placeholder(scope.WithOpName("b"), DT_BOOL);
667   Output x = ops::RandomUniform(scope.WithOpName("x"), {1, 2}, DT_FLOAT);
668   auto s = ops::Switch(scope.WithOpName("s"), x, b);
669 
670   // IdentityN nodes to be removed.
671   auto id_f = ops::IdentityN(scope.WithOpName("id_f"), {s.output_false});
672   auto id_t = ops::IdentityN(scope.WithOpName("id_t"), {s.output_true});
673   auto id_b =
674       ops::IdentityN(scope.WithOpName("id_b"), {s.output_false, s.output_true});
675 
676   // Outputs
677   Output out1 = ops::Identity(scope.WithOpName("out1"), id_f[0]);
678   Output out2 = ops::Identity(scope.WithOpName("out2"), id_t[0]);
679   Output out3 = ops::Identity(scope.WithOpName("out3"), id_b[0]);
680   Output out4 = ops::Identity(scope.WithOpName("out4"), id_b[1]);
681 
682   GrapplerItem item;
683   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
684   item.fetch = {"out1", "out2", "out3", "out4"};
685 
686   DependencyOptimizer optimizer;
687   GraphDef output;
688   Status status = optimizer.Optimize(nullptr, item, &output);
689   TF_EXPECT_OK(status);
690 
691   EXPECT_EQ(8, output.node_size());
692 
693   auto out1_node = output.node(7);
694   EXPECT_EQ("out1", out1_node.name());
695   EXPECT_EQ(1, out1_node.input_size());
696   EXPECT_EQ("s", out1_node.input(0));
697 
698   auto out2_node = output.node(4);
699   EXPECT_EQ("out2", out2_node.name());
700   EXPECT_EQ(1, out2_node.input_size());
701   EXPECT_EQ("s:1", out2_node.input(0));
702 
703   auto out3_node = output.node(5);
704   EXPECT_EQ("out3", out3_node.name());
705   EXPECT_EQ(1, out3_node.input_size());
706   EXPECT_EQ("s", out3_node.input(0));
707 
708   auto out4_node = output.node(6);
709   EXPECT_EQ("out4", out4_node.name());
710   EXPECT_EQ(1, out4_node.input_size());
711   EXPECT_EQ("s:1", out4_node.input(0));
712 }
713 
TEST_F(DependencyOptimizerTest,DoNotRemoveIdentityNWithControlDependency)714 TEST_F(DependencyOptimizerTest, DoNotRemoveIdentityNWithControlDependency) {
715   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
716   Output input1 = ops::Placeholder(scope.WithOpName("input1"), DT_BOOL);
717   Output input2 = ops::Const(scope.WithOpName("input2"), {1, 2});
718 
719   auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {input1, input2});
720   Output out1 = ops::Identity(scope.WithOpName("out1"), id_n[0]);
721   Output out2 = ops::Identity(scope.WithOpName("out2"), id_n[1]);
722   auto out3 =
723       ops::NoOp(scope.WithOpName("out3").WithControlDependencies(id_n[1]));
724 
725   GrapplerItem item;
726   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
727   item.fetch = {"out1", "out2", "out3"};
728 
729   DependencyOptimizer optimizer;
730   GraphDef optimized_graph_def;
731   Status status = optimizer.Optimize(nullptr, item, &optimized_graph_def);
732   TF_EXPECT_OK(status);
733 
734   EXPECT_EQ(6, optimized_graph_def.node_size());
735 }
736 
TEST_F(DependencyOptimizerTest,Identity_DeviceCrossing_ConsumerOnDifferentDevice)737 TEST_F(DependencyOptimizerTest,
738        Identity_DeviceCrossing_ConsumerOnDifferentDevice) {
739   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
740   Output x_on_1 =
741       ops::Const(s.WithOpName("x_on_1").WithDevice("/gpu:1"), {1.0f}, {});
742   Output one_on_3 =
743       ops::Const(s.WithOpName("one_on_3").WithDevice("/gpu:3"), {1.0f}, {});
744   Output x_on_2 =
745       ops::Identity(s.WithOpName("x_on_2").WithDevice("/gpu:2"), x_on_1);
746   Output result =
747       ops::Add(s.WithOpName("result").WithDevice("/gpu:3"), x_on_2, one_on_3);
748 
749   GrapplerItem item;
750   TF_CHECK_OK(s.ToGraphDef(&item.graph));
751   item.fetch = {"result"};
752   DependencyOptimizer optimizer;
753   GraphDef output;
754   Status status = optimizer.Optimize(nullptr, item, &output);
755   TF_EXPECT_OK(status);
756 
757   VerifyGraphsEqual(item.graph, output, __FUNCTION__);
758 }
759 
TEST_F(DependencyOptimizerTest,Identity_DeviceCrossing_ConsumerOnSameDevice)760 TEST_F(DependencyOptimizerTest, Identity_DeviceCrossing_ConsumerOnSameDevice) {
761   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
762   Output x_on_1 =
763       ops::Const(s.WithOpName("x_on_1").WithDevice("/gpu:1"), {1.0f}, {});
764   Output one_on_2 =
765       ops::Const(s.WithOpName("one_on_2").WithDevice("/gpu:2"), {1.0f}, {});
766   Output x_on_2 =
767       ops::Identity(s.WithOpName("x_on_2").WithDevice("/gpu:2"), x_on_1);
768   Output result =
769       ops::Add(s.WithOpName("result").WithDevice("/gpu:2"), x_on_2, one_on_2);
770 
771   GrapplerItem item;
772   TF_CHECK_OK(s.ToGraphDef(&item.graph));
773   item.fetch = {"result"};
774   DependencyOptimizer optimizer;
775   GraphDef output;
776   Status status = optimizer.Optimize(nullptr, item, &output);
777   TF_EXPECT_OK(status);
778   EXPECT_EQ(3, output.node_size());
779   for (const auto& node : output.node()) {
780     EXPECT_NE("x_on_2", node.name());
781     if (node.name() == "result") {
782       EXPECT_EQ("x_on_1", node.input(0));
783     }
784   }
785 }
786 
TEST_F(DependencyOptimizerTest,RemoveGreaterEqualWithNoOp)787 TEST_F(DependencyOptimizerTest, RemoveGreaterEqualWithNoOp) {
788   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
789   Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
790                               ops::Placeholder::Shape({}));
791   Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
792                               ops::Placeholder::Shape({}));
793   auto greaterequal = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y);
794   auto noop =
795       ops::NoOp(s.WithOpName("NoOp").WithControlDependencies(greaterequal));
796   Output add = ops::Add(
797       s.WithOpName("z").WithControlDependencies({noop.operation}), x, y);
798   GrapplerItem item;
799   TF_CHECK_OK(s.ToGraphDef(&item.graph));
800 
801   DependencyOptimizer optimizer;
802   GraphDef output;
803   item.fetch.push_back("z");
804   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
805 
806   int count = 0;
807   for (const NodeDef& node : output.node()) {
808     if (node.name() == "x") {
809       count++;
810       EXPECT_EQ("Placeholder", node.op());
811       EXPECT_EQ(0, node.input_size());
812     } else if (node.name() == "y") {
813       count++;
814       EXPECT_EQ("Placeholder", node.op());
815       EXPECT_EQ(0, node.input_size());
816     } else if (node.name() == "GreaterEqual") {
817       count++;
818     } else if (node.name() == "NoOp") {
819       count++;
820     } else if (node.name() == "z") {
821       count++;
822       EXPECT_EQ("Add", node.op());
823       EXPECT_EQ(2, node.input_size());
824       EXPECT_EQ("x", node.input(0));
825       EXPECT_EQ("y", node.input(1));
826     }
827   }
828   EXPECT_EQ(3, count);
829 }
830 
TEST_F(DependencyOptimizerTest,GroupCrossDeviceControlDeps)831 TEST_F(DependencyOptimizerTest, GroupCrossDeviceControlDeps) {
832   GrapplerItem item;
833   {
834     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
835     Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:1"),
836                                   {1, 2}, DT_FLOAT);
837     Output b = ops::RandomUniform(s.WithOpName("b").WithDevice("/CPU:2"),
838                                   {1, 2}, DT_FLOAT);
839     Output c = ops::RandomUniform(s.WithOpName("c").WithDevice("/CPU:1"),
840                                   {1, 2}, DT_FLOAT);
841     Output d = ops::RandomUniform(s.WithOpName("d").WithDevice("/CPU:3"),
842                                   {1, 2}, DT_FLOAT);
843     Output e = ops::RandomUniform(s.WithOpName("e").WithDevice("/CPU:0"),
844                                   {1, 2}, DT_FLOAT);
845     // Node with cross-device dependencies.
846     auto fetch = ops::Identity(
847         s.WithOpName("f")
848             .WithControlDependencies({a.op(), b.op(), c.op(), d.op()})
849             .WithDevice("/GPU:0"),
850         {e});
851 
852     TF_CHECK_OK(s.ToGraphDef(&item.graph));
853     item.fetch.push_back("f");
854   }
855 
856   GraphDef expected;
857   {
858     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
859     Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:1"),
860                                   {1, 2}, DT_FLOAT);
861     Output b = ops::RandomUniform(s.WithOpName("b").WithDevice("/CPU:2"),
862                                   {1, 2}, DT_FLOAT);
863     Output c = ops::RandomUniform(s.WithOpName("c").WithDevice("/CPU:1"),
864                                   {1, 2}, DT_FLOAT);
865     Output d = ops::RandomUniform(s.WithOpName("d").WithDevice("/CPU:3"),
866                                   {1, 2}, DT_FLOAT);
867     Output e = ops::RandomUniform(s.WithOpName("e").WithDevice("/CPU:0"),
868                                   {1, 2}, DT_FLOAT);
869     auto noop = ops::NoOp(s.WithOpName("GroupCrossDeviceControlEdges_0/f")
870                               .WithDevice("/CPU:1")
871                               .WithControlDependencies({a.op(), c.op()}));
872     auto fetch =
873         ops::Identity(s.WithOpName("f")
874                           .WithControlDependencies({b.op(), d.op(), noop})
875                           .WithDevice("/GPU:0"),
876                       {e});
877 
878     TF_CHECK_OK(s.ToGraphDef(&expected));
879   }
880 
881   DependencyOptimizer optimizer;
882   GraphDef output;
883   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
884   CompareGraphs(expected, output);
885 
886   // Run the optimizer again to verify idempotence.
887   item.graph.Swap(&output);
888   output.Clear();
889   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
890   CompareGraphs(expected, output);
891 }
892 
TEST_F(DependencyOptimizerTest,GroupCrossHostControlDeps)893 TEST_F(DependencyOptimizerTest, GroupCrossHostControlDeps) {
894   GrapplerItem item;
895   {
896     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
897     std::vector<Operation> ops;
898     Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:0"),
899                                   {1, 2}, DT_FLOAT);
900     for (int t = 0; t < 4; ++t) {
901       for (int c = 0; c < 8; ++c) {
902         string opname = absl::StrCat("t", t, "/c", c);
903         string device = absl::StrCat("/task:", t, "/device:TPU:", c);
904         Output output = ops::RandomUniform(
905             s.WithOpName(opname).WithDevice(device), {1, 2}, DT_FLOAT);
906         ops.push_back(output.op());
907       }
908     }
909     // Node with cross-device dependencies.
910     auto fetch = ops::Identity(
911         s.WithOpName("f").WithControlDependencies(ops).WithDevice("/CPU:0"),
912         {a});
913 
914     TF_CHECK_OK(s.ToGraphDef(&item.graph));
915     item.fetch.push_back("f");
916   }
917 
918   GraphDef expected;
919   {
920     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
921     TF_CHECK_OK(s.ToGraphDef(&expected));
922   }
923 
924   DependencyOptimizer optimizer;
925   GraphDef output;
926   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
927 
928   EXPECT_EQ(output.node_size(), item.graph.node_size() + 4);
929   std::set<string> tasks;
930   for (const auto& n : output.node()) {
931     if (n.op() == "NoOp") {
932       EXPECT_TRUE(absl::StartsWith(n.name(), "GroupCrossDeviceControlEdges"));
933       EXPECT_EQ(n.input_size(), 8);
934       tasks.insert(n.device());
935     }
936 
937     if (n.name() == "f") {
938       EXPECT_EQ(n.input_size(), 5);
939       for (const auto& i : n.input()) {
940         EXPECT_TRUE(i == "a" ||
941                     absl::StartsWith(i, "^GroupCrossDeviceControlEdges"));
942       }
943     }
944   }
945   EXPECT_EQ(tasks.size(), 4);
946 }
947 
948 }  // namespace
949 }  // namespace grappler
950 }  // namespace tensorflow
951