1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
16
17 #include <unordered_set>
18
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/framework/tensor_testutil.h"
25 #include "tensorflow/core/graph/testlib.h"
26 #include "tensorflow/core/grappler/grappler_item.h"
27 #include "tensorflow/core/grappler/op_types.h"
28 #include "tensorflow/core/grappler/utils.h"
29 #include "tensorflow/core/grappler/utils/topological_sort.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/protobuf/config.pb.h"
34 #include "tensorflow/core/public/session.h"
35 #include "tensorflow/core/public/session_options.h"
36
37 namespace tensorflow {
38 namespace grappler {
39 namespace {
40
41 class ScopedAllocatorOptimizerTest : public ::testing::Test {
42 public:
CreateSession(const GraphDef & graph,const ConfigProto & config)43 std::unique_ptr<Session> CreateSession(const GraphDef& graph,
44 const ConfigProto& config) {
45 SessionOptions options;
46 options.config = config;
47 (*options.config.mutable_device_count())["CPU"] = 2;
48 Session* session = NewSession(options);
49 TF_CHECK_OK(session->Create(graph));
50 return std::unique_ptr<Session>(session);
51 }
52
EvaluateNodes(const GraphDef & graph,const std::vector<string> & fetch)53 std::vector<Tensor> EvaluateNodes(const GraphDef& graph,
54 const std::vector<string>& fetch) {
55 SessionOptions options;
56 std::unique_ptr<Session> session(NewSession(options));
57 TF_CHECK_OK(session->Create(graph));
58 RunOptions run_options;
59 std::vector<Tensor> output_tensors;
60 TF_CHECK_OK(
61 session->Run(run_options, {}, fetch, fetch, &output_tensors, nullptr));
62 TF_CHECK_OK(session->Close());
63 return output_tensors;
64 }
65
66 // Constructs the following graph.
67 // (Flow is top to bottom, like nature intends.)
68 //
69 // The intended optimization is to have s1 and s2 allocate from
70 // a new ScopedAllocator, then replace a1 and a2 with a3 that
71 // reads from the backing buffer.
72 /*
73 a b c
74 \ / \ /
75 s1 s2
76 | |
77 (i1) (i2) if forward is true
78 | |
79 a1 a2
80 | |
81 r1 r2
82 */
BuildAbsGraph(GraphDef * graph_def,bool forward)83 void BuildAbsGraph(GraphDef* graph_def, bool forward) {
84 Scope s = Scope::NewRootScope();
85 s = s.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0");
86
87 Output a =
88 ops::Const<float>(s.WithOpName("a"), {1.0, 0.0, 0.0, -1.0}, {2, 2});
89 Output b =
90 ops::Const<float>(s.WithOpName("b"), {1.0, -2.0, 3.0, 4.0}, {2, 2});
91 Output c =
92 ops::Const<float>(s.WithOpName("c"), {-5.0, -2.0, 0.0, -2.0}, {2, 2});
93 Output s1 = ops::Add(s.WithOpName("s1"), a, b);
94 Output s2 = ops::Add(s.WithOpName("s2"), b, c);
95 Output int1, int2;
96 if (forward) {
97 int1 = ops::Identity(s.WithOpName("i1"), s1);
98 int2 = ops::Identity(s.WithOpName("i2"), s2);
99 } else {
100 int1 = s1;
101 int2 = s2;
102 }
103 Output a1 = ops::Abs(s.WithOpName("a1"), int1);
104 Output a2 = ops::Abs(s.WithOpName("a2"), int2);
105 Output r1 = ops::Reshape(s.WithOpName("r1"), a1, {1, 4});
106 Output r2 = ops::Reshape(s.WithOpName("r2"), a2, {4, 1});
107 TF_CHECK_OK(s.ToGraphDef(graph_def));
108 }
109
110 // Constructs the following graph.
111 // (Flow is top to bottom, like nature intends.)
112 //
113 // a, b, and c are placeholders. s is an Add op. a1, a2, and a3 are Abs ops.
114 // r1, r2, and r3 are Reshape ops.
115 //
116 // After this graph undergoes SA optimization, we expect a, b, and s to be
117 // allocated from a new ScopedAllocator. There will be control edges from the
118 // ScopedAllocator node to a, b, and s, to ensure that we allocate the
119 // backing tensor before we need it. There will also be a control edge from c
120 // to ScopedAllocator node, so that we delay allocation as much as possible.
121 // There should be no edge from b to ScopedAllocator node, because that would
122 // imply a cycle in the graph.
123 /*
124 a b c
125 | / \ /
126 | / \ /
127 | | s1
128 | | |
129 a1 a2 a3
130 | | |
131 r1 r2 r3
132 */
BuildAbsGraphWithInputDependencies(GraphDef * graph_def)133 void BuildAbsGraphWithInputDependencies(GraphDef* graph_def) {
134 Scope s = Scope::NewRootScope();
135 s = s.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0");
136
137 Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
138 ops::Placeholder::Shape({2, 2}));
139 Output b = ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
140 ops::Placeholder::Shape({2, 2}));
141 Output c = ops::Placeholder(s.WithOpName("c"), DT_FLOAT,
142 ops::Placeholder::Shape({2, 2}));
143 Output s1 = ops::Add(s.WithOpName("s1"), b, c);
144 Output a1 = ops::Abs(s.WithOpName("a1"), a);
145 Output a2 = ops::Abs(s.WithOpName("a2"), b);
146 Output a3 = ops::Abs(s.WithOpName("a3"), s1);
147 Output r1 = ops::Reshape(s.WithOpName("r1"), a1, {1, 4});
148 Output r2 = ops::Reshape(s.WithOpName("r2"), a2, {4, 1});
149 Output r3 = ops::Reshape(s.WithOpName("r3"), a3, {4, 1});
150 TF_CHECK_OK(s.ToGraphDef(graph_def));
151 }
152
153 // Constructs the following graph.
154 //
155 // a and b are data inputs. ctl1 and ctl2 are control inputs. a1 and a2 are
156 // Abs ops. o1 and o2 are data outputs. a1 -> ctl3 and a2 -> ctl4 are
157 // control edges.
158 //
159 // After the optimizer runs, we expect the ctl1 and ctl2 to be connected to
160 // the SAConcat node, and ctl3 and ctl4 to be connected to SASplit node.
161 /*
162 a ctl1 b ctl2
163 \ / \ /
164 a1 a2
165 / \ / \
166 o1 ctl3 o2 ctl4
167 */
BuildAbsGraphWithInputAndOutputControlEdges(GraphDef * graph_def)168 void BuildAbsGraphWithInputAndOutputControlEdges(GraphDef* graph_def) {
169 Scope s = Scope::NewRootScope();
170 s = s.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0");
171
172 Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
173 ops::Placeholder::Shape({2, 2}));
174 Output b = ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
175 ops::Placeholder::Shape({2, 2}));
176 Output ctl1 = ops::Placeholder(s.WithOpName("ctl1"), DT_FLOAT,
177 ops::Placeholder::Shape({2, 2}));
178 Output ctl2 = ops::Placeholder(s.WithOpName("ctl2"), DT_FLOAT,
179 ops::Placeholder::Shape({2, 2}));
180 Output a1 = ops::Abs(s.WithOpName("a1").WithControlDependencies({ctl1}), a);
181 Output a2 = ops::Abs(s.WithOpName("a2").WithControlDependencies({ctl2}), b);
182 Output o1 = ops::Reshape(s.WithOpName("o1"), a1, {1, 4});
183 Output o2 = ops::Reshape(s.WithOpName("o2"), a2, {4, 1});
184 Output ctl3 =
185 ops::Const<float>(s.WithOpName("ctl3").WithControlDependencies({a1}),
186 {0.0, 0.0, 0.0, 0.0}, {2, 2});
187 Output ctl4 =
188 ops::Const<float>(s.WithOpName("ctl4").WithControlDependencies({a2}),
189 {0.0, 0.0, 0.0, 0.0}, {2, 2});
190 TF_CHECK_OK(s.ToGraphDef(graph_def));
191 }
192
193 // Constructs the following graph.
194 //
195 // We have 2 different name scopes in this graph. s3, a3, a4, r3, and r4 are
196 // all under "sub" scope. All other nodes are in the root scope.
197 //
198 // The intention is to test that ScopedAllocatorOptimizer works well with a
199 // graph that has multiple name scopes. In particular, it should work when a
200 // node (in this case s2) is an input to two nodes in different name scopes
201 // (a2 and sub/a3) which may be scope allocated.
202 /*
203 a b c a b
204 \ / \ / \ /
205 s1 s2------ sub/s3
206 | | | |
207 a1 a2 sub/a4 sub/a3
208 | | | |
209 r1 r2 sub/r4 sub/r3
210 */
BuildGraphWithMultipleScopes(GraphDef * graph_def)211 void BuildGraphWithMultipleScopes(GraphDef* graph_def) {
212 Scope root_scope = Scope::NewRootScope();
213 root_scope =
214 root_scope.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0");
215
216 Output a = ops::Const<float>(root_scope.WithOpName("a"),
217 {1.0, 0.0, 0.0, -1.0}, {2, 2});
218 Output b = ops::Const<float>(root_scope.WithOpName("b"),
219 {1.0, -2.0, 3.0, 4.0}, {2, 2});
220 Output c = ops::Const<float>(root_scope.WithOpName("c"),
221 {-5.0, -2.0, 0.0, -2.0}, {2, 2});
222
223 // Root scope ops.
224 Output s1 = ops::Add(root_scope.WithOpName("s1"), a, b);
225 Output s2 = ops::Add(root_scope.WithOpName("s2"), b, c);
226 Output a1 = ops::Abs(root_scope.WithOpName("a1"), s1);
227 Output a2 = ops::Abs(root_scope.WithOpName("a2"), s2);
228 Output r1 = ops::Reshape(root_scope.WithOpName("r1"), a1, {1, 4});
229 Output r2 = ops::Reshape(root_scope.WithOpName("r2"), a2, {4, 1});
230
231 // Sub scope ops.
232 Scope sub_scope = root_scope.NewSubScope("sub");
233 Output s3 = ops::Add(sub_scope.WithOpName("s3"), a, b);
234 Output a3 = ops::Abs(sub_scope.WithOpName("a3"), s3);
235 Output a4 = ops::Abs(sub_scope.WithOpName("a4"), s2);
236 Output r3 = ops::Reshape(sub_scope.WithOpName("r3"), a3, {1, 4});
237 Output r4 = ops::Reshape(sub_scope.WithOpName("r4"), a4, {4, 1});
238
239 TF_CHECK_OK(root_scope.ToGraphDef(graph_def));
240 }
241
242 // Constructs the following graph.
243 //
244 // c1 and c2 are Const ops. a1 and a2 are Abs ops.
245 // We expect the optimizer to succeed and insert Identity between ci and ai.
246 // This will ensure that we will still be able use ScopedAllocator with Const
247 // inputs.
248 /*
249 c1 c2
250 | |
251 a1 a2
252 | |
253 r1 r2
254 */
BuildConstGraph(GraphDef * graph_def,bool forward)255 void BuildConstGraph(GraphDef* graph_def, bool forward) {
256 Scope s = Scope::NewRootScope();
257 s = s.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0");
258
259 Output c1 =
260 ops::Const<float>(s.WithOpName("c1"), {1.0, 0.0, 0.0, -1.0}, {2, 2});
261 Output c2 =
262 ops::Const<float>(s.WithOpName("c2"), {1.0, -2.0, 3.0, 4.0}, {2, 2});
263 Output a1 = ops::Abs(s.WithOpName("a1"), c1);
264 Output a2 = ops::Abs(s.WithOpName("a2"), c2);
265 Output r1 = ops::Reshape(s.WithOpName("r1"), a1, {1, 4});
266 Output r2 = ops::Reshape(s.WithOpName("r2"), a2, {4, 1});
267 TF_CHECK_OK(s.ToGraphDef(graph_def));
268 }
269
SetShapes(GraphDef * graph_def)270 void SetShapes(GraphDef* graph_def) {
271 TensorShapeProto shape_proto;
272 shape_proto.add_dim()->set_size(2);
273 shape_proto.add_dim()->set_size(2);
274
275 for (NodeDef& n : *graph_def->mutable_node()) {
276 if (n.op() == "Add" || n.op() == "Abs") {
277 AddNodeAttr("_output_shapes", {shape_proto}, &n);
278 }
279 }
280 }
281
282 // Invokes ScopedAllocatorOptimizer on `graph_def`, then executes it and
283 // returns the outputs specified by `output_names` in `outputs`.
ExecuteGraph(const GraphDef & graph_def,const std::vector<string> & output_names,std::vector<Tensor> * outputs)284 void ExecuteGraph(const GraphDef& graph_def,
285 const std::vector<string>& output_names,
286 std::vector<Tensor>* outputs) {
287 // Turn off all optimization except the ScopedAllocatorOptimizer
288 // to avoid anything that would alter the expected graph input/output,
289 // e.g. by constant folding away all calculations.
290 ConfigProto config;
291 GraphOptions* gopt = config.mutable_graph_options();
292 OptimizerOptions* opts = gopt->mutable_optimizer_options();
293 opts->set_do_common_subexpression_elimination(false);
294 opts->set_do_constant_folding(false);
295 opts->set_do_function_inlining(false);
296 opts->set_opt_level(OptimizerOptions::L0);
297 RewriterConfig* rwcfg = gopt->mutable_rewrite_options();
298 rwcfg->clear_optimizers();
299 (*rwcfg->add_optimizers()) = "scoped_allocator";
300 rwcfg->mutable_scoped_allocator_opts()->add_enable_op("Abs");
301 std::unique_ptr<Session> session(CreateSession(graph_def, config));
302
303 std::vector<std::pair<string, Tensor>> inputs;
304 std::vector<string> target_nodes = {};
305 Status s = session->Run(inputs, output_names, target_nodes, outputs);
306 TF_ASSERT_OK(s);
307 ASSERT_EQ(outputs->size(), output_names.size());
308 }
309
310 // Validates that outputs match expected.
ValidateValues(const std::vector<Tensor> & outputs,const std::vector<std::vector<float>> & expected)311 void ValidateValues(const std::vector<Tensor>& outputs,
312 const std::vector<std::vector<float>>& expected) {
313 for (int i = 0; i < expected.size(); ++i) {
314 EXPECT_EQ(expected[i].size(), outputs[i].NumElements());
315 for (int j = 0; j < expected[i].size(); ++j) {
316 EXPECT_EQ(expected[i][j], outputs[i].flat<float>()(j));
317 }
318 }
319 }
320
GetNode(NodeMap * node_map,const string & node_name,NodeDef ** node_def)321 void GetNode(NodeMap* node_map, const string& node_name, NodeDef** node_def) {
322 *node_def = node_map->GetNode(node_name);
323 ASSERT_TRUE(*node_def);
324 }
325
326 // Validate that a node has a single control input from scoped allocator node.
327 // Return the scoped allocator node.
ValidateSAControlInput(GraphDef * graph,NodeMap * node_map,const string & node_name)328 NodeDef* ValidateSAControlInput(GraphDef* graph, NodeMap* node_map,
329 const string& node_name) {
330 NodeDef* node = nullptr;
331 GetNode(node_map, node_name, &node);
332 int num_control_inputs = 0;
333 string control_input_name;
334 for (const auto& input : node->input()) {
335 if (IsControlInput(input)) {
336 ++num_control_inputs;
337 control_input_name = input;
338 }
339 }
340 EXPECT_EQ(num_control_inputs, 1);
341 NodeDef* control_input_node = nullptr;
342 GetNode(node_map, control_input_name, &control_input_node);
343 EXPECT_EQ(control_input_node->op(), "_ScopedAllocator");
344 return control_input_node;
345 }
346
NumControlInputs(NodeMap * node_map,const string & node_name)347 int NumControlInputs(NodeMap* node_map, const string& node_name) {
348 NodeDef* node = nullptr;
349 GetNode(node_map, node_name, &node);
350 int num_control_inputs = 0;
351 for (const auto& input : node->input()) {
352 if (IsControlInput(input)) {
353 ++num_control_inputs;
354 }
355 }
356 return num_control_inputs;
357 }
358 };
359 #ifndef ENABLE_MKL
360
TEST_F(ScopedAllocatorOptimizerTest,UnaryRewriteOnly)361 TEST_F(ScopedAllocatorOptimizerTest, UnaryRewriteOnly) {
362 // Tests that Rewrite of program with parallel unary Ops is done as
363 // anticipated.
364 GrapplerItem item;
365 BuildAbsGraph(&item.graph, false);
366 SetShapes(&item.graph);
367
368 ScopedAllocatorOptions opts;
369 opts.add_enable_op("Abs");
370 ScopedAllocatorOptimizer sao(RewriterConfig::ON, opts);
371 ScopedAllocatorOptimizer::OpNameSet ons;
372 ons.insert("Abs");
373
374 GraphDef optimized_graph;
375 TF_ASSERT_OK(sao.Optimize(nullptr /*cluster*/, item, &optimized_graph));
376
377 // Examine the resulting graph def.
378 NodeMap node_map(&optimized_graph);
379 NodeDef* nd = nullptr;
380 GetNode(&node_map, "scoped_allocator_1_1", &nd);
381 {
382 auto& nd_set = node_map.GetOutputs(nd->name());
383 ASSERT_EQ(3, nd_set.size());
384 std::unordered_set<string> expected = {"scoped_allocator_concat_1_1", "s1",
385 "s2"};
386 for (auto it : nd_set) {
387 ASSERT_NE(expected.find(it->name()), expected.end())
388 << "Failed to find " << it->name();
389 }
390 }
391 {
392 auto& nd_set = node_map.GetOutputs("scoped_allocator_concat_1_1");
393 ASSERT_EQ(1, nd_set.size());
394 for (auto it : nd_set) {
395 ASSERT_EQ("scoped_allocator_1_1_Abs", it->name());
396 }
397 }
398 {
399 auto& nd_set = node_map.GetOutputs("scoped_allocator_1_1_Abs");
400 ASSERT_EQ(1, nd_set.size());
401 for (auto it : nd_set) {
402 ASSERT_EQ("scoped_allocator_split_1_1", it->name());
403 }
404 }
405 {
406 auto& nd_set = node_map.GetOutputs("scoped_allocator_split_1_1");
407 ASSERT_EQ(2, nd_set.size());
408 std::unordered_set<string> name_set;
409 for (auto it : nd_set) {
410 name_set.insert(it->name());
411 }
412 ASSERT_TRUE(name_set.find("r1") != name_set.end());
413 ASSERT_TRUE(name_set.find("r2") != name_set.end());
414 }
415 }
416
TEST_F(ScopedAllocatorOptimizerTest,UnaryExecute)417 TEST_F(ScopedAllocatorOptimizerTest, UnaryExecute) {
418 // Builds the same graph as UnaryRewriteOnly but also executes it and
419 // validates the output.
420 GraphDef graph_def;
421 BuildAbsGraph(&graph_def, /*forward=*/false);
422 SetShapes(&graph_def);
423 std::vector<Tensor> outputs;
424 ExecuteGraph(graph_def,
425 /*output_names=*/{"r1:0", "r2:0"}, &outputs);
426 // a + b == 2, -2, 3, 3
427 // b + c == -4, -4, 3, 2
428 ValidateValues(outputs, /*expected=*/{{2, 2, 3, 3}, {4, 4, 3, 2}});
429 }
430
TEST_F(ScopedAllocatorOptimizerTest,MultipleScopes)431 TEST_F(ScopedAllocatorOptimizerTest, MultipleScopes) {
432 GraphDef graph_def;
433 BuildGraphWithMultipleScopes(&graph_def);
434 SetShapes(&graph_def);
435 std::vector<Tensor> outputs;
436 ExecuteGraph(graph_def,
437 /*output_names=*/{"r1:0", "r2:0", "sub/r3:0", "sub/r4:0"},
438 &outputs);
439 ValidateValues(
440 outputs,
441 /*expected=*/{{2, 2, 3, 3}, {4, 4, 3, 2}, {2, 2, 3, 3}, {4, 4, 3, 2}});
442 }
443
444 // Tests static ScopedAllocatorOptimizer::ExtendNodeAttr.
445 // Maybe this should be moved elsewhere?
TEST_F(ScopedAllocatorOptimizerTest,Extend)446 TEST_F(ScopedAllocatorOptimizerTest, Extend) {
447 NodeDef nd;
448 ScopedAllocatorOptimizer::ExtendNodeAttr("_scoped_allocator", {0, 2}, &nd);
449 ScopedAllocatorOptimizer::ExtendNodeAttr("_scoped_allocator", {6, 7}, &nd);
450 ScopedAllocatorOptimizer::ExtendNodeAttr("_scoped_allocator", {2, 3}, &nd);
451 VLOG(0) << "nd: " << nd.DebugString();
452 std::vector<int> scoped_allocator_attrs;
453 AttrSlice slice(nd);
454 Status sa_status =
455 GetNodeAttr(slice, "_scoped_allocator", &scoped_allocator_attrs);
456 for (int i : scoped_allocator_attrs) {
457 VLOG(0) << "extracted: " << i;
458 }
459 NodeDef nd2;
460 AddNodeAttr("_scoped_allocator", {0, 2}, &nd2);
461 AddNodeAttr("_scoped_allocator", {6, 7}, &nd2);
462 AddNodeAttr("_scoped_allocator", {2, 3}, &nd2);
463 VLOG(0) << "nd2: " << nd2.DebugString();
464 }
465
TEST_F(ScopedAllocatorOptimizerTest,ForwardInputToOutput)466 TEST_F(ScopedAllocatorOptimizerTest, ForwardInputToOutput) {
467 // Test that kernels that forward the input to output using `set_output` work
468 // well with scoped allocator optimization.
469 GraphDef graph_def;
470 BuildAbsGraph(&graph_def, /*forward=*/true);
471 SetShapes(&graph_def);
472 std::vector<Tensor> outputs;
473 ExecuteGraph(graph_def, /*output_names=*/{"r1:0", "r2:0"}, &outputs);
474 // a + b == 2, -2, 3, 3
475 // b + c == -4, -4, 3, 2
476 ValidateValues(outputs, /*expected=*/{{2, 2, 3, 3}, {4, 4, 3, 2}});
477 }
478
479 // Test that graphs with a dependency upstream from the inputs, such as the one
480 // produced by `BuildAbsGraphWithInputDependencies`, are handled well by this
481 // optimizer. In particular, the optimizer should not create cycles.
TEST_F(ScopedAllocatorOptimizerTest,InputDependencies)482 TEST_F(ScopedAllocatorOptimizerTest, InputDependencies) {
483 GrapplerItem item;
484 BuildAbsGraphWithInputDependencies(&item.graph);
485 SetShapes(&item.graph);
486
487 ScopedAllocatorOptions opts;
488 opts.add_enable_op("Abs");
489 ScopedAllocatorOptimizer sao(RewriterConfig::ON, opts);
490 ScopedAllocatorOptimizer::OpNameSet ons;
491 ons.insert("Add");
492
493 GraphDef optimized_graph;
494 TF_ASSERT_OK(sao.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
495 NodeMap node_map(&optimized_graph);
496
497 // Check that all inputs to Abs ops have ScopedAllocator as a control
498 // dependency.
499 NodeDef* scoped_allocator_node =
500 ValidateSAControlInput(&optimized_graph, &node_map, "a");
501 VLOG(1) << scoped_allocator_node->DebugString();
502 EXPECT_TRUE(ValidateSAControlInput(&optimized_graph, &node_map, "b"));
503 EXPECT_TRUE(ValidateSAControlInput(&optimized_graph, &node_map, "s1"));
504
505 // Check that ScopedAllocator node has a single input, which is a control edge
506 // from c.
507 EXPECT_EQ(scoped_allocator_node->input_size(), 1);
508 EXPECT_EQ(scoped_allocator_node->input(0), "^c");
509 }
510
511 // Test that graphs with input and output control edges are rewired correctly by
512 // the optimizer.
TEST_F(ScopedAllocatorOptimizerTest,ControlEdgeRewire)513 TEST_F(ScopedAllocatorOptimizerTest, ControlEdgeRewire) {
514 GrapplerItem item;
515 BuildAbsGraphWithInputAndOutputControlEdges(&item.graph);
516 SetShapes(&item.graph);
517 LOG(INFO) << item.graph.DebugString();
518
519 ScopedAllocatorOptions opts;
520 opts.add_enable_op("Abs");
521 ScopedAllocatorOptimizer sao(RewriterConfig::ON, opts);
522 ScopedAllocatorOptimizer::OpNameSet ons;
523 ons.insert("Const");
524
525 GraphDef optimized_graph;
526 TF_ASSERT_OK(sao.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
527 TF_ASSERT_OK(TopologicalSort(&optimized_graph));
528 NodeMap node_map(&optimized_graph);
529 LOG(INFO) << optimized_graph.DebugString();
530
531 // Check that ctl1 and ctl2 are now connected only to SAConcat.
532 NodeDef* ctl1 = nullptr;
533 GetNode(&node_map, "ctl1", &ctl1);
534 const auto& ctl1_outputs = node_map.GetOutputs("ctl1");
535 EXPECT_EQ(ctl1_outputs.size(), 1);
536 NodeDef* sa_concat = *ctl1_outputs.begin();
537 EXPECT_EQ(sa_concat->op(), "_ScopedAllocatorConcat");
538 NodeDef* ctl2 = nullptr;
539 GetNode(&node_map, "ctl2", &ctl2);
540 const auto& ctl2_outputs = node_map.GetOutputs("ctl2");
541 EXPECT_EQ(ctl2_outputs.size(), 1);
542 EXPECT_EQ(*ctl2_outputs.begin(), sa_concat);
543
544 // Check that SAConcat has only 2 input control edges.
545 EXPECT_EQ(NumControlInputs(&node_map, sa_concat->name()), 2);
546
547 // Check that fused node, which conceptually used to have control inputs from
548 // ctl1 and ctl2 respectively, no longer has any control inputs.
549 const auto& sa_concat_outputs = node_map.GetOutputs(sa_concat->name());
550 EXPECT_EQ(sa_concat_outputs.size(), 1);
551 NodeDef* fused_abs = *sa_concat_outputs.begin();
552 EXPECT_EQ(NumControlInputs(&node_map, fused_abs->name()), 0);
553
554 // Check that SASplit node has control edges to ctl3, ctl4; also check that
555 // those are the only control inputs on ctl3 and ctl4.
556 const auto& fused_abs_outputs = node_map.GetOutputs(fused_abs->name());
557 EXPECT_EQ(fused_abs_outputs.size(), 1);
558 NodeDef* sa_split = *fused_abs_outputs.begin();
559 EXPECT_EQ(NumControlOutputs(*sa_split, node_map), 2);
560 EXPECT_EQ(NumControlInputs(&node_map, "ctl3"), 1);
561 EXPECT_EQ(NumControlInputs(&node_map, "ctl4"), 1);
562 }
563
564 // Test that the optimization succeeds when any input is a Const op, and that it
565 // inserts Identity op between Const and Abs.
TEST_F(ScopedAllocatorOptimizerTest,ConstInput)566 TEST_F(ScopedAllocatorOptimizerTest, ConstInput) {
567 GrapplerItem item;
568 BuildConstGraph(&item.graph, false);
569 SetShapes(&item.graph);
570
571 ScopedAllocatorOptions opts;
572 opts.add_enable_op("Abs");
573 ScopedAllocatorOptimizer sao(RewriterConfig::ON, opts);
574 ScopedAllocatorOptimizer::OpNameSet ons;
575 ons.insert("Abs");
576
577 GraphDef optimized_graph;
578 TF_ASSERT_OK(sao.Optimize(nullptr /*cluster*/, item, &optimized_graph));
579
580 // Examine the resulting graphdef.
581 const NodeDef* sa_node = nullptr;
582 for (const NodeDef& node : optimized_graph.node()) {
583 if (node.op() == "_ScopedAllocator") {
584 sa_node = &node;
585 break;
586 }
587 }
588 ASSERT_NE(sa_node, nullptr);
589 int num_identity_ops = 0;
590 NodeMap node_map(&optimized_graph);
591 for (NodeDef* sa_output : node_map.GetOutputs(sa_node->name())) {
592 EXPECT_FALSE(IsConstant(*sa_output));
593 if (IsIdentity(*sa_output)) {
594 ++num_identity_ops;
595 }
596 }
597 EXPECT_EQ(num_identity_ops, 2);
598 }
599 #endif // ENABLE_MKL
600
601 } // namespace
602 } // namespace grappler
603 } // namespace tensorflow
604