1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
17
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/cc/ops/functional_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/function_testlib.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/grappler/grappler_item.h"
24 #include "tensorflow/core/grappler/op_types.h"
25 #include "tensorflow/core/grappler/utils/grappler_test.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/gtl/flatset.h"
28
29 namespace tensorflow {
30 namespace grappler {
31
32 namespace {
33 constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
34 } // namespace
35
36 class FunctionOptimizerTest : public GrapplerTest {};
37
TEST_F(FunctionOptimizerTest,InlineFunction_SimpleFunction)38 TEST_F(FunctionOptimizerTest, InlineFunction_SimpleFunction) {
39 using test::function::NDef;
40
41 FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
42
43 // Build a graph to compute y = XTimesTwo(x)
44 GrapplerItem item;
45 item.graph = test::function::GDef(
46 {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
47 NDef("y", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, kDevice),
48 NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
49 // FunctionLib
50 {
51 test::function::XTimesTwo(),
52 });
53
54 GraphDef output;
55 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
56
57 const string arg0 = "Func/y/input/_0";
58 const string ret0 = "Func/y/output/_1";
59
60 const Tensor kTwo = test::AsScalar<int64_t>(2);
61 GraphDef expected = test::function::GDef(
62 {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}),
63 NDef(arg0, "Identity", {"x"}, {{"T", DT_FLOAT}}),
64 NDef("y/two", "Const", {}, {{"dtype", DT_INT64}, {"value", kTwo}}),
65 NDef("y/scale", "Cast", {"y/two"},
66 {{"DstT", DT_FLOAT}, {"SrcT", DT_INT64}}),
67 NDef("y/y", "Mul", {arg0, "y/scale"}, {{"T", DT_FLOAT}}),
68 NDef(ret0, "Identity", {"y/y"}, {{"T", DT_FLOAT}}),
69 NDef("z", "Identity", {ret0}, {{"T", DT_FLOAT}})},
70 {});
71 for (NodeDef& node : *expected.mutable_node()) node.set_device(kDevice);
72
73 CompareGraphs(expected, output);
74
75 Tensor pi = test::AsScalar<float>(3.14f);
76 item.fetch = {"z"};
77 item.feed.emplace_back("x", pi);
78 auto tensors_expected = EvaluateFetchNodes(item);
79 GrapplerItem optimized = item.WithGraph(std::move(output));
80 auto tensors = EvaluateFetchNodes(optimized);
81 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
82 }
83
TEST_F(FunctionOptimizerTest,InlineFunction_FixedTypeFunction)84 TEST_F(FunctionOptimizerTest, InlineFunction_FixedTypeFunction) {
85 using test::function::NDef;
86
87 FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
88
89 // Create and instantiate a version of the XTimesTwo function that only
90 // accepts floats a inputs.
91 const Tensor kTwo = test::AsScalar<float>(2.0f);
92 FunctionDef x_times_two = FunctionDefHelper::Define(
93 // Name
94 "XTimesTwo",
95 // Args
96 {"x: float"},
97 // Return values
98 {"y: float"},
99 // Attr def
100 {},
101 // Nodes
102 {
103 {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
104 // "enter" node is used to verify that InlineFunction would update the
105 // frame name accordingly.
106 {{"enter"},
107 "Enter",
108 {"x"},
109 {{"T", DT_FLOAT}, {"frame_name", "frame"}}},
110 {{"y"}, "Mul", {"x", "two"}, {{"T", DT_FLOAT}}},
111 });
112
113 GrapplerItem item;
114 item.graph = test::function::GDef(
115 {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
116 NDef("y", "XTimesTwo", {"x"}, {}, kDevice),
117 NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
118 // FunctionLib
119 {
120 x_times_two,
121 });
122
123 GraphDef output;
124 Status status = optimizer.Optimize(nullptr, item, &output);
125 TF_EXPECT_OK(status);
126
127 // Calls to XTimesTwo were removed from the graph.
128 for (const NodeDef& node : output.node()) {
129 EXPECT_NE(node.op(), "XTimesTwo");
130 }
131 // And the function itself was removed from the library.
132 EXPECT_EQ(output.library().function_size(), 0);
133
134 Tensor pi = test::AsScalar<float>(3.14f);
135 item.fetch = {"z"};
136 item.feed.emplace_back("x", pi);
137 auto tensors_expected = EvaluateFetchNodes(item);
138 GrapplerItem optimized = item.WithGraph(std::move(output));
139 auto tensors = EvaluateFetchNodes(optimized);
140 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
141 }
142
TEST_F(FunctionOptimizerTest,InlineFunction_FunctionWithOutputMapping)143 TEST_F(FunctionOptimizerTest, InlineFunction_FunctionWithOutputMapping) {
144 using test::function::NDef;
145
146 FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
147
148 FunctionDef func = FunctionDefHelper::Create(
149 // Name
150 "Exp_func",
151 // Args
152 {"in: float"},
153 // Return values
154 {"out: float"},
155 // Attr def
156 {},
157 // Nodes
158 {{{"Linear_func"}, "Identity", {"in"}, {{"T", DT_FLOAT}}},
159 {{"Exp"}, "Exp", {"Linear_func:output:0"}, {{"T", DT_FLOAT}}}},
160 // Mapping
161 {{"out", "Exp:y:0"}});
162
163 GrapplerItem item;
164 item.graph = test::function::GDef(
165 {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
166 NDef("y", "Exp_func", {"x"}, {}, kDevice),
167 NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
168 // FunctionLib
169 {
170 func,
171 });
172
173 GraphDef output;
174 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
175
176 // Function call was removed from the graph.
177 for (const NodeDef& node : output.node()) {
178 EXPECT_NE(node.op(), "Exp_func");
179 }
180 // And the function itself was removed from the library.
181 EXPECT_EQ(output.library().function_size(), 0);
182
183 Tensor pi = test::AsScalar<float>(3.14f);
184 item.fetch = {"z"};
185 item.feed.emplace_back("x", pi);
186 auto tensors_expected = EvaluateFetchNodes(item);
187 GrapplerItem optimized = item.WithGraph(std::move(output));
188 auto tensors = EvaluateFetchNodes(optimized);
189 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
190 }
191
TEST_F(FunctionOptimizerTest,InlineFunction_FunctionWithInputForwarding)192 TEST_F(FunctionOptimizerTest, InlineFunction_FunctionWithInputForwarding) {
193 using test::function::NDef;
194
195 FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
196
197 FunctionDef func = FunctionDefHelper::Create(
198 // Name
199 "ForwardInputs",
200 // Args
201 {"in0: float", "in1: float", "arg2: float", "arg3: int32", "arg4: float"},
202 // Return values
203 {"out0: float", "arg2: float", "arg3: int32"},
204 // Attr def
205 {},
206 // Nodes
207 {},
208 // Mapping
209 {{"out0", "in0"}, {"arg2", "arg2"}, {"arg3", "arg3"}});
210
211 GrapplerItem item;
212 item.graph = test::function::GDef(
213 {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
214 NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
215 NDef("x2", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
216 NDef("x3", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
217 NDef("x4", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
218 NDef("y", "ForwardInputs", {"x0", "x1", "x2", "x3", "x4"}, {}, kDevice),
219 NDef("z0", "Identity", {"y:0"}, {{"T", DT_FLOAT}}, kDevice),
220 NDef("z1", "Identity", {"y:1"}, {{"T", DT_FLOAT}}, kDevice),
221 NDef("z2", "Identity", {"y:2"}, {{"T", DT_INT32}}, kDevice)},
222 // FunctionLib
223 {
224 func,
225 });
226
227 GraphDef output;
228 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
229
230 // Function call was removed from the graph.
231 for (const NodeDef& node : output.node()) {
232 EXPECT_NE(node.op(), "ForwardInputs");
233 }
234 // And the function itself was removed from the library.
235 EXPECT_EQ(output.library().function_size(), 0);
236
237 item.fetch = {"z0", "z1", "z2"};
238 item.feed.emplace_back("x0", test::AsScalar<float>(3.14f));
239 item.feed.emplace_back("x1", test::AsScalar<float>(2.7f));
240 item.feed.emplace_back("x2", test::AsScalar<float>(1.0f));
241 item.feed.emplace_back("x4", test::AsScalar<float>(-1.0f));
242 item.feed.emplace_back("x3", test::AsScalar<int>(1234));
243 auto tensors_expected = EvaluateFetchNodes(item);
244 GrapplerItem optimized = item.WithGraph(std::move(output));
245 auto tensors = EvaluateFetchNodes(optimized);
246 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
247 test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
248 test::ExpectTensorEqual<int>(tensors_expected[2], tensors[2]);
249 }
250
TEST_F(FunctionOptimizerTest,InlineFunction_FunctionWithoutInput)251 TEST_F(FunctionOptimizerTest, InlineFunction_FunctionWithoutInput) {
252 using test::function::NDef;
253
254 FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
255
256 const Tensor kTwo = test::AsScalar<int64_t>(2);
257 FunctionDef func = FunctionDefHelper::Define(
258 // Name
259 "GenerateTwo",
260 // Args
261 {},
262 // Return value
263 {"o: T"},
264 // Attr def
265 {"T: {float, double}"},
266 // Nodes
267 {{{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
268 {{"o"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}});
269
270 GrapplerItem item;
271 item.graph = test::function::GDef(
272 {NDef("y", "GenerateTwo", {}, {{"T", DT_FLOAT}}, kDevice),
273 NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
274 // FunctionLib
275 {
276 func,
277 });
278
279 GraphDef output;
280 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
281
282 // Function call was removed from the graph.
283 for (const NodeDef& node : output.node()) {
284 EXPECT_NE(node.op(), "GenerateTwo");
285 }
286 // And the function itself was removed from the library.
287 EXPECT_EQ(output.library().function_size(), 0);
288
289 item.fetch = {"z"};
290 auto tensors_expected = EvaluateFetchNodes(item);
291 GrapplerItem optimized = item.WithGraph(std::move(output));
292 auto tensors = EvaluateFetchNodes(optimized);
293 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
294 }
295
TEST_F(FunctionOptimizerTest,InlineFunction_FunctionWithNestedFunctionCall)296 TEST_F(FunctionOptimizerTest, InlineFunction_FunctionWithNestedFunctionCall) {
297 using test::function::NDef;
298
299 FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
300
301 // Define square via function library:
302 // MySquare(x) = MyMul(x, x)
303
304 FunctionDef mul_func = FunctionDefHelper::Create(
305 "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
306 {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
307 /* Mapping between function returns and function node outputs. */
308 {{"z", "output:z:0"}});
309
310 FunctionDef square_func = FunctionDefHelper::Create(
311 "MySquare", {"x:T"}, {"z:T"}, {"T: {float, double}"},
312 {{{"output"}, "MyMul", {"x", "x"}, {{"T", "$T"}}}},
313 /* Mapping between function returns and function node outputs. */
314 {{"z", "output:z:0"}});
315
316 GrapplerItem item;
317 item.graph = test::function::GDef(
318 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
319 NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}}, kDevice),
320 NDef("outputs", "Identity", {"square:0"}, {{"T", DT_FLOAT}}, kDevice)},
321 // FunctionLib
322 {mul_func, square_func});
323
324 GraphDef output;
325 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
326
327 // Function calls were removed from the graph.
328 for (const NodeDef& node : output.node()) {
329 EXPECT_NE(node.op(), "MySquare");
330 EXPECT_NE(node.op(), "MyMul");
331 }
332 // And functions were removed from the library.
333 EXPECT_EQ(output.library().function_size(), 0);
334
335 item.fetch = {"outputs"};
336 item.feed.emplace_back("a", test::AsScalar<float>(2.0f));
337 auto tensors_expected = EvaluateFetchNodes(item);
338
339 GrapplerItem optimized = item.WithGraph(std::move(output));
340 auto tensors = EvaluateFetchNodes(optimized);
341
342 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
343 }
344
TEST_F(FunctionOptimizerTest,InlineSymbolicGradient_TestFunc)345 TEST_F(FunctionOptimizerTest, InlineSymbolicGradient_TestFunc) {
346 FunctionOptimizer optimizer(RewriterConfig::ON, true);
347
348 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
349
350 FunctionDef func = FunctionDefHelper::Define(
351 "TestFunc", {"x:float", "y:float"}, {"l:float"}, {},
352 {
353 {{"z"}, "Add", {"x", "y"}, {{"T", DT_FLOAT}}},
354 FunctionDefHelper::Const("zero", 0),
355 FunctionDefHelper::Const("one", 1),
356 {{"r"}, "Rank", {"z"}, {{"T", DT_FLOAT}}},
357 {{"indices"}, "Range", {"zero", "r", "one"}},
358 {{"l"}, "Sum", {"z", "indices"}, {{"T", DT_FLOAT}}},
359 });
360
361 auto x = ops::Const(scope, 1.0f);
362 auto y = ops::Const(scope, 2.0f);
363 auto dl = ops::Const(scope, 3.0f);
364
365 NameAttrList fn;
366 fn.set_name("TestFunc");
367 (*fn.mutable_attr())["T"].set_type(DT_FLOAT);
368 auto g0 = ops::SymbolicGradient(scope, std::initializer_list<Input>{x, y, dl},
369 {DT_FLOAT, DT_FLOAT}, fn);
370 auto out1 = ops::Identity(scope.WithOpName("out1"), g0.output[0]);
371 auto out2 = ops::Identity(scope.WithOpName("out2"), g0.output[1]);
372
373 GrapplerItem item;
374 TF_EXPECT_OK(scope.ToGraphDef(&item.graph));
375 *item.graph.mutable_library()->add_function() = func;
376
377 GraphDef output;
378 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
379
380 // SymbolicGradient calls were removed from the graph.
381 for (const NodeDef& node : output.node()) {
382 EXPECT_NE(node.op(), "SymbolicGradient");
383 }
384 // And functions were removed from the library.
385 EXPECT_EQ(output.library().function_size(), 0);
386
387 std::vector<Tensor> expected =
388 EvaluateNodes(item.graph, {"out1", "out2"}, {});
389 std::vector<Tensor> optimized = EvaluateNodes(output, {"out1", "out2"}, {});
390 test::ExpectTensorEqual<float>(expected[0], optimized[0]);
391 test::ExpectTensorEqual<float>(expected[1], optimized[1]);
392 }
393
TEST_F(FunctionOptimizerTest,InlineSymbolicGradient_IdentityFunc)394 TEST_F(FunctionOptimizerTest, InlineSymbolicGradient_IdentityFunc) {
395 FunctionOptimizer optimizer(RewriterConfig::ON, true);
396
397 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
398
399 FunctionDef func = FunctionDefHelper::Create(
400 // Name
401 "Identity_func",
402 // Args
403 {"in: float"},
404 // Return values
405 {"out: float"},
406 // Attr def
407 {},
408 // Nodes
409 {{{"Identity"}, "Identity", {"in"}, {{"T", DT_FLOAT}}}},
410 // Mapping
411 {{"out", "Identity:output:0"}});
412
413 auto x = ops::Const(scope, 1.0f, {3, 5, 7});
414 auto z = ops::Const(scope, 3.0f, {3, 5, 7});
415
416 NameAttrList fn;
417 fn.set_name("Identity_func");
418 auto g0 = ops::SymbolicGradient(scope, std::initializer_list<Input>{x, z},
419 {DT_FLOAT}, fn);
420 auto out = ops::Identity(scope.WithOpName("out"), g0.output[0]);
421
422 GrapplerItem item;
423 TF_EXPECT_OK(scope.ToGraphDef(&item.graph));
424 *item.graph.mutable_library()->add_function() = func;
425
426 GraphDef output;
427 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
428
429 // SymbolicGradient calls were removed from the graph.
430 for (const NodeDef& node : output.node()) {
431 EXPECT_NE(node.op(), "SymbolicGradient");
432 }
433 // And functions were removed from the library.
434 EXPECT_EQ(output.library().function_size(), 0);
435
436 std::vector<Tensor> expected = EvaluateNodes(item.graph, {"out"}, {});
437 std::vector<Tensor> optimized = EvaluateNodes(output, {"out"}, {});
438 test::ExpectTensorEqual<float>(expected[0], optimized[0]);
439 }
440
TEST_F(FunctionOptimizerTest,InlineSymbolicGradientNoInlineFunc)441 TEST_F(FunctionOptimizerTest, InlineSymbolicGradientNoInlineFunc) {
442 FunctionOptimizer optimizer(RewriterConfig::ON, true);
443
444 FunctionDef func = FunctionDefHelper::Define(
445 "TestFunc", {"x:float", "y:float"}, {"l:float"}, {},
446 {
447 {{"z"}, "Add", {"x", "y"}, {{"T", DT_FLOAT}}},
448 FunctionDefHelper::Const("zero", 0),
449 FunctionDefHelper::Const("one", 1),
450 {{"r"}, "Rank", {"z"}, {{"T", DT_FLOAT}}},
451 {{"indices"}, "Range", {"zero", "r", "one"}},
452 {{"l"}, "Sum", {"z", "indices"}, {{"T", DT_FLOAT}}},
453 });
454 (*func.mutable_attr())["_noinline"].set_b(true);
455
456 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
457 auto x = ops::Const(scope, 1.0f);
458 auto y = ops::Const(scope, 2.0f);
459 auto dl = ops::Const(scope, 3.0f);
460
461 NameAttrList fn;
462 fn.set_name("TestFunc");
463 (*fn.mutable_attr())["T"].set_type(DT_FLOAT);
464 auto g0 = ops::SymbolicGradient(scope, std::initializer_list<Input>{x, y, dl},
465 {DT_FLOAT, DT_FLOAT}, fn);
466 auto out1 = ops::Identity(scope.WithOpName("out1"), g0.output[0]);
467 auto out2 = ops::Identity(scope.WithOpName("out2"), g0.output[1]);
468
469 GrapplerItem item;
470 TF_EXPECT_OK(scope.ToGraphDef(&item.graph));
471 *item.graph.mutable_library()->add_function() = func;
472
473 GraphDef output;
474 Status status = optimizer.Optimize(nullptr, item, &output);
475 // The optimizer should succeed but the graphs should be the same.
476 TF_EXPECT_OK(status);
477 CompareGraphs(item.graph, output);
478 }
479
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionSimpleFunction)480 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionSimpleFunction) {
481 using test::function::NDef;
482 using FDH = FunctionDefHelper;
483
484 FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
485
486 FunctionDef mul_func = FunctionDefHelper::Create(
487 "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
488 {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
489 /* Mapping between function returns and function node outputs. */
490 {{"z", "mul:z:0"}});
491
492 // Build a graph to compute c = MyMul(a, b)
493 GrapplerItem item;
494 item.fetch = {"d"};
495 item.graph = test::function::GDef(
496 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
497 NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
498 NDef("c", "PartitionedCall", {"a", "b"},
499 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
500 {"Tout", DataTypeSlice{DT_FLOAT}},
501 {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
502 kDevice),
503 NDef("d", "Identity", {"c"}, {{"T", DT_FLOAT}}, kDevice)},
504 {mul_func} /* Function library */);
505
506 Tensor pi = test::AsScalar<float>(3.14f);
507 item.feed.emplace_back("a", pi);
508 item.feed.emplace_back("b", pi);
509
510 const string input_x = "Func/c/input/_0";
511 const string input_y = "Func/c/input/_1";
512 const string output_z = "Func/c/output/_2";
513
514 // If device set is empty, inlined function body must not be placed.
515 {
516 GraphDef optimized_graph;
517 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
518
519 GraphDef expected = test::function::GDef(
520 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
521 NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
522
523 // Function body nodes copy only job/task/replica parts of device
524 // assignment, and function input nodes must copy full device
525 // assignment from input arguments. Optimized graph is not fully
526 // placed.
527 NDef(input_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
528 NDef(input_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, kDevice),
529 // NOTE(ezhulenev): Currently multi-device function inlining placer
530 // strategy will override all empty devices with function call device.
531 NDef("c/mul", "Mul", {input_x, input_y}, {{"T", DT_FLOAT}}, kDevice),
532 NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}),
533
534 NDef("d", "Identity", {output_z}, {{"T", DT_FLOAT}}, kDevice)},
535 // Function library.
536 {mul_func});
537
538 CompareGraphs(expected, optimized_graph);
539
540 GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
541 auto tensors_expected = EvaluateFetchNodes(item);
542 auto tensors = EvaluateFetchNodes(optimized);
543 ASSERT_EQ(tensors_expected.size(), 1);
544 ASSERT_EQ(tensors.size(), tensors_expected.size());
545 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
546 }
547
548 // If device set is not empty, inlined function body must be placed.
549 {
550 GraphDef optimized_graph;
551 TF_EXPECT_OK(item.AddDevice(kDevice));
552 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
553
554 GraphDef expected = test::function::GDef(
555 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
556 NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
557
558 NDef(input_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
559 NDef(input_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, kDevice),
560 NDef("c/mul", "Mul", {input_x, input_y}, {{"T", DT_FLOAT}}, kDevice),
561 NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, kDevice),
562
563 NDef("d", "Identity", {output_z}, {{"T", DT_FLOAT}}, kDevice)},
564 // Function library.
565 {mul_func});
566
567 CompareGraphs(expected, optimized_graph);
568
569 GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
570 auto tensors_expected = EvaluateFetchNodes(item);
571 auto tensors = EvaluateFetchNodes(optimized);
572 ASSERT_EQ(tensors_expected.size(), 1);
573 ASSERT_EQ(tensors.size(), tensors_expected.size());
574 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
575 }
576 }
577
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionWithControlDependencies)578 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithControlDependencies) {
579 using test::function::NDef;
580 using FDH = FunctionDefHelper;
581
582 FunctionOptimizer optimizer(RewriterConfig::ON, true);
583
584 const Tensor kOne = test::AsScalar<float>(1.0);
585 const Tensor kTwo = test::AsScalar<float>(2.0);
586 const TensorShape scalar = TensorShape({});
587
588 // Compute `x*y` and add `1.0` to the variable.
589 FunctionDef mul_func = FunctionDefHelper::Create(
590 "MyMul", {"x:T", "y:T", "v: resource"}, {"z:T"}, {"T: {float, double}"},
591 {{{"one"}, "Const", {}, {{"value", kOne}, {"dtype", DT_FLOAT}}},
592 {{"add"},
593 "AssignAddVariableOp",
594 {"v", "one:output:0"},
595 {{"dtype", DT_FLOAT}}},
596 {{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
597 /* Mapping between function returns and function node outputs. */
598 {{"z", "mul:z:0"}},
599 /* Control output to ensure that side effects will be executed. */
600 {{"size_effects", "add"}});
601
602 // Build a graph to compute:
603 // a = Placeholder
604 // b = Placeholder
605 // v = VarHandleOp(init = a)
606 // f1 = MyMul(a, b, v)
607 // f2 = MyMul(f1, f1, v)
608 // return [f2, v]
609 GrapplerItem item;
610 TF_EXPECT_OK(item.AddDevice(kDevice)); // device for placing inlined function
611 item.fetch = {"out_1", "out_2"};
612 item.graph = test::function::GDef(
613 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
614 NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
615
616 // Initialize variable with one of the placeholders.
617 NDef("v", "VarHandleOp", {}, {{"dtype", DT_FLOAT}, {"shape", scalar}}),
618 NDef("init_v", "AssignVariableOp", {"v", "a"}, {{"dtype", DT_FLOAT}},
619 kDevice),
620
621 // Call function first time.
622 NDef("f1", "PartitionedCall", {"a", "b", "v", "^init_v"},
623 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_RESOURCE}},
624 {"Tout", DataTypeSlice{DT_FLOAT}},
625 {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
626 kDevice),
627
628 // Call function second time.
629 NDef("f2", "PartitionedCall", {"f1", "f1", "v", "^f1"},
630 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_RESOURCE}},
631 {"Tout", DataTypeSlice{DT_FLOAT}},
632 {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
633 kDevice),
634
635 // Return result of multiplication and a current value of the variable.
636 NDef("out_1", "Identity", {"f2"}, {{"T", DT_FLOAT}}, kDevice),
637 NDef("out_2", "ReadVariableOp", {"v", "^f1", "^f2"},
638 {{"dtype", DT_FLOAT}}, kDevice)},
639
640 // Function library.
641 {mul_func});
642
643 GraphDef optimized_graph;
644 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
645
646 GraphDef expected = test::function::GDef(
647 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
648 NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
649
650 // Initialize variable with one of the placeholders.
651 NDef("v", "VarHandleOp", {}, {{"dtype", DT_FLOAT}, {"shape", scalar}},
652 kDevice),
653 NDef("init_v", "AssignVariableOp", {"v", "a"}, {{"dtype", DT_FLOAT}},
654 kDevice),
655
656 // Function body of a first function call inlined into the graph.
657 NDef("Func/f1/input_control_node/_0", "NoOp", {"^init_v"}, {}, kDevice),
658
659 NDef("Func/f1/input/_1", "Identity", // input: 'x'
660 {"a", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
661 kDevice),
662 NDef("Func/f1/input/_2", "Identity", // input: 'y'
663 {"b", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
664 kDevice),
665 NDef("Func/f1/input/_3", "Identity", // input: 'v'
666 {"v", "^Func/f1/input_control_node/_0"}, {{"T", DT_RESOURCE}},
667 kDevice),
668
669 NDef("f1/one", "Const", {"^Func/f1/input_control_node/_0"},
670 {{"dtype", DT_FLOAT}, {"value", kOne}}, kDevice),
671 NDef("f1/mul", "Mul", {"Func/f1/input/_1", "Func/f1/input/_2"},
672 {{"T", DT_FLOAT}}, kDevice),
673 NDef("f1/add", "AssignAddVariableOp", {"Func/f1/input/_3", "f1/one"},
674 {{"dtype", DT_FLOAT}}, kDevice),
675
676 NDef("Func/f1/output/_4", "Identity", {"f1/mul"}, {{"T", DT_FLOAT}},
677 kDevice),
678 NDef("Func/f1/output_control_node/_5", "NoOp", {"^f1/add"}, {}, kDevice),
679
680 // Function body of a second function call also inlined into the graph,
681 // and input nodes read from the output nodes of the first function call.
682 NDef("Func/f2/input_control_node/_6", "NoOp",
683 {"^Func/f1/output_control_node/_5"}, {}, kDevice),
684
685 NDef("Func/f2/input/_7", "Identity", // input: 'x'
686 {"Func/f1/output/_4", "^Func/f2/input_control_node/_6"},
687 {{"T", DT_FLOAT}}, kDevice),
688 NDef("Func/f2/input/_8", "Identity", // input: 'y'
689 {"Func/f1/output/_4", "^Func/f2/input_control_node/_6"},
690 {{"T", DT_FLOAT}}, kDevice),
691 NDef("Func/f2/input/_9", "Identity", // input: 'v'
692 {"v", "^Func/f2/input_control_node/_6"}, {{"T", DT_RESOURCE}},
693 kDevice),
694
695 NDef("f2/one", "Const", {"^Func/f2/input_control_node/_6"},
696 {{"dtype", DT_FLOAT}, {"value", kOne}}, kDevice),
697 NDef("f2/add", "AssignAddVariableOp", {"Func/f2/input/_9", "f2/one"},
698 {{"dtype", DT_FLOAT}}, kDevice),
699 NDef("f2/mul", "Mul", {"Func/f2/input/_7", "Func/f2/input/_8"},
700 {{"T", DT_FLOAT}}, kDevice),
701
702 NDef("Func/f2/output/_10", "Identity", {"f2/mul"}, {{"T", DT_FLOAT}},
703 kDevice),
704 NDef("Func/f2/output_control_node/_11", "NoOp", {"^f2/add"}, {},
705 kDevice),
706
707 // Return values read from inlined output nodes.
708 NDef("out_1", "Identity", {"Func/f2/output/_10"}, {{"T", DT_FLOAT}},
709 kDevice),
710 NDef("out_2", "ReadVariableOp",
711 {"v", "^Func/f1/output_control_node/_5",
712 "^Func/f2/output_control_node/_11"},
713 {{"dtype", DT_FLOAT}}, kDevice)},
714
715 // Function library.
716 {mul_func});
717
718 CompareGraphs(expected, optimized_graph);
719
720 item.feed.emplace_back("a", kOne);
721 item.feed.emplace_back("b", kTwo);
722
723 auto tensors_expected = EvaluateFetchNodes(item);
724 ASSERT_EQ(tensors_expected.size(), 2);
725 EXPECT_EQ(tensors_expected[0].flat<float>()(0), 4.0); // mul
726 EXPECT_EQ(tensors_expected[1].flat<float>()(0), 3.0); // read variable
727
728 GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
729 auto tensors = EvaluateFetchNodes(optimized);
730 ASSERT_EQ(tensors.size(), 2);
731 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
732 test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
733 }
734
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionWithDevicePlacement)735 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithDevicePlacement) {
736 using test::function::NDef;
737 using FDH = FunctionDefHelper;
738
739 FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
740
741 FunctionDef mul_func = FunctionDefHelper::Create(
742 "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
743 {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
744 /* Mapping between function returns and function node outputs. */
745 {{"z", "mul:z:0"}});
746 // Add device placement spec to the function body node.
747 (*mul_func.mutable_node_def())[0].set_device("/device:CPU:1");
748
749 // We need fully defined device names to run the placer for inlined function.
750 const string cpu0 = "/job:work/replica:1/task:1/device:CPU:0";
751 const string cpu1 = "/job:work/replica:1/task:1/device:CPU:1";
752
753 // Build a graph to compute c = MyMul(a, b)
754 GrapplerItem item;
755 item.fetch = {"d"};
756 item.graph = test::function::GDef(
757 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
758 NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
759 NDef("c", "PartitionedCall", {"a", "b"},
760 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
761 {"Tout", DataTypeSlice{DT_FLOAT}},
762 {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
763 cpu0),
764 NDef("d", "Identity", {"c"}, {{"T", DT_FLOAT}}, cpu0)},
765 // Function library.
766 {mul_func});
767 ASSERT_TRUE(item.InferDevicesFromGraph().ok());
768
769 GraphDef optimized_graph;
770 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
771
772 const string input_x = "Func/c/input/_0";
773 const string input_y = "Func/c/input/_1";
774 const string output_z = "Func/c/output/_2";
775
776 GraphDef expected = test::function::GDef(
777 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
778 NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
779
780 // Function must be inlined and `mul` node placed on a requested device,
781 // and input `Identity` nodes must be colocated with their source nodes.
782 NDef(input_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, cpu0),
783 NDef(input_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, cpu1),
784 NDef("c/mul", "Mul", {input_x, input_y}, {{"T", DT_FLOAT}}, cpu1),
785 NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, cpu1),
786
787 NDef("d", "Identity", {output_z}, {{"T", DT_FLOAT}}, cpu0)},
788 // Function library.
789 {mul_func});
790
791 CompareGraphs(expected, optimized_graph);
792 }
793
TEST_F(FunctionOptimizerTest,InlineMultipleIndirectFunctionWithDevicePlacement)794 TEST_F(FunctionOptimizerTest,
795 InlineMultipleIndirectFunctionWithDevicePlacement) {
796 using test::function::NDef;
797 using FDH = FunctionDefHelper;
798
799 FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
800
801 FunctionDef mul_func = FunctionDefHelper::Create(
802 "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
803 {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
804 /* Mapping between function returns and function node outputs. */
805 {{"z", "mul:z:0"}});
806 // Add device placement spec to the function body node.
807 (*mul_func.mutable_node_def())[0].set_device("/device:CPU:1");
808
809 // We need fully defined device names to run the placer for inlined function.
810 const string cpu0 = "/job:work/replica:1/task:1/device:CPU:0";
811 const string cpu1 = "/job:work/replica:1/task:1/device:CPU:1";
812
813 // Build a graph to compute c = MyMul(a, b)
814 GrapplerItem item;
815 item.fetch = {"e"};
816 item.graph = test::function::GDef(
817 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
818 NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
819 NDef("c", "PartitionedCall", {"a", "b"},
820 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
821 {"Tout", DataTypeSlice{DT_FLOAT}},
822 {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
823 cpu0),
824 NDef("d", "PartitionedCall", {"a", "c"},
825 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
826 {"Tout", DataTypeSlice{DT_FLOAT}},
827 {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
828 cpu0),
829 NDef("e", "Identity", {"d"}, {{"T", DT_FLOAT}}, cpu0)},
830 // Function library.
831 {mul_func});
832 ASSERT_TRUE(item.InferDevicesFromGraph().ok());
833
834 GraphDef optimized_graph;
835 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
836
837 const string input_c_x = "Func/c/input/_0";
838 const string input_c_y = "Func/c/input/_1";
839 const string output_c_z = "Func/c/output/_2";
840 const string input_d_x = "Func/d/input/_3";
841 const string input_d_y = "Func/d/input/_4";
842 const string output_d_z = "Func/d/output/_5";
843
844 GraphDef expected = test::function::GDef(
845 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
846 NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
847
848 // Function must be inlined and `mul` node placed on a requested device,
849 // and input/output `Identity` nodes must be colocated with their
850 // source nodes.
851 NDef(input_c_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, cpu0),
852 NDef(input_c_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, cpu1),
853 NDef("c/mul", "Mul", {input_c_x, input_c_y}, {{"T", DT_FLOAT}}, cpu1),
854 NDef(output_c_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, cpu1),
855
856 // Function must be inlined and `mul` node placed on a requested device,
857 // and input/output `Identity` nodes must be colocated with their
858 // source nodes.
859 NDef(input_d_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, cpu0),
860 NDef(input_d_y, "Identity", {output_c_z}, {{"T", DT_FLOAT}}, cpu1),
861 NDef("d/mul", "Mul", {input_d_x, input_d_y}, {{"T", DT_FLOAT}}, cpu1),
862 NDef(output_d_z, "Identity", {"d/mul"}, {{"T", DT_FLOAT}}, cpu1),
863
864 NDef("e", "Identity", {output_d_z}, {{"T", DT_FLOAT}}, cpu0)},
865 // Function library.
866 {mul_func});
867
868 CompareGraphs(expected, optimized_graph);
869 }
870
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionWithControlDependencyAndNoSideEffects)871 TEST_F(FunctionOptimizerTest,
872 InlineIndirectFunctionWithControlDependencyAndNoSideEffects) {
873 using test::function::NDef;
874 using FDH = FunctionDefHelper;
875
876 FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
877
878 const Tensor kOne = test::AsScalar<float>(1.0);
879 const Tensor kTwo = test::AsScalar<float>(2.0);
880 const TensorShape scalar = TensorShape({});
881
882 // MyMul doesn't have any side-effectful nodes in the function body, but the
883 // optimized graph has a control dependency edge `f1->f2`.
884 FunctionDef mul_func = FunctionDefHelper::Create(
885 "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
886 {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
887 /* Mapping between function returns and function node outputs. */
888 {{"z", "mul:z:0"}});
889
890 // Build a graph to compute:
891 // a = Placeholder
892 // b = Placeholder
893 // f1 = MyMul(a, b)
894 // f2 = MyMul(a, b, ^f1) <-- control dependency on inlined function!
895 // return f2
896 GrapplerItem item;
897 TF_EXPECT_OK(item.AddDevice(kDevice)); // device for placing inlined function
898 item.fetch = {"out"};
899 item.graph = test::function::GDef(
900 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
901 NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
902
903 NDef("c", "NoOp", {}, {}, kDevice),
904
905 // Call function first time.
906 NDef("f1", "PartitionedCall", {"a", "b", "^c"},
907 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
908 {"Tout", DataTypeSlice{DT_FLOAT}},
909 {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
910 kDevice),
911
912 // Call function second time.
913 NDef("f2", "PartitionedCall", {"f1", "f1", "^f1"},
914 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
915 {"Tout", DataTypeSlice{DT_FLOAT}},
916 {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
917 kDevice),
918
919 // Return result of f2.
920 NDef("out", "Identity", {"f2"}, {{"T", DT_FLOAT}}, kDevice)},
921
922 // Function library.
923 {mul_func});
924
925 GraphDef optimized_graph;
926 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
927
928 GraphDef expected = test::function::GDef(
929 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
930 NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
931
932 NDef("c", "NoOp", {}, {}, kDevice),
933
934 // Function body of a first function call inlined into the graph.
935 NDef("Func/f1/input_control_node/_0", "NoOp", {"^c"}, {}, kDevice),
936
937 NDef("Func/f1/input/_1", "Identity", // input: 'x'
938 {"a", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
939 kDevice),
940 NDef("Func/f1/input/_2", "Identity", // input: 'y'
941 {"b", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
942 kDevice),
943
944 NDef("f1/mul", "Mul", {"Func/f1/input/_1", "Func/f1/input/_2"},
945 {{"T", DT_FLOAT}}, kDevice),
946
947 NDef("Func/f1/output/_3", "Identity", {"f1/mul"}, {{"T", DT_FLOAT}},
948 kDevice),
949 // Control input from `input_control_node` node is added to ensure
950 // correct frame execution.
951 NDef("Func/f1/output_control_node/_4", "NoOp",
952 {"^Func/f1/input_control_node/_0"}, {}, kDevice),
953
954 // Function body of a second function call also inlined into the graph,
955 // and input nodes read directly from the output nodes of the first
956 // function call, and control dependency edge removed.
957 NDef("Func/f2/input_control_node/_5", "NoOp",
958 {"^Func/f1/output_control_node/_4"}, {}, kDevice),
959
960 NDef("Func/f2/input/_6", "Identity",
961 {"Func/f1/output/_3", "^Func/f2/input_control_node/_5"},
962 {{"T", DT_FLOAT}}, kDevice),
963 NDef("Func/f2/input/_7", "Identity",
964 {"Func/f1/output/_3", "^Func/f2/input_control_node/_5"},
965 {{"T", DT_FLOAT}}, kDevice),
966
967 NDef("f2/mul", "Mul", {"Func/f2/input/_6", "Func/f2/input/_7"},
968 {{"T", DT_FLOAT}}, kDevice),
969 NDef("Func/f2/output/_8", "Identity", {"f2/mul"}, {{"T", DT_FLOAT}},
970 kDevice),
971
972 // Return directly from output node of f2.
973 NDef("out", "Identity", {"Func/f2/output/_8"}, {{"T", DT_FLOAT}},
974 kDevice)},
975
976 // Function library.
977 {mul_func});
978
979 CompareGraphs(expected, optimized_graph);
980
981 item.feed.emplace_back("a", kOne);
982 item.feed.emplace_back("b", kTwo);
983
984 auto tensors_expected = EvaluateFetchNodes(item);
985 ASSERT_EQ(tensors_expected.size(), 1);
986
987 GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
988 auto tensors = EvaluateFetchNodes(optimized);
989 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
990 }
991
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionDoNotInlineDeadOutputs)992 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionDoNotInlineDeadOutputs) {
993 using test::function::NDef;
994 using FDH = FunctionDefHelper;
995
996 FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
997
998 // Function output can be dead.
999 FunctionDef dead_outputs = FunctionDefHelper::Create(
1000 "DeadOutputs", {"x:T", "cond:bool"}, {"z:T"}, {"T: {float, double}"},
1001 {
1002 {{"switch"}, "Switch", {"x", "cond"}, {{"T", "$T"}}},
1003 {{"if_false"}, "Identity", {"switch:output_false:0"}, {{"T", "$T"}}},
1004 {{"if_true"}, "Identity", {"switch:output_true:0"}, {{"T", "$T"}}},
1005 },
1006 /* Mapping between function returns and function node outputs. */
1007 {{"z", "if_false:output:0"}});
1008
1009 // Simple proxy functions that calls DeadOutputs from the function body.
1010 FunctionDef proxy_func = FunctionDefHelper::Create(
1011 "Proxy", {"x:T", "cond:bool"}, {"z:T"}, {"T: {float, double}"},
1012 {{{"dead"}, "DeadOutputs", {"x", "cond"}, {{"T", "$T"}}}},
1013 /* Mapping between function returns and function node outputs. */
1014 {{"z", "dead:z:0"}});
1015
1016 // Build a graph to compute:
1017 // a: float
1018 // b: bool
1019 // fn0 = DeadOutputs(x, b)
1020 // fn1 = Proxy(x, b)
1021 // out0 = Identity(fn0)
1022 // out1 = Identity(fn1)
1023 // return [out0, out1]
1024 //
1025 GrapplerItem item;
1026 item.fetch = {"out0", "out1"};
1027 item.graph = test::function::GDef(
1028 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1029 NDef("b", "Placeholder", {}, {{"dtype", DT_BOOL}}, kDevice),
1030
1031 NDef("fn0", "PartitionedCall", {"a", "b"},
1032 {{"Tin", DataTypeSlice{DT_FLOAT, DT_BOOL}},
1033 {"Tout", DataTypeSlice{DT_FLOAT}},
1034 {"f", FDH::FunctionRef("DeadOutputs", {{"T", DT_FLOAT}})}},
1035 kDevice),
1036
1037 NDef("fn1", "PartitionedCall", {"a", "b"},
1038 {{"Tin", DataTypeSlice{DT_FLOAT, DT_BOOL}},
1039 {"Tout", DataTypeSlice{DT_FLOAT}},
1040 {"f", FDH::FunctionRef("Proxy", {{"T", DT_FLOAT}})}},
1041 kDevice),
1042
1043 NDef("out0", "Identity", {"fn0"}, {{"T", DT_FLOAT}}, kDevice),
1044 NDef("out1", "Identity", {"fn1"}, {{"T", DT_FLOAT}}, kDevice)},
1045 // Function library.
1046 {dead_outputs, proxy_func});
1047
1048 GraphDef optimized_graph;
1049 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
1050
1051 GraphDef expected = item.graph;
1052 CompareGraphs(expected, optimized_graph);
1053
1054 const Tensor one = test::AsScalar<float>(1.0);
1055 item.feed.emplace_back("a", one);
1056 item.feed.emplace_back("b", test::AsScalar<bool>(false));
1057
1058 auto tensors = EvaluateFetchNodes(item);
1059 ASSERT_EQ(tensors.size(), 2);
1060 test::ExpectTensorEqual<float>(tensors[0], one);
1061 test::ExpectTensorEqual<float>(tensors[1], one);
1062 }
1063
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionWithMergedDeadTensors)1064 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithMergedDeadTensors) {
1065 using test::function::NDef;
1066 using FDH = FunctionDefHelper;
1067
1068 FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
1069
1070 // Function output can't be dead because it goes through the Merge node.
1071 FunctionDef no_dead_outputs = FunctionDefHelper::Create(
1072 "NoDeadOutputs", {"x:T", "cond:bool"}, {"z:T"}, {"T: {float, double}"},
1073 {
1074 {{"switch"}, "Switch", {"x", "cond"}, {{"T", "$T"}}},
1075 {{"if_false"}, "Identity", {"switch:output_false:0"}, {{"T", "$T"}}},
1076 {{"if_true"}, "Identity", {"switch:output_true:0"}, {{"T", "$T"}}},
1077 {{"merge"},
1078 "Merge",
1079 {"if_false:output:0", "if_true:output:0"},
1080 {{"T", "$T"}, {"N", 2}}},
1081 },
1082 /* Mapping between function returns and function node outputs. */
1083 {{"z", "merge:output:0"}});
1084
1085 // Build a graph to compute:
1086 // a: float
1087 // b: bool
1088 // d = DeadOutputs(x, b)
1089 // out = Identity(d)
1090 // return out
1091 //
1092 GrapplerItem item;
1093 TF_EXPECT_OK(item.AddDevice(kDevice)); // device for placing inlined function
1094 item.fetch = {"out"};
1095 item.graph = test::function::GDef(
1096 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1097 NDef("b", "Placeholder", {}, {{"dtype", DT_BOOL}}, kDevice),
1098
1099 NDef("fn", "PartitionedCall", {"a", "b"},
1100 {{"Tin", DataTypeSlice{DT_FLOAT, DT_BOOL}},
1101 {"Tout", DataTypeSlice{DT_FLOAT}},
1102 {"f", FDH::FunctionRef("NoDeadOutputs", {{"T", DT_FLOAT}})}},
1103 kDevice),
1104
1105 NDef("out", "Identity", {"fn"}, {{"T", DT_FLOAT}}, kDevice)},
1106 // Function library.
1107 {no_dead_outputs});
1108
1109 GraphDef optimized_graph;
1110 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
1111
1112 GraphDef expected = test::function::GDef(
1113 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1114 NDef("b", "Placeholder", {}, {{"dtype", DT_BOOL}}, kDevice),
1115
1116 // Function body of a first function call inlined into the graph.
1117 NDef("Func/fn/input/_0", "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
1118 NDef("Func/fn/input/_1", "Identity", {"b"}, {{"T", DT_BOOL}}, kDevice),
1119
1120 NDef("fn/switch", "Switch", {"Func/fn/input/_0", "Func/fn/input/_1"},
1121 {{"T", DT_FLOAT}}, kDevice),
1122 NDef("fn/if_false", "Identity", {"fn/switch"}, {{"T", DT_FLOAT}},
1123 kDevice),
1124 NDef("fn/if_true", "Identity", {"fn/switch:1"}, {{"T", DT_FLOAT}},
1125 kDevice),
1126 NDef("fn/merge", "Merge", {"fn/if_false", "fn/if_true"},
1127 {{"T", DT_FLOAT}, {"N", 2}}, kDevice),
1128
1129 NDef("Func/fn/output/_2", "Identity", {"fn/merge"}, {{"T", DT_FLOAT}},
1130 kDevice),
1131
1132 // Return directly from inlined function output node.
1133 NDef("out", "Identity", {"Func/fn/output/_2"}, {{"T", DT_FLOAT}},
1134 kDevice)},
1135
1136 // Function library.
1137 {no_dead_outputs});
1138
1139 CompareGraphs(expected, optimized_graph);
1140
1141 const Tensor one = test::AsScalar<float>(1.0);
1142 item.feed.emplace_back("a", one);
1143 item.feed.emplace_back("b", test::AsScalar<bool>(false));
1144
1145 auto tensors_expected = EvaluateFetchNodes(item);
1146 ASSERT_EQ(tensors_expected.size(), 1);
1147
1148 GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
1149 auto tensors = EvaluateFetchNodes(optimized);
1150 ASSERT_EQ(tensors.size(), 1);
1151
1152 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1153 }
1154
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionWithNestedFunctionCall)1155 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithNestedFunctionCall) {
1156 using test::function::NDef;
1157 using FDH = FunctionDefHelper;
1158
1159 FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
1160
1161 FunctionDef mul_func = FunctionDefHelper::Create(
1162 "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
1163 {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
1164 /* Mapping between function returns and function node outputs. */
1165 {{"z", "mul:z:0"}});
1166
1167 // `Square` implemented in terms of PartitionedCall to `MyMul`.
1168 FunctionDef square_func = FunctionDefHelper::Create(
1169 "MySquare", {"x:T"}, {"output:T"}, {"T: {float, double}"},
1170 {{{"square"},
1171 "PartitionedCall",
1172 {"x", "x"},
1173 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1174 {"Tout", DataTypeSlice{DT_FLOAT}},
1175 {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}}}},
1176 /* Mapping between function returns and function node outputs. */
1177 {{"output", "square:output:0"}});
1178
1179 // Build a graph to compute:
1180 // b = Square(a)
1181 // c = Identity(b)
1182 // return c
1183 GrapplerItem item;
1184 TF_EXPECT_OK(item.AddDevice(kDevice)); // device for placing inlined function
1185 item.fetch = {"c"};
1186 item.graph = test::function::GDef(
1187 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1188 NDef("b", "PartitionedCall", {"a"},
1189 {{"Tin", DataTypeSlice{DT_FLOAT}},
1190 {"Tout", DataTypeSlice{DT_FLOAT}},
1191 {"f", FDH::FunctionRef("MySquare", {{"T", DT_FLOAT}})}},
1192 kDevice),
1193 NDef("c", "Identity", {"b"}, {{"T", DT_FLOAT}}, kDevice)},
1194 /* Function library */
1195 {mul_func, square_func});
1196
1197 GraphDef optimized_graph;
1198 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
1199
1200 GraphDef expected = test::function::GDef(
1201 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1202
1203 // Inlined inputs of `b` node.
1204 NDef("Func/b/input/_0", "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
1205
1206 // Inlined inputs of `square` node inside inlined `MySquare` function.
1207 NDef("Func/b/square/input/_2", "Identity", {"Func/b/input/_0"},
1208 {{"T", DT_FLOAT}}, kDevice),
1209 NDef("Func/b/square/input/_3", "Identity", {"Func/b/input/_0"},
1210 {{"T", DT_FLOAT}}, kDevice),
1211
1212 // Inlined mul node from the `MyMul` function.
1213 NDef("b/square/mul", "Mul",
1214 {"Func/b/square/input/_2", "Func/b/square/input/_3"},
1215 {{"T", DT_FLOAT}}, kDevice),
1216
1217 NDef("Func/b/square/output/_4", "Identity", {"b/square/mul"},
1218 {{"T", DT_FLOAT}}, kDevice),
1219 NDef("Func/b/output/_1", "Identity", {"Func/b/square/output/_4"},
1220 {{"T", DT_FLOAT}}, kDevice),
1221
1222 NDef("c", "Identity", {"Func/b/output/_1"}, {{"T", DT_FLOAT}}, kDevice)},
1223 // Function library.
1224 {mul_func});
1225
1226 CompareGraphs(expected, optimized_graph);
1227
1228 Tensor three = test::AsScalar<float>(3.0f);
1229 item.feed.emplace_back("a", three);
1230
1231 GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
1232 auto tensors_expected = EvaluateFetchNodes(item);
1233 auto tensors = EvaluateFetchNodes(optimized);
1234 ASSERT_EQ(tensors_expected.size(), 1);
1235 ASSERT_EQ(tensors.size(), tensors_expected.size());
1236 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1237 }
1238
ConditionalAdd()1239 GrapplerItem ConditionalAdd() {
1240 // Returns the conditional (is_add) ? a + b : a * b;
1241 using test::function::NDef;
1242 using FDH = FunctionDefHelper;
1243
1244 FunctionDef add_func = FDH::Create(
1245 "MyAdd", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
1246 {{{"add"}, "Add", {"x", "y"}, {{"T", "$T"}}}},
1247 /* Mapping between function returns and function node outputs. */
1248 {{"z", "add:z:0"}});
1249
1250 FunctionDef mul_func = FDH::Create(
1251 "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
1252 {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
1253 /* Mapping between function returns and function node outputs. */
1254 {{"z", "mul:z:0"}});
1255
1256 // Compute: return cond ? a + b : a * b
1257 FunctionDef add_or_mul_func = FDH::Create(
1258 "AddOrMul", {"cond:bool", "x:float", "y:float"}, {"z:float"}, {},
1259 {
1260 {{"if_node"},
1261 "If",
1262 {"cond", "x", "y"},
1263 {
1264 {"Tcond", DT_BOOL},
1265 {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1266 {"Tout", DataTypeSlice{DT_FLOAT}},
1267 {"then_branch", FDH::FunctionRef("MyAdd", {{"T", DT_FLOAT}})},
1268 {"else_branch", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})},
1269 {"_lower_using_switch_merge", true},
1270 }},
1271 },
1272 /* Mapping between function returns and function node outputs. */
1273 {{"z", "if_node:output:0"}}, {{"side_effect", "if_node"}});
1274
1275 // Build a computation graph for:
1276 // is_add: bool
1277 // a: float
1278 // b: float
1279 // c = AddOrMul(is_add, a, b) # is_add ? a + b : a * b
1280 // d = Identity(c)
1281 // return d
1282
1283 // c = MyMul(a, b)
1284 GrapplerItem item;
1285 item.fetch = {"d"};
1286 item.graph = test::function::GDef(
1287 {NDef("is_add", "Placeholder", {}, {{"dtype", DT_BOOL}}, kDevice),
1288 NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1289 NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1290
1291 NDef("c", "PartitionedCall", {"is_add", "a", "b"},
1292 {{"Tin", DataTypeSlice{DT_BOOL, DT_FLOAT, DT_FLOAT}},
1293 {"Tout", DataTypeSlice{DT_FLOAT}},
1294 {"f", FDH::FunctionRef("AddOrMul")}},
1295 kDevice),
1296
1297 NDef("d", "Identity", {"c", "^c"}, {{"T", DT_FLOAT}}, kDevice)},
1298 // Function library.
1299 {add_or_mul_func, add_func, mul_func});
1300 return item;
1301 }
1302
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionWithFunctionalControlFlow)1303 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithFunctionalControlFlow) {
1304 FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE, true);
1305
1306 // item.fetch['d'] == (is_add) ? a + b : a * b
1307 GrapplerItem item = ConditionalAdd();
1308 GraphDef optimized_graph;
1309 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
1310
1311 const auto count_nodes_with_op = [&](const string& op) {
1312 return absl::c_count_if(optimized_graph.node(), [&](const NodeDef& node) {
1313 return node.op() == op;
1314 });
1315 };
1316
1317 // All `PartitionedCall` nodes in the optimized graph must be inlined, and
1318 // `If` node must be lowered to `Switch` and `Merge` nodes.
1319 EXPECT_EQ(count_nodes_with_op("PartitionedCall"), 0);
1320 EXPECT_EQ(count_nodes_with_op("If"), 0);
1321 EXPECT_EQ(count_nodes_with_op("Switch"), 3);
1322 EXPECT_EQ(count_nodes_with_op("Merge"), 2);
1323
1324 GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
1325
1326 Tensor one = test::AsScalar<float>(1.0);
1327 Tensor two = test::AsScalar<float>(2.0);
1328 Tensor three = test::AsScalar<float>(3.0);
1329
1330 const auto feed_args = [&](bool is_add) {
1331 std::vector<std::pair<string, Tensor>> feed;
1332 feed.emplace_back("a", one);
1333 feed.emplace_back("b", two);
1334 feed.emplace_back("is_add", test::AsScalar<bool>(is_add));
1335 return feed;
1336 };
1337
1338 { // Check 'is_add == true': a + b
1339 item.feed = feed_args(true);
1340 optimized.feed = feed_args(true);
1341
1342 auto tensors_expected = EvaluateFetchNodes(item);
1343 ASSERT_EQ(tensors_expected.size(), 1);
1344 test::ExpectTensorEqual<float>(tensors_expected[0], three);
1345
1346 auto tensors = EvaluateFetchNodes(optimized);
1347 ASSERT_EQ(tensors.size(), tensors_expected.size());
1348 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1349 }
1350
1351 { // Check 'is_add == false': a * b
1352 item.feed = feed_args(false);
1353 optimized.feed = feed_args(false);
1354
1355 auto tensors_expected = EvaluateFetchNodes(item);
1356 ASSERT_EQ(tensors_expected.size(), 1);
1357 test::ExpectTensorEqual<float>(tensors_expected[0], two);
1358
1359 auto tensors = EvaluateFetchNodes(optimized);
1360 ASSERT_EQ(tensors.size(), tensors_expected.size());
1361 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1362 }
1363 }
1364
TEST_F(FunctionOptimizerTest,InlineIndirectFunctionDontLowerControlFlow)1365 TEST_F(FunctionOptimizerTest, InlineIndirectFunctionDontLowerControlFlow) {
1366 FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE,
1367 /*lower_control_flow=*/false);
1368
1369 // item.fetch['d'] == (is_add) ? a + b : a * b
1370 GrapplerItem item = ConditionalAdd();
1371 GraphDef optimized_graph;
1372 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
1373
1374 const auto count_nodes_with_op = [&](const string& op) {
1375 return absl::c_count_if(optimized_graph.node(), [&](const NodeDef& node) {
1376 return node.op() == op;
1377 });
1378 };
1379
1380 // All `PartitionedCall` nodes in the optimized graph must be inlined, and
1381 // `If` node must be lowered to `Switch` and `Merge` nodes.
1382 EXPECT_EQ(count_nodes_with_op("PartitionedCall"), 0);
1383 EXPECT_EQ(count_nodes_with_op("If"), 1);
1384 EXPECT_EQ(count_nodes_with_op("Switch"), 0);
1385 EXPECT_EQ(count_nodes_with_op("Merge"), 0);
1386
1387 GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
1388
1389 Tensor one = test::AsScalar<float>(1.0);
1390 Tensor two = test::AsScalar<float>(2.0);
1391 Tensor three = test::AsScalar<float>(3.0);
1392
1393 const auto feed_args = [&](bool is_add) {
1394 std::vector<std::pair<string, Tensor>> feed;
1395 feed.emplace_back("a", one);
1396 feed.emplace_back("b", two);
1397 feed.emplace_back("is_add", test::AsScalar<bool>(is_add));
1398 return feed;
1399 };
1400
1401 { // Check 'is_add == true': a + b
1402 item.feed = feed_args(true);
1403 optimized.feed = feed_args(true);
1404
1405 auto tensors_expected = EvaluateFetchNodes(item);
1406 ASSERT_EQ(tensors_expected.size(), 1);
1407 test::ExpectTensorEqual<float>(tensors_expected[0], three);
1408
1409 auto tensors = EvaluateFetchNodes(optimized);
1410 ASSERT_EQ(tensors.size(), tensors_expected.size());
1411 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1412 }
1413
1414 { // Check 'is_add == false': a * b
1415 item.feed = feed_args(false);
1416 optimized.feed = feed_args(false);
1417
1418 auto tensors_expected = EvaluateFetchNodes(item);
1419 ASSERT_EQ(tensors_expected.size(), 1);
1420 test::ExpectTensorEqual<float>(tensors_expected[0], two);
1421
1422 auto tensors = EvaluateFetchNodes(optimized);
1423 ASSERT_EQ(tensors.size(), tensors_expected.size());
1424 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1425 }
1426 }
1427
TEST_F(FunctionOptimizerTest,SpecializeFunctionXTimesTwo)1428 TEST_F(FunctionOptimizerTest, SpecializeFunctionXTimesTwo) {
1429 using test::function::NDef;
1430
1431 FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
1432
1433 // Mark XTimesTwo as noinline.
1434 FunctionDef x_times_two = test::function::XTimesTwo();
1435 (*x_times_two.mutable_attr())["_noinline"].set_b(true);
1436 std::vector<FunctionDef> function_library = {x_times_two};
1437
1438 // Build a graph to compute y = XTimesTwo(x).
1439 GrapplerItem item;
1440 item.id = "tf_graph";
1441 item.graph = test::function::GDef(
1442 {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1443 NDef("y", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, kDevice),
1444 NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
1445 function_library);
1446
1447 GraphDef output;
1448 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
1449
1450 // Make sure that specialized function was added to the library and original
1451 // function was removed.
1452 EXPECT_EQ(1, output.library().function_size());
1453 EXPECT_EQ("XTimesTwo_specialized_for_y_at_tf_graph",
1454 output.library().function(0).signature().name());
1455
1456 // And 'y' node is calling specialized function.
1457 int count = 0;
1458 for (const NodeDef& node : output.node()) {
1459 if (node.name() == "y" && ++count) {
1460 EXPECT_EQ("XTimesTwo_specialized_for_y_at_tf_graph", node.op());
1461 }
1462 }
1463 EXPECT_EQ(1, count);
1464
1465 // And that graph evaluation yields the same result.
1466 Tensor pi = test::AsScalar<float>(3.14f);
1467 item.fetch = {"z"};
1468 item.feed.emplace_back("x", pi);
1469
1470 auto tensors_expected = EvaluateFetchNodes(item);
1471 GrapplerItem optimized = item.WithGraph(std::move(output));
1472 auto tensors = EvaluateFetchNodes(optimized);
1473 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1474 }
1475
TEST_F(FunctionOptimizerTest,SpecializeIndirectFunctionXTimesTwo)1476 TEST_F(FunctionOptimizerTest, SpecializeIndirectFunctionXTimesTwo) {
1477 using test::function::NDef;
1478 using FDH = FunctionDefHelper;
1479
1480 FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
1481
1482 // Mark XTimesTwo as noinline.
1483 FunctionDef x_times_two = test::function::XTimesTwo();
1484 (*x_times_two.mutable_attr())["_noinline"].set_b(true);
1485 std::vector<FunctionDef> function_library = {x_times_two};
1486
1487 // Tensorflow graph:
1488 // y = PartitionedCall[f=XTimesTwo, Tin=[DT_FLOAT], Tout=[DT_FLOAT]](x)
1489 GrapplerItem item;
1490 item.id = "tf_graph";
1491 item.graph = test::function::GDef(
1492 {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1493 NDef("y", "PartitionedCall", {"x"},
1494 {{"Tin", DataTypeSlice{DT_FLOAT}},
1495 {"Tout", DataTypeSlice{DT_FLOAT}},
1496 {"f", FDH::FunctionRef("XTimesTwo", {{"T", DT_FLOAT}})}},
1497 kDevice),
1498 NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
1499 function_library);
1500
1501 GraphDef output;
1502 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
1503
1504 // Make sure that specialized function was added to the library and original
1505 // function was removed.
1506 EXPECT_EQ(1, output.library().function_size());
1507 EXPECT_EQ("XTimesTwo_specialized_for_y_at_tf_graph",
1508 output.library().function(0).signature().name());
1509
1510 // And 'y' node is calling specialized function.
1511 int count = 0;
1512 for (const NodeDef& node : output.node()) {
1513 if (node.name() == "y" && ++count) {
1514 EXPECT_EQ("PartitionedCall", node.op());
1515 auto& func = AttrSlice(node).Find("f")->func();
1516 // Function calls into the specialized function.
1517 EXPECT_EQ("XTimesTwo_specialized_for_y_at_tf_graph", func.name());
1518 // And input/output types stay the same.
1519 auto& tin = AttrSlice(node).Find("Tin")->list();
1520 auto& tout = AttrSlice(node).Find("Tout")->list();
1521 ASSERT_EQ(1, tin.type_size());
1522 ASSERT_EQ(1, tout.type_size());
1523 EXPECT_EQ(DT_FLOAT, tin.type(0));
1524 EXPECT_EQ(DT_FLOAT, tout.type(0));
1525 }
1526 }
1527 EXPECT_EQ(1, count);
1528
1529 // And that graph evaluation yields the same result.
1530 Tensor pi = test::AsScalar<float>(3.14f);
1531 item.fetch = {"z"};
1532 item.feed.emplace_back("x", pi);
1533
1534 auto tensors_expected = EvaluateFetchNodes(item);
1535 GrapplerItem optimized = item.WithGraph(std::move(output));
1536 auto tensors = EvaluateFetchNodes(optimized);
1537 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1538 }
1539
TEST_F(FunctionOptimizerTest,SpecializeFunctionPushDownConstInput)1540 TEST_F(FunctionOptimizerTest, SpecializeFunctionPushDownConstInput) {
1541 using test::function::NDef;
1542
1543 FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
1544
1545 FunctionDef mul_func = FunctionDefHelper::Create(
1546 "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
1547 {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
1548 /* Mapping between function returns and function node outputs. */
1549 {{"z", "output:z:0"}});
1550
1551 // Mark MyMul as noinline.
1552 (*mul_func.mutable_attr())["_noinline"].set_b(true);
1553 std::vector<FunctionDef> function_library = {mul_func};
1554
1555 // Build a graph to compute y = MyMul(x, 2.0).
1556 const Tensor kTwo = test::AsScalar<float>(2.0);
1557
1558 GrapplerItem item;
1559 item.id = "tf_graph";
1560 item.graph = test::function::GDef(
1561 {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1562 NDef("init", "NoOp", {}, {}, kDevice),
1563 NDef("two", "Const", {"^init", "^x"},
1564 {{"dtype", DT_FLOAT}, {"value", kTwo}}, kDevice),
1565 NDef("y", "MyMul", {"x", "two"}, {{"T", DT_FLOAT}}, kDevice),
1566 NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
1567 function_library);
1568
1569 GraphDef output;
1570 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
1571
1572 // Make sure that specialized function was added to the library and original
1573 // function was removed.
1574 ASSERT_EQ(1, output.library().function_size());
1575
1576 const FunctionDef& specialized = output.library().function(0);
1577 EXPECT_EQ("MyMul_specialized_for_y_at_tf_graph",
1578 specialized.signature().name());
1579 EXPECT_EQ(1, specialized.signature().input_arg_size());
1580
1581 // And 'y' node has control dependencies of a pushed down const node.
1582 int count = 0;
1583 for (const NodeDef& node : output.node()) {
1584 if (node.name() == "y" && ++count) {
1585 ASSERT_EQ(2, node.input_size());
1586 EXPECT_EQ("x", node.input(0));
1587 EXPECT_EQ("^init", node.input(1));
1588 }
1589 }
1590 EXPECT_EQ(1, count);
1591
1592 // And that graph evaluation yields the same result.
1593 Tensor pi = test::AsScalar<float>(3.14f);
1594 item.fetch = {"z"};
1595 item.feed.emplace_back("x", pi);
1596
1597 auto tensors_expected = EvaluateFetchNodes(item);
1598 GrapplerItem optimized = item.WithGraph(std::move(output));
1599 auto tensors = EvaluateFetchNodes(optimized);
1600 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1601 }
1602
TEST_F(FunctionOptimizerTest,SpecializeIndirectFunctionPushDownConstInput)1603 TEST_F(FunctionOptimizerTest, SpecializeIndirectFunctionPushDownConstInput) {
1604 using test::function::NDef;
1605 using FDH = FunctionDefHelper;
1606
1607 FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
1608
1609 FunctionDef mul_func = FunctionDefHelper::Create(
1610 "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
1611 {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
1612 /* Mapping between function returns and function node outputs. */
1613 {{"z", "output:z:0"}});
1614
1615 // Mark MyMul as noinline.
1616 (*mul_func.mutable_attr())["_noinline"].set_b(true);
1617 std::vector<FunctionDef> function_library = {mul_func};
1618
1619 const Tensor kTwo = test::AsScalar<float>(2.0);
1620
1621 // Tensorflow graph:
1622 // y = PartitionedCall[Tin=[DT_FLOAT], Tout=[DT_FLOAT], f=MyMul](x, two)
1623 GrapplerItem item;
1624 item.id = "tf_graph";
1625 item.graph = test::function::GDef(
1626 {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1627 NDef("init", "NoOp", {}, {}, kDevice),
1628 NDef("two", "Const", {"^init", "^x"},
1629 {{"dtype", DT_FLOAT}, {"value", kTwo}}, kDevice),
1630 NDef("y", "PartitionedCall", {"x", "two"},
1631 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1632 {"Tout", DataTypeSlice{DT_FLOAT}},
1633 {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
1634 kDevice),
1635 NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
1636 function_library);
1637
1638 GraphDef output;
1639 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
1640
1641 // Make sure that specialized function was added to the library and original
1642 // function was removed.
1643 ASSERT_EQ(1, output.library().function_size());
1644
1645 const FunctionDef& specialized = output.library().function(0);
1646 EXPECT_EQ("MyMul_specialized_for_y_at_tf_graph",
1647 specialized.signature().name());
1648 EXPECT_EQ(1, specialized.signature().input_arg_size());
1649
1650 // And 'y' node has control dependencies of a pushed down const node.
1651 int count = 0;
1652 for (const NodeDef& node : output.node()) {
1653 if (node.name() == "y" && ++count) {
1654 EXPECT_EQ("PartitionedCall", node.op());
1655 ASSERT_EQ(2, node.input_size());
1656 EXPECT_EQ("x", node.input(0));
1657 EXPECT_EQ("^init", node.input(1));
1658 // Function calls into the specialized function.
1659 auto& func = AttrSlice(node).Find("f")->func();
1660 EXPECT_EQ("MyMul_specialized_for_y_at_tf_graph", func.name());
1661 // And input/output type lists were updated.
1662 auto& tin = AttrSlice(node).Find("Tin")->list();
1663 auto& tout = AttrSlice(node).Find("Tout")->list();
1664 ASSERT_EQ(1, tin.type_size());
1665 ASSERT_EQ(1, tout.type_size());
1666 EXPECT_EQ(DT_FLOAT, tin.type(0));
1667 EXPECT_EQ(DT_FLOAT, tout.type(0));
1668 }
1669 }
1670 ASSERT_EQ(1, count);
1671
1672 // And that graph evaluation yields the same result.
1673 Tensor pi = test::AsScalar<float>(3.14f);
1674 item.fetch = {"z"};
1675 item.feed.emplace_back("x", pi);
1676
1677 auto tensors_expected = EvaluateFetchNodes(item);
1678 GrapplerItem optimized = item.WithGraph(std::move(output));
1679 auto tensors = EvaluateFetchNodes(optimized);
1680 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1681 }
1682
TEST_F(FunctionOptimizerTest,SpecializeFunction_OncePerUniqueContext)1683 TEST_F(FunctionOptimizerTest, SpecializeFunction_OncePerUniqueContext) {
1684 using test::function::NDef;
1685
1686 FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
1687
1688 // Mark MyMul as noinline.
1689 FunctionDef mul_func = FunctionDefHelper::Create(
1690 "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, int32}"},
1691 {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
1692 /* Mapping between function returns and function node outputs. */
1693 {{"z", "output:z:0"}});
1694 (*mul_func.mutable_attr())["_noinline"].set_b(true);
1695 std::vector<FunctionDef> function_library = {mul_func};
1696
1697 const Tensor kTwo = test::AsScalar<float>(2.0);
1698 const Tensor kThree = test::AsScalar<float>(3.0);
1699
1700 GrapplerItem item;
1701 item.id = "tf_graph";
1702 item.graph = test::function::GDef(
1703 {NDef("init", "NoOp", {}, {}, kDevice),
1704
1705 // Float placeholders.
1706 NDef("xf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1707 NDef("yf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1708
1709 // Int32 placeholders.
1710 NDef("xi", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
1711 NDef("yi", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
1712
1713 // Consts. Control inputs has to be attached to specialized func calls.
1714 NDef("two", "Const", {"^init", "^xf"},
1715 {{"dtype", DT_FLOAT}, {"value", kTwo}}, kDevice),
1716 NDef("three", "Const", {"^init", "^xf"},
1717 {{"dtype", DT_FLOAT}, {"value", kThree}}, kDevice),
1718
1719 // Specialization #1: DT_FLOAT type parameter.
1720 NDef("mul_1", "MyMul", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice),
1721 NDef("mul_2", "MyMul", {"yf", "xf"}, {{"T", DT_FLOAT}}, kDevice),
1722
1723 // Specialization #2: DT_INT32 type parameter.
1724 NDef("mul_3", "MyMul", {"xi", "yi"}, {{"T", DT_INT32}}, kDevice),
1725
1726 // Specialization #3: DT_FLOAT type parameter + const input kTwo.
1727 NDef("mul_4", "MyMul", {"xf", "two"}, {{"T", DT_FLOAT}}, kDevice),
1728 NDef("mul_5", "MyMul", {"yf", "two"}, {{"T", DT_FLOAT}}, kDevice),
1729
1730 // Specialization #4: DT_FLOAT type parameter + const input kThree.
1731 NDef("mul_6", "MyMul", {"three", "xf"}, {{"T", DT_FLOAT}}, kDevice)},
1732 function_library);
1733
1734 // Specify fetch nodes before optimization to prevent pruning unused function
1735 // outputs.
1736 item.fetch = {"mul_1", "mul_2", "mul_3", "mul_4", "mul_5", "mul_6"};
1737
1738 GraphDef output;
1739 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
1740
1741 // Make sure that MyMul was specialized once per unique context.
1742 EXPECT_EQ(4, output.library().function_size());
1743
1744 // And graph nodes calling specialized functions.
1745 int count = 0;
1746 for (const NodeDef& node : output.node()) {
1747 if (node.name() == "mul_1" && ++count) {
1748 EXPECT_EQ("MyMul_specialized_for_mul_1_at_tf_graph", node.op());
1749 ASSERT_EQ(2, node.input_size());
1750 EXPECT_EQ("xf", node.input(0));
1751 EXPECT_EQ("yf", node.input(1));
1752
1753 } else if (node.name() == "mul_2" && ++count) {
1754 EXPECT_EQ("MyMul_specialized_for_mul_1_at_tf_graph", node.op());
1755 ASSERT_EQ(2, node.input_size());
1756 EXPECT_EQ("yf", node.input(0));
1757 EXPECT_EQ("xf", node.input(1));
1758
1759 } else if (node.name() == "mul_3" && ++count) {
1760 EXPECT_EQ("MyMul_specialized_for_mul_3_at_tf_graph", node.op());
1761 ASSERT_EQ(2, node.input_size());
1762 EXPECT_EQ("xi", node.input(0));
1763 EXPECT_EQ("yi", node.input(1));
1764
1765 } else if (node.name() == "mul_4" && ++count) {
1766 EXPECT_EQ("MyMul_specialized_for_mul_4_at_tf_graph", node.op());
1767 ASSERT_EQ(2, node.input_size());
1768 EXPECT_EQ("xf", node.input(0));
1769 EXPECT_EQ("^init", node.input(1));
1770
1771 } else if (node.name() == "mul_5" && ++count) {
1772 EXPECT_EQ("MyMul_specialized_for_mul_4_at_tf_graph", node.op());
1773 ASSERT_EQ(3, node.input_size());
1774 EXPECT_EQ("yf", node.input(0));
1775 gtl::FlatSet<string> expected_ctrl = {"^init", "^xf"};
1776 gtl::FlatSet<string> actual_ctrl = {node.input(1), node.input(2)};
1777 EXPECT_EQ(expected_ctrl, actual_ctrl);
1778
1779 } else if (node.name() == "mul_6" && ++count) {
1780 EXPECT_EQ("MyMul_specialized_for_mul_6_at_tf_graph", node.op());
1781 ASSERT_EQ(2, node.input_size());
1782 EXPECT_EQ("xf", node.input(0));
1783 EXPECT_EQ("^init", node.input(1));
1784 }
1785 }
1786 EXPECT_EQ(6, count);
1787
1788 // And that graph evaluation yields the same result.
1789 Tensor pi = test::AsScalar<float>(3.14f);
1790 Tensor four = test::AsScalar<int32>(4);
1791 item.feed = {{"xf", pi}, {"yf", pi}, {"xi", four}, {"yi", four}};
1792
1793 auto tensors_expected = EvaluateFetchNodes(item);
1794 GrapplerItem optimized = item.WithGraph(std::move(output));
1795 auto tensors = EvaluateFetchNodes(optimized);
1796
1797 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1798 test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
1799 test::ExpectTensorEqual<int32>(tensors_expected[2], tensors[2]);
1800 test::ExpectTensorEqual<float>(tensors_expected[3], tensors[3]);
1801 test::ExpectTensorEqual<float>(tensors_expected[4], tensors[4]);
1802 test::ExpectTensorEqual<float>(tensors_expected[5], tensors[5]);
1803 }
1804
TEST_F(FunctionOptimizerTest,SpecializeFunctionForUsedOutputTensors)1805 TEST_F(FunctionOptimizerTest, SpecializeFunctionForUsedOutputTensors) {
1806 using test::function::NDef;
1807
1808 FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
1809
1810 // MyFunc computes x*y three times and has three output values.
1811 FunctionDef my_func = FunctionDefHelper::Create(
1812 "MyFunc", {"x:T", "y:T"}, {"z1:T", "z2:T", "z3:T"}, {"T: {float, int32}"},
1813 {{{"output1"}, "Mul", {"x", "y"}, {{"T", "$T"}}},
1814 {{"output2"}, "Mul", {"x", "y"}, {{"T", "$T"}}},
1815 {{"output3"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
1816 /* Mapping between function returns and function node outputs. */
1817 {{"z1", "output1:z:0"}, {"z2", "output2:z:0"}, {"z3", "output3:z:0"}});
1818 (*my_func.mutable_attr())["_noinline"].set_b(true);
1819 std::vector<FunctionDef> function_library = {my_func};
1820
1821 GrapplerItem item;
1822 item.id = "tf_graph";
1823 item.graph = test::function::GDef(
1824 {NDef("init", "NoOp", {}, {}, kDevice),
1825
1826 // Float placeholders.
1827 NDef("xf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1828 NDef("yf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1829
1830 // Specialization #1: DT_FLOAT type parameter. All outputs used.
1831 NDef("fn1", "MyFunc", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice),
1832 NDef("use_fn1_0", "Identity", {"fn1:0"}, {{"T", DT_FLOAT}}, kDevice),
1833 NDef("use_fn1_1", "Identity", {"fn1:1"}, {{"T", DT_FLOAT}}, kDevice),
1834 NDef("use_fn1_2", "Identity", {"fn1:2"}, {{"T", DT_FLOAT}}, kDevice),
1835
1836 // Specialization #2: DT_FLOAT type parameter. Only first output used.
1837 NDef("fn2", "MyFunc", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice),
1838 NDef("use_fn2_0", "Identity", {"fn2:0"}, {{"T", DT_FLOAT}}, kDevice),
1839
1840 // Specialization #3: DT_FLOAT type parameter. Only second output used.
1841 NDef("fn3", "MyFunc", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice),
1842 NDef("use_fn3_1", "Identity", {"fn3:1"}, {{"T", DT_FLOAT}}, kDevice),
1843
1844 // Specialization #4: DT_FLOAT type parameter. Only last output used.
1845 NDef("fn4", "MyFunc", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice),
1846 NDef("use_fn4_2", "Identity", {"fn4:2"}, {{"T", DT_FLOAT}}, kDevice),
1847
1848 // Specialization #5: DT_FLOAT type parameter. First and last outputs.
1849 NDef("fn5", "MyFunc", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice),
1850 NDef("use_fn5_0", "Identity", {"fn5:0"}, {{"T", DT_FLOAT}}, kDevice),
1851 NDef("use_fn5_2", "Identity", {"fn5:2"}, {{"T", DT_FLOAT}}, kDevice),
1852
1853 // Specialization #6: DT_FLOAT type parameter. Outputs not used.
1854 // Check that function optimizer do not fail. In practice it should be
1855 // pruned from the graph before passing to function optimizer.
1856 NDef("fn6", "MyFunc", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice)},
1857 function_library);
1858
1859 GraphDef output;
1860 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
1861
1862 // Make sure that MyFunc was specialized once per unique context.
1863 EXPECT_EQ(6, output.library().function_size());
1864
1865 // And graph nodes calling specialized functions.
1866 int found = 0;
1867 for (const NodeDef& node : output.node()) {
1868 // All function caller nodes must be specialized.
1869 if (node.name() == "fn1" && ++found) {
1870 EXPECT_EQ("MyFunc_specialized_for_fn1_at_tf_graph", node.op());
1871 } else if (node.name() == "fn2" && ++found) {
1872 EXPECT_EQ("MyFunc_specialized_for_fn2_at_tf_graph", node.op());
1873 } else if (node.name() == "fn3" && ++found) {
1874 EXPECT_EQ("MyFunc_specialized_for_fn3_at_tf_graph", node.op());
1875 } else if (node.name() == "fn4" && ++found) {
1876 EXPECT_EQ("MyFunc_specialized_for_fn4_at_tf_graph", node.op());
1877 } else if (node.name() == "fn5" && ++found) {
1878 EXPECT_EQ("MyFunc_specialized_for_fn5_at_tf_graph", node.op());
1879 } else if (node.name() == "fn6" && ++found) {
1880 EXPECT_EQ("MyFunc_specialized_for_fn6_at_tf_graph", node.op());
1881 }
1882 // And all consumers of specialized function nodes must be mapped to new
1883 // output ports.
1884 if (node.name() == "use_fn3_1" && ++found) {
1885 EXPECT_EQ("fn3", node.input(0));
1886 } else if (node.name() == "use_fn4_2" && ++found) {
1887 EXPECT_EQ("fn4", node.input(0));
1888 } else if (node.name() == "use_fn5_0" && ++found) {
1889 EXPECT_EQ("fn5", node.input(0));
1890 } else if (node.name() == "use_fn5_2" && ++found) {
1891 EXPECT_EQ("fn5:1", node.input(0));
1892 }
1893 }
1894 EXPECT_EQ(10, found);
1895
1896 // And that graph evaluation yields the same result.
1897 Tensor pi = test::AsScalar<float>(3.14f);
1898 item.fetch = {"use_fn1_0", "use_fn1_1", "use_fn1_2", "use_fn2_0",
1899 "use_fn3_1", "use_fn4_2", "use_fn5_0", "use_fn5_2"};
1900 item.feed = {{"xf", pi}, {"yf", pi}};
1901
1902 auto tensors_expected = EvaluateFetchNodes(item);
1903 GrapplerItem optimized = item.WithGraph(std::move(output));
1904 auto tensors = EvaluateFetchNodes(optimized);
1905
1906 ASSERT_EQ(tensors_expected.size(), tensors.size());
1907 for (int i = 0; i < item.fetch.size(); ++i) {
1908 test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
1909 }
1910 }
1911
TEST_F(FunctionOptimizerTest,SpecializeIndirectFunctionForUsedOutputTensors)1912 TEST_F(FunctionOptimizerTest, SpecializeIndirectFunctionForUsedOutputTensors) {
1913 using test::function::NDef;
1914 using FDH = FunctionDefHelper;
1915
1916 FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
1917
1918 // MyFunc computes x*y three times and has three output values.
1919 FunctionDef my_func = FunctionDefHelper::Create(
1920 "MyFunc", {"x:T", "y:T"}, {"z1:T", "z2:T", "z3:T"}, {"T: {float, int32}"},
1921 {{{"output1"}, "Mul", {"x", "y"}, {{"T", "$T"}}},
1922 {{"output2"}, "Mul", {"x", "y"}, {{"T", "$T"}}},
1923 {{"output3"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
1924 /* Mapping between function returns and function node outputs. */
1925 {{"z1", "output1:z:0"}, {"z2", "output2:z:0"}, {"z3", "output3:z:0"}});
1926 (*my_func.mutable_attr())["_noinline"].set_b(true);
1927 std::vector<FunctionDef> function_library = {my_func};
1928
1929 GrapplerItem item;
1930 item.id = "tf_graph";
1931 item.graph = test::function::GDef(
1932 {NDef("init", "NoOp", {}, {}, kDevice),
1933
1934 // Float placeholders.
1935 NDef("xf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1936 NDef("yf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1937
1938 // Specialization #1: DT_FLOAT type parameter. All outputs used.
1939 NDef("fn1", "PartitionedCall", {"xf", "yf"},
1940 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1941 {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT}},
1942 {"f", FDH::FunctionRef("MyFunc", {{"T", DT_FLOAT}})}},
1943 kDevice),
1944 NDef("use_fn1_0", "Identity", {"fn1:0"}, {{"T", DT_FLOAT}}, kDevice),
1945 NDef("use_fn1_1", "Identity", {"fn1:1"}, {{"T", DT_FLOAT}}, kDevice),
1946 NDef("use_fn1_2", "Identity", {"fn1:2"}, {{"T", DT_FLOAT}}, kDevice),
1947
1948 // Specialization #2: DT_FLOAT type parameter. Only first output used.
1949 NDef("fn2", "PartitionedCall", {"xf", "yf"},
1950 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1951 {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT}},
1952 {"f", FDH::FunctionRef("MyFunc", {{"T", DT_FLOAT}})}},
1953 kDevice),
1954 NDef("use_fn2_0", "Identity", {"fn2:0"}, {{"T", DT_FLOAT}}, kDevice),
1955
1956 // Specialization #3: DT_FLOAT type parameter. Only second output used.
1957 NDef("fn3", "PartitionedCall", {"xf", "yf"},
1958 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1959 {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT}},
1960 {"f", FDH::FunctionRef("MyFunc", {{"T", DT_FLOAT}})}},
1961 kDevice),
1962 NDef("use_fn3_1", "Identity", {"fn3:1"}, {{"T", DT_FLOAT}}, kDevice),
1963
1964 // Specialization #4: DT_FLOAT type parameter. Only last output used.
1965 NDef("fn4", "PartitionedCall", {"xf", "yf"},
1966 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1967 {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT}},
1968 {"f", FDH::FunctionRef("MyFunc", {{"T", DT_FLOAT}})}},
1969 kDevice),
1970 NDef("use_fn4_2", "Identity", {"fn4:2"}, {{"T", DT_FLOAT}}, kDevice),
1971
1972 // Specialization #5: DT_FLOAT type parameter. First and last outputs.
1973 NDef("fn5", "PartitionedCall", {"xf", "yf"},
1974 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1975 {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT}},
1976 {"f", FDH::FunctionRef("MyFunc", {{"T", DT_FLOAT}})}},
1977 kDevice),
1978 NDef("use_fn5_0", "Identity", {"fn5:0"}, {{"T", DT_FLOAT}}, kDevice),
1979 NDef("use_fn5_2", "Identity", {"fn5:2"}, {{"T", DT_FLOAT}}, kDevice),
1980
1981 // Specialization #6: DT_FLOAT type parameter. Outputs not used.
1982 // Check that function optimizer do not fail. In practice it should be
1983 // pruned from the graph before passing to function optimizer.
1984 NDef("fn6", "PartitionedCall", {"xf", "yf"},
1985 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1986 {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT}},
1987 {"f", FDH::FunctionRef("MyFunc", {{"T", DT_FLOAT}})}},
1988 kDevice)},
1989 function_library);
1990
1991 GraphDef output;
1992 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
1993
1994 // Make sure that MyFunc was specialized once per unique context.
1995 EXPECT_EQ(6, output.library().function_size());
1996
1997 // And graph nodes calling specialized functions.
1998 int found = 0;
1999 for (const NodeDef& node : output.node()) {
2000 // All function caller nodes must be specialized.
2001 if (node.name() == "fn1" && ++found) {
2002 auto& func = AttrSlice(node).Find("f")->func();
2003 auto& tout = AttrSlice(node).Find("Tout")->list();
2004 EXPECT_EQ("PartitionedCall", node.op());
2005 EXPECT_EQ("MyFunc_specialized_for_fn1_at_tf_graph", func.name());
2006 ASSERT_EQ(3, tout.type_size());
2007
2008 } else if (node.name() == "fn2" && ++found) {
2009 auto& func = AttrSlice(node).Find("f")->func();
2010 auto& tout = AttrSlice(node).Find("Tout")->list();
2011 EXPECT_EQ("PartitionedCall", node.op());
2012 EXPECT_EQ("MyFunc_specialized_for_fn2_at_tf_graph", func.name());
2013 ASSERT_EQ(1, tout.type_size());
2014
2015 } else if (node.name() == "fn3" && ++found) {
2016 auto& func = AttrSlice(node).Find("f")->func();
2017 auto& tout = AttrSlice(node).Find("Tout")->list();
2018 EXPECT_EQ("PartitionedCall", node.op());
2019 EXPECT_EQ("MyFunc_specialized_for_fn3_at_tf_graph", func.name());
2020 ASSERT_EQ(1, tout.type_size());
2021
2022 } else if (node.name() == "fn4" && ++found) {
2023 auto& func = AttrSlice(node).Find("f")->func();
2024 auto& tout = AttrSlice(node).Find("Tout")->list();
2025 EXPECT_EQ("PartitionedCall", node.op());
2026 EXPECT_EQ("MyFunc_specialized_for_fn4_at_tf_graph", func.name());
2027 ASSERT_EQ(1, tout.type_size());
2028
2029 } else if (node.name() == "fn5" && ++found) {
2030 auto& func = AttrSlice(node).Find("f")->func();
2031 auto& tout = AttrSlice(node).Find("Tout")->list();
2032 EXPECT_EQ("PartitionedCall", node.op());
2033 EXPECT_EQ("MyFunc_specialized_for_fn5_at_tf_graph", func.name());
2034 ASSERT_EQ(2, tout.type_size());
2035
2036 } else if (node.name() == "fn6" && ++found) {
2037 auto& func = AttrSlice(node).Find("f")->func();
2038 auto& tout = AttrSlice(node).Find("Tout")->list();
2039 EXPECT_EQ("PartitionedCall", node.op());
2040 EXPECT_EQ("MyFunc_specialized_for_fn6_at_tf_graph", func.name());
2041 ASSERT_EQ(0, tout.type_size());
2042 }
2043 // And all consumers of specialized function nodes must be mapped to new
2044 // output ports.
2045 if (node.name() == "use_fn3_1" && ++found) {
2046 EXPECT_EQ("fn3", node.input(0));
2047 } else if (node.name() == "use_fn4_2" && ++found) {
2048 EXPECT_EQ("fn4", node.input(0));
2049 } else if (node.name() == "use_fn5_0" && ++found) {
2050 EXPECT_EQ("fn5", node.input(0));
2051 } else if (node.name() == "use_fn5_2" && ++found) {
2052 EXPECT_EQ("fn5:1", node.input(0));
2053 }
2054 }
2055 EXPECT_EQ(10, found);
2056
2057 // And that graph evaluation yields the same result.
2058 Tensor pi = test::AsScalar<float>(3.14f);
2059 item.fetch = {"use_fn1_0", "use_fn1_1", "use_fn1_2", "use_fn2_0",
2060 "use_fn3_1", "use_fn4_2", "use_fn5_0", "use_fn5_2"};
2061 item.feed = {{"xf", pi}, {"yf", pi}};
2062
2063 auto tensors_expected = EvaluateFetchNodes(item);
2064 GrapplerItem optimized = item.WithGraph(std::move(output));
2065 auto tensors = EvaluateFetchNodes(optimized);
2066
2067 ASSERT_EQ(tensors_expected.size(), tensors.size());
2068 for (int i = 0; i < item.fetch.size(); ++i) {
2069 test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
2070 }
2071 }
2072
TEST_F(FunctionOptimizerTest,PruningUselessLibraryFunctions)2073 TEST_F(FunctionOptimizerTest, PruningUselessLibraryFunctions) {
2074 using test::function::NDef;
2075 FunctionOptimizer optimizer(RewriterConfig::DEFAULT, true);
2076 auto func = test::function::XTimesTwo();
2077 (*func.mutable_attr())["_noinline"].set_b(true);
2078 GrapplerItem item;
2079 item.id = "test_graph";
2080 item.graph = test::function::GDef(
2081 {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, "/device:CPU:0"),
2082 NDef("y", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, "/device:CPU:0"),
2083 NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, "/device:CPU:0")},
2084 // FunctionLib
2085 {
2086 func,
2087 test::function::XTimesTwoInt32(),
2088 test::function::XTimes16(),
2089 });
2090 GraphDef output;
2091 Status status = optimizer.Optimize(nullptr, item, &output);
2092 TF_EXPECT_OK(status);
2093
2094 ASSERT_EQ(output.library().function().size(), 1);
2095 EXPECT_EQ(output.library().function(0).signature().name(),
2096 "XTimesTwo_specialized_for_y_at_test_graph");
2097 }
2098
2099 } // namespace grappler
2100 } // namespace tensorflow
2101