xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/function_optimizer_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/cc/ops/functional_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/function_testlib.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/grappler/grappler_item.h"
24 #include "tensorflow/core/grappler/op_types.h"
25 #include "tensorflow/core/grappler/utils/grappler_test.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/gtl/flatset.h"
28 
29 namespace tensorflow {
30 namespace grappler {
31 
32 namespace {
33 constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
34 }  // namespace
35 
36 class FunctionOptimizerTest : public GrapplerTest {};
37 
TEST_F(FunctionOptimizerTest,InlineFunction_SimpleFunction)38 TEST_F(FunctionOptimizerTest, InlineFunction_SimpleFunction) {
39   using test::function::NDef;
40 
41   FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
42 
43   // Build a graph to compute y = XTimesTwo(x)
44   GrapplerItem item;
45   item.graph = test::function::GDef(
46       {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
47        NDef("y", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, kDevice),
48        NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
49       // FunctionLib
50       {
51           test::function::XTimesTwo(),
52       });
53 
54   GraphDef output;
55   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
56 
57   const string arg0 = "Func/y/input/_0";
58   const string ret0 = "Func/y/output/_1";
59 
60   const Tensor kTwo = test::AsScalar<int64_t>(2);
61   GraphDef expected = test::function::GDef(
62       {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}),
63        NDef(arg0, "Identity", {"x"}, {{"T", DT_FLOAT}}),
64        NDef("y/two", "Const", {}, {{"dtype", DT_INT64}, {"value", kTwo}}),
65        NDef("y/scale", "Cast", {"y/two"},
66             {{"DstT", DT_FLOAT}, {"SrcT", DT_INT64}}),
67        NDef("y/y", "Mul", {arg0, "y/scale"}, {{"T", DT_FLOAT}}),
68        NDef(ret0, "Identity", {"y/y"}, {{"T", DT_FLOAT}}),
69        NDef("z", "Identity", {ret0}, {{"T", DT_FLOAT}})},
70       {});
71   for (NodeDef& node : *expected.mutable_node()) node.set_device(kDevice);
72 
73   CompareGraphs(expected, output);
74 
75   Tensor pi = test::AsScalar<float>(3.14f);
76   item.fetch = {"z"};
77   item.feed.emplace_back("x", pi);
78   auto tensors_expected = EvaluateFetchNodes(item);
79   GrapplerItem optimized = item.WithGraph(std::move(output));
80   auto tensors = EvaluateFetchNodes(optimized);
81   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
82 }
83 
TEST_F(FunctionOptimizerTest,InlineFunction_FixedTypeFunction)84 TEST_F(FunctionOptimizerTest, InlineFunction_FixedTypeFunction) {
85   using test::function::NDef;
86 
87   FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
88 
89   // Create and instantiate a version of the XTimesTwo function that only
90   // accepts floats a inputs.
91   const Tensor kTwo = test::AsScalar<float>(2.0f);
92   FunctionDef x_times_two = FunctionDefHelper::Define(
93       // Name
94       "XTimesTwo",
95       // Args
96       {"x: float"},
97       // Return values
98       {"y: float"},
99       // Attr def
100       {},
101       // Nodes
102       {
103           {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
104           // "enter" node is used to verify that InlineFunction would update the
105           // frame name accordingly.
106           {{"enter"},
107            "Enter",
108            {"x"},
109            {{"T", DT_FLOAT}, {"frame_name", "frame"}}},
110           {{"y"}, "Mul", {"x", "two"}, {{"T", DT_FLOAT}}},
111       });
112 
113   GrapplerItem item;
114   item.graph = test::function::GDef(
115       {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
116        NDef("y", "XTimesTwo", {"x"}, {}, kDevice),
117        NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
118       // FunctionLib
119       {
120           x_times_two,
121       });
122 
123   GraphDef output;
124   Status status = optimizer.Optimize(nullptr, item, &output);
125   TF_EXPECT_OK(status);
126 
127   // Calls to XTimesTwo were removed from the graph.
128   for (const NodeDef& node : output.node()) {
129     EXPECT_NE(node.op(), "XTimesTwo");
130   }
131   // And the function itself was removed from the library.
132   EXPECT_EQ(output.library().function_size(), 0);
133 
134   Tensor pi = test::AsScalar<float>(3.14f);
135   item.fetch = {"z"};
136   item.feed.emplace_back("x", pi);
137   auto tensors_expected = EvaluateFetchNodes(item);
138   GrapplerItem optimized = item.WithGraph(std::move(output));
139   auto tensors = EvaluateFetchNodes(optimized);
140   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
141 }
142 
TEST_F(FunctionOptimizerTest,InlineFunction_FunctionWithOutputMapping)143 TEST_F(FunctionOptimizerTest, InlineFunction_FunctionWithOutputMapping) {
144   using test::function::NDef;
145 
146   FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
147 
148   FunctionDef func = FunctionDefHelper::Create(
149       // Name
150       "Exp_func",
151       // Args
152       {"in: float"},
153       // Return values
154       {"out: float"},
155       // Attr def
156       {},
157       // Nodes
158       {{{"Linear_func"}, "Identity", {"in"}, {{"T", DT_FLOAT}}},
159        {{"Exp"}, "Exp", {"Linear_func:output:0"}, {{"T", DT_FLOAT}}}},
160       // Mapping
161       {{"out", "Exp:y:0"}});
162 
163   GrapplerItem item;
164   item.graph = test::function::GDef(
165       {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
166        NDef("y", "Exp_func", {"x"}, {}, kDevice),
167        NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
168       // FunctionLib
169       {
170           func,
171       });
172 
173   GraphDef output;
174   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
175 
176   // Function call was removed from the graph.
177   for (const NodeDef& node : output.node()) {
178     EXPECT_NE(node.op(), "Exp_func");
179   }
180   // And the function itself was removed from the library.
181   EXPECT_EQ(output.library().function_size(), 0);
182 
183   Tensor pi = test::AsScalar<float>(3.14f);
184   item.fetch = {"z"};
185   item.feed.emplace_back("x", pi);
186   auto tensors_expected = EvaluateFetchNodes(item);
187   GrapplerItem optimized = item.WithGraph(std::move(output));
188   auto tensors = EvaluateFetchNodes(optimized);
189   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
190 }
191 
TEST_F(FunctionOptimizerTest,InlineFunction_FunctionWithInputForwarding)192 TEST_F(FunctionOptimizerTest, InlineFunction_FunctionWithInputForwarding) {
193   using test::function::NDef;
194 
195   FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
196 
197   FunctionDef func = FunctionDefHelper::Create(
198       // Name
199       "ForwardInputs",
200       // Args
201       {"in0: float", "in1: float", "arg2: float", "arg3: int32", "arg4: float"},
202       // Return values
203       {"out0: float", "arg2: float", "arg3: int32"},
204       // Attr def
205       {},
206       // Nodes
207       {},
208       // Mapping
209       {{"out0", "in0"}, {"arg2", "arg2"}, {"arg3", "arg3"}});
210 
211   GrapplerItem item;
212   item.graph = test::function::GDef(
213       {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
214        NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
215        NDef("x2", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
216        NDef("x3", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
217        NDef("x4", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
218        NDef("y", "ForwardInputs", {"x0", "x1", "x2", "x3", "x4"}, {}, kDevice),
219        NDef("z0", "Identity", {"y:0"}, {{"T", DT_FLOAT}}, kDevice),
220        NDef("z1", "Identity", {"y:1"}, {{"T", DT_FLOAT}}, kDevice),
221        NDef("z2", "Identity", {"y:2"}, {{"T", DT_INT32}}, kDevice)},
222       // FunctionLib
223       {
224           func,
225       });
226 
227   GraphDef output;
228   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
229 
230   // Function call was removed from the graph.
231   for (const NodeDef& node : output.node()) {
232     EXPECT_NE(node.op(), "ForwardInputs");
233   }
234   // And the function itself was removed from the library.
235   EXPECT_EQ(output.library().function_size(), 0);
236 
237   item.fetch = {"z0", "z1", "z2"};
238   item.feed.emplace_back("x0", test::AsScalar<float>(3.14f));
239   item.feed.emplace_back("x1", test::AsScalar<float>(2.7f));
240   item.feed.emplace_back("x2", test::AsScalar<float>(1.0f));
241   item.feed.emplace_back("x4", test::AsScalar<float>(-1.0f));
242   item.feed.emplace_back("x3", test::AsScalar<int>(1234));
243   auto tensors_expected = EvaluateFetchNodes(item);
244   GrapplerItem optimized = item.WithGraph(std::move(output));
245   auto tensors = EvaluateFetchNodes(optimized);
246   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
247   test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
248   test::ExpectTensorEqual<int>(tensors_expected[2], tensors[2]);
249 }
250 
TEST_F(FunctionOptimizerTest,InlineFunction_FunctionWithoutInput)251 TEST_F(FunctionOptimizerTest, InlineFunction_FunctionWithoutInput) {
252   using test::function::NDef;
253 
254   FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
255 
256   const Tensor kTwo = test::AsScalar<int64_t>(2);
257   FunctionDef func = FunctionDefHelper::Define(
258       // Name
259       "GenerateTwo",
260       // Args
261       {},
262       // Return value
263       {"o: T"},
264       // Attr def
265       {"T: {float, double}"},
266       // Nodes
267       {{{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
268        {{"o"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}});
269 
270   GrapplerItem item;
271   item.graph = test::function::GDef(
272       {NDef("y", "GenerateTwo", {}, {{"T", DT_FLOAT}}, kDevice),
273        NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
274       // FunctionLib
275       {
276           func,
277       });
278 
279   GraphDef output;
280   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
281 
282   // Function call was removed from the graph.
283   for (const NodeDef& node : output.node()) {
284     EXPECT_NE(node.op(), "GenerateTwo");
285   }
286   // And the function itself was removed from the library.
287   EXPECT_EQ(output.library().function_size(), 0);
288 
289   item.fetch = {"z"};
290   auto tensors_expected = EvaluateFetchNodes(item);
291   GrapplerItem optimized = item.WithGraph(std::move(output));
292   auto tensors = EvaluateFetchNodes(optimized);
293   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
294 }
295 
TEST_F(FunctionOptimizerTest,InlineFunction_FunctionWithNestedFunctionCall)296 TEST_F(FunctionOptimizerTest, InlineFunction_FunctionWithNestedFunctionCall) {
297   using test::function::NDef;
298 
299   FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
300 
301   // Define square via function library:
302   //   MySquare(x) = MyMul(x, x)
303 
304   FunctionDef mul_func = FunctionDefHelper::Create(
305       "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
306       {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
307       /* Mapping between function returns and function node outputs. */
308       {{"z", "output:z:0"}});
309 
310   FunctionDef square_func = FunctionDefHelper::Create(
311       "MySquare", {"x:T"}, {"z:T"}, {"T: {float, double}"},
312       {{{"output"}, "MyMul", {"x", "x"}, {{"T", "$T"}}}},
313       /* Mapping between function returns and function node outputs. */
314       {{"z", "output:z:0"}});
315 
316   GrapplerItem item;
317   item.graph = test::function::GDef(
318       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
319        NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}}, kDevice),
320        NDef("outputs", "Identity", {"square:0"}, {{"T", DT_FLOAT}}, kDevice)},
321       // FunctionLib
322       {mul_func, square_func});
323 
324   GraphDef output;
325   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
326 
327   // Function calls were removed from the graph.
328   for (const NodeDef& node : output.node()) {
329     EXPECT_NE(node.op(), "MySquare");
330     EXPECT_NE(node.op(), "MyMul");
331   }
332   // And functions were removed from the library.
333   EXPECT_EQ(output.library().function_size(), 0);
334 
335   item.fetch = {"outputs"};
336   item.feed.emplace_back("a", test::AsScalar<float>(2.0f));
337   auto tensors_expected = EvaluateFetchNodes(item);
338 
339   GrapplerItem optimized = item.WithGraph(std::move(output));
340   auto tensors = EvaluateFetchNodes(optimized);
341 
342   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
343 }
344 
TEST_F(FunctionOptimizerTest,InlineSymbolicGradient_TestFunc)345 TEST_F(FunctionOptimizerTest, InlineSymbolicGradient_TestFunc) {
346   FunctionOptimizer optimizer(RewriterConfig::ON, true);
347 
348   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
349 
350   FunctionDef func = FunctionDefHelper::Define(
351       "TestFunc", {"x:float", "y:float"}, {"l:float"}, {},
352       {
353           {{"z"}, "Add", {"x", "y"}, {{"T", DT_FLOAT}}},
354           FunctionDefHelper::Const("zero", 0),
355           FunctionDefHelper::Const("one", 1),
356           {{"r"}, "Rank", {"z"}, {{"T", DT_FLOAT}}},
357           {{"indices"}, "Range", {"zero", "r", "one"}},
358           {{"l"}, "Sum", {"z", "indices"}, {{"T", DT_FLOAT}}},
359       });
360 
361   auto x = ops::Const(scope, 1.0f);
362   auto y = ops::Const(scope, 2.0f);
363   auto dl = ops::Const(scope, 3.0f);
364 
365   NameAttrList fn;
366   fn.set_name("TestFunc");
367   (*fn.mutable_attr())["T"].set_type(DT_FLOAT);
368   auto g0 = ops::SymbolicGradient(scope, std::initializer_list<Input>{x, y, dl},
369                                   {DT_FLOAT, DT_FLOAT}, fn);
370   auto out1 = ops::Identity(scope.WithOpName("out1"), g0.output[0]);
371   auto out2 = ops::Identity(scope.WithOpName("out2"), g0.output[1]);
372 
373   GrapplerItem item;
374   TF_EXPECT_OK(scope.ToGraphDef(&item.graph));
375   *item.graph.mutable_library()->add_function() = func;
376 
377   GraphDef output;
378   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
379 
380   // SymbolicGradient calls were removed from the graph.
381   for (const NodeDef& node : output.node()) {
382     EXPECT_NE(node.op(), "SymbolicGradient");
383   }
384   // And functions were removed from the library.
385   EXPECT_EQ(output.library().function_size(), 0);
386 
387   std::vector<Tensor> expected =
388       EvaluateNodes(item.graph, {"out1", "out2"}, {});
389   std::vector<Tensor> optimized = EvaluateNodes(output, {"out1", "out2"}, {});
390   test::ExpectTensorEqual<float>(expected[0], optimized[0]);
391   test::ExpectTensorEqual<float>(expected[1], optimized[1]);
392 }
393 
TEST_F(FunctionOptimizerTest,InlineSymbolicGradient_IdentityFunc)394 TEST_F(FunctionOptimizerTest, InlineSymbolicGradient_IdentityFunc) {
395   FunctionOptimizer optimizer(RewriterConfig::ON, true);
396 
397   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
398 
399   FunctionDef func = FunctionDefHelper::Create(
400       // Name
401       "Identity_func",
402       // Args
403       {"in: float"},
404       // Return values
405       {"out: float"},
406       // Attr def
407       {},
408       // Nodes
409       {{{"Identity"}, "Identity", {"in"}, {{"T", DT_FLOAT}}}},
410       // Mapping
411       {{"out", "Identity:output:0"}});
412 
413   auto x = ops::Const(scope, 1.0f, {3, 5, 7});
414   auto z = ops::Const(scope, 3.0f, {3, 5, 7});
415 
416   NameAttrList fn;
417   fn.set_name("Identity_func");
418   auto g0 = ops::SymbolicGradient(scope, std::initializer_list<Input>{x, z},
419                                   {DT_FLOAT}, fn);
420   auto out = ops::Identity(scope.WithOpName("out"), g0.output[0]);
421 
422   GrapplerItem item;
423   TF_EXPECT_OK(scope.ToGraphDef(&item.graph));
424   *item.graph.mutable_library()->add_function() = func;
425 
426   GraphDef output;
427   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
428 
429   // SymbolicGradient calls were removed from the graph.
430   for (const NodeDef& node : output.node()) {
431     EXPECT_NE(node.op(), "SymbolicGradient");
432   }
433   // And functions were removed from the library.
434   EXPECT_EQ(output.library().function_size(), 0);
435 
436   std::vector<Tensor> expected = EvaluateNodes(item.graph, {"out"}, {});
437   std::vector<Tensor> optimized = EvaluateNodes(output, {"out"}, {});
438   test::ExpectTensorEqual<float>(expected[0], optimized[0]);
439 }
440 
TEST_F(FunctionOptimizerTest,InlineSymbolicGradientNoInlineFunc)441 TEST_F(FunctionOptimizerTest, InlineSymbolicGradientNoInlineFunc) {
442   FunctionOptimizer optimizer(RewriterConfig::ON, true);
443 
444   FunctionDef func = FunctionDefHelper::Define(
445       "TestFunc", {"x:float", "y:float"}, {"l:float"}, {},
446       {
447           {{"z"}, "Add", {"x", "y"}, {{"T", DT_FLOAT}}},
448           FunctionDefHelper::Const("zero", 0),
449           FunctionDefHelper::Const("one", 1),
450           {{"r"}, "Rank", {"z"}, {{"T", DT_FLOAT}}},
451           {{"indices"}, "Range", {"zero", "r", "one"}},
452           {{"l"}, "Sum", {"z", "indices"}, {{"T", DT_FLOAT}}},
453       });
454   (*func.mutable_attr())["_noinline"].set_b(true);
455 
456   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
457   auto x = ops::Const(scope, 1.0f);
458   auto y = ops::Const(scope, 2.0f);
459   auto dl = ops::Const(scope, 3.0f);
460 
461   NameAttrList fn;
462   fn.set_name("TestFunc");
463   (*fn.mutable_attr())["T"].set_type(DT_FLOAT);
464   auto g0 = ops::SymbolicGradient(scope, std::initializer_list<Input>{x, y, dl},
465                                   {DT_FLOAT, DT_FLOAT}, fn);
466   auto out1 = ops::Identity(scope.WithOpName("out1"), g0.output[0]);
467   auto out2 = ops::Identity(scope.WithOpName("out2"), g0.output[1]);
468 
469   GrapplerItem item;
470   TF_EXPECT_OK(scope.ToGraphDef(&item.graph));
471   *item.graph.mutable_library()->add_function() = func;
472 
473   GraphDef output;
474   Status status = optimizer.Optimize(nullptr, item, &output);
475   // The optimizer should succeed but the graphs should be the same.
476   TF_EXPECT_OK(status);
477   CompareGraphs(item.graph, output);
478 }
479 
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionSimpleFunction)480 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionSimpleFunction) {
481   using test::function::NDef;
482   using FDH = FunctionDefHelper;
483 
484   FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
485 
486   FunctionDef mul_func = FunctionDefHelper::Create(
487       "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
488       {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
489       /* Mapping between function returns and function node outputs. */
490       {{"z", "mul:z:0"}});
491 
492   // Build a graph to compute c = MyMul(a, b)
493   GrapplerItem item;
494   item.fetch = {"d"};
495   item.graph = test::function::GDef(
496       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
497        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
498        NDef("c", "PartitionedCall", {"a", "b"},
499             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
500              {"Tout", DataTypeSlice{DT_FLOAT}},
501              {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
502             kDevice),
503        NDef("d", "Identity", {"c"}, {{"T", DT_FLOAT}}, kDevice)},
504       {mul_func} /* Function library */);
505 
506   Tensor pi = test::AsScalar<float>(3.14f);
507   item.feed.emplace_back("a", pi);
508   item.feed.emplace_back("b", pi);
509 
510   const string input_x = "Func/c/input/_0";
511   const string input_y = "Func/c/input/_1";
512   const string output_z = "Func/c/output/_2";
513 
514   // If device set is empty, inlined function body must not be placed.
515   {
516     GraphDef optimized_graph;
517     TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
518 
519     GraphDef expected = test::function::GDef(
520         {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
521          NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
522 
523          // Function body nodes copy only job/task/replica parts of device
524          // assignment, and function input nodes must copy full device
525          // assignment from input arguments. Optimized graph is not fully
526          // placed.
527          NDef(input_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
528          NDef(input_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, kDevice),
529          // NOTE(ezhulenev): Currently multi-device function inlining placer
530          // strategy will override all empty devices with function call device.
531          NDef("c/mul", "Mul", {input_x, input_y}, {{"T", DT_FLOAT}}, kDevice),
532          NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}),
533 
534          NDef("d", "Identity", {output_z}, {{"T", DT_FLOAT}}, kDevice)},
535         // Function library.
536         {mul_func});
537 
538     CompareGraphs(expected, optimized_graph);
539 
540     GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
541     auto tensors_expected = EvaluateFetchNodes(item);
542     auto tensors = EvaluateFetchNodes(optimized);
543     ASSERT_EQ(tensors_expected.size(), 1);
544     ASSERT_EQ(tensors.size(), tensors_expected.size());
545     test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
546   }
547 
548   // If device set is not empty, inlined function body must be placed.
549   {
550     GraphDef optimized_graph;
551     TF_EXPECT_OK(item.AddDevice(kDevice));
552     TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
553 
554     GraphDef expected = test::function::GDef(
555         {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
556          NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
557 
558          NDef(input_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
559          NDef(input_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, kDevice),
560          NDef("c/mul", "Mul", {input_x, input_y}, {{"T", DT_FLOAT}}, kDevice),
561          NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, kDevice),
562 
563          NDef("d", "Identity", {output_z}, {{"T", DT_FLOAT}}, kDevice)},
564         // Function library.
565         {mul_func});
566 
567     CompareGraphs(expected, optimized_graph);
568 
569     GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
570     auto tensors_expected = EvaluateFetchNodes(item);
571     auto tensors = EvaluateFetchNodes(optimized);
572     ASSERT_EQ(tensors_expected.size(), 1);
573     ASSERT_EQ(tensors.size(), tensors_expected.size());
574     test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
575   }
576 }
577 
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionWithControlDependencies)578 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithControlDependencies) {
579   using test::function::NDef;
580   using FDH = FunctionDefHelper;
581 
582   FunctionOptimizer optimizer(RewriterConfig::ON, true);
583 
584   const Tensor kOne = test::AsScalar<float>(1.0);
585   const Tensor kTwo = test::AsScalar<float>(2.0);
586   const TensorShape scalar = TensorShape({});
587 
588   // Compute `x*y` and add `1.0` to the variable.
589   FunctionDef mul_func = FunctionDefHelper::Create(
590       "MyMul", {"x:T", "y:T", "v: resource"}, {"z:T"}, {"T: {float, double}"},
591       {{{"one"}, "Const", {}, {{"value", kOne}, {"dtype", DT_FLOAT}}},
592        {{"add"},
593         "AssignAddVariableOp",
594         {"v", "one:output:0"},
595         {{"dtype", DT_FLOAT}}},
596        {{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
597       /* Mapping between function returns and function node outputs. */
598       {{"z", "mul:z:0"}},
599       /* Control output to ensure that side effects will be executed. */
600       {{"size_effects", "add"}});
601 
602   // Build a graph to compute:
603   //   a = Placeholder
604   //   b = Placeholder
605   //   v = VarHandleOp(init = a)
606   //   f1 = MyMul(a, b, v)
607   //   f2 = MyMul(f1, f1, v)
608   //   return [f2, v]
609   GrapplerItem item;
610   TF_EXPECT_OK(item.AddDevice(kDevice));  // device for placing inlined function
611   item.fetch = {"out_1", "out_2"};
612   item.graph = test::function::GDef(
613       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
614        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
615 
616        // Initialize variable with one of the placeholders.
617        NDef("v", "VarHandleOp", {}, {{"dtype", DT_FLOAT}, {"shape", scalar}}),
618        NDef("init_v", "AssignVariableOp", {"v", "a"}, {{"dtype", DT_FLOAT}},
619             kDevice),
620 
621        // Call function first time.
622        NDef("f1", "PartitionedCall", {"a", "b", "v", "^init_v"},
623             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_RESOURCE}},
624              {"Tout", DataTypeSlice{DT_FLOAT}},
625              {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
626             kDevice),
627 
628        // Call function second time.
629        NDef("f2", "PartitionedCall", {"f1", "f1", "v", "^f1"},
630             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_RESOURCE}},
631              {"Tout", DataTypeSlice{DT_FLOAT}},
632              {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
633             kDevice),
634 
635        // Return result of multiplication and a current value of the variable.
636        NDef("out_1", "Identity", {"f2"}, {{"T", DT_FLOAT}}, kDevice),
637        NDef("out_2", "ReadVariableOp", {"v", "^f1", "^f2"},
638             {{"dtype", DT_FLOAT}}, kDevice)},
639 
640       // Function library.
641       {mul_func});
642 
643   GraphDef optimized_graph;
644   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
645 
646   GraphDef expected = test::function::GDef(
647       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
648        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
649 
650        // Initialize variable with one of the placeholders.
651        NDef("v", "VarHandleOp", {}, {{"dtype", DT_FLOAT}, {"shape", scalar}},
652             kDevice),
653        NDef("init_v", "AssignVariableOp", {"v", "a"}, {{"dtype", DT_FLOAT}},
654             kDevice),
655 
656        // Function body of a first function call inlined into the graph.
657        NDef("Func/f1/input_control_node/_0", "NoOp", {"^init_v"}, {}, kDevice),
658 
659        NDef("Func/f1/input/_1", "Identity",  // input: 'x'
660             {"a", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
661             kDevice),
662        NDef("Func/f1/input/_2", "Identity",  // input: 'y'
663             {"b", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
664             kDevice),
665        NDef("Func/f1/input/_3", "Identity",  // input: 'v'
666             {"v", "^Func/f1/input_control_node/_0"}, {{"T", DT_RESOURCE}},
667             kDevice),
668 
669        NDef("f1/one", "Const", {"^Func/f1/input_control_node/_0"},
670             {{"dtype", DT_FLOAT}, {"value", kOne}}, kDevice),
671        NDef("f1/mul", "Mul", {"Func/f1/input/_1", "Func/f1/input/_2"},
672             {{"T", DT_FLOAT}}, kDevice),
673        NDef("f1/add", "AssignAddVariableOp", {"Func/f1/input/_3", "f1/one"},
674             {{"dtype", DT_FLOAT}}, kDevice),
675 
676        NDef("Func/f1/output/_4", "Identity", {"f1/mul"}, {{"T", DT_FLOAT}},
677             kDevice),
678        NDef("Func/f1/output_control_node/_5", "NoOp", {"^f1/add"}, {}, kDevice),
679 
680        // Function body of a second function call also inlined into the graph,
681        // and input nodes read from the output nodes of the first function call.
682        NDef("Func/f2/input_control_node/_6", "NoOp",
683             {"^Func/f1/output_control_node/_5"}, {}, kDevice),
684 
685        NDef("Func/f2/input/_7", "Identity",  // input: 'x'
686             {"Func/f1/output/_4", "^Func/f2/input_control_node/_6"},
687             {{"T", DT_FLOAT}}, kDevice),
688        NDef("Func/f2/input/_8", "Identity",  // input: 'y'
689             {"Func/f1/output/_4", "^Func/f2/input_control_node/_6"},
690             {{"T", DT_FLOAT}}, kDevice),
691        NDef("Func/f2/input/_9", "Identity",  // input: 'v'
692             {"v", "^Func/f2/input_control_node/_6"}, {{"T", DT_RESOURCE}},
693             kDevice),
694 
695        NDef("f2/one", "Const", {"^Func/f2/input_control_node/_6"},
696             {{"dtype", DT_FLOAT}, {"value", kOne}}, kDevice),
697        NDef("f2/add", "AssignAddVariableOp", {"Func/f2/input/_9", "f2/one"},
698             {{"dtype", DT_FLOAT}}, kDevice),
699        NDef("f2/mul", "Mul", {"Func/f2/input/_7", "Func/f2/input/_8"},
700             {{"T", DT_FLOAT}}, kDevice),
701 
702        NDef("Func/f2/output/_10", "Identity", {"f2/mul"}, {{"T", DT_FLOAT}},
703             kDevice),
704        NDef("Func/f2/output_control_node/_11", "NoOp", {"^f2/add"}, {},
705             kDevice),
706 
707        // Return values read from inlined output nodes.
708        NDef("out_1", "Identity", {"Func/f2/output/_10"}, {{"T", DT_FLOAT}},
709             kDevice),
710        NDef("out_2", "ReadVariableOp",
711             {"v", "^Func/f1/output_control_node/_5",
712              "^Func/f2/output_control_node/_11"},
713             {{"dtype", DT_FLOAT}}, kDevice)},
714 
715       // Function library.
716       {mul_func});
717 
718   CompareGraphs(expected, optimized_graph);
719 
720   item.feed.emplace_back("a", kOne);
721   item.feed.emplace_back("b", kTwo);
722 
723   auto tensors_expected = EvaluateFetchNodes(item);
724   ASSERT_EQ(tensors_expected.size(), 2);
725   EXPECT_EQ(tensors_expected[0].flat<float>()(0), 4.0);  // mul
726   EXPECT_EQ(tensors_expected[1].flat<float>()(0), 3.0);  // read variable
727 
728   GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
729   auto tensors = EvaluateFetchNodes(optimized);
730   ASSERT_EQ(tensors.size(), 2);
731   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
732   test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
733 }
734 
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionWithDevicePlacement)735 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithDevicePlacement) {
736   using test::function::NDef;
737   using FDH = FunctionDefHelper;
738 
739   FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
740 
741   FunctionDef mul_func = FunctionDefHelper::Create(
742       "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
743       {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
744       /* Mapping between function returns and function node outputs. */
745       {{"z", "mul:z:0"}});
746   // Add device placement spec to the function body node.
747   (*mul_func.mutable_node_def())[0].set_device("/device:CPU:1");
748 
749   // We need fully defined device names to run the placer for inlined function.
750   const string cpu0 = "/job:work/replica:1/task:1/device:CPU:0";
751   const string cpu1 = "/job:work/replica:1/task:1/device:CPU:1";
752 
753   // Build a graph to compute c = MyMul(a, b)
754   GrapplerItem item;
755   item.fetch = {"d"};
756   item.graph = test::function::GDef(
757       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
758        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
759        NDef("c", "PartitionedCall", {"a", "b"},
760             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
761              {"Tout", DataTypeSlice{DT_FLOAT}},
762              {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
763             cpu0),
764        NDef("d", "Identity", {"c"}, {{"T", DT_FLOAT}}, cpu0)},
765       // Function library.
766       {mul_func});
767   ASSERT_TRUE(item.InferDevicesFromGraph().ok());
768 
769   GraphDef optimized_graph;
770   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
771 
772   const string input_x = "Func/c/input/_0";
773   const string input_y = "Func/c/input/_1";
774   const string output_z = "Func/c/output/_2";
775 
776   GraphDef expected = test::function::GDef(
777       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
778        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
779 
780        // Function must be inlined and `mul` node placed on a requested device,
781        // and input `Identity` nodes must be colocated with their source nodes.
782        NDef(input_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, cpu0),
783        NDef(input_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, cpu1),
784        NDef("c/mul", "Mul", {input_x, input_y}, {{"T", DT_FLOAT}}, cpu1),
785        NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, cpu1),
786 
787        NDef("d", "Identity", {output_z}, {{"T", DT_FLOAT}}, cpu0)},
788       // Function library.
789       {mul_func});
790 
791   CompareGraphs(expected, optimized_graph);
792 }
793 
TEST_F(FunctionOptimizerTest,InlineMultipleIndirectFunctionWithDevicePlacement)794 TEST_F(FunctionOptimizerTest,
795        InlineMultipleIndirectFunctionWithDevicePlacement) {
796   using test::function::NDef;
797   using FDH = FunctionDefHelper;
798 
799   FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
800 
801   FunctionDef mul_func = FunctionDefHelper::Create(
802       "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
803       {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
804       /* Mapping between function returns and function node outputs. */
805       {{"z", "mul:z:0"}});
806   // Add device placement spec to the function body node.
807   (*mul_func.mutable_node_def())[0].set_device("/device:CPU:1");
808 
809   // We need fully defined device names to run the placer for inlined function.
810   const string cpu0 = "/job:work/replica:1/task:1/device:CPU:0";
811   const string cpu1 = "/job:work/replica:1/task:1/device:CPU:1";
812 
813   // Build a graph to compute c = MyMul(a, b)
814   GrapplerItem item;
815   item.fetch = {"e"};
816   item.graph = test::function::GDef(
817       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
818        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
819        NDef("c", "PartitionedCall", {"a", "b"},
820             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
821              {"Tout", DataTypeSlice{DT_FLOAT}},
822              {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
823             cpu0),
824        NDef("d", "PartitionedCall", {"a", "c"},
825             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
826              {"Tout", DataTypeSlice{DT_FLOAT}},
827              {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
828             cpu0),
829        NDef("e", "Identity", {"d"}, {{"T", DT_FLOAT}}, cpu0)},
830       // Function library.
831       {mul_func});
832   ASSERT_TRUE(item.InferDevicesFromGraph().ok());
833 
834   GraphDef optimized_graph;
835   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
836 
837   const string input_c_x = "Func/c/input/_0";
838   const string input_c_y = "Func/c/input/_1";
839   const string output_c_z = "Func/c/output/_2";
840   const string input_d_x = "Func/d/input/_3";
841   const string input_d_y = "Func/d/input/_4";
842   const string output_d_z = "Func/d/output/_5";
843 
844   GraphDef expected = test::function::GDef(
845       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
846        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
847 
848        // Function must be inlined and `mul` node placed on a requested device,
849        // and input/output `Identity` nodes must be colocated with their
850        // source nodes.
851        NDef(input_c_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, cpu0),
852        NDef(input_c_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, cpu1),
853        NDef("c/mul", "Mul", {input_c_x, input_c_y}, {{"T", DT_FLOAT}}, cpu1),
854        NDef(output_c_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, cpu1),
855 
856        // Function must be inlined and `mul` node placed on a requested device,
857        // and input/output `Identity` nodes must be colocated with their
858        // source nodes.
859        NDef(input_d_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, cpu0),
860        NDef(input_d_y, "Identity", {output_c_z}, {{"T", DT_FLOAT}}, cpu1),
861        NDef("d/mul", "Mul", {input_d_x, input_d_y}, {{"T", DT_FLOAT}}, cpu1),
862        NDef(output_d_z, "Identity", {"d/mul"}, {{"T", DT_FLOAT}}, cpu1),
863 
864        NDef("e", "Identity", {output_d_z}, {{"T", DT_FLOAT}}, cpu0)},
865       // Function library.
866       {mul_func});
867 
868   CompareGraphs(expected, optimized_graph);
869 }
870 
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionWithControlDependencyAndNoSideEffects)871 TEST_F(FunctionOptimizerTest,
872        InlineIndirectFunctionWithControlDependencyAndNoSideEffects) {
873   using test::function::NDef;
874   using FDH = FunctionDefHelper;
875 
876   FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
877 
878   const Tensor kOne = test::AsScalar<float>(1.0);
879   const Tensor kTwo = test::AsScalar<float>(2.0);
880   const TensorShape scalar = TensorShape({});
881 
882   // MyMul doesn't have any side-effectful nodes in the function body, but the
883   // optimized graph has a control dependency edge `f1->f2`.
884   FunctionDef mul_func = FunctionDefHelper::Create(
885       "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
886       {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
887       /* Mapping between function returns and function node outputs. */
888       {{"z", "mul:z:0"}});
889 
890   // Build a graph to compute:
891   //   a = Placeholder
892   //   b = Placeholder
893   //   f1 = MyMul(a, b)
894   //   f2 = MyMul(a, b, ^f1)  <-- control dependency on inlined function!
895   //   return f2
896   GrapplerItem item;
897   TF_EXPECT_OK(item.AddDevice(kDevice));  // device for placing inlined function
898   item.fetch = {"out"};
899   item.graph = test::function::GDef(
900       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
901        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
902 
903        NDef("c", "NoOp", {}, {}, kDevice),
904 
905        // Call function first time.
906        NDef("f1", "PartitionedCall", {"a", "b", "^c"},
907             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
908              {"Tout", DataTypeSlice{DT_FLOAT}},
909              {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
910             kDevice),
911 
912        // Call function second time.
913        NDef("f2", "PartitionedCall", {"f1", "f1", "^f1"},
914             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
915              {"Tout", DataTypeSlice{DT_FLOAT}},
916              {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
917             kDevice),
918 
919        // Return result of f2.
920        NDef("out", "Identity", {"f2"}, {{"T", DT_FLOAT}}, kDevice)},
921 
922       // Function library.
923       {mul_func});
924 
925   GraphDef optimized_graph;
926   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
927 
928   GraphDef expected = test::function::GDef(
929       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
930        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
931 
932        NDef("c", "NoOp", {}, {}, kDevice),
933 
934        // Function body of a first function call inlined into the graph.
935        NDef("Func/f1/input_control_node/_0", "NoOp", {"^c"}, {}, kDevice),
936 
937        NDef("Func/f1/input/_1", "Identity",  // input: 'x'
938             {"a", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
939             kDevice),
940        NDef("Func/f1/input/_2", "Identity",  // input: 'y'
941             {"b", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
942             kDevice),
943 
944        NDef("f1/mul", "Mul", {"Func/f1/input/_1", "Func/f1/input/_2"},
945             {{"T", DT_FLOAT}}, kDevice),
946 
947        NDef("Func/f1/output/_3", "Identity", {"f1/mul"}, {{"T", DT_FLOAT}},
948             kDevice),
949        // Control input from `input_control_node` node is added to ensure
950        // correct frame execution.
951        NDef("Func/f1/output_control_node/_4", "NoOp",
952             {"^Func/f1/input_control_node/_0"}, {}, kDevice),
953 
954        // Function body of a second function call also inlined into the graph,
955        // and input nodes read directly from the output nodes of the first
956        // function call, and control dependency edge removed.
957        NDef("Func/f2/input_control_node/_5", "NoOp",
958             {"^Func/f1/output_control_node/_4"}, {}, kDevice),
959 
960        NDef("Func/f2/input/_6", "Identity",
961             {"Func/f1/output/_3", "^Func/f2/input_control_node/_5"},
962             {{"T", DT_FLOAT}}, kDevice),
963        NDef("Func/f2/input/_7", "Identity",
964             {"Func/f1/output/_3", "^Func/f2/input_control_node/_5"},
965             {{"T", DT_FLOAT}}, kDevice),
966 
967        NDef("f2/mul", "Mul", {"Func/f2/input/_6", "Func/f2/input/_7"},
968             {{"T", DT_FLOAT}}, kDevice),
969        NDef("Func/f2/output/_8", "Identity", {"f2/mul"}, {{"T", DT_FLOAT}},
970             kDevice),
971 
972        // Return directly from output node of f2.
973        NDef("out", "Identity", {"Func/f2/output/_8"}, {{"T", DT_FLOAT}},
974             kDevice)},
975 
976       // Function library.
977       {mul_func});
978 
979   CompareGraphs(expected, optimized_graph);
980 
981   item.feed.emplace_back("a", kOne);
982   item.feed.emplace_back("b", kTwo);
983 
984   auto tensors_expected = EvaluateFetchNodes(item);
985   ASSERT_EQ(tensors_expected.size(), 1);
986 
987   GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
988   auto tensors = EvaluateFetchNodes(optimized);
989   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
990 }
991 
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionDoNotInlineDeadOutputs)992 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionDoNotInlineDeadOutputs) {
993   using test::function::NDef;
994   using FDH = FunctionDefHelper;
995 
996   FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
997 
998   // Function output can be dead.
999   FunctionDef dead_outputs = FunctionDefHelper::Create(
1000       "DeadOutputs", {"x:T", "cond:bool"}, {"z:T"}, {"T: {float, double}"},
1001       {
1002           {{"switch"}, "Switch", {"x", "cond"}, {{"T", "$T"}}},
1003           {{"if_false"}, "Identity", {"switch:output_false:0"}, {{"T", "$T"}}},
1004           {{"if_true"}, "Identity", {"switch:output_true:0"}, {{"T", "$T"}}},
1005       },
1006       /* Mapping between function returns and function node outputs. */
1007       {{"z", "if_false:output:0"}});
1008 
1009   // Simple proxy functions that calls DeadOutputs from the function body.
1010   FunctionDef proxy_func = FunctionDefHelper::Create(
1011       "Proxy", {"x:T", "cond:bool"}, {"z:T"}, {"T: {float, double}"},
1012       {{{"dead"}, "DeadOutputs", {"x", "cond"}, {{"T", "$T"}}}},
1013       /* Mapping between function returns and function node outputs. */
1014       {{"z", "dead:z:0"}});
1015 
1016   // Build a graph to compute:
1017   //   a: float
1018   //   b: bool
1019   //   fn0 = DeadOutputs(x, b)
1020   //   fn1 = Proxy(x, b)
1021   //   out0 = Identity(fn0)
1022   //   out1 = Identity(fn1)
1023   //   return [out0, out1]
1024   //
1025   GrapplerItem item;
1026   item.fetch = {"out0", "out1"};
1027   item.graph = test::function::GDef(
1028       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1029        NDef("b", "Placeholder", {}, {{"dtype", DT_BOOL}}, kDevice),
1030 
1031        NDef("fn0", "PartitionedCall", {"a", "b"},
1032             {{"Tin", DataTypeSlice{DT_FLOAT, DT_BOOL}},
1033              {"Tout", DataTypeSlice{DT_FLOAT}},
1034              {"f", FDH::FunctionRef("DeadOutputs", {{"T", DT_FLOAT}})}},
1035             kDevice),
1036 
1037        NDef("fn1", "PartitionedCall", {"a", "b"},
1038             {{"Tin", DataTypeSlice{DT_FLOAT, DT_BOOL}},
1039              {"Tout", DataTypeSlice{DT_FLOAT}},
1040              {"f", FDH::FunctionRef("Proxy", {{"T", DT_FLOAT}})}},
1041             kDevice),
1042 
1043        NDef("out0", "Identity", {"fn0"}, {{"T", DT_FLOAT}}, kDevice),
1044        NDef("out1", "Identity", {"fn1"}, {{"T", DT_FLOAT}}, kDevice)},
1045       // Function library.
1046       {dead_outputs, proxy_func});
1047 
1048   GraphDef optimized_graph;
1049   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
1050 
1051   GraphDef expected = item.graph;
1052   CompareGraphs(expected, optimized_graph);
1053 
1054   const Tensor one = test::AsScalar<float>(1.0);
1055   item.feed.emplace_back("a", one);
1056   item.feed.emplace_back("b", test::AsScalar<bool>(false));
1057 
1058   auto tensors = EvaluateFetchNodes(item);
1059   ASSERT_EQ(tensors.size(), 2);
1060   test::ExpectTensorEqual<float>(tensors[0], one);
1061   test::ExpectTensorEqual<float>(tensors[1], one);
1062 }
1063 
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionWithMergedDeadTensors)1064 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithMergedDeadTensors) {
1065   using test::function::NDef;
1066   using FDH = FunctionDefHelper;
1067 
1068   FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
1069 
1070   // Function output can't be dead because it goes through the Merge node.
1071   FunctionDef no_dead_outputs = FunctionDefHelper::Create(
1072       "NoDeadOutputs", {"x:T", "cond:bool"}, {"z:T"}, {"T: {float, double}"},
1073       {
1074           {{"switch"}, "Switch", {"x", "cond"}, {{"T", "$T"}}},
1075           {{"if_false"}, "Identity", {"switch:output_false:0"}, {{"T", "$T"}}},
1076           {{"if_true"}, "Identity", {"switch:output_true:0"}, {{"T", "$T"}}},
1077           {{"merge"},
1078            "Merge",
1079            {"if_false:output:0", "if_true:output:0"},
1080            {{"T", "$T"}, {"N", 2}}},
1081       },
1082       /* Mapping between function returns and function node outputs. */
1083       {{"z", "merge:output:0"}});
1084 
1085   // Build a graph to compute:
1086   //   a: float
1087   //   b: bool
1088   //   d = DeadOutputs(x, b)
1089   //   out = Identity(d)
1090   //   return out
1091   //
1092   GrapplerItem item;
1093   TF_EXPECT_OK(item.AddDevice(kDevice));  // device for placing inlined function
1094   item.fetch = {"out"};
1095   item.graph = test::function::GDef(
1096       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1097        NDef("b", "Placeholder", {}, {{"dtype", DT_BOOL}}, kDevice),
1098 
1099        NDef("fn", "PartitionedCall", {"a", "b"},
1100             {{"Tin", DataTypeSlice{DT_FLOAT, DT_BOOL}},
1101              {"Tout", DataTypeSlice{DT_FLOAT}},
1102              {"f", FDH::FunctionRef("NoDeadOutputs", {{"T", DT_FLOAT}})}},
1103             kDevice),
1104 
1105        NDef("out", "Identity", {"fn"}, {{"T", DT_FLOAT}}, kDevice)},
1106       // Function library.
1107       {no_dead_outputs});
1108 
1109   GraphDef optimized_graph;
1110   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
1111 
1112   GraphDef expected = test::function::GDef(
1113       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1114        NDef("b", "Placeholder", {}, {{"dtype", DT_BOOL}}, kDevice),
1115 
1116        // Function body of a first function call inlined into the graph.
1117        NDef("Func/fn/input/_0", "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
1118        NDef("Func/fn/input/_1", "Identity", {"b"}, {{"T", DT_BOOL}}, kDevice),
1119 
1120        NDef("fn/switch", "Switch", {"Func/fn/input/_0", "Func/fn/input/_1"},
1121             {{"T", DT_FLOAT}}, kDevice),
1122        NDef("fn/if_false", "Identity", {"fn/switch"}, {{"T", DT_FLOAT}},
1123             kDevice),
1124        NDef("fn/if_true", "Identity", {"fn/switch:1"}, {{"T", DT_FLOAT}},
1125             kDevice),
1126        NDef("fn/merge", "Merge", {"fn/if_false", "fn/if_true"},
1127             {{"T", DT_FLOAT}, {"N", 2}}, kDevice),
1128 
1129        NDef("Func/fn/output/_2", "Identity", {"fn/merge"}, {{"T", DT_FLOAT}},
1130             kDevice),
1131 
1132        // Return directly from inlined function output node.
1133        NDef("out", "Identity", {"Func/fn/output/_2"}, {{"T", DT_FLOAT}},
1134             kDevice)},
1135 
1136       // Function library.
1137       {no_dead_outputs});
1138 
1139   CompareGraphs(expected, optimized_graph);
1140 
1141   const Tensor one = test::AsScalar<float>(1.0);
1142   item.feed.emplace_back("a", one);
1143   item.feed.emplace_back("b", test::AsScalar<bool>(false));
1144 
1145   auto tensors_expected = EvaluateFetchNodes(item);
1146   ASSERT_EQ(tensors_expected.size(), 1);
1147 
1148   GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
1149   auto tensors = EvaluateFetchNodes(optimized);
1150   ASSERT_EQ(tensors.size(), 1);
1151 
1152   test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1153 }
1154 
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionWithNestedFunctionCall)1155 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithNestedFunctionCall) {
1156   using test::function::NDef;
1157   using FDH = FunctionDefHelper;
1158 
1159   FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
1160 
1161   FunctionDef mul_func = FunctionDefHelper::Create(
1162       "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
1163       {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
1164       /* Mapping between function returns and function node outputs. */
1165       {{"z", "mul:z:0"}});
1166 
1167   // `Square` implemented in terms of PartitionedCall to `MyMul`.
1168   FunctionDef square_func = FunctionDefHelper::Create(
1169       "MySquare", {"x:T"}, {"output:T"}, {"T: {float, double}"},
1170       {{{"square"},
1171         "PartitionedCall",
1172         {"x", "x"},
1173         {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1174          {"Tout", DataTypeSlice{DT_FLOAT}},
1175          {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}}}},
1176       /* Mapping between function returns and function node outputs. */
1177       {{"output", "square:output:0"}});
1178 
1179   // Build a graph to compute:
1180   //   b = Square(a)
1181   //   c = Identity(b)
1182   //   return c
1183   GrapplerItem item;
1184   TF_EXPECT_OK(item.AddDevice(kDevice));  // device for placing inlined function
1185   item.fetch = {"c"};
1186   item.graph = test::function::GDef(
1187       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1188        NDef("b", "PartitionedCall", {"a"},
1189             {{"Tin", DataTypeSlice{DT_FLOAT}},
1190              {"Tout", DataTypeSlice{DT_FLOAT}},
1191              {"f", FDH::FunctionRef("MySquare", {{"T", DT_FLOAT}})}},
1192             kDevice),
1193        NDef("c", "Identity", {"b"}, {{"T", DT_FLOAT}}, kDevice)},
1194       /* Function library */
1195       {mul_func, square_func});
1196 
1197   GraphDef optimized_graph;
1198   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
1199 
1200   GraphDef expected = test::function::GDef(
1201       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1202 
1203        // Inlined inputs of `b` node.
1204        NDef("Func/b/input/_0", "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
1205 
1206        // Inlined inputs of `square` node inside inlined `MySquare` function.
1207        NDef("Func/b/square/input/_2", "Identity", {"Func/b/input/_0"},
1208             {{"T", DT_FLOAT}}, kDevice),
1209        NDef("Func/b/square/input/_3", "Identity", {"Func/b/input/_0"},
1210             {{"T", DT_FLOAT}}, kDevice),
1211 
1212        // Inlined mul node from the `MyMul` function.
1213        NDef("b/square/mul", "Mul",
1214             {"Func/b/square/input/_2", "Func/b/square/input/_3"},
1215             {{"T", DT_FLOAT}}, kDevice),
1216 
1217        NDef("Func/b/square/output/_4", "Identity", {"b/square/mul"},
1218             {{"T", DT_FLOAT}}, kDevice),
1219        NDef("Func/b/output/_1", "Identity", {"Func/b/square/output/_4"},
1220             {{"T", DT_FLOAT}}, kDevice),
1221 
1222        NDef("c", "Identity", {"Func/b/output/_1"}, {{"T", DT_FLOAT}}, kDevice)},
1223       // Function library.
1224       {mul_func});
1225 
1226   CompareGraphs(expected, optimized_graph);
1227 
1228   Tensor three = test::AsScalar<float>(3.0f);
1229   item.feed.emplace_back("a", three);
1230 
1231   GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
1232   auto tensors_expected = EvaluateFetchNodes(item);
1233   auto tensors = EvaluateFetchNodes(optimized);
1234   ASSERT_EQ(tensors_expected.size(), 1);
1235   ASSERT_EQ(tensors.size(), tensors_expected.size());
1236   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1237 }
1238 
ConditionalAdd()1239 GrapplerItem ConditionalAdd() {
1240   // Returns the conditional (is_add) ? a + b : a * b;
1241   using test::function::NDef;
1242   using FDH = FunctionDefHelper;
1243 
1244   FunctionDef add_func = FDH::Create(
1245       "MyAdd", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
1246       {{{"add"}, "Add", {"x", "y"}, {{"T", "$T"}}}},
1247       /* Mapping between function returns and function node outputs. */
1248       {{"z", "add:z:0"}});
1249 
1250   FunctionDef mul_func = FDH::Create(
1251       "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
1252       {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
1253       /* Mapping between function returns and function node outputs. */
1254       {{"z", "mul:z:0"}});
1255 
1256   // Compute: return cond ? a + b : a * b
1257   FunctionDef add_or_mul_func = FDH::Create(
1258       "AddOrMul", {"cond:bool", "x:float", "y:float"}, {"z:float"}, {},
1259       {
1260           {{"if_node"},
1261            "If",
1262            {"cond", "x", "y"},
1263            {
1264                {"Tcond", DT_BOOL},
1265                {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1266                {"Tout", DataTypeSlice{DT_FLOAT}},
1267                {"then_branch", FDH::FunctionRef("MyAdd", {{"T", DT_FLOAT}})},
1268                {"else_branch", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})},
1269                {"_lower_using_switch_merge", true},
1270            }},
1271       },
1272       /* Mapping between function returns and function node outputs. */
1273       {{"z", "if_node:output:0"}}, {{"side_effect", "if_node"}});
1274 
1275   // Build a computation graph for:
1276   //   is_add: bool
1277   //   a: float
1278   //   b: float
1279   //   c = AddOrMul(is_add, a, b)  # is_add ? a + b : a * b
1280   //   d = Identity(c)
1281   //   return d
1282 
1283   // c = MyMul(a, b)
1284   GrapplerItem item;
1285   item.fetch = {"d"};
1286   item.graph = test::function::GDef(
1287       {NDef("is_add", "Placeholder", {}, {{"dtype", DT_BOOL}}, kDevice),
1288        NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1289        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1290 
1291        NDef("c", "PartitionedCall", {"is_add", "a", "b"},
1292             {{"Tin", DataTypeSlice{DT_BOOL, DT_FLOAT, DT_FLOAT}},
1293              {"Tout", DataTypeSlice{DT_FLOAT}},
1294              {"f", FDH::FunctionRef("AddOrMul")}},
1295             kDevice),
1296 
1297        NDef("d", "Identity", {"c", "^c"}, {{"T", DT_FLOAT}}, kDevice)},
1298       // Function library.
1299       {add_or_mul_func, add_func, mul_func});
1300   return item;
1301 }
1302 
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionWithFunctionalControlFlow)1303 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithFunctionalControlFlow) {
1304   FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
1305 
1306   // item.fetch['d'] == (is_add) ? a + b : a * b
1307   GrapplerItem item = ConditionalAdd();
1308   GraphDef optimized_graph;
1309   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
1310 
1311   const auto count_nodes_with_op = [&](const string& op) {
1312     return absl::c_count_if(optimized_graph.node(), [&](const NodeDef& node) {
1313       return node.op() == op;
1314     });
1315   };
1316 
1317   // All `PartitionedCall` nodes in the optimized graph must be inlined, and
1318   // `If` node must be lowered to `Switch` and `Merge` nodes.
1319   EXPECT_EQ(count_nodes_with_op("PartitionedCall"), 0);
1320   EXPECT_EQ(count_nodes_with_op("If"), 0);
1321   EXPECT_EQ(count_nodes_with_op("Switch"), 3);
1322   EXPECT_EQ(count_nodes_with_op("Merge"), 2);
1323 
1324   GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
1325 
1326   Tensor one = test::AsScalar<float>(1.0);
1327   Tensor two = test::AsScalar<float>(2.0);
1328   Tensor three = test::AsScalar<float>(3.0);
1329 
1330   const auto feed_args = [&](bool is_add) {
1331     std::vector<std::pair<string, Tensor>> feed;
1332     feed.emplace_back("a", one);
1333     feed.emplace_back("b", two);
1334     feed.emplace_back("is_add", test::AsScalar<bool>(is_add));
1335     return feed;
1336   };
1337 
1338   {  // Check 'is_add == true': a + b
1339     item.feed = feed_args(true);
1340     optimized.feed = feed_args(true);
1341 
1342     auto tensors_expected = EvaluateFetchNodes(item);
1343     ASSERT_EQ(tensors_expected.size(), 1);
1344     test::ExpectTensorEqual<float>(tensors_expected[0], three);
1345 
1346     auto tensors = EvaluateFetchNodes(optimized);
1347     ASSERT_EQ(tensors.size(), tensors_expected.size());
1348     test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1349   }
1350 
1351   {  // Check 'is_add == false': a * b
1352     item.feed = feed_args(false);
1353     optimized.feed = feed_args(false);
1354 
1355     auto tensors_expected = EvaluateFetchNodes(item);
1356     ASSERT_EQ(tensors_expected.size(), 1);
1357     test::ExpectTensorEqual<float>(tensors_expected[0], two);
1358 
1359     auto tensors = EvaluateFetchNodes(optimized);
1360     ASSERT_EQ(tensors.size(), tensors_expected.size());
1361     test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1362   }
1363 }
1364 
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionDontLowerControlFlow)1365 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionDontLowerControlFlow) {
1366   FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE,
1367                               /*lower_control_flow=*/false);
1368 
1369   // item.fetch['d'] == (is_add) ? a + b : a * b
1370   GrapplerItem item = ConditionalAdd();
1371   GraphDef optimized_graph;
1372   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
1373 
1374   const auto count_nodes_with_op = [&](const string& op) {
1375     return absl::c_count_if(optimized_graph.node(), [&](const NodeDef& node) {
1376       return node.op() == op;
1377     });
1378   };
1379 
1380   // All `PartitionedCall` nodes in the optimized graph must be inlined, and
1381   // `If` node must be lowered to `Switch` and `Merge` nodes.
1382   EXPECT_EQ(count_nodes_with_op("PartitionedCall"), 0);
1383   EXPECT_EQ(count_nodes_with_op("If"), 1);
1384   EXPECT_EQ(count_nodes_with_op("Switch"), 0);
1385   EXPECT_EQ(count_nodes_with_op("Merge"), 0);
1386 
1387   GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
1388 
1389   Tensor one = test::AsScalar<float>(1.0);
1390   Tensor two = test::AsScalar<float>(2.0);
1391   Tensor three = test::AsScalar<float>(3.0);
1392 
1393   const auto feed_args = [&](bool is_add) {
1394     std::vector<std::pair<string, Tensor>> feed;
1395     feed.emplace_back("a", one);
1396     feed.emplace_back("b", two);
1397     feed.emplace_back("is_add", test::AsScalar<bool>(is_add));
1398     return feed;
1399   };
1400 
1401   {  // Check 'is_add == true': a + b
1402     item.feed = feed_args(true);
1403     optimized.feed = feed_args(true);
1404 
1405     auto tensors_expected = EvaluateFetchNodes(item);
1406     ASSERT_EQ(tensors_expected.size(), 1);
1407     test::ExpectTensorEqual<float>(tensors_expected[0], three);
1408 
1409     auto tensors = EvaluateFetchNodes(optimized);
1410     ASSERT_EQ(tensors.size(), tensors_expected.size());
1411     test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1412   }
1413 
1414   {  // Check 'is_add == false': a * b
1415     item.feed = feed_args(false);
1416     optimized.feed = feed_args(false);
1417 
1418     auto tensors_expected = EvaluateFetchNodes(item);
1419     ASSERT_EQ(tensors_expected.size(), 1);
1420     test::ExpectTensorEqual<float>(tensors_expected[0], two);
1421 
1422     auto tensors = EvaluateFetchNodes(optimized);
1423     ASSERT_EQ(tensors.size(), tensors_expected.size());
1424     test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1425   }
1426 }
1427 
TEST_F(FunctionOptimizerTest,SpecializeFunctionXTimesTwo)1428 TEST_F(FunctionOptimizerTest, SpecializeFunctionXTimesTwo) {
1429   using test::function::NDef;
1430 
1431   FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
1432 
1433   // Mark XTimesTwo as noinline.
1434   FunctionDef x_times_two = test::function::XTimesTwo();
1435   (*x_times_two.mutable_attr())["_noinline"].set_b(true);
1436   std::vector<FunctionDef> function_library = {x_times_two};
1437 
1438   // Build a graph to compute y = XTimesTwo(x).
1439   GrapplerItem item;
1440   item.id = "tf_graph";
1441   item.graph = test::function::GDef(
1442       {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1443        NDef("y", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, kDevice),
1444        NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
1445       function_library);
1446 
1447   GraphDef output;
1448   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
1449 
1450   // Make sure that specialized function was added to the library and original
1451   // function was removed.
1452   EXPECT_EQ(1, output.library().function_size());
1453   EXPECT_EQ("XTimesTwo_specialized_for_y_at_tf_graph",
1454             output.library().function(0).signature().name());
1455 
1456   // And 'y' node is calling specialized function.
1457   int count = 0;
1458   for (const NodeDef& node : output.node()) {
1459     if (node.name() == "y" && ++count) {
1460       EXPECT_EQ("XTimesTwo_specialized_for_y_at_tf_graph", node.op());
1461     }
1462   }
1463   EXPECT_EQ(1, count);
1464 
1465   // And that graph evaluation yields the same result.
1466   Tensor pi = test::AsScalar<float>(3.14f);
1467   item.fetch = {"z"};
1468   item.feed.emplace_back("x", pi);
1469 
1470   auto tensors_expected = EvaluateFetchNodes(item);
1471   GrapplerItem optimized = item.WithGraph(std::move(output));
1472   auto tensors = EvaluateFetchNodes(optimized);
1473   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1474 }
1475 
TEST_F(FunctionOptimizerTest,SpecializeIndirectFunctionXTimesTwo)1476 TEST_F(FunctionOptimizerTest, SpecializeIndirectFunctionXTimesTwo) {
1477   using test::function::NDef;
1478   using FDH = FunctionDefHelper;
1479 
1480   FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
1481 
1482   // Mark XTimesTwo as noinline.
1483   FunctionDef x_times_two = test::function::XTimesTwo();
1484   (*x_times_two.mutable_attr())["_noinline"].set_b(true);
1485   std::vector<FunctionDef> function_library = {x_times_two};
1486 
1487   // Tensorflow graph:
1488   //   y = PartitionedCall[f=XTimesTwo, Tin=[DT_FLOAT], Tout=[DT_FLOAT]](x)
1489   GrapplerItem item;
1490   item.id = "tf_graph";
1491   item.graph = test::function::GDef(
1492       {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1493        NDef("y", "PartitionedCall", {"x"},
1494             {{"Tin", DataTypeSlice{DT_FLOAT}},
1495              {"Tout", DataTypeSlice{DT_FLOAT}},
1496              {"f", FDH::FunctionRef("XTimesTwo", {{"T", DT_FLOAT}})}},
1497             kDevice),
1498        NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
1499       function_library);
1500 
1501   GraphDef output;
1502   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
1503 
1504   // Make sure that specialized function was added to the library and original
1505   // function was removed.
1506   EXPECT_EQ(1, output.library().function_size());
1507   EXPECT_EQ("XTimesTwo_specialized_for_y_at_tf_graph",
1508             output.library().function(0).signature().name());
1509 
1510   // And 'y' node is calling specialized function.
1511   int count = 0;
1512   for (const NodeDef& node : output.node()) {
1513     if (node.name() == "y" && ++count) {
1514       EXPECT_EQ("PartitionedCall", node.op());
1515       auto& func = AttrSlice(node).Find("f")->func();
1516       // Function calls into the specialized function.
1517       EXPECT_EQ("XTimesTwo_specialized_for_y_at_tf_graph", func.name());
1518       // And input/output types stay the same.
1519       auto& tin = AttrSlice(node).Find("Tin")->list();
1520       auto& tout = AttrSlice(node).Find("Tout")->list();
1521       ASSERT_EQ(1, tin.type_size());
1522       ASSERT_EQ(1, tout.type_size());
1523       EXPECT_EQ(DT_FLOAT, tin.type(0));
1524       EXPECT_EQ(DT_FLOAT, tout.type(0));
1525     }
1526   }
1527   EXPECT_EQ(1, count);
1528 
1529   // And that graph evaluation yields the same result.
1530   Tensor pi = test::AsScalar<float>(3.14f);
1531   item.fetch = {"z"};
1532   item.feed.emplace_back("x", pi);
1533 
1534   auto tensors_expected = EvaluateFetchNodes(item);
1535   GrapplerItem optimized = item.WithGraph(std::move(output));
1536   auto tensors = EvaluateFetchNodes(optimized);
1537   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1538 }
1539 
TEST_F(FunctionOptimizerTest,SpecializeFunctionPushDownConstInput)1540 TEST_F(FunctionOptimizerTest, SpecializeFunctionPushDownConstInput) {
1541   using test::function::NDef;
1542 
1543   FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
1544 
1545   FunctionDef mul_func = FunctionDefHelper::Create(
1546       "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
1547       {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
1548       /* Mapping between function returns and function node outputs. */
1549       {{"z", "output:z:0"}});
1550 
1551   // Mark MyMul as noinline.
1552   (*mul_func.mutable_attr())["_noinline"].set_b(true);
1553   std::vector<FunctionDef> function_library = {mul_func};
1554 
1555   // Build a graph to compute y = MyMul(x, 2.0).
1556   const Tensor kTwo = test::AsScalar<float>(2.0);
1557 
1558   GrapplerItem item;
1559   item.id = "tf_graph";
1560   item.graph = test::function::GDef(
1561       {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1562        NDef("init", "NoOp", {}, {}, kDevice),
1563        NDef("two", "Const", {"^init", "^x"},
1564             {{"dtype", DT_FLOAT}, {"value", kTwo}}, kDevice),
1565        NDef("y", "MyMul", {"x", "two"}, {{"T", DT_FLOAT}}, kDevice),
1566        NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
1567       function_library);
1568 
1569   GraphDef output;
1570   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
1571 
1572   // Make sure that specialized function was added to the library and original
1573   // function was removed.
1574   ASSERT_EQ(1, output.library().function_size());
1575 
1576   const FunctionDef& specialized = output.library().function(0);
1577   EXPECT_EQ("MyMul_specialized_for_y_at_tf_graph",
1578             specialized.signature().name());
1579   EXPECT_EQ(1, specialized.signature().input_arg_size());
1580 
1581   // And 'y' node has control dependencies of a pushed down const node.
1582   int count = 0;
1583   for (const NodeDef& node : output.node()) {
1584     if (node.name() == "y" && ++count) {
1585       ASSERT_EQ(2, node.input_size());
1586       EXPECT_EQ("x", node.input(0));
1587       EXPECT_EQ("^init", node.input(1));
1588     }
1589   }
1590   EXPECT_EQ(1, count);
1591 
1592   // And that graph evaluation yields the same result.
1593   Tensor pi = test::AsScalar<float>(3.14f);
1594   item.fetch = {"z"};
1595   item.feed.emplace_back("x", pi);
1596 
1597   auto tensors_expected = EvaluateFetchNodes(item);
1598   GrapplerItem optimized = item.WithGraph(std::move(output));
1599   auto tensors = EvaluateFetchNodes(optimized);
1600   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1601 }
1602 
TEST_F(FunctionOptimizerTest,SpecializeIndirectFunctionPushDownConstInput)1603 TEST_F(FunctionOptimizerTest, SpecializeIndirectFunctionPushDownConstInput) {
1604   using test::function::NDef;
1605   using FDH = FunctionDefHelper;
1606 
1607   FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
1608 
1609   FunctionDef mul_func = FunctionDefHelper::Create(
1610       "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
1611       {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
1612       /* Mapping between function returns and function node outputs. */
1613       {{"z", "output:z:0"}});
1614 
1615   // Mark MyMul as noinline.
1616   (*mul_func.mutable_attr())["_noinline"].set_b(true);
1617   std::vector<FunctionDef> function_library = {mul_func};
1618 
1619   const Tensor kTwo = test::AsScalar<float>(2.0);
1620 
1621   // Tensorflow graph:
1622   //   y = PartitionedCall[Tin=[DT_FLOAT], Tout=[DT_FLOAT], f=MyMul](x, two)
1623   GrapplerItem item;
1624   item.id = "tf_graph";
1625   item.graph = test::function::GDef(
1626       {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1627        NDef("init", "NoOp", {}, {}, kDevice),
1628        NDef("two", "Const", {"^init", "^x"},
1629             {{"dtype", DT_FLOAT}, {"value", kTwo}}, kDevice),
1630        NDef("y", "PartitionedCall", {"x", "two"},
1631             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1632              {"Tout", DataTypeSlice{DT_FLOAT}},
1633              {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
1634             kDevice),
1635        NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
1636       function_library);
1637 
1638   GraphDef output;
1639   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
1640 
1641   // Make sure that specialized function was added to the library and original
1642   // function was removed.
1643   ASSERT_EQ(1, output.library().function_size());
1644 
1645   const FunctionDef& specialized = output.library().function(0);
1646   EXPECT_EQ("MyMul_specialized_for_y_at_tf_graph",
1647             specialized.signature().name());
1648   EXPECT_EQ(1, specialized.signature().input_arg_size());
1649 
1650   // And 'y' node has control dependencies of a pushed down const node.
1651   int count = 0;
1652   for (const NodeDef& node : output.node()) {
1653     if (node.name() == "y" && ++count) {
1654       EXPECT_EQ("PartitionedCall", node.op());
1655       ASSERT_EQ(2, node.input_size());
1656       EXPECT_EQ("x", node.input(0));
1657       EXPECT_EQ("^init", node.input(1));
1658       // Function calls into the specialized function.
1659       auto& func = AttrSlice(node).Find("f")->func();
1660       EXPECT_EQ("MyMul_specialized_for_y_at_tf_graph", func.name());
1661       // And input/output type lists were updated.
1662       auto& tin = AttrSlice(node).Find("Tin")->list();
1663       auto& tout = AttrSlice(node).Find("Tout")->list();
1664       ASSERT_EQ(1, tin.type_size());
1665       ASSERT_EQ(1, tout.type_size());
1666       EXPECT_EQ(DT_FLOAT, tin.type(0));
1667       EXPECT_EQ(DT_FLOAT, tout.type(0));
1668     }
1669   }
1670   ASSERT_EQ(1, count);
1671 
1672   // And that graph evaluation yields the same result.
1673   Tensor pi = test::AsScalar<float>(3.14f);
1674   item.fetch = {"z"};
1675   item.feed.emplace_back("x", pi);
1676 
1677   auto tensors_expected = EvaluateFetchNodes(item);
1678   GrapplerItem optimized = item.WithGraph(std::move(output));
1679   auto tensors = EvaluateFetchNodes(optimized);
1680   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1681 }
1682 
TEST_F(FunctionOptimizerTest,SpecializeFunction_OncePerUniqueContext)1683 TEST_F(FunctionOptimizerTest, SpecializeFunction_OncePerUniqueContext) {
1684   using test::function::NDef;
1685 
1686   FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
1687 
1688   // Mark MyMul as noinline.
1689   FunctionDef mul_func = FunctionDefHelper::Create(
1690       "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, int32}"},
1691       {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
1692       /* Mapping between function returns and function node outputs. */
1693       {{"z", "output:z:0"}});
1694   (*mul_func.mutable_attr())["_noinline"].set_b(true);
1695   std::vector<FunctionDef> function_library = {mul_func};
1696 
1697   const Tensor kTwo = test::AsScalar<float>(2.0);
1698   const Tensor kThree = test::AsScalar<float>(3.0);
1699 
1700   GrapplerItem item;
1701   item.id = "tf_graph";
1702   item.graph = test::function::GDef(
1703       {NDef("init", "NoOp", {}, {}, kDevice),
1704 
1705        // Float placeholders.
1706        NDef("xf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1707        NDef("yf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1708 
1709        // Int32 placeholders.
1710        NDef("xi", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
1711        NDef("yi", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
1712 
1713        // Consts. Control inputs has to be attached to specialized func calls.
1714        NDef("two", "Const", {"^init", "^xf"},
1715             {{"dtype", DT_FLOAT}, {"value", kTwo}}, kDevice),
1716        NDef("three", "Const", {"^init", "^xf"},
1717             {{"dtype", DT_FLOAT}, {"value", kThree}}, kDevice),
1718 
1719        // Specialization #1: DT_FLOAT type parameter.
1720        NDef("mul_1", "MyMul", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice),
1721        NDef("mul_2", "MyMul", {"yf", "xf"}, {{"T", DT_FLOAT}}, kDevice),
1722 
1723        // Specialization #2: DT_INT32 type parameter.
1724        NDef("mul_3", "MyMul", {"xi", "yi"}, {{"T", DT_INT32}}, kDevice),
1725 
1726        // Specialization #3: DT_FLOAT type parameter + const input kTwo.
1727        NDef("mul_4", "MyMul", {"xf", "two"}, {{"T", DT_FLOAT}}, kDevice),
1728        NDef("mul_5", "MyMul", {"yf", "two"}, {{"T", DT_FLOAT}}, kDevice),
1729 
1730        // Specialization #4: DT_FLOAT type parameter + const input kThree.
1731        NDef("mul_6", "MyMul", {"three", "xf"}, {{"T", DT_FLOAT}}, kDevice)},
1732       function_library);
1733 
1734   // Specify fetch nodes before optimization to prevent pruning unused function
1735   // outputs.
1736   item.fetch = {"mul_1", "mul_2", "mul_3", "mul_4", "mul_5", "mul_6"};
1737 
1738   GraphDef output;
1739   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
1740 
1741   // Make sure that MyMul was specialized once per unique context.
1742   EXPECT_EQ(4, output.library().function_size());
1743 
1744   // And graph nodes calling specialized functions.
1745   int count = 0;
1746   for (const NodeDef& node : output.node()) {
1747     if (node.name() == "mul_1" && ++count) {
1748       EXPECT_EQ("MyMul_specialized_for_mul_1_at_tf_graph", node.op());
1749       ASSERT_EQ(2, node.input_size());
1750       EXPECT_EQ("xf", node.input(0));
1751       EXPECT_EQ("yf", node.input(1));
1752 
1753     } else if (node.name() == "mul_2" && ++count) {
1754       EXPECT_EQ("MyMul_specialized_for_mul_1_at_tf_graph", node.op());
1755       ASSERT_EQ(2, node.input_size());
1756       EXPECT_EQ("yf", node.input(0));
1757       EXPECT_EQ("xf", node.input(1));
1758 
1759     } else if (node.name() == "mul_3" && ++count) {
1760       EXPECT_EQ("MyMul_specialized_for_mul_3_at_tf_graph", node.op());
1761       ASSERT_EQ(2, node.input_size());
1762       EXPECT_EQ("xi", node.input(0));
1763       EXPECT_EQ("yi", node.input(1));
1764 
1765     } else if (node.name() == "mul_4" && ++count) {
1766       EXPECT_EQ("MyMul_specialized_for_mul_4_at_tf_graph", node.op());
1767       ASSERT_EQ(2, node.input_size());
1768       EXPECT_EQ("xf", node.input(0));
1769       EXPECT_EQ("^init", node.input(1));
1770 
1771     } else if (node.name() == "mul_5" && ++count) {
1772       EXPECT_EQ("MyMul_specialized_for_mul_4_at_tf_graph", node.op());
1773       ASSERT_EQ(3, node.input_size());
1774       EXPECT_EQ("yf", node.input(0));
1775       gtl::FlatSet<string> expected_ctrl = {"^init", "^xf"};
1776       gtl::FlatSet<string> actual_ctrl = {node.input(1), node.input(2)};
1777       EXPECT_EQ(expected_ctrl, actual_ctrl);
1778 
1779     } else if (node.name() == "mul_6" && ++count) {
1780       EXPECT_EQ("MyMul_specialized_for_mul_6_at_tf_graph", node.op());
1781       ASSERT_EQ(2, node.input_size());
1782       EXPECT_EQ("xf", node.input(0));
1783       EXPECT_EQ("^init", node.input(1));
1784     }
1785   }
1786   EXPECT_EQ(6, count);
1787 
1788   // And that graph evaluation yields the same result.
1789   Tensor pi = test::AsScalar<float>(3.14f);
1790   Tensor four = test::AsScalar<int32>(4);
1791   item.feed = {{"xf", pi}, {"yf", pi}, {"xi", four}, {"yi", four}};
1792 
1793   auto tensors_expected = EvaluateFetchNodes(item);
1794   GrapplerItem optimized = item.WithGraph(std::move(output));
1795   auto tensors = EvaluateFetchNodes(optimized);
1796 
1797   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1798   test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
1799   test::ExpectTensorEqual<int32>(tensors_expected[2], tensors[2]);
1800   test::ExpectTensorEqual<float>(tensors_expected[3], tensors[3]);
1801   test::ExpectTensorEqual<float>(tensors_expected[4], tensors[4]);
1802   test::ExpectTensorEqual<float>(tensors_expected[5], tensors[5]);
1803 }
1804 
TEST_F(FunctionOptimizerTest,SpecializeFunctionForUsedOutputTensors)1805 TEST_F(FunctionOptimizerTest, SpecializeFunctionForUsedOutputTensors) {
1806   using test::function::NDef;
1807 
1808   FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
1809 
1810   // MyFunc computes x*y three times and has three output values.
1811   FunctionDef my_func = FunctionDefHelper::Create(
1812       "MyFunc", {"x:T", "y:T"}, {"z1:T", "z2:T", "z3:T"}, {"T: {float, int32}"},
1813       {{{"output1"}, "Mul", {"x", "y"}, {{"T", "$T"}}},
1814        {{"output2"}, "Mul", {"x", "y"}, {{"T", "$T"}}},
1815        {{"output3"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
1816       /* Mapping between function returns and function node outputs. */
1817       {{"z1", "output1:z:0"}, {"z2", "output2:z:0"}, {"z3", "output3:z:0"}});
1818   (*my_func.mutable_attr())["_noinline"].set_b(true);
1819   std::vector<FunctionDef> function_library = {my_func};
1820 
1821   GrapplerItem item;
1822   item.id = "tf_graph";
1823   item.graph = test::function::GDef(
1824       {NDef("init", "NoOp", {}, {}, kDevice),
1825 
1826        // Float placeholders.
1827        NDef("xf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1828        NDef("yf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1829 
1830        // Specialization #1: DT_FLOAT type parameter. All outputs used.
1831        NDef("fn1", "MyFunc", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice),
1832        NDef("use_fn1_0", "Identity", {"fn1:0"}, {{"T", DT_FLOAT}}, kDevice),
1833        NDef("use_fn1_1", "Identity", {"fn1:1"}, {{"T", DT_FLOAT}}, kDevice),
1834        NDef("use_fn1_2", "Identity", {"fn1:2"}, {{"T", DT_FLOAT}}, kDevice),
1835 
1836        // Specialization #2: DT_FLOAT type parameter. Only first output used.
1837        NDef("fn2", "MyFunc", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice),
1838        NDef("use_fn2_0", "Identity", {"fn2:0"}, {{"T", DT_FLOAT}}, kDevice),
1839 
1840        // Specialization #3: DT_FLOAT type parameter. Only second output used.
1841        NDef("fn3", "MyFunc", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice),
1842        NDef("use_fn3_1", "Identity", {"fn3:1"}, {{"T", DT_FLOAT}}, kDevice),
1843 
1844        // Specialization #4: DT_FLOAT type parameter. Only last output used.
1845        NDef("fn4", "MyFunc", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice),
1846        NDef("use_fn4_2", "Identity", {"fn4:2"}, {{"T", DT_FLOAT}}, kDevice),
1847 
1848        // Specialization #5: DT_FLOAT type parameter. First and last outputs.
1849        NDef("fn5", "MyFunc", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice),
1850        NDef("use_fn5_0", "Identity", {"fn5:0"}, {{"T", DT_FLOAT}}, kDevice),
1851        NDef("use_fn5_2", "Identity", {"fn5:2"}, {{"T", DT_FLOAT}}, kDevice),
1852 
1853        // Specialization #6: DT_FLOAT type parameter. Outputs not used.
1854        // Check that function optimizer do not fail. In practice it should be
1855        // pruned from the graph before passing to function optimizer.
1856        NDef("fn6", "MyFunc", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice)},
1857       function_library);
1858 
1859   GraphDef output;
1860   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
1861 
1862   // Make sure that MyFunc was specialized once per unique context.
1863   EXPECT_EQ(6, output.library().function_size());
1864 
1865   // And graph nodes calling specialized functions.
1866   int found = 0;
1867   for (const NodeDef& node : output.node()) {
1868     // All function caller nodes must be specialized.
1869     if (node.name() == "fn1" && ++found) {
1870       EXPECT_EQ("MyFunc_specialized_for_fn1_at_tf_graph", node.op());
1871     } else if (node.name() == "fn2" && ++found) {
1872       EXPECT_EQ("MyFunc_specialized_for_fn2_at_tf_graph", node.op());
1873     } else if (node.name() == "fn3" && ++found) {
1874       EXPECT_EQ("MyFunc_specialized_for_fn3_at_tf_graph", node.op());
1875     } else if (node.name() == "fn4" && ++found) {
1876       EXPECT_EQ("MyFunc_specialized_for_fn4_at_tf_graph", node.op());
1877     } else if (node.name() == "fn5" && ++found) {
1878       EXPECT_EQ("MyFunc_specialized_for_fn5_at_tf_graph", node.op());
1879     } else if (node.name() == "fn6" && ++found) {
1880       EXPECT_EQ("MyFunc_specialized_for_fn6_at_tf_graph", node.op());
1881     }
1882     // And all consumers of specialized function nodes must be mapped to new
1883     // output ports.
1884     if (node.name() == "use_fn3_1" && ++found) {
1885       EXPECT_EQ("fn3", node.input(0));
1886     } else if (node.name() == "use_fn4_2" && ++found) {
1887       EXPECT_EQ("fn4", node.input(0));
1888     } else if (node.name() == "use_fn5_0" && ++found) {
1889       EXPECT_EQ("fn5", node.input(0));
1890     } else if (node.name() == "use_fn5_2" && ++found) {
1891       EXPECT_EQ("fn5:1", node.input(0));
1892     }
1893   }
1894   EXPECT_EQ(10, found);
1895 
1896   // And that graph evaluation yields the same result.
1897   Tensor pi = test::AsScalar<float>(3.14f);
1898   item.fetch = {"use_fn1_0", "use_fn1_1", "use_fn1_2", "use_fn2_0",
1899                 "use_fn3_1", "use_fn4_2", "use_fn5_0", "use_fn5_2"};
1900   item.feed = {{"xf", pi}, {"yf", pi}};
1901 
1902   auto tensors_expected = EvaluateFetchNodes(item);
1903   GrapplerItem optimized = item.WithGraph(std::move(output));
1904   auto tensors = EvaluateFetchNodes(optimized);
1905 
1906   ASSERT_EQ(tensors_expected.size(), tensors.size());
1907   for (int i = 0; i < item.fetch.size(); ++i) {
1908     test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
1909   }
1910 }
1911 
TEST_F(FunctionOptimizerTest,SpecializeIndirectFunctionForUsedOutputTensors)1912 TEST_F(FunctionOptimizerTest, SpecializeIndirectFunctionForUsedOutputTensors) {
1913   using test::function::NDef;
1914   using FDH = FunctionDefHelper;
1915 
1916   FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
1917 
1918   // MyFunc computes x*y three times and has three output values.
1919   FunctionDef my_func = FunctionDefHelper::Create(
1920       "MyFunc", {"x:T", "y:T"}, {"z1:T", "z2:T", "z3:T"}, {"T: {float, int32}"},
1921       {{{"output1"}, "Mul", {"x", "y"}, {{"T", "$T"}}},
1922        {{"output2"}, "Mul", {"x", "y"}, {{"T", "$T"}}},
1923        {{"output3"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
1924       /* Mapping between function returns and function node outputs. */
1925       {{"z1", "output1:z:0"}, {"z2", "output2:z:0"}, {"z3", "output3:z:0"}});
1926   (*my_func.mutable_attr())["_noinline"].set_b(true);
1927   std::vector<FunctionDef> function_library = {my_func};
1928 
1929   GrapplerItem item;
1930   item.id = "tf_graph";
1931   item.graph = test::function::GDef(
1932       {NDef("init", "NoOp", {}, {}, kDevice),
1933 
1934        // Float placeholders.
1935        NDef("xf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1936        NDef("yf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1937 
1938        // Specialization #1: DT_FLOAT type parameter. All outputs used.
1939        NDef("fn1", "PartitionedCall", {"xf", "yf"},
1940             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1941              {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT}},
1942              {"f", FDH::FunctionRef("MyFunc", {{"T", DT_FLOAT}})}},
1943             kDevice),
1944        NDef("use_fn1_0", "Identity", {"fn1:0"}, {{"T", DT_FLOAT}}, kDevice),
1945        NDef("use_fn1_1", "Identity", {"fn1:1"}, {{"T", DT_FLOAT}}, kDevice),
1946        NDef("use_fn1_2", "Identity", {"fn1:2"}, {{"T", DT_FLOAT}}, kDevice),
1947 
1948        // Specialization #2: DT_FLOAT type parameter. Only first output used.
1949        NDef("fn2", "PartitionedCall", {"xf", "yf"},
1950             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1951              {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT}},
1952              {"f", FDH::FunctionRef("MyFunc", {{"T", DT_FLOAT}})}},
1953             kDevice),
1954        NDef("use_fn2_0", "Identity", {"fn2:0"}, {{"T", DT_FLOAT}}, kDevice),
1955 
1956        // Specialization #3: DT_FLOAT type parameter. Only second output used.
1957        NDef("fn3", "PartitionedCall", {"xf", "yf"},
1958             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1959              {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT}},
1960              {"f", FDH::FunctionRef("MyFunc", {{"T", DT_FLOAT}})}},
1961             kDevice),
1962        NDef("use_fn3_1", "Identity", {"fn3:1"}, {{"T", DT_FLOAT}}, kDevice),
1963 
1964        // Specialization #4: DT_FLOAT type parameter. Only last output used.
1965        NDef("fn4", "PartitionedCall", {"xf", "yf"},
1966             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1967              {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT}},
1968              {"f", FDH::FunctionRef("MyFunc", {{"T", DT_FLOAT}})}},
1969             kDevice),
1970        NDef("use_fn4_2", "Identity", {"fn4:2"}, {{"T", DT_FLOAT}}, kDevice),
1971 
1972        // Specialization #5: DT_FLOAT type parameter. First and last outputs.
1973        NDef("fn5", "PartitionedCall", {"xf", "yf"},
1974             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1975              {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT}},
1976              {"f", FDH::FunctionRef("MyFunc", {{"T", DT_FLOAT}})}},
1977             kDevice),
1978        NDef("use_fn5_0", "Identity", {"fn5:0"}, {{"T", DT_FLOAT}}, kDevice),
1979        NDef("use_fn5_2", "Identity", {"fn5:2"}, {{"T", DT_FLOAT}}, kDevice),
1980 
1981        // Specialization #6: DT_FLOAT type parameter. Outputs not used.
1982        // Check that function optimizer do not fail. In practice it should be
1983        // pruned from the graph before passing to function optimizer.
1984        NDef("fn6", "PartitionedCall", {"xf", "yf"},
1985             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1986              {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT}},
1987              {"f", FDH::FunctionRef("MyFunc", {{"T", DT_FLOAT}})}},
1988             kDevice)},
1989       function_library);
1990 
1991   GraphDef output;
1992   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
1993 
1994   // Make sure that MyFunc was specialized once per unique context.
1995   EXPECT_EQ(6, output.library().function_size());
1996 
1997   // And graph nodes calling specialized functions.
1998   int found = 0;
1999   for (const NodeDef& node : output.node()) {
2000     // All function caller nodes must be specialized.
2001     if (node.name() == "fn1" && ++found) {
2002       auto& func = AttrSlice(node).Find("f")->func();
2003       auto& tout = AttrSlice(node).Find("Tout")->list();
2004       EXPECT_EQ("PartitionedCall", node.op());
2005       EXPECT_EQ("MyFunc_specialized_for_fn1_at_tf_graph", func.name());
2006       ASSERT_EQ(3, tout.type_size());
2007 
2008     } else if (node.name() == "fn2" && ++found) {
2009       auto& func = AttrSlice(node).Find("f")->func();
2010       auto& tout = AttrSlice(node).Find("Tout")->list();
2011       EXPECT_EQ("PartitionedCall", node.op());
2012       EXPECT_EQ("MyFunc_specialized_for_fn2_at_tf_graph", func.name());
2013       ASSERT_EQ(1, tout.type_size());
2014 
2015     } else if (node.name() == "fn3" && ++found) {
2016       auto& func = AttrSlice(node).Find("f")->func();
2017       auto& tout = AttrSlice(node).Find("Tout")->list();
2018       EXPECT_EQ("PartitionedCall", node.op());
2019       EXPECT_EQ("MyFunc_specialized_for_fn3_at_tf_graph", func.name());
2020       ASSERT_EQ(1, tout.type_size());
2021 
2022     } else if (node.name() == "fn4" && ++found) {
2023       auto& func = AttrSlice(node).Find("f")->func();
2024       auto& tout = AttrSlice(node).Find("Tout")->list();
2025       EXPECT_EQ("PartitionedCall", node.op());
2026       EXPECT_EQ("MyFunc_specialized_for_fn4_at_tf_graph", func.name());
2027       ASSERT_EQ(1, tout.type_size());
2028 
2029     } else if (node.name() == "fn5" && ++found) {
2030       auto& func = AttrSlice(node).Find("f")->func();
2031       auto& tout = AttrSlice(node).Find("Tout")->list();
2032       EXPECT_EQ("PartitionedCall", node.op());
2033       EXPECT_EQ("MyFunc_specialized_for_fn5_at_tf_graph", func.name());
2034       ASSERT_EQ(2, tout.type_size());
2035 
2036     } else if (node.name() == "fn6" && ++found) {
2037       auto& func = AttrSlice(node).Find("f")->func();
2038       auto& tout = AttrSlice(node).Find("Tout")->list();
2039       EXPECT_EQ("PartitionedCall", node.op());
2040       EXPECT_EQ("MyFunc_specialized_for_fn6_at_tf_graph", func.name());
2041       ASSERT_EQ(0, tout.type_size());
2042     }
2043     // And all consumers of specialized function nodes must be mapped to new
2044     // output ports.
2045     if (node.name() == "use_fn3_1" && ++found) {
2046       EXPECT_EQ("fn3", node.input(0));
2047     } else if (node.name() == "use_fn4_2" && ++found) {
2048       EXPECT_EQ("fn4", node.input(0));
2049     } else if (node.name() == "use_fn5_0" && ++found) {
2050       EXPECT_EQ("fn5", node.input(0));
2051     } else if (node.name() == "use_fn5_2" && ++found) {
2052       EXPECT_EQ("fn5:1", node.input(0));
2053     }
2054   }
2055   EXPECT_EQ(10, found);
2056 
2057   // And that graph evaluation yields the same result.
2058   Tensor pi = test::AsScalar<float>(3.14f);
2059   item.fetch = {"use_fn1_0", "use_fn1_1", "use_fn1_2", "use_fn2_0",
2060                 "use_fn3_1", "use_fn4_2", "use_fn5_0", "use_fn5_2"};
2061   item.feed = {{"xf", pi}, {"yf", pi}};
2062 
2063   auto tensors_expected = EvaluateFetchNodes(item);
2064   GrapplerItem optimized = item.WithGraph(std::move(output));
2065   auto tensors = EvaluateFetchNodes(optimized);
2066 
2067   ASSERT_EQ(tensors_expected.size(), tensors.size());
2068   for (int i = 0; i < item.fetch.size(); ++i) {
2069     test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
2070   }
2071 }
2072 
TEST_F(FunctionOptimizerTest,PruningUselessLibraryFunctions)2073 TEST_F(FunctionOptimizerTest, PruningUselessLibraryFunctions) {
2074   using test::function::NDef;
2075   FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
2076   auto func = test::function::XTimesTwo();
2077   (*func.mutable_attr())["_noinline"].set_b(true);
2078   GrapplerItem item;
2079   item.id = "test_graph";
2080   item.graph = test::function::GDef(
2081       {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, "/device:CPU:0"),
2082        NDef("y", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, "/device:CPU:0"),
2083        NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, "/device:CPU:0")},
2084       // FunctionLib
2085       {
2086           func,
2087           test::function::XTimesTwoInt32(),
2088           test::function::XTimes16(),
2089       });
2090   GraphDef output;
2091   Status status = optimizer.Optimize(nullptr, item, &output);
2092   TF_EXPECT_OK(status);
2093 
2094   ASSERT_EQ(output.library().function().size(), 1);
2095   EXPECT_EQ(output.library().function(0).signature().name(),
2096             "XTimesTwo_specialized_for_y_at_test_graph");
2097 }
2098 
2099 }  // namespace grappler
2100 }  // namespace tensorflow
2101