1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
17
18 #include "absl/strings/match.h"
19 #include "tensorflow/cc/framework/scope.h"
20 #include "tensorflow/cc/ops/array_ops.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/functional_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/compiler/jit/encapsulate_util.h"
25 #include "tensorflow/compiler/xla/test.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/framework/common_shape_fns.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/graph_to_functiondef.h"
31 #include "tensorflow/core/framework/node_def_util.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/public/session_options.h"
36 #include "tensorflow/core/public/version.h"
37
38 namespace tensorflow {
39
TEST(RewriteOutsideCompilationSubgraphFnTest,Basic)40 TEST(RewriteOutsideCompilationSubgraphFnTest, Basic) {
41 // Build the graph:
42 // "add" = "arg0" + "arg1"
43 // "ret0" = "add"
44 // "ret1" = "arg1"
45 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
46 Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_INT32, 0);
47 Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_FLOAT, 1);
48 Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
49 Output add = ops::Add(s.WithOpName("add"), arg0, arg0);
50 auto ret0 = ops::_Retval(s.WithOpName("ret0"), add, 0);
51 auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg1, 1);
52 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
53 TF_CHECK_OK(s.ToGraph(g.get()));
54 auto node_name_image = g->BuildNodeNameIndex();
55 Node *add_node = node_name_image["add"];
56 EXPECT_NE(add_node, nullptr);
57 add_node->AddAttr(kXlaConnectedToXlaComputationAttrName, "cluster");
58 add_node->AddAttr(kXlaConnectedFromXlaComputationAttrName, "cluster");
59
60 RewriteOutsideCompilationSubgraphFn rewrite_fn("_xla", "_oc", "cluster", "");
61 std::vector<OutputTensor> arg_source_tensors;
62 NodeDef call_node_def;
63 call_node_def.set_op("0");
64 TF_CHECK_OK(
65 rewrite_fn(arg_source_tensors, &g, nullptr, nullptr, &call_node_def));
66 node_name_image = g->BuildNodeNameIndex();
67
68 // Verify step 1: add key placeholder node.
69 Node *key_placeholder = node_name_image["cluster_key_placeholder"];
70 EXPECT_NE(key_placeholder, nullptr);
71 // Verify step 2: replace _Arg nodes with XlaRecvAtHost.
72 for (Node *n : g->nodes()) {
73 EXPECT_NE(n->type_string(), "_Arg");
74 }
75 Node *recv_at_host = node_name_image["outside_compilation_cluster__0_recv"];
76 EXPECT_NE(recv_at_host, nullptr);
77 std::vector<DataType> recv_at_host_dtypes;
78 TF_CHECK_OK(
79 GetNodeAttr(recv_at_host->attrs(), "Toutputs", &recv_at_host_dtypes));
80 EXPECT_EQ(recv_at_host_dtypes.size(), 3);
81 EXPECT_EQ(recv_at_host_dtypes[0], DT_INT32);
82 EXPECT_EQ(recv_at_host_dtypes[1], DT_FLOAT);
83 EXPECT_EQ(recv_at_host_dtypes[2], DT_INT32);
84 // Verify step 3: replace _Retval nodes with XlaSendFromHost.
85 for (Node *n : g->nodes()) {
86 EXPECT_NE(n->type_string(), "_Retval");
87 }
88 Node *send_from_host = node_name_image["outside_compilation_cluster__0_send"];
89 EXPECT_NE(send_from_host, nullptr);
90 std::vector<DataType> send_from_host_dtypes;
91 TF_CHECK_OK(
92 GetNodeAttr(send_from_host->attrs(), "Tinputs", &send_from_host_dtypes));
93 EXPECT_EQ(send_from_host_dtypes.size(), 2);
94 EXPECT_EQ(send_from_host_dtypes[0], DT_INT32);
95 EXPECT_EQ(send_from_host_dtypes[1], DT_FLOAT);
96 // Verify step 4: nodes marked with XLA cluster and outside compilation attr.
97 add_node = node_name_image["add"];
98 EXPECT_NE(add_node, nullptr);
99 EXPECT_TRUE(HasNodeAttr(add_node->def(), "_xla"));
100 EXPECT_TRUE(HasNodeAttr(add_node->def(), "_oc"));
101 // Verify step 5: control edges added.
102 bool has_control_edge_from_recv_at_host = false;
103 for (auto e : add_node->in_edges()) {
104 if (e->IsControlEdge() && e->src() == recv_at_host) {
105 has_control_edge_from_recv_at_host = true;
106 }
107 }
108 EXPECT_TRUE(has_control_edge_from_recv_at_host);
109 bool has_control_edge_to_send_from_host = false;
110 for (auto e : add_node->out_edges()) {
111 if (e->IsControlEdge() && e->dst() == send_from_host) {
112 has_control_edge_to_send_from_host = true;
113 }
114 }
115 EXPECT_TRUE(has_control_edge_to_send_from_host);
116 // Verify step 7: necessary attrs added to call_node_def.
117 NameAttrList shape_inference_graph;
118 TF_CHECK_OK(GetNodeAttr(AttrSlice(&call_node_def.attr()),
119 "shape_inference_graph", &shape_inference_graph));
120 EXPECT_EQ(shape_inference_graph.name(),
121 "_outside_compilation_shape_inference_cluster__0");
122 }
123
TEST(RewriteOutsideCompilationSubgraphFnTest,NoSendFromHost)124 TEST(RewriteOutsideCompilationSubgraphFnTest, NoSendFromHost) {
125 // Build the graph: only 1 node: "arg0"
126 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
127 Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_INT32, 0);
128 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
129 TF_CHECK_OK(s.ToGraph(g.get()));
130
131 RewriteOutsideCompilationSubgraphFn rewrite_fn("_xla", "_oc", "cluster", "");
132 std::vector<OutputTensor> arg_source_tensors;
133 NodeDef call_node_def;
134 call_node_def.set_op("0");
135 TF_CHECK_OK(
136 rewrite_fn(arg_source_tensors, &g, nullptr, nullptr, &call_node_def));
137 auto node_name_image = g->BuildNodeNameIndex();
138
139 // Check key placeholder and RecvAtHost is present, but SendFromHost is not.
140 Node *key_placeholder = node_name_image["cluster_key_placeholder"];
141 EXPECT_NE(key_placeholder, nullptr);
142 Node *recv_at_host = node_name_image["outside_compilation_cluster__0_recv"];
143 EXPECT_NE(recv_at_host, nullptr);
144 Node *send_from_host = node_name_image["outside_compilation_cluster__0_send"];
145 EXPECT_EQ(send_from_host, nullptr);
146 }
147
TEST(RewriteOutsideCompilationSubgraphFnTest,NoRecvAtHost)148 TEST(RewriteOutsideCompilationSubgraphFnTest, NoRecvAtHost) {
149 // Build the graph:
150 // "ret" = "const0"
151 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
152 Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
153 auto ret = ops::_Retval(s.WithOpName("ret"), const0, 0);
154 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
155 TF_CHECK_OK(s.ToGraph(g.get()));
156
157 RewriteOutsideCompilationSubgraphFn rewrite_fn("_xla", "_oc", "cluster", "");
158 std::vector<OutputTensor> arg_source_tensors;
159 NodeDef call_node_def;
160 call_node_def.set_op("0");
161 TF_CHECK_OK(
162 rewrite_fn(arg_source_tensors, &g, nullptr, nullptr, &call_node_def));
163 auto node_name_image = g->BuildNodeNameIndex();
164
165 // Check key placeholder and SendFromHost is present, but RecvAtHost is not.
166 Node *key_placeholder = node_name_image["cluster_key_placeholder"];
167 EXPECT_NE(key_placeholder, nullptr);
168 Node *recv_at_host = node_name_image["outside_compilation_cluster__0_recv"];
169 EXPECT_EQ(recv_at_host, nullptr);
170 Node *send_from_host = node_name_image["outside_compilation_cluster__0_send"];
171 EXPECT_NE(send_from_host, nullptr);
172 }
173
TEST(RewriteOutsideCompilationSubgraphFnTest,NoKeyPlaceholder)174 TEST(RewriteOutsideCompilationSubgraphFnTest, NoKeyPlaceholder) {
175 // Build the graph: only 1 node: "const0"
176 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
177 Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
178 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
179 TF_CHECK_OK(s.ToGraph(g.get()));
180
181 RewriteOutsideCompilationSubgraphFn rewrite_fn("_xla", "_oc", "cluster", "");
182 std::vector<OutputTensor> arg_source_tensors;
183 NodeDef call_node_def;
184 call_node_def.set_op("0");
185 TF_CHECK_OK(
186 rewrite_fn(arg_source_tensors, &g, nullptr, nullptr, &call_node_def));
187 auto node_name_image = g->BuildNodeNameIndex();
188
189 // Check key placeholder/RecvAtHost/SendFromHost are not present.
190 Node *key_placeholder = node_name_image["cluster_key_placeholder"];
191 EXPECT_EQ(key_placeholder, nullptr);
192 Node *recv_at_host = node_name_image["outside_compilation_cluster__0_recv"];
193 EXPECT_EQ(recv_at_host, nullptr);
194 Node *send_from_host = node_name_image["outside_compilation_cluster__0_send"];
195 EXPECT_EQ(send_from_host, nullptr);
196 }
197
TEST(RewriteOutsideCompilationSubgraphFnTest,ShapesInferred)198 TEST(RewriteOutsideCompilationSubgraphFnTest, ShapesInferred) {
199 // Build the graph:
200 // "ret" = "const0"
201 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
202 Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
203 auto ret = ops::_Retval(s.WithOpName("ret"), const0, 0);
204 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
205 TF_CHECK_OK(s.ToGraph(g.get()));
206 auto node_name_image = g->BuildNodeNameIndex();
207 Node *const0_node = node_name_image["const0"];
208 EXPECT_NE(const0_node, nullptr);
209 PartialTensorShape shape({2});
210 const0_node->AddAttr(kXlaInferredShapesAttrName,
211 std::vector<PartialTensorShape>{shape});
212
213 RewriteOutsideCompilationSubgraphFn rewrite_fn("_xla", "_oc", "cluster", "");
214 std::vector<OutputTensor> arg_source_tensors;
215 NodeDef call_node_def;
216 call_node_def.set_op("0");
217 TF_CHECK_OK(
218 rewrite_fn(arg_source_tensors, &g, nullptr, nullptr, &call_node_def));
219 node_name_image = g->BuildNodeNameIndex();
220
221 // Check "shape" attr is available in call_node_def.
222 std::vector<TensorShapeProto> shapes;
223 TF_CHECK_OK(GetNodeAttr(AttrSlice(&call_node_def.attr()), "shapes", &shapes));
224 EXPECT_EQ(shapes.size(), 1);
225 EXPECT_EQ(shapes[0].dim_size(), 1);
226 }
227
228 class ExtractOutsideCompilationForFunctionTest : public ::testing::Test {
229 public:
SetUp()230 void SetUp() override {
231 SessionOptions session_options;
232 std::vector<std::unique_ptr<Device>> devices;
233 TF_CHECK_OK(DeviceFactory::AddDevices(
234 session_options, "/job:localhost/replica:0/task:0", &devices));
235 device_mgr_ = std::make_unique<StaticDeviceMgr>(std::move(devices));
236 }
237
ExtractOutsideCompilationTest(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const NameAttrList & func_name_attrs,const string & new_func_name,const string & host_graph_func_name,const std::map<string,int> & host_compute_core,FunctionLibraryDefinition * fld,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)238 Status ExtractOutsideCompilationTest(
239 const string &xla_cluster_attr_name,
240 const string &outside_compilation_attr_name,
241 const string &xla_cluster_name, const NameAttrList &func_name_attrs,
242 const string &new_func_name, const string &host_graph_func_name,
243 const std::map<string, int> &host_compute_core,
244 FunctionLibraryDefinition *fld,
245 std::vector<string> *shape_inference_graphs,
246 bool *has_outside_compilation) {
247 OptimizerOptions opts;
248 pflr_ = std::make_unique<ProcessFunctionLibraryRuntime>(
249 device_mgr_.get(), Env::Default(), /*config=*/nullptr,
250 TF_GRAPH_DEF_VERSION, fld, opts,
251 /*default_thread_pool=*/nullptr);
252 auto flr = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
253 return ExtractOutsideCompilationForFunction(
254 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
255 func_name_attrs, new_func_name, host_graph_func_name, host_compute_core,
256 flr, fld, shape_inference_graphs, has_outside_compilation);
257 }
258
259 private:
260 std::unique_ptr<DeviceMgr> device_mgr_;
261 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
262 };
263
TEST_F(ExtractOutsideCompilationForFunctionTest,Basic)264 TEST_F(ExtractOutsideCompilationForFunctionTest, Basic) {
265 // Build the XLA computation func.
266 // "const0"
267 // "identity0" = "const0" (outside compilation cluster "0")
268 // "identity1" = "identity0" (outside compilation cluster "1")
269 // "identity2" = "identity1"
270 FunctionDefLibrary fdl;
271 {
272 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
273 Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
274 Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
275 Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
276 Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
277 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
278 TF_CHECK_OK(s.ToGraph(g.get()));
279 auto node_name_image = g->BuildNodeNameIndex();
280 node_name_image["identity0"]->AddAttr("_oc", "0");
281 node_name_image["identity1"]->AddAttr("_oc", "1");
282 PartialTensorShape shape({2});
283 node_name_image["identity1"]->AddAttr(
284 kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
285
286 FunctionDef *xla_fdef = fdl.add_function();
287 TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
288 }
289 FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
290
291 protobuf::Map<string, tensorflow::AttrValue> attrs;
292 std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
293 std::vector<string> shape_inference_graphs;
294 bool has_outside_compilation;
295 NameAttrList name_attrs;
296 name_attrs.set_name("cluster");
297 *name_attrs.mutable_attr() = attrs;
298 TF_CHECK_OK(ExtractOutsideCompilationTest(
299 "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
300 host_compute_core, &fld, &shape_inference_graphs,
301 &has_outside_compilation));
302
303 // Get rewritten XLA computation function.
304 std::unique_ptr<FunctionBody> xla_fbody;
305 TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
306 AttrSlice(), &fld, &xla_fbody));
307 auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
308
309 // Check XlaHostCompute nodes.
310 Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
311 EXPECT_NE(host_compute_0, nullptr);
312 Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
313 EXPECT_NE(host_compute_1, nullptr);
314 // Check XlaHostCompute nodes' "tpu_core" attr.
315 int tpu_core;
316 TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "tpu_core", &tpu_core));
317 EXPECT_EQ(tpu_core, 1);
318 TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "tpu_core", &tpu_core));
319 EXPECT_EQ(tpu_core, 0);
320 // Check XlaHostCompute nodes' "shapes" attr. "0" should not have shapes, and
321 // "1" should have shapes.
322 std::vector<TensorShapeProto> shapes;
323 TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shapes", &shapes));
324 EXPECT_EQ(shapes.size(), 0);
325 TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shapes", &shapes));
326 EXPECT_EQ(shapes.size(), 1);
327 EXPECT_EQ(shapes[0].dim_size(), 1);
328 // Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have
329 // empty values.
330 NameAttrList shape_inference_graph;
331 TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph",
332 &shape_inference_graph));
333 EXPECT_EQ(shape_inference_graph.name(), "");
334 TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph",
335 &shape_inference_graph));
336 EXPECT_EQ(shape_inference_graph.name(), "");
337
338 // Check `shape_inference_graphs`.
339 EXPECT_EQ(shape_inference_graphs.size(), 0);
340
341 // Check host graph: verify we have key placeholder and sequencer.
342 std::unique_ptr<FunctionBody> host_fbody;
343 AttrValue device_ordinal_temp_value;
344 device_ordinal_temp_value.set_i(0);
345 protobuf::Map<string, AttrValue> host_func_attrs;
346 host_func_attrs["_device_ordinal"] = device_ordinal_temp_value;
347 TF_CHECK_OK(FunctionDefToBodyHelper(
348 *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, &host_fbody));
349 Graph *host_graph = host_fbody->graph;
350 Node *key_placeholder = nullptr, *sequencer = nullptr;
351 for (Node *n : host_graph->nodes()) {
352 if (n->type_string() == "Placeholder" &&
353 absl::EndsWith(n->name(), "_key_placeholder")) {
354 EXPECT_EQ(key_placeholder, nullptr);
355 key_placeholder = n;
356 } else if (HasNodeAttr(n->def(), "_xla_host_transfer_sequencer")) {
357 EXPECT_EQ(sequencer, nullptr);
358 sequencer = n;
359 }
360 }
361 EXPECT_NE(key_placeholder, nullptr);
362 EXPECT_NE(sequencer, nullptr);
363 // Check SendFromHost and RecvAtHost has key placeholder as input, and have
364 // control edge to sequencer.
365 int num_send_from_host = 0, num_recv_at_host = 0;
366 std::vector<Node *> send_recv_nodes;
367 for (Node *n : host_graph->nodes()) {
368 if (n->type_string() == "_XlaSendFromHost") {
369 num_send_from_host++;
370 send_recv_nodes.push_back(n);
371 } else if (n->type_string() == "_XlaRecvAtHost") {
372 num_recv_at_host++;
373 send_recv_nodes.push_back(n);
374 }
375 }
376 EXPECT_EQ(num_send_from_host, 1);
377 EXPECT_EQ(num_recv_at_host, 1);
378 for (Node *n : send_recv_nodes) {
379 Node *input_node;
380 TF_CHECK_OK(n->input_node(n->num_inputs() - 1, &input_node));
381 EXPECT_EQ(input_node, key_placeholder);
382
383 bool has_control_edge_to_sequencer = false;
384 for (const Edge *e : n->out_edges()) {
385 if (e->IsControlEdge() && e->dst() == sequencer) {
386 has_control_edge_to_sequencer = true;
387 break;
388 }
389 }
390 EXPECT_TRUE(has_control_edge_to_sequencer);
391 }
392 }
393
TEST_F(ExtractOutsideCompilationForFunctionTest,NoHostGraph)394 TEST_F(ExtractOutsideCompilationForFunctionTest, NoHostGraph) {
395 // Build the XLA computation func.
396 // "const0"
397 FunctionDefLibrary fdl;
398 {
399 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
400 Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
401 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
402 TF_CHECK_OK(s.ToGraph(g.get()));
403
404 FunctionDef *xla_fdef = fdl.add_function();
405 TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
406 }
407 FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
408
409 protobuf::Map<string, tensorflow::AttrValue> attrs;
410 std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
411 std::vector<string> shape_inference_graphs;
412 bool has_outside_compilation;
413 NameAttrList name_attrs;
414 name_attrs.set_name("cluster");
415 *name_attrs.mutable_attr() = attrs;
416 TF_CHECK_OK(ExtractOutsideCompilationTest(
417 "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
418 host_compute_core, &fld, &shape_inference_graphs,
419 &has_outside_compilation));
420
421 // Check host graph is not created.
422 EXPECT_EQ(fld.Find("host_graph"), nullptr);
423 }
424
TEST_F(ExtractOutsideCompilationForFunctionTest,OutsideCompilationInIf)425 TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) {
426 // Build the XLA computation func.
427 // "const0" (bool)
428 // "const1" (int32)
429 // "if0" (pred = "const0", input = "const1", then_branch = "true_fn",
430 // else_branch = "false_fn")
431 FunctionDefLibrary fdl;
432 {
433 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
434 Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0);
435 Output identity = ops::Identity(s.WithOpName("identity_true_fn"), arg);
436 ops::_Retval retval(s.WithOpName("retval"), identity, 0);
437 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
438 TF_CHECK_OK(s.ToGraph(g.get()));
439 auto node_name_image = g->BuildNodeNameIndex();
440 node_name_image["identity_true_fn"]->AddAttr("_oc", "0");
441 PartialTensorShape shape({2});
442 node_name_image["identity_true_fn"]->AddAttr(
443 kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
444
445 FunctionDef *true_fn_fdef = fdl.add_function();
446 TF_CHECK_OK(GraphToFunctionDef(*g, "true_fn", true_fn_fdef));
447 }
448 {
449 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
450 Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0);
451 Output identity = ops::Identity(s.WithOpName("identity_false_fn"), arg);
452 ops::_Retval retval(s.WithOpName("retval"), identity, 0);
453 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
454 TF_CHECK_OK(s.ToGraph(g.get()));
455 auto node_name_image = g->BuildNodeNameIndex();
456 node_name_image["identity_false_fn"]->AddAttr("_oc", "0");
457 PartialTensorShape shape({2});
458 node_name_image["identity_false_fn"]->AddAttr(
459 kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
460
461 FunctionDef *false_fn_fdef = fdl.add_function();
462 TF_CHECK_OK(GraphToFunctionDef(*g, "false_fn", false_fn_fdef));
463 }
464 {
465 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
466 Output cond = ops::Const(s.WithOpName("const0"), true, {2});
467 Output input = ops::Const(s.WithOpName("const1"), 1, {2});
468 NameAttrList true_fn;
469 true_fn.set_name("true_fn");
470 NameAttrList false_fn;
471 false_fn.set_name("false_fn");
472 auto if_op = ops::If(s.WithOpName("if"), cond,
473 std::initializer_list<Input>{cond, input}, {DT_INT32},
474 true_fn, false_fn);
475 ops::_Retval retval(s.WithOpName("retval"), if_op.output[0], 0);
476 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
477 TF_CHECK_OK(s.ToGraph(g.get()));
478
479 FunctionDef *xla_fdef = fdl.add_function();
480 TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
481 }
482 FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
483
484 protobuf::Map<string, tensorflow::AttrValue> attrs;
485 std::map<string, int> host_compute_core;
486 std::vector<string> shape_inference_graphs;
487 bool has_outside_compilation;
488 NameAttrList name_attrs;
489 name_attrs.set_name("cluster");
490 *name_attrs.mutable_attr() = attrs;
491 TF_CHECK_OK(ExtractOutsideCompilationTest(
492 "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
493 host_compute_core, &fld, &shape_inference_graphs,
494 &has_outside_compilation));
495
496 // Check host graph.
497 {
498 std::unique_ptr<FunctionBody> host_fbody;
499 AttrValue device_ordinal_temp_value;
500 device_ordinal_temp_value.set_i(0);
501 protobuf::Map<string, AttrValue> host_func_attrs;
502 host_func_attrs["_device_ordinal"] = device_ordinal_temp_value;
503 TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"),
504 AttrSlice(&host_func_attrs), &fld,
505 &host_fbody));
506 Graph *host_graph = host_fbody->graph;
507 auto node_name_index = host_graph->BuildNodeNameIndex();
508
509 // Verify we have XlaRecvAtHost to receive "If" predicate.
510 Node *recv_if_pred_node = node_name_index["recv_oc_if_pred_if"];
511 EXPECT_NE(recv_if_pred_node, nullptr);
512
513 // Verify we have an "If" to choose outside compilation between then_branch
514 // and else_branch, and it has `recv_if_pred_node` as cond input.
515 Node *if_oc_node = node_name_index["oc_if_if"];
516 EXPECT_NE(if_oc_node, nullptr);
517 Node *if_oc_node_cond_input;
518 TF_CHECK_OK(if_oc_node->input_node(0, &if_oc_node_cond_input));
519 EXPECT_EQ(if_oc_node_cond_input, recv_if_pred_node);
520
521 // Check that then_branch outside compilation has node "identity_true_fn".
522 const FunctionDef *true_def = fld.Find("oc_then_branch_host_if_true_fn");
523 EXPECT_NE(true_def, nullptr);
524 bool has_identity_true_fn_node = false;
525 for (const auto &node_def : true_def->node_def()) {
526 if (node_def.name() == "identity_true_fn") {
527 has_identity_true_fn_node = true;
528 break;
529 }
530 }
531 EXPECT_TRUE(has_identity_true_fn_node);
532
533 // Check that else_branch outside compilation has node "identity_false_fn".
534 const FunctionDef *false_def = fld.Find("oc_else_branch_host_if_false_fn");
535 EXPECT_NE(false_def, nullptr);
536 bool has_identity_false_fn_node = false;
537 for (const auto &node_def : false_def->node_def()) {
538 if (node_def.name() == "identity_false_fn") {
539 has_identity_false_fn_node = true;
540 break;
541 }
542 }
543 EXPECT_TRUE(has_identity_false_fn_node);
544 }
545
546 // Check XLA graph.
547 {
548 std::unique_ptr<FunctionBody> xla_fbody;
549 TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
550 AttrSlice(), &fld, &xla_fbody));
551 Graph *xla_graph = xla_fbody->graph;
552 auto node_name_index = xla_graph->BuildNodeNameIndex();
553
554 // Check that we have XlaSendToHost to send cond predicate to host, and
555 // there is a control edge to If node.
556 Node *send_if_pred_node = node_name_index["send_oc_if_pred_if"];
557 EXPECT_NE(send_if_pred_node, nullptr);
558 bool has_control_edge_to_if = false;
559 for (const Edge *e : send_if_pred_node->out_edges()) {
560 if (e->IsControlEdge() && e->dst()->name() == "if") {
561 has_control_edge_to_if = true;
562 break;
563 }
564 }
565 EXPECT_TRUE(has_control_edge_to_if);
566
567 // Check that the "If" node now has `send_if_pred_node` as attribute
568 // _xla_token_input_nodes.
569 Node *if_node = node_name_index["if"];
570 EXPECT_NE(if_node, nullptr);
571 std::vector<string> token_inputs;
572 TF_CHECK_OK(
573 GetNodeAttr(if_node->def(), "_xla_token_input_nodes", &token_inputs));
574 EXPECT_THAT(token_inputs, ::testing::ElementsAre("send_oc_if_pred_if"));
575 }
576 }
577
TEST_F(ExtractOutsideCompilationForFunctionTest,OutsideCompilationInWhile)578 TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) {
579 // Build the XLA computation func.
580 // "const0" (bool)
581 // "while0" (input = "const0", cond = "cond_fn", body = "body_fn")
582 FunctionDefLibrary fdl;
583 {
584 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
585 Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0);
586 Output identity = ops::Identity(s.WithOpName("identity_cond_fn"), arg);
587 ops::_Retval retval(s.WithOpName("retval"), identity, 0);
588 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
589 TF_CHECK_OK(s.ToGraph(g.get()));
590 auto node_name_image = g->BuildNodeNameIndex();
591 node_name_image["identity_cond_fn"]->AddAttr("_oc", "0");
592 PartialTensorShape shape({2});
593 node_name_image["identity_cond_fn"]->AddAttr(
594 kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
595
596 FunctionDef *cond_fn_fdef = fdl.add_function();
597 TF_CHECK_OK(GraphToFunctionDef(*g, "cond_fn", cond_fn_fdef));
598 }
599 {
600 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
601 Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0);
602 Output identity = ops::Identity(s.WithOpName("identity_body_fn"), arg);
603 ops::_Retval retval(s.WithOpName("retval"), identity, 0);
604 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
605 TF_CHECK_OK(s.ToGraph(g.get()));
606 auto node_name_image = g->BuildNodeNameIndex();
607 node_name_image["identity_body_fn"]->AddAttr("_oc", "0");
608 PartialTensorShape shape({2});
609 node_name_image["identity_body_fn"]->AddAttr(
610 kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
611
612 FunctionDef *body_fn_fdef = fdl.add_function();
613 TF_CHECK_OK(GraphToFunctionDef(*g, "body_fn", body_fn_fdef));
614 }
615 {
616 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
617 Output input = ops::Const(s.WithOpName("const0"), true, {2});
618 NameAttrList cond_fn;
619 cond_fn.set_name("cond_fn");
620 NameAttrList body_fn;
621 body_fn.set_name("body_fn");
622 auto while_op =
623 ops::While(s.WithOpName("while"), std::initializer_list<Input>{input},
624 cond_fn, body_fn);
625 ops::_Retval retval(s.WithOpName("retval"), while_op.output[0], 0);
626 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
627 TF_CHECK_OK(s.ToGraph(g.get()));
628
629 FunctionDef *xla_fdef = fdl.add_function();
630 TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
631 }
632 FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
633
634 protobuf::Map<string, tensorflow::AttrValue> attrs;
635 std::map<string, int> host_compute_core;
636 std::vector<string> shape_inference_graphs;
637 bool has_outside_compilation;
638 NameAttrList name_attrs;
639 name_attrs.set_name("cluster");
640 *name_attrs.mutable_attr() = attrs;
641 TF_CHECK_OK(ExtractOutsideCompilationTest(
642 "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
643 host_compute_core, &fld, &shape_inference_graphs,
644 &has_outside_compilation));
645
646 // Check host graph.
647 {
648 std::unique_ptr<FunctionBody> host_fbody;
649 AttrValue device_ordinal_temp_value;
650 device_ordinal_temp_value.set_i(0);
651 protobuf::Map<string, AttrValue> host_func_attrs;
652 host_func_attrs["_device_ordinal"] = device_ordinal_temp_value;
653 TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"),
654 AttrSlice(&host_func_attrs), &fld,
655 &host_fbody));
656 Graph *host_graph = host_fbody->graph;
657 auto node_name_index = host_graph->BuildNodeNameIndex();
658
659 // Verify we have an "While" to execute outside compilation.
660 Node *while_oc_node = node_name_index["oc_while_while"];
661 EXPECT_NE(while_oc_node, nullptr);
662
663 // Check that cond outside compilation has node "identity_cond_fn".
664 const FunctionDef *cond_def = fld.Find("oc_cond_host_while_cond_fn");
665 EXPECT_NE(cond_def, nullptr);
666 bool has_identity_cond_fn_node = false;
667 for (const auto &node_def : cond_def->node_def()) {
668 if (node_def.name() == "identity_cond_fn") {
669 has_identity_cond_fn_node = true;
670 break;
671 }
672 }
673 EXPECT_TRUE(has_identity_cond_fn_node);
674
675 // Check that body outside compilation has node "identity_body_fn".
676 const FunctionDef *body_def = fld.Find("oc_body_host_while_body_fn");
677 EXPECT_NE(body_def, nullptr);
678 bool has_identity_body_fn_node = false;
679 for (const auto &node_def : body_def->node_def()) {
680 if (node_def.name() == "identity_body_fn") {
681 has_identity_body_fn_node = true;
682 break;
683 }
684 }
685 EXPECT_TRUE(has_identity_body_fn_node);
686 }
687
688 // Check XLA graph.
689 {
690 // Verify that rewritten cond fn has XlaSendToHost to send loop predicate to
691 // host.
692 const FunctionDef *cond_def = fld.Find("cond_fn_oc");
693 EXPECT_NE(cond_def, nullptr);
694 bool has_send_oc_while_cond_node = false;
695 for (const auto &node_def : cond_def->node_def()) {
696 if (node_def.name() == "send_oc_while_cond_while") {
697 has_send_oc_while_cond_node = true;
698 break;
699 }
700 }
701 EXPECT_TRUE(has_send_oc_while_cond_node);
702 }
703 }
704
TEST_F(ExtractOutsideCompilationForFunctionTest,OutsideCompilationInFunction)705 TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) {
706 // Build the XLA computation func.
707 // "const0" (int32)
708 // "fn" (input = "const0")
709 FunctionDefLibrary fdl;
710 {
711 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
712 Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0);
713 Output identity = ops::Identity(s.WithOpName("identity"), arg);
714 ops::_Retval retval(s.WithOpName("retval"), identity, 0);
715 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
716 TF_CHECK_OK(s.ToGraph(g.get()));
717 auto node_name_image = g->BuildNodeNameIndex();
718 node_name_image["identity"]->AddAttr("_oc", "0");
719 PartialTensorShape shape({2});
720 node_name_image["identity"]->AddAttr(
721 kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
722
723 FunctionDef *true_fn_fdef = fdl.add_function();
724 TF_CHECK_OK(GraphToFunctionDef(*g, "fn", true_fn_fdef));
725 }
726 FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
727 {
728 std::unique_ptr<Graph> g(new Graph(&fld));
729
730 tensorflow::TensorProto tensor_proto;
731 tensor_proto.set_dtype(tensorflow::DT_INT32);
732 tensorflow::TensorShapeProto shape;
733 shape.add_dim()->set_size(2);
734 *tensor_proto.mutable_tensor_shape() = shape;
735 for (int i = 0; i < 2; ++i) {
736 tensor_proto.add_int_val(1);
737 }
738 NodeDef const_def;
739 TF_CHECK_OK(NodeDefBuilder("const", "Const")
740 .Attr("dtype", DT_INT32)
741 .Attr("value", tensor_proto)
742 .Finalize(&const_def));
743 Status s;
744 Node *const_node = g->AddNode(const_def, &s);
745 TF_CHECK_OK(s);
746
747 NodeDef fn_def;
748 TF_CHECK_OK(NodeDefBuilder("fn", "fn", &fld)
749 .Input("const", 0, DT_INT32)
750 .Finalize(&fn_def));
751 Node *fn_node = g->AddNode(fn_def, &s);
752 TF_CHECK_OK(s);
753 g->AddEdge(const_node, 0, fn_node, 0);
754
755 NodeDef ret_def;
756 TF_CHECK_OK(NodeDefBuilder("ret", "_Retval")
757 .Attr("index", 0)
758 .Attr("T", DT_INT32)
759 .Input("fn", 0, DT_INT32)
760 .Finalize(&ret_def));
761 Node *ret_node = g->AddNode(ret_def, &s);
762 TF_CHECK_OK(s);
763 g->AddEdge(fn_node, 0, ret_node, 0);
764
765 FunctionDef *xla_fdef = fdl.add_function();
766 TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
767 TF_CHECK_OK(fld.AddFunctionDef(*xla_fdef));
768 }
769
770 protobuf::Map<string, tensorflow::AttrValue> attrs;
771 std::map<string, int> host_compute_core;
772 std::vector<string> shape_inference_graphs;
773 bool has_outside_compilation;
774 NameAttrList name_attrs;
775 name_attrs.set_name("cluster");
776 *name_attrs.mutable_attr() = attrs;
777 TF_CHECK_OK(ExtractOutsideCompilationTest(
778 "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
779 host_compute_core, &fld, &shape_inference_graphs,
780 &has_outside_compilation));
781
782 // Check host graph.
783 {
784 std::unique_ptr<FunctionBody> host_fbody;
785 AttrValue device_ordinal_temp_value;
786 device_ordinal_temp_value.set_i(0);
787 protobuf::Map<string, AttrValue> host_func_attrs;
788 host_func_attrs["_device_ordinal"] = device_ordinal_temp_value;
789 TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"),
790 AttrSlice(&host_func_attrs), &fld,
791 &host_fbody));
792 Graph *host_graph = host_fbody->graph;
793 auto node_name_index = host_graph->BuildNodeNameIndex();
794
795 // Verify we have call node for outside compilation in `fn`.
796 Node *call_node = node_name_index["oc_call_fn"];
797 EXPECT_NE(call_node, nullptr);
798
799 std::unique_ptr<FunctionBody> call_fbody;
800 TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("oc_func_call_host_fn"),
801 AttrSlice(&host_func_attrs), &fld,
802 &call_fbody));
803
804 // Verify we have _XlaRecvAtHost and _XlaSendFromHost nodes.
805 bool has_recv = false, has_send = false;
806 for (Node *n : call_fbody->graph->nodes()) {
807 if (n->type_string() == "_XlaRecvAtHost") {
808 has_recv = true;
809 } else if (n->type_string() == "_XlaSendFromHost") {
810 has_send = true;
811 }
812 }
813 EXPECT_TRUE(has_recv);
814 EXPECT_TRUE(has_send);
815 }
816
817 // Check XLA graph.
818 {
819 std::unique_ptr<FunctionBody> xla_fbody;
820 TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
821 AttrSlice(), &fld, &xla_fbody));
822 Graph *xla_graph = xla_fbody->graph;
823 auto node_name_index = xla_graph->BuildNodeNameIndex();
824
825 // Check that we have call node.
826 Node *fn_node = node_name_index["fn"];
827 EXPECT_NE(fn_node, nullptr);
828 EXPECT_EQ(fn_node->type_string(), "fn_oc");
829
830 std::unique_ptr<FunctionBody> call_fbody;
831 TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("fn_oc"), AttrSlice(), &fld,
832 &call_fbody));
833
834 // Verify we have XlaHostCompute nodes.
835 bool has_hc = false;
836 for (Node *n : call_fbody->graph->nodes()) {
837 if (n->type_string() == "XlaHostCompute") {
838 has_hc = true;
839 }
840 }
841 EXPECT_TRUE(has_hc);
842 }
843 }
844
TEST_F(ExtractOutsideCompilationForFunctionTest,OutsideCompilationClusterDataDependency)845 TEST_F(ExtractOutsideCompilationForFunctionTest,
846 OutsideCompilationClusterDataDependency) {
847 // Build the XLA computation func.
848 // "const0"
849 // "identity0" = "const0" (outside compilation cluster "0")
850 // "identity1" = "identity0" (outside compilation cluster "1")
851 // "identity2" = "identity1"
852 FunctionDefLibrary fdl;
853 {
854 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
855 Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
856 Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
857 Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
858 Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
859 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
860 TF_CHECK_OK(s.ToGraph(g.get()));
861 std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString()
862 << std::endl;
863 auto node_name_image = g->BuildNodeNameIndex();
864 node_name_image["identity0"]->AddAttr("_oc", "0");
865 node_name_image["identity1"]->AddAttr("_oc", "1");
866
867 PartialTensorShape shape({2});
868 node_name_image["identity1"]->AddAttr(
869 kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
870
871 FunctionDef *xla_fdef = fdl.add_function();
872 TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
873 }
874 FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
875
876 protobuf::Map<string, tensorflow::AttrValue> attrs;
877 std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
878 std::vector<string> shape_inference_graphs;
879 bool has_outside_compilation;
880 NameAttrList name_attrs;
881 name_attrs.set_name("cluster");
882 *name_attrs.mutable_attr() = attrs;
883 TF_CHECK_OK(ExtractOutsideCompilationTest(
884 "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
885 host_compute_core, &fld, &shape_inference_graphs,
886 &has_outside_compilation));
887
888 // Get rewritten XLA computation function.
889 std::unique_ptr<FunctionBody> xla_fbody;
890 TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
891 AttrSlice(), &fld, &xla_fbody));
892 auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
893
894 // Check XlaHostCompute nodes.
895 Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
896 EXPECT_NE(host_compute_0, nullptr);
897 Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
898 EXPECT_NE(host_compute_1, nullptr);
899
900 // Check XlaHostCompute nodes' "_xla_token_input_nodes" attr.
901 std::vector<string> token_input_nodes;
902 TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()),
903 "_xla_token_input_nodes", &token_input_nodes));
904
905 std::vector<string> expected_token_input_nodes_0({"_xla_token_arg_node"});
906 EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0);
907 token_input_nodes.clear();
908 std::vector<string> expected_token_input_nodes_1(
909 {"_xla_token_arg_node", "outside_compilation_0_host_compute"});
910 TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
911 "_xla_token_input_nodes", &token_input_nodes));
912 EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
913
914 // Check there is a control edge from host_compute_0 to host_compute_1.
915 bool has_control_edge = false;
916 for (const Edge *e : host_compute_1->in_edges()) {
917 if (e->IsControlEdge() && e->src() == host_compute_0) {
918 has_control_edge = true;
919 break;
920 }
921 }
922 EXPECT_TRUE(has_control_edge);
923 }
924
TEST_F(ExtractOutsideCompilationForFunctionTest,OutsideCompilationClusterControlDependency)925 TEST_F(ExtractOutsideCompilationForFunctionTest,
926 OutsideCompilationClusterControlDependency) {
927 // Build the XLA computation func.
928 // "const0"
929 // "identity0" = "const0" (outside compilation cluster "0")
930 // "identity1" = "const0" "^identity0" (outside compilation cluster "1",
931 // control dependent on cluster "0")
932 // "identity2" = "identity1"
933 FunctionDefLibrary fdl;
934 {
935 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
936 Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
937 Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
938 Output identity1 = ops::Identity(
939 s.WithOpName("identity1").WithControlDependencies(identity0), const0);
940 Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
941 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
942 TF_CHECK_OK(s.ToGraph(g.get()));
943 std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString()
944 << std::endl;
945 auto node_name_image = g->BuildNodeNameIndex();
946 node_name_image["identity0"]->AddAttr("_oc", "0");
947 node_name_image["identity1"]->AddAttr("_oc", "1");
948
949 PartialTensorShape shape({2});
950 node_name_image["identity1"]->AddAttr(
951 kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
952
953 FunctionDef *xla_fdef = fdl.add_function();
954 TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
955 }
956 FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
957
958 protobuf::Map<string, tensorflow::AttrValue> attrs;
959 std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
960 std::vector<string> shape_inference_graphs;
961 bool has_outside_compilation;
962 NameAttrList name_attrs;
963 name_attrs.set_name("cluster");
964 *name_attrs.mutable_attr() = attrs;
965 TF_CHECK_OK(ExtractOutsideCompilationTest(
966 "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
967 host_compute_core, &fld, &shape_inference_graphs,
968 &has_outside_compilation));
969
970 // Get rewritten XLA computation function.
971 std::unique_ptr<FunctionBody> xla_fbody;
972 TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
973 AttrSlice(), &fld, &xla_fbody));
974 auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
975
976 // Check XlaHostCompute nodes.
977 Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
978 EXPECT_NE(host_compute_0, nullptr);
979 Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
980 EXPECT_NE(host_compute_1, nullptr);
981
982 // Check XlaHostCompute nodes' "_xla_token_input_nodes" attr.
983 std::vector<string> token_input_nodes;
984 TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()),
985 "_xla_token_input_nodes", &token_input_nodes));
986
987 std::vector<string> expected_token_input_nodes_0({"_xla_token_arg_node"});
988 EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0);
989 token_input_nodes.clear();
990 std::vector<string> expected_token_input_nodes_1(
991 {"_xla_token_arg_node", "outside_compilation_0_host_compute"});
992 TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
993 "_xla_token_input_nodes", &token_input_nodes));
994 EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
995
996 // Check there is a control edge from host_compute_0 to host_compute_1.
997 bool has_control_edge = false;
998 for (const Edge *e : host_compute_1->in_edges()) {
999 if (e->IsControlEdge() && e->src() == host_compute_0) {
1000 has_control_edge = true;
1001 break;
1002 }
1003 }
1004 EXPECT_TRUE(has_control_edge);
1005 }
1006 } // namespace tensorflow
1007