1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
17
18 #include "absl/strings/match.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
24 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
25 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
26 #include "tensorflow/core/grappler/utils.h"
27 #include "tensorflow/core/grappler/utils/grappler_test.h"
28 #include "tensorflow/core/grappler/utils/topological_sort.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/platform/test.h"
31
32 namespace tensorflow {
33 namespace grappler {
34 namespace {
35
36 class DependencyOptimizerTest : public GrapplerTest {};
37
VerifyGraphsEqual(const GraphDef & original_graph,const GraphDef & optimized_graph,const string & func)38 void VerifyGraphsEqual(const GraphDef& original_graph,
39 const GraphDef& optimized_graph, const string& func) {
40 EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
41 for (int i = 0; i < original_graph.node_size(); ++i) {
42 const NodeDef& original = original_graph.node(i);
43 const NodeDef& optimized = optimized_graph.node(i);
44 EXPECT_EQ(original.name(), optimized.name()) << func;
45 EXPECT_EQ(original.op(), optimized.op()) << func;
46 EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
47 for (int j = 0; j < original.input_size(); ++j) {
48 EXPECT_EQ(original.input(j), optimized.input(j)) << func;
49 }
50 }
51 }
52
TEST_F(DependencyOptimizerTest,NoOp)53 TEST_F(DependencyOptimizerTest, NoOp) {
54 // This trivial graph is so basic there's nothing to optimize.
55 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
56 GrapplerItem item;
57 CHECK(fake_input.NextItem(&item));
58
59 DependencyOptimizer optimizer;
60 GraphDef output;
61 Status status = optimizer.Optimize(nullptr, item, &output);
62 TF_EXPECT_OK(status);
63
64 VerifyGraphsEqual(item.graph, output, __FUNCTION__);
65 }
66
TEST_F(DependencyOptimizerTest,DependenciesDrivenByConstants)67 TEST_F(DependencyOptimizerTest, DependenciesDrivenByConstants) {
68 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
69 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
70 Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2});
71 Output z = ops::Const(s.WithOpName("z"), {1.0f, 2.0f}, {1, 2});
72 Output add = ops::Add(s.WithOpName("add"), x, y);
73 Output id1 =
74 ops::Identity(s.WithOpName("id1").WithControlDependencies(x), add);
75 Output id2 = ops::Identity(
76 s.WithOpName("id2").WithControlDependencies(y).WithControlDependencies(z),
77 add);
78
79 GrapplerItem item;
80 TF_CHECK_OK(s.ToGraphDef(&item.graph));
81 item.fetch.push_back("id1");
82 item.fetch.push_back("id2");
83
84 DependencyOptimizer optimizer;
85 GraphDef output;
86 Status status = optimizer.Optimize(nullptr, item, &output);
87 TF_EXPECT_OK(status);
88 // Run the optimizer twice to make sure the rewrite is idempotent.
89 item.graph.Swap(&output);
90 status = optimizer.Optimize(nullptr, item, &output);
91 TF_EXPECT_OK(status);
92
93 // The 'z' node should have been optimized away leaving only 5 nodes.
94 EXPECT_EQ(5, output.node_size());
95
96 for (const NodeDef& node : item.graph.node()) {
97 if (node.name() == "id1" || node.name() == "id2") {
98 EXPECT_EQ(1, node.input_size());
99 EXPECT_EQ("add", node.input(0));
100 }
101 }
102 }
103
TEST_F(DependencyOptimizerTest,ChangeToNoop)104 TEST_F(DependencyOptimizerTest, ChangeToNoop) {
105 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
106 Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
107 Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
108 Output add = ops::Add(s.WithOpName("add"), x, y);
109 Output id1 =
110 ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x);
111 Output id2 =
112 ops::Identity(s.WithOpName("id2").WithControlDependencies(add), y);
113
114 GrapplerItem item;
115 TF_CHECK_OK(s.ToGraphDef(&item.graph));
116 item.fetch.push_back("id1");
117 item.fetch.push_back("id2");
118
119 DependencyOptimizer optimizer;
120 GraphDef output;
121 Status status = optimizer.Optimize(nullptr, item, &output);
122 TF_EXPECT_OK(status);
123 // Run the optimizer twice to make sure the rewrite is idempotent.
124 item.graph.Swap(&output);
125 status = optimizer.Optimize(nullptr, item, &output);
126 TF_EXPECT_OK(status);
127
128 EXPECT_EQ(item.graph.node_size(), output.node_size());
129 int found = 0;
130 for (int i = 0; i < item.graph.node_size(); ++i) {
131 const NodeDef& node = item.graph.node(i);
132 // "add" should get turned into a NoOp and removed.
133 EXPECT_NE("add", node.name());
134 if (node.name() == "id1") {
135 EXPECT_EQ("Identity", node.op());
136 EXPECT_EQ(2, node.input_size());
137 EXPECT_EQ("x", node.input(0));
138 EXPECT_EQ("^y", node.input(1));
139 ++found;
140 } else if (node.name() == "id2") {
141 EXPECT_EQ("Identity", node.op());
142 EXPECT_EQ(2, node.input_size());
143 EXPECT_EQ("y", node.input(0));
144 EXPECT_EQ("^x", node.input(1));
145 ++found;
146 }
147 }
148 EXPECT_EQ(2, found);
149 }
150
TEST_F(DependencyOptimizerTest,ChangeToNoop_RepeatedInput)151 TEST_F(DependencyOptimizerTest, ChangeToNoop_RepeatedInput) {
152 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
153 Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
154 Output add = ops::Add(s.WithOpName("add"), x, x);
155 Output id1 =
156 ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x);
157 GrapplerItem item;
158 TF_CHECK_OK(s.ToGraphDef(&item.graph));
159 item.fetch = {"id1"};
160
161 DependencyOptimizer optimizer;
162 GraphDef output;
163 Status status = optimizer.Optimize(nullptr, item, &output);
164 TF_EXPECT_OK(status);
165 // Run the optimizer twice to make sure the rewrite is idempotent.
166 item.graph.Swap(&output);
167 status = optimizer.Optimize(nullptr, item, &output);
168 TF_EXPECT_OK(status);
169
170 EXPECT_EQ(item.graph.node_size(), output.node_size());
171 int found = 0;
172 for (int i = 0; i < item.graph.node_size(); ++i) {
173 const NodeDef& node = item.graph.node(i);
174 // "add" should get turned into a NoOp and removed.
175 EXPECT_NE("add", node.name());
176 if (node.name() == "id1") {
177 EXPECT_EQ("Identity", node.op());
178 EXPECT_EQ(1, node.input_size());
179 EXPECT_EQ("x", node.input(0));
180 ++found;
181 }
182 }
183 EXPECT_EQ(1, found);
184 }
185
TEST_F(DependencyOptimizerTest,ChangeToNoop_SwitchIdentity)186 TEST_F(DependencyOptimizerTest, ChangeToNoop_SwitchIdentity) {
187 // This tests that we don't try to repeatedly add Identity nodes
188 // with names like "ConstantFoldingCtrl/foo/bar/switch_$port" when
189 // multiple nodes reading the same output of a Switch node get
190 // optimized (e.g. constant folded or turned into NoOps).
191 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
192 ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
193 ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
194 ops::Switch s(scope.WithOpName("switch"), v_in, v_ctrl);
195 // "neg" should be turned into a NoOp with a control dependency from
196 // the existing Identity node "ConstantFoldingCtrl/switch_1" and
197 // subsequently eliminated completely from the graph.
198 Output neg = ops::Neg(scope.WithOpName("neg"), s.output_true);
199 // c1 could be a result of constant folding some node fed by neg.
200 Output c1 = ops::Const(scope.WithOpName("c1").WithControlDependencies(neg),
201 {1.0f, 2.0f}, {1, 2});
202 Output ctrl_dep_id = ops::Identity(
203 scope.WithOpName("ConstantFoldingCtrl/switch_1"), s.output_true);
204 // c2 could be a result of constant folding a node fed by s, which also
205 // added the ctrl_dep_id node.
206 Output c2 =
207 ops::Const(scope.WithOpName("c2").WithControlDependencies(ctrl_dep_id),
208 {1.0f, 2.0f}, {1, 2});
209 Output neg1 = ops::Neg(scope.WithOpName("neg1"), s.output_false);
210 Output neg2 = ops::Neg(scope.WithOpName("neg2"), ctrl_dep_id);
211
212 GrapplerItem item;
213 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
214 item.fetch.push_back("c1");
215 item.fetch.push_back("c2");
216 item.fetch.push_back("neg1");
217 item.fetch.push_back("neg2");
218
219 DependencyOptimizer optimizer;
220 GraphDef output;
221 Status status = optimizer.Optimize(nullptr, item, &output);
222 TF_EXPECT_OK(status);
223
224 EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
225 for (int i = 0; i < output.node_size(); ++i) {
226 const NodeDef& node = output.node(i);
227 // "neg" should be eliminated.
228 EXPECT_NE("neg", node.name());
229 // A control dep from "^ConstantFoldingCtrl/switch_1"
230 // should be attached to "c1".
231 if (node.name() == "c1") {
232 EXPECT_EQ("Const", node.op());
233 EXPECT_EQ(1, node.input_size());
234 EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
235 }
236 }
237 }
238
239 // TODO(rmlarsen): Add test to make sure we skip Switch and Merge.
TEST_F(DependencyOptimizerTest,ChangeToNoop_NoFetch)240 TEST_F(DependencyOptimizerTest, ChangeToNoop_NoFetch) {
241 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
242 Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
243 Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
244 Output add = ops::Add(s.WithOpName("add"), x, y);
245 Output id1 =
246 ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x);
247 Output id2 =
248 ops::Identity(s.WithOpName("id2").WithControlDependencies(add), y);
249
250 GrapplerItem item;
251 TF_CHECK_OK(s.ToGraphDef(&item.graph));
252
253 DependencyOptimizer optimizer;
254 GraphDef output;
255 Status status = optimizer.Optimize(nullptr, item, &output);
256 TF_EXPECT_OK(status);
257
258 TF_CHECK_OK(TopologicalSort(&item.graph));
259 VerifyGraphsEqual(item.graph, output, __FUNCTION__);
260 }
261
TEST_F(DependencyOptimizerTest,RemoveNoOps_EmptyInputOrOutput)262 TEST_F(DependencyOptimizerTest, RemoveNoOps_EmptyInputOrOutput) {
263 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
264 Output x = ops::RandomUniform(s, {1, 2}, DT_FLOAT);
265 auto noop1 = ops::NoOp(s);
266 auto noop2 = ops::NoOp(s.WithControlDependencies(x));
267 Output id = ops::Identity(s.WithControlDependencies({noop1.operation}), x);
268
269 GrapplerItem item;
270 TF_CHECK_OK(s.ToGraphDef(&item.graph));
271 item.fetch.push_back("Identity");
272
273 DependencyOptimizer optimizer;
274 GraphDef output;
275 Status status = optimizer.Optimize(nullptr, item, &output);
276 TF_EXPECT_OK(status);
277 // Run the optimizer twice to make sure the rewrite is idempotent.
278 item.graph.Swap(&output);
279 status = optimizer.Optimize(nullptr, item, &output);
280 TF_EXPECT_OK(status);
281
282 EXPECT_EQ(item.graph.node_size(), output.node_size());
283 for (const NodeDef& node : output.node()) {
284 if (node.name() == "NoOp" || node.name() == "NoOp_1") {
285 EXPECT_EQ(0, node.input_size());
286 } else if (node.name() == "Identity") {
287 EXPECT_EQ(1, node.input_size());
288 EXPECT_EQ("RandomUniform", node.input(0));
289 }
290 }
291 }
292
TEST_F(DependencyOptimizerTest,RemoveNoOps_DeviceBoundaries)293 TEST_F(DependencyOptimizerTest, RemoveNoOps_DeviceBoundaries) {
294 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
295 Output x = ops::RandomUniform(s.WithOpName("x").WithDevice("/CPU:0"), {1, 2},
296 DT_FLOAT);
297 Output y = ops::RandomUniform(s.WithOpName("y").WithDevice("/CPU:0"), {1, 2},
298 DT_FLOAT);
299 // NoOp with a single input- and two output dependencies.
300 auto noop = ops::NoOp(s.WithControlDependencies(x).WithDevice("/CPU:1"));
301 // NoOp with a two input- and a single output dependency.
302 auto noop_1 = ops::NoOp(
303 s.WithControlDependencies(x).WithControlDependencies(y).WithDevice(
304 "/CPU:0"));
305 Output id = ops::Identity(
306 s.WithControlDependencies({noop.operation}).WithDevice("/CPU:1"), x);
307 Output id_1 = ops::Identity(
308 s.WithControlDependencies({noop.operation, noop_1.operation})
309 .WithDevice("/CPU:1"),
310 y);
311
312 GrapplerItem item;
313 TF_CHECK_OK(s.ToGraphDef(&item.graph));
314 item.fetch.push_back("Identity");
315 item.fetch.push_back("Identity_1");
316
317 DependencyOptimizer optimizer;
318 GraphDef output;
319 Status status = optimizer.Optimize(nullptr, item, &output);
320 TF_EXPECT_OK(status);
321
322 // The optimization should be disabled to prevent increasing the number of
323 // nodes crossing device boundaries.
324 TF_CHECK_OK(TopologicalSort(&item.graph));
325 VerifyGraphsEqual(item.graph, output, __FUNCTION__);
326 }
327
TEST_F(DependencyOptimizerTest,RemoveIdentityOps_DeviceBoundaries)328 TEST_F(DependencyOptimizerTest, RemoveIdentityOps_DeviceBoundaries) {
329 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
330 Output x = ops::RandomUniform(s.WithOpName("x").WithDevice("/CPU:0"), {1, 2},
331 DT_FLOAT);
332 Output y = ops::RandomUniform(s.WithOpName("y").WithDevice("/CPU:0"), {1, 2},
333 DT_FLOAT);
334 // Identity with a single input- and two output dependencies.
335 auto id_a = ops::Identity(s.WithOpName("id_a").WithDevice("/CPU:1"), x);
336 // Identity with a two input- and a single output dependency.
337 auto id_b = ops::Identity(
338 s.WithOpName("id_b").WithControlDependencies(y).WithDevice("/CPU:0"), x);
339
340 Output id =
341 ops::Identity(s.WithControlDependencies(id_a).WithDevice("/CPU:1"), id_b);
342 Output id_1 = ops::Identity(s.WithDevice("/CPU:1"), id_a);
343
344 GrapplerItem item;
345 TF_CHECK_OK(s.ToGraphDef(&item.graph));
346 item.fetch.push_back("Identity");
347 item.fetch.push_back("Identity_1");
348
349 DependencyOptimizer optimizer;
350 GraphDef output;
351 Status status = optimizer.Optimize(nullptr, item, &output);
352 TF_EXPECT_OK(status);
353
354 // The optimization should be disabled to prevent increasing the number of
355 // nodes crossing device boundaries.
356 TF_CHECK_OK(TopologicalSort(&item.graph));
357 VerifyGraphsEqual(item.graph, output, __FUNCTION__);
358 }
359
TEST_F(DependencyOptimizerTest,RemoveIdentityOps_IdenticalDevices)360 TEST_F(DependencyOptimizerTest, RemoveIdentityOps_IdenticalDevices) {
361 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
362 Output x = ops::RandomUniform(s.WithOpName("x").WithDevice("/CPU:0"), {1, 2},
363 DT_FLOAT);
364 auto id_a = ops::Identity(s.WithOpName("id_a").WithDevice("/CPU:1"), x);
365 Output id =
366 ops::Identity(s.WithControlDependencies(id_a).WithDevice("/CPU:0"), id_a);
367
368 GrapplerItem item;
369 TF_CHECK_OK(s.ToGraphDef(&item.graph));
370 item.fetch.push_back("Identity");
371
372 DependencyOptimizer optimizer;
373 GraphDef output;
374 Status status = optimizer.Optimize(nullptr, item, &output);
375 TF_EXPECT_OK(status);
376
377 EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
378 for (const NodeDef& node : output.node()) {
379 EXPECT_NE(node.name(), "id_a");
380 if (node.name() == "Identity") {
381 EXPECT_EQ(node.input(0), "x");
382 }
383 }
384 }
385
TEST_F(DependencyOptimizerTest,RemoveNoOps_SingleInputOrOutput)386 TEST_F(DependencyOptimizerTest, RemoveNoOps_SingleInputOrOutput) {
387 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
388 Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
389 Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
390 // NoOp with a single input- and two output dependencies.
391 auto noop = ops::NoOp(s.WithControlDependencies(x));
392 // NoOp with a two input- and a single output dependency.
393 auto noop_1 =
394 ops::NoOp(s.WithControlDependencies(x).WithControlDependencies(y));
395 Output id = ops::Identity(s.WithControlDependencies({noop.operation}), x);
396 Output id_1 = ops::Identity(
397 s.WithControlDependencies({noop.operation, noop_1.operation}), y);
398
399 GrapplerItem item;
400 TF_CHECK_OK(s.ToGraphDef(&item.graph));
401 item.fetch.push_back("Identity");
402 item.fetch.push_back("Identity_1");
403
404 DependencyOptimizer optimizer;
405 GraphDef output;
406 Status status = optimizer.Optimize(nullptr, item, &output);
407 TF_EXPECT_OK(status);
408 // Run the optimizer twice to make sure the rewrite is idempotent.
409 item.graph.Swap(&output);
410 status = optimizer.Optimize(nullptr, item, &output);
411 TF_EXPECT_OK(status);
412
413 EXPECT_EQ(item.graph.node_size(), output.node_size());
414 for (const NodeDef& node : output.node()) {
415 if (node.name() == "NoOp" || node.name() == "NoOp_1") {
416 EXPECT_EQ(0, node.input_size());
417 } else if (node.name() == "Identity") {
418 EXPECT_EQ("x", node.input(0));
419 } else if (node.name() == "Identity_1") {
420 EXPECT_EQ("y", node.input(0));
421 EXPECT_EQ("^x", node.input(1));
422 }
423 }
424 }
425
TEST_F(DependencyOptimizerTest,RemoveIdentity)426 TEST_F(DependencyOptimizerTest, RemoveIdentity) {
427 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
428 Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
429 Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT);
430 Output z = ops::RandomUniform(s.WithOpName("z"), {1, 2}, DT_FLOAT);
431
432 // Identity nodes to be removed.
433 // Case a) with a single input- and multiple outputs.
434 auto id_a = ops::Identity(s.WithOpName("id_a"), x);
435 // Case b) with multiple inputs and a single output.
436 auto id_b = ops::Identity(
437 s.WithOpName("id_b").WithControlDependencies(y).WithControlDependencies(
438 z),
439 x);
440 // Case c) with two inputs and two outputs.
441 auto id_c = ops::Identity(s.WithOpName("id_c").WithControlDependencies(y), x);
442
443 // Output for Case a.
444 Output a_a = ops::Identity(s.WithOpName("a_a"), id_a);
445 Output a_b = ops::Identity(s.WithOpName("a_b"), id_a);
446 Output a_c =
447 ops::Identity(s.WithOpName("a_c").WithControlDependencies(id_a), z);
448 Output a_d =
449 ops::Identity(s.WithOpName("a_d").WithControlDependencies(id_a), z);
450 // Output for Case b.
451 Output b_a = ops::Identity(s.WithOpName("b_a"), id_b);
452 // Output for Case c.
453 Output c_a = ops::Identity(s.WithOpName("c_a"), id_c);
454 Output c_b =
455 ops::Identity(s.WithOpName("c_b").WithControlDependencies(id_c), z);
456
457 GrapplerItem item;
458 TF_CHECK_OK(s.ToGraphDef(&item.graph));
459 item.fetch = {"a_a", "a_b", "a_c", "a_d", "b_a", "c_a", "c_b"};
460
461 DependencyOptimizer optimizer;
462 GraphDef output;
463 Status status = optimizer.Optimize(nullptr, item, &output);
464 TF_EXPECT_OK(status);
465
466 EXPECT_EQ(item.graph.node_size() - 3, output.node_size());
467 int found = 0;
468 for (const NodeDef& node : output.node()) {
469 EXPECT_NE("id_a", node.name());
470 EXPECT_NE("id_b", node.name());
471 EXPECT_NE("id_c", node.name());
472 if (node.name() == "a_a" || node.name() == "a_b") {
473 ASSERT_EQ(1, node.input_size());
474 EXPECT_EQ("x", node.input(0));
475 ++found;
476 }
477 if (node.name() == "a_c" || node.name() == "a_d") {
478 ASSERT_EQ(2, node.input_size());
479 EXPECT_EQ("z", node.input(0));
480 EXPECT_EQ("^x", node.input(1));
481 ++found;
482 }
483 if (node.name() == "b_a") {
484 ASSERT_EQ(3, node.input_size());
485 EXPECT_EQ("x", node.input(0));
486 EXPECT_EQ("^y", node.input(1));
487 EXPECT_EQ("^z", node.input(2));
488 ++found;
489 }
490 if (node.name() == "c_a") {
491 ASSERT_EQ(2, node.input_size());
492 EXPECT_EQ("x", node.input(0));
493 EXPECT_EQ("^y", node.input(1));
494 ++found;
495 }
496 if (node.name() == "c_b") {
497 ASSERT_EQ(3, node.input_size());
498 EXPECT_EQ("z", node.input(0));
499 EXPECT_EQ("^x", node.input(1));
500 EXPECT_EQ("^y", node.input(2));
501 ++found;
502 }
503 }
504 EXPECT_EQ(found, 7);
505 }
506
TEST_F(DependencyOptimizerTest,RemoveIdentity_RepeatedInputs)507 TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) {
508 // Corner cases with repeated inputs.
509 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
510 ops::Variable x(scope.WithOpName("x"), {}, DT_BOOL);
511 ops::Variable y(scope.WithOpName("y"), {}, DT_BOOL);
512 ops::Switch sw(scope.WithOpName("switch"), x, x);
513 // id0 should be removed.
514 Output id0 = ops::Identity(scope.WithOpName("id0"), sw.output_true);
515 // id1 should not be removed, since it would anchor a control dependency
516 // on the switch.
517 Output id1 = ops::Identity(scope.WithOpName("id1"), sw.output_false);
518 Output or0 = ops::LogicalOr(scope.WithOpName("or0"), id0, id0);
519 Output or1 = ops::LogicalOr(scope.WithOpName("or1"), id0, y);
520 Output or2 = ops::LogicalOr(
521 scope.WithOpName("or2").WithControlDependencies(id1), y, y);
522
523 GrapplerItem item;
524 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
525 item.fetch.push_back("or0");
526 item.fetch.push_back("or1");
527 item.fetch.push_back("or2");
528 DependencyOptimizer optimizer;
529 GraphDef output;
530 Status status = optimizer.Optimize(nullptr, item, &output);
531 TF_EXPECT_OK(status);
532
533 EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
534 int found = 0;
535 for (const NodeDef& node : output.node()) {
536 EXPECT_NE("id0", node.name());
537 if (node.name() == "or0") {
538 EXPECT_EQ(2, node.input_size());
539 EXPECT_EQ("switch:1", node.input(0));
540 EXPECT_EQ("switch:1", node.input(1));
541 ++found;
542 }
543 if (node.name() == "or1") {
544 EXPECT_EQ(2, node.input_size());
545 EXPECT_EQ("switch:1", node.input(0));
546 EXPECT_EQ("y", node.input(1));
547 ++found;
548 }
549 if (node.name() == "or2") {
550 // or1 should be unchanged.
551 EXPECT_EQ(3, node.input_size());
552 EXPECT_EQ("y", node.input(0));
553 EXPECT_EQ("y", node.input(1));
554 EXPECT_EQ("^id1", node.input(2));
555 ++found;
556 }
557 }
558 EXPECT_EQ(found, 3);
559 }
560
TEST_F(DependencyOptimizerTest,Transitive_Reduction_Simple)561 TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) {
562 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
563 Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
564 Output x = ops::Square(s.WithOpName("x"), c);
565 Output neg1 = ops::Neg(s.WithOpName("neg1"), x);
566 Output neg2 =
567 ops::Neg(s.WithOpName("neg2").WithControlDependencies({x}), neg1);
568
569 GrapplerItem item;
570 TF_CHECK_OK(s.ToGraphDef(&item.graph));
571 item.fetch.push_back("neg2");
572 DependencyOptimizer optimizer;
573 GraphDef output;
574 Status status = optimizer.Optimize(nullptr, item, &output);
575 TF_EXPECT_OK(status);
576 EXPECT_EQ(4, output.node_size());
577 EXPECT_EQ("neg2", output.node(3).name());
578 EXPECT_EQ(1, output.node(3).input_size());
579 EXPECT_EQ("neg1", output.node(3).input(0));
580 }
581
TEST_F(DependencyOptimizerTest,ChangeToNoop_Identity)582 TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) {
583 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
584 ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
585 Output id_after_var = ops::Identity(scope.WithOpName("id_after_var"), v_in);
586 ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
587 ops::Switch s(
588 scope.WithOpName("switch").WithControlDependencies(id_after_var), v_in,
589 v_ctrl);
590 Output id0 = ops::Identity(scope.WithOpName("id0"), s.output_true);
591 Output grappler_added_id = ops::Identity(
592 scope.WithOpName("ConstantFoldingCtrl/switch_1"), s.output_true);
593 Output c1 = ops::Const(scope.WithOpName("c1")
594 .WithControlDependencies(id_after_var)
595 .WithControlDependencies(grappler_added_id),
596 {1.0f, 2.0f}, {1, 2});
597 Output id1 = ops::Identity(scope.WithOpName("id1"), c1);
598 Output id2 = ops::Identity(scope.WithOpName("id2"), id0);
599 Output fetch =
600 ops::Identity(scope.WithOpName("fetch").WithControlDependencies(id1), c1);
601
602 GrapplerItem item;
603 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
604 item.fetch.push_back("c1");
605 item.fetch.push_back("id2");
606 item.fetch.push_back("fetch");
607
608 DependencyOptimizer optimizer;
609 GraphDef output;
610 Status status = optimizer.Optimize(nullptr, item, &output);
611 TF_EXPECT_OK(status);
612
613 EXPECT_EQ(item.graph.node_size() - 2, output.node_size());
614 bool found = false;
615 for (int i = 0; i < output.node_size(); ++i) {
616 const NodeDef& node = output.node(i);
617 // "id0" and "id1" but neither "ConstantFoldingCtrl/switch_1",
618 // "id_after_var, nor "id2"" should be eliminated.
619 EXPECT_NE("id0", node.name());
620 EXPECT_NE("id1", node.name());
621 if (node.name() == "c1") {
622 EXPECT_EQ("Const", node.op());
623 EXPECT_EQ(1, node.input_size());
624 EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
625 found = true;
626 }
627 }
628 EXPECT_TRUE(found);
629 }
630
TEST_F(DependencyOptimizerTest,IdentityInputs)631 TEST_F(DependencyOptimizerTest, IdentityInputs) {
632 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
633 Output b = ops::Placeholder(scope.WithOpName("b"), DT_BOOL);
634 Output x = ops::RandomUniform(scope.WithOpName("x"), {1, 2}, DT_FLOAT);
635 auto s = ops::Switch(scope.WithOpName("s"), x, b);
636
637 // Identity nodes to be removed.
638 auto id_f = ops::Identity(scope.WithOpName("id_f"), s.output_false);
639 auto id_t = ops::Identity(scope.WithOpName("id_t"), s.output_true);
640
641 // Output
642 Output out1 = ops::Identity(scope.WithOpName("out1"), id_f);
643 Output out2 = ops::Identity(scope.WithOpName("out2"), id_t);
644
645 GrapplerItem item;
646 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
647 item.fetch = {"out1", "out2"};
648
649 DependencyOptimizer optimizer;
650 GraphDef output;
651 Status status = optimizer.Optimize(nullptr, item, &output);
652 TF_EXPECT_OK(status);
653
654 EXPECT_EQ(6, output.node_size());
655 EXPECT_EQ("out1", output.node(4).name());
656 EXPECT_EQ(1, output.node(4).input_size());
657 EXPECT_EQ("s", output.node(4).input(0));
658
659 EXPECT_EQ("out2", output.node(5).name());
660 EXPECT_EQ(1, output.node(5).input_size());
661 EXPECT_EQ("s:1", output.node(5).input(0));
662 }
663
TEST_F(DependencyOptimizerTest,RemoveIdentityN_SwitchInput)664 TEST_F(DependencyOptimizerTest, RemoveIdentityN_SwitchInput) {
665 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
666 Output b = ops::Placeholder(scope.WithOpName("b"), DT_BOOL);
667 Output x = ops::RandomUniform(scope.WithOpName("x"), {1, 2}, DT_FLOAT);
668 auto s = ops::Switch(scope.WithOpName("s"), x, b);
669
670 // IdentityN nodes to be removed.
671 auto id_f = ops::IdentityN(scope.WithOpName("id_f"), {s.output_false});
672 auto id_t = ops::IdentityN(scope.WithOpName("id_t"), {s.output_true});
673 auto id_b =
674 ops::IdentityN(scope.WithOpName("id_b"), {s.output_false, s.output_true});
675
676 // Outputs
677 Output out1 = ops::Identity(scope.WithOpName("out1"), id_f[0]);
678 Output out2 = ops::Identity(scope.WithOpName("out2"), id_t[0]);
679 Output out3 = ops::Identity(scope.WithOpName("out3"), id_b[0]);
680 Output out4 = ops::Identity(scope.WithOpName("out4"), id_b[1]);
681
682 GrapplerItem item;
683 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
684 item.fetch = {"out1", "out2", "out3", "out4"};
685
686 DependencyOptimizer optimizer;
687 GraphDef output;
688 Status status = optimizer.Optimize(nullptr, item, &output);
689 TF_EXPECT_OK(status);
690
691 EXPECT_EQ(8, output.node_size());
692
693 auto out1_node = output.node(7);
694 EXPECT_EQ("out1", out1_node.name());
695 EXPECT_EQ(1, out1_node.input_size());
696 EXPECT_EQ("s", out1_node.input(0));
697
698 auto out2_node = output.node(4);
699 EXPECT_EQ("out2", out2_node.name());
700 EXPECT_EQ(1, out2_node.input_size());
701 EXPECT_EQ("s:1", out2_node.input(0));
702
703 auto out3_node = output.node(5);
704 EXPECT_EQ("out3", out3_node.name());
705 EXPECT_EQ(1, out3_node.input_size());
706 EXPECT_EQ("s", out3_node.input(0));
707
708 auto out4_node = output.node(6);
709 EXPECT_EQ("out4", out4_node.name());
710 EXPECT_EQ(1, out4_node.input_size());
711 EXPECT_EQ("s:1", out4_node.input(0));
712 }
713
TEST_F(DependencyOptimizerTest,DoNotRemoveIdentityNWithControlDependency)714 TEST_F(DependencyOptimizerTest, DoNotRemoveIdentityNWithControlDependency) {
715 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
716 Output input1 = ops::Placeholder(scope.WithOpName("input1"), DT_BOOL);
717 Output input2 = ops::Const(scope.WithOpName("input2"), {1, 2});
718
719 auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {input1, input2});
720 Output out1 = ops::Identity(scope.WithOpName("out1"), id_n[0]);
721 Output out2 = ops::Identity(scope.WithOpName("out2"), id_n[1]);
722 auto out3 =
723 ops::NoOp(scope.WithOpName("out3").WithControlDependencies(id_n[1]));
724
725 GrapplerItem item;
726 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
727 item.fetch = {"out1", "out2", "out3"};
728
729 DependencyOptimizer optimizer;
730 GraphDef optimized_graph_def;
731 Status status = optimizer.Optimize(nullptr, item, &optimized_graph_def);
732 TF_EXPECT_OK(status);
733
734 EXPECT_EQ(6, optimized_graph_def.node_size());
735 }
736
TEST_F(DependencyOptimizerTest,Identity_DeviceCrossing_ConsumerOnDifferentDevice)737 TEST_F(DependencyOptimizerTest,
738 Identity_DeviceCrossing_ConsumerOnDifferentDevice) {
739 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
740 Output x_on_1 =
741 ops::Const(s.WithOpName("x_on_1").WithDevice("/gpu:1"), {1.0f}, {});
742 Output one_on_3 =
743 ops::Const(s.WithOpName("one_on_3").WithDevice("/gpu:3"), {1.0f}, {});
744 Output x_on_2 =
745 ops::Identity(s.WithOpName("x_on_2").WithDevice("/gpu:2"), x_on_1);
746 Output result =
747 ops::Add(s.WithOpName("result").WithDevice("/gpu:3"), x_on_2, one_on_3);
748
749 GrapplerItem item;
750 TF_CHECK_OK(s.ToGraphDef(&item.graph));
751 item.fetch = {"result"};
752 DependencyOptimizer optimizer;
753 GraphDef output;
754 Status status = optimizer.Optimize(nullptr, item, &output);
755 TF_EXPECT_OK(status);
756
757 VerifyGraphsEqual(item.graph, output, __FUNCTION__);
758 }
759
TEST_F(DependencyOptimizerTest,Identity_DeviceCrossing_ConsumerOnSameDevice)760 TEST_F(DependencyOptimizerTest, Identity_DeviceCrossing_ConsumerOnSameDevice) {
761 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
762 Output x_on_1 =
763 ops::Const(s.WithOpName("x_on_1").WithDevice("/gpu:1"), {1.0f}, {});
764 Output one_on_2 =
765 ops::Const(s.WithOpName("one_on_2").WithDevice("/gpu:2"), {1.0f}, {});
766 Output x_on_2 =
767 ops::Identity(s.WithOpName("x_on_2").WithDevice("/gpu:2"), x_on_1);
768 Output result =
769 ops::Add(s.WithOpName("result").WithDevice("/gpu:2"), x_on_2, one_on_2);
770
771 GrapplerItem item;
772 TF_CHECK_OK(s.ToGraphDef(&item.graph));
773 item.fetch = {"result"};
774 DependencyOptimizer optimizer;
775 GraphDef output;
776 Status status = optimizer.Optimize(nullptr, item, &output);
777 TF_EXPECT_OK(status);
778 EXPECT_EQ(3, output.node_size());
779 for (const auto& node : output.node()) {
780 EXPECT_NE("x_on_2", node.name());
781 if (node.name() == "result") {
782 EXPECT_EQ("x_on_1", node.input(0));
783 }
784 }
785 }
786
TEST_F(DependencyOptimizerTest,RemoveGreaterEqualWithNoOp)787 TEST_F(DependencyOptimizerTest, RemoveGreaterEqualWithNoOp) {
788 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
789 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
790 ops::Placeholder::Shape({}));
791 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
792 ops::Placeholder::Shape({}));
793 auto greaterequal = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y);
794 auto noop =
795 ops::NoOp(s.WithOpName("NoOp").WithControlDependencies(greaterequal));
796 Output add = ops::Add(
797 s.WithOpName("z").WithControlDependencies({noop.operation}), x, y);
798 GrapplerItem item;
799 TF_CHECK_OK(s.ToGraphDef(&item.graph));
800
801 DependencyOptimizer optimizer;
802 GraphDef output;
803 item.fetch.push_back("z");
804 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
805
806 int count = 0;
807 for (const NodeDef& node : output.node()) {
808 if (node.name() == "x") {
809 count++;
810 EXPECT_EQ("Placeholder", node.op());
811 EXPECT_EQ(0, node.input_size());
812 } else if (node.name() == "y") {
813 count++;
814 EXPECT_EQ("Placeholder", node.op());
815 EXPECT_EQ(0, node.input_size());
816 } else if (node.name() == "GreaterEqual") {
817 count++;
818 } else if (node.name() == "NoOp") {
819 count++;
820 } else if (node.name() == "z") {
821 count++;
822 EXPECT_EQ("Add", node.op());
823 EXPECT_EQ(2, node.input_size());
824 EXPECT_EQ("x", node.input(0));
825 EXPECT_EQ("y", node.input(1));
826 }
827 }
828 EXPECT_EQ(3, count);
829 }
830
TEST_F(DependencyOptimizerTest,GroupCrossDeviceControlDeps)831 TEST_F(DependencyOptimizerTest, GroupCrossDeviceControlDeps) {
832 GrapplerItem item;
833 {
834 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
835 Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:1"),
836 {1, 2}, DT_FLOAT);
837 Output b = ops::RandomUniform(s.WithOpName("b").WithDevice("/CPU:2"),
838 {1, 2}, DT_FLOAT);
839 Output c = ops::RandomUniform(s.WithOpName("c").WithDevice("/CPU:1"),
840 {1, 2}, DT_FLOAT);
841 Output d = ops::RandomUniform(s.WithOpName("d").WithDevice("/CPU:3"),
842 {1, 2}, DT_FLOAT);
843 Output e = ops::RandomUniform(s.WithOpName("e").WithDevice("/CPU:0"),
844 {1, 2}, DT_FLOAT);
845 // Node with cross-device dependencies.
846 auto fetch = ops::Identity(
847 s.WithOpName("f")
848 .WithControlDependencies({a.op(), b.op(), c.op(), d.op()})
849 .WithDevice("/GPU:0"),
850 {e});
851
852 TF_CHECK_OK(s.ToGraphDef(&item.graph));
853 item.fetch.push_back("f");
854 }
855
856 GraphDef expected;
857 {
858 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
859 Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:1"),
860 {1, 2}, DT_FLOAT);
861 Output b = ops::RandomUniform(s.WithOpName("b").WithDevice("/CPU:2"),
862 {1, 2}, DT_FLOAT);
863 Output c = ops::RandomUniform(s.WithOpName("c").WithDevice("/CPU:1"),
864 {1, 2}, DT_FLOAT);
865 Output d = ops::RandomUniform(s.WithOpName("d").WithDevice("/CPU:3"),
866 {1, 2}, DT_FLOAT);
867 Output e = ops::RandomUniform(s.WithOpName("e").WithDevice("/CPU:0"),
868 {1, 2}, DT_FLOAT);
869 auto noop = ops::NoOp(s.WithOpName("GroupCrossDeviceControlEdges_0/f")
870 .WithDevice("/CPU:1")
871 .WithControlDependencies({a.op(), c.op()}));
872 auto fetch =
873 ops::Identity(s.WithOpName("f")
874 .WithControlDependencies({b.op(), d.op(), noop})
875 .WithDevice("/GPU:0"),
876 {e});
877
878 TF_CHECK_OK(s.ToGraphDef(&expected));
879 }
880
881 DependencyOptimizer optimizer;
882 GraphDef output;
883 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
884 CompareGraphs(expected, output);
885
886 // Run the optimizer again to verify idempotence.
887 item.graph.Swap(&output);
888 output.Clear();
889 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
890 CompareGraphs(expected, output);
891 }
892
TEST_F(DependencyOptimizerTest,GroupCrossHostControlDeps)893 TEST_F(DependencyOptimizerTest, GroupCrossHostControlDeps) {
894 GrapplerItem item;
895 {
896 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
897 std::vector<Operation> ops;
898 Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:0"),
899 {1, 2}, DT_FLOAT);
900 for (int t = 0; t < 4; ++t) {
901 for (int c = 0; c < 8; ++c) {
902 string opname = absl::StrCat("t", t, "/c", c);
903 string device = absl::StrCat("/task:", t, "/device:TPU:", c);
904 Output output = ops::RandomUniform(
905 s.WithOpName(opname).WithDevice(device), {1, 2}, DT_FLOAT);
906 ops.push_back(output.op());
907 }
908 }
909 // Node with cross-device dependencies.
910 auto fetch = ops::Identity(
911 s.WithOpName("f").WithControlDependencies(ops).WithDevice("/CPU:0"),
912 {a});
913
914 TF_CHECK_OK(s.ToGraphDef(&item.graph));
915 item.fetch.push_back("f");
916 }
917
918 GraphDef expected;
919 {
920 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
921 TF_CHECK_OK(s.ToGraphDef(&expected));
922 }
923
924 DependencyOptimizer optimizer;
925 GraphDef output;
926 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
927
928 EXPECT_EQ(output.node_size(), item.graph.node_size() + 4);
929 std::set<string> tasks;
930 for (const auto& n : output.node()) {
931 if (n.op() == "NoOp") {
932 EXPECT_TRUE(absl::StartsWith(n.name(), "GroupCrossDeviceControlEdges"));
933 EXPECT_EQ(n.input_size(), 8);
934 tasks.insert(n.device());
935 }
936
937 if (n.name() == "f") {
938 EXPECT_EQ(n.input_size(), 5);
939 for (const auto& i : n.input()) {
940 EXPECT_TRUE(i == "a" ||
941 absl::StartsWith(i, "^GroupCrossDeviceControlEdges"));
942 }
943 }
944 }
945 EXPECT_EQ(tasks.size(), 4);
946 }
947
948 } // namespace
949 } // namespace grappler
950 } // namespace tensorflow
951