1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
16 
17 #include <unordered_set>
18 
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/framework/tensor_testutil.h"
25 #include "tensorflow/core/graph/testlib.h"
26 #include "tensorflow/core/grappler/grappler_item.h"
27 #include "tensorflow/core/grappler/op_types.h"
28 #include "tensorflow/core/grappler/utils.h"
29 #include "tensorflow/core/grappler/utils/topological_sort.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/protobuf/config.pb.h"
34 #include "tensorflow/core/public/session.h"
35 #include "tensorflow/core/public/session_options.h"
36 
37 namespace tensorflow {
38 namespace grappler {
39 namespace {
40 
41 class ScopedAllocatorOptimizerTest : public ::testing::Test {
42  public:
CreateSession(const GraphDef & graph,const ConfigProto & config)43   std::unique_ptr<Session> CreateSession(const GraphDef& graph,
44                                          const ConfigProto& config) {
45     SessionOptions options;
46     options.config = config;
47     (*options.config.mutable_device_count())["CPU"] = 2;
48     Session* session = NewSession(options);
49     TF_CHECK_OK(session->Create(graph));
50     return std::unique_ptr<Session>(session);
51   }
52 
EvaluateNodes(const GraphDef & graph,const std::vector<string> & fetch)53   std::vector<Tensor> EvaluateNodes(const GraphDef& graph,
54                                     const std::vector<string>& fetch) {
55     SessionOptions options;
56     std::unique_ptr<Session> session(NewSession(options));
57     TF_CHECK_OK(session->Create(graph));
58     RunOptions run_options;
59     std::vector<Tensor> output_tensors;
60     TF_CHECK_OK(
61         session->Run(run_options, {}, fetch, fetch, &output_tensors, nullptr));
62     TF_CHECK_OK(session->Close());
63     return output_tensors;
64   }
65 
66   // Constructs the following graph.
67   // (Flow is top to bottom, like nature intends.)
68   //
69   // The intended optimization is to have s1 and s2 allocate from
70   // a new ScopedAllocator, then replace a1 and a2 with a3 that
71   // reads from the backing buffer.
72   /*
73         a    b    c
74          \  / \  /
75           s1   s2
76           |    |
77          (i1) (i2)  if forward is true
78           |    |
79           a1   a2
80           |    |
81           r1   r2
82   */
BuildAbsGraph(GraphDef * graph_def,bool forward)83   void BuildAbsGraph(GraphDef* graph_def, bool forward) {
84     Scope s = Scope::NewRootScope();
85     s = s.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0");
86 
87     Output a =
88         ops::Const<float>(s.WithOpName("a"), {1.0, 0.0, 0.0, -1.0}, {2, 2});
89     Output b =
90         ops::Const<float>(s.WithOpName("b"), {1.0, -2.0, 3.0, 4.0}, {2, 2});
91     Output c =
92         ops::Const<float>(s.WithOpName("c"), {-5.0, -2.0, 0.0, -2.0}, {2, 2});
93     Output s1 = ops::Add(s.WithOpName("s1"), a, b);
94     Output s2 = ops::Add(s.WithOpName("s2"), b, c);
95     Output int1, int2;
96     if (forward) {
97       int1 = ops::Identity(s.WithOpName("i1"), s1);
98       int2 = ops::Identity(s.WithOpName("i2"), s2);
99     } else {
100       int1 = s1;
101       int2 = s2;
102     }
103     Output a1 = ops::Abs(s.WithOpName("a1"), int1);
104     Output a2 = ops::Abs(s.WithOpName("a2"), int2);
105     Output r1 = ops::Reshape(s.WithOpName("r1"), a1, {1, 4});
106     Output r2 = ops::Reshape(s.WithOpName("r2"), a2, {4, 1});
107     TF_CHECK_OK(s.ToGraphDef(graph_def));
108   }
109 
110   // Constructs the following graph.
111   // (Flow is top to bottom, like nature intends.)
112   //
113   // a, b, and c are placeholders.  s is an Add op.  a1, a2, and a3 are Abs ops.
114   // r1, r2, and r3 are Reshape ops.
115   //
116   // After this graph undergoes SA optimization, we expect a, b, and s to be
117   // allocated from a new ScopedAllocator.  There will be control edges from the
118   // ScopedAllocator node to a, b, and s, to ensure that we allocate the
119   // backing tensor before we need it.  There will also be a control edge from c
120   // to ScopedAllocator node, so that we delay allocation as much as possible.
121   // There should be no edge from b to ScopedAllocator node, because that would
122   // imply a cycle in the graph.
123   /*
124       a      b     c
125       |     / \   /
126       |    /   \ /
127       |    |    s1
128       |    |    |
129       a1   a2   a3
130       |    |    |
131       r1   r2   r3
132   */
BuildAbsGraphWithInputDependencies(GraphDef * graph_def)133   void BuildAbsGraphWithInputDependencies(GraphDef* graph_def) {
134     Scope s = Scope::NewRootScope();
135     s = s.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0");
136 
137     Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
138                                 ops::Placeholder::Shape({2, 2}));
139     Output b = ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
140                                 ops::Placeholder::Shape({2, 2}));
141     Output c = ops::Placeholder(s.WithOpName("c"), DT_FLOAT,
142                                 ops::Placeholder::Shape({2, 2}));
143     Output s1 = ops::Add(s.WithOpName("s1"), b, c);
144     Output a1 = ops::Abs(s.WithOpName("a1"), a);
145     Output a2 = ops::Abs(s.WithOpName("a2"), b);
146     Output a3 = ops::Abs(s.WithOpName("a3"), s1);
147     Output r1 = ops::Reshape(s.WithOpName("r1"), a1, {1, 4});
148     Output r2 = ops::Reshape(s.WithOpName("r2"), a2, {4, 1});
149     Output r3 = ops::Reshape(s.WithOpName("r3"), a3, {4, 1});
150     TF_CHECK_OK(s.ToGraphDef(graph_def));
151   }
152 
153   // Constructs the following graph.
154   //
155   // a and b are data inputs.  ctl1 and ctl2 are control inputs.  a1 and a2 are
156   // Abs ops.  o1 and o2 are data outputs.  a1 -> ctl3 and a2 -> ctl4 are
157   // control edges.
158   //
159   // After the optimizer runs, we expect the ctl1 and ctl2 to be connected to
160   // the SAConcat node, and ctl3 and ctl4 to be connected to SASplit node.
161   /*
162      a  ctl1   b  ctl2
163       \  /      \  /
164        a1        a2
165       /  \      /  \
166      o1  ctl3  o2   ctl4
167   */
BuildAbsGraphWithInputAndOutputControlEdges(GraphDef * graph_def)168   void BuildAbsGraphWithInputAndOutputControlEdges(GraphDef* graph_def) {
169     Scope s = Scope::NewRootScope();
170     s = s.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0");
171 
172     Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
173                                 ops::Placeholder::Shape({2, 2}));
174     Output b = ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
175                                 ops::Placeholder::Shape({2, 2}));
176     Output ctl1 = ops::Placeholder(s.WithOpName("ctl1"), DT_FLOAT,
177                                    ops::Placeholder::Shape({2, 2}));
178     Output ctl2 = ops::Placeholder(s.WithOpName("ctl2"), DT_FLOAT,
179                                    ops::Placeholder::Shape({2, 2}));
180     Output a1 = ops::Abs(s.WithOpName("a1").WithControlDependencies({ctl1}), a);
181     Output a2 = ops::Abs(s.WithOpName("a2").WithControlDependencies({ctl2}), b);
182     Output o1 = ops::Reshape(s.WithOpName("o1"), a1, {1, 4});
183     Output o2 = ops::Reshape(s.WithOpName("o2"), a2, {4, 1});
184     Output ctl3 =
185         ops::Const<float>(s.WithOpName("ctl3").WithControlDependencies({a1}),
186                           {0.0, 0.0, 0.0, 0.0}, {2, 2});
187     Output ctl4 =
188         ops::Const<float>(s.WithOpName("ctl4").WithControlDependencies({a2}),
189                           {0.0, 0.0, 0.0, 0.0}, {2, 2});
190     TF_CHECK_OK(s.ToGraphDef(graph_def));
191   }
192 
193   // Constructs the following graph.
194   //
195   // We have 2 different name scopes in this graph.  s3, a3, a4, r3, and r4 are
196   // all under "sub" scope.  All other nodes are in the root scope.
197   //
198   // The intention is to test that ScopedAllocatorOptimizer works well with a
199   // graph that has multiple name scopes.  In particular, it should work when a
200   // node (in this case s2) is an input to two nodes in different name scopes
201   // (a2 and sub/a3) which may be scope allocated.
202   /*
203         a    b    c         a    b
204          \  / \  /           \  /
205           s1   s2------      sub/s3
206           |    |      |        |
207           a1   a2   sub/a4   sub/a3
208           |    |      |        |
209           r1   r2   sub/r4   sub/r3
210   */
BuildGraphWithMultipleScopes(GraphDef * graph_def)211   void BuildGraphWithMultipleScopes(GraphDef* graph_def) {
212     Scope root_scope = Scope::NewRootScope();
213     root_scope =
214         root_scope.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0");
215 
216     Output a = ops::Const<float>(root_scope.WithOpName("a"),
217                                  {1.0, 0.0, 0.0, -1.0}, {2, 2});
218     Output b = ops::Const<float>(root_scope.WithOpName("b"),
219                                  {1.0, -2.0, 3.0, 4.0}, {2, 2});
220     Output c = ops::Const<float>(root_scope.WithOpName("c"),
221                                  {-5.0, -2.0, 0.0, -2.0}, {2, 2});
222 
223     // Root scope ops.
224     Output s1 = ops::Add(root_scope.WithOpName("s1"), a, b);
225     Output s2 = ops::Add(root_scope.WithOpName("s2"), b, c);
226     Output a1 = ops::Abs(root_scope.WithOpName("a1"), s1);
227     Output a2 = ops::Abs(root_scope.WithOpName("a2"), s2);
228     Output r1 = ops::Reshape(root_scope.WithOpName("r1"), a1, {1, 4});
229     Output r2 = ops::Reshape(root_scope.WithOpName("r2"), a2, {4, 1});
230 
231     // Sub scope ops.
232     Scope sub_scope = root_scope.NewSubScope("sub");
233     Output s3 = ops::Add(sub_scope.WithOpName("s3"), a, b);
234     Output a3 = ops::Abs(sub_scope.WithOpName("a3"), s3);
235     Output a4 = ops::Abs(sub_scope.WithOpName("a4"), s2);
236     Output r3 = ops::Reshape(sub_scope.WithOpName("r3"), a3, {1, 4});
237     Output r4 = ops::Reshape(sub_scope.WithOpName("r4"), a4, {4, 1});
238 
239     TF_CHECK_OK(root_scope.ToGraphDef(graph_def));
240   }
241 
242   // Constructs the following graph.
243   //
244   // c1 and c2 are Const ops.  a1 and a2 are Abs ops.
245   // We expect the optimizer to succeed and insert Identity between ci and ai.
246   // This will ensure that we will still be able use ScopedAllocator with Const
247   // inputs.
248   /*
249           c1   c2
250           |    |
251           a1   a2
252           |    |
253           r1   r2
254   */
BuildConstGraph(GraphDef * graph_def,bool forward)255   void BuildConstGraph(GraphDef* graph_def, bool forward) {
256     Scope s = Scope::NewRootScope();
257     s = s.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0");
258 
259     Output c1 =
260         ops::Const<float>(s.WithOpName("c1"), {1.0, 0.0, 0.0, -1.0}, {2, 2});
261     Output c2 =
262         ops::Const<float>(s.WithOpName("c2"), {1.0, -2.0, 3.0, 4.0}, {2, 2});
263     Output a1 = ops::Abs(s.WithOpName("a1"), c1);
264     Output a2 = ops::Abs(s.WithOpName("a2"), c2);
265     Output r1 = ops::Reshape(s.WithOpName("r1"), a1, {1, 4});
266     Output r2 = ops::Reshape(s.WithOpName("r2"), a2, {4, 1});
267     TF_CHECK_OK(s.ToGraphDef(graph_def));
268   }
269 
SetShapes(GraphDef * graph_def)270   void SetShapes(GraphDef* graph_def) {
271     TensorShapeProto shape_proto;
272     shape_proto.add_dim()->set_size(2);
273     shape_proto.add_dim()->set_size(2);
274 
275     for (NodeDef& n : *graph_def->mutable_node()) {
276       if (n.op() == "Add" || n.op() == "Abs") {
277         AddNodeAttr("_output_shapes", {shape_proto}, &n);
278       }
279     }
280   }
281 
282   // Invokes ScopedAllocatorOptimizer on `graph_def`, then executes it and
283   // returns the outputs specified by `output_names` in `outputs`.
ExecuteGraph(const GraphDef & graph_def,const std::vector<string> & output_names,std::vector<Tensor> * outputs)284   void ExecuteGraph(const GraphDef& graph_def,
285                     const std::vector<string>& output_names,
286                     std::vector<Tensor>* outputs) {
287     // Turn off all optimization except the ScopedAllocatorOptimizer
288     // to avoid anything that would alter the expected graph input/output,
289     // e.g. by constant folding away all calculations.
290     ConfigProto config;
291     GraphOptions* gopt = config.mutable_graph_options();
292     OptimizerOptions* opts = gopt->mutable_optimizer_options();
293     opts->set_do_common_subexpression_elimination(false);
294     opts->set_do_constant_folding(false);
295     opts->set_do_function_inlining(false);
296     opts->set_opt_level(OptimizerOptions::L0);
297     RewriterConfig* rwcfg = gopt->mutable_rewrite_options();
298     rwcfg->clear_optimizers();
299     (*rwcfg->add_optimizers()) = "scoped_allocator";
300     rwcfg->mutable_scoped_allocator_opts()->add_enable_op("Abs");
301     std::unique_ptr<Session> session(CreateSession(graph_def, config));
302 
303     std::vector<std::pair<string, Tensor>> inputs;
304     std::vector<string> target_nodes = {};
305     Status s = session->Run(inputs, output_names, target_nodes, outputs);
306     TF_ASSERT_OK(s);
307     ASSERT_EQ(outputs->size(), output_names.size());
308   }
309 
310   // Validates that outputs match expected.
ValidateValues(const std::vector<Tensor> & outputs,const std::vector<std::vector<float>> & expected)311   void ValidateValues(const std::vector<Tensor>& outputs,
312                       const std::vector<std::vector<float>>& expected) {
313     for (int i = 0; i < expected.size(); ++i) {
314       EXPECT_EQ(expected[i].size(), outputs[i].NumElements());
315       for (int j = 0; j < expected[i].size(); ++j) {
316         EXPECT_EQ(expected[i][j], outputs[i].flat<float>()(j));
317       }
318     }
319   }
320 
GetNode(NodeMap * node_map,const string & node_name,NodeDef ** node_def)321   void GetNode(NodeMap* node_map, const string& node_name, NodeDef** node_def) {
322     *node_def = node_map->GetNode(node_name);
323     ASSERT_TRUE(*node_def);
324   }
325 
326   // Validate that a node has a single control input from scoped allocator node.
327   // Return the scoped allocator node.
ValidateSAControlInput(GraphDef * graph,NodeMap * node_map,const string & node_name)328   NodeDef* ValidateSAControlInput(GraphDef* graph, NodeMap* node_map,
329                                   const string& node_name) {
330     NodeDef* node = nullptr;
331     GetNode(node_map, node_name, &node);
332     int num_control_inputs = 0;
333     string control_input_name;
334     for (const auto& input : node->input()) {
335       if (IsControlInput(input)) {
336         ++num_control_inputs;
337         control_input_name = input;
338       }
339     }
340     EXPECT_EQ(num_control_inputs, 1);
341     NodeDef* control_input_node = nullptr;
342     GetNode(node_map, control_input_name, &control_input_node);
343     EXPECT_EQ(control_input_node->op(), "_ScopedAllocator");
344     return control_input_node;
345   }
346 
NumControlInputs(NodeMap * node_map,const string & node_name)347   int NumControlInputs(NodeMap* node_map, const string& node_name) {
348     NodeDef* node = nullptr;
349     GetNode(node_map, node_name, &node);
350     int num_control_inputs = 0;
351     for (const auto& input : node->input()) {
352       if (IsControlInput(input)) {
353         ++num_control_inputs;
354       }
355     }
356     return num_control_inputs;
357   }
358 };
359 #ifndef ENABLE_MKL
360 
TEST_F(ScopedAllocatorOptimizerTest,UnaryRewriteOnly)361 TEST_F(ScopedAllocatorOptimizerTest, UnaryRewriteOnly) {
362   // Tests that Rewrite of program with parallel unary Ops is done as
363   // anticipated.
364   GrapplerItem item;
365   BuildAbsGraph(&item.graph, false);
366   SetShapes(&item.graph);
367 
368   ScopedAllocatorOptions opts;
369   opts.add_enable_op("Abs");
370   ScopedAllocatorOptimizer sao(RewriterConfig::ON, opts);
371   ScopedAllocatorOptimizer::OpNameSet ons;
372   ons.insert("Abs");
373 
374   GraphDef optimized_graph;
375   TF_ASSERT_OK(sao.Optimize(nullptr /*cluster*/, item, &optimized_graph));
376 
377   // Examine the resulting graph def.
378   NodeMap node_map(&optimized_graph);
379   NodeDef* nd = nullptr;
380   GetNode(&node_map, "scoped_allocator_1_1", &nd);
381   {
382     auto& nd_set = node_map.GetOutputs(nd->name());
383     ASSERT_EQ(3, nd_set.size());
384     std::unordered_set<string> expected = {"scoped_allocator_concat_1_1", "s1",
385                                            "s2"};
386     for (auto it : nd_set) {
387       ASSERT_NE(expected.find(it->name()), expected.end())
388           << "Failed to find " << it->name();
389     }
390   }
391   {
392     auto& nd_set = node_map.GetOutputs("scoped_allocator_concat_1_1");
393     ASSERT_EQ(1, nd_set.size());
394     for (auto it : nd_set) {
395       ASSERT_EQ("scoped_allocator_1_1_Abs", it->name());
396     }
397   }
398   {
399     auto& nd_set = node_map.GetOutputs("scoped_allocator_1_1_Abs");
400     ASSERT_EQ(1, nd_set.size());
401     for (auto it : nd_set) {
402       ASSERT_EQ("scoped_allocator_split_1_1", it->name());
403     }
404   }
405   {
406     auto& nd_set = node_map.GetOutputs("scoped_allocator_split_1_1");
407     ASSERT_EQ(2, nd_set.size());
408     std::unordered_set<string> name_set;
409     for (auto it : nd_set) {
410       name_set.insert(it->name());
411     }
412     ASSERT_TRUE(name_set.find("r1") != name_set.end());
413     ASSERT_TRUE(name_set.find("r2") != name_set.end());
414   }
415 }
416 
TEST_F(ScopedAllocatorOptimizerTest,UnaryExecute)417 TEST_F(ScopedAllocatorOptimizerTest, UnaryExecute) {
418   // Builds the same graph as UnaryRewriteOnly but also executes it and
419   // validates the output.
420   GraphDef graph_def;
421   BuildAbsGraph(&graph_def, /*forward=*/false);
422   SetShapes(&graph_def);
423   std::vector<Tensor> outputs;
424   ExecuteGraph(graph_def,
425                /*output_names=*/{"r1:0", "r2:0"}, &outputs);
426   // a + b == 2, -2, 3, 3
427   // b + c == -4, -4, 3, 2
428   ValidateValues(outputs, /*expected=*/{{2, 2, 3, 3}, {4, 4, 3, 2}});
429 }
430 
TEST_F(ScopedAllocatorOptimizerTest,MultipleScopes)431 TEST_F(ScopedAllocatorOptimizerTest, MultipleScopes) {
432   GraphDef graph_def;
433   BuildGraphWithMultipleScopes(&graph_def);
434   SetShapes(&graph_def);
435   std::vector<Tensor> outputs;
436   ExecuteGraph(graph_def,
437                /*output_names=*/{"r1:0", "r2:0", "sub/r3:0", "sub/r4:0"},
438                &outputs);
439   ValidateValues(
440       outputs,
441       /*expected=*/{{2, 2, 3, 3}, {4, 4, 3, 2}, {2, 2, 3, 3}, {4, 4, 3, 2}});
442 }
443 
444 // Tests static ScopedAllocatorOptimizer::ExtendNodeAttr.
445 // Maybe this should be moved elsewhere?
TEST_F(ScopedAllocatorOptimizerTest,Extend)446 TEST_F(ScopedAllocatorOptimizerTest, Extend) {
447   NodeDef nd;
448   ScopedAllocatorOptimizer::ExtendNodeAttr("_scoped_allocator", {0, 2}, &nd);
449   ScopedAllocatorOptimizer::ExtendNodeAttr("_scoped_allocator", {6, 7}, &nd);
450   ScopedAllocatorOptimizer::ExtendNodeAttr("_scoped_allocator", {2, 3}, &nd);
451   VLOG(0) << "nd: " << nd.DebugString();
452   std::vector<int> scoped_allocator_attrs;
453   AttrSlice slice(nd);
454   Status sa_status =
455       GetNodeAttr(slice, "_scoped_allocator", &scoped_allocator_attrs);
456   for (int i : scoped_allocator_attrs) {
457     VLOG(0) << "extracted: " << i;
458   }
459   NodeDef nd2;
460   AddNodeAttr("_scoped_allocator", {0, 2}, &nd2);
461   AddNodeAttr("_scoped_allocator", {6, 7}, &nd2);
462   AddNodeAttr("_scoped_allocator", {2, 3}, &nd2);
463   VLOG(0) << "nd2: " << nd2.DebugString();
464 }
465 
TEST_F(ScopedAllocatorOptimizerTest,ForwardInputToOutput)466 TEST_F(ScopedAllocatorOptimizerTest, ForwardInputToOutput) {
467   // Test that kernels that forward the input to output using `set_output` work
468   // well with scoped allocator optimization.
469   GraphDef graph_def;
470   BuildAbsGraph(&graph_def, /*forward=*/true);
471   SetShapes(&graph_def);
472   std::vector<Tensor> outputs;
473   ExecuteGraph(graph_def, /*output_names=*/{"r1:0", "r2:0"}, &outputs);
474   // a + b == 2, -2, 3, 3
475   // b + c == -4, -4, 3, 2
476   ValidateValues(outputs, /*expected=*/{{2, 2, 3, 3}, {4, 4, 3, 2}});
477 }
478 
479 // Test that graphs with a dependency upstream from the inputs, such as the one
480 // produced by `BuildAbsGraphWithInputDependencies`, are handled well by this
481 // optimizer.  In particular, the optimizer should not create cycles.
TEST_F(ScopedAllocatorOptimizerTest,InputDependencies)482 TEST_F(ScopedAllocatorOptimizerTest, InputDependencies) {
483   GrapplerItem item;
484   BuildAbsGraphWithInputDependencies(&item.graph);
485   SetShapes(&item.graph);
486 
487   ScopedAllocatorOptions opts;
488   opts.add_enable_op("Abs");
489   ScopedAllocatorOptimizer sao(RewriterConfig::ON, opts);
490   ScopedAllocatorOptimizer::OpNameSet ons;
491   ons.insert("Add");
492 
493   GraphDef optimized_graph;
494   TF_ASSERT_OK(sao.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
495   NodeMap node_map(&optimized_graph);
496 
497   // Check that all inputs to Abs ops have ScopedAllocator as a control
498   // dependency.
499   NodeDef* scoped_allocator_node =
500       ValidateSAControlInput(&optimized_graph, &node_map, "a");
501   VLOG(1) << scoped_allocator_node->DebugString();
502   EXPECT_TRUE(ValidateSAControlInput(&optimized_graph, &node_map, "b"));
503   EXPECT_TRUE(ValidateSAControlInput(&optimized_graph, &node_map, "s1"));
504 
505   // Check that ScopedAllocator node has a single input, which is a control edge
506   // from c.
507   EXPECT_EQ(scoped_allocator_node->input_size(), 1);
508   EXPECT_EQ(scoped_allocator_node->input(0), "^c");
509 }
510 
511 // Test that graphs with input and output control edges are rewired correctly by
512 // the optimizer.
TEST_F(ScopedAllocatorOptimizerTest,ControlEdgeRewire)513 TEST_F(ScopedAllocatorOptimizerTest, ControlEdgeRewire) {
514   GrapplerItem item;
515   BuildAbsGraphWithInputAndOutputControlEdges(&item.graph);
516   SetShapes(&item.graph);
517   LOG(INFO) << item.graph.DebugString();
518 
519   ScopedAllocatorOptions opts;
520   opts.add_enable_op("Abs");
521   ScopedAllocatorOptimizer sao(RewriterConfig::ON, opts);
522   ScopedAllocatorOptimizer::OpNameSet ons;
523   ons.insert("Const");
524 
525   GraphDef optimized_graph;
526   TF_ASSERT_OK(sao.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
527   TF_ASSERT_OK(TopologicalSort(&optimized_graph));
528   NodeMap node_map(&optimized_graph);
529   LOG(INFO) << optimized_graph.DebugString();
530 
531   // Check that ctl1 and ctl2 are now connected only to SAConcat.
532   NodeDef* ctl1 = nullptr;
533   GetNode(&node_map, "ctl1", &ctl1);
534   const auto& ctl1_outputs = node_map.GetOutputs("ctl1");
535   EXPECT_EQ(ctl1_outputs.size(), 1);
536   NodeDef* sa_concat = *ctl1_outputs.begin();
537   EXPECT_EQ(sa_concat->op(), "_ScopedAllocatorConcat");
538   NodeDef* ctl2 = nullptr;
539   GetNode(&node_map, "ctl2", &ctl2);
540   const auto& ctl2_outputs = node_map.GetOutputs("ctl2");
541   EXPECT_EQ(ctl2_outputs.size(), 1);
542   EXPECT_EQ(*ctl2_outputs.begin(), sa_concat);
543 
544   // Check that SAConcat has only 2 input control edges.
545   EXPECT_EQ(NumControlInputs(&node_map, sa_concat->name()), 2);
546 
547   // Check that fused node, which conceptually used to have control inputs from
548   // ctl1 and ctl2 respectively, no longer has any control inputs.
549   const auto& sa_concat_outputs = node_map.GetOutputs(sa_concat->name());
550   EXPECT_EQ(sa_concat_outputs.size(), 1);
551   NodeDef* fused_abs = *sa_concat_outputs.begin();
552   EXPECT_EQ(NumControlInputs(&node_map, fused_abs->name()), 0);
553 
554   // Check that SASplit node has control edges to ctl3, ctl4; also check that
555   // those are the only control inputs on ctl3 and ctl4.
556   const auto& fused_abs_outputs = node_map.GetOutputs(fused_abs->name());
557   EXPECT_EQ(fused_abs_outputs.size(), 1);
558   NodeDef* sa_split = *fused_abs_outputs.begin();
559   EXPECT_EQ(NumControlOutputs(*sa_split, node_map), 2);
560   EXPECT_EQ(NumControlInputs(&node_map, "ctl3"), 1);
561   EXPECT_EQ(NumControlInputs(&node_map, "ctl4"), 1);
562 }
563 
564 // Test that the optimization succeeds when any input is a Const op, and that it
565 // inserts Identity op between Const and Abs.
TEST_F(ScopedAllocatorOptimizerTest,ConstInput)566 TEST_F(ScopedAllocatorOptimizerTest, ConstInput) {
567   GrapplerItem item;
568   BuildConstGraph(&item.graph, false);
569   SetShapes(&item.graph);
570 
571   ScopedAllocatorOptions opts;
572   opts.add_enable_op("Abs");
573   ScopedAllocatorOptimizer sao(RewriterConfig::ON, opts);
574   ScopedAllocatorOptimizer::OpNameSet ons;
575   ons.insert("Abs");
576 
577   GraphDef optimized_graph;
578   TF_ASSERT_OK(sao.Optimize(nullptr /*cluster*/, item, &optimized_graph));
579 
580   // Examine the resulting graphdef.
581   const NodeDef* sa_node = nullptr;
582   for (const NodeDef& node : optimized_graph.node()) {
583     if (node.op() == "_ScopedAllocator") {
584       sa_node = &node;
585       break;
586     }
587   }
588   ASSERT_NE(sa_node, nullptr);
589   int num_identity_ops = 0;
590   NodeMap node_map(&optimized_graph);
591   for (NodeDef* sa_output : node_map.GetOutputs(sa_node->name())) {
592     EXPECT_FALSE(IsConstant(*sa_output));
593     if (IsIdentity(*sa_output)) {
594       ++num_identity_ops;
595     }
596   }
597   EXPECT_EQ(num_identity_ops, 2);
598 }
599 #endif  // ENABLE_MKL
600 
601 }  // namespace
602 }  // namespace grappler
603 }  // namespace tensorflow
604