xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h"
16 
17 #include <memory>
18 #include <utility>
19 
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "tensorflow/cc/framework/scope.h"
23 #include "tensorflow/cc/ops/array_ops.h"
24 #include "tensorflow/cc/ops/const_op.h"
25 #include "tensorflow/cc/ops/function_ops.h"
26 #include "tensorflow/cc/ops/functional_ops.h"
27 #include "tensorflow/cc/ops/math_ops.h"
28 #include "tensorflow/cc/ops/resource_variable_ops.h"
29 #include "tensorflow/cc/ops/sendrecv_ops.h"
30 #include "tensorflow/cc/ops/standard_ops.h"
31 #include "tensorflow/cc/ops/while_loop.h"
32 #include "tensorflow/compiler/jit/defs.h"
33 #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
34 #include "tensorflow/core/framework/attr_value.pb.h"
35 #include "tensorflow/core/framework/device_factory.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/graph.pb.h"
38 #include "tensorflow/core/framework/graph_to_functiondef.h"
39 #include "tensorflow/core/framework/node_def.pb.h"
40 #include "tensorflow/core/framework/node_def_builder.h"
41 #include "tensorflow/core/framework/tensor_testutil.h"
42 #include "tensorflow/core/framework/types.pb.h"
43 #include "tensorflow/core/grappler/utils/grappler_test.h"
44 #include "tensorflow/core/kernels/resource_variable_ops.h"
45 #include "tensorflow/core/lib/core/status_test_util.h"
46 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
47 #include "tensorflow/core/util/equal_graph_def.h"
48 
49 namespace tensorflow {
50 namespace tfrt_stub {
51 namespace {
52 
53 using ::testing::_;
54 using ::testing::ElementsAre;
55 using ::testing::EqualsProto;
56 using ::testing::HasSubstr;
57 using ::testing::IsEmpty;
58 using ::testing::NotNull;
59 using ::testing::Pair;
60 using ::testing::SizeIs;
61 using ::testing::proto::IgnoringFieldPaths;
62 using ::testing::proto::IgnoringRepeatedFieldOrdering;
63 
64 class PruneGraphDefTest : public grappler::GrapplerTest {};
65 
TEST_F(PruneGraphDefTest,ConstFeedWithInput)66 TEST_F(PruneGraphDefTest, ConstFeedWithInput) {
67   GraphDef graphdef;
68   {
69     auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
70 
71     Output a = ops::Const(scope.WithOpName("a"), 0.0f, {10, 10});
72 
73     Output b = ops::Const(scope.WithControlDependencies(a).WithOpName("b"),
74                           0.0f, {10, 10});
75     Output c = ops::Identity(scope.WithOpName("c"), b);
76 
77     TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
78   }
79 
80   CallableOptions callable_options;
81   callable_options.add_feed("b");
82   callable_options.add_fetch("c");
83 
84   TF_ASSERT_OK(PruneGraphDef(graphdef, callable_options));
85 
86   GraphDef expected;
87   {
88     auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
89 
90     Output b = ops::Const(scope.WithOpName("b"), 0.0f, {10, 10});
91     Output c = ops::Identity(scope.WithOpName("c"), b);
92 
93     TF_ASSERT_OK(scope.ToGraphDef(&expected));
94   }
95 
96   CompareGraphs(expected, graphdef);
97 }
98 
LessThanTenCond(const Scope & scope,const std::vector<Output> & inputs,Output * output)99 Status LessThanTenCond(const Scope& scope, const std::vector<Output>& inputs,
100                        Output* output) {
101   *output = ops::Less(scope, inputs[0], 10);
102   return scope.status();
103 }
104 
AddOneBody(const Scope & scope,const std::vector<Output> & inputs,std::vector<Output> * outputs)105 Status AddOneBody(const Scope& scope, const std::vector<Output>& inputs,
106                   std::vector<Output>* outputs) {
107   outputs->push_back(ops::AddN(scope, {inputs[0], 1}));
108   return scope.status();
109 }
110 
TEST_F(PruneGraphDefTest,InsertIdentityForLoopExitFeed)111 TEST_F(PruneGraphDefTest, InsertIdentityForLoopExitFeed) {
112   GraphDef graphdef;
113   {
114     auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
115 
116     std::vector<Output> inputs;
117     inputs.push_back(ops::Placeholder(scope.WithOpName("input"), DT_INT32));
118     std::vector<Output> outputs;
119     TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("while"), inputs,
120                                      LessThanTenCond, AddOneBody, "test_loop",
121                                      &outputs));
122 
123     TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
124   }
125 
126   CallableOptions callable_options;
127   callable_options.add_feed("input");
128   callable_options.add_fetch("while/Exit");
129 
130   TF_ASSERT_OK(PruneGraphDef(graphdef, callable_options));
131 
132   for (const auto& node : graphdef.node()) {
133     if (node.op() == "Exit") {
134       EXPECT_EQ(node.name(), "while/Exit/tfrt_renamed");
135     }
136     if (node.name() == "while/Exit") {
137       EXPECT_EQ(node.op(), "Identity");
138       ASSERT_EQ(node.input().size(), 1);
139       EXPECT_EQ(node.input(0), "while/Exit/tfrt_renamed");
140     }
141   }
142 }
143 
TEST_F(PruneGraphDefTest,EliminateRefEntersFromControlFlow)144 TEST_F(PruneGraphDefTest, EliminateRefEntersFromControlFlow) {
145   GraphDef graphdef;
146   absl::flat_hash_map<std::string, NodeDef> name_to_node;
147   {
148     auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
149 
150     std::vector<Output> inputs;
151     inputs.push_back(ops::Placeholder(scope.WithOpName("input"), DT_INT32));
152     std::vector<Output> outputs1;
153     std::vector<Output> outputs2;
154     TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("while"), inputs,
155                                      LessThanTenCond, AddOneBody, "test_loop",
156                                      &outputs1));
157     TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("while"), inputs,
158                                      LessThanTenCond, AddOneBody, "test_loop2",
159                                      &outputs2));
160 
161     TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
162 
163     // Simply replace Enter with RefEnter. Note this is not valid graph though.
164     for (auto& node : *graphdef.mutable_node()) {
165       if (node.op() == "Enter") {
166         node.set_op("RefEnter");
167       }
168       name_to_node.insert({node.name(), node});
169     }
170   }
171 
172   TF_ASSERT_OK(EliminateRefVariablesFromV1ControlFlow(graphdef));
173 
174   int num_identity_op = 0;
175   int num_enter_op = 0;
176   int num_ref_enter_op = 0;
177   for (const auto& node : graphdef.node()) {
178     if (node.op() == "Identity") {
179       num_identity_op++;
180       EXPECT_EQ(node.name(), "input/identity");
181       ASSERT_EQ(node.input().size(), 1);
182       EXPECT_EQ(node.input(0), "input");
183       EXPECT_THAT(node.attr(), ElementsAre(Pair("T", _)));
184     } else if (node.op() == "RefEnter") {
185       num_ref_enter_op++;
186     } else if (node.op() == "Enter") {
187       // Identity op should be placed before Enter.
188       EXPECT_EQ(num_identity_op, 1);
189       num_enter_op++;
190       ASSERT_EQ(node.input().size(), 1);
191       EXPECT_EQ(node.input(0), "input/identity");
192       EXPECT_THAT(
193           node, IgnoringFieldPaths({"input", "op"},
194                                    EqualsProto(name_to_node.at(node.name()))));
195     } else {
196       EXPECT_THAT(node, EqualsProto(name_to_node.at(node.name())));
197     }
198     name_to_node.erase(node.name());
199   }
200   EXPECT_EQ(num_identity_op, 1);
201   EXPECT_EQ(num_enter_op, 2);
202   EXPECT_EQ(num_ref_enter_op, 0);
203   EXPECT_THAT(name_to_node, IsEmpty());
204 }
205 
TEST_F(PruneGraphDefTest,EliminateRefSwitchesFromControlFlow)206 TEST_F(PruneGraphDefTest, EliminateRefSwitchesFromControlFlow) {
207   GraphDef graphdef;
208   absl::flat_hash_map<std::string, NodeDef> name_to_node;
209   {
210     auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
211 
212     Output cond_a = ops::Placeholder(scope.WithOpName("cond_a"), DT_BOOL);
213     Output cond_b = ops::Placeholder(scope.WithOpName("cond_b"), DT_BOOL);
214     Output input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
215 
216     ops::Switch switch_a(scope.WithOpName("switch_a"), input, cond_a);
217     ops::Switch switch_b(scope.WithOpName("switch_b"), input, cond_b);
218 
219     Output switch_a_true =
220         ops::Identity(scope.WithOpName("switch_a_true"), switch_a.output_true);
221     Output switch_b_true =
222         ops::Identity(scope.WithOpName("switch_b_true"), switch_b.output_true);
223 
224     TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
225 
226     // Simply replace Switch with RefSwitch. Note this is not valid graph
227     // though.
228     for (auto& node : *graphdef.mutable_node()) {
229       if (node.op() == "Switch") {
230         node.set_op("RefSwitch");
231       }
232       name_to_node.insert({node.name(), node});
233     }
234   }
235 
236   TF_ASSERT_OK(EliminateRefVariablesFromV1ControlFlow(graphdef));
237 
238   int num_identity_op = 0;
239   int num_switch_op = 0;
240   int num_ref_switch_op = 0;
241   for (const auto& node : graphdef.node()) {
242     if (node.name() == "switch_a_true" || node.name() == "switch_b_true") {
243       EXPECT_THAT(node, EqualsProto(name_to_node.at(node.name())));
244     } else if (node.op() == "Identity") {
245       num_identity_op++;
246       EXPECT_EQ(node.name(), "input/identity");
247       ASSERT_EQ(node.input().size(), 1);
248       EXPECT_EQ(node.input(0), "input");
249       EXPECT_THAT(node.attr(), ElementsAre(Pair("T", _)));
250     } else if (node.op() == "RefSwitch") {
251       num_ref_switch_op++;
252     } else if (node.op() == "Switch") {
253       // Identity op should be placed before Switch.
254       EXPECT_EQ(num_identity_op, 1);
255       num_switch_op++;
256       ASSERT_EQ(node.input().size(), 2);
257       EXPECT_TRUE(node.input(0) == "input/identity" ||
258                   node.input(1) == "input/identity");
259       EXPECT_THAT(
260           node, IgnoringFieldPaths({"input", "op"},
261                                    EqualsProto(name_to_node.at(node.name()))));
262     } else {
263       EXPECT_THAT(node, EqualsProto(name_to_node.at(node.name())));
264     }
265     name_to_node.erase(node.name());
266   }
267   EXPECT_EQ(num_identity_op, 1);
268   EXPECT_EQ(num_switch_op, 2);
269   EXPECT_EQ(num_ref_switch_op, 0);
270   EXPECT_THAT(name_to_node, IsEmpty());
271 }
272 
TEST_F(PruneGraphDefTest,EliminateRefVariablesFromV1ControlFlowFailed)273 TEST_F(PruneGraphDefTest, EliminateRefVariablesFromV1ControlFlowFailed) {
274   GraphDef graphdef;
275   absl::flat_hash_map<std::string, NodeDef> name_to_node;
276   {
277     auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
278 
279     Output cond = ops::Placeholder(scope.WithOpName("cond"), DT_BOOL);
280     Output input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
281 
282     ops::Switch switch_op(scope.WithOpName("switch"), input, cond);
283     Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
284     Output assign =
285         ops::Assign(scope.WithOpName("assign"), var, switch_op.output_true);
286 
287     TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
288 
289     // Simply replace Switch with RefSwitch. Note this is not valid graph
290     // though.
291     for (auto& node : *graphdef.mutable_node()) {
292       if (node.op() == "Switch") {
293         node.set_op("RefSwitch");
294       }
295       name_to_node.insert({node.name(), node});
296     }
297   }
298 
299   const auto status = EliminateRefVariablesFromV1ControlFlow(graphdef);
300   EXPECT_FALSE(status.ok());
301   EXPECT_THAT(status.error_message(),
302               HasSubstr("requires its input to be refs"));
303 }
304 
TEST_F(PruneGraphDefTest,KeepLoopStructureComplete)305 TEST_F(PruneGraphDefTest, KeepLoopStructureComplete) {
306   GraphDef graphdef;
307   {
308     auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
309 
310     std::vector<Output> inputs;
311     inputs.push_back(ops::Placeholder(scope.WithOpName("input"), DT_INT32));
312     std::vector<Output> outputs;
313     TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("while"), inputs,
314                                      LessThanTenCond, AddOneBody, "test_loop",
315                                      &outputs));
316 
317     TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
318   }
319 
320   CallableOptions callable_options;
321   callable_options.add_feed("input");
322   // Sets the fetch node such that traversing from there will miss part of the
323   // while loop structure.
324   callable_options.add_fetch("while/LoopCond");
325 
326   GraphDef original_graphdef = graphdef;
327   TF_ASSERT_OK(PruneGraphDef(graphdef, callable_options));
328   EXPECT_THAT(graphdef,
329               IgnoringRepeatedFieldOrdering(EqualsProto(original_graphdef)));
330 }
331 
332 class OptimizeGraphTest : public grappler::GrapplerTest {};
333 
TEST_F(OptimizeGraphTest,OptimizeFunctions)334 TEST_F(OptimizeGraphTest, OptimizeFunctions) {
335   GraphDef graphdef;
336   tensorflow::FunctionDefLibrary fdef_lib;
337   {
338     auto scope = tensorflow::Scope::NewRootScope().WithDevice(
339         "/job:localhost/replica:0/task:0/device:CPU:0");
340 
341     const Tensor kThree = test::AsScalar<float>(3.0);
342     auto fdef = tensorflow::FunctionDefHelper::Create(
343         "Pow3", {"x: float"}, {"y: float"}, {},
344         {{{"three"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", kThree}}},
345          {{"pow3"}, "Pow", {"x", "three:output:0"}, {{"T", DT_FLOAT}}}},
346         {{"y", "pow3:z:0"}});
347 
348     tensorflow::FunctionDefLibrary fdef_lib;
349     *fdef_lib.add_function() = fdef;
350     TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
351 
352     Output a = ops::Const(scope.WithOpName("a"), 2.0, {1, 1});
353 
354     std::vector<tensorflow::Output> inputs = {a};
355     std::vector<tensorflow::DataType> output_dtypes = {
356         fdef.signature().output_arg(0).type()};
357     tensorflow::NameAttrList func_attr;
358     func_attr.set_name(fdef.signature().name());
359     auto pcall = ops::PartitionedCall(scope, inputs, output_dtypes, func_attr);
360     Output b = pcall.output.front();
361 
362     Output c = ops::Identity(scope.WithOpName("c"), b);
363 
364     TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
365   }
366 
367   TF_ASSERT_OK_AND_ASSIGN(
368       auto fallback_state,
369       tensorflow::tfrt_stub::FallbackState::Create({}, fdef_lib));
370 
371   TfrtGraphExecutionState::Options options;
372   options.run_placer_grappler_on_functions = true;
373   TF_ASSERT_OK_AND_ASSIGN(
374       auto graph_execution_state,
375       TfrtGraphExecutionState::Create(options, graphdef, *fallback_state));
376 
377   tensorflow::GraphImportConfig graph_import_config;
378   graph_import_config.prune_unused_nodes = true;
379   graph_import_config.enable_shape_inference = false;
380   tensorflow::ArrayInfo array_info;
381   array_info.imported_dtype = DT_FLOAT;
382   array_info.shape.set_unknown_rank(true);
383   graph_import_config.inputs["a"] = array_info;
384   graph_import_config.outputs = {"c"};
385 
386   TF_ASSERT_OK_AND_ASSIGN(
387       auto optimized_graph,
388       graph_execution_state->CreateOptimizedGraph(graph_import_config));
389   GraphDef optimized_graph_def;
390   optimized_graph.graph->ToGraphDef(&optimized_graph_def);
391 
392   GraphDef expected;
393   {
394     auto scope = tensorflow::Scope::NewRootScope().WithDevice(
395         "/job:localhost/replica:0/task:0/device:CPU:0");
396 
397     const Tensor kThree = test::AsScalar<float>(3.0);
398     // After optimization, "x^3" will be transformed to "(x^2)*x".
399     auto fdef = tensorflow::FunctionDefHelper::Create(
400         "Pow3", {"x: float"}, {"y_retval: float"}, {},
401         {{{"ArithmeticOptimizer/ConvertPow__inner_pow3"},
402           "Square",
403           {"x"},
404           {{"dtype", DT_FLOAT}},
405           /*dep=*/{},
406           "/job:localhost/replica:0/task:0/device:CPU:0"},
407          {{"pow3"},
408           "Mul",
409           {"ArithmeticOptimizer/ConvertPow__inner_pow3:y:0", "x"},
410           {{"T", DT_FLOAT}},
411           /*dep=*/{},
412           "/job:localhost/replica:0/task:0/device:CPU:0"}},
413         {{"y_retval", "pow3:z:0"}});
414 
415     tensorflow::FunctionDefLibrary fdef_lib;
416     *fdef_lib.add_function() = fdef;
417     TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
418 
419     Output a = ops::Const(scope.WithOpName("a"), 2.0, {1, 1});
420 
421     std::vector<tensorflow::Output> inputs = {a};
422     std::vector<tensorflow::DataType> output_dtypes = {
423         fdef.signature().output_arg(0).type()};
424     tensorflow::NameAttrList func_attr;
425     func_attr.set_name(fdef.signature().name());
426     auto pcall = ops::PartitionedCall(scope, inputs, output_dtypes, func_attr);
427     Output b = pcall.output.front();
428 
429     Output c = ops::Identity(scope.WithOpName("c"), b);
430 
431     TF_ASSERT_OK(scope.ToGraphDef(&expected));
432   }
433 
434   CompareGraphs(expected, optimized_graph_def);
435   CompareFunctions(expected.library().function(0),
436                    optimized_graph_def.library().function(0));
437 }
438 
TEST_F(OptimizeGraphTest,OptimizeFunctionsUsedByFunctionNodes)439 TEST_F(OptimizeGraphTest, OptimizeFunctionsUsedByFunctionNodes) {
440   GraphDef graphdef;
441   tensorflow::FunctionDefLibrary fdef_lib;
442   {
443     auto scope = tensorflow::Scope::NewRootScope().WithDevice(
444         "/job:localhost/replica:0/task:0/device:CPU:0");
445 
446     const Tensor kThree = test::AsScalar<float>(3.0);
447     auto pow3_fdef = tensorflow::FunctionDefHelper::Create(
448         "Pow3", {"x: float"}, {"y: float"}, {},
449         {{{"three"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", kThree}}},
450          {{"pow3"}, "Pow", {"x", "three:output:0"}, {{"T", DT_FLOAT}}}},
451         {{"y", "pow3:z:0"}});
452 
453     const Tensor kOne = test::AsScalar<float>(1.0);
454     auto base2pow3_fdef = tensorflow::FunctionDefHelper::Create(
455         "Add1Pow3", {"x: float"}, {"y: float"}, {},
456         {{{"one"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", kOne}}},
457          {{"add"}, "Add", {"x", "one:output:0"}, {{"T", DT_FLOAT}}},
458          {{"pcall"},
459           "PartitionedCall",
460           {"add:z:0"},
461           {{"Tin", DataTypeSlice({DT_FLOAT})},
462            {"Tout", DataTypeSlice({DT_FLOAT})},
463            {"f", tensorflow::FunctionDefHelper::FunctionRef(
464                      "Pow3", {{"T", DT_FLOAT}})}}}},
465         {{"y", "pcall:output:0"}});
466 
467     tensorflow::FunctionDefLibrary fdef_lib;
468     *fdef_lib.add_function() = pow3_fdef;
469     *fdef_lib.add_function() = base2pow3_fdef;
470     TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
471 
472     Output a = ops::Const(scope.WithOpName("a"), 1.0, {1, 1});
473 
474     std::vector<tensorflow::Output> inputs = {a};
475     std::vector<tensorflow::DataType> output_dtypes = {
476         base2pow3_fdef.signature().output_arg(0).type()};
477     tensorflow::NameAttrList func_attr;
478     func_attr.set_name(base2pow3_fdef.signature().name());
479     auto pcall = ops::PartitionedCall(scope, inputs, output_dtypes, func_attr);
480     Output b = pcall.output.front();
481 
482     Output c = ops::Identity(scope.WithOpName("c"), b);
483 
484     TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
485   }
486 
487   TF_ASSERT_OK_AND_ASSIGN(
488       auto fallback_state,
489       tensorflow::tfrt_stub::FallbackState::Create({}, fdef_lib));
490 
491   TfrtGraphExecutionState::Options options;
492   options.run_placer_grappler_on_functions = true;
493   TF_ASSERT_OK_AND_ASSIGN(
494       auto graph_execution_state,
495       TfrtGraphExecutionState::Create(options, graphdef, *fallback_state));
496 
497   tensorflow::GraphImportConfig graph_import_config;
498   graph_import_config.prune_unused_nodes = true;
499   graph_import_config.enable_shape_inference = false;
500   tensorflow::ArrayInfo array_info;
501   array_info.imported_dtype = DT_FLOAT;
502   array_info.shape.set_unknown_rank(true);
503   graph_import_config.inputs["a"] = array_info;
504   graph_import_config.outputs = {"c"};
505 
506   TF_ASSERT_OK_AND_ASSIGN(
507       auto optimized_graph,
508       graph_execution_state->CreateOptimizedGraph(graph_import_config));
509   GraphDef optimized_graph_def;
510   optimized_graph.graph->ToGraphDef(&optimized_graph_def);
511 
512   GraphDef expected;
513   {
514     auto scope = tensorflow::Scope::NewRootScope().WithDevice(
515         "/job:localhost/replica:0/task:0/device:CPU:0");
516 
517     const Tensor kThree = test::AsScalar<float>(3.0);
518     // After optimization, "x^3" will be transformed to "(x^2)*x".
519     auto pow3_fdef = tensorflow::FunctionDefHelper::Create(
520         "Pow3", {"x: float"}, {"y_retval: float"}, {},
521         {{{"ArithmeticOptimizer/ConvertPow__inner_pow3"},
522           "Square",
523           {"x"},
524           {{"dtype", DT_FLOAT}},
525           /*dep=*/{},
526           "/job:localhost/replica:0/task:0/device:CPU:0"},
527          {{"pow3"},
528           "Mul",
529           {"ArithmeticOptimizer/ConvertPow__inner_pow3:y:0", "x"},
530           {{"T", DT_FLOAT}},
531           /*dep=*/{},
532           "/job:localhost/replica:0/task:0/device:CPU:0"}},
533         {{"y_retval", "pow3:z:0"}});
534 
535     const Tensor kOne = test::AsScalar<float>(1.0);
536     auto base2pow3_fdef = tensorflow::FunctionDefHelper::Create(
537         "Add1Pow3", {"x: float"}, {"y: float"}, {},
538         {{{"one"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", kOne}}},
539          {{"add"}, "Add", {"x", "one:output:0"}, {{"T", DT_FLOAT}}},
540          {{"pcall"},
541           "PartitionedCall",
542           {"add:z:0"},
543           {{"Tin", DataTypeSlice({DT_FLOAT})},
544            {"Tout", DataTypeSlice({DT_FLOAT})},
545            {"f", tensorflow::FunctionDefHelper::FunctionRef(
546                      "Pow3", {{"T", DT_FLOAT}})}}}},
547         {{"y", "pcall:output:0"}});
548 
549     tensorflow::FunctionDefLibrary fdef_lib;
550     *fdef_lib.add_function() = pow3_fdef;
551     *fdef_lib.add_function() = base2pow3_fdef;
552     TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
553 
554     Output a = ops::Const(scope.WithOpName("a"), 1.0, {1, 1});
555 
556     std::vector<tensorflow::Output> inputs = {a};
557     std::vector<tensorflow::DataType> output_dtypes = {
558         base2pow3_fdef.signature().output_arg(0).type()};
559     tensorflow::NameAttrList func_attr;
560     func_attr.set_name(base2pow3_fdef.signature().name());
561     auto pcall = ops::PartitionedCall(scope, inputs, output_dtypes, func_attr);
562     Output b = pcall.output.front();
563 
564     Output c = ops::Identity(scope.WithOpName("c"), b);
565 
566     TF_ASSERT_OK(scope.ToGraphDef(&expected));
567   }
568 
569   // Since `Pow3` is called by `Add1Pow3`, it is optimized.
570   CompareFunctions(expected.library().function(1),
571                    optimized_graph_def.library().function(1));
572   ASSERT_EQ("Pow3",
573             optimized_graph_def.library().function(1).signature().name());
574 }
575 
TEST_F(OptimizeGraphTest,DontOptimizeUnsafeFunction)576 TEST_F(OptimizeGraphTest, DontOptimizeUnsafeFunction) {
577   GraphDef graphdef;
578   tensorflow::FunctionDefLibrary fdef_lib;
579   {
580     auto scope = tensorflow::Scope::NewRootScope().WithDevice(
581         "/job:localhost/replica:0/task:0/device:CPU:0");
582 
583     const Tensor kThree = test::AsScalar<float>(3.0);
584     auto fdef = tensorflow::FunctionDefHelper::Create(
585         "Pow3", {"x: float"}, {"y: float"}, {},
586         {{{"three"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", kThree}}},
587          {{"pow3"}, "Pow", {"x", "three:output:0"}, {{"T", DT_FLOAT}}}},
588         {{"y", "pow3:z:0"}});
589 
590     tensorflow::FunctionDefLibrary fdef_lib;
591     *fdef_lib.add_function() = fdef;
592     TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
593 
594     Output a = ops::Const(scope.WithOpName("a"), 2.0, {1, 1});
595 
596     Output cond = ops::Const(scope.WithOpName("cond"), true, {1, 1});
597     std::vector<tensorflow::Output> inputs = {a};
598     std::vector<tensorflow::DataType> output_dtypes = {
599         fdef.signature().output_arg(0).type()};
600     tensorflow::NameAttrList func_attr;
601     func_attr.set_name(fdef.signature().name());
602     auto if_op =
603         ops::If(scope, cond, inputs, output_dtypes, func_attr, func_attr);
604     Output b = if_op.output.front();
605 
606     Output c = ops::Identity(scope.WithOpName("c"), b);
607 
608     TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
609   }
610 
611   TF_ASSERT_OK_AND_ASSIGN(
612       auto fallback_state,
613       tensorflow::tfrt_stub::FallbackState::Create({}, fdef_lib));
614 
615   TfrtGraphExecutionState::Options options;
616   options.run_placer_grappler_on_functions = true;
617   TF_ASSERT_OK_AND_ASSIGN(
618       auto graph_execution_state,
619       TfrtGraphExecutionState::Create(options, graphdef, *fallback_state));
620 
621   tensorflow::GraphImportConfig graph_import_config;
622   graph_import_config.prune_unused_nodes = true;
623   graph_import_config.enable_shape_inference = false;
624   tensorflow::ArrayInfo array_info;
625   array_info.imported_dtype = DT_FLOAT;
626   array_info.shape.set_unknown_rank(true);
627   graph_import_config.inputs["a"] = array_info;
628   graph_import_config.outputs = {"c"};
629 
630   TF_ASSERT_OK_AND_ASSIGN(
631       auto optimized_graph,
632       graph_execution_state->CreateOptimizedGraph(graph_import_config));
633   GraphDef optimized_graph_def;
634   optimized_graph.graph->ToGraphDef(&optimized_graph_def);
635 
636   // The optimized graph remains the same as the original one, because the
637   // function used by `If` op is not optimized.
638   CompareGraphs(graphdef, optimized_graph_def);
639   CompareFunctions(graphdef.library().function(0),
640                    optimized_graph_def.library().function(0));
641 }
642 
TEST_F(OptimizeGraphTest,FunctionBecomeUnsafeIfAnyOpIsUnsafe)643 TEST_F(OptimizeGraphTest, FunctionBecomeUnsafeIfAnyOpIsUnsafe) {
644   GraphDef graphdef;
645   tensorflow::FunctionDefLibrary fdef_lib;
646   {
647     auto scope = tensorflow::Scope::NewRootScope().WithDevice(
648         "/job:localhost/replica:0/task:0/device:CPU:0");
649 
650     const Tensor kThree = test::AsScalar<float>(3.0);
651     auto fdef = tensorflow::FunctionDefHelper::Create(
652         "Pow3", {"x: float"}, {"y: float"}, {},
653         {{{"three"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", kThree}}},
654          {{"pow3"}, "Pow", {"x", "three:output:0"}, {{"T", DT_FLOAT}}}},
655         {{"y", "pow3:z:0"}});
656 
657     tensorflow::FunctionDefLibrary fdef_lib;
658     *fdef_lib.add_function() = fdef;
659     TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
660 
661     Output a = ops::Const(scope.WithOpName("a"), 2.0, {1, 1});
662 
663     Output cond = ops::Const(scope.WithOpName("cond"), true, {1, 1});
664     std::vector<tensorflow::Output> inputs = {a};
665     std::vector<tensorflow::DataType> output_dtypes = {
666         fdef.signature().output_arg(0).type()};
667     tensorflow::NameAttrList func_attr;
668     func_attr.set_name(fdef.signature().name());
669     auto if_op =
670         ops::If(scope, cond, inputs, output_dtypes, func_attr, func_attr);
671     Output b = if_op.output.front();
672 
673     inputs = {b};
674     auto pcall = ops::PartitionedCall(scope, inputs, output_dtypes, func_attr);
675     Output c = pcall.output.front();
676 
677     Output d = ops::Identity(scope.WithOpName("d"), c);
678 
679     TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
680   }
681 
682   TF_ASSERT_OK_AND_ASSIGN(
683       auto fallback_state,
684       tensorflow::tfrt_stub::FallbackState::Create({}, fdef_lib));
685 
686   TfrtGraphExecutionState::Options options;
687   options.run_placer_grappler_on_functions = true;
688   TF_ASSERT_OK_AND_ASSIGN(
689       auto graph_execution_state,
690       TfrtGraphExecutionState::Create(options, graphdef, *fallback_state));
691 
692   tensorflow::GraphImportConfig graph_import_config;
693   graph_import_config.prune_unused_nodes = true;
694   graph_import_config.enable_shape_inference = false;
695   tensorflow::ArrayInfo array_info;
696   array_info.imported_dtype = DT_FLOAT;
697   array_info.shape.set_unknown_rank(true);
698   graph_import_config.inputs["a"] = array_info;
699   graph_import_config.outputs = {"d"};
700 
701   TF_ASSERT_OK_AND_ASSIGN(
702       auto optimized_graph,
703       graph_execution_state->CreateOptimizedGraph(graph_import_config));
704   GraphDef optimized_graph_def;
705   optimized_graph.graph->ToGraphDef(&optimized_graph_def);
706 
707   // Both `If` and `PartitionedCall` ops use the function, so the function
708   // remains unoptimized.
709   CompareFunctions(graphdef.library().function(0),
710                    optimized_graph_def.library().function(0));
711 }
712 
713 class ExtendGraphTest : public grappler::GrapplerTest {};
714 
TEST_F(ExtendGraphTest,ExtendGraph)715 TEST_F(ExtendGraphTest, ExtendGraph) {
716   GraphDef graphdef;
717   {
718     auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
719 
720     Output a = ops::Const(scope.WithOpName("a"), 0.0f, {10, 10});
721 
722     TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
723   }
724 
725   TF_ASSERT_OK_AND_ASSIGN(auto fallback_state,
726                           tensorflow::tfrt_stub::FallbackState::Create({}, {}));
727 
728   TfrtGraphExecutionState::Options options;
729   options.run_placer_grappler_on_functions = false;
730   TF_ASSERT_OK_AND_ASSIGN(
731       auto graph_execution_state,
732       TfrtGraphExecutionState::Create(options, graphdef, *fallback_state));
733 
734   GraphDef extension;
735   {
736     auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
737 
738     Output b = ops::Const(scope.WithOpName("b"), 0.0f, {10, 10});
739 
740     TF_ASSERT_OK(scope.ToGraphDef(&extension));
741   }
742 
743   TF_ASSERT_OK(graph_execution_state->Extend(extension));
744 
745   GraphDef expected;
746   {
747     auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
748 
749     Output a = ops::Const(scope.WithOpName("a"), 0.0f, {10, 10});
750 
751     Output b = ops::Const(scope.WithOpName("b"), 0.0f, {10, 10});
752 
753     TF_ASSERT_OK(scope.ToGraphDef(&expected));
754   }
755 
756   ASSERT_NE(graph_execution_state->original_graph_def(), nullptr);
757   CompareGraphs(expected, *graph_execution_state->original_graph_def());
758 }
759 
760 // An auxiliary struct to verify the graph after partitioning and inserting
761 // transfer ops.
762 struct GraphInfo {
763   NodeDef* input_node = nullptr;
764   NodeDef* output_node = nullptr;
765   NodeDef* stateful_partitioned_call_node = nullptr;
766   std::vector<NodeDef*> partitioned_call_nodes;
767   std::vector<FunctionDef> fdefs;
768 };
769 
770 class InsertTransferOpsTest : public grappler::GrapplerTest {
771  protected:
SetUp()772   void SetUp() override {
773     SessionOptions options;
774     auto* device_count = options.config.mutable_device_count();
775     device_count->insert({"CPU", 2});
776     std::vector<std::unique_ptr<Device>> devices;
777     TF_ASSERT_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
778                                            &devices));
779     device0_ = devices[0].get();
780     device1_ = devices[1].get();
781 
782     fallback_state_ =
783         std::make_unique<FallbackState>(options, std::move(devices), fdef_lib_);
784   }
785 
GetGraphInfo(const std::string & input,const std::string & output,GraphDef & graphdef)786   GraphInfo GetGraphInfo(const std::string& input, const std::string& output,
787                          GraphDef& graphdef) {
788     GraphInfo graph_info;
789     for (NodeDef& node : *graphdef.mutable_node()) {
790       if (node.op() == "PartitionedCall") {
791         graph_info.partitioned_call_nodes.push_back(&node);
792       } else if (node.op() == "StatefulPartitionedCall") {
793         graph_info.stateful_partitioned_call_node = &node;
794       } else if (node.name() == input) {
795         graph_info.input_node = &node;
796       } else if (node.name() == output) {
797         graph_info.output_node = &node;
798       }
799     }
800 
801     // Find the corresponding function called by the PartitionedCall nodes.
802     absl::flat_hash_map<std::string, FunctionDef> func_name_to_func;
803     for (const FunctionDef& fdef : graphdef.library().function()) {
804       func_name_to_func[fdef.signature().name()] = fdef;
805     }
806     for (NodeDef* node : graph_info.partitioned_call_nodes) {
807       CHECK(node->attr().contains("f"));
808       CHECK(func_name_to_func.contains(node->attr().at("f").func().name()));
809       const FunctionDef& fdef =
810           func_name_to_func.at(node->attr().at("f").func().name());
811       graph_info.fdefs.push_back(fdef);
812     }
813     return graph_info;
814   }
815 
816   std::unique_ptr<FallbackState> fallback_state_;
817   Device* device0_ = nullptr;  // Not owned.
818   Device* device1_ = nullptr;  // Not owned.
819   tensorflow::FunctionDefLibrary fdef_lib_;
820 };
821 
TEST_F(InsertTransferOpsTest,InsertTransferOps)822 TEST_F(InsertTransferOpsTest, InsertTransferOps) {
823   GraphDef graphdef;
824   {
825     Scope scope = Scope::NewRootScope();
826     Scope scope1 = scope.WithDevice(device0_->name());
827     Scope scope2 = scope.WithDevice(device1_->name());
828 
829     // A graph whose nodes are on different devices.
830     // a(Const, on device0) -> b(Abs, on device1) -> c(Identity, on device0)
831     Output a = ops::Const(scope1.WithOpName("a"), 2.0, {1, 1});
832     Output b = ops::Abs(scope2.WithOpName("b"), a);
833     Output c = ops::Identity(scope1.WithOpName("c"), b);
834 
835     // Before partitioning, there is no send/recv nodes.
836     int send_count = 0, recv_count = 0;
837     for (const auto* op : scope.graph()->op_nodes()) {
838       if (op->IsSend())
839         ++send_count;
840       else if (op->IsRecv())
841         ++recv_count;
842     }
843     ASSERT_EQ(scope.graph()->num_op_nodes(), 3);
844     ASSERT_EQ(send_count, 0);
845     ASSERT_EQ(recv_count, 0);
846 
847     TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
848   }
849 
850   TfrtGraphExecutionState::Options options;
851   options.run_placer_grappler_on_functions = false;
852   options.enable_tfrt_gpu = true;
853   TF_ASSERT_OK_AND_ASSIGN(
854       auto graph_execution_state,
855       TfrtGraphExecutionState::Create(options, graphdef, *fallback_state_));
856 
857   tensorflow::GraphImportConfig graph_import_config;
858   graph_import_config.prune_unused_nodes = true;
859   graph_import_config.enable_shape_inference = false;
860   tensorflow::ArrayInfo array_info;
861   array_info.imported_dtype = DT_FLOAT;
862   array_info.shape.set_unknown_rank(true);
863   graph_import_config.inputs["a"] = array_info;
864   graph_import_config.outputs = {"c"};
865 
866   TF_ASSERT_OK_AND_ASSIGN(
867       auto optimized_graph,
868       graph_execution_state->CreateOptimizedGraph(graph_import_config));
869 
870   GraphDef new_graphdef;
871   optimized_graph.graph->ToGraphDef(&new_graphdef);
872 
873   GraphInfo graph_info =
874       GetGraphInfo(/*input=*/"a", /*output=*/"c", new_graphdef);
875 
876   ASSERT_THAT(graph_info.input_node, NotNull());
877   ASSERT_THAT(graph_info.output_node, NotNull());
878   ASSERT_THAT(graph_info.partitioned_call_nodes, SizeIs(2));
879   ASSERT_THAT(graph_info.stateful_partitioned_call_node, NotNull());
880 
881   // Verify that each partition contains a _Send op and a _Recv op.
882   for (const FunctionDef& fdef : graph_info.fdefs) {
883     int send_count = 0, recv_count = 0;
884     for (const NodeDef& node : fdef.node_def()) {
885       if (node.op() == "_Send")
886         ++send_count;
887       else if (node.op() == "_Recv")
888         ++recv_count;
889     }
890     EXPECT_EQ(send_count, 1);
891     EXPECT_EQ(recv_count, 1);
892   }
893 }
894 
TEST_F(InsertTransferOpsTest,InsertTransferOpsWithFunctionInlining)895 TEST_F(InsertTransferOpsTest, InsertTransferOpsWithFunctionInlining) {
896   GraphDef graphdef;
897   {
898     Scope scope = Scope::NewRootScope();
899     Scope scope1 = scope.WithDevice(device0_->name());
900     Scope scope2 = scope.WithDevice(device1_->name());
901 
902     // A graph whose nodes are on different devices.
903     // a(Const, on device0) -> b(PartitionedCall) -> c(Identity, on device0)
904     // where PartitionedCall invokes a function with two nodes assigned to
905     // different devices.
906     const Tensor kThree = test::AsScalar<float>(3.0);
907     auto fdef = tensorflow::FunctionDefHelper::Create(
908         "_Pow3", {"x: float"}, {"y: float"}, {},
909         {// The two nodes in the function are assigned to different devices.
910          {{"three"},
911           "Const",
912           {},
913           {{"dtype", DT_FLOAT}, {"value", kThree}},
914           /*dep=*/{},
915           device0_->name()},
916          {{"pow3"},
917           "Pow",
918           {"x", "three:output:0"},
919           {{"T", DT_FLOAT}},
920           /*dep=*/{},
921           device1_->name()}},
922         {{"y", "pow3:z:0"}});
923 
924     tensorflow::FunctionDefLibrary fdef_lib;
925     *fdef_lib.add_function() = fdef;
926     TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
927 
928     Output a = ops::Const<float>(scope1.WithOpName("a"), 2.0, {1, 1});
929 
930     std::vector<tensorflow::Output> inputs = {a};
931     std::vector<tensorflow::DataType> output_dtypes = {
932         fdef.signature().output_arg(0).type()};
933     tensorflow::NameAttrList func_attr;
934     func_attr.set_name(fdef.signature().name());
935     auto pcall = ops::PartitionedCall(scope2, inputs, output_dtypes, func_attr);
936     Output b = pcall.output.front();
937 
938     Output c = ops::Identity(scope1.WithOpName("c"), b);
939 
940     TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
941 
942     // Before partitioning, there is no send/recv nodes.
943     int partitioned_call_count = 0, mul_count = 0, send_count = 0,
944         recv_count = 0;
945     for (const auto* op : scope.graph()->op_nodes()) {
946       if (op->IsPartitionedCall())
947         ++partitioned_call_count;
948       else if (op->IsSend())
949         ++send_count;
950       else if (op->IsRecv())
951         ++recv_count;
952       else if (op->type_string() == "Mul")
953         ++mul_count;
954     }
955     ASSERT_EQ(partitioned_call_count, 1);
956     ASSERT_EQ(send_count, 0);
957     ASSERT_EQ(recv_count, 0);
958     ASSERT_EQ(mul_count, 0);
959   }
960 
961   TfrtGraphExecutionState::Options options;
962   options.run_placer_grappler_on_functions = false;
963   options.enable_tfrt_gpu = true;
964   TF_ASSERT_OK_AND_ASSIGN(
965       auto graph_execution_state,
966       TfrtGraphExecutionState::Create(options, graphdef, *fallback_state_));
967 
968   tensorflow::GraphImportConfig graph_import_config;
969   graph_import_config.prune_unused_nodes = true;
970   graph_import_config.enable_shape_inference = false;
971   tensorflow::ArrayInfo array_info;
972   array_info.imported_dtype = DT_FLOAT;
973   array_info.shape.set_unknown_rank(true);
974   graph_import_config.inputs["a"] = array_info;
975   graph_import_config.outputs = {"c"};
976 
977   TF_ASSERT_OK_AND_ASSIGN(
978       auto optimized_graph,
979       graph_execution_state->CreateOptimizedGraph(graph_import_config));
980 
981   GraphDef new_graphdef;
982   optimized_graph.graph->ToGraphDef(&new_graphdef);
983 
984   GraphInfo graph_info =
985       GetGraphInfo(/*input=*/"a", /*output=*/"c", new_graphdef);
986 
987   ASSERT_THAT(graph_info.input_node, NotNull());
988   ASSERT_THAT(graph_info.output_node, NotNull());
989   ASSERT_THAT(graph_info.partitioned_call_nodes, SizeIs(2));
990   ASSERT_THAT(graph_info.stateful_partitioned_call_node, NotNull());
991 
992   // Verify that each partition contains a _Send op and a _Recv op.
993   for (const FunctionDef& fdef : graph_info.fdefs) {
994     int send_count = 0, recv_count = 0;
995     for (const NodeDef& node : fdef.node_def()) {
996       if (node.op() == "_Send")
997         ++send_count;
998       else if (node.op() == "_Recv")
999         ++recv_count;
1000     }
1001     EXPECT_EQ(send_count, 1);
1002     EXPECT_EQ(recv_count, 1);
1003   }
1004 }
1005 
MakeOuterGraph(const FunctionLibraryDefinition & flib_def,const std::string & function_name)1006 std::unique_ptr<Graph> MakeOuterGraph(const FunctionLibraryDefinition& flib_def,
1007                                       const std::string& function_name) {
1008   Scope scope = Scope::NewRootScope().ExitOnError();
1009   TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto()));
1010 
1011   auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
1012   auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
1013   auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
1014   auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
1015   auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
1016   auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
1017   auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
1018 
1019   std::vector<tensorflow::NodeDefBuilder::NodeOut> func_inputs;
1020   func_inputs.push_back(
1021       tensorflow::NodeDefBuilder::NodeOut(a.node()->name(), 0, DT_INT32));
1022   func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(b.node()->name(), 0,
1023                                                             b.output.type()));
1024   func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(c.node()->name(), 0,
1025                                                             c.output.type()));
1026   func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(d.node()->name(), 0,
1027                                                             d.output.type()));
1028   func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(u.node()->name(), 0,
1029                                                             u.output.type()));
1030   func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(v.node()->name(), 0,
1031                                                             v.output.type()));
1032   func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(w.node()->name(), 0,
1033                                                             w.output.type()));
1034 
1035   std::vector<DataType> input_dtypes;
1036   for (const NodeDefBuilder::NodeOut& func_input : func_inputs) {
1037     input_dtypes.push_back(func_input.data_type);
1038   }
1039 
1040   std::vector<DataType> output_dtypes = {DT_FLOAT, DT_INT32, DT_FLOAT,
1041                                          DT_FLOAT};
1042 
1043   NameAttrList f;
1044   f.set_name(function_name);
1045 
1046   NodeDef def;
1047   TF_CHECK_OK(NodeDefBuilder("xla_call_0", "StatefulPartitionedCall", &flib_def)
1048                   .Input(func_inputs)
1049                   .Attr("Tin", input_dtypes)
1050                   .Attr("Tout", output_dtypes)
1051                   .Attr("f", f)
1052                   .Device("/gpu:0")
1053                   .Attr(kXlaMustCompileAttr, true)
1054                   .Finalize(&def));
1055 
1056   Status status;
1057   Node* launch = scope.graph()->AddNode(def, &status);
1058   TF_CHECK_OK(status);
1059   TF_CHECK_OK(scope.DoShapeInference(launch));
1060   scope.graph()->AddEdge(a.node(), 0, launch, 0);
1061   scope.graph()->AddEdge(b.node(), 0, launch, 1);
1062   scope.graph()->AddEdge(c.node(), 0, launch, 2);
1063   scope.graph()->AddEdge(d.node(), 0, launch, 3);
1064   scope.graph()->AddEdge(u.node(), 0, launch, 4);
1065   scope.graph()->AddEdge(v.node(), 0, launch, 5);
1066   scope.graph()->AddEdge(w.node(), 0, launch, 6);
1067 
1068   auto consumer0_a =
1069       ops::Identity(scope.WithOpName("consumer0_a"), Output(launch, 0));
1070   auto consumer0_b =
1071       ops::Identity(scope.WithOpName("consumer0_b"), Output(launch, 0));
1072   auto consumer0_c =
1073       ops::Identity(scope.WithOpName("consumer0_c"), Output(launch, 0));
1074   auto consumer1 =
1075       ops::Identity(scope.WithOpName("consumer1"), Output(launch, 1));
1076   auto consumer2 =
1077       ops::Identity(scope.WithOpName("consumer2"), Output(launch, 2));
1078   auto consumer3 =
1079       ops::Identity(scope.WithOpName("consumer3"), Output(launch, 3));
1080 
1081   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1082   TF_CHECK_OK(scope.ToGraph(graph.get()));
1083   return graph;
1084 }
1085 
1086 // Makes an encapsulate body graph for use in tests.
MakeBodyGraph()1087 std::unique_ptr<Graph> MakeBodyGraph() {
1088   Scope scope = Scope::NewRootScope().ExitOnError();
1089 
1090   auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0);
1091   auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1);
1092   auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2);
1093   auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3);
1094 
1095   auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4);
1096   auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5);
1097   auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
1098 
1099   auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1);
1100   auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT);
1101   auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT);
1102   auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT);
1103 
1104   auto e = ops::Add(scope.WithOpName("E"), arg0, arg2);
1105   auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
1106   auto g = ops::Add(scope.WithOpName("G"), f, arg3);
1107 
1108   auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"),
1109                            b_identity, 0);
1110   auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1);
1111   auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2);
1112   auto out3 =
1113       ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3);
1114 
1115   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1116   TF_CHECK_OK(scope.ToGraph(graph.get()));
1117   return graph;
1118 }
1119 
TEST(BuildXlaOpsTest,BuildXlaLaunchOp)1120 TEST(BuildXlaOpsTest, BuildXlaLaunchOp) {
1121   std::unique_ptr<Graph> body_graph = MakeBodyGraph();
1122   FunctionDefLibrary flib;
1123   TF_ASSERT_OK(
1124       GraphToFunctionDef(*body_graph, "xla_func_0", flib.add_function()));
1125 
1126   FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
1127 
1128   std::unique_ptr<Graph> graph = MakeOuterGraph(flib_def, "xla_func_0");
1129   TF_ASSERT_OK(BuildXlaLaunchOps(graph.get()));
1130 
1131   Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError();
1132   TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
1133 
1134   auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
1135   auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
1136   auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
1137   auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
1138   auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
1139   auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
1140   auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
1141 
1142   NameAttrList function;
1143   function.set_name("xla_func_0");
1144   auto launch = ops::XlaLaunch(
1145       scope.WithOpName("xla_call_0").WithDevice("/gpu:0"),
1146       std::initializer_list<Input>{}, std::initializer_list<Input>{a, b, c, d},
1147       std::initializer_list<Input>{u, v, w},
1148       DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function);
1149 
1150   auto consumer0_a =
1151       ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]);
1152   auto consumer0_b =
1153       ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]);
1154   auto consumer0_c =
1155       ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]);
1156   auto consumer1 =
1157       ops::Identity(scope.WithOpName("consumer1"), launch.results[1]);
1158   auto consumer2 =
1159       ops::Identity(scope.WithOpName("consumer2"), launch.results[2]);
1160   auto consumer3 =
1161       ops::Identity(scope.WithOpName("consumer3"), launch.results[3]);
1162 
1163   GraphDef expected_def;
1164   TF_ASSERT_OK(scope.ToGraphDef(&expected_def));
1165 
1166   GraphDef actual_def;
1167   graph->ToGraphDef(&actual_def);
1168   TF_EXPECT_GRAPH_EQ(expected_def, actual_def);
1169 }
1170 
1171 }  // namespace
1172 }  // namespace tfrt_stub
1173 }  // namespace tensorflow
1174