xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
17 
18 #include <atomic>
19 
20 #include "absl/strings/match.h"
21 #include "absl/strings/substitute.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/grappler/grappler_item.h"
28 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
29 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
30 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
31 #include "tensorflow/core/grappler/utils.h"
32 #include "tensorflow/core/grappler/utils/grappler_test.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/lib/gtl/map_util.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/test.h"
37 #include "tensorflow/core/protobuf/config.pb.h"
38 
39 namespace tensorflow {
40 namespace grappler {
41 namespace {
42 
43 constexpr char kDevice[] = "/device:CPU:0";
44 
45 class TestOptimizer : public CustomGraphOptimizer {
46  public:
SetOptimized(const bool flag_value)47   static void SetOptimized(const bool flag_value) { optimized_ = flag_value; }
IsOptimized()48   static bool IsOptimized() { return optimized_; }
49 
TestOptimizer()50   TestOptimizer() {}
name() const51   string name() const override { return "test_optimizer"; }
UsesFunctionLibrary() const52   bool UsesFunctionLibrary() const override { return false; }
53 
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config=nullptr)54   Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* config =
55                   nullptr) override {
56     return OkStatus();
57   }
58 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)59   Status Optimize(Cluster* cluster, const GrapplerItem& item,
60                   GraphDef* optimized_graph) override {
61     optimized_ = true;
62     *optimized_graph = item.graph;
63     return OkStatus();
64   }
65 
66  private:
67   static bool optimized_;
68 };
69 
70 bool TestOptimizer::optimized_;
71 
72 REGISTER_GRAPH_OPTIMIZER(TestOptimizer);
73 
74 class TestGraphOptimizer : public TestOptimizer {
75  public:
name() const76   string name() const override { return "test_graph_optimizer"; }
77 };
78 
79 REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer);
80 
81 class TestOptimizerWithParams : public TestOptimizer {
82  public:
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)83   Status Init(
84       const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
85     CHECK(config != nullptr);
86     return OkStatus();
87   }
88 };
89 
90 REGISTER_GRAPH_OPTIMIZER(TestOptimizerWithParams);
91 
92 // Record various properties of the GrapplerItems passed for optimization.
93 class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer {
94  public:
SetOptimizationOptions(gtl::FlatMap<string,GrapplerItem::OptimizationOptions> * optimization_options)95   static void SetOptimizationOptions(
96       gtl::FlatMap<string, GrapplerItem::OptimizationOptions>*
97           optimization_options) {
98     optimization_options_ = optimization_options;
99   }
ResetOptimizationOptions()100   static void ResetOptimizationOptions() { optimization_options_ = nullptr; }
101 
GrapplerItemPropertiesAccumulator()102   GrapplerItemPropertiesAccumulator() {}
name() const103   string name() const override {
104     return "grappler_item_properties_accumulator";
105   }
UsesFunctionLibrary() const106   bool UsesFunctionLibrary() const override { return false; }
107 
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)108   Status Init(
109       const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
110     return OkStatus();
111   }
112 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)113   Status Optimize(Cluster* cluster, const GrapplerItem& item,
114                   GraphDef* optimized_graph) override {
115     *optimized_graph = item.graph;
116     if (optimization_options_) {
117       optimization_options_->insert({item.id, item.optimization_options()});
118     }
119     return OkStatus();
120   }
121 
122  private:
123   static gtl::FlatMap<string, GrapplerItem::OptimizationOptions>*
124       optimization_options_;
125 };
126 
127 gtl::FlatMap<string, GrapplerItem::OptimizationOptions>*
128     GrapplerItemPropertiesAccumulator::optimization_options_;
129 
130 REGISTER_GRAPH_OPTIMIZER(GrapplerItemPropertiesAccumulator);
131 
132 class MetaOptimizerTest : public GrapplerTest {};
133 
TEST_F(MetaOptimizerTest,RunsCustomOptimizer)134 TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
135   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
136   GrapplerItem item;
137   ASSERT_TRUE(fake_input.NextItem(&item));
138 
139   TestOptimizer::SetOptimized(false);
140   ConfigProto config_proto;
141   auto& rewriter_config =
142       *config_proto.mutable_graph_options()->mutable_rewrite_options();
143   rewriter_config.add_optimizers("TestOptimizer");
144   rewriter_config.set_min_graph_nodes(-1);
145 
146   MetaOptimizer optimizer(nullptr, config_proto);
147   GraphDef output;
148   const Status status = optimizer.Optimize(nullptr, item, &output);
149   TF_EXPECT_OK(status);
150   EXPECT_TRUE(TestOptimizer::IsOptimized());
151 }
152 
TEST_F(MetaOptimizerTest,RunsCustomOptimizerWithParams)153 TEST_F(MetaOptimizerTest, RunsCustomOptimizerWithParams) {
154   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
155   GrapplerItem item;
156   ASSERT_TRUE(fake_input.NextItem(&item));
157 
158   TestOptimizer::SetOptimized(false);
159   ConfigProto config_proto;
160   auto& rewriter_config =
161       *config_proto.mutable_graph_options()->mutable_rewrite_options();
162   rewriter_config.add_optimizers("TestOptimizerWithParams");
163   auto* custom_config = rewriter_config.add_custom_optimizers();
164   custom_config->set_name("TestOptimizerWithParams");
165   (*custom_config->mutable_parameter_map())["foo"] = AttrValue();
166 
167   MetaOptimizer optimizer(nullptr, config_proto);
168   GraphDef output;
169   const Status status = optimizer.Optimize(nullptr, item, &output);
170   TF_EXPECT_OK(status);
171   EXPECT_TRUE(TestOptimizer::IsOptimized());
172 }
173 
TEST_F(MetaOptimizerTest,RunsCustomOptimizerAndCustomGraphOptimizer)174 TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) {
175   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
176   GrapplerItem item;
177   ASSERT_TRUE(fake_input.NextItem(&item));
178 
179   TestOptimizer::SetOptimized(false);
180   TestGraphOptimizer::SetOptimized(false);
181   ConfigProto config_proto;
182   auto& rewriter_config =
183       *config_proto.mutable_graph_options()->mutable_rewrite_options();
184   rewriter_config.add_optimizers("TestOptimizer");
185   auto customGraphOptimizer = rewriter_config.add_custom_optimizers();
186   customGraphOptimizer->set_name("TestGraphOptimizer");
187   rewriter_config.set_min_graph_nodes(-1);
188 
189   MetaOptimizer optimizer(nullptr, config_proto);
190   GraphDef output;
191   const Status status = optimizer.Optimize(nullptr, item, &output);
192   TF_EXPECT_OK(status);
193   EXPECT_TRUE(TestOptimizer::IsOptimized());
194   EXPECT_TRUE(TestGraphOptimizer::IsOptimized());
195 }
196 
TEST_F(MetaOptimizerTest,RunsPluginOptimizer)197 TEST_F(MetaOptimizerTest, RunsPluginOptimizer) {
198   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"/device:GPU:0"});
199   GrapplerItem item;
200   ASSERT_TRUE(fake_input.NextItem(&item));
201 
202   TestOptimizer::SetOptimized(false);
203   ConfigProto config_proto;
204   auto& rewriter_config =
205       *config_proto.mutable_graph_options()->mutable_rewrite_options();
206   rewriter_config.set_min_graph_nodes(-1);
207 
208   const auto creator = []() { return new TestOptimizer; };
209   ConfigList config_list;
210   config_list.disable_model_pruning = true;
211   PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(creator, "GPU",
212                                                              config_list);
213 
214   MetaOptimizer optimizer(nullptr, config_proto);
215   GraphDef output;
216   const Status status = optimizer.Optimize(nullptr, item, &output);
217   TF_EXPECT_OK(status);
218   EXPECT_TRUE(TestOptimizer::IsOptimized());
219 }
220 
TEST_F(MetaOptimizerTest,RunOptimizersTwice)221 TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
222   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
223   GrapplerItem item;
224   ASSERT_TRUE(fake_input.NextItem(&item));
225 
226   ConfigProto config_proto;
227   auto& rewriter_config =
228       *config_proto.mutable_graph_options()->mutable_rewrite_options();
229   rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
230   rewriter_config.set_min_graph_nodes(-1);
231 
232   MetaOptimizer optimizer(nullptr, config_proto);
233   GraphDef output;
234   const Status status = optimizer.Optimize(nullptr, item, &output);
235   TF_EXPECT_OK(status);
236 }
237 
TEST_F(MetaOptimizerTest,RunToggleOptimizersAndCustomGraphOptimizerTwice)238 TEST_F(MetaOptimizerTest, RunToggleOptimizersAndCustomGraphOptimizerTwice) {
239   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
240   GrapplerItem item;
241   ASSERT_TRUE(fake_input.NextItem(&item));
242 
243   ConfigProto config_proto;
244   auto& rewriter_config =
245       *config_proto.mutable_graph_options()->mutable_rewrite_options();
246   auto customGraphOptimizer = rewriter_config.add_custom_optimizers();
247   customGraphOptimizer->set_name("TestGraphOptimizer");
248   rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
249   rewriter_config.set_min_graph_nodes(-1);
250 
251   MetaOptimizer optimizer(nullptr, config_proto);
252   GraphDef output;
253   const Status status = optimizer.Optimize(nullptr, item, &output);
254   TF_EXPECT_OK(status);
255   EXPECT_TRUE(TestGraphOptimizer::IsOptimized());
256 }
257 
TEST_F(MetaOptimizerTest,OptimizeFunctionLibrary)258 TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
259   using test::function::NDef;
260 
261   // Enable only function optimization.
262   ConfigProto config_proto;
263   auto& rewriter_config =
264       *config_proto.mutable_graph_options()->mutable_rewrite_options();
265 
266   rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
267   rewriter_config.set_function_optimization(RewriterConfig::ON);
268   rewriter_config.add_optimizers("function");
269   rewriter_config.set_min_graph_nodes(-1);
270 
271   MetaOptimizer optimizer(nullptr, config_proto);
272 
273   // Define function library:
274   //
275   //   MyMul(x, y)    = x * y
276   //  *MySquare(x)    = MyMul(x, x)
277   //  *MyQuadratic(x) = MySquare(MySquare(x))
278   //
279   //  * - marked as noinline
280 
281   FunctionDef mul_func = FunctionDefHelper::Create(
282       "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
283       {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
284       /*ret_def=*/
285       {{"z", "mul:z:0"}});
286 
287   FunctionDef square_func = FunctionDefHelper::Create(
288       "MySquare", {"x:T"}, {"z:T"}, {"T: {float, double}"},
289       {{{"my_mul"}, "MyMul", {"x", "x"}, {{"T", "$T"}}}},
290       /*ret_def=*/
291       {{"z", "my_mul:z:0"}});
292   (*square_func.mutable_attr())["_noinline"].set_b(true);
293 
294   FunctionDef quadratic_func = FunctionDefHelper::Create(
295       "MyQuadratic", {"x:T"}, {"z:T"}, {"T: {float, double}"},
296       {{{"square"}, "MySquare", {"x"}, {{"T", "$T"}}},
297        {{"quadratic"}, "MySquare", {"square:z"}, {{"T", "$T"}}}},
298       /*ret_def=*/
299       {{"z", "quadratic:z:0"}});
300   (*quadratic_func.mutable_attr())["_noinline"].set_b(true);
301 
302   // Tensorflow graph:
303   //
304   //   a = tf.Placeholder(tf.float);
305   //   b = tf.Placeholder(tf.int32);
306   //
307   //   square = MySquare(a);        // a^2
308   //   quadratic = MyQuadratic(b);  // b^4
309   GrapplerItem item;
310   item.id = "tf_graph";
311   item.graph = test::function::GDef(
312       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
313        NDef("b", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
314        // Calls into function library
315        NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}}, kDevice),
316        NDef("quadratic", "MyQuadratic", {"b"}, {{"T", DT_INT32}}, kDevice),
317        // Forward outputs
318        NDef("out_s", "Identity", {"square:0"}, {{"T", DT_FLOAT}}, kDevice),
319        NDef("out_q", "Identity", {"quadratic:0"}, {{"T", DT_INT32}}, kDevice)},
320       /*funcs=*/
321       {mul_func, square_func, quadratic_func});
322 
323   GraphDef output;
324   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
325 
326   FunctionLibraryDefinition optimized_flib(OpRegistry::Global(),
327                                            output.library());
328 
329   // Specialized and optimized functions should be added to the graph.
330   EXPECT_EQ(3, optimized_flib.num_functions());
331 
332   // Get a specialized function name.
333   const auto specialized_name = [](const string& fn, const string& node,
334                                    const string& id) {
335     return absl::Substitute("$0_specialized_for_$1_at_$2", fn, node, id);
336   };
337 
338   // MyQuadratic should be specialized once:
339   //   0. 'quadratic' node in the main graph
340   const string optimized_0 =
341       specialized_name("MyQuadratic", "quadratic", "tf_graph");
342 
343   // MySquare should be specialized and optimized for 3 instantiations:
344   //   1.  'square' node in the main graph
345   //   2.  'square' node in the MyQuadratic specialization
346   //   3*. 'quadratic' node in the MyQuadratic specialization
347   //        has identical instantiation context to #2
348 
349   const string optimized_1 = specialized_name("MySquare", "square", "tf_graph");
350   const string optimized_2 =
351       specialized_name("MySquare", "square", optimized_0);
352 
353   const FunctionDef* optimized_func_0 = optimized_flib.Find(optimized_0);
354   const FunctionDef* optimized_func_1 = optimized_flib.Find(optimized_1);
355   const FunctionDef* optimized_func_2 = optimized_flib.Find(optimized_2);
356 
357   ASSERT_NE(optimized_func_0, nullptr);
358   ASSERT_NE(optimized_func_1, nullptr);
359   ASSERT_NE(optimized_func_2, nullptr);
360 
361   // Graph should call optimized function.
362   int count = 0;
363   for (const NodeDef& node : output.node()) {
364     if (node.name() == "square" && ++count) {
365       EXPECT_EQ(optimized_1, node.op());
366     } else if (node.name() == "quadratic" && ++count) {
367       EXPECT_EQ(optimized_0, node.op());
368     }
369   }
370   EXPECT_EQ(2, count);
371 
372   // Specialized MySquare should call specialized functions.
373   count = 0;
374   for (const NodeDef& node : optimized_func_0->node_def()) {
375     if (node.name() == "square" && ++count) {
376       EXPECT_EQ(optimized_2, node.op());
377     } else if (node.name() == "quadratic" && ++count) {
378       EXPECT_EQ(optimized_2, node.op());
379     }
380   }
381   EXPECT_EQ(2, count);
382 
383   const std::vector<const FunctionDef*> optimized_funcs = {optimized_func_1,
384                                                            optimized_func_2};
385 
386   // MyMul should be inlined into all optimized versions of MySquare.
387   for (const FunctionDef* optimized_func : optimized_funcs) {
388     count = 0;
389     for (const NodeDef& node : optimized_func->node_def()) {
390       if (node.name() == "Func/my_mul/input/_0" && ++count) {
391         EXPECT_EQ("Identity", node.op());
392         EXPECT_EQ(1, node.input_size());
393         EXPECT_EQ("x", node.input(0));
394       } else if (node.name() == "Func/my_mul/input/_1" && ++count) {
395         EXPECT_EQ("Identity", node.op());
396         EXPECT_EQ(1, node.input_size());
397         EXPECT_EQ("x", node.input(0));
398       } else if (node.name() == "my_mul/mul" && ++count) {
399         EXPECT_EQ("Mul", node.op());
400         EXPECT_EQ(2, node.input_size());
401         EXPECT_EQ("Func/my_mul/input/_0:output:0", node.input(0));
402         EXPECT_EQ("Func/my_mul/input/_1:output:0", node.input(1));
403       }
404       EXPECT_TRUE(node.device().empty());
405     }
406     EXPECT_EQ(3, count);
407     ASSERT_EQ(1, optimized_func->ret().size());
408     EXPECT_EQ("Func/my_mul/output/_2:output:0", optimized_func->ret().at("z"));
409   }
410 
411   item.fetch = {"out_s", "out_q"};
412   item.feed.emplace_back("a", test::AsScalar<float>(2.0f));
413   item.feed.emplace_back("b", test::AsScalar<int>(4));
414   auto tensors_expected = EvaluateFetchNodes(item);
415 
416   GrapplerItem optimized = item.WithGraph(std::move(output));
417   auto tensors = EvaluateFetchNodes(optimized);
418 
419   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
420   test::ExpectTensorEqual<int>(tensors_expected[1], tensors[1]);
421 }
422 
TEST_F(MetaOptimizerTest,OptimizeFunctionLibraryPruneUnusedOutputs)423 TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryPruneUnusedOutputs) {
424   using test::function::NDef;
425 
426   ConfigProto config_proto;
427   MetaOptimizer optimizer(nullptr, config_proto);
428 
429   // MyMul computes x*y three times and has three output values.
430   FunctionDef my_mul = FunctionDefHelper::Create(
431       "MyMul", {"x:T", "y:T"}, {"z0:T", "z1:T", "z2:T"}, {"T: {float, int32}"},
432       {{{"output0"}, "Mul", {"x", "y"}, {{"T", "$T"}}},
433        {{"output1"}, "Mul", {"x", "y"}, {{"T", "$T"}}},
434        {{"output2"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
435       /*ret_def=*/
436       {{"z0", "output0:z:0"}, {"z1", "output1:z:0"}, {"z2", "output2:z:0"}});
437 
438   // Call MyMyl and forward all three outputs.
439   FunctionDef my_fwd = FunctionDefHelper::Create(
440       "Fwd", {"x:T", "y:T"}, {"z0:T", "z1:T", "z2:T"}, {"T: {float, int32}"},
441       {{{"output"}, "MyMul", {"x", "y"}, {{"T", "$T"}}}},
442       /*ret_def=*/
443       {{"z0", "output:z0:0"}, {"z1", "output:z1:0"}, {"z2", "output:z2:0"}});
444 
445   // Mark both functions as `_noinline` to trigger specialization.
446   (*my_mul.mutable_attr())["_noinline"].set_b(true);
447   (*my_fwd.mutable_attr())["_noinline"].set_b(true);
448   /*funcs=*/
449   std::vector<FunctionDef> function_library = {my_mul, my_fwd};
450 
451   // Tensorflow graph:
452   //   a = Placeholder[T=float]
453   //   b = Placeholder[T=float]
454   //   fwd = Fwd(a, b)
455   //
456   // Fetch fwd:2 via Identity node.
457   GrapplerItem item;
458   item.id = "tf_graph";
459   item.fetch = {"ret"};
460   item.graph = test::function::GDef(
461       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
462        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
463        NDef("fwd", "Fwd", {"a", "b"}, {{"T", DT_FLOAT}}, kDevice),
464        NDef("ret", "Identity", {"fwd:2"}, {{"T", DT_FLOAT}}, kDevice)},
465       function_library);
466 
467   GraphDef output;
468   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
469 
470   FunctionLibraryDefinition optimized_flib(OpRegistry::Global(),
471                                            output.library());
472 
473   // Specialized functions should be added to the graph.
474   EXPECT_EQ(2, optimized_flib.num_functions());
475 
476   // Expected names of the specialized functions.
477   const string specialized_my_fwd = "Fwd_specialized_for_fwd_at_tf_graph";
478   const string specialized_my_mul =
479       absl::StrCat("MyMul_specialized_for_output_at_", specialized_my_fwd);
480 
481   // Specialized MyMul should have just one output argument.
482   FunctionDef expected_my_mul = FunctionDefHelper::Create(
483       specialized_my_mul, {"x:float", "y:float"}, {"z2:float"}, {},
484       {{{"output2"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
485       /*ret_def=*/
486       {{"z2", "output2:z:0"}});
487 
488   // Specialized Fwd should also have just one output argument.
489   FunctionDef expected_my_fwd = FunctionDefHelper::Create(
490       specialized_my_fwd, {"x:float", "y:float"}, {"z2:float"}, {},
491       {{{"output"}, specialized_my_mul, {"x", "y"}, {{"T", DT_FLOAT}}}},
492       /*ret_def=*/
493       {{"z2", "output:z2:0"}});
494 
495   const FunctionDef* my_mul_spec = optimized_flib.Find(specialized_my_mul);
496   const FunctionDef* my_fwd_spec = optimized_flib.Find(specialized_my_fwd);
497 
498   ASSERT_NE(my_mul_spec, nullptr);
499   ASSERT_NE(my_fwd_spec, nullptr);
500 
501   CompareFunctions(expected_my_mul, *my_mul_spec);
502   CompareFunctions(expected_my_fwd, *my_fwd_spec);
503 
504   item.feed.emplace_back("a", test::AsScalar<float>(2.0f));
505   item.feed.emplace_back("b", test::AsScalar<float>(4.0f));
506   auto tensors_expected = EvaluateFetchNodes(item);
507 
508   GrapplerItem optimized = item.WithGraph(std::move(output));
509   auto tensors = EvaluateFetchNodes(optimized);
510 
511   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
512 }
513 
TEST_F(MetaOptimizerTest,OptimizeFunctionLibraryPruneFunctionBody)514 TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryPruneFunctionBody) {
515   using test::function::NDef;
516 
517   // Enable function optimization and pruning.
518   ConfigProto config_proto;
519   auto& rewriter_config =
520       *config_proto.mutable_graph_options()->mutable_rewrite_options();
521 
522   rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
523   rewriter_config.set_function_optimization(RewriterConfig::ON);
524   rewriter_config.add_optimizers("function");
525   rewriter_config.add_optimizers("pruning");
526   rewriter_config.set_min_graph_nodes(-1);
527 
528   MetaOptimizer optimizer(nullptr, config_proto);
529 
530   // MyFunc defines two Mul nodes inside function body and two corresponding
531   // function outputs.
532   FunctionDef my_func = FunctionDefHelper::Create(
533       "MyFunc", {"x:T", "y:T"}, {"z1:T", "z2:T"}, {"T: {float, double}"},
534       {{{"mul1"}, "Mul", {"x", "y"}, {{"T", "$T"}}},
535        {{"mul2"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
536       /*ret_def=*/
537       {{"z1", "mul1:z:0"}, {"z2", "mul2:z:0"}});
538   (*my_func.mutable_attr())["_noinline"].set_b(true);
539 
540   // Tensorflow graph:
541   //
542   //   a = tf.Placeholder(tf.float);
543   //   b = tf.Placeholder(tf.int32);
544   //
545   //   fn1 = MyFunc(a, b);
546   //   fn2 = MyFunc(a, b);
547   //
548   // Fetch: fn1:0 and fn2:1 via Identity nodes.
549   GrapplerItem item;
550   item.id = "tf_graph";
551   item.graph = test::function::GDef(
552       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
553        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
554        // Calls into function library
555        NDef("fn1", "MyFunc", {"a", "b"}, {{"T", DT_FLOAT}}, kDevice),
556        NDef("fn2", "MyFunc", {"a", "b"}, {{"T", DT_FLOAT}}, kDevice),
557        // Read outputs of function call nodes
558        NDef("out_fn1", "Identity", {"fn1:0"}, {{"T", DT_FLOAT}}, kDevice),
559        NDef("out_fn2", "Identity", {"fn2:1"}, {{"T", DT_FLOAT}}, kDevice)},
560       /*funcs=*/
561       {my_func});
562 
563   GraphDef output;
564   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
565 
566   FunctionLibraryDefinition optimized_flib(OpRegistry::Global(),
567                                            output.library());
568 
569   // Specialized and optimized functions should be added to the graph.
570   EXPECT_EQ(2, optimized_flib.num_functions());
571 
572   // Expected names of the specialized and optimized functions.
573   const string optimized_fn1 = "MyFunc_specialized_for_fn1_at_tf_graph";
574   const string optimized_fn2 = "MyFunc_specialized_for_fn2_at_tf_graph";
575 
576   const FunctionDef* optimized_func_fn1 = optimized_flib.Find(optimized_fn1);
577   const FunctionDef* optimized_func_fn2 = optimized_flib.Find(optimized_fn2);
578 
579   ASSERT_NE(optimized_func_fn1, nullptr);
580   ASSERT_NE(optimized_func_fn2, nullptr);
581 
582   // Graph should call optimized function.
583   int count = 0;
584   for (const NodeDef& node : output.node()) {
585     if (node.name() == "fn1" && ++count) {
586       EXPECT_EQ(optimized_fn1, node.op());
587     } else if (node.name() == "fn2" && ++count) {
588       EXPECT_EQ(optimized_fn2, node.op());
589     }
590   }
591   EXPECT_EQ(2, count);
592 
593   // Specialized MyFuncs should have just one Mul node and single output arg.
594 
595   // 1. Specialized for fn1:0.
596   ASSERT_EQ(1, optimized_func_fn1->node_def_size());
597   EXPECT_EQ(1, optimized_func_fn1->signature().output_arg_size());
598   EXPECT_EQ("z1", optimized_func_fn1->signature().output_arg(0).name());
599   EXPECT_EQ("mul1", optimized_func_fn1->node_def(0).name());
600 
601   // 2. Specialized for fn2:1.
602   ASSERT_EQ(1, optimized_func_fn2->node_def_size());
603   EXPECT_EQ(1, optimized_func_fn2->signature().output_arg_size());
604   EXPECT_EQ("z2", optimized_func_fn2->signature().output_arg(0).name());
605   EXPECT_EQ("mul2", optimized_func_fn2->node_def(0).name());
606 
607   // Verify that output tensors are equal.
608   item.fetch = {"out_fn1", "out_fn2"};
609   item.feed.emplace_back("a", test::AsScalar<float>(2.0f));
610   item.feed.emplace_back("b", test::AsScalar<float>(3.123f));
611   auto tensors_expected = EvaluateFetchNodes(item);
612 
613   GrapplerItem optimized = item.WithGraph(std::move(output));
614   auto tensors = EvaluateFetchNodes(optimized);
615 
616   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
617   test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
618 }
619 
TEST_F(MetaOptimizerTest,OptimizeFunctionLibraryWithRestrictions)620 TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) {
621   using test::function::NDef;
622   using FDH = FunctionDefHelper;
623 
624   // We will record what type of optimizations meta optimizer allows for each
625   // GrapplerItem (main graph and graphs for each function).
626   gtl::FlatMap<string, GrapplerItem::OptimizationOptions> optimization_options;
627   GrapplerItemPropertiesAccumulator::SetOptimizationOptions(
628       &optimization_options);
629 
630   // Just record properties of optimized Grappler items.
631   ConfigProto config_proto;
632   auto& rewriter_config =
633       *config_proto.mutable_graph_options()->mutable_rewrite_options();
634 
635   rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
636   rewriter_config.add_optimizers("GrapplerItemPropertiesAccumulator");
637   rewriter_config.set_min_graph_nodes(-1);
638 
639   MetaOptimizer optimizer(nullptr, config_proto);
640 
641   // Define simple function library with two identical mul functions.
642   FunctionDef mul_func_1 = FunctionDefHelper::Create(
643       "MyMul1", {"x:float", "y:float"}, {"z:float"}, {},
644       {{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
645       /*ret_def=*/
646       {{"z", "mul:z:0"}});
647 
648   FunctionDef mul_func_2 = FunctionDefHelper::Create(
649       "MyMul2", {"x:float", "y:float"}, {"z:float"}, {},
650       {{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
651       /*ret_def=*/
652       {{"z", "mul:z:0"}});
653 
654   // Tensorflow graph:
655   //
656   //   x0 = tf.Placeholder(tf.float);
657   //   x1 = tf.Placeholder(tf.float);
658   //   dy = tf.Placeholder(tf.float);
659   //
660   //   mul_1 = MyMul1(x0, x1);
661   //   mul_2 = MyMul2(x0, x1);
662   //   dx = SymbolicGradient({x0, x1, dy}, f=MyMul2)
663   GrapplerItem item;
664   item.id = "main";
665   item.graph = test::function::GDef(
666       {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
667        NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
668        NDef("dy", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
669        // Calls into function library
670        NDef("mul_1", "MyMul1", {"x0", "x1"}, {}, kDevice),
671        NDef("mul_2", "MyMul2", {"x0", "x1"}, {}, kDevice),
672        // Symbolic gradient of a MyMul2
673        NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"},
674             {{"f", FDH::FunctionRef("MyMul2", {})},
675              {"Tin", DataTypeSlice{DT_FLOAT}},
676              {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT}}},
677             kDevice)},
678       /*funcs=*/
679       {mul_func_1, mul_func_2});
680   item.fetch = {"mul_1", "mul_2", "dx"};
681 
682   GraphDef output;
683   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
684 
685   // Our custom optimizer must be called for the main graph and for the two
686   // functions.
687   ASSERT_EQ(optimization_options.size(), 3);
688 
689   auto optimization_options_main =
690       gtl::FindOrNull(optimization_options, "main");
691   ASSERT_NE(optimization_options_main, nullptr);
692   EXPECT_TRUE(optimization_options_main->allow_non_differentiable_rewrites);
693 
694   auto optimization_options_my_mul_1 =
695       gtl::FindOrNull(optimization_options, "MyMul1");
696   ASSERT_NE(optimization_options_my_mul_1, nullptr);
697   EXPECT_TRUE(optimization_options_my_mul_1->allow_non_differentiable_rewrites);
698 
699   auto optimization_options_my_mul_2 =
700       gtl::FindOrNull(optimization_options, "MyMul2");
701   ASSERT_NE(optimization_options_my_mul_2, nullptr);
702   EXPECT_FALSE(
703       optimization_options_my_mul_2->allow_non_differentiable_rewrites);
704 }
705 
706 class SleepingOptimizer : public CustomGraphOptimizer {
707  public:
SleepingOptimizer()708   SleepingOptimizer() {}
name() const709   string name() const override { return "test_optimizer"; }
UsesFunctionLibrary() const710   bool UsesFunctionLibrary() const override { return false; }
711 
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)712   Status Init(
713       const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
714     return OkStatus();
715   }
716 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)717   Status Optimize(Cluster* cluster, const GrapplerItem& item,
718                   GraphDef* optimized_graph) override {
719     *optimized_graph = item.graph;
720     Env::Default()->SleepForMicroseconds(1000000);
721     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
722     optimized_graph->add_node();
723     return OkStatus();
724   }
725 };
726 
727 REGISTER_GRAPH_OPTIMIZER(SleepingOptimizer);
728 
TEST_F(MetaOptimizerTest,OptimizerTimesOut)729 TEST_F(MetaOptimizerTest, OptimizerTimesOut) {
730   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
731   GrapplerItem item;
732   ASSERT_TRUE(fake_input.NextItem(&item));
733 
734   ConfigProto config;
735   RewriterConfig& rewriter_config =
736       *config.mutable_graph_options()->mutable_rewrite_options();
737   rewriter_config.add_optimizers("SleepingOptimizer");
738   rewriter_config.set_min_graph_nodes(-1);
739   rewriter_config.set_meta_optimizer_timeout_ms(500);
740   rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
741 
742   GraphDef output;
743   GraphDef original = item.graph;
744   const Status status =
745       RunMetaOptimizer(std::move(item), config, nullptr, nullptr, &output);
746   EXPECT_EQ(status.error_message(), "meta_optimizer exceeded deadline.");
747   // Make sure the graph was reverted to the original regardless of when the
748   // optimizer timed out.
749   CompareGraphs(original, output);
750 }
751 
TEST_F(MetaOptimizerTest,MetaOptimizerTimesOut)752 TEST_F(MetaOptimizerTest, MetaOptimizerTimesOut) {
753   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
754   GrapplerItem item;
755   ASSERT_TRUE(fake_input.NextItem(&item));
756 
757   ConfigProto config;
758   RewriterConfig& rewriter_config =
759       *config.mutable_graph_options()->mutable_rewrite_options();
760   rewriter_config.add_optimizers("SleepingOptimizer");
761   rewriter_config.set_min_graph_nodes(-1);
762   rewriter_config.set_meta_optimizer_timeout_ms(1500);
763   rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
764 
765   GraphDef output;
766   const int original_node_size = item.graph.node_size();
767   const Status status =
768       RunMetaOptimizer(std::move(item), config, nullptr, nullptr, &output);
769   EXPECT_EQ(status.error_message(), "meta_optimizer exceeded deadline.");
770   // The meta optimizer should manage to finish one iteration.
771   EXPECT_EQ(original_node_size + 1, output.node_size());
772 }
773 
TEST_F(MetaOptimizerTest,OptimizerDoesNotTimeOut)774 TEST_F(MetaOptimizerTest, OptimizerDoesNotTimeOut) {
775   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
776   GrapplerItem item;
777   ASSERT_TRUE(fake_input.NextItem(&item));
778 
779   ConfigProto config;
780   RewriterConfig& rewriter_config =
781       *config.mutable_graph_options()->mutable_rewrite_options();
782   rewriter_config.add_optimizers("SleepingOptimizer");
783   rewriter_config.set_min_graph_nodes(-1);
784   rewriter_config.set_meta_optimizer_timeout_ms(2500);
785   rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
786   GraphDef output;
787   const int original_node_size = item.graph.node_size();
788   const Status status =
789       RunMetaOptimizer(std::move(item), config, nullptr, nullptr, &output);
790   TF_EXPECT_OK(status);
791   // The meta optimizer should manage to finish two iterations.
792   EXPECT_EQ(original_node_size + 2, output.node_size());
793 }
794 
TEST_F(MetaOptimizerTest,RunPostOptimizationVerifiersOnValidGraph)795 TEST_F(MetaOptimizerTest, RunPostOptimizationVerifiersOnValidGraph) {
796   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
797   GrapplerItem item;
798   ASSERT_TRUE(fake_input.NextItem(&item));
799 
800   ConfigProto config_proto;
801   auto& post_optimization_verifier_config =
802       *config_proto.mutable_graph_options()
803            ->mutable_rewrite_options()
804            ->mutable_post_optimization_verifier_config();
805   post_optimization_verifier_config.set_structure_verifier(VerifierConfig::ON);
806 
807   MetaOptimizer optimizer(nullptr, config_proto);
808   GraphDef output;
809   const Status status = optimizer.Optimize(nullptr, item, &output);
810   TF_EXPECT_OK(status);
811 }
812 
TEST_F(MetaOptimizerTest,RunInterOptimizerVerifiersOnValidGraph)813 TEST_F(MetaOptimizerTest, RunInterOptimizerVerifiersOnValidGraph) {
814   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {kDevice});
815   GrapplerItem item;
816   ASSERT_TRUE(fake_input.NextItem(&item));
817 
818   ConfigProto config_proto;
819   auto& inter_optimizer_verifier_config =
820       *config_proto.mutable_graph_options()
821            ->mutable_rewrite_options()
822            ->mutable_inter_optimizer_verifier_config();
823   inter_optimizer_verifier_config.set_structure_verifier(VerifierConfig::ON);
824 
825   MetaOptimizer optimizer(nullptr, config_proto);
826   GraphDef output;
827   const Status status = optimizer.Optimize(nullptr, item, &output);
828   TF_EXPECT_OK(status);
829 }
830 
TEST_F(MetaOptimizerTest,RunPostOptimizationVerifiersOnInvalidGraph)831 TEST_F(MetaOptimizerTest, RunPostOptimizationVerifiersOnInvalidGraph) {
832   using test::function::NDef;
833   using FDH = FunctionDefHelper;
834 
835   gtl::FlatMap<string, GrapplerItem::OptimizationOptions> optimization_options;
836   GrapplerItemPropertiesAccumulator::SetOptimizationOptions(
837       &optimization_options);
838 
839   // Define simple function library with two identical mul functions.
840   FunctionDef mul_func_1 =
841       FunctionDefHelper::Create("MyMul1", {"x:float", "y:float"}, {"z:float"},
842                                 {}, {{{"mul"}, "Mul", {"x", "y"}, {}}},
843                                 /*ret_def=*/
844                                 {{"z", "mul:z:0"}});
845 
846   FunctionDef mul_func_2 =
847       FunctionDefHelper::Create("MyMul2", {"x:float", "y:float"}, {"z:float"},
848                                 {}, {{{"mul"}, "Mul", {"x", "y"}, {}}},
849                                 /*ret_def=*/
850                                 {{"z", "mul:z:0"}});
851 
852   // Tensorflow graph:
853   //
854   //   x0 = tf.Placeholder(tf.float);
855   //   x1 = tf.Placeholder(tf.float);
856   //   dy = tf.Placeholder(tf.float);
857   //
858   //   mul_1 = MyMul1(x0, x1);
859   //   mul_2 = MyMul2(x0, x1);
860   //   dx = SymbolicGradient({x0, x1, dy}, f=MyMul2)
861   GrapplerItem item;
862   item.id = "main";
863   item.graph = test::function::GDef(
864       {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
865        NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
866        NDef("dy", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
867        // Calls into function library
868        NDef("mul_1", "MyMul1", {"x0", "x1"}, {}, kDevice),
869        NDef("mul_2", "MyMul2", {"x0", "x1"}, {}, kDevice),
870        // Symbolic gradient of a MyMul2
871        NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"},
872             {{"f", FDH::FunctionRef("MyMul2", {})},
873              {"Tin", DataTypeSlice{DT_FLOAT}},
874              {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT}}},
875             kDevice)},
876       /*funcs=*/
877       {mul_func_1, mul_func_2});
878   item.fetch = {"mul_1", "mul_2", "dx"};
879 
880   GraphDef output;
881 
882   // Call Optimize with post optimization verifiers.
883   ConfigProto config_proto;
884   auto& rewriter_config =
885       *config_proto.mutable_graph_options()->mutable_rewrite_options();
886 
887   rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
888   rewriter_config.add_optimizers("GrapplerItemPropertiesAccumulator");
889   rewriter_config.set_min_graph_nodes(-1);
890   auto& post_optimization_verifier_config =
891       *config_proto.mutable_graph_options()
892            ->mutable_rewrite_options()
893            ->mutable_post_optimization_verifier_config();
894   post_optimization_verifier_config.set_structure_verifier(VerifierConfig::ON);
895 
896   MetaOptimizer optimizer_with_post_verifiers(nullptr, config_proto);
897   Status status =
898       optimizer_with_post_verifiers.Optimize(nullptr, item, &output);
899   EXPECT_TRUE(errors::IsInvalidArgument(status));
900   EXPECT_TRUE(absl::StrContains(
901       status.error_message(),
902       "NodeDef expected inputs 'float' do not match 3 inputs specified"));
903 }
904 
TEST_F(MetaOptimizerTest,RunInterOptimizerVerifiersOnInvalidGraph)905 TEST_F(MetaOptimizerTest, RunInterOptimizerVerifiersOnInvalidGraph) {
906   using test::function::NDef;
907   using FDH = FunctionDefHelper;
908 
909   gtl::FlatMap<string, GrapplerItem::OptimizationOptions> optimization_options;
910   GrapplerItemPropertiesAccumulator::SetOptimizationOptions(
911       &optimization_options);
912 
913   // Define simple function library with two identical mul functions.
914   FunctionDef mul_func_1 =
915       FunctionDefHelper::Create("MyMul1", {"x:float", "y:float"}, {"z:float"},
916                                 {}, {{{"mul"}, "Mul", {"x", "y"}, {}}},
917                                 /*ret_def=*/
918                                 {{"z", "mul:z:0"}});
919 
920   FunctionDef mul_func_2 =
921       FunctionDefHelper::Create("MyMul2", {"x:float", "y:float"}, {"z:float"},
922                                 {}, {{{"mul"}, "Mul", {"x", "y"}, {}}},
923                                 /*ret_def=*/
924                                 {{"z", "mul:z:0"}});
925 
926   // Tensorflow graph:
927   //
928   //   x0 = tf.Placeholder(tf.float);
929   //   x1 = tf.Placeholder(tf.float);
930   //   dy = tf.Placeholder(tf.float);
931   //
932   //   mul_1 = MyMul1(x0, x1);
933   //   mul_2 = MyMul2(x0, x1);
934   //   dx = SymbolicGradient({x0, x1, dy}, f=MyMul2)
935   GrapplerItem item;
936   item.id = "main";
937   item.graph = test::function::GDef(
938       {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
939        NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
940        NDef("dy", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
941        NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
942        // Calls into function library
943        NDef("mul_1", "MyMul1", {"x0", "x1"}, {}, kDevice),
944        NDef("mul_2", "MyMul2", {"x0", "x1"}, {}, kDevice),
945        // Symbolic gradient of a MyMul2
946        NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"},
947             {{"f", FDH::FunctionRef("MyMul2", {})},
948              {"Tin", DataTypeSlice{DT_FLOAT}},
949              {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT}}},
950             kDevice)},
951       /*funcs=*/
952       {mul_func_1, mul_func_2});
953   item.fetch = {"mul_1", "mul_2", "dx"};
954 
955   GraphDef output;
956 
957   // Call Optimize with post optimization verifiers.
958   ConfigProto config_proto;
959   // Call Optimize with inter optimizer verifiers.
960   auto& rewriter_config =
961       *config_proto.mutable_graph_options()->mutable_rewrite_options();
962   rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
963   rewriter_config.add_optimizers("GrapplerItemPropertiesAccumulator");
964   rewriter_config.set_min_graph_nodes(-1);
965   auto& inter_optimizer_verifier_config =
966       *config_proto.mutable_graph_options()
967            ->mutable_rewrite_options()
968            ->mutable_inter_optimizer_verifier_config();
969   inter_optimizer_verifier_config.set_structure_verifier(VerifierConfig::ON);
970 
971   MetaOptimizer optimizer_with_inter_verifiers(nullptr, config_proto);
972   Status status =
973       optimizer_with_inter_verifiers.Optimize(nullptr, item, &output);
974   EXPECT_EQ(status.code(), errors::Code::INVALID_ARGUMENT);
975   EXPECT_TRUE(absl::StrContains(
976       status.error_message(),
977       "NodeDef expected inputs 'float' do not match 3 inputs specified"));
978 }
979 
TEST_F(MetaOptimizerTest,CompressConstants)980 TEST_F(MetaOptimizerTest, CompressConstants) {
981   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
982   Tensor zeros_t(DT_FLOAT, TensorShape({64}));
983   Tensor ones_t(DT_FLOAT, TensorShape({64}));
984   for (int i = 0; i < 64; ++i) {
985     zeros_t.flat<float>()(i) = 0.0f;
986     ones_t.flat<float>()(i) = 1.0f;
987   }
988   Output zeros = ops::Const(scope.WithOpName("zeros"), zeros_t);
989   Output host_ones = ops::Const(scope.WithOpName("host_ones"), ones_t);
990   GrapplerItem item;
991   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
992   ASSERT_EQ(item.graph.node(1).name(), "host_ones");
993   // There is not C++ api for HostConst, so we manually change the node type
994   // here.
995   item.graph.mutable_node(1)->set_op("HostConst");
996   item.fetch = {"zeros", "host_ones"};
997   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
998 
999   ConfigProto config_proto;
1000   auto& rewriter_config =
1001       *config_proto.mutable_graph_options()->mutable_rewrite_options();
1002   rewriter_config.set_min_graph_nodes(-1);
1003   MetaOptimizer optimizer(/*cpu_device=*/nullptr, config_proto);
1004   GraphDef output;
1005   TF_EXPECT_OK(optimizer.Optimize(/*cluster=*/nullptr, item, &output));
1006 
1007   bool found_zeros = false;
1008   bool found_host_ones = false;
1009   ASSERT_EQ(output.node_size(), 2);
1010   for (const auto& node : output.node()) {
1011     if (node.name() == "zeros") {
1012       found_zeros = true;
1013       EXPECT_EQ(node.op(), "Const");
1014       const TensorProto& zeroes_t = node.attr().at("value").tensor();
1015       EXPECT_EQ(zeroes_t.float_val_size(), 0);
1016     } else if (node.name() == "host_ones") {
1017       found_host_ones = true;
1018       EXPECT_EQ(node.op(), "HostConst");
1019       const TensorProto& ones_t = node.attr().at("value").tensor();
1020       EXPECT_EQ(ones_t.float_val_size(), 1);
1021       EXPECT_EQ(ones_t.float_val(0), 1.0f);
1022     }
1023   }
1024 
1025   EXPECT_TRUE(found_zeros);
1026   EXPECT_TRUE(found_host_ones);
1027 
1028   auto tensors = EvaluateNodes(output, item.fetch, {});
1029   ASSERT_EQ(tensors.size(), 2);
1030   ASSERT_EQ(tensors_expected.size(), 2);
1031   for (int i = 0; i < 2; ++i) {
1032     test::ExpectTensorEqual<float>(tensors[i], tensors_expected[i]);
1033   }
1034 }
1035 
TEST_F(MetaOptimizerTest,TestTFGRemoveDeadArguments)1036 TEST_F(MetaOptimizerTest, TestTFGRemoveDeadArguments) {
1037   using test::function::NDef;
1038 
1039   gtl::FlatMap<string, GrapplerItem::OptimizationOptions> optimization_options;
1040   GrapplerItemPropertiesAccumulator::SetOptimizationOptions(
1041       &optimization_options);
1042 
1043   // Define a simple function library with one branch function.
1044   //   def branch_func(x, y):
1045   //     z = tf.Mul(x, x)
1046   //     return z
1047   FunctionDef case_func = FunctionDefHelper::Create(
1048       "branch_func", {"x:float", "y:float"}, {"z:float"}, {},
1049       {{{"mul"}, "Mul", {"x", "x"}, {{"T", DT_FLOAT}}}},
1050       /*ret_def=*/
1051       {{"z", "mul:z:0"}});
1052 
1053   // Tensorflow graph:
1054   //
1055   //   idx = tf.Placeholder(tf.int32);
1056   //   x = tf.Placeholder(tf.float);
1057   //   y = tf.Placeholder(tf.float);
1058   //
1059   //   case = tf.Case(idx, x, y, branches=[branch_func])
1060   GrapplerItem item;
1061   item.id = "main";
1062 
1063   AttrValue branches;
1064   branches.mutable_list()->add_func()->set_name("branch_func");
1065   AttrValue output_shapes;
1066   output_shapes.mutable_list()->add_shape();
1067   item.graph = test::function::GDef(
1068       {NDef("idx", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
1069        NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1070        NDef("y", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1071        // Calls into function library
1072        NDef("case", "Case", {"idx", "x", "y"},
1073             {{"branches", std::move(branches)},
1074              {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1075              {"Tout", DataTypeSlice{DT_FLOAT}},
1076              {"output_shapes", std::move(output_shapes)}},
1077             kDevice)},
1078       /*funcs=*/
1079       {case_func});
1080   item.fetch = {"case"};
1081 
1082   GraphDef output;
1083   ConfigProto config_proto;
1084 
1085   MetaOptimizer optimizer(nullptr, config_proto);
1086   Status status = optimizer.Optimize(nullptr, item, &output);
1087   EXPECT_TRUE(status.ok());
1088   EXPECT_EQ(output.library().function_size(), 1);
1089   // One of the arguments was removed.
1090   auto& func = output.library().function(0);
1091   EXPECT_EQ(func.signature().input_arg_size(), 1);
1092   EXPECT_EQ(func.signature().input_arg(0).name(), "x_tfg_result_0");
1093 }
1094 
TEST_F(MetaOptimizerTest,TestTFGControlFlowSink)1095 TEST_F(MetaOptimizerTest, TestTFGControlFlowSink) {
1096   using test::function::NDef;
1097 
1098   gtl::FlatMap<string, GrapplerItem::OptimizationOptions> optimization_options;
1099   GrapplerItemPropertiesAccumulator::SetOptimizationOptions(
1100       &optimization_options);
1101 
1102   // Define a branch function.
1103   //   def branch_func(x, y):
1104   //     z = tf.Mul(x, y)
1105   //     return z
1106   FunctionDef case_func = FunctionDefHelper::Create(
1107       "branch_func", {"x:float", "y:float"}, {"z:float"}, {},
1108       {{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
1109       /*ret_def=*/
1110       {{"z", "mul:z:0"}});
1111 
1112   // Define a function with a control-flow op.
1113   //   def Foo(idx, a, b):
1114   //     x_foo = Add(a, b)
1115   //     y_foo = Mul(a, b)
1116   //     case = Case(idx, x_foo, y_foo, branches=[branch_func[)
1117   //     return case
1118   AttrValue branches;
1119   branches.mutable_list()->add_func()->set_name("branch_func");
1120   AttrValue output_shapes;
1121   output_shapes.mutable_list()->add_shape();
1122   FunctionDef foo_func = FunctionDefHelper::Create(
1123       "Foo", {"idx:int32", "a:float", "b:float"}, {"c:float"}, {},
1124       {{{"add"}, "Add", {"a", "b"}, {{"T", DT_FLOAT}}},
1125        {{"mul"}, "Mul", {"a", "b"}, {{"T", DT_FLOAT}}},
1126        {{"case"},
1127         "Case",
1128         {"idx", "add:z:0", "mul:z:0"},
1129         {{"branches", std::move(branches)},
1130          {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
1131          {"Tout", DataTypeSlice{DT_FLOAT}},
1132          {"output_shapes", std::move(output_shapes)}}}},
1133       /*ret_def=*/
1134       {{"c", "case:output:0"}});
1135   (*foo_func.mutable_attr())["_noinline"].set_b(true);
1136 
1137   // Tensorflow graph:
1138   //
1139   //   idx = tf.Placeholder(tf.int32);
1140   //   a = tf.Placeholder(tf.float);
1141   //   b = tf.Placeholder(tf.float);
1142   //
1143   //   foo_val = Foo(idx, a, b)
1144   GrapplerItem item;
1145   item.id = "main";
1146 
1147   item.graph = test::function::GDef(
1148       {NDef("idx", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
1149        NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1150        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1151        // Calls into function library
1152        NDef("foo", "Foo", {"idx", "a", "b"}, {}, kDevice)},
1153       /*funcs=*/
1154       {case_func, foo_func});
1155   item.fetch = {"foo"};
1156 
1157   GraphDef output;
1158   ConfigProto config_proto;
1159 
1160   MetaOptimizer optimizer(nullptr, config_proto);
1161   Status status = optimizer.Optimize(nullptr, item, &output);
1162   TF_EXPECT_OK(status);
1163   EXPECT_EQ(output.library().function_size(), 2);
1164 
1165   const FunctionDef* optimized_foo_func = nullptr;
1166   const FunctionDef* specialized_branch_func = nullptr;
1167   for (const FunctionDef& func : output.library().function()) {
1168     if (func.signature().name() == "Foo")
1169       optimized_foo_func = &func;
1170     else if (absl::StartsWith(func.signature().name(), "branch_func"))
1171       specialized_branch_func = &func;
1172   }
1173   ASSERT_TRUE(optimized_foo_func);
1174   EXPECT_EQ(optimized_foo_func->node_def_size(), 1);
1175   ASSERT_TRUE(specialized_branch_func);
1176   EXPECT_EQ(specialized_branch_func->node_def_size(), 3);
1177 }
1178 
1179 // Tests for checking expected behavior when skipping tf.data functions in
1180 // meta optimizer.
1181 
1182 // Custom optimizer which counts the number of calls of its method `Optimize`
1183 // across all class instances.
1184 class TfDataTestOptimizer : public CustomGraphOptimizer {
1185  public:
InitCount()1186   static void InitCount() { count_ = 0; }
GetCount()1187   static int GetCount() { return count_; }
1188 
1189   TfDataTestOptimizer() = default;
1190   ~TfDataTestOptimizer() override = default;
1191   TfDataTestOptimizer(const TfDataTestOptimizer&) = delete;
1192   TfDataTestOptimizer& operator=(const TfDataTestOptimizer& other) = delete;
1193 
name() const1194   std::string name() const override { return "tf_data_test_optimizer"; }
UsesFunctionLibrary() const1195   bool UsesFunctionLibrary() const override { return false; }
1196 
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)1197   Status Init(
1198       const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
1199     return OkStatus();
1200   }
1201 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)1202   Status Optimize(Cluster* cluster, const GrapplerItem& item,
1203                   GraphDef* optimized_graph) override {
1204     ++count_;
1205     *optimized_graph = item.graph;
1206     return OkStatus();
1207   }
1208 
1209  private:
1210   static std::atomic<int> count_;
1211 };
1212 
1213 std::atomic<int> TfDataTestOptimizer::count_;
1214 
1215 REGISTER_GRAPH_OPTIMIZER(TfDataTestOptimizer);
1216 
1217 // Type for specifying how the inner function is nested inside the outer
1218 // function.
1219 enum class FuncNestingType {
1220   CallFromNode = 0,
1221   CallFromAttr = 1,
1222   CallFromList = 2
1223 };
1224 
1225 // Test fixture for parametrized testing.
1226 class TfDataTestFixture
1227     : public ::testing::TestWithParam<std::tuple<bool, bool, FuncNestingType>> {
1228  protected:
SetUp()1229   void SetUp() override {
1230     is_inner_func_tf_data_ = std::get<0>(GetParam());
1231     is_outer_func_tf_data_ = std::get<1>(GetParam());
1232     func_nesting_type_ = std::get<2>(GetParam());
1233   }
1234   // Controls which of the functions is flagged as tf.data function.
1235   bool is_inner_func_tf_data_ = false;
1236   bool is_outer_func_tf_data_ = false;
1237   // Controls how the inner function is nested inside the outer function.
1238   FuncNestingType func_nesting_type_ = FuncNestingType::CallFromNode;
1239 };
1240 
1241 // Helper functions for setting up the call of `inner_func` inside of
1242 // `outer_func`.
1243 
SetUpCallFromNode(FunctionDef & outer_func)1244 void SetUpCallFromNode(FunctionDef& outer_func) {
1245   // Call `inner_func` from a node in `outer_func`.
1246   outer_func = FunctionDefHelper::Create(
1247       "outer_func", {"x:float"}, {"z:float"}, {},
1248       /*node_def=*/
1249       {{{"inner_func"}, "inner_func", {"x", "x"}, {{"T", DT_FLOAT}}}},
1250       /*ret_def=*/
1251       {{"z", "inner_func:z:0"}});
1252 }
1253 
SetUpCallFromAttr(FunctionDef & outer_func)1254 void SetUpCallFromAttr(FunctionDef& outer_func) {
1255   // Call `inner_func` from an attribute in a node in `outer_func`.
1256   outer_func = FunctionDefHelper::Create(
1257       "outer_func", {"x:float"}, {"z:float"}, {},
1258       /*node_def=*/
1259       {{{"identity"},
1260         "Identity",
1261         {"x"},
1262         {{"T", DT_FLOAT},
1263          {"f", FunctionDefHelper::FunctionRef("inner_func", {})}}}},
1264       /*ret_def=*/
1265       {{"z", "x"}});
1266 }
1267 
SetUpCallFromList(FunctionDef & outer_func)1268 void SetUpCallFromList(FunctionDef& outer_func) {
1269   // Call `inner_func` from a list attribute in a node in `outer_func`.
1270   outer_func = FunctionDefHelper::Create(
1271       "outer_func", {"x:float"}, {"z:float"}, {},
1272       /*node_def=*/
1273       {{{"identity"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
1274       /*ret_def=*/
1275       {{"z", "x"}});
1276 
1277   // Add a list containing `inner_func` to the `identity` node.
1278   // `list_value` will be deallocated automatically since it is passed as
1279   // allocated list below.
1280   AttrValue_ListValue* list_value =
1281       (*outer_func.mutable_node_def(0)->mutable_attr())["list"].mutable_list();
1282   NameAttrList* entry = list_value->add_func();
1283   entry->set_name("inner_func");
1284 }
1285 
TEST_P(TfDataTestFixture,TfDataTests)1286 TEST_P(TfDataTestFixture, TfDataTests) {
1287   using test::function::NDef;
1288 
1289   // Define function library with `outer_func` and `inner_func`.
1290 
1291   FunctionDef inner_func = FunctionDefHelper::Create(
1292       "inner_func", {"x:float", "y:float"}, {"z:float"}, {},
1293       /*node_def=*/
1294       {{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
1295       /*ret_def=*/
1296       {{"z", "mul:z:0"}});
1297   (*inner_func.mutable_attr())[data::kTFDataFunction].set_b(
1298       is_inner_func_tf_data_);
1299 
1300   FunctionDef outer_func;
1301   switch (func_nesting_type_) {
1302     case FuncNestingType::CallFromNode:
1303       SetUpCallFromNode(outer_func);
1304       break;
1305     case FuncNestingType::CallFromAttr:
1306       SetUpCallFromAttr(outer_func);
1307       break;
1308     case FuncNestingType::CallFromList:
1309       SetUpCallFromList(outer_func);
1310       break;
1311     default:
1312       break;
1313   }
1314   (*outer_func.mutable_attr())[data::kTFDataFunction].set_b(
1315       is_outer_func_tf_data_);
1316 
1317   // Tensorflow graph:
1318   //
1319   //   a = tf.Placeholder(tf.float);
1320   //   result = outer_func(a);
1321   GrapplerItem item;
1322   item.id = "tf_graph";
1323   item.graph = test::function::GDef(
1324       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
1325        // Calls into function library
1326        NDef("outer_func_node", "outer_func", {"a"}, {{"T", DT_FLOAT}}, kDevice),
1327        // Forward outputs
1328        NDef("out_s", "Identity", {"outer_func_node:0"}, {{"T", DT_FLOAT}},
1329             kDevice)},
1330       /*funcs=*/
1331       {inner_func, outer_func});
1332 
1333   // Use only custom optimizer which counts its calls.
1334   TfDataTestOptimizer::InitCount();
1335   ConfigProto config_proto;
1336   auto& rewriter_config =
1337       *(config_proto.mutable_graph_options()->mutable_rewrite_options());
1338   rewriter_config.add_optimizers("TfDataTestOptimizer");
1339   rewriter_config.set_min_graph_nodes(-1);
1340   rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
1341 
1342   MetaOptimizer optimizer(nullptr, config_proto);
1343   GraphDef output;
1344   const Status status = optimizer.Optimize(nullptr, item, &output);
1345   TF_EXPECT_OK(status);
1346 
1347   // We expect one graph optimization + one optimization for each non-tf.data
1348   // function. Note that if `outer_func` is flagged as a tf.data function, then
1349   // `inner_func` is implicitly also considered a tf.data function because it is
1350   // called from `outer_func`.
1351   int expected_count = 3;
1352   if (is_outer_func_tf_data_)
1353     expected_count = 1;
1354   else if (is_inner_func_tf_data_)
1355     expected_count = 2;
1356   EXPECT_EQ(TfDataTestOptimizer::GetCount(), expected_count);
1357 
1358   // We expect that the tf.data-attribute has been propagated from `outer_func`
1359   // to its callee `inner_func` if the value is `true`. Otherwise, the attribute
1360   // values should be unchanged.
1361   FunctionLibraryDefinition flib(OpRegistry::Global(), output.library());
1362   const FunctionDef* outer_func_after_opt = flib.Find("outer_func");
1363   const FunctionDef* inner_func_after_opt = flib.Find("inner_func");
1364 
1365   EXPECT_EQ(data::IsTFDataFunction(*outer_func_after_opt),
1366             is_outer_func_tf_data_);
1367   if (is_outer_func_tf_data_ || is_inner_func_tf_data_) {
1368     EXPECT_EQ(data::IsTFDataFunction(*inner_func_after_opt), true);
1369   } else {
1370     EXPECT_EQ(data::IsTFDataFunction(*inner_func_after_opt), false);
1371   }
1372 }
1373 
1374 INSTANTIATE_TEST_SUITE_P(
1375     MetaOptimizerTest, TfDataTestFixture,
1376     ::testing::Combine(::testing::Bool(), ::testing::Bool(),
1377                        ::testing::Values(FuncNestingType::CallFromNode,
1378                                          FuncNestingType::CallFromAttr,
1379                                          FuncNestingType::CallFromList)),
__anon114b59d00402(const ::testing::TestParamInfo<TfDataTestFixture::ParamType>& info) 1380     [](const ::testing::TestParamInfo<TfDataTestFixture::ParamType>& info) {
1381       bool is_inner_func_tf_data = std::get<0>(info.param);
1382       bool is_outer_func_tf_data = std::get<1>(info.param);
1383       FuncNestingType func_nesting_type = std::get<2>(info.param);
1384 
1385       std::string test_name;
1386       if (is_inner_func_tf_data && is_outer_func_tf_data)
1387         test_name = "both_funcs_tf_data";
1388       else if (is_inner_func_tf_data)
1389         test_name = "inner_func_tf_data";
1390       else if (is_outer_func_tf_data)
1391         test_name = "outer_func_tf_data";
1392       else
1393         test_name = "no_func_tf_data";
1394       switch (func_nesting_type) {
1395         case FuncNestingType::CallFromNode:
1396           test_name += "_call_from_node";
1397           break;
1398         case FuncNestingType::CallFromAttr:
1399           test_name += "_call_from_attribute";
1400           break;
1401         case FuncNestingType::CallFromList:
1402           test_name += "_call_from_list";
1403           break;
1404         default:
1405           break;
1406       }
1407       return test_name;
1408     });
1409 
1410 }  // namespace
1411 }  // namespace grappler
1412 }  // namespace tensorflow
1413