1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h"
16
17 #include <memory>
18 #include <utility>
19
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "tensorflow/cc/framework/scope.h"
23 #include "tensorflow/cc/ops/array_ops.h"
24 #include "tensorflow/cc/ops/const_op.h"
25 #include "tensorflow/cc/ops/function_ops.h"
26 #include "tensorflow/cc/ops/functional_ops.h"
27 #include "tensorflow/cc/ops/math_ops.h"
28 #include "tensorflow/cc/ops/resource_variable_ops.h"
29 #include "tensorflow/cc/ops/sendrecv_ops.h"
30 #include "tensorflow/cc/ops/standard_ops.h"
31 #include "tensorflow/cc/ops/while_loop.h"
32 #include "tensorflow/compiler/jit/defs.h"
33 #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
34 #include "tensorflow/core/framework/attr_value.pb.h"
35 #include "tensorflow/core/framework/device_factory.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/graph.pb.h"
38 #include "tensorflow/core/framework/graph_to_functiondef.h"
39 #include "tensorflow/core/framework/node_def.pb.h"
40 #include "tensorflow/core/framework/node_def_builder.h"
41 #include "tensorflow/core/framework/tensor_testutil.h"
42 #include "tensorflow/core/framework/types.pb.h"
43 #include "tensorflow/core/grappler/utils/grappler_test.h"
44 #include "tensorflow/core/kernels/resource_variable_ops.h"
45 #include "tensorflow/core/lib/core/status_test_util.h"
46 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
47 #include "tensorflow/core/util/equal_graph_def.h"
48
49 namespace tensorflow {
50 namespace tfrt_stub {
51 namespace {
52
53 using ::testing::_;
54 using ::testing::ElementsAre;
55 using ::testing::EqualsProto;
56 using ::testing::HasSubstr;
57 using ::testing::IsEmpty;
58 using ::testing::NotNull;
59 using ::testing::Pair;
60 using ::testing::SizeIs;
61 using ::testing::proto::IgnoringFieldPaths;
62 using ::testing::proto::IgnoringRepeatedFieldOrdering;
63
64 class PruneGraphDefTest : public grappler::GrapplerTest {};
65
TEST_F(PruneGraphDefTest,ConstFeedWithInput)66 TEST_F(PruneGraphDefTest, ConstFeedWithInput) {
67 GraphDef graphdef;
68 {
69 auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
70
71 Output a = ops::Const(scope.WithOpName("a"), 0.0f, {10, 10});
72
73 Output b = ops::Const(scope.WithControlDependencies(a).WithOpName("b"),
74 0.0f, {10, 10});
75 Output c = ops::Identity(scope.WithOpName("c"), b);
76
77 TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
78 }
79
80 CallableOptions callable_options;
81 callable_options.add_feed("b");
82 callable_options.add_fetch("c");
83
84 TF_ASSERT_OK(PruneGraphDef(graphdef, callable_options));
85
86 GraphDef expected;
87 {
88 auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
89
90 Output b = ops::Const(scope.WithOpName("b"), 0.0f, {10, 10});
91 Output c = ops::Identity(scope.WithOpName("c"), b);
92
93 TF_ASSERT_OK(scope.ToGraphDef(&expected));
94 }
95
96 CompareGraphs(expected, graphdef);
97 }
98
LessThanTenCond(const Scope & scope,const std::vector<Output> & inputs,Output * output)99 Status LessThanTenCond(const Scope& scope, const std::vector<Output>& inputs,
100 Output* output) {
101 *output = ops::Less(scope, inputs[0], 10);
102 return scope.status();
103 }
104
AddOneBody(const Scope & scope,const std::vector<Output> & inputs,std::vector<Output> * outputs)105 Status AddOneBody(const Scope& scope, const std::vector<Output>& inputs,
106 std::vector<Output>* outputs) {
107 outputs->push_back(ops::AddN(scope, {inputs[0], 1}));
108 return scope.status();
109 }
110
TEST_F(PruneGraphDefTest,InsertIdentityForLoopExitFeed)111 TEST_F(PruneGraphDefTest, InsertIdentityForLoopExitFeed) {
112 GraphDef graphdef;
113 {
114 auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
115
116 std::vector<Output> inputs;
117 inputs.push_back(ops::Placeholder(scope.WithOpName("input"), DT_INT32));
118 std::vector<Output> outputs;
119 TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("while"), inputs,
120 LessThanTenCond, AddOneBody, "test_loop",
121 &outputs));
122
123 TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
124 }
125
126 CallableOptions callable_options;
127 callable_options.add_feed("input");
128 callable_options.add_fetch("while/Exit");
129
130 TF_ASSERT_OK(PruneGraphDef(graphdef, callable_options));
131
132 for (const auto& node : graphdef.node()) {
133 if (node.op() == "Exit") {
134 EXPECT_EQ(node.name(), "while/Exit/tfrt_renamed");
135 }
136 if (node.name() == "while/Exit") {
137 EXPECT_EQ(node.op(), "Identity");
138 ASSERT_EQ(node.input().size(), 1);
139 EXPECT_EQ(node.input(0), "while/Exit/tfrt_renamed");
140 }
141 }
142 }
143
TEST_F(PruneGraphDefTest,EliminateRefEntersFromControlFlow)144 TEST_F(PruneGraphDefTest, EliminateRefEntersFromControlFlow) {
145 GraphDef graphdef;
146 absl::flat_hash_map<std::string, NodeDef> name_to_node;
147 {
148 auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
149
150 std::vector<Output> inputs;
151 inputs.push_back(ops::Placeholder(scope.WithOpName("input"), DT_INT32));
152 std::vector<Output> outputs1;
153 std::vector<Output> outputs2;
154 TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("while"), inputs,
155 LessThanTenCond, AddOneBody, "test_loop",
156 &outputs1));
157 TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("while"), inputs,
158 LessThanTenCond, AddOneBody, "test_loop2",
159 &outputs2));
160
161 TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
162
163 // Simply replace Enter with RefEnter. Note this is not valid graph though.
164 for (auto& node : *graphdef.mutable_node()) {
165 if (node.op() == "Enter") {
166 node.set_op("RefEnter");
167 }
168 name_to_node.insert({node.name(), node});
169 }
170 }
171
172 TF_ASSERT_OK(EliminateRefVariablesFromV1ControlFlow(graphdef));
173
174 int num_identity_op = 0;
175 int num_enter_op = 0;
176 int num_ref_enter_op = 0;
177 for (const auto& node : graphdef.node()) {
178 if (node.op() == "Identity") {
179 num_identity_op++;
180 EXPECT_EQ(node.name(), "input/identity");
181 ASSERT_EQ(node.input().size(), 1);
182 EXPECT_EQ(node.input(0), "input");
183 EXPECT_THAT(node.attr(), ElementsAre(Pair("T", _)));
184 } else if (node.op() == "RefEnter") {
185 num_ref_enter_op++;
186 } else if (node.op() == "Enter") {
187 // Identity op should be placed before Enter.
188 EXPECT_EQ(num_identity_op, 1);
189 num_enter_op++;
190 ASSERT_EQ(node.input().size(), 1);
191 EXPECT_EQ(node.input(0), "input/identity");
192 EXPECT_THAT(
193 node, IgnoringFieldPaths({"input", "op"},
194 EqualsProto(name_to_node.at(node.name()))));
195 } else {
196 EXPECT_THAT(node, EqualsProto(name_to_node.at(node.name())));
197 }
198 name_to_node.erase(node.name());
199 }
200 EXPECT_EQ(num_identity_op, 1);
201 EXPECT_EQ(num_enter_op, 2);
202 EXPECT_EQ(num_ref_enter_op, 0);
203 EXPECT_THAT(name_to_node, IsEmpty());
204 }
205
TEST_F(PruneGraphDefTest,EliminateRefSwitchesFromControlFlow)206 TEST_F(PruneGraphDefTest, EliminateRefSwitchesFromControlFlow) {
207 GraphDef graphdef;
208 absl::flat_hash_map<std::string, NodeDef> name_to_node;
209 {
210 auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
211
212 Output cond_a = ops::Placeholder(scope.WithOpName("cond_a"), DT_BOOL);
213 Output cond_b = ops::Placeholder(scope.WithOpName("cond_b"), DT_BOOL);
214 Output input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
215
216 ops::Switch switch_a(scope.WithOpName("switch_a"), input, cond_a);
217 ops::Switch switch_b(scope.WithOpName("switch_b"), input, cond_b);
218
219 Output switch_a_true =
220 ops::Identity(scope.WithOpName("switch_a_true"), switch_a.output_true);
221 Output switch_b_true =
222 ops::Identity(scope.WithOpName("switch_b_true"), switch_b.output_true);
223
224 TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
225
226 // Simply replace Switch with RefSwitch. Note this is not valid graph
227 // though.
228 for (auto& node : *graphdef.mutable_node()) {
229 if (node.op() == "Switch") {
230 node.set_op("RefSwitch");
231 }
232 name_to_node.insert({node.name(), node});
233 }
234 }
235
236 TF_ASSERT_OK(EliminateRefVariablesFromV1ControlFlow(graphdef));
237
238 int num_identity_op = 0;
239 int num_switch_op = 0;
240 int num_ref_switch_op = 0;
241 for (const auto& node : graphdef.node()) {
242 if (node.name() == "switch_a_true" || node.name() == "switch_b_true") {
243 EXPECT_THAT(node, EqualsProto(name_to_node.at(node.name())));
244 } else if (node.op() == "Identity") {
245 num_identity_op++;
246 EXPECT_EQ(node.name(), "input/identity");
247 ASSERT_EQ(node.input().size(), 1);
248 EXPECT_EQ(node.input(0), "input");
249 EXPECT_THAT(node.attr(), ElementsAre(Pair("T", _)));
250 } else if (node.op() == "RefSwitch") {
251 num_ref_switch_op++;
252 } else if (node.op() == "Switch") {
253 // Identity op should be placed before Switch.
254 EXPECT_EQ(num_identity_op, 1);
255 num_switch_op++;
256 ASSERT_EQ(node.input().size(), 2);
257 EXPECT_TRUE(node.input(0) == "input/identity" ||
258 node.input(1) == "input/identity");
259 EXPECT_THAT(
260 node, IgnoringFieldPaths({"input", "op"},
261 EqualsProto(name_to_node.at(node.name()))));
262 } else {
263 EXPECT_THAT(node, EqualsProto(name_to_node.at(node.name())));
264 }
265 name_to_node.erase(node.name());
266 }
267 EXPECT_EQ(num_identity_op, 1);
268 EXPECT_EQ(num_switch_op, 2);
269 EXPECT_EQ(num_ref_switch_op, 0);
270 EXPECT_THAT(name_to_node, IsEmpty());
271 }
272
TEST_F(PruneGraphDefTest,EliminateRefVariablesFromV1ControlFlowFailed)273 TEST_F(PruneGraphDefTest, EliminateRefVariablesFromV1ControlFlowFailed) {
274 GraphDef graphdef;
275 absl::flat_hash_map<std::string, NodeDef> name_to_node;
276 {
277 auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
278
279 Output cond = ops::Placeholder(scope.WithOpName("cond"), DT_BOOL);
280 Output input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
281
282 ops::Switch switch_op(scope.WithOpName("switch"), input, cond);
283 Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
284 Output assign =
285 ops::Assign(scope.WithOpName("assign"), var, switch_op.output_true);
286
287 TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
288
289 // Simply replace Switch with RefSwitch. Note this is not valid graph
290 // though.
291 for (auto& node : *graphdef.mutable_node()) {
292 if (node.op() == "Switch") {
293 node.set_op("RefSwitch");
294 }
295 name_to_node.insert({node.name(), node});
296 }
297 }
298
299 const auto status = EliminateRefVariablesFromV1ControlFlow(graphdef);
300 EXPECT_FALSE(status.ok());
301 EXPECT_THAT(status.error_message(),
302 HasSubstr("requires its input to be refs"));
303 }
304
TEST_F(PruneGraphDefTest,KeepLoopStructureComplete)305 TEST_F(PruneGraphDefTest, KeepLoopStructureComplete) {
306 GraphDef graphdef;
307 {
308 auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
309
310 std::vector<Output> inputs;
311 inputs.push_back(ops::Placeholder(scope.WithOpName("input"), DT_INT32));
312 std::vector<Output> outputs;
313 TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("while"), inputs,
314 LessThanTenCond, AddOneBody, "test_loop",
315 &outputs));
316
317 TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
318 }
319
320 CallableOptions callable_options;
321 callable_options.add_feed("input");
322 // Sets the fetch node such that traversing from there will miss part of the
323 // while loop structure.
324 callable_options.add_fetch("while/LoopCond");
325
326 GraphDef original_graphdef = graphdef;
327 TF_ASSERT_OK(PruneGraphDef(graphdef, callable_options));
328 EXPECT_THAT(graphdef,
329 IgnoringRepeatedFieldOrdering(EqualsProto(original_graphdef)));
330 }
331
332 class OptimizeGraphTest : public grappler::GrapplerTest {};
333
TEST_F(OptimizeGraphTest,OptimizeFunctions)334 TEST_F(OptimizeGraphTest, OptimizeFunctions) {
335 GraphDef graphdef;
336 tensorflow::FunctionDefLibrary fdef_lib;
337 {
338 auto scope = tensorflow::Scope::NewRootScope().WithDevice(
339 "/job:localhost/replica:0/task:0/device:CPU:0");
340
341 const Tensor kThree = test::AsScalar<float>(3.0);
342 auto fdef = tensorflow::FunctionDefHelper::Create(
343 "Pow3", {"x: float"}, {"y: float"}, {},
344 {{{"three"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", kThree}}},
345 {{"pow3"}, "Pow", {"x", "three:output:0"}, {{"T", DT_FLOAT}}}},
346 {{"y", "pow3:z:0"}});
347
348 tensorflow::FunctionDefLibrary fdef_lib;
349 *fdef_lib.add_function() = fdef;
350 TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
351
352 Output a = ops::Const(scope.WithOpName("a"), 2.0, {1, 1});
353
354 std::vector<tensorflow::Output> inputs = {a};
355 std::vector<tensorflow::DataType> output_dtypes = {
356 fdef.signature().output_arg(0).type()};
357 tensorflow::NameAttrList func_attr;
358 func_attr.set_name(fdef.signature().name());
359 auto pcall = ops::PartitionedCall(scope, inputs, output_dtypes, func_attr);
360 Output b = pcall.output.front();
361
362 Output c = ops::Identity(scope.WithOpName("c"), b);
363
364 TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
365 }
366
367 TF_ASSERT_OK_AND_ASSIGN(
368 auto fallback_state,
369 tensorflow::tfrt_stub::FallbackState::Create({}, fdef_lib));
370
371 TfrtGraphExecutionState::Options options;
372 options.run_placer_grappler_on_functions = true;
373 TF_ASSERT_OK_AND_ASSIGN(
374 auto graph_execution_state,
375 TfrtGraphExecutionState::Create(options, graphdef, *fallback_state));
376
377 tensorflow::GraphImportConfig graph_import_config;
378 graph_import_config.prune_unused_nodes = true;
379 graph_import_config.enable_shape_inference = false;
380 tensorflow::ArrayInfo array_info;
381 array_info.imported_dtype = DT_FLOAT;
382 array_info.shape.set_unknown_rank(true);
383 graph_import_config.inputs["a"] = array_info;
384 graph_import_config.outputs = {"c"};
385
386 TF_ASSERT_OK_AND_ASSIGN(
387 auto optimized_graph,
388 graph_execution_state->CreateOptimizedGraph(graph_import_config));
389 GraphDef optimized_graph_def;
390 optimized_graph.graph->ToGraphDef(&optimized_graph_def);
391
392 GraphDef expected;
393 {
394 auto scope = tensorflow::Scope::NewRootScope().WithDevice(
395 "/job:localhost/replica:0/task:0/device:CPU:0");
396
397 const Tensor kThree = test::AsScalar<float>(3.0);
398 // After optimization, "x^3" will be transformed to "(x^2)*x".
399 auto fdef = tensorflow::FunctionDefHelper::Create(
400 "Pow3", {"x: float"}, {"y_retval: float"}, {},
401 {{{"ArithmeticOptimizer/ConvertPow__inner_pow3"},
402 "Square",
403 {"x"},
404 {{"dtype", DT_FLOAT}},
405 /*dep=*/{},
406 "/job:localhost/replica:0/task:0/device:CPU:0"},
407 {{"pow3"},
408 "Mul",
409 {"ArithmeticOptimizer/ConvertPow__inner_pow3:y:0", "x"},
410 {{"T", DT_FLOAT}},
411 /*dep=*/{},
412 "/job:localhost/replica:0/task:0/device:CPU:0"}},
413 {{"y_retval", "pow3:z:0"}});
414
415 tensorflow::FunctionDefLibrary fdef_lib;
416 *fdef_lib.add_function() = fdef;
417 TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
418
419 Output a = ops::Const(scope.WithOpName("a"), 2.0, {1, 1});
420
421 std::vector<tensorflow::Output> inputs = {a};
422 std::vector<tensorflow::DataType> output_dtypes = {
423 fdef.signature().output_arg(0).type()};
424 tensorflow::NameAttrList func_attr;
425 func_attr.set_name(fdef.signature().name());
426 auto pcall = ops::PartitionedCall(scope, inputs, output_dtypes, func_attr);
427 Output b = pcall.output.front();
428
429 Output c = ops::Identity(scope.WithOpName("c"), b);
430
431 TF_ASSERT_OK(scope.ToGraphDef(&expected));
432 }
433
434 CompareGraphs(expected, optimized_graph_def);
435 CompareFunctions(expected.library().function(0),
436 optimized_graph_def.library().function(0));
437 }
438
TEST_F(OptimizeGraphTest,OptimizeFunctionsUsedByFunctionNodes)439 TEST_F(OptimizeGraphTest, OptimizeFunctionsUsedByFunctionNodes) {
440 GraphDef graphdef;
441 tensorflow::FunctionDefLibrary fdef_lib;
442 {
443 auto scope = tensorflow::Scope::NewRootScope().WithDevice(
444 "/job:localhost/replica:0/task:0/device:CPU:0");
445
446 const Tensor kThree = test::AsScalar<float>(3.0);
447 auto pow3_fdef = tensorflow::FunctionDefHelper::Create(
448 "Pow3", {"x: float"}, {"y: float"}, {},
449 {{{"three"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", kThree}}},
450 {{"pow3"}, "Pow", {"x", "three:output:0"}, {{"T", DT_FLOAT}}}},
451 {{"y", "pow3:z:0"}});
452
453 const Tensor kOne = test::AsScalar<float>(1.0);
454 auto base2pow3_fdef = tensorflow::FunctionDefHelper::Create(
455 "Add1Pow3", {"x: float"}, {"y: float"}, {},
456 {{{"one"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", kOne}}},
457 {{"add"}, "Add", {"x", "one:output:0"}, {{"T", DT_FLOAT}}},
458 {{"pcall"},
459 "PartitionedCall",
460 {"add:z:0"},
461 {{"Tin", DataTypeSlice({DT_FLOAT})},
462 {"Tout", DataTypeSlice({DT_FLOAT})},
463 {"f", tensorflow::FunctionDefHelper::FunctionRef(
464 "Pow3", {{"T", DT_FLOAT}})}}}},
465 {{"y", "pcall:output:0"}});
466
467 tensorflow::FunctionDefLibrary fdef_lib;
468 *fdef_lib.add_function() = pow3_fdef;
469 *fdef_lib.add_function() = base2pow3_fdef;
470 TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
471
472 Output a = ops::Const(scope.WithOpName("a"), 1.0, {1, 1});
473
474 std::vector<tensorflow::Output> inputs = {a};
475 std::vector<tensorflow::DataType> output_dtypes = {
476 base2pow3_fdef.signature().output_arg(0).type()};
477 tensorflow::NameAttrList func_attr;
478 func_attr.set_name(base2pow3_fdef.signature().name());
479 auto pcall = ops::PartitionedCall(scope, inputs, output_dtypes, func_attr);
480 Output b = pcall.output.front();
481
482 Output c = ops::Identity(scope.WithOpName("c"), b);
483
484 TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
485 }
486
487 TF_ASSERT_OK_AND_ASSIGN(
488 auto fallback_state,
489 tensorflow::tfrt_stub::FallbackState::Create({}, fdef_lib));
490
491 TfrtGraphExecutionState::Options options;
492 options.run_placer_grappler_on_functions = true;
493 TF_ASSERT_OK_AND_ASSIGN(
494 auto graph_execution_state,
495 TfrtGraphExecutionState::Create(options, graphdef, *fallback_state));
496
497 tensorflow::GraphImportConfig graph_import_config;
498 graph_import_config.prune_unused_nodes = true;
499 graph_import_config.enable_shape_inference = false;
500 tensorflow::ArrayInfo array_info;
501 array_info.imported_dtype = DT_FLOAT;
502 array_info.shape.set_unknown_rank(true);
503 graph_import_config.inputs["a"] = array_info;
504 graph_import_config.outputs = {"c"};
505
506 TF_ASSERT_OK_AND_ASSIGN(
507 auto optimized_graph,
508 graph_execution_state->CreateOptimizedGraph(graph_import_config));
509 GraphDef optimized_graph_def;
510 optimized_graph.graph->ToGraphDef(&optimized_graph_def);
511
512 GraphDef expected;
513 {
514 auto scope = tensorflow::Scope::NewRootScope().WithDevice(
515 "/job:localhost/replica:0/task:0/device:CPU:0");
516
517 const Tensor kThree = test::AsScalar<float>(3.0);
518 // After optimization, "x^3" will be transformed to "(x^2)*x".
519 auto pow3_fdef = tensorflow::FunctionDefHelper::Create(
520 "Pow3", {"x: float"}, {"y_retval: float"}, {},
521 {{{"ArithmeticOptimizer/ConvertPow__inner_pow3"},
522 "Square",
523 {"x"},
524 {{"dtype", DT_FLOAT}},
525 /*dep=*/{},
526 "/job:localhost/replica:0/task:0/device:CPU:0"},
527 {{"pow3"},
528 "Mul",
529 {"ArithmeticOptimizer/ConvertPow__inner_pow3:y:0", "x"},
530 {{"T", DT_FLOAT}},
531 /*dep=*/{},
532 "/job:localhost/replica:0/task:0/device:CPU:0"}},
533 {{"y_retval", "pow3:z:0"}});
534
535 const Tensor kOne = test::AsScalar<float>(1.0);
536 auto base2pow3_fdef = tensorflow::FunctionDefHelper::Create(
537 "Add1Pow3", {"x: float"}, {"y: float"}, {},
538 {{{"one"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", kOne}}},
539 {{"add"}, "Add", {"x", "one:output:0"}, {{"T", DT_FLOAT}}},
540 {{"pcall"},
541 "PartitionedCall",
542 {"add:z:0"},
543 {{"Tin", DataTypeSlice({DT_FLOAT})},
544 {"Tout", DataTypeSlice({DT_FLOAT})},
545 {"f", tensorflow::FunctionDefHelper::FunctionRef(
546 "Pow3", {{"T", DT_FLOAT}})}}}},
547 {{"y", "pcall:output:0"}});
548
549 tensorflow::FunctionDefLibrary fdef_lib;
550 *fdef_lib.add_function() = pow3_fdef;
551 *fdef_lib.add_function() = base2pow3_fdef;
552 TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
553
554 Output a = ops::Const(scope.WithOpName("a"), 1.0, {1, 1});
555
556 std::vector<tensorflow::Output> inputs = {a};
557 std::vector<tensorflow::DataType> output_dtypes = {
558 base2pow3_fdef.signature().output_arg(0).type()};
559 tensorflow::NameAttrList func_attr;
560 func_attr.set_name(base2pow3_fdef.signature().name());
561 auto pcall = ops::PartitionedCall(scope, inputs, output_dtypes, func_attr);
562 Output b = pcall.output.front();
563
564 Output c = ops::Identity(scope.WithOpName("c"), b);
565
566 TF_ASSERT_OK(scope.ToGraphDef(&expected));
567 }
568
569 // Since `Pow3` is called by `Add1Pow3`, it is optimized.
570 CompareFunctions(expected.library().function(1),
571 optimized_graph_def.library().function(1));
572 ASSERT_EQ("Pow3",
573 optimized_graph_def.library().function(1).signature().name());
574 }
575
TEST_F(OptimizeGraphTest,DontOptimizeUnsafeFunction)576 TEST_F(OptimizeGraphTest, DontOptimizeUnsafeFunction) {
577 GraphDef graphdef;
578 tensorflow::FunctionDefLibrary fdef_lib;
579 {
580 auto scope = tensorflow::Scope::NewRootScope().WithDevice(
581 "/job:localhost/replica:0/task:0/device:CPU:0");
582
583 const Tensor kThree = test::AsScalar<float>(3.0);
584 auto fdef = tensorflow::FunctionDefHelper::Create(
585 "Pow3", {"x: float"}, {"y: float"}, {},
586 {{{"three"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", kThree}}},
587 {{"pow3"}, "Pow", {"x", "three:output:0"}, {{"T", DT_FLOAT}}}},
588 {{"y", "pow3:z:0"}});
589
590 tensorflow::FunctionDefLibrary fdef_lib;
591 *fdef_lib.add_function() = fdef;
592 TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
593
594 Output a = ops::Const(scope.WithOpName("a"), 2.0, {1, 1});
595
596 Output cond = ops::Const(scope.WithOpName("cond"), true, {1, 1});
597 std::vector<tensorflow::Output> inputs = {a};
598 std::vector<tensorflow::DataType> output_dtypes = {
599 fdef.signature().output_arg(0).type()};
600 tensorflow::NameAttrList func_attr;
601 func_attr.set_name(fdef.signature().name());
602 auto if_op =
603 ops::If(scope, cond, inputs, output_dtypes, func_attr, func_attr);
604 Output b = if_op.output.front();
605
606 Output c = ops::Identity(scope.WithOpName("c"), b);
607
608 TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
609 }
610
611 TF_ASSERT_OK_AND_ASSIGN(
612 auto fallback_state,
613 tensorflow::tfrt_stub::FallbackState::Create({}, fdef_lib));
614
615 TfrtGraphExecutionState::Options options;
616 options.run_placer_grappler_on_functions = true;
617 TF_ASSERT_OK_AND_ASSIGN(
618 auto graph_execution_state,
619 TfrtGraphExecutionState::Create(options, graphdef, *fallback_state));
620
621 tensorflow::GraphImportConfig graph_import_config;
622 graph_import_config.prune_unused_nodes = true;
623 graph_import_config.enable_shape_inference = false;
624 tensorflow::ArrayInfo array_info;
625 array_info.imported_dtype = DT_FLOAT;
626 array_info.shape.set_unknown_rank(true);
627 graph_import_config.inputs["a"] = array_info;
628 graph_import_config.outputs = {"c"};
629
630 TF_ASSERT_OK_AND_ASSIGN(
631 auto optimized_graph,
632 graph_execution_state->CreateOptimizedGraph(graph_import_config));
633 GraphDef optimized_graph_def;
634 optimized_graph.graph->ToGraphDef(&optimized_graph_def);
635
636 // The optimized graph remains the same as the original one, because the
637 // function used by `If` op is not optimized.
638 CompareGraphs(graphdef, optimized_graph_def);
639 CompareFunctions(graphdef.library().function(0),
640 optimized_graph_def.library().function(0));
641 }
642
TEST_F(OptimizeGraphTest,FunctionBecomeUnsafeIfAnyOpIsUnsafe)643 TEST_F(OptimizeGraphTest, FunctionBecomeUnsafeIfAnyOpIsUnsafe) {
644 GraphDef graphdef;
645 tensorflow::FunctionDefLibrary fdef_lib;
646 {
647 auto scope = tensorflow::Scope::NewRootScope().WithDevice(
648 "/job:localhost/replica:0/task:0/device:CPU:0");
649
650 const Tensor kThree = test::AsScalar<float>(3.0);
651 auto fdef = tensorflow::FunctionDefHelper::Create(
652 "Pow3", {"x: float"}, {"y: float"}, {},
653 {{{"three"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", kThree}}},
654 {{"pow3"}, "Pow", {"x", "three:output:0"}, {{"T", DT_FLOAT}}}},
655 {{"y", "pow3:z:0"}});
656
657 tensorflow::FunctionDefLibrary fdef_lib;
658 *fdef_lib.add_function() = fdef;
659 TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
660
661 Output a = ops::Const(scope.WithOpName("a"), 2.0, {1, 1});
662
663 Output cond = ops::Const(scope.WithOpName("cond"), true, {1, 1});
664 std::vector<tensorflow::Output> inputs = {a};
665 std::vector<tensorflow::DataType> output_dtypes = {
666 fdef.signature().output_arg(0).type()};
667 tensorflow::NameAttrList func_attr;
668 func_attr.set_name(fdef.signature().name());
669 auto if_op =
670 ops::If(scope, cond, inputs, output_dtypes, func_attr, func_attr);
671 Output b = if_op.output.front();
672
673 inputs = {b};
674 auto pcall = ops::PartitionedCall(scope, inputs, output_dtypes, func_attr);
675 Output c = pcall.output.front();
676
677 Output d = ops::Identity(scope.WithOpName("d"), c);
678
679 TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
680 }
681
682 TF_ASSERT_OK_AND_ASSIGN(
683 auto fallback_state,
684 tensorflow::tfrt_stub::FallbackState::Create({}, fdef_lib));
685
686 TfrtGraphExecutionState::Options options;
687 options.run_placer_grappler_on_functions = true;
688 TF_ASSERT_OK_AND_ASSIGN(
689 auto graph_execution_state,
690 TfrtGraphExecutionState::Create(options, graphdef, *fallback_state));
691
692 tensorflow::GraphImportConfig graph_import_config;
693 graph_import_config.prune_unused_nodes = true;
694 graph_import_config.enable_shape_inference = false;
695 tensorflow::ArrayInfo array_info;
696 array_info.imported_dtype = DT_FLOAT;
697 array_info.shape.set_unknown_rank(true);
698 graph_import_config.inputs["a"] = array_info;
699 graph_import_config.outputs = {"d"};
700
701 TF_ASSERT_OK_AND_ASSIGN(
702 auto optimized_graph,
703 graph_execution_state->CreateOptimizedGraph(graph_import_config));
704 GraphDef optimized_graph_def;
705 optimized_graph.graph->ToGraphDef(&optimized_graph_def);
706
707 // Both `If` and `PartitionedCall` ops use the function, so the function
708 // remains unoptimized.
709 CompareFunctions(graphdef.library().function(0),
710 optimized_graph_def.library().function(0));
711 }
712
713 class ExtendGraphTest : public grappler::GrapplerTest {};
714
TEST_F(ExtendGraphTest,ExtendGraph)715 TEST_F(ExtendGraphTest, ExtendGraph) {
716 GraphDef graphdef;
717 {
718 auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
719
720 Output a = ops::Const(scope.WithOpName("a"), 0.0f, {10, 10});
721
722 TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
723 }
724
725 TF_ASSERT_OK_AND_ASSIGN(auto fallback_state,
726 tensorflow::tfrt_stub::FallbackState::Create({}, {}));
727
728 TfrtGraphExecutionState::Options options;
729 options.run_placer_grappler_on_functions = false;
730 TF_ASSERT_OK_AND_ASSIGN(
731 auto graph_execution_state,
732 TfrtGraphExecutionState::Create(options, graphdef, *fallback_state));
733
734 GraphDef extension;
735 {
736 auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
737
738 Output b = ops::Const(scope.WithOpName("b"), 0.0f, {10, 10});
739
740 TF_ASSERT_OK(scope.ToGraphDef(&extension));
741 }
742
743 TF_ASSERT_OK(graph_execution_state->Extend(extension));
744
745 GraphDef expected;
746 {
747 auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0");
748
749 Output a = ops::Const(scope.WithOpName("a"), 0.0f, {10, 10});
750
751 Output b = ops::Const(scope.WithOpName("b"), 0.0f, {10, 10});
752
753 TF_ASSERT_OK(scope.ToGraphDef(&expected));
754 }
755
756 ASSERT_NE(graph_execution_state->original_graph_def(), nullptr);
757 CompareGraphs(expected, *graph_execution_state->original_graph_def());
758 }
759
760 // An auxiliary struct to verify the graph after partitioning and inserting
761 // transfer ops.
762 struct GraphInfo {
763 NodeDef* input_node = nullptr;
764 NodeDef* output_node = nullptr;
765 NodeDef* stateful_partitioned_call_node = nullptr;
766 std::vector<NodeDef*> partitioned_call_nodes;
767 std::vector<FunctionDef> fdefs;
768 };
769
770 class InsertTransferOpsTest : public grappler::GrapplerTest {
771 protected:
SetUp()772 void SetUp() override {
773 SessionOptions options;
774 auto* device_count = options.config.mutable_device_count();
775 device_count->insert({"CPU", 2});
776 std::vector<std::unique_ptr<Device>> devices;
777 TF_ASSERT_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
778 &devices));
779 device0_ = devices[0].get();
780 device1_ = devices[1].get();
781
782 fallback_state_ =
783 std::make_unique<FallbackState>(options, std::move(devices), fdef_lib_);
784 }
785
GetGraphInfo(const std::string & input,const std::string & output,GraphDef & graphdef)786 GraphInfo GetGraphInfo(const std::string& input, const std::string& output,
787 GraphDef& graphdef) {
788 GraphInfo graph_info;
789 for (NodeDef& node : *graphdef.mutable_node()) {
790 if (node.op() == "PartitionedCall") {
791 graph_info.partitioned_call_nodes.push_back(&node);
792 } else if (node.op() == "StatefulPartitionedCall") {
793 graph_info.stateful_partitioned_call_node = &node;
794 } else if (node.name() == input) {
795 graph_info.input_node = &node;
796 } else if (node.name() == output) {
797 graph_info.output_node = &node;
798 }
799 }
800
801 // Find the corresponding function called by the PartitionedCall nodes.
802 absl::flat_hash_map<std::string, FunctionDef> func_name_to_func;
803 for (const FunctionDef& fdef : graphdef.library().function()) {
804 func_name_to_func[fdef.signature().name()] = fdef;
805 }
806 for (NodeDef* node : graph_info.partitioned_call_nodes) {
807 CHECK(node->attr().contains("f"));
808 CHECK(func_name_to_func.contains(node->attr().at("f").func().name()));
809 const FunctionDef& fdef =
810 func_name_to_func.at(node->attr().at("f").func().name());
811 graph_info.fdefs.push_back(fdef);
812 }
813 return graph_info;
814 }
815
816 std::unique_ptr<FallbackState> fallback_state_;
817 Device* device0_ = nullptr; // Not owned.
818 Device* device1_ = nullptr; // Not owned.
819 tensorflow::FunctionDefLibrary fdef_lib_;
820 };
821
TEST_F(InsertTransferOpsTest,InsertTransferOps)822 TEST_F(InsertTransferOpsTest, InsertTransferOps) {
823 GraphDef graphdef;
824 {
825 Scope scope = Scope::NewRootScope();
826 Scope scope1 = scope.WithDevice(device0_->name());
827 Scope scope2 = scope.WithDevice(device1_->name());
828
829 // A graph whose nodes are on different devices.
830 // a(Const, on device0) -> b(Abs, on device1) -> c(Identity, on device0)
831 Output a = ops::Const(scope1.WithOpName("a"), 2.0, {1, 1});
832 Output b = ops::Abs(scope2.WithOpName("b"), a);
833 Output c = ops::Identity(scope1.WithOpName("c"), b);
834
835 // Before partitioning, there is no send/recv nodes.
836 int send_count = 0, recv_count = 0;
837 for (const auto* op : scope.graph()->op_nodes()) {
838 if (op->IsSend())
839 ++send_count;
840 else if (op->IsRecv())
841 ++recv_count;
842 }
843 ASSERT_EQ(scope.graph()->num_op_nodes(), 3);
844 ASSERT_EQ(send_count, 0);
845 ASSERT_EQ(recv_count, 0);
846
847 TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
848 }
849
850 TfrtGraphExecutionState::Options options;
851 options.run_placer_grappler_on_functions = false;
852 options.enable_tfrt_gpu = true;
853 TF_ASSERT_OK_AND_ASSIGN(
854 auto graph_execution_state,
855 TfrtGraphExecutionState::Create(options, graphdef, *fallback_state_));
856
857 tensorflow::GraphImportConfig graph_import_config;
858 graph_import_config.prune_unused_nodes = true;
859 graph_import_config.enable_shape_inference = false;
860 tensorflow::ArrayInfo array_info;
861 array_info.imported_dtype = DT_FLOAT;
862 array_info.shape.set_unknown_rank(true);
863 graph_import_config.inputs["a"] = array_info;
864 graph_import_config.outputs = {"c"};
865
866 TF_ASSERT_OK_AND_ASSIGN(
867 auto optimized_graph,
868 graph_execution_state->CreateOptimizedGraph(graph_import_config));
869
870 GraphDef new_graphdef;
871 optimized_graph.graph->ToGraphDef(&new_graphdef);
872
873 GraphInfo graph_info =
874 GetGraphInfo(/*input=*/"a", /*output=*/"c", new_graphdef);
875
876 ASSERT_THAT(graph_info.input_node, NotNull());
877 ASSERT_THAT(graph_info.output_node, NotNull());
878 ASSERT_THAT(graph_info.partitioned_call_nodes, SizeIs(2));
879 ASSERT_THAT(graph_info.stateful_partitioned_call_node, NotNull());
880
881 // Verify that each partition contains a _Send op and a _Recv op.
882 for (const FunctionDef& fdef : graph_info.fdefs) {
883 int send_count = 0, recv_count = 0;
884 for (const NodeDef& node : fdef.node_def()) {
885 if (node.op() == "_Send")
886 ++send_count;
887 else if (node.op() == "_Recv")
888 ++recv_count;
889 }
890 EXPECT_EQ(send_count, 1);
891 EXPECT_EQ(recv_count, 1);
892 }
893 }
894
TEST_F(InsertTransferOpsTest,InsertTransferOpsWithFunctionInlining)895 TEST_F(InsertTransferOpsTest, InsertTransferOpsWithFunctionInlining) {
896 GraphDef graphdef;
897 {
898 Scope scope = Scope::NewRootScope();
899 Scope scope1 = scope.WithDevice(device0_->name());
900 Scope scope2 = scope.WithDevice(device1_->name());
901
902 // A graph whose nodes are on different devices.
903 // a(Const, on device0) -> b(PartitionedCall) -> c(Identity, on device0)
904 // where PartitionedCall invokes a function with two nodes assigned to
905 // different devices.
906 const Tensor kThree = test::AsScalar<float>(3.0);
907 auto fdef = tensorflow::FunctionDefHelper::Create(
908 "_Pow3", {"x: float"}, {"y: float"}, {},
909 {// The two nodes in the function are assigned to different devices.
910 {{"three"},
911 "Const",
912 {},
913 {{"dtype", DT_FLOAT}, {"value", kThree}},
914 /*dep=*/{},
915 device0_->name()},
916 {{"pow3"},
917 "Pow",
918 {"x", "three:output:0"},
919 {{"T", DT_FLOAT}},
920 /*dep=*/{},
921 device1_->name()}},
922 {{"y", "pow3:z:0"}});
923
924 tensorflow::FunctionDefLibrary fdef_lib;
925 *fdef_lib.add_function() = fdef;
926 TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
927
928 Output a = ops::Const<float>(scope1.WithOpName("a"), 2.0, {1, 1});
929
930 std::vector<tensorflow::Output> inputs = {a};
931 std::vector<tensorflow::DataType> output_dtypes = {
932 fdef.signature().output_arg(0).type()};
933 tensorflow::NameAttrList func_attr;
934 func_attr.set_name(fdef.signature().name());
935 auto pcall = ops::PartitionedCall(scope2, inputs, output_dtypes, func_attr);
936 Output b = pcall.output.front();
937
938 Output c = ops::Identity(scope1.WithOpName("c"), b);
939
940 TF_ASSERT_OK(scope.ToGraphDef(&graphdef));
941
942 // Before partitioning, there is no send/recv nodes.
943 int partitioned_call_count = 0, mul_count = 0, send_count = 0,
944 recv_count = 0;
945 for (const auto* op : scope.graph()->op_nodes()) {
946 if (op->IsPartitionedCall())
947 ++partitioned_call_count;
948 else if (op->IsSend())
949 ++send_count;
950 else if (op->IsRecv())
951 ++recv_count;
952 else if (op->type_string() == "Mul")
953 ++mul_count;
954 }
955 ASSERT_EQ(partitioned_call_count, 1);
956 ASSERT_EQ(send_count, 0);
957 ASSERT_EQ(recv_count, 0);
958 ASSERT_EQ(mul_count, 0);
959 }
960
961 TfrtGraphExecutionState::Options options;
962 options.run_placer_grappler_on_functions = false;
963 options.enable_tfrt_gpu = true;
964 TF_ASSERT_OK_AND_ASSIGN(
965 auto graph_execution_state,
966 TfrtGraphExecutionState::Create(options, graphdef, *fallback_state_));
967
968 tensorflow::GraphImportConfig graph_import_config;
969 graph_import_config.prune_unused_nodes = true;
970 graph_import_config.enable_shape_inference = false;
971 tensorflow::ArrayInfo array_info;
972 array_info.imported_dtype = DT_FLOAT;
973 array_info.shape.set_unknown_rank(true);
974 graph_import_config.inputs["a"] = array_info;
975 graph_import_config.outputs = {"c"};
976
977 TF_ASSERT_OK_AND_ASSIGN(
978 auto optimized_graph,
979 graph_execution_state->CreateOptimizedGraph(graph_import_config));
980
981 GraphDef new_graphdef;
982 optimized_graph.graph->ToGraphDef(&new_graphdef);
983
984 GraphInfo graph_info =
985 GetGraphInfo(/*input=*/"a", /*output=*/"c", new_graphdef);
986
987 ASSERT_THAT(graph_info.input_node, NotNull());
988 ASSERT_THAT(graph_info.output_node, NotNull());
989 ASSERT_THAT(graph_info.partitioned_call_nodes, SizeIs(2));
990 ASSERT_THAT(graph_info.stateful_partitioned_call_node, NotNull());
991
992 // Verify that each partition contains a _Send op and a _Recv op.
993 for (const FunctionDef& fdef : graph_info.fdefs) {
994 int send_count = 0, recv_count = 0;
995 for (const NodeDef& node : fdef.node_def()) {
996 if (node.op() == "_Send")
997 ++send_count;
998 else if (node.op() == "_Recv")
999 ++recv_count;
1000 }
1001 EXPECT_EQ(send_count, 1);
1002 EXPECT_EQ(recv_count, 1);
1003 }
1004 }
1005
MakeOuterGraph(const FunctionLibraryDefinition & flib_def,const std::string & function_name)1006 std::unique_ptr<Graph> MakeOuterGraph(const FunctionLibraryDefinition& flib_def,
1007 const std::string& function_name) {
1008 Scope scope = Scope::NewRootScope().ExitOnError();
1009 TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto()));
1010
1011 auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
1012 auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
1013 auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
1014 auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
1015 auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
1016 auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
1017 auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
1018
1019 std::vector<tensorflow::NodeDefBuilder::NodeOut> func_inputs;
1020 func_inputs.push_back(
1021 tensorflow::NodeDefBuilder::NodeOut(a.node()->name(), 0, DT_INT32));
1022 func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(b.node()->name(), 0,
1023 b.output.type()));
1024 func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(c.node()->name(), 0,
1025 c.output.type()));
1026 func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(d.node()->name(), 0,
1027 d.output.type()));
1028 func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(u.node()->name(), 0,
1029 u.output.type()));
1030 func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(v.node()->name(), 0,
1031 v.output.type()));
1032 func_inputs.push_back(tensorflow::NodeDefBuilder::NodeOut(w.node()->name(), 0,
1033 w.output.type()));
1034
1035 std::vector<DataType> input_dtypes;
1036 for (const NodeDefBuilder::NodeOut& func_input : func_inputs) {
1037 input_dtypes.push_back(func_input.data_type);
1038 }
1039
1040 std::vector<DataType> output_dtypes = {DT_FLOAT, DT_INT32, DT_FLOAT,
1041 DT_FLOAT};
1042
1043 NameAttrList f;
1044 f.set_name(function_name);
1045
1046 NodeDef def;
1047 TF_CHECK_OK(NodeDefBuilder("xla_call_0", "StatefulPartitionedCall", &flib_def)
1048 .Input(func_inputs)
1049 .Attr("Tin", input_dtypes)
1050 .Attr("Tout", output_dtypes)
1051 .Attr("f", f)
1052 .Device("/gpu:0")
1053 .Attr(kXlaMustCompileAttr, true)
1054 .Finalize(&def));
1055
1056 Status status;
1057 Node* launch = scope.graph()->AddNode(def, &status);
1058 TF_CHECK_OK(status);
1059 TF_CHECK_OK(scope.DoShapeInference(launch));
1060 scope.graph()->AddEdge(a.node(), 0, launch, 0);
1061 scope.graph()->AddEdge(b.node(), 0, launch, 1);
1062 scope.graph()->AddEdge(c.node(), 0, launch, 2);
1063 scope.graph()->AddEdge(d.node(), 0, launch, 3);
1064 scope.graph()->AddEdge(u.node(), 0, launch, 4);
1065 scope.graph()->AddEdge(v.node(), 0, launch, 5);
1066 scope.graph()->AddEdge(w.node(), 0, launch, 6);
1067
1068 auto consumer0_a =
1069 ops::Identity(scope.WithOpName("consumer0_a"), Output(launch, 0));
1070 auto consumer0_b =
1071 ops::Identity(scope.WithOpName("consumer0_b"), Output(launch, 0));
1072 auto consumer0_c =
1073 ops::Identity(scope.WithOpName("consumer0_c"), Output(launch, 0));
1074 auto consumer1 =
1075 ops::Identity(scope.WithOpName("consumer1"), Output(launch, 1));
1076 auto consumer2 =
1077 ops::Identity(scope.WithOpName("consumer2"), Output(launch, 2));
1078 auto consumer3 =
1079 ops::Identity(scope.WithOpName("consumer3"), Output(launch, 3));
1080
1081 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1082 TF_CHECK_OK(scope.ToGraph(graph.get()));
1083 return graph;
1084 }
1085
1086 // Makes an encapsulate body graph for use in tests.
MakeBodyGraph()1087 std::unique_ptr<Graph> MakeBodyGraph() {
1088 Scope scope = Scope::NewRootScope().ExitOnError();
1089
1090 auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0);
1091 auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1);
1092 auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2);
1093 auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3);
1094
1095 auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4);
1096 auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5);
1097 auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
1098
1099 auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1);
1100 auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT);
1101 auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT);
1102 auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT);
1103
1104 auto e = ops::Add(scope.WithOpName("E"), arg0, arg2);
1105 auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
1106 auto g = ops::Add(scope.WithOpName("G"), f, arg3);
1107
1108 auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"),
1109 b_identity, 0);
1110 auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1);
1111 auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2);
1112 auto out3 =
1113 ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3);
1114
1115 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1116 TF_CHECK_OK(scope.ToGraph(graph.get()));
1117 return graph;
1118 }
1119
TEST(BuildXlaOpsTest,BuildXlaLaunchOp)1120 TEST(BuildXlaOpsTest, BuildXlaLaunchOp) {
1121 std::unique_ptr<Graph> body_graph = MakeBodyGraph();
1122 FunctionDefLibrary flib;
1123 TF_ASSERT_OK(
1124 GraphToFunctionDef(*body_graph, "xla_func_0", flib.add_function()));
1125
1126 FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
1127
1128 std::unique_ptr<Graph> graph = MakeOuterGraph(flib_def, "xla_func_0");
1129 TF_ASSERT_OK(BuildXlaLaunchOps(graph.get()));
1130
1131 Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError();
1132 TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
1133
1134 auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
1135 auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
1136 auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
1137 auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
1138 auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
1139 auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
1140 auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
1141
1142 NameAttrList function;
1143 function.set_name("xla_func_0");
1144 auto launch = ops::XlaLaunch(
1145 scope.WithOpName("xla_call_0").WithDevice("/gpu:0"),
1146 std::initializer_list<Input>{}, std::initializer_list<Input>{a, b, c, d},
1147 std::initializer_list<Input>{u, v, w},
1148 DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function);
1149
1150 auto consumer0_a =
1151 ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]);
1152 auto consumer0_b =
1153 ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]);
1154 auto consumer0_c =
1155 ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]);
1156 auto consumer1 =
1157 ops::Identity(scope.WithOpName("consumer1"), launch.results[1]);
1158 auto consumer2 =
1159 ops::Identity(scope.WithOpName("consumer2"), launch.results[2]);
1160 auto consumer3 =
1161 ops::Identity(scope.WithOpName("consumer3"), launch.results[3]);
1162
1163 GraphDef expected_def;
1164 TF_ASSERT_OK(scope.ToGraphDef(&expected_def));
1165
1166 GraphDef actual_def;
1167 graph->ToGraphDef(&actual_def);
1168 TF_EXPECT_GRAPH_EQ(expected_def, actual_def);
1169 }
1170
1171 } // namespace
1172 } // namespace tfrt_stub
1173 } // namespace tensorflow
1174