1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
17
18 #include <atomic>
19
20 #include "absl/strings/match.h"
21 #include "absl/strings/substitute.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/grappler/grappler_item.h"
28 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
29 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
30 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
31 #include "tensorflow/core/grappler/utils.h"
32 #include "tensorflow/core/grappler/utils/grappler_test.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/lib/gtl/map_util.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/test.h"
37 #include "tensorflow/core/protobuf/config.pb.h"
38
39 namespace tensorflow {
40 namespace grappler {
41 namespace {
42
43 constexpr char kDevice[] = "/device:CPU:0";
44
45 class TestOptimizer : public CustomGraphOptimizer {
46 public:
SetOptimized(const bool flag_value)47 static void SetOptimized(const bool flag_value) { optimized_ = flag_value; }
IsOptimized()48 static bool IsOptimized() { return optimized_; }
49
TestOptimizer()50 TestOptimizer() {}
name() const51 string name() const override { return "test_optimizer"; }
UsesFunctionLibrary() const52 bool UsesFunctionLibrary() const override { return false; }
53
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config=nullptr)54 Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* config =
55 nullptr) override {
56 return OkStatus();
57 }
58
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)59 Status Optimize(Cluster* cluster, const GrapplerItem& item,
60 GraphDef* optimized_graph) override {
61 optimized_ = true;
62 *optimized_graph = item.graph;
63 return OkStatus();
64 }
65
66 private:
67 static bool optimized_;
68 };
69
70 bool TestOptimizer::optimized_;
71
72 REGISTER_GRAPH_OPTIMIZER(TestOptimizer);
73
74 class TestGraphOptimizer : public TestOptimizer {
75 public:
name() const76 string name() const override { return "test_graph_optimizer"; }
77 };
78
79 REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer);
80
81 class TestOptimizerWithParams : public TestOptimizer {
82 public:
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)83 Status Init(
84 const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
85 CHECK(config != nullptr);
86 return OkStatus();
87 }
88 };
89
90 REGISTER_GRAPH_OPTIMIZER(TestOptimizerWithParams);
91
92 // Record various properties of the GrapplerItems passed for optimization.
93 class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer {
94 public:
SetOptimizationOptions(gtl::FlatMap<string,GrapplerItem::OptimizationOptions> * optimization_options)95 static void SetOptimizationOptions(
96 gtl::FlatMap<string, GrapplerItem::OptimizationOptions>*
97 optimization_options) {
98 optimization_options_ = optimization_options;
99 }
ResetOptimizationOptions()100 static void ResetOptimizationOptions() { optimization_options_ = nullptr; }
101
GrapplerItemPropertiesAccumulator()102 GrapplerItemPropertiesAccumulator() {}
name() const103 string name() const override {
104 return "grappler_item_properties_accumulator";
105 }
UsesFunctionLibrary() const106 bool UsesFunctionLibrary() const override { return false; }
107
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)108 Status Init(
109 const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
110 return OkStatus();
111 }
112
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)113 Status Optimize(Cluster* cluster, const GrapplerItem& item,
114 GraphDef* optimized_graph) override {
115 *optimized_graph = item.graph;
116 if (optimization_options_) {
117 optimization_options_->insert({item.id, item.optimization_options()});
118 }
119 return OkStatus();
120 }
121
122 private:
123 static gtl::FlatMap<string, GrapplerItem::OptimizationOptions>*
124 optimization_options_;
125 };
126
127 gtl::FlatMap<string, GrapplerItem::OptimizationOptions>*
128 GrapplerItemPropertiesAccumulator::optimization_options_;
129
130 REGISTER_GRAPH_OPTIMIZER(GrapplerItemPropertiesAccumulator);
131
132 class MetaOptimizerTest : public GrapplerTest {};
133
TEST_F(MetaOptimizerTest,RunsCustomOptimizer)134 TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
135 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
136 GrapplerItem item;
137 ASSERT_TRUE(fake_input.NextItem(&item));
138
139 TestOptimizer::SetOptimized(false);
140 ConfigProto config_proto;
141 auto& rewriter_config =
142 *config_proto.mutable_graph_options()->mutable_rewrite_options();
143 rewriter_config.add_optimizers("TestOptimizer");
144 rewriter_config.set_min_graph_nodes(-1);
145
146 MetaOptimizer optimizer(nullptr, config_proto);
147 GraphDef output;
148 const Status status = optimizer.Optimize(nullptr, item, &output);
149 TF_EXPECT_OK(status);
150 EXPECT_TRUE(TestOptimizer::IsOptimized());
151 }
152
TEST_F(MetaOptimizerTest,RunsCustomOptimizerWithParams)153 TEST_F(MetaOptimizerTest, RunsCustomOptimizerWithParams) {
154 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
155 GrapplerItem item;
156 ASSERT_TRUE(fake_input.NextItem(&item));
157
158 TestOptimizer::SetOptimized(false);
159 ConfigProto config_proto;
160 auto& rewriter_config =
161 *config_proto.mutable_graph_options()->mutable_rewrite_options();
162 rewriter_config.add_optimizers("TestOptimizerWithParams");
163 auto* custom_config = rewriter_config.add_custom_optimizers();
164 custom_config->set_name("TestOptimizerWithParams");
165 (*custom_config->mutable_parameter_map())["foo"] = AttrValue();
166
167 MetaOptimizer optimizer(nullptr, config_proto);
168 GraphDef output;
169 const Status status = optimizer.Optimize(nullptr, item, &output);
170 TF_EXPECT_OK(status);
171 EXPECT_TRUE(TestOptimizer::IsOptimized());
172 }
173
TEST_F(MetaOptimizerTest,RunsCustomOptimizerAndCustomGraphOptimizer)174 TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) {
175 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
176 GrapplerItem item;
177 ASSERT_TRUE(fake_input.NextItem(&item));
178
179 TestOptimizer::SetOptimized(false);
180 TestGraphOptimizer::SetOptimized(false);
181 ConfigProto config_proto;
182 auto& rewriter_config =
183 *config_proto.mutable_graph_options()->mutable_rewrite_options();
184 rewriter_config.add_optimizers("TestOptimizer");
185 auto customGraphOptimizer = rewriter_config.add_custom_optimizers();
186 customGraphOptimizer->set_name("TestGraphOptimizer");
187 rewriter_config.set_min_graph_nodes(-1);
188
189 MetaOptimizer optimizer(nullptr, config_proto);
190 GraphDef output;
191 const Status status = optimizer.Optimize(nullptr, item, &output);
192 TF_EXPECT_OK(status);
193 EXPECT_TRUE(TestOptimizer::IsOptimized());
194 EXPECT_TRUE(TestGraphOptimizer::IsOptimized());
195 }
196
TEST_F(MetaOptimizerTest,RunsPluginOptimizer)197 TEST_F(MetaOptimizerTest, RunsPluginOptimizer) {
198 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"/device:GPU:0"});
199 GrapplerItem item;
200 ASSERT_TRUE(fake_input.NextItem(&item));
201
202 TestOptimizer::SetOptimized(false);
203 ConfigProto config_proto;
204 auto& rewriter_config =
205 *config_proto.mutable_graph_options()->mutable_rewrite_options();
206 rewriter_config.set_min_graph_nodes(-1);
207
208 const auto creator = []() { return new TestOptimizer; };
209 ConfigList config_list;
210 config_list.disable_model_pruning = true;
211 PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(creator, "GPU",
212 config_list);
213
214 MetaOptimizer optimizer(nullptr, config_proto);
215 GraphDef output;
216 const Status status = optimizer.Optimize(nullptr, item, &output);
217 TF_EXPECT_OK(status);
218 EXPECT_TRUE(TestOptimizer::IsOptimized());
219 }
220
TEST_F(MetaOptimizerTest,RunOptimizersTwice)221 TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
222 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
223 GrapplerItem item;
224 ASSERT_TRUE(fake_input.NextItem(&item));
225
226 ConfigProto config_proto;
227 auto& rewriter_config =
228 *config_proto.mutable_graph_options()->mutable_rewrite_options();
229 rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
230 rewriter_config.set_min_graph_nodes(-1);
231
232 MetaOptimizer optimizer(nullptr, config_proto);
233 GraphDef output;
234 const Status status = optimizer.Optimize(nullptr, item, &output);
235 TF_EXPECT_OK(status);
236 }
237
TEST_F(MetaOptimizerTest,RunToggleOptimizersAndCustomGraphOptimizerTwice)238 TEST_F(MetaOptimizerTest, RunToggleOptimizersAndCustomGraphOptimizerTwice) {
239 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
240 GrapplerItem item;
241 ASSERT_TRUE(fake_input.NextItem(&item));
242
243 ConfigProto config_proto;
244 auto& rewriter_config =
245 *config_proto.mutable_graph_options()->mutable_rewrite_options();
246 auto customGraphOptimizer = rewriter_config.add_custom_optimizers();
247 customGraphOptimizer->set_name("TestGraphOptimizer");
248 rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
249 rewriter_config.set_min_graph_nodes(-1);
250
251 MetaOptimizer optimizer(nullptr, config_proto);
252 GraphDef output;
253 const Status status = optimizer.Optimize(nullptr, item, &output);
254 TF_EXPECT_OK(status);
255 EXPECT_TRUE(TestGraphOptimizer::IsOptimized());
256 }
257
TEST_F(MetaOptimizerTest,OptimizeFunctionLibrary)258 TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
259 using test::function::NDef;
260
261 // Enable only function optimization.
262 ConfigProto config_proto;
263 auto& rewriter_config =
264 *config_proto.mutable_graph_options()->mutable_rewrite_options();
265
266 rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
267 rewriter_config.set_function_optimization(RewriterConfig::ON);
268 rewriter_config.add_optimizers("function");
269 rewriter_config.set_min_graph_nodes(-1);
270
271 MetaOptimizer optimizer(nullptr, config_proto);
272
273 // Define function library:
274 //
275 // MyMul(x, y) = x * y
276 // *MySquare(x) = MyMul(x, x)
277 // *MyQuadratic(x) = MySquare(MySquare(x))
278 //
279 // * - marked as noinline
280
281 FunctionDef mul_func = FunctionDefHelper::Create(
282 "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
283 {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
284 /*ret_def=*/
285 {{"z", "mul:z:0"}});
286
287 FunctionDef square_func = FunctionDefHelper::Create(
288 "MySquare", {"x:T"}, {"z:T"}, {"T: {float, double}"},
289 {{{"my_mul"}, "MyMul", {"x", "x"}, {{"T", "$T"}}}},
290 /*ret_def=*/
291 {{"z", "my_mul:z:0"}});
292 (*square_func.mutable_attr())["_noinline"].set_b(true);
293
294 FunctionDef quadratic_func = FunctionDefHelper::Create(
295 "MyQuadratic", {"x:T"}, {"z:T"}, {"T: {float, double}"},
296 {{{"square"}, "MySquare", {"x"}, {{"T", "$T"}}},
297 {{"quadratic"}, "MySquare", {"square:z"}, {{"T", "$T"}}}},
298 /*ret_def=*/
299 {{"z", "quadratic:z:0"}});
300 (*quadratic_func.mutable_attr())["_noinline"].set_b(true);
301
302 // Tensorflow graph:
303 //
304 // a = tf.Placeholder(tf.float);
305 // b = tf.Placeholder(tf.int32);
306 //
307 // square = MySquare(a); // a^2
308 // quadratic = MyQuadratic(b); // b^4
309 GrapplerItem item;
310 item.id = "tf_graph";
311 item.graph = test::function::GDef(
312 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
313 NDef("b", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
314 // Calls into function library
315 NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}}, kDevice),
316 NDef("quadratic", "MyQuadratic", {"b"}, {{"T", DT_INT32}}, kDevice),
317 // Forward outputs
318 NDef("out_s", "Identity", {"square:0"}, {{"T", DT_FLOAT}}, kDevice),
319 NDef("out_q", "Identity", {"quadratic:0"}, {{"T", DT_INT32}}, kDevice)},
320 /*funcs=*/
321 {mul_func, square_func, quadratic_func});
322
323 GraphDef output;
324 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
325
326 FunctionLibraryDefinition optimized_flib(OpRegistry::Global(),
327 output.library());
328
329 // Specialized and optimized functions should be added to the graph.
330 EXPECT_EQ(3, optimized_flib.num_functions());
331
332 // Get a specialized function name.
333 const auto specialized_name = [](const string& fn, const string& node,
334 const string& id) {
335 return absl::Substitute("$0_specialized_for_$1_at_$2", fn, node, id);
336 };
337
338 // MyQuadratic should be specialized once:
339 // 0. 'quadratic' node in the main graph
340 const string optimized_0 =
341 specialized_name("MyQuadratic", "quadratic", "tf_graph");
342
343 // MySquare should be specialized and optimized for 3 instantiations:
344 // 1. 'square' node in the main graph
345 // 2. 'square' node in the MyQuadratic specialization
346 // 3*. 'quadratic' node in the MyQuadratic specialization
347 // has identical instantiation context to #2
348
349 const string optimized_1 = specialized_name("MySquare", "square", "tf_graph");
350 const string optimized_2 =
351 specialized_name("MySquare", "square", optimized_0);
352
353 const FunctionDef* optimized_func_0 = optimized_flib.Find(optimized_0);
354 const FunctionDef* optimized_func_1 = optimized_flib.Find(optimized_1);
355 const FunctionDef* optimized_func_2 = optimized_flib.Find(optimized_2);
356
357 ASSERT_NE(optimized_func_0, nullptr);
358 ASSERT_NE(optimized_func_1, nullptr);
359 ASSERT_NE(optimized_func_2, nullptr);
360
361 // Graph should call optimized function.
362 int count = 0;
363 for (const NodeDef& node : output.node()) {
364 if (node.name() == "square" && ++count) {
365 EXPECT_EQ(optimized_1, node.op());
366 } else if (node.name() == "quadratic" && ++count) {
367 EXPECT_EQ(optimized_0, node.op());
368 }
369 }
370 EXPECT_EQ(2, count);
371
372 // Specialized MySquare should call specialized functions.
373 count = 0;
374 for (const NodeDef& node : optimized_func_0->node_def()) {
375 if (node.name() == "square" && ++count) {
376 EXPECT_EQ(optimized_2, node.op());
377 } else if (node.name() == "quadratic" && ++count) {
378 EXPECT_EQ(optimized_2, node.op());
379 }
380 }
381 EXPECT_EQ(2, count);
382
383 const std::vector<const FunctionDef*> optimized_funcs = {optimized_func_1,
384 optimized_func_2};
385
386 // MyMul should be inlined into all optimized versions of MySquare.
387 for (const FunctionDef* optimized_func : optimized_funcs) {
388 count = 0;
389 for (const NodeDef& node : optimized_func->node_def()) {
390 if (node.name() == "Func/my_mul/input/_0" && ++count) {
391 EXPECT_EQ("Identity", node.op());
392 EXPECT_EQ(1, node.input_size());
393 EXPECT_EQ("x", node.input(0));
394 } else if (node.name() == "Func/my_mul/input/_1" && ++count) {
395 EXPECT_EQ("Identity", node.op());
396 EXPECT_EQ(1, node.input_size());
397 EXPECT_EQ("x", node.input(0));
398 } else if (node.name() == "my_mul/mul" && ++count) {
399 EXPECT_EQ("Mul", node.op());
400 EXPECT_EQ(2, node.input_size());
401 EXPECT_EQ("Func/my_mul/input/_0:output:0", node.input(0));
402 EXPECT_EQ("Func/my_mul/input/_1:output:0", node.input(1));
403 }
404 EXPECT_TRUE(node.device().empty());
405 }
406 EXPECT_EQ(3, count);
407 ASSERT_EQ(1, optimized_func->ret().size());
408 EXPECT_EQ("Func/my_mul/output/_2:output:0", optimized_func->ret().at("z"));
409 }
410
411 item.fetch = {"out_s", "out_q"};
412 item.feed.emplace_back("a", test::AsScalar<float>(2.0f));
413 item.feed.emplace_back("b", test::AsScalar<int>(4));
414 auto tensors_expected = EvaluateFetchNodes(item);
415
416 GrapplerItem optimized = item.WithGraph(std::move(output));
417 auto tensors = EvaluateFetchNodes(optimized);
418
419 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
420 test::ExpectTensorEqual<int>(tensors_expected[1], tensors[1]);
421 }
422
TEST_F(MetaOptimizerTest,OptimizeFunctionLibraryPruneUnusedOutputs)423 TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryPruneUnusedOutputs) {
424 using test::function::NDef;
425
426 ConfigProto config_proto;
427 MetaOptimizer optimizer(nullptr, config_proto);
428
429 // MyMul computes x*y three times and has three output values.
430 FunctionDef my_mul = FunctionDefHelper::Create(
431 "MyMul", {"x:T", "y:T"}, {"z0:T", "z1:T", "z2:T"}, {"T: {float, int32}"},
432 {{{"output0"}, "Mul", {"x", "y"}, {{"T", "$T"}}},
433 {{"output1"}, "Mul", {"x", "y"}, {{"T", "$T"}}},
434 {{"output2"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
435 /*ret_def=*/
436 {{"z0", "output0:z:0"}, {"z1", "output1:z:0"}, {"z2", "output2:z:0"}});
437
438 // Call MyMyl and forward all three outputs.
439 FunctionDef my_fwd = FunctionDefHelper::Create(
440 "Fwd", {"x:T", "y:T"}, {"z0:T", "z1:T", "z2:T"}, {"T: {float, int32}"},
441 {{{"output"}, "MyMul", {"x", "y"}, {{"T", "$T"}}}},
442 /*ret_def=*/
443 {{"z0", "output:z0:0"}, {"z1", "output:z1:0"}, {"z2", "output:z2:0"}});
444
445 // Mark both functions as `_noinline` to trigger specialization.
446 (*my_mul.mutable_attr())["_noinline"].set_b(true);
447 (*my_fwd.mutable_attr())["_noinline"].set_b(true);
448 /*funcs=*/
449 std::vector<FunctionDef> function_library = {my_mul, my_fwd};
450
451 // Tensorflow graph:
452 // a = Placeholder[T=float]
453 // b = Placeholder[T=float]
454 // fwd = Fwd(a, b)
455 //
456 // Fetch fwd:2 via Identity node.
457 GrapplerItem item;
458 item.id = "tf_graph";
459 item.fetch = {"ret"};
460 item.graph = test::function::GDef(
461 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
462 NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
463 NDef("fwd", "Fwd", {"a", "b"}, {{"T", DT_FLOAT}}, kDevice),
464 NDef("ret", "Identity", {"fwd:2"}, {{"T", DT_FLOAT}}, kDevice)},
465 function_library);
466
467 GraphDef output;
468 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
469
470 FunctionLibraryDefinition optimized_flib(OpRegistry::Global(),
471 output.library());
472
473 // Specialized functions should be added to the graph.
474 EXPECT_EQ(2, optimized_flib.num_functions());
475
476 // Expected names of the specialized functions.
477 const string specialized_my_fwd = "Fwd_specialized_for_fwd_at_tf_graph";
478 const string specialized_my_mul =
479 absl::StrCat("MyMul_specialized_for_output_at_", specialized_my_fwd);
480
481 // Specialized MyMul should have just one output argument.
482 FunctionDef expected_my_mul = FunctionDefHelper::Create(
483 specialized_my_mul, {"x:float", "y:float"}, {"z2:float"}, {},
484 {{{"output2"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
485 /*ret_def=*/
486 {{"z2", "output2:z:0"}});
487
488 // Specialized Fwd should also have just one output argument.
489 FunctionDef expected_my_fwd = FunctionDefHelper::Create(
490 specialized_my_fwd, {"x:float", "y:float"}, {"z2:float"}, {},
491 {{{"output"}, specialized_my_mul, {"x", "y"}, {{"T", DT_FLOAT}}}},
492 /*ret_def=*/
493 {{"z2", "output:z2:0"}});
494
495 const FunctionDef* my_mul_spec = optimized_flib.Find(specialized_my_mul);
496 const FunctionDef* my_fwd_spec = optimized_flib.Find(specialized_my_fwd);
497
498 ASSERT_NE(my_mul_spec, nullptr);
499 ASSERT_NE(my_fwd_spec, nullptr);
500
501 CompareFunctions(expected_my_mul, *my_mul_spec);
502 CompareFunctions(expected_my_fwd, *my_fwd_spec);
503
504 item.feed.emplace_back("a", test::AsScalar<float>(2.0f));
505 item.feed.emplace_back("b", test::AsScalar<float>(4.0f));
506 auto tensors_expected = EvaluateFetchNodes(item);
507
508 GrapplerItem optimized = item.WithGraph(std::move(output));
509 auto tensors = EvaluateFetchNodes(optimized);
510
511 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
512 }
513
TEST_F(MetaOptimizerTest,OptimizeFunctionLibraryPruneFunctionBody)514 TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryPruneFunctionBody) {
515 using test::function::NDef;
516
517 // Enable function optimization and pruning.
518 ConfigProto config_proto;
519 auto& rewriter_config =
520 *config_proto.mutable_graph_options()->mutable_rewrite_options();
521
522 rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
523 rewriter_config.set_function_optimization(RewriterConfig::ON);
524 rewriter_config.add_optimizers("function");
525 rewriter_config.add_optimizers("pruning");
526 rewriter_config.set_min_graph_nodes(-1);
527
528 MetaOptimizer optimizer(nullptr, config_proto);
529
530 // MyFunc defines two Mul nodes inside function body and two corresponding
531 // function outputs.
532 FunctionDef my_func = FunctionDefHelper::Create(
533 "MyFunc", {"x:T", "y:T"}, {"z1:T", "z2:T"}, {"T: {float, double}"},
534 {{{"mul1"}, "Mul", {"x", "y"}, {{"T", "$T"}}},
535 {{"mul2"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
536 /*ret_def=*/
537 {{"z1", "mul1:z:0"}, {"z2", "mul2:z:0"}});
538 (*my_func.mutable_attr())["_noinline"].set_b(true);
539
540 // Tensorflow graph:
541 //
542 // a = tf.Placeholder(tf.float);
543 // b = tf.Placeholder(tf.int32);
544 //
545 // fn1 = MyFunc(a, b);
546 // fn2 = MyFunc(a, b);
547 //
548 // Fetch: fn1:0 and fn2:1 via Identity nodes.
549 GrapplerItem item;
550 item.id = "tf_graph";
551 item.graph = test::function::GDef(
552 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
553 NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
554 // Calls into function library
555 NDef("fn1", "MyFunc", {"a", "b"}, {{"T", DT_FLOAT}}, kDevice),
556 NDef("fn2", "MyFunc", {"a", "b"}, {{"T", DT_FLOAT}}, kDevice),
557 // Read outputs of function call nodes
558 NDef("out_fn1", "Identity", {"fn1:0"}, {{"T", DT_FLOAT}}, kDevice),
559 NDef("out_fn2", "Identity", {"fn2:1"}, {{"T", DT_FLOAT}}, kDevice)},
560 /*funcs=*/
561 {my_func});
562
563 GraphDef output;
564 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
565
566 FunctionLibraryDefinition optimized_flib(OpRegistry::Global(),
567 output.library());
568
569 // Specialized and optimized functions should be added to the graph.
570 EXPECT_EQ(2, optimized_flib.num_functions());
571
572 // Expected names of the specialized and optimized functions.
573 const string optimized_fn1 = "MyFunc_specialized_for_fn1_at_tf_graph";
574 const string optimized_fn2 = "MyFunc_specialized_for_fn2_at_tf_graph";
575
576 const FunctionDef* optimized_func_fn1 = optimized_flib.Find(optimized_fn1);
577 const FunctionDef* optimized_func_fn2 = optimized_flib.Find(optimized_fn2);
578
579 ASSERT_NE(optimized_func_fn1, nullptr);
580 ASSERT_NE(optimized_func_fn2, nullptr);
581
582 // Graph should call optimized function.
583 int count = 0;
584 for (const NodeDef& node : output.node()) {
585 if (node.name() == "fn1" && ++count) {
586 EXPECT_EQ(optimized_fn1, node.op());
587 } else if (node.name() == "fn2" && ++count) {
588 EXPECT_EQ(optimized_fn2, node.op());
589 }
590 }
591 EXPECT_EQ(2, count);
592
593 // Specialized MyFuncs should have just one Mul node and single output arg.
594
595 // 1. Specialized for fn1:0.
596 ASSERT_EQ(1, optimized_func_fn1->node_def_size());
597 EXPECT_EQ(1, optimized_func_fn1->signature().output_arg_size());
598 EXPECT_EQ("z1", optimized_func_fn1->signature().output_arg(0).name());
599 EXPECT_EQ("mul1", optimized_func_fn1->node_def(0).name());
600
601 // 2. Specialized for fn2:1.
602 ASSERT_EQ(1, optimized_func_fn2->node_def_size());
603 EXPECT_EQ(1, optimized_func_fn2->signature().output_arg_size());
604 EXPECT_EQ("z2", optimized_func_fn2->signature().output_arg(0).name());
605 EXPECT_EQ("mul2", optimized_func_fn2->node_def(0).name());
606
607 // Verify that output tensors are equal.
608 item.fetch = {"out_fn1", "out_fn2"};
609 item.feed.emplace_back("a", test::AsScalar<float>(2.0f));
610 item.feed.emplace_back("b", test::AsScalar<float>(3.123f));
611 auto tensors_expected = EvaluateFetchNodes(item);
612
613 GrapplerItem optimized = item.WithGraph(std::move(output));
614 auto tensors = EvaluateFetchNodes(optimized);
615
616 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
617 test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
618 }
619
TEST_F(MetaOptimizerTest,OptimizeFunctionLibraryWithRestrictions)620 TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) {
621 using test::function::NDef;
622 using FDH = FunctionDefHelper;
623
624 // We will record what type of optimizations meta optimizer allows for each
625 // GrapplerItem (main graph and graphs for each function).
626 gtl::FlatMap<string, GrapplerItem::OptimizationOptions> optimization_options;
627 GrapplerItemPropertiesAccumulator::SetOptimizationOptions(
628 &optimization_options);
629
630 // Just record properties of optimized Grappler items.
631 ConfigProto config_proto;
632 auto& rewriter_config =
633 *config_proto.mutable_graph_options()->mutable_rewrite_options();
634
635 rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
636 rewriter_config.add_optimizers("GrapplerItemPropertiesAccumulator");
637 rewriter_config.set_min_graph_nodes(-1);
638
639 MetaOptimizer optimizer(nullptr, config_proto);
640
641 // Define simple function library with two identical mul functions.
642 FunctionDef mul_func_1 = FunctionDefHelper::Create(
643 "MyMul1", {"x:float", "y:float"}, {"z:float"}, {},
644 {{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
645 /*ret_def=*/
646 {{"z", "mul:z:0"}});
647
648 FunctionDef mul_func_2 = FunctionDefHelper::Create(
649 "MyMul2", {"x:float", "y:float"}, {"z:float"}, {},
650 {{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
651 /*ret_def=*/
652 {{"z", "mul:z:0"}});
653
654 // Tensorflow graph:
655 //
656 // x0 = tf.Placeholder(tf.float);
657 // x1 = tf.Placeholder(tf.float);
658 // dy = tf.Placeholder(tf.float);
659 //
660 // mul_1 = MyMul1(x0, x1);
661 // mul_2 = MyMul2(x0, x1);
662 // dx = SymbolicGradient({x0, x1, dy}, f=MyMul2)
663 GrapplerItem item;
664 item.id = "main";
665 item.graph = test::function::GDef(
666 {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
667 NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
668 NDef("dy", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
669 // Calls into function library
670 NDef("mul_1", "MyMul1", {"x0", "x1"}, {}, kDevice),
671 NDef("mul_2", "MyMul2", {"x0", "x1"}, {}, kDevice),
672 // Symbolic gradient of a MyMul2
673 NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"},
674 {{"f", FDH::FunctionRef("MyMul2", {})},
675 {"Tin", DataTypeSlice{DT_FLOAT}},
676 {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT}}},
677 kDevice)},
678 /*funcs=*/
679 {mul_func_1, mul_func_2});
680 item.fetch = {"mul_1", "mul_2", "dx"};
681
682 GraphDef output;
683 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
684
685 // Our custom optimizer must be called for the main graph and for the two
686 // functions.
687 ASSERT_EQ(optimization_options.size(), 3);
688
689 auto optimization_options_main =
690 gtl::FindOrNull(optimization_options, "main");
691 ASSERT_NE(optimization_options_main, nullptr);
692 EXPECT_TRUE(optimization_options_main->allow_non_differentiable_rewrites);
693
694 auto optimization_options_my_mul_1 =
695 gtl::FindOrNull(optimization_options, "MyMul1");
696 ASSERT_NE(optimization_options_my_mul_1, nullptr);
697 EXPECT_TRUE(optimization_options_my_mul_1->allow_non_differentiable_rewrites);
698
699 auto optimization_options_my_mul_2 =
700 gtl::FindOrNull(optimization_options, "MyMul2");
701 ASSERT_NE(optimization_options_my_mul_2, nullptr);
702 EXPECT_FALSE(
703 optimization_options_my_mul_2->allow_non_differentiable_rewrites);
704 }
705
706 class SleepingOptimizer : public CustomGraphOptimizer {
707 public:
SleepingOptimizer()708 SleepingOptimizer() {}
name() const709 string name() const override { return "test_optimizer"; }
UsesFunctionLibrary() const710 bool UsesFunctionLibrary() const override { return false; }
711
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)712 Status Init(
713 const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
714 return OkStatus();
715 }
716
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)717 Status Optimize(Cluster* cluster, const GrapplerItem& item,
718 GraphDef* optimized_graph) override {
719 *optimized_graph = item.graph;
720 Env::Default()->SleepForMicroseconds(1000000);
721 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
722 optimized_graph->add_node();
723 return OkStatus();
724 }
725 };
726
727 REGISTER_GRAPH_OPTIMIZER(SleepingOptimizer);
728
TEST_F(MetaOptimizerTest,OptimizerTimesOut)729 TEST_F(MetaOptimizerTest, OptimizerTimesOut) {
730 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
731 GrapplerItem item;
732 ASSERT_TRUE(fake_input.NextItem(&item));
733
734 ConfigProto config;
735 RewriterConfig& rewriter_config =
736 *config.mutable_graph_options()->mutable_rewrite_options();
737 rewriter_config.add_optimizers("SleepingOptimizer");
738 rewriter_config.set_min_graph_nodes(-1);
739 rewriter_config.set_meta_optimizer_timeout_ms(500);
740 rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
741
742 GraphDef output;
743 GraphDef original = item.graph;
744 const Status status =
745 RunMetaOptimizer(std::move(item), config, nullptr, nullptr, &output);
746 EXPECT_EQ(status.error_message(), "meta_optimizer exceeded deadline.");
747 // Make sure the graph was reverted to the original regardless of when the
748 // optimizer timed out.
749 CompareGraphs(original, output);
750 }
751
TEST_F(MetaOptimizerTest,MetaOptimizerTimesOut)752 TEST_F(MetaOptimizerTest, MetaOptimizerTimesOut) {
753 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
754 GrapplerItem item;
755 ASSERT_TRUE(fake_input.NextItem(&item));
756
757 ConfigProto config;
758 RewriterConfig& rewriter_config =
759 *config.mutable_graph_options()->mutable_rewrite_options();
760 rewriter_config.add_optimizers("SleepingOptimizer");
761 rewriter_config.set_min_graph_nodes(-1);
762 rewriter_config.set_meta_optimizer_timeout_ms(1500);
763 rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
764
765 GraphDef output;
766 const int original_node_size = item.graph.node_size();
767 const Status status =
768 RunMetaOptimizer(std::move(item), config, nullptr, nullptr, &output);
769 EXPECT_EQ(status.error_message(), "meta_optimizer exceeded deadline.");
770 // The meta optimizer should manage to finish one iteration.
771 EXPECT_EQ(original_node_size + 1, output.node_size());
772 }
773
TEST_F(MetaOptimizerTest,OptimizerDoesNotTimeOut)774 TEST_F(MetaOptimizerTest, OptimizerDoesNotTimeOut) {
775 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
776 GrapplerItem item;
777 ASSERT_TRUE(fake_input.NextItem(&item));
778
779 ConfigProto config;
780 RewriterConfig& rewriter_config =
781 *config.mutable_graph_options()->mutable_rewrite_options();
782 rewriter_config.add_optimizers("SleepingOptimizer");
783 rewriter_config.set_min_graph_nodes(-1);
784 rewriter_config.set_meta_optimizer_timeout_ms(2500);
785 rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
786 GraphDef output;
787 const int original_node_size = item.graph.node_size();
788 const Status status =
789 RunMetaOptimizer(std::move(item), config, nullptr, nullptr, &output);
790 TF_EXPECT_OK(status);
791 // The meta optimizer should manage to finish two iterations.
792 EXPECT_EQ(original_node_size + 2, output.node_size());
793 }
794
TEST_F(MetaOptimizerTest,RunPostOptimizationVerifiersOnValidGraph)795 TEST_F(MetaOptimizerTest, RunPostOptimizationVerifiersOnValidGraph) {
796 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
797 GrapplerItem item;
798 ASSERT_TRUE(fake_input.NextItem(&item));
799
800 ConfigProto config_proto;
801 auto& post_optimization_verifier_config =
802 *config_proto.mutable_graph_options()
803 ->mutable_rewrite_options()
804 ->mutable_post_optimization_verifier_config();
805 post_optimization_verifier_config.set_structure_verifier(VerifierConfig::ON);
806
807 MetaOptimizer optimizer(nullptr, config_proto);
808 GraphDef output;
809 const Status status = optimizer.Optimize(nullptr, item, &output);
810 TF_EXPECT_OK(status);
811 }
812
TEST_F(MetaOptimizerTest,RunInterOptimizerVerifiersOnValidGraph)813 TEST_F(MetaOptimizerTest, RunInterOptimizerVerifiersOnValidGraph) {
814 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
815 GrapplerItem item;
816 ASSERT_TRUE(fake_input.NextItem(&item));
817
818 ConfigProto config_proto;
819 auto& inter_optimizer_verifier_config =
820 *config_proto.mutable_graph_options()
821 ->mutable_rewrite_options()
822 ->mutable_inter_optimizer_verifier_config();
823 inter_optimizer_verifier_config.set_structure_verifier(VerifierConfig::ON);
824
825 MetaOptimizer optimizer(nullptr, config_proto);
826 GraphDef output;
827 const Status status = optimizer.Optimize(nullptr, item, &output);
828 TF_EXPECT_OK(status);
829 }
830
TEST_F(MetaOptimizerTest,RunPostOptimizationVerifiersOnInvalidGraph)831 TEST_F(MetaOptimizerTest, RunPostOptimizationVerifiersOnInvalidGraph) {
832 using test::function::NDef;
833 using FDH = FunctionDefHelper;
834
835 gtl::FlatMap<string, GrapplerItem::OptimizationOptions> optimization_options;
836 GrapplerItemPropertiesAccumulator::SetOptimizationOptions(
837 &optimization_options);
838
839 // Define simple function library with two identical mul functions.
840 FunctionDef mul_func_1 =
841 FunctionDefHelper::Create("MyMul1", {"x:float", "y:float"}, {"z:float"},
842 {}, {{{"mul"}, "Mul", {"x", "y"}, {}}},
843 /*ret_def=*/
844 {{"z", "mul:z:0"}});
845
846 FunctionDef mul_func_2 =
847 FunctionDefHelper::Create("MyMul2", {"x:float", "y:float"}, {"z:float"},
848 {}, {{{"mul"}, "Mul", {"x", "y"}, {}}},
849 /*ret_def=*/
850 {{"z", "mul:z:0"}});
851
852 // Tensorflow graph:
853 //
854 // x0 = tf.Placeholder(tf.float);
855 // x1 = tf.Placeholder(tf.float);
856 // dy = tf.Placeholder(tf.float);
857 //
858 // mul_1 = MyMul1(x0, x1);
859 // mul_2 = MyMul2(x0, x1);
860 // dx = SymbolicGradient({x0, x1, dy}, f=MyMul2)
861 GrapplerItem item;
862 item.id = "main";
863 item.graph = test::function::GDef(
864 {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
865 NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
866 NDef("dy", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
867 // Calls into function library
868 NDef("mul_1", "MyMul1", {"x0", "x1"}, {}, kDevice),
869 NDef("mul_2", "MyMul2", {"x0", "x1"}, {}, kDevice),
870 // Symbolic gradient of a MyMul2
871 NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"},
872 {{"f", FDH::FunctionRef("MyMul2", {})},
873 {"Tin", DataTypeSlice{DT_FLOAT}},
874 {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT}}},
875 kDevice)},
876 /*funcs=*/
877 {mul_func_1, mul_func_2});
878 item.fetch = {"mul_1", "mul_2", "dx"};
879
880 GraphDef output;
881
882 // Call Optimize with post optimization verifiers.
883 ConfigProto config_proto;
884 auto& rewriter_config =
885 *config_proto.mutable_graph_options()->mutable_rewrite_options();
886
887 rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
888 rewriter_config.add_optimizers("GrapplerItemPropertiesAccumulator");
889 rewriter_config.set_min_graph_nodes(-1);
890 auto& post_optimization_verifier_config =
891 *config_proto.mutable_graph_options()
892 ->mutable_rewrite_options()
893 ->mutable_post_optimization_verifier_config();
894 post_optimization_verifier_config.set_structure_verifier(VerifierConfig::ON);
895
896 MetaOptimizer optimizer_with_post_verifiers(nullptr, config_proto);
897 Status status =
898 optimizer_with_post_verifiers.Optimize(nullptr, item, &output);
899 EXPECT_TRUE(errors::IsInvalidArgument(status));
900 EXPECT_TRUE(absl::StrContains(
901 status.error_message(),
902 "NodeDef expected inputs 'float' do not match 3 inputs specified"));
903 }
904
TEST_F(MetaOptimizerTest,RunInterOptimizerVerifiersOnInvalidGraph)905 TEST_F(MetaOptimizerTest, RunInterOptimizerVerifiersOnInvalidGraph) {
906 using test::function::NDef;
907 using FDH = FunctionDefHelper;
908
909 gtl::FlatMap<string, GrapplerItem::OptimizationOptions> optimization_options;
910 GrapplerItemPropertiesAccumulator::SetOptimizationOptions(
911 &optimization_options);
912
913 // Define simple function library with two identical mul functions.
914 FunctionDef mul_func_1 =
915 FunctionDefHelper::Create("MyMul1", {"x:float", "y:float"}, {"z:float"},
916 {}, {{{"mul"}, "Mul", {"x", "y"}, {}}},
917 /*ret_def=*/
918 {{"z", "mul:z:0"}});
919
920 FunctionDef mul_func_2 =
921 FunctionDefHelper::Create("MyMul2", {"x:float", "y:float"}, {"z:float"},
922 {}, {{{"mul"}, "Mul", {"x", "y"}, {}}},
923 /*ret_def=*/
924 {{"z", "mul:z:0"}});
925
926 // Tensorflow graph:
927 //
928 // x0 = tf.Placeholder(tf.float);
929 // x1 = tf.Placeholder(tf.float);
930 // dy = tf.Placeholder(tf.float);
931 //
932 // mul_1 = MyMul1(x0, x1);
933 // mul_2 = MyMul2(x0, x1);
934 // dx = SymbolicGradient({x0, x1, dy}, f=MyMul2)
935 GrapplerItem item;
936 item.id = "main";
937 item.graph = test::function::GDef(
938 {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
939 NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
940 NDef("dy", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
941 NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
942 // Calls into function library
943 NDef("mul_1", "MyMul1", {"x0", "x1"}, {}, kDevice),
944 NDef("mul_2", "MyMul2", {"x0", "x1"}, {}, kDevice),
945 // Symbolic gradient of a MyMul2
946 NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"},
947 {{"f", FDH::FunctionRef("MyMul2", {})},
948 {"Tin", DataTypeSlice{DT_FLOAT}},
949 {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT}}},
950 kDevice)},
951 /*funcs=*/
952 {mul_func_1, mul_func_2});
953 item.fetch = {"mul_1", "mul_2", "dx"};
954
955 GraphDef output;
956
957 // Call Optimize with post optimization verifiers.
958 ConfigProto config_proto;
959 // Call Optimize with inter optimizer verifiers.
960 auto& rewriter_config =
961 *config_proto.mutable_graph_options()->mutable_rewrite_options();
962 rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
963 rewriter_config.add_optimizers("GrapplerItemPropertiesAccumulator");
964 rewriter_config.set_min_graph_nodes(-1);
965 auto& inter_optimizer_verifier_config =
966 *config_proto.mutable_graph_options()
967 ->mutable_rewrite_options()
968 ->mutable_inter_optimizer_verifier_config();
969 inter_optimizer_verifier_config.set_structure_verifier(VerifierConfig::ON);
970
971 MetaOptimizer optimizer_with_inter_verifiers(nullptr, config_proto);
972 Status status =
973 optimizer_with_inter_verifiers.Optimize(nullptr, item, &output);
974 EXPECT_EQ(status.code(), errors::Code::INVALID_ARGUMENT);
975 EXPECT_TRUE(absl::StrContains(
976 status.error_message(),
977 "NodeDef expected inputs 'float' do not match 3 inputs specified"));
978 }
979
TEST_F(MetaOptimizerTest,CompressConstants)980 TEST_F(MetaOptimizerTest, CompressConstants) {
981 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
982 Tensor zeros_t(DT_FLOAT, TensorShape({64}));
983 Tensor ones_t(DT_FLOAT, TensorShape({64}));
984 for (int i = 0; i < 64; ++i) {
985 zeros_t.flat<float>()(i) = 0.0f;
986 ones_t.flat<float>()(i) = 1.0f;
987 }
988 Output zeros = ops::Const(scope.WithOpName("zeros"), zeros_t);
989 Output host_ones = ops::Const(scope.WithOpName("host_ones"), ones_t);
990 GrapplerItem item;
991 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
992 ASSERT_EQ(item.graph.node(1).name(), "host_ones");
993 // There is not C++ api for HostConst, so we manually change the node type
994 // here.
995 item.graph.mutable_node(1)->set_op("HostConst");
996 item.fetch = {"zeros", "host_ones"};
997 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
998
999 ConfigProto config_proto;
1000 auto& rewriter_config =
1001 *config_proto.mutable_graph_options()->mutable_rewrite_options();
1002 rewriter_config.set_min_graph_nodes(-1);
1003 MetaOptimizer optimizer(/*cpu_device=*/nullptr, config_proto);
1004 GraphDef output;
1005 TF_EXPECT_OK(optimizer.Optimize(/*cluster=*/nullptr, item, &output));
1006
1007 bool found_zeros = false;
1008 bool found_host_ones = false;
1009 ASSERT_EQ(output.node_size(), 2);
1010 for (const auto& node : output.node()) {
1011 if (node.name() == "zeros") {
1012 found_zeros = true;
1013 EXPECT_EQ(node.op(), "Const");
1014 const TensorProto& zeroes_t = node.attr().at("value").tensor();
1015 EXPECT_EQ(zeroes_t.float_val_size(), 0);
1016 } else if (node.name() == "host_ones") {
1017 found_host_ones = true;
1018 EXPECT_EQ(node.op(), "HostConst");
1019 const TensorProto& ones_t = node.attr().at("value").tensor();
1020 EXPECT_EQ(ones_t.float_val_size(), 1);
1021 EXPECT_EQ(ones_t.float_val(0), 1.0f);
1022 }
1023 }
1024
1025 EXPECT_TRUE(found_zeros);
1026 EXPECT_TRUE(found_host_ones);
1027
1028 auto tensors = EvaluateNodes(output, item.fetch, {});
1029 ASSERT_EQ(tensors.size(), 2);
1030 ASSERT_EQ(tensors_expected.size(), 2);
1031 for (int i = 0; i < 2; ++i) {
1032 test::ExpectTensorEqual<float>(tensors[i], tensors_expected[i]);
1033 }
1034 }
1035
TEST_F(MetaOptimizerTest,TestTFGRemoveDeadArguments)1036 TEST_F(MetaOptimizerTest, TestTFGRemoveDeadArguments) {
1037 using test::function::NDef;
1038
1039 gtl::FlatMap<string, GrapplerItem::OptimizationOptions> optimization_options;
1040 GrapplerItemPropertiesAccumulator::SetOptimizationOptions(
1041 &optimization_options);
1042
1043 // Define a simple function library with one branch function.
1044 // def branch_func(x, y):
1045 // z = tf.Mul(x, x)
1046 // return z
1047 FunctionDef case_func = FunctionDefHelper::Create(
1048 "branch_func", {"x:float", "y:float"}, {"z:float"}, {},
1049 {{{"mul"}, "Mul", {"x", "x"}, {{"T", DT_FLOAT}}}},
1050 /*ret_def=*/
1051 {{"z", "mul:z:0"}});
1052
1053 // Tensorflow graph:
1054 //
1055 // idx = tf.Placeholder(tf.int32);
1056 // x = tf.Placeholder(tf.float);
1057 // y = tf.Placeholder(tf.float);
1058 //
1059 // case = tf.Case(idx, x, y, branches=[branch_func])
1060 GrapplerItem item;
1061 item.id = "main";
1062
1063 AttrValue branches;
1064 branches.mutable_list()->add_func()->set_name("branch_func");
1065 AttrValue output_shapes;
1066 output_shapes.mutable_list()->add_shape();
1067 item.graph = test::function::GDef(
1068 {NDef("idx", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
1069 NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1070 NDef("y", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1071 // Calls into function library
1072 NDef("case", "Case", {"idx", "x", "y"},
1073 {{"branches", std::move(branches)},
1074 {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1075 {"Tout", DataTypeSlice{DT_FLOAT}},
1076 {"output_shapes", std::move(output_shapes)}},
1077 kDevice)},
1078 /*funcs=*/
1079 {case_func});
1080 item.fetch = {"case"};
1081
1082 GraphDef output;
1083 ConfigProto config_proto;
1084
1085 MetaOptimizer optimizer(nullptr, config_proto);
1086 Status status = optimizer.Optimize(nullptr, item, &output);
1087 EXPECT_TRUE(status.ok());
1088 EXPECT_EQ(output.library().function_size(), 1);
1089 // One of the arguments was removed.
1090 auto& func = output.library().function(0);
1091 EXPECT_EQ(func.signature().input_arg_size(), 1);
1092 EXPECT_EQ(func.signature().input_arg(0).name(), "x_tfg_result_0");
1093 }
1094
TEST_F(MetaOptimizerTest,TestTFGControlFlowSink)1095 TEST_F(MetaOptimizerTest, TestTFGControlFlowSink) {
1096 using test::function::NDef;
1097
1098 gtl::FlatMap<string, GrapplerItem::OptimizationOptions> optimization_options;
1099 GrapplerItemPropertiesAccumulator::SetOptimizationOptions(
1100 &optimization_options);
1101
1102 // Define a branch function.
1103 // def branch_func(x, y):
1104 // z = tf.Mul(x, y)
1105 // return z
1106 FunctionDef case_func = FunctionDefHelper::Create(
1107 "branch_func", {"x:float", "y:float"}, {"z:float"}, {},
1108 {{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
1109 /*ret_def=*/
1110 {{"z", "mul:z:0"}});
1111
1112 // Define a function with a control-flow op.
1113 // def Foo(idx, a, b):
1114 // x_foo = Add(a, b)
1115 // y_foo = Mul(a, b)
1116 // case = Case(idx, x_foo, y_foo, branches=[branch_func[)
1117 // return case
1118 AttrValue branches;
1119 branches.mutable_list()->add_func()->set_name("branch_func");
1120 AttrValue output_shapes;
1121 output_shapes.mutable_list()->add_shape();
1122 FunctionDef foo_func = FunctionDefHelper::Create(
1123 "Foo", {"idx:int32", "a:float", "b:float"}, {"c:float"}, {},
1124 {{{"add"}, "Add", {"a", "b"}, {{"T", DT_FLOAT}}},
1125 {{"mul"}, "Mul", {"a", "b"}, {{"T", DT_FLOAT}}},
1126 {{"case"},
1127 "Case",
1128 {"idx", "add:z:0", "mul:z:0"},
1129 {{"branches", std::move(branches)},
1130 {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1131 {"Tout", DataTypeSlice{DT_FLOAT}},
1132 {"output_shapes", std::move(output_shapes)}}}},
1133 /*ret_def=*/
1134 {{"c", "case:output:0"}});
1135 (*foo_func.mutable_attr())["_noinline"].set_b(true);
1136
1137 // Tensorflow graph:
1138 //
1139 // idx = tf.Placeholder(tf.int32);
1140 // a = tf.Placeholder(tf.float);
1141 // b = tf.Placeholder(tf.float);
1142 //
1143 // foo_val = Foo(idx, a, b)
1144 GrapplerItem item;
1145 item.id = "main";
1146
1147 item.graph = test::function::GDef(
1148 {NDef("idx", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
1149 NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1150 NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1151 // Calls into function library
1152 NDef("foo", "Foo", {"idx", "a", "b"}, {}, kDevice)},
1153 /*funcs=*/
1154 {case_func, foo_func});
1155 item.fetch = {"foo"};
1156
1157 GraphDef output;
1158 ConfigProto config_proto;
1159
1160 MetaOptimizer optimizer(nullptr, config_proto);
1161 Status status = optimizer.Optimize(nullptr, item, &output);
1162 TF_EXPECT_OK(status);
1163 EXPECT_EQ(output.library().function_size(), 2);
1164
1165 const FunctionDef* optimized_foo_func = nullptr;
1166 const FunctionDef* specialized_branch_func = nullptr;
1167 for (const FunctionDef& func : output.library().function()) {
1168 if (func.signature().name() == "Foo")
1169 optimized_foo_func = &func;
1170 else if (absl::StartsWith(func.signature().name(), "branch_func"))
1171 specialized_branch_func = &func;
1172 }
1173 ASSERT_TRUE(optimized_foo_func);
1174 EXPECT_EQ(optimized_foo_func->node_def_size(), 1);
1175 ASSERT_TRUE(specialized_branch_func);
1176 EXPECT_EQ(specialized_branch_func->node_def_size(), 3);
1177 }
1178
1179 // Tests for checking expected behavior when skipping tf.data functions in
1180 // meta optimizer.
1181
1182 // Custom optimizer which counts the number of calls of its method `Optimize`
1183 // across all class instances.
1184 class TfDataTestOptimizer : public CustomGraphOptimizer {
1185 public:
InitCount()1186 static void InitCount() { count_ = 0; }
GetCount()1187 static int GetCount() { return count_; }
1188
1189 TfDataTestOptimizer() = default;
1190 ~TfDataTestOptimizer() override = default;
1191 TfDataTestOptimizer(const TfDataTestOptimizer&) = delete;
1192 TfDataTestOptimizer& operator=(const TfDataTestOptimizer& other) = delete;
1193
name() const1194 std::string name() const override { return "tf_data_test_optimizer"; }
UsesFunctionLibrary() const1195 bool UsesFunctionLibrary() const override { return false; }
1196
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)1197 Status Init(
1198 const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
1199 return OkStatus();
1200 }
1201
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)1202 Status Optimize(Cluster* cluster, const GrapplerItem& item,
1203 GraphDef* optimized_graph) override {
1204 ++count_;
1205 *optimized_graph = item.graph;
1206 return OkStatus();
1207 }
1208
1209 private:
1210 static std::atomic<int> count_;
1211 };
1212
1213 std::atomic<int> TfDataTestOptimizer::count_;
1214
1215 REGISTER_GRAPH_OPTIMIZER(TfDataTestOptimizer);
1216
1217 // Type for specifying how the inner function is nested inside the outer
1218 // function.
1219 enum class FuncNestingType {
1220 CallFromNode = 0,
1221 CallFromAttr = 1,
1222 CallFromList = 2
1223 };
1224
1225 // Test fixture for parametrized testing.
1226 class TfDataTestFixture
1227 : public ::testing::TestWithParam<std::tuple<bool, bool, FuncNestingType>> {
1228 protected:
SetUp()1229 void SetUp() override {
1230 is_inner_func_tf_data_ = std::get<0>(GetParam());
1231 is_outer_func_tf_data_ = std::get<1>(GetParam());
1232 func_nesting_type_ = std::get<2>(GetParam());
1233 }
1234 // Controls which of the functions is flagged as tf.data function.
1235 bool is_inner_func_tf_data_ = false;
1236 bool is_outer_func_tf_data_ = false;
1237 // Controls how the inner function is nested inside the outer function.
1238 FuncNestingType func_nesting_type_ = FuncNestingType::CallFromNode;
1239 };
1240
1241 // Helper functions for setting up the call of `inner_func` inside of
1242 // `outer_func`.
1243
SetUpCallFromNode(FunctionDef & outer_func)1244 void SetUpCallFromNode(FunctionDef& outer_func) {
1245 // Call `inner_func` from a node in `outer_func`.
1246 outer_func = FunctionDefHelper::Create(
1247 "outer_func", {"x:float"}, {"z:float"}, {},
1248 /*node_def=*/
1249 {{{"inner_func"}, "inner_func", {"x", "x"}, {{"T", DT_FLOAT}}}},
1250 /*ret_def=*/
1251 {{"z", "inner_func:z:0"}});
1252 }
1253
SetUpCallFromAttr(FunctionDef & outer_func)1254 void SetUpCallFromAttr(FunctionDef& outer_func) {
1255 // Call `inner_func` from an attribute in a node in `outer_func`.
1256 outer_func = FunctionDefHelper::Create(
1257 "outer_func", {"x:float"}, {"z:float"}, {},
1258 /*node_def=*/
1259 {{{"identity"},
1260 "Identity",
1261 {"x"},
1262 {{"T", DT_FLOAT},
1263 {"f", FunctionDefHelper::FunctionRef("inner_func", {})}}}},
1264 /*ret_def=*/
1265 {{"z", "x"}});
1266 }
1267
SetUpCallFromList(FunctionDef & outer_func)1268 void SetUpCallFromList(FunctionDef& outer_func) {
1269 // Call `inner_func` from a list attribute in a node in `outer_func`.
1270 outer_func = FunctionDefHelper::Create(
1271 "outer_func", {"x:float"}, {"z:float"}, {},
1272 /*node_def=*/
1273 {{{"identity"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
1274 /*ret_def=*/
1275 {{"z", "x"}});
1276
1277 // Add a list containing `inner_func` to the `identity` node.
1278 // `list_value` will be deallocated automatically since it is passed as
1279 // allocated list below.
1280 AttrValue_ListValue* list_value =
1281 (*outer_func.mutable_node_def(0)->mutable_attr())["list"].mutable_list();
1282 NameAttrList* entry = list_value->add_func();
1283 entry->set_name("inner_func");
1284 }
1285
TEST_P(TfDataTestFixture,TfDataTests)1286 TEST_P(TfDataTestFixture, TfDataTests) {
1287 using test::function::NDef;
1288
1289 // Define function library with `outer_func` and `inner_func`.
1290
1291 FunctionDef inner_func = FunctionDefHelper::Create(
1292 "inner_func", {"x:float", "y:float"}, {"z:float"}, {},
1293 /*node_def=*/
1294 {{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
1295 /*ret_def=*/
1296 {{"z", "mul:z:0"}});
1297 (*inner_func.mutable_attr())[data::kTFDataFunction].set_b(
1298 is_inner_func_tf_data_);
1299
1300 FunctionDef outer_func;
1301 switch (func_nesting_type_) {
1302 case FuncNestingType::CallFromNode:
1303 SetUpCallFromNode(outer_func);
1304 break;
1305 case FuncNestingType::CallFromAttr:
1306 SetUpCallFromAttr(outer_func);
1307 break;
1308 case FuncNestingType::CallFromList:
1309 SetUpCallFromList(outer_func);
1310 break;
1311 default:
1312 break;
1313 }
1314 (*outer_func.mutable_attr())[data::kTFDataFunction].set_b(
1315 is_outer_func_tf_data_);
1316
1317 // Tensorflow graph:
1318 //
1319 // a = tf.Placeholder(tf.float);
1320 // result = outer_func(a);
1321 GrapplerItem item;
1322 item.id = "tf_graph";
1323 item.graph = test::function::GDef(
1324 {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1325 // Calls into function library
1326 NDef("outer_func_node", "outer_func", {"a"}, {{"T", DT_FLOAT}}, kDevice),
1327 // Forward outputs
1328 NDef("out_s", "Identity", {"outer_func_node:0"}, {{"T", DT_FLOAT}},
1329 kDevice)},
1330 /*funcs=*/
1331 {inner_func, outer_func});
1332
1333 // Use only custom optimizer which counts its calls.
1334 TfDataTestOptimizer::InitCount();
1335 ConfigProto config_proto;
1336 auto& rewriter_config =
1337 *(config_proto.mutable_graph_options()->mutable_rewrite_options());
1338 rewriter_config.add_optimizers("TfDataTestOptimizer");
1339 rewriter_config.set_min_graph_nodes(-1);
1340 rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
1341
1342 MetaOptimizer optimizer(nullptr, config_proto);
1343 GraphDef output;
1344 const Status status = optimizer.Optimize(nullptr, item, &output);
1345 TF_EXPECT_OK(status);
1346
1347 // We expect one graph optimization + one optimization for each non-tf.data
1348 // function. Note that if `outer_func` is flagged as a tf.data function, then
1349 // `inner_func` is implicitly also considered a tf.data function because it is
1350 // called from `outer_func`.
1351 int expected_count = 3;
1352 if (is_outer_func_tf_data_)
1353 expected_count = 1;
1354 else if (is_inner_func_tf_data_)
1355 expected_count = 2;
1356 EXPECT_EQ(TfDataTestOptimizer::GetCount(), expected_count);
1357
1358 // We expect that the tf.data-attribute has been propagated from `outer_func`
1359 // to its callee `inner_func` if the value is `true`. Otherwise, the attribute
1360 // values should be unchanged.
1361 FunctionLibraryDefinition flib(OpRegistry::Global(), output.library());
1362 const FunctionDef* outer_func_after_opt = flib.Find("outer_func");
1363 const FunctionDef* inner_func_after_opt = flib.Find("inner_func");
1364
1365 EXPECT_EQ(data::IsTFDataFunction(*outer_func_after_opt),
1366 is_outer_func_tf_data_);
1367 if (is_outer_func_tf_data_ || is_inner_func_tf_data_) {
1368 EXPECT_EQ(data::IsTFDataFunction(*inner_func_after_opt), true);
1369 } else {
1370 EXPECT_EQ(data::IsTFDataFunction(*inner_func_after_opt), false);
1371 }
1372 }
1373
1374 INSTANTIATE_TEST_SUITE_P(
1375 MetaOptimizerTest, TfDataTestFixture,
1376 ::testing::Combine(::testing::Bool(), ::testing::Bool(),
1377 ::testing::Values(FuncNestingType::CallFromNode,
1378 FuncNestingType::CallFromAttr,
1379 FuncNestingType::CallFromList)),
__anon114b59d00402(const ::testing::TestParamInfo<TfDataTestFixture::ParamType>& info) 1380 [](const ::testing::TestParamInfo<TfDataTestFixture::ParamType>& info) {
1381 bool is_inner_func_tf_data = std::get<0>(info.param);
1382 bool is_outer_func_tf_data = std::get<1>(info.param);
1383 FuncNestingType func_nesting_type = std::get<2>(info.param);
1384
1385 std::string test_name;
1386 if (is_inner_func_tf_data && is_outer_func_tf_data)
1387 test_name = "both_funcs_tf_data";
1388 else if (is_inner_func_tf_data)
1389 test_name = "inner_func_tf_data";
1390 else if (is_outer_func_tf_data)
1391 test_name = "outer_func_tf_data";
1392 else
1393 test_name = "no_func_tf_data";
1394 switch (func_nesting_type) {
1395 case FuncNestingType::CallFromNode:
1396 test_name += "_call_from_node";
1397 break;
1398 case FuncNestingType::CallFromAttr:
1399 test_name += "_call_from_attribute";
1400 break;
1401 case FuncNestingType::CallFromList:
1402 test_name += "_call_from_list";
1403 break;
1404 default:
1405 break;
1406 }
1407 return test_name;
1408 });
1409
1410 } // namespace
1411 } // namespace grappler
1412 } // namespace tensorflow
1413