xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/strings/match.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/cc/framework/ops.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/compiler/jit/encapsulate_util.h"
26 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
27 #include "tensorflow/compiler/jit/test_util.h"
28 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
29 #include "tensorflow/core/common_runtime/device_factory.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/common_runtime/graph_constructor.h"
32 #include "tensorflow/core/framework/function_testlib.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/graph/graph_def_builder.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/core/status_test_util.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/core/public/session_options.h"
39 #include "tensorflow/core/public/version.h"
40 #include "tensorflow/core/util/equal_graph_def.h"
41 
42 namespace tensorflow {
43 namespace {
44 
45 const char* const kXlaHostTransferSequencerAttr =
46     "_xla_host_transfer_sequencer";
47 
AddGraphDefToFunctionLibrary(const GraphDefBuilder & graphdef_builder,const string & name_suffix,FunctionDefLibrary * library)48 Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder,
49                                     const string& name_suffix,
50                                     FunctionDefLibrary* library) {
51   GraphDef graphdef;
52   TF_RETURN_IF_ERROR(graphdef_builder.ToGraphDef(&graphdef));
53   std::unique_ptr<Graph> graph =
54       std::unique_ptr<Graph>(new Graph(OpRegistry::Global()));
55   GraphConstructorOptions opts;
56   opts.allow_internal_ops = true;
57   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graphdef, graph.get()));
58   FunctionDef* fdef = library->add_function();
59   TF_RETURN_IF_ERROR(GraphToFunctionDef(
60       *graph,
61       absl::StrCat("_outside_compilation_shape_inference_", name_suffix),
62       fdef));
63   return OkStatus();
64 }
65 
66 template <class Tkey, class Tvalue>
EqualProtoMap(const::tensorflow::protobuf::Map<Tkey,Tvalue> & a,const::tensorflow::protobuf::Map<Tkey,Tvalue> & b,const std::function<string (const Tkey &)> & key_to_string,const std::function<string (const Tvalue &)> & value_to_string,const std::function<bool (const Tkey &,const Tvalue &,const Tvalue &)> & compare,const string & map_name,string * diff)67 bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
68                    const ::tensorflow::protobuf::Map<Tkey, Tvalue>& b,
69                    const std::function<string(const Tkey&)>& key_to_string,
70                    const std::function<string(const Tvalue&)>& value_to_string,
71                    const std::function<bool(const Tkey&, const Tvalue&,
72                                             const Tvalue&)>& compare,
73                    const string& map_name, string* diff) {
74   for (const auto& elt_a : a) {
75     const auto iter = b.find(elt_a.first);
76     if (iter == b.end()) {
77       if (diff) {
78         *diff = absl::StrCat(map_name, " expected: contains element with key '",
79                              key_to_string(elt_a.first),
80                              "' got: map has no such element");
81       }
82       return false;
83     }
84     if (!compare(elt_a.first, elt_a.second, iter->second)) {
85       if (diff) {
86         *diff = absl::StrCat(map_name, " expected: element with key '",
87                              key_to_string(elt_a.first), "' has value '",
88                              value_to_string(elt_a.second), "' got: '",
89                              value_to_string(iter->second), "'");
90       }
91       return false;
92     }
93   }
94   for (const auto& elt_b : b) {
95     const auto iter = a.find(elt_b.first);
96     if (iter == a.end()) {
97       if (diff) {
98         *diff = absl::StrCat(map_name, " got: contains element with key '",
99                              key_to_string(elt_b.first),
100                              "' expected: map has no such element");
101       }
102       return false;
103     }
104   }
105   return true;
106 }
107 
EqualFunctionNodeDef(const NodeDef & a,const NodeDef & b,const string & diff_preamble,string * diff)108 bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
109                           const string& diff_preamble, string* diff) {
110   if (a.op() != b.op()) {
111     if (diff) {
112       *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
113                            ", expected op '", a.op(), "' got '", b.op());
114     }
115     return false;
116   }
117   if (a.device() != b.device()) {
118     if (diff) {
119       *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
120                            ", expected device '", a.device(), "' got '",
121                            b.device());
122     }
123     return false;
124   }
125   if (a.input_size() != b.input_size()) {
126     if (diff) {
127       *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
128                            ", expected ", a.input_size(), " inputs got ",
129                            b.input_size(), " expected:\n", a.DebugString(),
130                            "\ngot:\n", b.DebugString());
131     }
132     return false;
133   }
134   std::unordered_set<string> control_input_a;
135   std::unordered_set<string> control_input_b;
136   for (int i = 0; i < a.input_size(); ++i) {
137     if (absl::StartsWith(a.input(i), "^")) {
138       if (!absl::StartsWith(b.input(i), "^")) {
139         if (diff) {
140           *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
141                                " input ", i, ", expected control input ",
142                                a.input(i), " got ", b.input(i), " expected:\n",
143                                a.DebugString(), "\ngot:\n", b.DebugString());
144         }
145         return false;
146       }
147       control_input_a.insert(a.input(i));
148       control_input_b.insert(b.input(i));
149     } else if (a.input(i) != b.input(i)) {
150       if (diff) {
151         *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
152                              " input ", i, ", expected ", a.input(i), " got ",
153                              b.input(i), " expected:\n", a.DebugString(),
154                              "\ngot:\n", b.DebugString());
155       }
156       return false;
157     }
158   }
159   if (control_input_a != control_input_b) {
160     if (diff) {
161       *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
162                            " control inputs differ expected:\n",
163                            a.DebugString(), "\ngot:\n", b.DebugString());
164     }
165     return false;
166   }
167   return EqualProtoMap<string, AttrValue>(
168       a.attr(), b.attr(), [](const string& s) { return s; },
169       [](const AttrValue& v) { return v.DebugString(); },
170       [](const string& key, const AttrValue& av, const AttrValue& bv) {
171         if (key == "ancestors") {
172           // The ancestors are added from a set so the order is unpredictable;
173           // just compare set equality not list equality.
174           std::unordered_set<string> a_set(av.list().s().begin(),
175                                            av.list().s().end());
176           std::unordered_set<string> b_set(bv.list().s().begin(),
177                                            bv.list().s().end());
178           return a_set == b_set;
179         } else {
180           return av.DebugString() == bv.DebugString();
181         }
182       },
183       absl::StrCat(diff_preamble, " attr mismatch for node ", a.name()), diff);
184 }
185 
EqualFunctionDef(const FunctionDef & a,const FunctionDef & b,string * diff)186 bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
187                       string* diff) {
188   if (a.signature().DebugString() != b.signature().DebugString()) {
189     if (diff) {
190       *diff =
191           absl::StrCat("Signature mismatch for function ", a.signature().name(),
192                        ", expected:\n", a.signature().DebugString(), "\ngot:\n",
193                        b.signature().DebugString());
194     }
195     return false;
196   }
197   if (!EqualProtoMap<string, AttrValue>(
198           a.attr(), b.attr(), [](const string& s) { return s; },
199           [](const AttrValue& v) { return v.DebugString(); },
200           [](const string& key, const AttrValue& av, const AttrValue& bv) {
201             return av.DebugString() == bv.DebugString();
202           },
203           absl::StrCat("attr mismatch for function ", a.signature().name()),
204           diff)) {
205     return false;
206   }
207   if (!EqualProtoMap<string, string>(
208           a.ret(), b.ret(), [](const string& s) { return s; },
209           [](const string& s) { return s; },
210           [](const string& key, const string& av, const string& bv) {
211             return av == bv;
212           },
213           absl::StrCat("ret mismatch for function ", a.signature().name()),
214           diff)) {
215     return false;
216   }
217   for (int i = 0; i < a.node_def_size(); ++i) {
218     bool found = false;
219     for (int j = 0; j < b.node_def_size(); ++j) {
220       if (a.node_def(i).name() == b.node_def(j).name()) {
221         if (!EqualFunctionNodeDef(
222                 a.node_def(i), b.node_def(j),
223                 absl::StrCat("Function ", a.signature().name()), diff)) {
224           return false;
225         }
226         found = true;
227         break;
228       }
229     }
230     if (!found) {
231       if (diff) {
232         *diff = absl::StrCat("Function ", a.signature().name(),
233                              ", expected: has node '", a.node_def(i).name(),
234                              "' got: no node of that name");
235       }
236       return false;
237     }
238   }
239   for (int i = 0; i < b.node_def_size(); ++i) {
240     bool found = false;
241     for (int j = 0; j < a.node_def_size(); ++j) {
242       if (b.node_def(i).name() == a.node_def(j).name()) {
243         found = true;
244         break;
245       }
246     }
247     if (!found) {
248       if (diff) {
249         *diff = absl::StrCat("Function ", a.signature().name(),
250                              ", got: has node '", b.node_def(i).name(),
251                              "' expected: no node of that name");
252       }
253       return false;
254     }
255   }
256   return true;
257 }
258 
EqualFunctionDefLibrary(const FunctionDefLibrary & expected,const FunctionDefLibrary & actual,string * diff)259 bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
260                              const FunctionDefLibrary& actual, string* diff) {
261   std::unordered_map<string, const FunctionDef*> actual_index;
262   for (const FunctionDef& function : actual.function()) {
263     actual_index[function.signature().name()] = &function;
264   }
265 
266   for (const FunctionDef& expected_function : expected.function()) {
267     auto it = actual_index.find(expected_function.signature().name());
268     if (it == actual_index.end()) {
269       if (diff) {
270         *diff = absl::StrCat("Did not find expected function '",
271                              expected_function.signature().name(), "'");
272       }
273       return false;
274     }
275     if (!EqualFunctionDef(expected_function, *it->second, diff)) return false;
276     actual_index.erase(it);
277   }
278 
279   if (!actual_index.empty()) {
280     if (diff != nullptr) {
281       *diff =
282           absl::StrCat("Found unexpected function '",
283                        actual_index.begin()->second->signature().name(), "'");
284     }
285     return false;
286   }
287 
288   return true;
289 }
290 
291 #define TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(expected, actual)         \
292   do {                                                            \
293     string diff;                                                  \
294     EXPECT_TRUE(EqualFunctionDefLibrary(expected, actual, &diff)) \
295         << diff << "\nActual: " << actual.DebugString();          \
296   } while (false)
297 
298 REGISTER_OP("InputTest")
299     .Output("o: float")
__anon4e24c0570b02(::tensorflow::shape_inference::InferenceContext* c) 300     .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
301       c->set_output(0, c->UnknownShape());
302       return OkStatus();
303     });
304 
305 REGISTER_OP("InputTestShaped")
306     .Output("o: float")
__anon4e24c0570c02(::tensorflow::shape_inference::InferenceContext* c) 307     .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
308       c->set_output(0, c->Vector(2));
309       return OkStatus();
310     });
311 
312 REGISTER_OP("UnaryTest")
313     .Input("a: float")
314     .Output("o: float")
__anon4e24c0570d02(::tensorflow::shape_inference::InferenceContext* c) 315     .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
316       ::tensorflow::shape_inference::ShapeHandle o;
317       TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
318       c->set_output(0, o);
319       return OkStatus();
320     });
321 REGISTER_OP("BinaryTest")
322     .Input("a: float")
323     .Input("b: float")
324     .Output("o: float")
__anon4e24c0570e02(::tensorflow::shape_inference::InferenceContext* c) 325     .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
326       ::tensorflow::shape_inference::ShapeHandle o;
327       TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
328       c->set_output(0, o);
329       return OkStatus();
330     });
331 REGISTER_OP("BinaryTest2")
332     .Input("a: float")
333     .Input("b: float")
334     .Output("o: float")
335     .SetShapeFn(::tensorflow::shape_inference::UnknownShape);
336 
337 REGISTER_OP("AddNLikeTest")
338     .Input("inputs: N * T")
339     .Output("sum: T")
340     .Attr("N: int >= 1")
341     .Attr("T: numbertype")
342     .SetIsCommutative()
343     .SetIsAggregate();
344 
Sequencer(const GraphDefBuilder::Options & opts,const string & call_node_name)345 Node* Sequencer(const GraphDefBuilder::Options& opts,
346                 const string& call_node_name) {
347   if (opts.HaveError()) return nullptr;
348   NodeBuilder node_builder(opts.GetNameForOp("NoOp"), "NoOp",
349                            opts.op_registry());
350   return opts.WithAttr(kXlaHostTransferSequencerAttr, call_node_name)
351       .FinalizeBuilder(&node_builder);
352 }
353 
Input(const GraphDefBuilder::Options & opts)354 Node* Input(const GraphDefBuilder::Options& opts) {
355   return ops::SourceOp("InputTest", opts);
356 }
357 
InputShaped(const GraphDefBuilder::Options & opts)358 Node* InputShaped(const GraphDefBuilder::Options& opts) {
359   return ops::SourceOp("InputTestShaped", opts);
360 }
361 
KnownShapeBase(DataType dtype,absl::Span<const int> shape,const GraphDefBuilder::Options & opts)362 Node* KnownShapeBase(DataType dtype, absl::Span<const int> shape,
363                      const GraphDefBuilder::Options& opts) {
364   if (opts.HaveError()) return nullptr;
365   NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const",
366                            opts.op_registry());
367   TensorProto value;
368   value.set_dtype(dtype);
369   for (int dim : shape) {
370     value.mutable_tensor_shape()->add_dim()->set_size(dim);
371   }
372   return opts.WithAttr("value", value)
373       .WithAttr("dtype", dtype)
374       .FinalizeBuilder(&node_builder);
375 }
376 
KnownShape(absl::Span<const int> shape,const GraphDefBuilder::Options & opts)377 Node* KnownShape(absl::Span<const int> shape,
378                  const GraphDefBuilder::Options& opts) {
379   return KnownShapeBase(DT_FLOAT, shape, opts);
380 }
381 
KeyPlaceholderShape(const GraphDefBuilder::Options & opts)382 Node* KeyPlaceholderShape(const GraphDefBuilder::Options& opts) {
383   return KnownShapeBase(DT_STRING, {2}, opts);
384 }
385 
KeyPlaceholder(const string & call_node,const GraphDefBuilder::Options & opts)386 Node* KeyPlaceholder(const string& call_node,
387                      const GraphDefBuilder::Options& opts) {
388   if (opts.HaveError()) return nullptr;
389   NodeBuilder node_builder(absl::StrCat(call_node, "_key_placeholder"),
390                            "Placeholder", opts.op_registry());
391   TensorShapeProto shape;
392   shape.add_dim()->set_size(2);
393   return opts.WithAttr("shape", shape)
394       .WithAttr("dtype", DT_STRING)
395       .WithAttr("_host_compute_call_node", call_node)
396       .FinalizeBuilder(&node_builder);
397 }
398 
RecvAtHost(ops::NodeOut key_input,const string & cluster,const string & new_func_name,const string & oc_cluster,absl::Span<const DataType> dtypes,const GraphDefBuilder::Options & opts)399 Node* RecvAtHost(ops::NodeOut key_input, const string& cluster,
400                  const string& new_func_name, const string& oc_cluster,
401                  absl::Span<const DataType> dtypes,
402                  const GraphDefBuilder::Options& opts) {
403   if (opts.HaveError()) return nullptr;
404   string key = absl::StrCat("host_compute_channel_", cluster, "_",
405                             new_func_name, "_", oc_cluster);
406   string name = absl::StrCat("outside_compilation_", cluster, "_",
407                              new_func_name, "_", oc_cluster, "_recv");
408   NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"),
409                            "_XlaRecvAtHost", opts.op_registry());
410   node_builder.Input(std::move(key_input));
411   return opts.WithAttr("Toutputs", dtypes)
412       .WithAttr("key", key)
413       .WithAttr("device_ordinal", 0)
414       .WithAttr("_encapsulate", cluster)
415       .WithAttr("_outside", oc_cluster)
416       .FinalizeBuilder(&node_builder);
417 }
418 
SendFromHost(ops::NodeOut key_input,const string & cluster,const string & new_func_name,const string & oc_cluster,const std::vector<ops::NodeOut> & inputs,const GraphDefBuilder::Options & opts)419 Node* SendFromHost(ops::NodeOut key_input, const string& cluster,
420                    const string& new_func_name, const string& oc_cluster,
421                    const std::vector<ops::NodeOut>& inputs,
422                    const GraphDefBuilder::Options& opts) {
423   if (opts.HaveError()) return nullptr;
424   string key = absl::StrCat("host_compute_channel_", cluster, "_",
425                             new_func_name, "_", oc_cluster);
426   string name = absl::StrCat("outside_compilation_", cluster, "_",
427                              new_func_name, "_", oc_cluster, "_send");
428   NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"),
429                            "_XlaSendFromHost", opts.op_registry());
430   node_builder.Input(inputs);
431   node_builder.Input(std::move(key_input));
432   std::vector<DataType> dtypes;
433   for (const auto& node : inputs) {
434     dtypes.push_back(node.dt);
435   }
436   return opts.WithAttr("Tinputs", dtypes)
437       .WithAttr("key", key)
438       .WithAttr("device_ordinal", 0)
439       .WithAttr("_encapsulate", cluster)
440       .WithAttr("_outside", oc_cluster)
441       .FinalizeBuilder(&node_builder);
442 }
443 
Unary(ops::NodeOut a,const GraphDefBuilder::Options & opts)444 Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
445   return ops::UnaryOp("UnaryTest", std::move(a), opts);
446 }
447 
Binary(ops::NodeOut a,ops::NodeOut b,const GraphDefBuilder::Options & opts)448 Node* Binary(ops::NodeOut a, ops::NodeOut b,
449              const GraphDefBuilder::Options& opts) {
450   return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts);
451 }
452 
BinaryUnknownShape(ops::NodeOut a,ops::NodeOut b,const GraphDefBuilder::Options & opts)453 Node* BinaryUnknownShape(ops::NodeOut a, ops::NodeOut b,
454                          const GraphDefBuilder::Options& opts) {
455   return ops::BinaryOp("BinaryTest2", std::move(a), std::move(b), opts);
456 }
457 
AddNLike(const std::vector<ops::NodeOut> & inputs,const GraphDefBuilder::Options & opts)458 Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
459                const GraphDefBuilder::Options& opts) {
460   if (opts.HaveError()) return nullptr;
461   NodeBuilder node_builder(opts.GetNameForOp("AddN"), "AddNLikeTest",
462                            opts.op_registry());
463   node_builder.Input(inputs);
464   return opts.FinalizeBuilder(&node_builder);
465 }
466 
ArgOp(int index,DataType type,const GraphDefBuilder::Options & opts)467 Node* ArgOp(int index, DataType type, const GraphDefBuilder::Options& opts) {
468   return ops::SourceOp("_Arg",
469                        opts.WithAttr("T", type).WithAttr("index", index));
470 }
471 
RetOp(int index,ops::NodeOut a,const GraphDefBuilder::Options & opts)472 Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) {
473   if (opts.HaveError()) return nullptr;
474   NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval",
475                            opts.op_registry());
476   node_builder.Input(std::move(a)).Attr("index", index);
477   return opts.FinalizeBuilder(&node_builder);
478 }
479 
Encapsulate(GraphDef * graphdef,FunctionDefLibrary * library,const std::vector<string> & encapsulated_functions)480 Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
481                    const std::vector<string>& encapsulated_functions) {
482   Status s;
483   // Convert the GraphDef to a Graph
484   std::unique_ptr<FunctionLibraryDefinition> lib_def(
485       new FunctionLibraryDefinition(OpRegistry::Global(), *library));
486   GraphConstructorOptions options;
487   options.allow_internal_ops = true;
488   std::unique_ptr<Graph> graph(new Graph(lib_def.get()));
489   s = ConvertGraphDefToGraph(options, *graphdef, graph.get());
490   if (!s.ok()) return s;
491 
492   s = PerformStaticShapeInferenceBeforeEncapsulation(graph.get());
493   if (!s.ok()) return s;
494 
495   // Create FunctionLibraryRuntime.
496   SessionOptions session_options;
497   std::vector<std::unique_ptr<Device>> devices;
498   TF_CHECK_OK(DeviceFactory::AddDevices(
499       session_options, "/job:localhost/replica:0/task:0", &devices));
500   OptimizerOptions opts;
501   auto device_mgr = std::make_unique<StaticDeviceMgr>(std::move(devices));
502   auto pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
503       device_mgr.get(), Env::Default(), /*config=*/nullptr,
504       TF_GRAPH_DEF_VERSION, lib_def.get(), opts,
505       /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
506   auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
507 
508   std::unique_ptr<Graph> graph_out;
509   s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph,
510                                       /*rewrite_subgraph_fn=*/{},
511                                       /*reuse_existing_functions=*/false,
512                                       &graph_out, lib_def.get());
513   if (!s.ok()) return s;
514 
515   std::unordered_map<string, XlaClusterInfo> clusters;
516   for (const auto& func : encapsulated_functions) {
517     Node* xla_computation_node;
518     for (Node* n : graph_out->nodes()) {
519       if (n->name() == func) {
520         xla_computation_node = n;
521       }
522     }
523     if (!xla_computation_node) {
524       return errors::Internal("Cannot find node ", func);
525     }
526     NameAttrList func_name_attrs;
527     func_name_attrs.set_name(func);
528     clusters.emplace(func,
529                      XlaClusterInfo{func, func_name_attrs, xla_computation_node,
530                                     std::map<string, int>{}});
531   }
532   bool modified;
533   s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters,
534                                 graph_out.get(), flr, lib_def.get(), &modified);
535   if (!s.ok()) return s;
536 
537   GraphDef graphdef_out;
538   graph_out->ToGraphDef(&graphdef_out);
539   graphdef->Swap(&graphdef_out);
540 
541   *library = lib_def->ToProto();
542   // Remove "_xla_inferred_shapes" attr. They are added by
543   // `PerformStaticShapeInferenceBeforeEncapsulation`.
544   for (FunctionDef& fdef : *library->mutable_function()) {
545     for (NodeDef& node_def : *fdef.mutable_node_def()) {
546       node_def.mutable_attr()->erase("_xla_inferred_shapes");
547     }
548   }
549 
550   return s;
551 }
552 
Encapsulate(GraphDef * graphdef,FunctionDefLibrary * library)553 Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
554   std::vector<string> encapsulated_functions;
555   return Encapsulate(graphdef, library, encapsulated_functions);
556 }
557 
558 // If there are no marked nodes, funcification should be a no-op.
TEST(EncapsulateSubgraphsTest,NoFunctions)559 TEST(EncapsulateSubgraphsTest, NoFunctions) {
560   GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
561 
562   Node* a = Input(builder.opts().WithName("A"));
563   Node* b = Input(builder.opts().WithName("B"));
564   Node* c = Unary(a, builder.opts().WithName("C"));
565   Binary(b, c, builder.opts().WithName("D"));
566 
567   GraphDef graphdef_in;
568   FunctionDefLibrary library_in;
569   TF_EXPECT_OK(builder.ToGraphDef(&graphdef_in));
570   *library_in.add_function() = test::function::XTimesTwo();
571 
572   GraphDef graphdef_out = graphdef_in;
573   FunctionDefLibrary library_out = library_in;
574   TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out));
575 
576   // If there are no marked nodes, funcification should be a no-op.
577   TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out);
578   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out);
579 }
580 
581 // Test with one function to transform.
TEST(EncapsulateSubgraphsTest,OneFunction)582 TEST(EncapsulateSubgraphsTest, OneFunction) {
583   FunctionDefLibrary library;
584   GraphDef graphdef;
585 
586   {
587     *library.add_function() = test::function::XTimesTwo();
588 
589     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
590     Node* a = Input(b1.opts().WithName("A"));
591     Node* b = Input(b1.opts().WithName("B"));
592     // Give nodes 'c' and 'd' names that collide after lowercasing.
593     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
594     Node* d = Binary(b, c,
595                      b1.opts().WithName("c").WithControlInput(c).WithAttr(
596                          "_encapsulate", "F1"));
597     Binary(a, d, b1.opts().WithName("E"));
598     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
599   }
600 
601   TF_EXPECT_OK(Encapsulate(&graphdef, &library));
602 
603   FunctionDefLibrary library_expected;
604   GraphDef graphdef_expected;
605 
606   *library_expected.add_function() = test::function::XTimesTwo();
607   *library_expected.add_function() = FunctionDefHelper::Create(
608       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"c_0_retval:float"}, {},
609       {
610           {{"C"}, "UnaryTest", {"a_0_arg"}},
611           {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
612       },
613       {{"c_0_retval", "c:o:0"}});
614 
615   {
616     std::unique_ptr<FunctionLibraryDefinition> lib_def(
617         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
618     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
619     Node* a = Input(b2.opts().WithName("A"));
620     Node* b = Input(b2.opts().WithName("B"));
621 
622     NodeBuilder node_builder("F1", "F1", lib_def.get());
623     node_builder.Input(a).Input(b);
624     Node* call = b2.opts().FinalizeBuilder(&node_builder);
625 
626     Binary(a, call, b2.opts().WithName("E"));
627     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
628   }
629 
630   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
631   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
632 }
633 
634 // Test with two functions to transform.
TEST(EncapsulateSubgraphsTest,TwoFunctions)635 TEST(EncapsulateSubgraphsTest, TwoFunctions) {
636   FunctionDefLibrary library;
637   GraphDef graphdef;
638 
639   {
640     *library.add_function() = test::function::XTimesTwo();
641 
642     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
643     Node* a = Input(b1.opts().WithName("A"));
644     Node* b = Input(b1.opts().WithName("B"));
645     Node* control = Input(b1.opts().WithName("Control"));
646     Node* c =
647         Unary(a, b1.opts().WithName("C").WithControlInput(control).WithAttr(
648                      "_encapsulate", "F1"));
649     Node* d = Binary(b, c,
650                      b1.opts().WithName("D").WithControlInput(control).WithAttr(
651                          "_encapsulate", "F2"));
652     Binary(a, d, b1.opts().WithName("E"));
653     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
654   }
655 
656   TF_EXPECT_OK(Encapsulate(&graphdef, &library));
657 
658   FunctionDefLibrary library_expected;
659   GraphDef graphdef_expected;
660 
661   *library_expected.add_function() = test::function::XTimesTwo();
662   *library_expected.add_function() = FunctionDefHelper::Create(
663       "F1", {"a_0_arg:float"}, {"c_0_retval:float"}, {},
664       {
665           {{"C"}, "UnaryTest", {"a_0_arg"}},
666       },
667       {{"c_0_retval", "C:o:0"}});
668   *library_expected.add_function() = FunctionDefHelper::Create(
669       "F2", {"b_0_arg:float", "c_0_arg:float"}, {"d_0_retval:float"}, {},
670       {
671           {{"D"}, "BinaryTest", {"b_0_arg", "c_0_arg"}},
672       },
673       {{"d_0_retval", "D:o:0"}});
674 
675   {
676     std::unique_ptr<FunctionLibraryDefinition> lib_def(
677         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
678     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
679     Node* a = Input(b2.opts().WithName("A"));
680     Node* b = Input(b2.opts().WithName("B"));
681     Node* control = Input(b2.opts().WithName("Control"));
682 
683     NodeBuilder nb("F1", "F1", lib_def.get());
684     nb.Input(a).ControlInput(control);
685     Node* call1 = b2.opts().FinalizeBuilder(&nb);
686 
687     NodeBuilder nb2("F2", "F2", lib_def.get());
688     nb2.Input(b).Input(call1).ControlInput(control);
689     Node* call2 = b2.opts().FinalizeBuilder(&nb2);
690 
691     Binary(a, call2, b2.opts().WithName("E"));
692     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
693   }
694 
695   // If there are no marked nodes, funcification should be a no-op.
696   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
697   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
698 }
699 
700 // Returns a vector of node names in 'graph', sorted by name.
GraphNodes(const Graph & graph)701 std::vector<string> GraphNodes(const Graph& graph) {
702   std::vector<string> nodes;
703   for (const auto& node : graph.nodes()) {
704     if (!node->IsSource() && !node->IsSink()) {
705       nodes.push_back(node->name());
706     }
707   }
708   std::sort(nodes.begin(), nodes.end());
709   return nodes;
710 }
711 
712 // Returns a sorted vector of (src, dst) edges in 'graph'.
GraphEdges(const Graph & graph)713 std::vector<std::pair<string, string>> GraphEdges(const Graph& graph) {
714   std::vector<std::pair<string, string>> edges;
715   for (const Edge* edge : graph.edges()) {
716     if (edge->src()->IsSource() || edge->dst()->IsSink()) continue;
717     edges.emplace_back(
718         absl::StrCat(edge->src()->name(), ":", edge->src_output()),
719         absl::StrCat(edge->dst()->name(), ":", edge->dst_input()));
720   }
721   std::sort(edges.begin(), edges.end());
722   return edges;
723 }
724 
TEST(EncapsulateSubgraphsTest,InputDeduplication)725 TEST(EncapsulateSubgraphsTest, InputDeduplication) {
726   Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
727       "/job:localhost/replica:0/task:0/cpu:0");
728   auto x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT);
729   auto add1 = ops::Add(root.WithOpName("add1"), x, x);
730   add1.node()->AddAttr("_cluster", "cluster1");
731   auto add2 = ops::Add(root.WithOpName("add2"), add1, add1);
732   add2.node()->AddAttr("_cluster", "cluster2");
733   auto out = ops::Mul(root.WithOpName("mul"), add1, add2);
734 
735   Graph graph_before_encapsulation(OpRegistry::Global());
736   TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation));
737 
738   FunctionLibraryDefinition library(OpRegistry::Global(), {});
739   std::unique_ptr<Graph> graph;
740   TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
741       "_cluster", graph_before_encapsulation,
742       /*rewrite_subgraph_fn=*/{},
743       /*reuse_existing_functions=*/false, &graph, &library));
744 
745   std::vector<string> expected_nodes = {"cluster1", "cluster2", "mul", "x"};
746   EXPECT_EQ(expected_nodes, GraphNodes(*graph));
747 
748   std::vector<std::pair<string, string>> expected_edges = {
749       {"cluster1:0", "cluster2:0"},
750       {"cluster1:0", "mul:0"},
751       {"cluster2:0", "mul:1"},
752       {"x:0", "cluster1:0"}};
753   EXPECT_EQ(expected_edges, GraphEdges(*graph));
754 }
755 
FindNodeByName(const Graph & graph,const string & name)756 const Node* FindNodeByName(const Graph& graph, const string& name) {
757   for (const Node* node : graph.nodes()) {
758     if (node->name() == name) return node;
759   }
760   return nullptr;
761 }
762 
HasGuaranteeConstAttr(const Node & n)763 bool HasGuaranteeConstAttr(const Node& n) {
764   bool is_guaranteed_constant = false;
765   if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant",
766                    &is_guaranteed_constant)
767            .ok()) {
768     return false;
769   }
770   return is_guaranteed_constant;
771 }
772 
TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest,Simple)773 TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
774   Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
775       "/job:localhost/replica:0/task:0/cpu:0");
776   auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
777   auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
778   auto const_guarantee_x2 =
779       ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2);
780   auto const_guarantee_x1 =
781       ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1);
782   auto add1 =
783       ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_guarantee_x2);
784   add1.node()->AddAttr("_encapsulate", "encapsulate1");
785 
786   Graph graph_before(OpRegistry::Global());
787   TF_ASSERT_OK(root.ToGraph(&graph_before));
788 
789   std::unique_ptr<Graph> graph_after;
790   FunctionLibraryDefinition library(OpRegistry::Global(), {});
791   int guaranteed_consts = 0;
792   TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
793       "_encapsulate", graph_before,
794       /*rewrite_subgraph_fn=*/
795       [&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
796                            std::unique_ptr<Graph>* graph_ptr,
797                            std::vector<int>* input_permutation,
798                            std::vector<int>* output_permutation,
799                            NodeDef* call_def) {
800         Graph* graph = graph_ptr->get();
801         for (const Node* n : graph->nodes()) {
802           if (n->type_string() == "_Arg" &&
803               absl::StartsWith(n->name(), "const")) {
804             ++guaranteed_consts;
805             EXPECT_TRUE(HasGuaranteeConstAttr(*n));
806           } else {
807             EXPECT_FALSE(HasGuaranteeConstAttr(*n));
808           }
809         }
810         return OkStatus();
811       },
812       /*reuse_existing_functions=*/false, &graph_after, &library));
813   EXPECT_EQ(2, guaranteed_consts);
814 }
815 
TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest,Add)816 TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
817   Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
818       "/job:localhost/replica:0/task:0/cpu:0");
819   auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
820   auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
821   auto const_guarantee_x1 =
822       ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1);
823   auto const_guarantee_x2 =
824       ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2);
825   auto const_guarantee_add1 = ops::Add(root.WithOpName("const_guarantee_add1"),
826                                        const_guarantee_x1, const_guarantee_x2);
827   auto add2 = ops::Add(root.WithOpName("add2"), const_guarantee_x1, x2);
828   auto mul1 = ops::Mul(root.WithOpName("mul1"), const_guarantee_add1, add2);
829   mul1.node()->AddAttr("_encapsulate", "encapsulate1");
830 
831   Graph graph_before(OpRegistry::Global());
832   TF_ASSERT_OK(root.ToGraph(&graph_before));
833 
834   std::unique_ptr<Graph> graph_after;
835   FunctionLibraryDefinition library(OpRegistry::Global(), {});
836   int guaranteed_consts = 0;
837   TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
838       "_encapsulate", graph_before,
839       /*rewrite_subgraph_fn=*/
840       [&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
841                            std::unique_ptr<Graph>* graph_ptr,
842                            std::vector<int>* input_permutation,
843                            std::vector<int>* output_permutation,
844                            NodeDef* call_def) {
845         Graph* graph = graph_ptr->get();
846         for (const Node* n : graph->nodes()) {
847           if (n->type_string() == "_Arg" &&
848               absl::StartsWith(n->name(), "const")) {
849             ++guaranteed_consts;
850             EXPECT_TRUE(HasGuaranteeConstAttr(*n));
851           } else {
852             EXPECT_FALSE(HasGuaranteeConstAttr(*n));
853           }
854         }
855         return OkStatus();
856       },
857       /*reuse_existing_functions=*/false, &graph_after, &library));
858   // Only 1 runtime const, which is const_guarantee_add1. Add2 has one const
859   // and another non-const, so overall non-const.
860   EXPECT_EQ(1, guaranteed_consts);
861 }
862 
863 // Test with one function to transform and one outside_compilation cluster.
TEST(EncapsulateSubgraphsTest,OneFunctionOneOutside)864 TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
865   FunctionDefLibrary library;
866   GraphDef graphdef;
867 
868   {
869     *library.add_function() = test::function::XTimesTwo();
870 
871     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
872     Node* a = Input(b1.opts().WithName("A"));
873     Node* b = Input(b1.opts().WithName("B"));
874     // Give nodes 'c' and 'd' names that collide after lowercasing.
875     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
876     Node* d = Binary(b, c,
877                      b1.opts().WithName("c").WithControlInput(c).WithAttr(
878                          "_encapsulate", "F1"));
879     Node* e = Binary(c, d,
880                      b1.opts()
881                          .WithName("E")
882                          .WithControlInputs({b, d})
883                          .WithAttr("_encapsulate", "F1")
884                          .WithAttr("_outside", "O1"));
885     Node* f = Binary(c, e,
886                      b1.opts().WithName("F").WithControlInput(e).WithAttr(
887                          "_encapsulate", "F1"));
888     Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
889     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
890   }
891 
892   std::vector<string> encapsulated_functions{"F1"};
893   TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
894 
895   FunctionDefLibrary library_expected;
896   GraphDef graphdef_expected;
897 
898   {
899     GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
900     Node* key_constant = KeyPlaceholder("F1", shape.opts());
901     Node* recv = RecvAtHost(
902         ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
903         shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
904     Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
905                      shape.opts()
906                          .WithName("E")
907                          .WithAttr("_encapsulate", "F1")
908                          .WithAttr("_outside", "O1"));
909     SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
910                  shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
911     TF_EXPECT_OK(
912         AddGraphDefToFunctionLibrary(shape, "F1_F1_O1", &library_expected));
913   }
914 
915   NameAttrList shape_inference_graph;
916   shape_inference_graph.set_name(
917       "_outside_compilation_shape_inference_F1_F1_O1");
918   *library_expected.add_function() = test::function::XTimesTwo();
919   *library_expected.add_function() = FunctionDefHelper::Create(
920       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
921       {
922           {{"C"}, "UnaryTest", {"a_0_arg"}},
923           {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
924           {{"F"},
925            "BinaryTest",
926            {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
927            {},
928            {"outside_compilation_O1_host_compute"}},
929           {{"outside_compilation_O1_host_compute"},
930            "XlaHostCompute",
931            {"C:o:0", "c:o:0"},
932            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
933             {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
934             {"ancestors", absl::Span<const string>({})},
935             {"key", "host_compute_channel_F1_F1_O1"},
936             {"send_key", ""},
937             {"recv_key", ""},
938             {"shape_inference_graph", shape_inference_graph},
939             {"tpu_core", 0},
940             {"cost_estimate_ns", 1000000},
941             {"shapes", absl::Span<const DataType>({})},
942             {"_outside_compilation_subgraph", "O1"},
943             {"_xla_token_input_nodes",
944              absl::Span<const string>({"_xla_token_arg_node"})},
945             {"_xla_original_oc_node_name",
946              "outside_compilation_O1_host_compute"}},
947            {"c"}},
948       },
949       {{"f_0_retval_retval", "F:o:0"}});
950 
951   {
952     std::unique_ptr<FunctionLibraryDefinition> lib_def(
953         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
954     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
955     Node* a = Input(b2.opts().WithName("A"));
956     Node* b = Input(b2.opts().WithName("B"));
957 
958     Node* key_constant =
959         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
960     Node* recv = RecvAtHost(
961         ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
962         b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
963     Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
964                      b2.opts()
965                          .WithName("E")
966                          .WithControlInputs({recv})
967                          .WithAttr("_encapsulate", "F1")
968                          .WithAttr("_outside", "O1"));
969     Node* send =
970         SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
971                      b2.opts().WithControlInput(e).WithAttr(
972                          kXlaHasHostTransferAttrName, true));
973 
974     Node* s = Sequencer(
975         b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
976         "F1");
977 
978     NodeBuilder node_builder("F1", "F1", lib_def.get());
979     node_builder.Input(a).Input(b);
980     Node* call =
981         b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder);
982 
983     Binary(a, call, b2.opts().WithName("G").WithControlInputs({call}));
984     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
985   }
986 
987   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
988   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
989 }
990 
991 // Test with one function to transform and two outside_compilation clusters.
TEST(EncapsulateSubgraphsTest,OneFunctionTwoOutside)992 TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
993   FunctionDefLibrary library;
994   GraphDef graphdef;
995 
996   {
997     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
998     Node* a = Input(b1.opts().WithName("A"));
999     Node* b = Input(b1.opts().WithName("B"));
1000     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1001     Node* d =
1002         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1003     Node* e = Binary(c, d,
1004                      b1.opts()
1005                          .WithName("E")
1006                          .WithControlInputs({b, d})
1007                          .WithAttr("_encapsulate", "F1")
1008                          .WithAttr("_outside", "O1"));
1009     Node* f = Binary(c, e,
1010                      b1.opts().WithName("F").WithControlInput(e).WithAttr(
1011                          "_encapsulate", "F1"));
1012     Node* g = Binary(e, f,
1013                      b1.opts()
1014                          .WithName("G")
1015                          .WithControlInputs({e, f})
1016                          .WithAttr("_encapsulate", "F1")
1017                          .WithAttr("_outside", "O2"));
1018     Node* h = Binary(d, e,
1019                      b1.opts()
1020                          .WithName("H")
1021                          .WithAttr("_encapsulate", "F1")
1022                          .WithAttr("_outside", "O2"));
1023     Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F1"));
1024     Binary(g, i, b1.opts().WithName("J"));
1025     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1026   }
1027 
1028   std::vector<string> encapsulated_functions{"F1"};
1029   TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1030 
1031   FunctionDefLibrary library_expected;
1032   GraphDef graphdef_expected;
1033 
1034   {
1035     GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
1036     Node* key_constant = KeyPlaceholder("F1", shape1.opts());
1037     Node* recv = RecvAtHost(
1038         ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
1039         shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1040     Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
1041                      shape1.opts()
1042                          .WithName("E")
1043                          .WithAttr("_encapsulate", "F1")
1044                          .WithAttr("_outside", "O1"));
1045     SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
1046                  shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1047     TF_EXPECT_OK(
1048         AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
1049   }
1050 
1051   {
1052     GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
1053     Node* key_constant = KeyPlaceholder("F1", shape2.opts());
1054     Node* recv1 = RecvAtHost(
1055         ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
1056         shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1057     Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
1058                      shape2.opts()
1059                          .WithName("E")
1060                          .WithAttr("_encapsulate", "F1")
1061                          .WithAttr("_outside", "O1"));
1062     Node* recv2 = RecvAtHost(
1063         ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT, DT_FLOAT},
1064         shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1065     Node* g = Binary(e, ops::NodeOut(recv2, 0),
1066                      shape2.opts()
1067                          .WithName("G")
1068                          .WithAttr("_encapsulate", "F1")
1069                          .WithAttr("_outside", "O2"));
1070     Node* h = Binary(ops::NodeOut(recv2, 1), e,
1071                      shape2.opts()
1072                          .WithName("H")
1073                          .WithAttr("_encapsulate", "F1")
1074                          .WithAttr("_outside", "O2"));
1075     SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {g, h},
1076                  shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1077     TF_EXPECT_OK(
1078         AddGraphDefToFunctionLibrary(shape2, "F1_F1_O2", &library_expected));
1079   }
1080 
1081   NameAttrList shape_inference_graph1, shape_inference_graph2;
1082   shape_inference_graph1.set_name(
1083       "_outside_compilation_shape_inference_F1_F1_O1");
1084   shape_inference_graph2.set_name(
1085       "_outside_compilation_shape_inference_F1_F1_O2");
1086   *library_expected.add_function() = FunctionDefHelper::Create(
1087       "F1", {"a_0_arg:float", "b_0_arg:float"},
1088       {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {},
1089       {
1090           {{"C"}, "UnaryTest", {"a_0_arg"}},
1091           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
1092           {{"I"},
1093            "UnaryTest",
1094            {"outside_compilation_O2_host_compute:outputs:1"}},
1095           {{"F"},
1096            "BinaryTest",
1097            {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
1098            {},
1099            {"outside_compilation_O1_host_compute"}},
1100           {{"outside_compilation_O2_host_compute"},
1101            "XlaHostCompute",
1102            {"F:o:0", "D:o:0"},
1103            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
1104             {"Toutputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
1105             {"ancestors", absl::Span<const string>({})},
1106             {"key", "host_compute_channel_F1_F1_O2"},
1107             {"send_key", ""},
1108             {"recv_key", ""},
1109             {"shape_inference_graph", shape_inference_graph2},
1110             {"tpu_core", 0},
1111             {"cost_estimate_ns", 1000000},
1112             {"shapes", absl::Span<const DataType>({})},
1113             {"_outside_compilation_subgraph", "O2"},
1114             {"_xla_token_input_nodes",
1115              absl::Span<const string>({"_xla_token_arg_node",
1116                                        "outside_compilation_O1_host_compute"})},
1117             {"_xla_original_oc_node_name",
1118              "outside_compilation_O2_host_compute"}},
1119            {"F", "outside_compilation_O1_host_compute"}},
1120           {{"outside_compilation_O1_host_compute"},
1121            "XlaHostCompute",
1122            {"C:o:0", "D:o:0"},
1123            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
1124             {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1125             {"ancestors", absl::Span<const string>({})},
1126             {"key", "host_compute_channel_F1_F1_O1"},
1127             {"send_key", ""},
1128             {"recv_key", ""},
1129             {"shape_inference_graph", shape_inference_graph1},
1130             {"tpu_core", 0},
1131             {"cost_estimate_ns", 1000000},
1132             {"shapes", absl::Span<const DataType>({})},
1133             {"_outside_compilation_subgraph", "O1"},
1134             {"_xla_token_input_nodes",
1135              absl::Span<const string>({"_xla_token_arg_node"})},
1136             {"_xla_original_oc_node_name",
1137              "outside_compilation_O1_host_compute"}},
1138            {"D"}},
1139       },
1140       {{"g_0_retval_retval", "outside_compilation_O2_host_compute:outputs:0"},
1141        {"i_0_retval_retval", "I:o:0"}});
1142 
1143   {
1144     std::unique_ptr<FunctionLibraryDefinition> lib_def(
1145         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
1146     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
1147     Node* a = Input(b2.opts().WithName("A"));
1148     Node* b = Input(b2.opts().WithName("B"));
1149 
1150     Node* key_constant =
1151         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
1152     Node* recv1 = RecvAtHost(
1153         ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
1154         b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1155     Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
1156                      b2.opts()
1157                          .WithName("E")
1158                          .WithControlInputs({recv1})
1159                          .WithAttr("_encapsulate", "F1")
1160                          .WithAttr("_outside", "O1"));
1161     Node* send1 =
1162         SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
1163                      b2.opts().WithControlInput(e).WithAttr(
1164                          kXlaHasHostTransferAttrName, true));
1165 
1166     Node* recv2 = RecvAtHost(
1167         ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT, DT_FLOAT},
1168         b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1169     Node* g = Binary(e, ops::NodeOut(recv2, 0),
1170                      b2.opts()
1171                          .WithName("G")
1172                          .WithControlInputs({recv2, e})
1173                          .WithAttr("_encapsulate", "F1")
1174                          .WithAttr("_outside", "O2"));
1175     Node* h = Binary(ops::NodeOut(recv2, 1), e,
1176                      b2.opts()
1177                          .WithName("H")
1178                          .WithAttr("_encapsulate", "F1")
1179                          .WithAttr("_outside", "O2"));
1180     Node* send2 =
1181         SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {g, h},
1182                      b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1183 
1184     Node* s = Sequencer(b2.opts()
1185                             .WithName("F1_sequencer")
1186                             .WithControlInputs({recv1, send1, recv2, send2}),
1187                         "F1");
1188 
1189     NodeBuilder node_builder("F1", "F1", lib_def.get());
1190     node_builder.Input(a).Input(b);
1191     Node* call =
1192         b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder);
1193 
1194     Binary(ops::NodeOut(call, 0), ops::NodeOut(call, 1),
1195            b2.opts().WithName("J"));
1196     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
1197   }
1198   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
1199   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
1200 }
1201 
1202 // Test with two functions to transform, each with one outside_compilation
1203 // cluster.
TEST(EncapsulateSubgraphsTest,TwoFunctionsTwoOutside)1204 TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
1205   FunctionDefLibrary library;
1206   GraphDef graphdef;
1207 
1208   {
1209     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
1210     Node* a = InputShaped(b1.opts().WithName("A"));
1211     Node* b = InputShaped(b1.opts().WithName("B"));
1212     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1213     Node* d =
1214         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1215     Node* e = Binary(c, d,
1216                      b1.opts()
1217                          .WithName("E")
1218                          .WithControlInputs({b, d})
1219                          .WithAttr("_encapsulate", "F1")
1220                          .WithAttr("_outside", "O1"));
1221     Node* f = Binary(c, e,
1222                      b1.opts().WithName("F").WithControlInput(e).WithAttr(
1223                          "_encapsulate", "F1"));
1224     Node* g = Binary(e, f,
1225                      b1.opts().WithName("G").WithControlInputs({e, f}).WithAttr(
1226                          "_encapsulate", "F2"));
1227     Node* h = Binary(d, g,
1228                      b1.opts()
1229                          .WithName("H")
1230                          .WithAttr("_encapsulate", "F2")
1231                          .WithAttr("_outside", "O1"));
1232     Node* i =
1233         Binary(f, h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2"));
1234     Binary(g, i, b1.opts().WithName("J"));
1235     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1236   }
1237 
1238   std::vector<string> encapsulated_functions{"F1", "F2"};
1239   TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1240 
1241   FunctionDefLibrary library_expected;
1242   GraphDef graphdef_expected;
1243 
1244   TensorShapeProto shape_proto_expected;
1245   shape_proto_expected.add_dim()->set_size(2);
1246 
1247   *library_expected.add_function() = FunctionDefHelper::Create(
1248       "F1", {"a_0_arg:float", "b_0_arg:float"},
1249       {"e_0_retval_retval:float", "f_0_retval_retval:float",
1250        "d_0_retval_retval:float"},
1251       {},
1252       {
1253           {{"C"}, "UnaryTest", {"a_0_arg"}},
1254           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
1255           {{"F"},
1256            "BinaryTest",
1257            {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
1258            {},
1259            {"outside_compilation_O1_host_compute"}},
1260           {{"outside_compilation_O1_host_compute"},
1261            "XlaHostCompute",
1262            {"C:o:0", "D:o:0"},
1263            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
1264             {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1265             {"ancestors", absl::Span<const string>({})},
1266             {"key", "host_compute_channel_F1_F1_O1"},
1267             {"send_key", ""},
1268             {"recv_key", ""},
1269             {"shape_inference_graph", NameAttrList()},
1270             {"tpu_core", 0},
1271             {"cost_estimate_ns", 1000000},
1272             {"shapes",
1273              absl::Span<const TensorShapeProto>({shape_proto_expected})},
1274             {"_outside_compilation_subgraph", "O1"},
1275             {"_xla_token_input_nodes",
1276              absl::Span<const string>({"_xla_token_arg_node"})},
1277             {"_xla_original_oc_node_name",
1278              "outside_compilation_O1_host_compute"}},
1279            {"D"}},
1280       },
1281       {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
1282        {"d_0_retval_retval", "D:o:0"},
1283        {"f_0_retval_retval", "F:o:0"}});
1284 
1285   *library_expected.add_function() = FunctionDefHelper::Create(
1286       "F2", {"e_0_arg:float", "f_0_arg:float", "d_0_arg:float"},
1287       {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {},
1288       {
1289           {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
1290           {{"I"},
1291            "BinaryTest",
1292            {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}},
1293           {{"outside_compilation_O1_host_compute"},
1294            "XlaHostCompute",
1295            {"d_0_arg", "G:o:0"},
1296            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
1297             {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1298             {"ancestors", absl::Span<const string>({})},
1299             {"key", "host_compute_channel_F2_F2_O1"},
1300             {"send_key", ""},
1301             {"recv_key", ""},
1302             {"shape_inference_graph", NameAttrList()},
1303             {"tpu_core", 0},
1304             {"cost_estimate_ns", 1000000},
1305             {"shapes",
1306              absl::Span<const TensorShapeProto>({shape_proto_expected})},
1307             {"_outside_compilation_subgraph", "O1"},
1308             {"_xla_token_input_nodes",
1309              absl::Span<const string>({"_xla_token_arg_node"})},
1310             {"_xla_original_oc_node_name",
1311              "outside_compilation_O1_host_compute"}}},
1312       },
1313       {{"g_0_retval_retval", "G:o:0"}, {"i_0_retval_retval", "I:o:0"}});
1314 
1315   {
1316     std::unique_ptr<FunctionLibraryDefinition> lib_def(
1317         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
1318     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
1319     Node* a = InputShaped(b2.opts().WithName("A"));
1320     Node* b = InputShaped(b2.opts().WithName("B"));
1321 
1322     Node* key_constant1 =
1323         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
1324     Node* recv1 = RecvAtHost(
1325         ops::NodeOut(key_constant1, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
1326         b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1327     Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
1328                      b2.opts()
1329                          .WithName("E")
1330                          .WithControlInputs({recv1})
1331                          .WithAttr("_encapsulate", "F1")
1332                          .WithAttr("_outside", "O1"));
1333     Node* send1 =
1334         SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "F1", "O1", {e},
1335                      b2.opts().WithControlInput(e).WithAttr(
1336                          kXlaHasHostTransferAttrName, true));
1337     Node* s1 = Sequencer(
1338         b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
1339         "F1");
1340 
1341     NodeBuilder node_builder1("F1", "F1", lib_def.get());
1342     node_builder1.Input(a).Input(b);
1343     Node* call1 =
1344         b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1);
1345 
1346     Node* key_constant2 =
1347         KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder"));
1348     Node* recv2 = RecvAtHost(
1349         ops::NodeOut(key_constant2, 0), "F2", "F2", "O1", {DT_FLOAT, DT_FLOAT},
1350         b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1351     Node* h = Binary(recv2, ops::NodeOut(recv2, 1),
1352                      b2.opts()
1353                          .WithName("H")
1354                          .WithAttr("_encapsulate", "F2")
1355                          .WithAttr("_outside", "O1"));
1356     Node* send2 =
1357         SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "F2", "O1", {h},
1358                      b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1359 
1360     Node* s2 = Sequencer(
1361         b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}),
1362         "F2");
1363     NodeBuilder node_builder2("F2", "F2", lib_def.get());
1364     node_builder2.Input(call1)
1365         .Input(ops::NodeOut(call1, 1))
1366         .Input(ops::NodeOut(call1, 2));
1367     Node* call2 = b2.opts()
1368                       .WithControlInputs({s2, call1})
1369                       .FinalizeBuilder(&node_builder2);
1370     Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J"));
1371     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
1372   }
1373 
1374   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
1375   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
1376 }
1377 
1378 // Test with two functions to transform, each with one outside_compilation
1379 // cluster, with the dependency between them purely from an outside_compilation
1380 // edge.
TEST(EncapsulateSubgraphsTest,TwoFunctionsTwoOutsideDependencyFromOutside)1381 TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
1382   FunctionDefLibrary library;
1383   GraphDef graphdef;
1384 
1385   {
1386     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
1387     Node* a = InputShaped(b1.opts().WithName("A"));
1388     Node* b = InputShaped(b1.opts().WithName("B"));
1389     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1390     Node* d =
1391         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1392     Node* e = Binary(c, d,
1393                      b1.opts()
1394                          .WithName("E")
1395                          .WithControlInputs({b, d})
1396                          .WithAttr("_encapsulate", "F1")
1397                          .WithAttr("_outside", "O1"));
1398     Node* f = Binary(c, e,
1399                      b1.opts().WithName("F").WithControlInput(e).WithAttr(
1400                          "_encapsulate", "F1"));
1401     Node* g =
1402         Binary(a, b, b1.opts().WithName("G").WithAttr("_encapsulate", "F2"));
1403     Node* h = Unary(g, b1.opts()
1404                            .WithName("H")
1405                            .WithAttr("_encapsulate", "F2")
1406                            .WithAttr("_outside", "O1"));
1407     Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2"));
1408     Binary(f, i, b1.opts().WithName("J"));
1409     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1410   }
1411 
1412   std::vector<string> encapsulated_functions{"F1", "F2"};
1413   TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1414 
1415   FunctionDefLibrary library_expected;
1416   GraphDef graphdef_expected;
1417   TensorShapeProto shape_proto_expected;
1418   shape_proto_expected.add_dim()->set_size(2);
1419 
1420   *library_expected.add_function() = FunctionDefHelper::Create(
1421       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
1422       {
1423           {{"C"}, "UnaryTest", {"a_0_arg"}},
1424           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
1425           {{"F"},
1426            "BinaryTest",
1427            {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
1428            {},
1429            {"outside_compilation_O1_host_compute"}},
1430           {{"outside_compilation_O1_host_compute"},
1431            "XlaHostCompute",
1432            {"C:o:0", "D:o:0"},
1433            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
1434             {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1435             {"ancestors", absl::Span<const string>({})},
1436             {"key", "host_compute_channel_F1_F1_O1"},
1437             {"send_key", ""},
1438             {"recv_key", ""},
1439             {"shape_inference_graph", NameAttrList()},
1440             {"tpu_core", 0},
1441             {"cost_estimate_ns", 1000000},
1442             {"shapes",
1443              absl::Span<const TensorShapeProto>({shape_proto_expected})},
1444             {"_outside_compilation_subgraph", "O1"},
1445             {"_xla_token_input_nodes",
1446              absl::Span<const string>({"_xla_token_arg_node"})},
1447             {"_xla_original_oc_node_name",
1448              "outside_compilation_O1_host_compute"}},
1449            {"D"}},
1450       },
1451       {{"f_0_retval_retval", "F:o:0"}});
1452 
1453   *library_expected.add_function() = FunctionDefHelper::Create(
1454       "F2", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval_retval:float"}, {},
1455       {
1456           {{"G"}, "BinaryTest", {"a_0_arg", "b_0_arg"}},
1457           {{"I"},
1458            "UnaryTest",
1459            {"outside_compilation_O1_host_compute:outputs:0"}},
1460           {{"outside_compilation_O1_host_compute"},
1461            "XlaHostCompute",
1462            {"G:o:0"},
1463            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
1464             {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1465             {"ancestors", absl::Span<const string>({})},
1466             {"key", "host_compute_channel_F2_F2_O1"},
1467             {"send_key", ""},
1468             {"recv_key", ""},
1469             {"shape_inference_graph", NameAttrList()},
1470             {"tpu_core", 0},
1471             {"cost_estimate_ns", 1000000},
1472             {"shapes",
1473              absl::Span<const TensorShapeProto>({shape_proto_expected})},
1474             {"_outside_compilation_subgraph", "O1"},
1475             {"_xla_token_input_nodes",
1476              absl::Span<const string>({"_xla_token_arg_node"})},
1477             {"_xla_original_oc_node_name",
1478              "outside_compilation_O1_host_compute"}}},
1479       },
1480       {{"i_0_retval_retval", "I:o:0"}});
1481 
1482   {
1483     std::unique_ptr<FunctionLibraryDefinition> lib_def(
1484         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
1485     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
1486     Node* a = InputShaped(b2.opts().WithName("A"));
1487     Node* b = InputShaped(b2.opts().WithName("B"));
1488 
1489     Node* key_constant1 =
1490         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
1491     Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "F1", "O1",
1492                              {DT_FLOAT, DT_FLOAT}, b2.opts());
1493     Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
1494                      b2.opts()
1495                          .WithName("E")
1496                          .WithControlInputs({recv1})
1497                          .WithAttr("_encapsulate", "F1")
1498                          .WithAttr("_outside", "O1"));
1499     Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "F1", "O1",
1500                                {e}, b2.opts().WithControlInput(e));
1501     Node* s1 = Sequencer(
1502         b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
1503         "F1");
1504 
1505     NodeBuilder node_builder1("F1", "F1", lib_def.get());
1506     node_builder1.Input(a).Input(b);
1507     Node* call1 =
1508         b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1);
1509 
1510     Node* key_constant2 =
1511         KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder"));
1512     Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "F2", "O1",
1513                              {DT_FLOAT}, b2.opts());
1514     Node* h = Unary(recv2, b2.opts()
1515                                .WithName("H")
1516                                .WithAttr("_encapsulate", "F2")
1517                                .WithAttr("_outside", "O1"));
1518     Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "F2", "O1",
1519                                {h}, b2.opts());
1520 
1521     Node* s2 = Sequencer(
1522         b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}),
1523         "F2");
1524     NodeBuilder node_builder2("F2", "F2", lib_def.get());
1525     node_builder2.Input(a).Input(b);
1526     Node* call2 =
1527         b2.opts().WithControlInputs({s2}).FinalizeBuilder(&node_builder2);
1528     Binary(call1, call2, b2.opts().WithName("J"));
1529     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
1530   }
1531 
1532   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
1533   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
1534 }
1535 
1536 // Test with one outside_compilation cluster that has no inputs from the
1537 // compiled subgraph.
TEST(EncapsulateSubgraphsTest,OutsideCompilationNoInputs)1538 TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
1539   FunctionDefLibrary library;
1540   GraphDef graphdef;
1541 
1542   {
1543     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
1544     Node* a = InputShaped(b1.opts().WithName("A"));
1545     Node* b = Input(b1.opts().WithName("B"));
1546     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1547     Node* d =
1548         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1549     Node* e = Unary(a, b1.opts()
1550                            .WithName("E")
1551                            .WithAttr("_encapsulate", "F1")
1552                            .WithAttr("_outside", "O1"));
1553     Node* f =
1554         Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
1555     Unary(f, b1.opts().WithName("G"));
1556     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1557   }
1558 
1559   std::vector<string> encapsulated_functions{"F1"};
1560   TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1561 
1562   FunctionDefLibrary library_expected;
1563   GraphDef graphdef_expected;
1564 
1565   TensorShapeProto shape_proto_expected;
1566   shape_proto_expected.add_dim()->set_size(2);
1567 
1568   *library_expected.add_function() = FunctionDefHelper::Create(
1569       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
1570       {
1571           {{"C"}, "UnaryTest", {"a_0_arg"}},
1572           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
1573           {{"F"},
1574            "BinaryTest",
1575            {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
1576           {{"outside_compilation_O1_host_compute"},
1577            "XlaHostCompute",
1578            {"a_0_arg"},
1579            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
1580             {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1581             {"ancestors", absl::Span<const string>({})},
1582             {"key", "host_compute_channel_F1_F1_O1"},
1583             {"send_key", ""},
1584             {"recv_key", ""},
1585             {"shape_inference_graph", NameAttrList()},
1586             {"tpu_core", 0},
1587             {"cost_estimate_ns", 1000000},
1588             {"shapes",
1589              absl::Span<const TensorShapeProto>({shape_proto_expected})},
1590             {"_outside_compilation_subgraph", "O1"},
1591             {"_xla_token_input_nodes",
1592              absl::Span<const string>({"_xla_token_arg_node"})},
1593             {"_xla_original_oc_node_name",
1594              "outside_compilation_O1_host_compute"}}},
1595       },
1596       {{"f_0_retval_retval", "F:o:0"}});
1597 
1598   {
1599     std::unique_ptr<FunctionLibraryDefinition> lib_def(
1600         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
1601     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
1602     Node* a = InputShaped(b2.opts().WithName("A"));
1603     Node* b = Input(b2.opts().WithName("B"));
1604 
1605     Node* key_constant =
1606         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
1607     Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1608                              {DT_FLOAT}, b2.opts());
1609     Node* e = Unary(recv1, b2.opts()
1610                                .WithName("E")
1611                                .WithAttr("_encapsulate", "F1")
1612                                .WithAttr("_outside", "O1"));
1613     Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1614                                {e}, b2.opts());
1615     Node* s1 = Sequencer(
1616         b2.opts().WithName("F1_sequencer").WithControlInputs({send1, recv1}),
1617         "F1");
1618     NodeBuilder node_builder1("F1", "F1", lib_def.get());
1619     node_builder1.Input(a).Input(b);
1620     Node* call1 =
1621         b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
1622 
1623     Unary(call1, b2.opts().WithName("G"));
1624     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
1625   }
1626 
1627   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
1628   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
1629 }
1630 
1631 // Test with one outside_compilation cluster that has no data inputs but has a
1632 // control input from the compiled subgraph.
TEST(EncapsulateSubgraphsTest,OutsideCompilationControlInput)1633 TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
1634   FunctionDefLibrary library;
1635   GraphDef graphdef;
1636 
1637   {
1638     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
1639     Node* a = InputShaped(b1.opts().WithName("A"));
1640     Node* b = Input(b1.opts().WithName("B"));
1641     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1642     Node* d =
1643         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1644     Node* e = Unary(a, b1.opts()
1645                            .WithName("E")
1646                            .WithControlInput(d)
1647                            .WithAttr("_encapsulate", "F1")
1648                            .WithAttr("_outside", "O1"));
1649     Node* f =
1650         Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
1651     Unary(f, b1.opts().WithName("G"));
1652     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1653   }
1654 
1655   std::vector<string> encapsulated_functions{"F1"};
1656   TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1657 
1658   FunctionDefLibrary library_expected;
1659   GraphDef graphdef_expected;
1660 
1661   TensorShapeProto shape_proto_expected;
1662   shape_proto_expected.add_dim()->set_size(2);
1663 
1664   *library_expected.add_function() = FunctionDefHelper::Create(
1665       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
1666       {
1667           {{"C"}, "UnaryTest", {"a_0_arg"}},
1668           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
1669           {{"F"},
1670            "BinaryTest",
1671            {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
1672           {{"outside_compilation_O1_host_compute"},
1673            "XlaHostCompute",
1674            {"a_0_arg"},
1675            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
1676             {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1677             {"ancestors", absl::Span<const string>({})},
1678             {"key", "host_compute_channel_F1_F1_O1"},
1679             {"send_key", ""},
1680             {"recv_key", ""},
1681             {"shape_inference_graph", NameAttrList()},
1682             {"tpu_core", 0},
1683             {"cost_estimate_ns", 1000000},
1684             {"shapes",
1685              absl::Span<const TensorShapeProto>({shape_proto_expected})},
1686             {"_outside_compilation_subgraph", "O1"},
1687             {"_xla_token_input_nodes",
1688              absl::Span<const string>({"_xla_token_arg_node"})},
1689             {"_xla_original_oc_node_name",
1690              "outside_compilation_O1_host_compute"}},
1691            {"D"}},
1692       },
1693       {{"f_0_retval_retval", "F:o:0"}});
1694 
1695   {
1696     std::unique_ptr<FunctionLibraryDefinition> lib_def(
1697         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
1698     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
1699     Node* a = InputShaped(b2.opts().WithName("A"));
1700     Node* b = Input(b2.opts().WithName("B"));
1701 
1702     Node* key_constant =
1703         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
1704     Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1705                              {DT_FLOAT}, b2.opts());
1706     Node* e = Unary(recv1, b2.opts()
1707                                .WithName("E")
1708                                .WithControlInput(recv1)
1709                                .WithAttr("_encapsulate", "F1")
1710                                .WithAttr("_outside", "O1"));
1711     Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1712                                {e}, b2.opts());
1713     Node* s1 = Sequencer(
1714         b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
1715         "F1");
1716     NodeBuilder node_builder1("F1", "F1", lib_def.get());
1717     node_builder1.Input(a).Input(b);
1718     Node* call1 =
1719         b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
1720 
1721     Unary(call1, b2.opts().WithName("G"));
1722     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
1723   }
1724 
1725   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
1726   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
1727 }
1728 
1729 // Test with one outside_compilation cluster that has no outputs from the
1730 // compiled subgraph.
TEST(EncapsulateSubgraphsTest,OutsideCompilationNoOutputs)1731 TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
1732   FunctionDefLibrary library;
1733   GraphDef graphdef;
1734 
1735   {
1736     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
1737     Node* a = Input(b1.opts().WithName("A"));
1738     Node* b = Input(b1.opts().WithName("B"));
1739     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1740     Node* d =
1741         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1742     Node* e = Unary(d, b1.opts()
1743                            .WithName("E")
1744                            .WithAttr("_encapsulate", "F1")
1745                            .WithAttr("_outside", "O1"));
1746     Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
1747     Binary(e, f, b1.opts().WithName("G"));
1748     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1749   }
1750 
1751   std::vector<string> encapsulated_functions{"F1"};
1752   TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1753 
1754   FunctionDefLibrary library_expected;
1755   GraphDef graphdef_expected;
1756 
1757   {
1758     GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
1759     Node* key_constant = KeyPlaceholder("F1", shape1.opts());
1760     Node* recv1 =
1761         RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
1762                    shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1763     Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts()
1764                                                 .WithName("E")
1765                                                 .WithAttr("_encapsulate", "F1")
1766                                                 .WithAttr("_outside", "O1"));
1767     SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
1768                  shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1769     TF_EXPECT_OK(
1770         AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
1771   }
1772 
1773   NameAttrList shape_inference_graph;
1774   shape_inference_graph.set_name(
1775       "_outside_compilation_shape_inference_F1_F1_O1");
1776   *library_expected.add_function() = FunctionDefHelper::Create(
1777       "F1", {"a_0_arg:float", "b_0_arg:float"},
1778       {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {},
1779       {
1780           {{"C"}, "UnaryTest", {"a_0_arg"}},
1781           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
1782           {{"F"}, "UnaryTest", {"D:o:0"}},
1783           {{"outside_compilation_O1_host_compute"},
1784            "XlaHostCompute",
1785            {"D:o:0"},
1786            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
1787             {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1788             {"ancestors", absl::Span<const string>({})},
1789             {"key", "host_compute_channel_F1_F1_O1"},
1790             {"send_key", ""},
1791             {"recv_key", ""},
1792             {"shape_inference_graph", shape_inference_graph},
1793             {"tpu_core", 0},
1794             {"cost_estimate_ns", 1000000},
1795             {"shapes", absl::Span<const TensorShapeProto>({})},
1796             {"_outside_compilation_subgraph", "O1"},
1797             {"_xla_token_input_nodes",
1798              absl::Span<const string>({"_xla_token_arg_node"})},
1799             {"_xla_original_oc_node_name",
1800              "outside_compilation_O1_host_compute"}}},
1801       },
1802       {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
1803        {"f_0_retval_retval", "F:o:0"}});
1804 
1805   {
1806     std::unique_ptr<FunctionLibraryDefinition> lib_def(
1807         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
1808     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
1809     Node* a = Input(b2.opts().WithName("A"));
1810     Node* b = Input(b2.opts().WithName("B"));
1811 
1812     Node* key_constant =
1813         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
1814     Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1815                              {DT_FLOAT}, b2.opts());
1816     Node* e = Unary(recv1, b2.opts()
1817                                .WithName("E")
1818                                .WithAttr("_encapsulate", "F1")
1819                                .WithAttr("_outside", "O1"));
1820     Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1821                                {e}, b2.opts());
1822     Node* s1 = Sequencer(
1823         b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
1824         "F1");
1825     NodeBuilder node_builder1("F1", "F1", lib_def.get());
1826     node_builder1.Input(a).Input(b);
1827     Node* call1 =
1828         b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
1829 
1830     Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G"));
1831     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
1832   }
1833 
1834   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
1835   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
1836 }
1837 
1838 // Test with one outside_compilation cluster that has no data outputs but has a
1839 // control output to the compiled subgraph.
TEST(EncapsulateSubgraphsTest,OutsideCompilationControlOutput)1840 TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
1841   FunctionDefLibrary library;
1842   GraphDef graphdef;
1843 
1844   {
1845     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
1846     Node* a = Input(b1.opts().WithName("A"));
1847     Node* b = Input(b1.opts().WithName("B"));
1848     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1849     Node* d =
1850         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1851     Node* e = Unary(d, b1.opts()
1852                            .WithName("E")
1853                            .WithAttr("_encapsulate", "F1")
1854                            .WithAttr("_outside", "O1"));
1855     Node* f = Unary(d, b1.opts().WithName("F").WithControlInput(e).WithAttr(
1856                            "_encapsulate", "F1"));
1857     Binary(e, f, b1.opts().WithName("G"));
1858     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1859   }
1860 
1861   std::vector<string> encapsulated_functions{"F1"};
1862   TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1863 
1864   FunctionDefLibrary library_expected;
1865   GraphDef graphdef_expected;
1866 
1867   {
1868     GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
1869     Node* key_constant = KeyPlaceholder("F1", shape1.opts());
1870     Node* recv1 =
1871         RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
1872                    shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1873     Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts()
1874                                                 .WithName("E")
1875                                                 .WithAttr("_encapsulate", "F1")
1876                                                 .WithAttr("_outside", "O1"));
1877     SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
1878                  shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1879     TF_EXPECT_OK(
1880         AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
1881   }
1882 
1883   NameAttrList shape_inference_graph;
1884   shape_inference_graph.set_name(
1885       "_outside_compilation_shape_inference_F1_F1_O1");
1886   *library_expected.add_function() = FunctionDefHelper::Create(
1887       "F1", {"a_0_arg:float", "b_0_arg:float"},
1888       {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {},
1889       {
1890           {{"C"}, "UnaryTest", {"a_0_arg"}},
1891           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
1892           {{"F"},
1893            "UnaryTest",
1894            {"D:o:0"},
1895            {},
1896            {"outside_compilation_O1_host_compute"}},
1897           {{"outside_compilation_O1_host_compute"},
1898            "XlaHostCompute",
1899            {"D:o:0"},
1900            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
1901             {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1902             {"ancestors", absl::Span<const string>({})},
1903             {"key", "host_compute_channel_F1_F1_O1"},
1904             {"send_key", ""},
1905             {"recv_key", ""},
1906             {"shape_inference_graph", shape_inference_graph},
1907             {"tpu_core", 0},
1908             {"cost_estimate_ns", 1000000},
1909             {"shapes", absl::Span<const TensorShapeProto>({})},
1910             {"_outside_compilation_subgraph", "O1"},
1911             {"_xla_token_input_nodes",
1912              absl::Span<const string>({"_xla_token_arg_node"})},
1913             {"_xla_original_oc_node_name",
1914              "outside_compilation_O1_host_compute"}}},
1915       },
1916       {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
1917        {"f_0_retval_retval", "F:o:0"}});
1918 
1919   {
1920     std::unique_ptr<FunctionLibraryDefinition> lib_def(
1921         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
1922     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
1923     Node* a = Input(b2.opts().WithName("A"));
1924     Node* b = Input(b2.opts().WithName("B"));
1925 
1926     Node* key_constant =
1927         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
1928     Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1929                              {DT_FLOAT}, b2.opts());
1930     Node* e = Unary(recv1, b2.opts()
1931                                .WithName("E")
1932                                .WithAttr("_encapsulate", "F1")
1933                                .WithAttr("_outside", "O1"));
1934     Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1935                                {e}, b2.opts().WithControlInput(e));
1936     Node* s1 = Sequencer(
1937         b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
1938         "F1");
1939     NodeBuilder node_builder1("F1", "F1", lib_def.get());
1940     node_builder1.Input(a).Input(b);
1941     Node* call1 =
1942         b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
1943 
1944     Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G"));
1945     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
1946   }
1947 
1948   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
1949   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
1950 }
1951 
1952 // Test with two outside_compilation clusters that interact outside the compiled
1953 // subgraph, where the ancestor has no HostCompute Op.
TEST(EncapsulateSubgraphsTest,OutsideCompilationClusterDependencyNoSrcCluster)1954 TEST(EncapsulateSubgraphsTest,
1955      OutsideCompilationClusterDependencyNoSrcCluster) {
1956   FunctionDefLibrary library;
1957   GraphDef graphdef;
1958 
1959   {
1960     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
1961     Node* a = Input(b1.opts().WithName("A"));
1962     Node* b = Input(b1.opts().WithName("B"));
1963     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1964     Node* d =
1965         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1966     Node* e = Unary(a, b1.opts()
1967                            .WithName("E")
1968                            .WithAttr("_encapsulate", "F1")
1969                            .WithAttr("_outside", "O1"));
1970     Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
1971     Node* g = Unary(f, b1.opts()
1972                            .WithName("G")
1973                            .WithAttr("_encapsulate", "F1")
1974                            .WithAttr("_outside", "O2")
1975                            .WithControlInput(e));
1976     Node* h = Unary(g, b1.opts().WithName("H").WithAttr("_encapsulate", "F1"));
1977     Binary(e, h, b1.opts().WithName("I"));
1978     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1979   }
1980 
1981   std::vector<string> encapsulated_functions{"F1"};
1982   TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1983 
1984   FunctionDefLibrary library_expected;
1985   GraphDef graphdef_expected;
1986 
1987   {
1988     GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
1989     Node* key_constant = KeyPlaceholder("F1", shape1.opts());
1990     Node* recv1 =
1991         RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
1992                    shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1993     Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts()
1994                                                 .WithName("E")
1995                                                 .WithAttr("_encapsulate", "F1")
1996                                                 .WithAttr("_outside", "O1"));
1997     SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
1998                  shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1999     TF_EXPECT_OK(
2000         AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
2001   }
2002 
2003   {
2004     GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
2005     Node* key_constant = KeyPlaceholder("F1", shape2.opts());
2006     Node* recv2 =
2007         RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT},
2008                    shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2009     Node* g = Unary(ops::NodeOut(recv2, 0), shape2.opts()
2010                                                 .WithName("G")
2011                                                 .WithAttr("_encapsulate", "F1")
2012                                                 .WithAttr("_outside", "O2"));
2013     SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {g},
2014                  shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2015     TF_EXPECT_OK(
2016         AddGraphDefToFunctionLibrary(shape2, "F1_F1_O2", &library_expected));
2017   }
2018 
2019   NameAttrList shape_inference_graph1;
2020   shape_inference_graph1.set_name(
2021       "_outside_compilation_shape_inference_F1_F1_O1");
2022   NameAttrList shape_inference_graph2;
2023   shape_inference_graph2.set_name(
2024       "_outside_compilation_shape_inference_F1_F1_O2");
2025   *library_expected.add_function() = FunctionDefHelper::Create(
2026       "F1", {"a_0_arg:float", "b_0_arg:float"},
2027       {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {},
2028       {
2029           {{"C"}, "UnaryTest", {"a_0_arg"}},
2030           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
2031           {{"F"}, "UnaryTest", {"D:o:0"}},
2032           {{"H"},
2033            "UnaryTest",
2034            {"outside_compilation_O2_host_compute:outputs:0"}},
2035           {{"outside_compilation_O1_host_compute"},
2036            "XlaHostCompute",
2037            {"a_0_arg"},
2038            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2039             {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
2040             {"ancestors", absl::Span<const string>({})},
2041             {"key", "host_compute_channel_F1_F1_O1"},
2042             {"send_key", ""},
2043             {"recv_key", ""},
2044             {"shape_inference_graph", shape_inference_graph1},
2045             {"tpu_core", 0},
2046             {"cost_estimate_ns", 1000000},
2047             {"shapes", absl::Span<const TensorShapeProto>({})},
2048             {"_outside_compilation_subgraph", "O1"},
2049             {"_xla_token_input_nodes",
2050              absl::Span<const string>({"_xla_token_arg_node"})},
2051             {"_xla_original_oc_node_name",
2052              "outside_compilation_O1_host_compute"}}},
2053           {{"outside_compilation_O2_host_compute"},
2054            "XlaHostCompute",
2055            {"F:o:0"},
2056            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2057             {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
2058             {"ancestors", absl::Span<const string>({})},
2059             {"key", "host_compute_channel_F1_F1_O2"},
2060             {"send_key", ""},
2061             {"recv_key", ""},
2062             {"shape_inference_graph", shape_inference_graph2},
2063             {"tpu_core", 0},
2064             {"cost_estimate_ns", 1000000},
2065             {"shapes", absl::Span<const TensorShapeProto>({})},
2066             {"_outside_compilation_subgraph", "O2"},
2067             {"_xla_token_input_nodes",
2068              absl::Span<const string>({"_xla_token_arg_node",
2069                                        "outside_compilation_O1_host_compute"})},
2070             {"_xla_original_oc_node_name",
2071              "outside_compilation_O2_host_compute"}},
2072            {"outside_compilation_O1_host_compute"}},
2073       },
2074       {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
2075        {"h_0_retval_retval", "H:o:0"}});
2076 
2077   {
2078     std::unique_ptr<FunctionLibraryDefinition> lib_def(
2079         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
2080     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
2081     Node* a = Input(b2.opts().WithName("A"));
2082     Node* b = Input(b2.opts().WithName("B"));
2083     Node* key_constant =
2084         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
2085     Node* recv1 =
2086         RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
2087                    b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2088 
2089     Node* e = Unary(recv1, b2.opts()
2090                                .WithName("E")
2091                                .WithAttr("_encapsulate", "F1")
2092                                .WithAttr("_outside", "O1"));
2093     Node* send1 =
2094         SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2095                      b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2096     Node* recv2 =
2097         RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT},
2098                    b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2099     Node* g = Unary(recv2, b2.opts()
2100                                .WithName("G")
2101                                .WithAttr("_encapsulate", "F1")
2102                                .WithAttr("_outside", "O2")
2103                                .WithControlInput(e));
2104     Node* send2 =
2105         SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {g},
2106                      b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2107     Node* s1 = Sequencer(b2.opts()
2108                              .WithName("F1_sequencer")
2109                              .WithControlInputs({recv1, send1, recv2, send2}),
2110                          "F1");
2111     NodeBuilder node_builder1("F1", "F1", lib_def.get());
2112     node_builder1.Input(a).Input(b).ControlInput(s1);
2113     Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
2114 
2115     Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I"));
2116     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
2117   }
2118 
2119   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
2120   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
2121 }
2122 
2123 // Test with two outside_compilation clusters that interact outside the compiled
2124 // subgraph, where the successor has no HostCompute Op.
TEST(EncapsulateSubgraphsTest,OutsideCompilationClusterDependencyNoDstCluster)2125 TEST(EncapsulateSubgraphsTest,
2126      OutsideCompilationClusterDependencyNoDstCluster) {
2127   FunctionDefLibrary library;
2128   GraphDef graphdef;
2129 
2130   {
2131     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
2132     Node* a = Input(b1.opts().WithName("A"));
2133     Node* b = Input(b1.opts().WithName("B"));
2134     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
2135     Node* d =
2136         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
2137     Node* e = Unary(d, b1.opts()
2138                            .WithName("E")
2139                            .WithAttr("_encapsulate", "F1")
2140                            .WithAttr("_outside", "O1"));
2141     Node* f = Unary(e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
2142     /*Node* g =*/Unary(a, b1.opts()
2143                               .WithName("G")
2144                               .WithAttr("_encapsulate", "F1")
2145                               .WithAttr("_outside", "O2")
2146                               .WithControlInput(e));
2147     Node* h = Unary(f, b1.opts().WithName("H").WithAttr("_encapsulate", "F1"));
2148     Binary(e, h, b1.opts().WithName("I"));
2149     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
2150   }
2151 
2152   std::vector<string> encapsulated_functions{"F1"};
2153   TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
2154 
2155   FunctionDefLibrary library_expected;
2156   GraphDef graphdef_expected;
2157 
2158   {
2159     GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
2160     Node* key_constant = KeyPlaceholder("F1", shape1.opts());
2161     Node* recv2 =
2162         RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
2163                    shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2164     Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
2165                                                 .WithName("E")
2166                                                 .WithAttr("_encapsulate", "F1")
2167                                                 .WithAttr("_outside", "O1"));
2168     SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2169                  shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2170     TF_EXPECT_OK(
2171         AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
2172   }
2173 
2174   NameAttrList shape_inference_graph;
2175   shape_inference_graph.set_name(
2176       "_outside_compilation_shape_inference_F1_F1_O1");
2177   *library_expected.add_function() = FunctionDefHelper::Create(
2178       "F1", {"a_0_arg:float", "b_0_arg:float"},
2179       {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {},
2180       {
2181           {{"C"}, "UnaryTest", {"a_0_arg"}},
2182           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
2183           {{"F"},
2184            "UnaryTest",
2185            {"outside_compilation_O1_host_compute:outputs:0"}},
2186           {{"H"}, "UnaryTest", {"F:o:0"}},
2187           {{"outside_compilation_O2_host_compute"},
2188            "XlaHostCompute",
2189            {"a_0_arg"},
2190            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2191             {"Toutputs", absl::Span<const DataType>({})},
2192             {"ancestors", absl::Span<const string>({})},
2193             {"key", "host_compute_channel_F1_F1_O2"},
2194             {"send_key", ""},
2195             {"recv_key", ""},
2196             {"shape_inference_graph", NameAttrList()},
2197             {"tpu_core", 0},
2198             {"cost_estimate_ns", 1000000},
2199             {"shapes", absl::Span<const TensorShapeProto>({})},
2200             {"_outside_compilation_subgraph", "O2"},
2201             {"_xla_token_input_nodes",
2202              absl::Span<const string>({"_xla_token_arg_node",
2203                                        "outside_compilation_O1_host_compute"})},
2204             {"_xla_original_oc_node_name",
2205              "outside_compilation_O2_host_compute"}},
2206            {"outside_compilation_O1_host_compute"}},
2207           {{"outside_compilation_O1_host_compute"},
2208            "XlaHostCompute",
2209            {"D:o:0"},
2210            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2211             {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
2212             {"ancestors", absl::Span<const string>({})},
2213             {"key", "host_compute_channel_F1_F1_O1"},
2214             {"send_key", ""},
2215             {"recv_key", ""},
2216             {"shape_inference_graph", shape_inference_graph},
2217             {"tpu_core", 0},
2218             {"cost_estimate_ns", 1000000},
2219             {"shapes", absl::Span<const TensorShapeProto>({})},
2220             {"_outside_compilation_subgraph", "O1"},
2221             {"_xla_token_input_nodes",
2222              absl::Span<const string>({"_xla_token_arg_node"})},
2223             {"_xla_original_oc_node_name",
2224              "outside_compilation_O1_host_compute"}}},
2225       },
2226       {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
2227        {"h_0_retval_retval", "H:o:0"}});
2228 
2229   {
2230     std::unique_ptr<FunctionLibraryDefinition> lib_def(
2231         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
2232     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
2233     Node* a = Input(b2.opts().WithName("A"));
2234     Node* b = Input(b2.opts().WithName("B"));
2235 
2236     Node* key_constant =
2237         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
2238     Node* recv1 =
2239         RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
2240                    b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2241     Node* e = Unary(recv1, b2.opts()
2242                                .WithName("E")
2243                                .WithAttr("_encapsulate", "F1")
2244                                .WithAttr("_outside", "O1"));
2245     Node* send =
2246         SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2247                      b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2248     Node* recv2 =
2249         RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT},
2250                    b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2251     /*Node* g =*/Unary(recv2, b2.opts()
2252                                   .WithName("G")
2253                                   .WithAttr("_encapsulate", "F1")
2254                                   .WithAttr("_outside", "O2")
2255                                   .WithControlInput(e));
2256     Node* s1 = Sequencer(b2.opts()
2257                              .WithName("F1_sequencer")
2258                              .WithControlInputs({recv1, recv2, send}),
2259                          "F1");
2260     NodeBuilder node_builder1("F1", "F1", lib_def.get());
2261     node_builder1.Input(a).Input(b).ControlInput(s1);
2262     Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
2263 
2264     Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I"));
2265     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
2266   }
2267 
2268   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
2269   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
2270 }
2271 
2272 // Test with two outside_compilation clusters that interact outside the compiled
2273 // subgraph.
TEST(EncapsulateSubgraphsTest,OutsideCompilationClusterDependency)2274 TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
2275   FunctionDefLibrary library;
2276   GraphDef graphdef;
2277 
2278   {
2279     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
2280     Node* a = Input(b1.opts().WithName("A"));
2281     Node* b = Input(b1.opts().WithName("B"));
2282     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
2283     Node* d =
2284         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
2285     Node* e = Unary(d, b1.opts()
2286                            .WithName("E")
2287                            .WithAttr("_encapsulate", "F1")
2288                            .WithAttr("_outside", "O1"));
2289     Node* f = Unary(e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
2290     Node* g = Unary(d, b1.opts()
2291                            .WithName("G")
2292                            .WithAttr("_encapsulate", "F1")
2293                            .WithAttr("_outside", "O2")
2294                            .WithControlInput(e));
2295     Node* h = Unary(f, b1.opts().WithName("H").WithAttr("_encapsulate", "F1"));
2296     /*Node* i =*/Binary(d, e,
2297                         b1.opts()
2298                             .WithName("I")
2299                             .WithAttr("_encapsulate", "F1")
2300                             .WithAttr("_outside", "O3")
2301                             .WithControlInput(g));
2302     Binary(e, h, b1.opts().WithName("J"));
2303     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
2304   }
2305 
2306   std::vector<string> encapsulated_functions{"F1"};
2307   TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
2308 
2309   FunctionDefLibrary library_expected;
2310   GraphDef graphdef_expected;
2311 
2312   {
2313     GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
2314     Node* key_constant = KeyPlaceholder("F1", shape1.opts());
2315     Node* recv2 =
2316         RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
2317                    shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2318     Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
2319                                                 .WithName("E")
2320                                                 .WithAttr("_encapsulate", "F1")
2321                                                 .WithAttr("_outside", "O1"));
2322     SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2323                  shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2324     TF_EXPECT_OK(
2325         AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
2326   }
2327 
2328   NameAttrList shape_inference_graph;
2329   shape_inference_graph.set_name(
2330       "_outside_compilation_shape_inference_F1_F1_O1");
2331   *library_expected.add_function() = FunctionDefHelper::Create(
2332       "F1", {"a_0_arg:float", "b_0_arg:float"},
2333       {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {},
2334       {{{"C"}, "UnaryTest", {"a_0_arg"}},
2335        {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
2336        {{"F"}, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}},
2337        {{"H"}, "UnaryTest", {"F:o:0"}},
2338        {{"outside_compilation_O1_host_compute"},
2339         "XlaHostCompute",
2340         {"D:o:0"},
2341         {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2342          {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
2343          {"ancestors", absl::Span<const string>({})},
2344          {"key", "host_compute_channel_F1_F1_O1"},
2345          {"send_key", ""},
2346          {"recv_key", ""},
2347          {"shape_inference_graph", shape_inference_graph},
2348          {"tpu_core", 0},
2349          {"cost_estimate_ns", 1000000},
2350          {"shapes", absl::Span<const TensorShapeProto>({})},
2351          {"_outside_compilation_subgraph", "O1"},
2352          {"_xla_token_input_nodes",
2353           absl::Span<const string>({"_xla_token_arg_node"})},
2354          {"_xla_original_oc_node_name",
2355           "outside_compilation_O1_host_compute"}}},
2356        {{"outside_compilation_O2_host_compute"},
2357         "XlaHostCompute",
2358         {"D:o:0"},
2359         {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2360          {"Toutputs", absl::Span<const DataType>({})},
2361          {"ancestors", absl::Span<const string>({})},
2362          {"key", "host_compute_channel_F1_F1_O2"},
2363          {"send_key", ""},
2364          {"recv_key", ""},
2365          {"shape_inference_graph", NameAttrList()},
2366          {"tpu_core", 0},
2367          {"cost_estimate_ns", 1000000},
2368          {"shapes", absl::Span<const TensorShapeProto>({})},
2369          {"_outside_compilation_subgraph", "O2"},
2370          {"_xla_token_input_nodes",
2371           absl::Span<const string>(
2372               {"_xla_token_arg_node", "outside_compilation_O1_host_compute"})},
2373          {"_xla_original_oc_node_name", "outside_compilation_O2_host_compute"}},
2374         {"outside_compilation_O1_host_compute"}},
2375        {{"outside_compilation_O3_host_compute"},
2376         "XlaHostCompute",
2377         {"D:o:0"},
2378         {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2379          {"Toutputs", absl::Span<const DataType>({})},
2380          {"ancestors", absl::Span<const string>({})},
2381          {"key", "host_compute_channel_F1_F1_O3"},
2382          {"send_key", ""},
2383          {"recv_key", ""},
2384          {"shape_inference_graph", NameAttrList()},
2385          {"tpu_core", 0},
2386          {"cost_estimate_ns", 1000000},
2387          {"shapes", absl::Span<const TensorShapeProto>({})},
2388          {"_outside_compilation_subgraph", "O3"},
2389          {"_xla_token_input_nodes",
2390           absl::Span<const string>({"_xla_token_arg_node",
2391                                     "outside_compilation_O1_host_compute",
2392                                     "outside_compilation_O2_host_compute"})},
2393          {"_xla_original_oc_node_name", "outside_compilation_O3_host_compute"}},
2394         {"outside_compilation_O1_host_compute",
2395          "outside_compilation_O2_host_compute"}}},
2396       {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
2397        {"h_0_retval_retval", "H:o:0"}});
2398 
2399   {
2400     std::unique_ptr<FunctionLibraryDefinition> lib_def(
2401         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
2402     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
2403     Node* a = Input(b2.opts().WithName("A"));
2404     Node* b = Input(b2.opts().WithName("B"));
2405 
2406     Node* key_constant =
2407         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
2408     Node* recv1 =
2409         RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
2410                    b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2411     Node* e = Unary(recv1, b2.opts()
2412                                .WithName("E")
2413                                .WithAttr("_encapsulate", "F1")
2414                                .WithAttr("_outside", "O1"));
2415     Node* send =
2416         SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2417                      b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2418     Node* recv2 =
2419         RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT},
2420                    b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2421     Node* g = Unary(recv2, b2.opts()
2422                                .WithName("G")
2423                                .WithAttr("_encapsulate", "F1")
2424                                .WithAttr("_outside", "O2")
2425                                .WithControlInput(e));
2426     Node* recv3 =
2427         RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O3", {DT_FLOAT},
2428                    b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2429     /*Node* i =*/Binary(recv3, e,
2430                         b2.opts()
2431                             .WithName("I")
2432                             .WithAttr("_encapsulate", "F1")
2433                             .WithAttr("_outside", "O3")
2434                             .WithControlInput(g));
2435     Node* s1 = Sequencer(b2.opts()
2436                              .WithName("F1_sequencer")
2437                              .WithControlInputs({recv1, send, recv2, recv3}),
2438                          "F1");
2439     NodeBuilder node_builder1("F1", "F1", lib_def.get());
2440     node_builder1.Input(a).Input(b).ControlInput(s1);
2441     Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
2442 
2443     Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("J"));
2444     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
2445   }
2446 
2447   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
2448   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
2449 }
2450 
2451 // Test with one outside_compilation cluster that has no outputs from the
2452 // compiled subgraph.
TEST(EncapsulateSubgraphsTest,OutsideCompilationNoInputsOrOutputs)2453 TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
2454   FunctionDefLibrary library;
2455   GraphDef graphdef;
2456 
2457   {
2458     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
2459     Node* a = Input(b1.opts().WithName("A"));
2460     Node* b = Input(b1.opts().WithName("B"));
2461     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
2462     Node* d =
2463         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
2464     Node* e = Unary(a, b1.opts()
2465                            .WithName("E")
2466                            .WithAttr("_encapsulate", "F1")
2467                            .WithAttr("_outside", "O1"));
2468     Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
2469     Binary(e, f, b1.opts().WithName("G"));
2470     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
2471   }
2472 
2473   std::vector<string> encapsulated_functions{"F1"};
2474   TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
2475 
2476   FunctionDefLibrary library_expected;
2477   GraphDef graphdef_expected;
2478 
2479   {
2480     GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
2481     Node* key_constant = KeyPlaceholder("F1", shape1.opts());
2482     Node* recv2 =
2483         RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
2484                    shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2485     Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
2486                                                 .WithName("E")
2487                                                 .WithAttr("_encapsulate", "F1")
2488                                                 .WithAttr("_outside", "O1"));
2489     SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2490                  shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2491     TF_EXPECT_OK(
2492         AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
2493   }
2494 
2495   NameAttrList shape_inference_graph;
2496   shape_inference_graph.set_name(
2497       "_outside_compilation_shape_inference_F1_F1_O1");
2498   *library_expected.add_function() = FunctionDefHelper::Create(
2499       "F1", {"a_0_arg:float", "b_0_arg:float"},
2500       {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {},
2501       {
2502           {{"C"}, "UnaryTest", {"a_0_arg"}},
2503           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
2504           {{"F"}, "UnaryTest", {"D:o:0"}},
2505           {{"outside_compilation_O1_host_compute"},
2506            "XlaHostCompute",
2507            {"a_0_arg"},
2508            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2509             {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
2510             {"ancestors", absl::Span<const string>({})},
2511             {"key", "host_compute_channel_F1_F1_O1"},
2512             {"send_key", ""},
2513             {"recv_key", ""},
2514             {"shape_inference_graph", shape_inference_graph},
2515             {"tpu_core", 0},
2516             {"cost_estimate_ns", 1000000},
2517             {"shapes", absl::Span<const TensorShapeProto>({})},
2518             {"_outside_compilation_subgraph", "O1"},
2519             {"_xla_token_input_nodes",
2520              absl::Span<const string>({"_xla_token_arg_node"})},
2521             {"_xla_original_oc_node_name",
2522              "outside_compilation_O1_host_compute"}}},
2523       },
2524       {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
2525        {"f_0_retval_retval", "F:o:0"}});
2526 
2527   {
2528     std::unique_ptr<FunctionLibraryDefinition> lib_def(
2529         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
2530     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
2531     Node* a = Input(b2.opts().WithName("A"));
2532     Node* b = Input(b2.opts().WithName("B"));
2533 
2534     Node* key_constant =
2535         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
2536     Node* recv =
2537         RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
2538                    b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2539     Node* e = Unary(recv, b2.opts()
2540                               .WithName("E")
2541                               .WithAttr("_encapsulate", "F1")
2542                               .WithAttr("_outside", "O1"));
2543     Node* send =
2544         SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2545                      b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2546     Node* s = Sequencer(
2547         b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
2548         "F1");
2549     NodeBuilder node_builder1("F1", "F1", lib_def.get());
2550     node_builder1.Input(a).Input(b).ControlInput(s);
2551     Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
2552 
2553     Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G"));
2554     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
2555   }
2556 
2557   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
2558   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
2559 }
2560 
2561 // Test for shape inference of outside compilation.
TEST(EncapsulateSubgraphsTest,OutsideCompilationShapeInference)2562 TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
2563   FunctionDefLibrary library;
2564   GraphDef graphdef;
2565 
2566   {
2567     *library.add_function() = test::function::XTimesTwo();
2568 
2569     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
2570     Node* a = InputShaped(b1.opts().WithName("A"));
2571     Node* b = Input(b1.opts().WithName("B"));
2572     // Give nodes 'c' and 'd' names that collide after lowercasing.
2573     Node* c = Unary(a, b1.opts().WithName("C"));
2574     Node* d = Unary(b, b1.opts().WithName("c").WithControlInput(c).WithAttr(
2575                            "_encapsulate", "F1"));
2576     Node* e = BinaryUnknownShape(c, d,
2577                                  b1.opts()
2578                                      .WithName("E")
2579                                      .WithControlInputs({b, d})
2580                                      .WithAttr("_encapsulate", "F1")
2581                                      .WithAttr("_outside", "O1"));
2582     Node* f = Binary(c, e,
2583                      b1.opts().WithName("F").WithControlInput(e).WithAttr(
2584                          "_encapsulate", "F1"));
2585     Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
2586     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
2587   }
2588 
2589   std::vector<string> encapsulated_functions{"F1"};
2590   TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
2591 
2592   FunctionDefLibrary library_expected;
2593   GraphDef graphdef_expected;
2594 
2595   {
2596     GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
2597     Node* key_constant = KeyPlaceholder("F1", shape.opts());
2598     Node* recv = RecvAtHost(
2599         ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
2600         shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2601     Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1),
2602                                  shape.opts()
2603                                      .WithName("E")
2604                                      .WithAttr("_encapsulate", "F1")
2605                                      .WithAttr("_outside", "O1"));
2606     SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2607                  shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2608     TF_EXPECT_OK(
2609         AddGraphDefToFunctionLibrary(shape, "F1_F1_O1", &library_expected));
2610   }
2611 
2612   NameAttrList shape_inference_graph;
2613   shape_inference_graph.set_name(
2614       "_outside_compilation_shape_inference_F1_F1_O1");
2615   *library_expected.add_function() = test::function::XTimesTwo();
2616   *library_expected.add_function() = FunctionDefHelper::Create(
2617       "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval_retval:float"}, {},
2618       {
2619           {{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}},
2620           {{"F"},
2621            "BinaryTest",
2622            {"c_0_arg", "outside_compilation_O1_host_compute:outputs:0"},
2623            {},
2624            {"outside_compilation_O1_host_compute"}},
2625           {{"outside_compilation_O1_host_compute"},
2626            "XlaHostCompute",
2627            {"c_0_arg", "c:o:0"},
2628            {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
2629             {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
2630             {"ancestors", absl::Span<const string>({})},
2631             {"key", "host_compute_channel_F1_F1_O1"},
2632             {"send_key", ""},
2633             {"recv_key", ""},
2634             {"shape_inference_graph", shape_inference_graph},
2635             {"tpu_core", 0},
2636             {"cost_estimate_ns", 1000000},
2637             {"shapes", absl::Span<const DataType>({})},
2638             {"_outside_compilation_subgraph", "O1"},
2639             {"_xla_token_input_nodes",
2640              absl::Span<const string>({"_xla_token_arg_node"})},
2641             {"_xla_original_oc_node_name",
2642              "outside_compilation_O1_host_compute"}},
2643            {"c"}},
2644       },
2645       {{"f_0_retval_retval", "F:o:0"}});
2646 
2647   {
2648     std::unique_ptr<FunctionLibraryDefinition> lib_def(
2649         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
2650     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
2651     Node* a = InputShaped(b2.opts().WithName("A"));
2652     Node* b = Input(b2.opts().WithName("B"));
2653     Node* c = Unary(a, b2.opts().WithName("C"));
2654 
2655     Node* key_constant =
2656         KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
2657     Node* recv = RecvAtHost(
2658         ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
2659         b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2660     Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1),
2661                                  b2.opts()
2662                                      .WithName("E")
2663                                      .WithControlInputs({recv})
2664                                      .WithAttr("_encapsulate", "F1")
2665                                      .WithAttr("_outside", "O1"));
2666     Node* send =
2667         SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2668                      b2.opts().WithControlInput(e).WithAttr(
2669                          kXlaHasHostTransferAttrName, true));
2670 
2671     Node* s = Sequencer(
2672         b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
2673         "F1");
2674 
2675     NodeBuilder node_builder("F1", "F1", lib_def.get());
2676     node_builder.Input(b).Input(c);
2677     Node* call =
2678         b2.opts().WithControlInputs({s, b, c}).FinalizeBuilder(&node_builder);
2679 
2680     Binary(a, call, b2.opts().WithName("G").WithControlInputs({call}));
2681     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
2682   }
2683 
2684   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
2685   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
2686 }
2687 
CreateSubgraphTouchingRefVar(const Scope & s)2688 void CreateSubgraphTouchingRefVar(const Scope& s) {
2689   Output variable =
2690       ops::Variable(s.WithOpName("variable"), PartialTensorShape{}, DT_FLOAT);
2691   Output read = ops::Identity(s.WithOpName("read_ref_var"), variable);
2692   Output neg = ops::Negate(s.WithOpName("negate_ref"), read);
2693   Output add = ops::Add(s.WithOpName("add_ref"), neg, neg);
2694 
2695   Output constant =
2696       ops::Const(s.WithOpName("constant_ref"), Input::Initializer(0.0));
2697   s.graph()->AddControlEdge(constant.node(), variable.node());
2698 }
2699 
TEST(EncapsulateSubgraphsTest,RefVariablesMarked)2700 TEST(EncapsulateSubgraphsTest, RefVariablesMarked) {
2701   Scope root = Scope::NewRootScope().ExitOnError();
2702   CreateSubgraphTouchingRefVar(root);
2703 
2704   auto graph = std::make_unique<Graph>(OpRegistry::Global());
2705   TF_ASSERT_OK(root.ToGraph(graph.get()));
2706 
2707   GraphOptimizationPassWrapper wrapper;
2708   GraphOptimizationPassOptions options =
2709       wrapper.CreateGraphOptimizationPassOptions(&graph);
2710 
2711   EncapsulateSubgraphsPass pass;
2712   TF_ASSERT_OK(pass.Run(options));
2713 
2714   for (const Node* node : graph->nodes()) {
2715     bool has_ref_var;
2716     TF_ASSERT_OK(
2717         GetNodeAttr(node->attrs(), kXlaHasReferenceVarsAttr, &has_ref_var));
2718     EXPECT_TRUE(node->IsSink() || node->IsSource() || has_ref_var)
2719         << "All nodes apart from source and sink can access reference variable";
2720   }
2721 }
2722 
CreateSubgraphNotTouchingRefVar(const Scope & s)2723 void CreateSubgraphNotTouchingRefVar(const Scope& s) {
2724   Output constant =
2725       ops::Const(s.WithOpName("constant_normal"), Input::Initializer(0.0));
2726   Output neg = ops::Negate(s.WithOpName("negate_normal"), constant);
2727   Output add = ops::Add(s.WithOpName("add_normal"), neg, neg);
2728 }
2729 
TEST(EncapsulateSubgraphsTest,NoRefVarsNoAttr)2730 TEST(EncapsulateSubgraphsTest, NoRefVarsNoAttr) {
2731   Scope root = Scope::NewRootScope().ExitOnError();
2732   CreateSubgraphNotTouchingRefVar(root);
2733 
2734   auto graph = std::make_unique<Graph>(OpRegistry::Global());
2735   TF_ASSERT_OK(root.ToGraph(graph.get()));
2736 
2737   GraphOptimizationPassWrapper wrapper;
2738   GraphOptimizationPassOptions options =
2739       wrapper.CreateGraphOptimizationPassOptions(&graph);
2740 
2741   EncapsulateSubgraphsPass pass;
2742   TF_ASSERT_OK(pass.Run(options));
2743 
2744   for (const Node* node : graph->nodes()) {
2745     bool has_ref_var;
2746     TF_ASSERT_OK(
2747         GetNodeAttr(node->attrs(), kXlaHasReferenceVarsAttr, &has_ref_var));
2748     EXPECT_FALSE(has_ref_var) << "The graph does not have reference variables";
2749   }
2750 }
2751 
2752 }  // namespace
2753 }  // namespace tensorflow
2754