xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
17 
18 #include "absl/strings/match.h"
19 #include "tensorflow/cc/framework/scope.h"
20 #include "tensorflow/cc/ops/array_ops.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/functional_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/compiler/jit/encapsulate_util.h"
25 #include "tensorflow/compiler/xla/test.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/framework/common_shape_fns.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/graph_to_functiondef.h"
31 #include "tensorflow/core/framework/node_def_util.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/public/session_options.h"
36 #include "tensorflow/core/public/version.h"
37 
38 namespace tensorflow {
39 
TEST(RewriteOutsideCompilationSubgraphFnTest,Basic)40 TEST(RewriteOutsideCompilationSubgraphFnTest, Basic) {
41   // Build the graph:
42   // "add" = "arg0" + "arg1"
43   // "ret0" = "add"
44   // "ret1" = "arg1"
45   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
46   Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_INT32, 0);
47   Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_FLOAT, 1);
48   Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
49   Output add = ops::Add(s.WithOpName("add"), arg0, arg0);
50   auto ret0 = ops::_Retval(s.WithOpName("ret0"), add, 0);
51   auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg1, 1);
52   std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
53   TF_CHECK_OK(s.ToGraph(g.get()));
54   auto node_name_image = g->BuildNodeNameIndex();
55   Node *add_node = node_name_image["add"];
56   EXPECT_NE(add_node, nullptr);
57   add_node->AddAttr(kXlaConnectedToXlaComputationAttrName, "cluster");
58   add_node->AddAttr(kXlaConnectedFromXlaComputationAttrName, "cluster");
59 
60   RewriteOutsideCompilationSubgraphFn rewrite_fn("_xla", "_oc", "cluster", "");
61   std::vector<OutputTensor> arg_source_tensors;
62   NodeDef call_node_def;
63   call_node_def.set_op("0");
64   TF_CHECK_OK(
65       rewrite_fn(arg_source_tensors, &g, nullptr, nullptr, &call_node_def));
66   node_name_image = g->BuildNodeNameIndex();
67 
68   // Verify step 1: add key placeholder node.
69   Node *key_placeholder = node_name_image["cluster_key_placeholder"];
70   EXPECT_NE(key_placeholder, nullptr);
71   // Verify step 2: replace _Arg nodes with XlaRecvAtHost.
72   for (Node *n : g->nodes()) {
73     EXPECT_NE(n->type_string(), "_Arg");
74   }
75   Node *recv_at_host = node_name_image["outside_compilation_cluster__0_recv"];
76   EXPECT_NE(recv_at_host, nullptr);
77   std::vector<DataType> recv_at_host_dtypes;
78   TF_CHECK_OK(
79       GetNodeAttr(recv_at_host->attrs(), "Toutputs", &recv_at_host_dtypes));
80   EXPECT_EQ(recv_at_host_dtypes.size(), 3);
81   EXPECT_EQ(recv_at_host_dtypes[0], DT_INT32);
82   EXPECT_EQ(recv_at_host_dtypes[1], DT_FLOAT);
83   EXPECT_EQ(recv_at_host_dtypes[2], DT_INT32);
84   // Verify step 3: replace _Retval nodes with XlaSendFromHost.
85   for (Node *n : g->nodes()) {
86     EXPECT_NE(n->type_string(), "_Retval");
87   }
88   Node *send_from_host = node_name_image["outside_compilation_cluster__0_send"];
89   EXPECT_NE(send_from_host, nullptr);
90   std::vector<DataType> send_from_host_dtypes;
91   TF_CHECK_OK(
92       GetNodeAttr(send_from_host->attrs(), "Tinputs", &send_from_host_dtypes));
93   EXPECT_EQ(send_from_host_dtypes.size(), 2);
94   EXPECT_EQ(send_from_host_dtypes[0], DT_INT32);
95   EXPECT_EQ(send_from_host_dtypes[1], DT_FLOAT);
96   // Verify step 4: nodes marked with XLA cluster and outside compilation attr.
97   add_node = node_name_image["add"];
98   EXPECT_NE(add_node, nullptr);
99   EXPECT_TRUE(HasNodeAttr(add_node->def(), "_xla"));
100   EXPECT_TRUE(HasNodeAttr(add_node->def(), "_oc"));
101   // Verify step 5: control edges added.
102   bool has_control_edge_from_recv_at_host = false;
103   for (auto e : add_node->in_edges()) {
104     if (e->IsControlEdge() && e->src() == recv_at_host) {
105       has_control_edge_from_recv_at_host = true;
106     }
107   }
108   EXPECT_TRUE(has_control_edge_from_recv_at_host);
109   bool has_control_edge_to_send_from_host = false;
110   for (auto e : add_node->out_edges()) {
111     if (e->IsControlEdge() && e->dst() == send_from_host) {
112       has_control_edge_to_send_from_host = true;
113     }
114   }
115   EXPECT_TRUE(has_control_edge_to_send_from_host);
116   // Verify step 7: necessary attrs added to call_node_def.
117   NameAttrList shape_inference_graph;
118   TF_CHECK_OK(GetNodeAttr(AttrSlice(&call_node_def.attr()),
119                           "shape_inference_graph", &shape_inference_graph));
120   EXPECT_EQ(shape_inference_graph.name(),
121             "_outside_compilation_shape_inference_cluster__0");
122 }
123 
TEST(RewriteOutsideCompilationSubgraphFnTest,NoSendFromHost)124 TEST(RewriteOutsideCompilationSubgraphFnTest, NoSendFromHost) {
125   // Build the graph: only 1 node: "arg0"
126   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
127   Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_INT32, 0);
128   std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
129   TF_CHECK_OK(s.ToGraph(g.get()));
130 
131   RewriteOutsideCompilationSubgraphFn rewrite_fn("_xla", "_oc", "cluster", "");
132   std::vector<OutputTensor> arg_source_tensors;
133   NodeDef call_node_def;
134   call_node_def.set_op("0");
135   TF_CHECK_OK(
136       rewrite_fn(arg_source_tensors, &g, nullptr, nullptr, &call_node_def));
137   auto node_name_image = g->BuildNodeNameIndex();
138 
139   // Check key placeholder and RecvAtHost is present, but SendFromHost is not.
140   Node *key_placeholder = node_name_image["cluster_key_placeholder"];
141   EXPECT_NE(key_placeholder, nullptr);
142   Node *recv_at_host = node_name_image["outside_compilation_cluster__0_recv"];
143   EXPECT_NE(recv_at_host, nullptr);
144   Node *send_from_host = node_name_image["outside_compilation_cluster__0_send"];
145   EXPECT_EQ(send_from_host, nullptr);
146 }
147 
TEST(RewriteOutsideCompilationSubgraphFnTest,NoRecvAtHost)148 TEST(RewriteOutsideCompilationSubgraphFnTest, NoRecvAtHost) {
149   // Build the graph:
150   // "ret" = "const0"
151   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
152   Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
153   auto ret = ops::_Retval(s.WithOpName("ret"), const0, 0);
154   std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
155   TF_CHECK_OK(s.ToGraph(g.get()));
156 
157   RewriteOutsideCompilationSubgraphFn rewrite_fn("_xla", "_oc", "cluster", "");
158   std::vector<OutputTensor> arg_source_tensors;
159   NodeDef call_node_def;
160   call_node_def.set_op("0");
161   TF_CHECK_OK(
162       rewrite_fn(arg_source_tensors, &g, nullptr, nullptr, &call_node_def));
163   auto node_name_image = g->BuildNodeNameIndex();
164 
165   // Check key placeholder and SendFromHost is present, but RecvAtHost is not.
166   Node *key_placeholder = node_name_image["cluster_key_placeholder"];
167   EXPECT_NE(key_placeholder, nullptr);
168   Node *recv_at_host = node_name_image["outside_compilation_cluster__0_recv"];
169   EXPECT_EQ(recv_at_host, nullptr);
170   Node *send_from_host = node_name_image["outside_compilation_cluster__0_send"];
171   EXPECT_NE(send_from_host, nullptr);
172 }
173 
TEST(RewriteOutsideCompilationSubgraphFnTest,NoKeyPlaceholder)174 TEST(RewriteOutsideCompilationSubgraphFnTest, NoKeyPlaceholder) {
175   // Build the graph: only 1 node: "const0"
176   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
177   Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
178   std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
179   TF_CHECK_OK(s.ToGraph(g.get()));
180 
181   RewriteOutsideCompilationSubgraphFn rewrite_fn("_xla", "_oc", "cluster", "");
182   std::vector<OutputTensor> arg_source_tensors;
183   NodeDef call_node_def;
184   call_node_def.set_op("0");
185   TF_CHECK_OK(
186       rewrite_fn(arg_source_tensors, &g, nullptr, nullptr, &call_node_def));
187   auto node_name_image = g->BuildNodeNameIndex();
188 
189   // Check key placeholder/RecvAtHost/SendFromHost are not present.
190   Node *key_placeholder = node_name_image["cluster_key_placeholder"];
191   EXPECT_EQ(key_placeholder, nullptr);
192   Node *recv_at_host = node_name_image["outside_compilation_cluster__0_recv"];
193   EXPECT_EQ(recv_at_host, nullptr);
194   Node *send_from_host = node_name_image["outside_compilation_cluster__0_send"];
195   EXPECT_EQ(send_from_host, nullptr);
196 }
197 
TEST(RewriteOutsideCompilationSubgraphFnTest,ShapesInferred)198 TEST(RewriteOutsideCompilationSubgraphFnTest, ShapesInferred) {
199   // Build the graph:
200   // "ret" = "const0"
201   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
202   Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
203   auto ret = ops::_Retval(s.WithOpName("ret"), const0, 0);
204   std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
205   TF_CHECK_OK(s.ToGraph(g.get()));
206   auto node_name_image = g->BuildNodeNameIndex();
207   Node *const0_node = node_name_image["const0"];
208   EXPECT_NE(const0_node, nullptr);
209   PartialTensorShape shape({2});
210   const0_node->AddAttr(kXlaInferredShapesAttrName,
211                        std::vector<PartialTensorShape>{shape});
212 
213   RewriteOutsideCompilationSubgraphFn rewrite_fn("_xla", "_oc", "cluster", "");
214   std::vector<OutputTensor> arg_source_tensors;
215   NodeDef call_node_def;
216   call_node_def.set_op("0");
217   TF_CHECK_OK(
218       rewrite_fn(arg_source_tensors, &g, nullptr, nullptr, &call_node_def));
219   node_name_image = g->BuildNodeNameIndex();
220 
221   // Check "shape" attr is available in call_node_def.
222   std::vector<TensorShapeProto> shapes;
223   TF_CHECK_OK(GetNodeAttr(AttrSlice(&call_node_def.attr()), "shapes", &shapes));
224   EXPECT_EQ(shapes.size(), 1);
225   EXPECT_EQ(shapes[0].dim_size(), 1);
226 }
227 
228 class ExtractOutsideCompilationForFunctionTest : public ::testing::Test {
229  public:
SetUp()230   void SetUp() override {
231     SessionOptions session_options;
232     std::vector<std::unique_ptr<Device>> devices;
233     TF_CHECK_OK(DeviceFactory::AddDevices(
234         session_options, "/job:localhost/replica:0/task:0", &devices));
235     device_mgr_ = std::make_unique<StaticDeviceMgr>(std::move(devices));
236   }
237 
ExtractOutsideCompilationTest(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const NameAttrList & func_name_attrs,const string & new_func_name,const string & host_graph_func_name,const std::map<string,int> & host_compute_core,FunctionLibraryDefinition * fld,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)238   Status ExtractOutsideCompilationTest(
239       const string &xla_cluster_attr_name,
240       const string &outside_compilation_attr_name,
241       const string &xla_cluster_name, const NameAttrList &func_name_attrs,
242       const string &new_func_name, const string &host_graph_func_name,
243       const std::map<string, int> &host_compute_core,
244       FunctionLibraryDefinition *fld,
245       std::vector<string> *shape_inference_graphs,
246       bool *has_outside_compilation) {
247     OptimizerOptions opts;
248     pflr_ = std::make_unique<ProcessFunctionLibraryRuntime>(
249         device_mgr_.get(), Env::Default(), /*config=*/nullptr,
250         TF_GRAPH_DEF_VERSION, fld, opts,
251         /*default_thread_pool=*/nullptr);
252     auto flr = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
253     return ExtractOutsideCompilationForFunction(
254         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
255         func_name_attrs, new_func_name, host_graph_func_name, host_compute_core,
256         flr, fld, shape_inference_graphs, has_outside_compilation);
257   }
258 
259  private:
260   std::unique_ptr<DeviceMgr> device_mgr_;
261   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
262 };
263 
TEST_F(ExtractOutsideCompilationForFunctionTest,Basic)264 TEST_F(ExtractOutsideCompilationForFunctionTest, Basic) {
265   // Build the XLA computation func.
266   // "const0"
267   // "identity0" = "const0" (outside compilation cluster "0")
268   // "identity1" = "identity0" (outside compilation cluster "1")
269   // "identity2" = "identity1"
270   FunctionDefLibrary fdl;
271   {
272     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
273     Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
274     Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
275     Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
276     Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
277     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
278     TF_CHECK_OK(s.ToGraph(g.get()));
279     auto node_name_image = g->BuildNodeNameIndex();
280     node_name_image["identity0"]->AddAttr("_oc", "0");
281     node_name_image["identity1"]->AddAttr("_oc", "1");
282     PartialTensorShape shape({2});
283     node_name_image["identity1"]->AddAttr(
284         kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
285 
286     FunctionDef *xla_fdef = fdl.add_function();
287     TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
288   }
289   FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
290 
291   protobuf::Map<string, tensorflow::AttrValue> attrs;
292   std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
293   std::vector<string> shape_inference_graphs;
294   bool has_outside_compilation;
295   NameAttrList name_attrs;
296   name_attrs.set_name("cluster");
297   *name_attrs.mutable_attr() = attrs;
298   TF_CHECK_OK(ExtractOutsideCompilationTest(
299       "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
300       host_compute_core, &fld, &shape_inference_graphs,
301       &has_outside_compilation));
302 
303   // Get rewritten XLA computation function.
304   std::unique_ptr<FunctionBody> xla_fbody;
305   TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
306                                       AttrSlice(), &fld, &xla_fbody));
307   auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
308 
309   // Check XlaHostCompute nodes.
310   Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
311   EXPECT_NE(host_compute_0, nullptr);
312   Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
313   EXPECT_NE(host_compute_1, nullptr);
314   // Check XlaHostCompute nodes' "tpu_core" attr.
315   int tpu_core;
316   TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "tpu_core", &tpu_core));
317   EXPECT_EQ(tpu_core, 1);
318   TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "tpu_core", &tpu_core));
319   EXPECT_EQ(tpu_core, 0);
320   // Check XlaHostCompute nodes' "shapes" attr. "0" should not have shapes, and
321   // "1" should have shapes.
322   std::vector<TensorShapeProto> shapes;
323   TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shapes", &shapes));
324   EXPECT_EQ(shapes.size(), 0);
325   TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shapes", &shapes));
326   EXPECT_EQ(shapes.size(), 1);
327   EXPECT_EQ(shapes[0].dim_size(), 1);
328   // Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have
329   // empty values.
330   NameAttrList shape_inference_graph;
331   TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph",
332                           &shape_inference_graph));
333   EXPECT_EQ(shape_inference_graph.name(), "");
334   TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph",
335                           &shape_inference_graph));
336   EXPECT_EQ(shape_inference_graph.name(), "");
337 
338   // Check `shape_inference_graphs`.
339   EXPECT_EQ(shape_inference_graphs.size(), 0);
340 
341   // Check host graph: verify we have key placeholder and sequencer.
342   std::unique_ptr<FunctionBody> host_fbody;
343   AttrValue device_ordinal_temp_value;
344   device_ordinal_temp_value.set_i(0);
345   protobuf::Map<string, AttrValue> host_func_attrs;
346   host_func_attrs["_device_ordinal"] = device_ordinal_temp_value;
347   TF_CHECK_OK(FunctionDefToBodyHelper(
348       *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, &host_fbody));
349   Graph *host_graph = host_fbody->graph;
350   Node *key_placeholder = nullptr, *sequencer = nullptr;
351   for (Node *n : host_graph->nodes()) {
352     if (n->type_string() == "Placeholder" &&
353         absl::EndsWith(n->name(), "_key_placeholder")) {
354       EXPECT_EQ(key_placeholder, nullptr);
355       key_placeholder = n;
356     } else if (HasNodeAttr(n->def(), "_xla_host_transfer_sequencer")) {
357       EXPECT_EQ(sequencer, nullptr);
358       sequencer = n;
359     }
360   }
361   EXPECT_NE(key_placeholder, nullptr);
362   EXPECT_NE(sequencer, nullptr);
363   // Check SendFromHost and RecvAtHost has key placeholder as input, and have
364   // control edge to sequencer.
365   int num_send_from_host = 0, num_recv_at_host = 0;
366   std::vector<Node *> send_recv_nodes;
367   for (Node *n : host_graph->nodes()) {
368     if (n->type_string() == "_XlaSendFromHost") {
369       num_send_from_host++;
370       send_recv_nodes.push_back(n);
371     } else if (n->type_string() == "_XlaRecvAtHost") {
372       num_recv_at_host++;
373       send_recv_nodes.push_back(n);
374     }
375   }
376   EXPECT_EQ(num_send_from_host, 1);
377   EXPECT_EQ(num_recv_at_host, 1);
378   for (Node *n : send_recv_nodes) {
379     Node *input_node;
380     TF_CHECK_OK(n->input_node(n->num_inputs() - 1, &input_node));
381     EXPECT_EQ(input_node, key_placeholder);
382 
383     bool has_control_edge_to_sequencer = false;
384     for (const Edge *e : n->out_edges()) {
385       if (e->IsControlEdge() && e->dst() == sequencer) {
386         has_control_edge_to_sequencer = true;
387         break;
388       }
389     }
390     EXPECT_TRUE(has_control_edge_to_sequencer);
391   }
392 }
393 
TEST_F(ExtractOutsideCompilationForFunctionTest,NoHostGraph)394 TEST_F(ExtractOutsideCompilationForFunctionTest, NoHostGraph) {
395   // Build the XLA computation func.
396   // "const0"
397   FunctionDefLibrary fdl;
398   {
399     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
400     Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
401     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
402     TF_CHECK_OK(s.ToGraph(g.get()));
403 
404     FunctionDef *xla_fdef = fdl.add_function();
405     TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
406   }
407   FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
408 
409   protobuf::Map<string, tensorflow::AttrValue> attrs;
410   std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
411   std::vector<string> shape_inference_graphs;
412   bool has_outside_compilation;
413   NameAttrList name_attrs;
414   name_attrs.set_name("cluster");
415   *name_attrs.mutable_attr() = attrs;
416   TF_CHECK_OK(ExtractOutsideCompilationTest(
417       "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
418       host_compute_core, &fld, &shape_inference_graphs,
419       &has_outside_compilation));
420 
421   // Check host graph is not created.
422   EXPECT_EQ(fld.Find("host_graph"), nullptr);
423 }
424 
TEST_F(ExtractOutsideCompilationForFunctionTest,OutsideCompilationInIf)425 TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) {
426   // Build the XLA computation func.
427   // "const0" (bool)
428   // "const1" (int32)
429   // "if0" (pred = "const0", input = "const1", then_branch = "true_fn",
430   //        else_branch = "false_fn")
431   FunctionDefLibrary fdl;
432   {
433     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
434     Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0);
435     Output identity = ops::Identity(s.WithOpName("identity_true_fn"), arg);
436     ops::_Retval retval(s.WithOpName("retval"), identity, 0);
437     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
438     TF_CHECK_OK(s.ToGraph(g.get()));
439     auto node_name_image = g->BuildNodeNameIndex();
440     node_name_image["identity_true_fn"]->AddAttr("_oc", "0");
441     PartialTensorShape shape({2});
442     node_name_image["identity_true_fn"]->AddAttr(
443         kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
444 
445     FunctionDef *true_fn_fdef = fdl.add_function();
446     TF_CHECK_OK(GraphToFunctionDef(*g, "true_fn", true_fn_fdef));
447   }
448   {
449     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
450     Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0);
451     Output identity = ops::Identity(s.WithOpName("identity_false_fn"), arg);
452     ops::_Retval retval(s.WithOpName("retval"), identity, 0);
453     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
454     TF_CHECK_OK(s.ToGraph(g.get()));
455     auto node_name_image = g->BuildNodeNameIndex();
456     node_name_image["identity_false_fn"]->AddAttr("_oc", "0");
457     PartialTensorShape shape({2});
458     node_name_image["identity_false_fn"]->AddAttr(
459         kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
460 
461     FunctionDef *false_fn_fdef = fdl.add_function();
462     TF_CHECK_OK(GraphToFunctionDef(*g, "false_fn", false_fn_fdef));
463   }
464   {
465     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
466     Output cond = ops::Const(s.WithOpName("const0"), true, {2});
467     Output input = ops::Const(s.WithOpName("const1"), 1, {2});
468     NameAttrList true_fn;
469     true_fn.set_name("true_fn");
470     NameAttrList false_fn;
471     false_fn.set_name("false_fn");
472     auto if_op = ops::If(s.WithOpName("if"), cond,
473                          std::initializer_list<Input>{cond, input}, {DT_INT32},
474                          true_fn, false_fn);
475     ops::_Retval retval(s.WithOpName("retval"), if_op.output[0], 0);
476     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
477     TF_CHECK_OK(s.ToGraph(g.get()));
478 
479     FunctionDef *xla_fdef = fdl.add_function();
480     TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
481   }
482   FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
483 
484   protobuf::Map<string, tensorflow::AttrValue> attrs;
485   std::map<string, int> host_compute_core;
486   std::vector<string> shape_inference_graphs;
487   bool has_outside_compilation;
488   NameAttrList name_attrs;
489   name_attrs.set_name("cluster");
490   *name_attrs.mutable_attr() = attrs;
491   TF_CHECK_OK(ExtractOutsideCompilationTest(
492       "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
493       host_compute_core, &fld, &shape_inference_graphs,
494       &has_outside_compilation));
495 
496   // Check host graph.
497   {
498     std::unique_ptr<FunctionBody> host_fbody;
499     AttrValue device_ordinal_temp_value;
500     device_ordinal_temp_value.set_i(0);
501     protobuf::Map<string, AttrValue> host_func_attrs;
502     host_func_attrs["_device_ordinal"] = device_ordinal_temp_value;
503     TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"),
504                                         AttrSlice(&host_func_attrs), &fld,
505                                         &host_fbody));
506     Graph *host_graph = host_fbody->graph;
507     auto node_name_index = host_graph->BuildNodeNameIndex();
508 
509     // Verify we have XlaRecvAtHost to receive "If" predicate.
510     Node *recv_if_pred_node = node_name_index["recv_oc_if_pred_if"];
511     EXPECT_NE(recv_if_pred_node, nullptr);
512 
513     // Verify we have an "If" to choose outside compilation between then_branch
514     // and else_branch, and it has `recv_if_pred_node` as cond input.
515     Node *if_oc_node = node_name_index["oc_if_if"];
516     EXPECT_NE(if_oc_node, nullptr);
517     Node *if_oc_node_cond_input;
518     TF_CHECK_OK(if_oc_node->input_node(0, &if_oc_node_cond_input));
519     EXPECT_EQ(if_oc_node_cond_input, recv_if_pred_node);
520 
521     // Check that then_branch outside compilation has node "identity_true_fn".
522     const FunctionDef *true_def = fld.Find("oc_then_branch_host_if_true_fn");
523     EXPECT_NE(true_def, nullptr);
524     bool has_identity_true_fn_node = false;
525     for (const auto &node_def : true_def->node_def()) {
526       if (node_def.name() == "identity_true_fn") {
527         has_identity_true_fn_node = true;
528         break;
529       }
530     }
531     EXPECT_TRUE(has_identity_true_fn_node);
532 
533     // Check that else_branch outside compilation has node "identity_false_fn".
534     const FunctionDef *false_def = fld.Find("oc_else_branch_host_if_false_fn");
535     EXPECT_NE(false_def, nullptr);
536     bool has_identity_false_fn_node = false;
537     for (const auto &node_def : false_def->node_def()) {
538       if (node_def.name() == "identity_false_fn") {
539         has_identity_false_fn_node = true;
540         break;
541       }
542     }
543     EXPECT_TRUE(has_identity_false_fn_node);
544   }
545 
546   // Check XLA graph.
547   {
548     std::unique_ptr<FunctionBody> xla_fbody;
549     TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
550                                         AttrSlice(), &fld, &xla_fbody));
551     Graph *xla_graph = xla_fbody->graph;
552     auto node_name_index = xla_graph->BuildNodeNameIndex();
553 
554     // Check that we have XlaSendToHost to send cond predicate to host, and
555     // there is a control edge to If node.
556     Node *send_if_pred_node = node_name_index["send_oc_if_pred_if"];
557     EXPECT_NE(send_if_pred_node, nullptr);
558     bool has_control_edge_to_if = false;
559     for (const Edge *e : send_if_pred_node->out_edges()) {
560       if (e->IsControlEdge() && e->dst()->name() == "if") {
561         has_control_edge_to_if = true;
562         break;
563       }
564     }
565     EXPECT_TRUE(has_control_edge_to_if);
566 
567     // Check that the "If" node now has `send_if_pred_node` as attribute
568     // _xla_token_input_nodes.
569     Node *if_node = node_name_index["if"];
570     EXPECT_NE(if_node, nullptr);
571     std::vector<string> token_inputs;
572     TF_CHECK_OK(
573         GetNodeAttr(if_node->def(), "_xla_token_input_nodes", &token_inputs));
574     EXPECT_THAT(token_inputs, ::testing::ElementsAre("send_oc_if_pred_if"));
575   }
576 }
577 
TEST_F(ExtractOutsideCompilationForFunctionTest,OutsideCompilationInWhile)578 TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) {
579   // Build the XLA computation func.
580   // "const0" (bool)
581   // "while0" (input = "const0", cond = "cond_fn", body = "body_fn")
582   FunctionDefLibrary fdl;
583   {
584     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
585     Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0);
586     Output identity = ops::Identity(s.WithOpName("identity_cond_fn"), arg);
587     ops::_Retval retval(s.WithOpName("retval"), identity, 0);
588     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
589     TF_CHECK_OK(s.ToGraph(g.get()));
590     auto node_name_image = g->BuildNodeNameIndex();
591     node_name_image["identity_cond_fn"]->AddAttr("_oc", "0");
592     PartialTensorShape shape({2});
593     node_name_image["identity_cond_fn"]->AddAttr(
594         kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
595 
596     FunctionDef *cond_fn_fdef = fdl.add_function();
597     TF_CHECK_OK(GraphToFunctionDef(*g, "cond_fn", cond_fn_fdef));
598   }
599   {
600     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
601     Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0);
602     Output identity = ops::Identity(s.WithOpName("identity_body_fn"), arg);
603     ops::_Retval retval(s.WithOpName("retval"), identity, 0);
604     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
605     TF_CHECK_OK(s.ToGraph(g.get()));
606     auto node_name_image = g->BuildNodeNameIndex();
607     node_name_image["identity_body_fn"]->AddAttr("_oc", "0");
608     PartialTensorShape shape({2});
609     node_name_image["identity_body_fn"]->AddAttr(
610         kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
611 
612     FunctionDef *body_fn_fdef = fdl.add_function();
613     TF_CHECK_OK(GraphToFunctionDef(*g, "body_fn", body_fn_fdef));
614   }
615   {
616     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
617     Output input = ops::Const(s.WithOpName("const0"), true, {2});
618     NameAttrList cond_fn;
619     cond_fn.set_name("cond_fn");
620     NameAttrList body_fn;
621     body_fn.set_name("body_fn");
622     auto while_op =
623         ops::While(s.WithOpName("while"), std::initializer_list<Input>{input},
624                    cond_fn, body_fn);
625     ops::_Retval retval(s.WithOpName("retval"), while_op.output[0], 0);
626     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
627     TF_CHECK_OK(s.ToGraph(g.get()));
628 
629     FunctionDef *xla_fdef = fdl.add_function();
630     TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
631   }
632   FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
633 
634   protobuf::Map<string, tensorflow::AttrValue> attrs;
635   std::map<string, int> host_compute_core;
636   std::vector<string> shape_inference_graphs;
637   bool has_outside_compilation;
638   NameAttrList name_attrs;
639   name_attrs.set_name("cluster");
640   *name_attrs.mutable_attr() = attrs;
641   TF_CHECK_OK(ExtractOutsideCompilationTest(
642       "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
643       host_compute_core, &fld, &shape_inference_graphs,
644       &has_outside_compilation));
645 
646   // Check host graph.
647   {
648     std::unique_ptr<FunctionBody> host_fbody;
649     AttrValue device_ordinal_temp_value;
650     device_ordinal_temp_value.set_i(0);
651     protobuf::Map<string, AttrValue> host_func_attrs;
652     host_func_attrs["_device_ordinal"] = device_ordinal_temp_value;
653     TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"),
654                                         AttrSlice(&host_func_attrs), &fld,
655                                         &host_fbody));
656     Graph *host_graph = host_fbody->graph;
657     auto node_name_index = host_graph->BuildNodeNameIndex();
658 
659     // Verify we have an "While" to execute outside compilation.
660     Node *while_oc_node = node_name_index["oc_while_while"];
661     EXPECT_NE(while_oc_node, nullptr);
662 
663     // Check that cond outside compilation has node "identity_cond_fn".
664     const FunctionDef *cond_def = fld.Find("oc_cond_host_while_cond_fn");
665     EXPECT_NE(cond_def, nullptr);
666     bool has_identity_cond_fn_node = false;
667     for (const auto &node_def : cond_def->node_def()) {
668       if (node_def.name() == "identity_cond_fn") {
669         has_identity_cond_fn_node = true;
670         break;
671       }
672     }
673     EXPECT_TRUE(has_identity_cond_fn_node);
674 
675     // Check that body outside compilation has node "identity_body_fn".
676     const FunctionDef *body_def = fld.Find("oc_body_host_while_body_fn");
677     EXPECT_NE(body_def, nullptr);
678     bool has_identity_body_fn_node = false;
679     for (const auto &node_def : body_def->node_def()) {
680       if (node_def.name() == "identity_body_fn") {
681         has_identity_body_fn_node = true;
682         break;
683       }
684     }
685     EXPECT_TRUE(has_identity_body_fn_node);
686   }
687 
688   // Check XLA graph.
689   {
690     // Verify that rewritten cond fn has XlaSendToHost to send loop predicate to
691     // host.
692     const FunctionDef *cond_def = fld.Find("cond_fn_oc");
693     EXPECT_NE(cond_def, nullptr);
694     bool has_send_oc_while_cond_node = false;
695     for (const auto &node_def : cond_def->node_def()) {
696       if (node_def.name() == "send_oc_while_cond_while") {
697         has_send_oc_while_cond_node = true;
698         break;
699       }
700     }
701     EXPECT_TRUE(has_send_oc_while_cond_node);
702   }
703 }
704 
TEST_F(ExtractOutsideCompilationForFunctionTest,OutsideCompilationInFunction)705 TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) {
706   // Build the XLA computation func.
707   // "const0" (int32)
708   // "fn" (input = "const0")
709   FunctionDefLibrary fdl;
710   {
711     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
712     Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0);
713     Output identity = ops::Identity(s.WithOpName("identity"), arg);
714     ops::_Retval retval(s.WithOpName("retval"), identity, 0);
715     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
716     TF_CHECK_OK(s.ToGraph(g.get()));
717     auto node_name_image = g->BuildNodeNameIndex();
718     node_name_image["identity"]->AddAttr("_oc", "0");
719     PartialTensorShape shape({2});
720     node_name_image["identity"]->AddAttr(
721         kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
722 
723     FunctionDef *true_fn_fdef = fdl.add_function();
724     TF_CHECK_OK(GraphToFunctionDef(*g, "fn", true_fn_fdef));
725   }
726   FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
727   {
728     std::unique_ptr<Graph> g(new Graph(&fld));
729 
730     tensorflow::TensorProto tensor_proto;
731     tensor_proto.set_dtype(tensorflow::DT_INT32);
732     tensorflow::TensorShapeProto shape;
733     shape.add_dim()->set_size(2);
734     *tensor_proto.mutable_tensor_shape() = shape;
735     for (int i = 0; i < 2; ++i) {
736       tensor_proto.add_int_val(1);
737     }
738     NodeDef const_def;
739     TF_CHECK_OK(NodeDefBuilder("const", "Const")
740                     .Attr("dtype", DT_INT32)
741                     .Attr("value", tensor_proto)
742                     .Finalize(&const_def));
743     Status s;
744     Node *const_node = g->AddNode(const_def, &s);
745     TF_CHECK_OK(s);
746 
747     NodeDef fn_def;
748     TF_CHECK_OK(NodeDefBuilder("fn", "fn", &fld)
749                     .Input("const", 0, DT_INT32)
750                     .Finalize(&fn_def));
751     Node *fn_node = g->AddNode(fn_def, &s);
752     TF_CHECK_OK(s);
753     g->AddEdge(const_node, 0, fn_node, 0);
754 
755     NodeDef ret_def;
756     TF_CHECK_OK(NodeDefBuilder("ret", "_Retval")
757                     .Attr("index", 0)
758                     .Attr("T", DT_INT32)
759                     .Input("fn", 0, DT_INT32)
760                     .Finalize(&ret_def));
761     Node *ret_node = g->AddNode(ret_def, &s);
762     TF_CHECK_OK(s);
763     g->AddEdge(fn_node, 0, ret_node, 0);
764 
765     FunctionDef *xla_fdef = fdl.add_function();
766     TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
767     TF_CHECK_OK(fld.AddFunctionDef(*xla_fdef));
768   }
769 
770   protobuf::Map<string, tensorflow::AttrValue> attrs;
771   std::map<string, int> host_compute_core;
772   std::vector<string> shape_inference_graphs;
773   bool has_outside_compilation;
774   NameAttrList name_attrs;
775   name_attrs.set_name("cluster");
776   *name_attrs.mutable_attr() = attrs;
777   TF_CHECK_OK(ExtractOutsideCompilationTest(
778       "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
779       host_compute_core, &fld, &shape_inference_graphs,
780       &has_outside_compilation));
781 
782   // Check host graph.
783   {
784     std::unique_ptr<FunctionBody> host_fbody;
785     AttrValue device_ordinal_temp_value;
786     device_ordinal_temp_value.set_i(0);
787     protobuf::Map<string, AttrValue> host_func_attrs;
788     host_func_attrs["_device_ordinal"] = device_ordinal_temp_value;
789     TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"),
790                                         AttrSlice(&host_func_attrs), &fld,
791                                         &host_fbody));
792     Graph *host_graph = host_fbody->graph;
793     auto node_name_index = host_graph->BuildNodeNameIndex();
794 
795     // Verify we have call node for outside compilation in `fn`.
796     Node *call_node = node_name_index["oc_call_fn"];
797     EXPECT_NE(call_node, nullptr);
798 
799     std::unique_ptr<FunctionBody> call_fbody;
800     TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("oc_func_call_host_fn"),
801                                         AttrSlice(&host_func_attrs), &fld,
802                                         &call_fbody));
803 
804     // Verify we have _XlaRecvAtHost and _XlaSendFromHost nodes.
805     bool has_recv = false, has_send = false;
806     for (Node *n : call_fbody->graph->nodes()) {
807       if (n->type_string() == "_XlaRecvAtHost") {
808         has_recv = true;
809       } else if (n->type_string() == "_XlaSendFromHost") {
810         has_send = true;
811       }
812     }
813     EXPECT_TRUE(has_recv);
814     EXPECT_TRUE(has_send);
815   }
816 
817   // Check XLA graph.
818   {
819     std::unique_ptr<FunctionBody> xla_fbody;
820     TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
821                                         AttrSlice(), &fld, &xla_fbody));
822     Graph *xla_graph = xla_fbody->graph;
823     auto node_name_index = xla_graph->BuildNodeNameIndex();
824 
825     // Check that we have call node.
826     Node *fn_node = node_name_index["fn"];
827     EXPECT_NE(fn_node, nullptr);
828     EXPECT_EQ(fn_node->type_string(), "fn_oc");
829 
830     std::unique_ptr<FunctionBody> call_fbody;
831     TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("fn_oc"), AttrSlice(), &fld,
832                                         &call_fbody));
833 
834     // Verify we have XlaHostCompute nodes.
835     bool has_hc = false;
836     for (Node *n : call_fbody->graph->nodes()) {
837       if (n->type_string() == "XlaHostCompute") {
838         has_hc = true;
839       }
840     }
841     EXPECT_TRUE(has_hc);
842   }
843 }
844 
TEST_F(ExtractOutsideCompilationForFunctionTest,OutsideCompilationClusterDataDependency)845 TEST_F(ExtractOutsideCompilationForFunctionTest,
846        OutsideCompilationClusterDataDependency) {
847   // Build the XLA computation func.
848   // "const0"
849   // "identity0" = "const0" (outside compilation cluster "0")
850   // "identity1" = "identity0" (outside compilation cluster "1")
851   // "identity2" = "identity1"
852   FunctionDefLibrary fdl;
853   {
854     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
855     Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
856     Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
857     Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
858     Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
859     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
860     TF_CHECK_OK(s.ToGraph(g.get()));
861     std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString()
862               << std::endl;
863     auto node_name_image = g->BuildNodeNameIndex();
864     node_name_image["identity0"]->AddAttr("_oc", "0");
865     node_name_image["identity1"]->AddAttr("_oc", "1");
866 
867     PartialTensorShape shape({2});
868     node_name_image["identity1"]->AddAttr(
869         kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
870 
871     FunctionDef *xla_fdef = fdl.add_function();
872     TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
873   }
874   FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
875 
876   protobuf::Map<string, tensorflow::AttrValue> attrs;
877   std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
878   std::vector<string> shape_inference_graphs;
879   bool has_outside_compilation;
880   NameAttrList name_attrs;
881   name_attrs.set_name("cluster");
882   *name_attrs.mutable_attr() = attrs;
883   TF_CHECK_OK(ExtractOutsideCompilationTest(
884       "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
885       host_compute_core, &fld, &shape_inference_graphs,
886       &has_outside_compilation));
887 
888   // Get rewritten XLA computation function.
889   std::unique_ptr<FunctionBody> xla_fbody;
890   TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
891                                       AttrSlice(), &fld, &xla_fbody));
892   auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
893 
894   // Check XlaHostCompute nodes.
895   Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
896   EXPECT_NE(host_compute_0, nullptr);
897   Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
898   EXPECT_NE(host_compute_1, nullptr);
899 
900   // Check XlaHostCompute nodes' "_xla_token_input_nodes" attr.
901   std::vector<string> token_input_nodes;
902   TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()),
903                           "_xla_token_input_nodes", &token_input_nodes));
904 
905   std::vector<string> expected_token_input_nodes_0({"_xla_token_arg_node"});
906   EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0);
907   token_input_nodes.clear();
908   std::vector<string> expected_token_input_nodes_1(
909       {"_xla_token_arg_node", "outside_compilation_0_host_compute"});
910   TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
911                           "_xla_token_input_nodes", &token_input_nodes));
912   EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
913 
914   // Check there is a control edge from host_compute_0 to host_compute_1.
915   bool has_control_edge = false;
916   for (const Edge *e : host_compute_1->in_edges()) {
917     if (e->IsControlEdge() && e->src() == host_compute_0) {
918       has_control_edge = true;
919       break;
920     }
921   }
922   EXPECT_TRUE(has_control_edge);
923 }
924 
TEST_F(ExtractOutsideCompilationForFunctionTest,OutsideCompilationClusterControlDependency)925 TEST_F(ExtractOutsideCompilationForFunctionTest,
926        OutsideCompilationClusterControlDependency) {
927   // Build the XLA computation func.
928   // "const0"
929   // "identity0" = "const0" (outside compilation cluster "0")
930   // "identity1" = "const0" "^identity0" (outside compilation cluster "1",
931   //                                      control dependent on cluster "0")
932   // "identity2" = "identity1"
933   FunctionDefLibrary fdl;
934   {
935     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
936     Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
937     Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
938     Output identity1 = ops::Identity(
939         s.WithOpName("identity1").WithControlDependencies(identity0), const0);
940     Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
941     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
942     TF_CHECK_OK(s.ToGraph(g.get()));
943     std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString()
944               << std::endl;
945     auto node_name_image = g->BuildNodeNameIndex();
946     node_name_image["identity0"]->AddAttr("_oc", "0");
947     node_name_image["identity1"]->AddAttr("_oc", "1");
948 
949     PartialTensorShape shape({2});
950     node_name_image["identity1"]->AddAttr(
951         kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
952 
953     FunctionDef *xla_fdef = fdl.add_function();
954     TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
955   }
956   FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
957 
958   protobuf::Map<string, tensorflow::AttrValue> attrs;
959   std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
960   std::vector<string> shape_inference_graphs;
961   bool has_outside_compilation;
962   NameAttrList name_attrs;
963   name_attrs.set_name("cluster");
964   *name_attrs.mutable_attr() = attrs;
965   TF_CHECK_OK(ExtractOutsideCompilationTest(
966       "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
967       host_compute_core, &fld, &shape_inference_graphs,
968       &has_outside_compilation));
969 
970   // Get rewritten XLA computation function.
971   std::unique_ptr<FunctionBody> xla_fbody;
972   TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
973                                       AttrSlice(), &fld, &xla_fbody));
974   auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
975 
976   // Check XlaHostCompute nodes.
977   Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
978   EXPECT_NE(host_compute_0, nullptr);
979   Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
980   EXPECT_NE(host_compute_1, nullptr);
981 
982   // Check XlaHostCompute nodes' "_xla_token_input_nodes" attr.
983   std::vector<string> token_input_nodes;
984   TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()),
985                           "_xla_token_input_nodes", &token_input_nodes));
986 
987   std::vector<string> expected_token_input_nodes_0({"_xla_token_arg_node"});
988   EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0);
989   token_input_nodes.clear();
990   std::vector<string> expected_token_input_nodes_1(
991       {"_xla_token_arg_node", "outside_compilation_0_host_compute"});
992   TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
993                           "_xla_token_input_nodes", &token_input_nodes));
994   EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
995 
996   // Check there is a control edge from host_compute_0 to host_compute_1.
997   bool has_control_edge = false;
998   for (const Edge *e : host_compute_1->in_edges()) {
999     if (e->IsControlEdge() && e->src() == host_compute_0) {
1000       has_control_edge = true;
1001       break;
1002     }
1003   }
1004   EXPECT_TRUE(has_control_edge);
1005 }
1006 }  // namespace tensorflow
1007