1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/strings/match.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/cc/framework/ops.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/compiler/jit/encapsulate_util.h"
26 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
27 #include "tensorflow/compiler/jit/test_util.h"
28 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
29 #include "tensorflow/core/common_runtime/device_factory.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/common_runtime/graph_constructor.h"
32 #include "tensorflow/core/framework/function_testlib.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/graph/graph_def_builder.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/core/status_test_util.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/core/public/session_options.h"
39 #include "tensorflow/core/public/version.h"
40 #include "tensorflow/core/util/equal_graph_def.h"
41
42 namespace tensorflow {
43 namespace {
44
45 const char* const kXlaHostTransferSequencerAttr =
46 "_xla_host_transfer_sequencer";
47
AddGraphDefToFunctionLibrary(const GraphDefBuilder & graphdef_builder,const string & name_suffix,FunctionDefLibrary * library)48 Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder,
49 const string& name_suffix,
50 FunctionDefLibrary* library) {
51 GraphDef graphdef;
52 TF_RETURN_IF_ERROR(graphdef_builder.ToGraphDef(&graphdef));
53 std::unique_ptr<Graph> graph =
54 std::unique_ptr<Graph>(new Graph(OpRegistry::Global()));
55 GraphConstructorOptions opts;
56 opts.allow_internal_ops = true;
57 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graphdef, graph.get()));
58 FunctionDef* fdef = library->add_function();
59 TF_RETURN_IF_ERROR(GraphToFunctionDef(
60 *graph,
61 absl::StrCat("_outside_compilation_shape_inference_", name_suffix),
62 fdef));
63 return OkStatus();
64 }
65
66 template <class Tkey, class Tvalue>
EqualProtoMap(const::tensorflow::protobuf::Map<Tkey,Tvalue> & a,const::tensorflow::protobuf::Map<Tkey,Tvalue> & b,const std::function<string (const Tkey &)> & key_to_string,const std::function<string (const Tvalue &)> & value_to_string,const std::function<bool (const Tkey &,const Tvalue &,const Tvalue &)> & compare,const string & map_name,string * diff)67 bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
68 const ::tensorflow::protobuf::Map<Tkey, Tvalue>& b,
69 const std::function<string(const Tkey&)>& key_to_string,
70 const std::function<string(const Tvalue&)>& value_to_string,
71 const std::function<bool(const Tkey&, const Tvalue&,
72 const Tvalue&)>& compare,
73 const string& map_name, string* diff) {
74 for (const auto& elt_a : a) {
75 const auto iter = b.find(elt_a.first);
76 if (iter == b.end()) {
77 if (diff) {
78 *diff = absl::StrCat(map_name, " expected: contains element with key '",
79 key_to_string(elt_a.first),
80 "' got: map has no such element");
81 }
82 return false;
83 }
84 if (!compare(elt_a.first, elt_a.second, iter->second)) {
85 if (diff) {
86 *diff = absl::StrCat(map_name, " expected: element with key '",
87 key_to_string(elt_a.first), "' has value '",
88 value_to_string(elt_a.second), "' got: '",
89 value_to_string(iter->second), "'");
90 }
91 return false;
92 }
93 }
94 for (const auto& elt_b : b) {
95 const auto iter = a.find(elt_b.first);
96 if (iter == a.end()) {
97 if (diff) {
98 *diff = absl::StrCat(map_name, " got: contains element with key '",
99 key_to_string(elt_b.first),
100 "' expected: map has no such element");
101 }
102 return false;
103 }
104 }
105 return true;
106 }
107
EqualFunctionNodeDef(const NodeDef & a,const NodeDef & b,const string & diff_preamble,string * diff)108 bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
109 const string& diff_preamble, string* diff) {
110 if (a.op() != b.op()) {
111 if (diff) {
112 *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
113 ", expected op '", a.op(), "' got '", b.op());
114 }
115 return false;
116 }
117 if (a.device() != b.device()) {
118 if (diff) {
119 *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
120 ", expected device '", a.device(), "' got '",
121 b.device());
122 }
123 return false;
124 }
125 if (a.input_size() != b.input_size()) {
126 if (diff) {
127 *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
128 ", expected ", a.input_size(), " inputs got ",
129 b.input_size(), " expected:\n", a.DebugString(),
130 "\ngot:\n", b.DebugString());
131 }
132 return false;
133 }
134 std::unordered_set<string> control_input_a;
135 std::unordered_set<string> control_input_b;
136 for (int i = 0; i < a.input_size(); ++i) {
137 if (absl::StartsWith(a.input(i), "^")) {
138 if (!absl::StartsWith(b.input(i), "^")) {
139 if (diff) {
140 *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
141 " input ", i, ", expected control input ",
142 a.input(i), " got ", b.input(i), " expected:\n",
143 a.DebugString(), "\ngot:\n", b.DebugString());
144 }
145 return false;
146 }
147 control_input_a.insert(a.input(i));
148 control_input_b.insert(b.input(i));
149 } else if (a.input(i) != b.input(i)) {
150 if (diff) {
151 *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
152 " input ", i, ", expected ", a.input(i), " got ",
153 b.input(i), " expected:\n", a.DebugString(),
154 "\ngot:\n", b.DebugString());
155 }
156 return false;
157 }
158 }
159 if (control_input_a != control_input_b) {
160 if (diff) {
161 *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
162 " control inputs differ expected:\n",
163 a.DebugString(), "\ngot:\n", b.DebugString());
164 }
165 return false;
166 }
167 return EqualProtoMap<string, AttrValue>(
168 a.attr(), b.attr(), [](const string& s) { return s; },
169 [](const AttrValue& v) { return v.DebugString(); },
170 [](const string& key, const AttrValue& av, const AttrValue& bv) {
171 if (key == "ancestors") {
172 // The ancestors are added from a set so the order is unpredictable;
173 // just compare set equality not list equality.
174 std::unordered_set<string> a_set(av.list().s().begin(),
175 av.list().s().end());
176 std::unordered_set<string> b_set(bv.list().s().begin(),
177 bv.list().s().end());
178 return a_set == b_set;
179 } else {
180 return av.DebugString() == bv.DebugString();
181 }
182 },
183 absl::StrCat(diff_preamble, " attr mismatch for node ", a.name()), diff);
184 }
185
EqualFunctionDef(const FunctionDef & a,const FunctionDef & b,string * diff)186 bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
187 string* diff) {
188 if (a.signature().DebugString() != b.signature().DebugString()) {
189 if (diff) {
190 *diff =
191 absl::StrCat("Signature mismatch for function ", a.signature().name(),
192 ", expected:\n", a.signature().DebugString(), "\ngot:\n",
193 b.signature().DebugString());
194 }
195 return false;
196 }
197 if (!EqualProtoMap<string, AttrValue>(
198 a.attr(), b.attr(), [](const string& s) { return s; },
199 [](const AttrValue& v) { return v.DebugString(); },
200 [](const string& key, const AttrValue& av, const AttrValue& bv) {
201 return av.DebugString() == bv.DebugString();
202 },
203 absl::StrCat("attr mismatch for function ", a.signature().name()),
204 diff)) {
205 return false;
206 }
207 if (!EqualProtoMap<string, string>(
208 a.ret(), b.ret(), [](const string& s) { return s; },
209 [](const string& s) { return s; },
210 [](const string& key, const string& av, const string& bv) {
211 return av == bv;
212 },
213 absl::StrCat("ret mismatch for function ", a.signature().name()),
214 diff)) {
215 return false;
216 }
217 for (int i = 0; i < a.node_def_size(); ++i) {
218 bool found = false;
219 for (int j = 0; j < b.node_def_size(); ++j) {
220 if (a.node_def(i).name() == b.node_def(j).name()) {
221 if (!EqualFunctionNodeDef(
222 a.node_def(i), b.node_def(j),
223 absl::StrCat("Function ", a.signature().name()), diff)) {
224 return false;
225 }
226 found = true;
227 break;
228 }
229 }
230 if (!found) {
231 if (diff) {
232 *diff = absl::StrCat("Function ", a.signature().name(),
233 ", expected: has node '", a.node_def(i).name(),
234 "' got: no node of that name");
235 }
236 return false;
237 }
238 }
239 for (int i = 0; i < b.node_def_size(); ++i) {
240 bool found = false;
241 for (int j = 0; j < a.node_def_size(); ++j) {
242 if (b.node_def(i).name() == a.node_def(j).name()) {
243 found = true;
244 break;
245 }
246 }
247 if (!found) {
248 if (diff) {
249 *diff = absl::StrCat("Function ", a.signature().name(),
250 ", got: has node '", b.node_def(i).name(),
251 "' expected: no node of that name");
252 }
253 return false;
254 }
255 }
256 return true;
257 }
258
EqualFunctionDefLibrary(const FunctionDefLibrary & expected,const FunctionDefLibrary & actual,string * diff)259 bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
260 const FunctionDefLibrary& actual, string* diff) {
261 std::unordered_map<string, const FunctionDef*> actual_index;
262 for (const FunctionDef& function : actual.function()) {
263 actual_index[function.signature().name()] = &function;
264 }
265
266 for (const FunctionDef& expected_function : expected.function()) {
267 auto it = actual_index.find(expected_function.signature().name());
268 if (it == actual_index.end()) {
269 if (diff) {
270 *diff = absl::StrCat("Did not find expected function '",
271 expected_function.signature().name(), "'");
272 }
273 return false;
274 }
275 if (!EqualFunctionDef(expected_function, *it->second, diff)) return false;
276 actual_index.erase(it);
277 }
278
279 if (!actual_index.empty()) {
280 if (diff != nullptr) {
281 *diff =
282 absl::StrCat("Found unexpected function '",
283 actual_index.begin()->second->signature().name(), "'");
284 }
285 return false;
286 }
287
288 return true;
289 }
290
291 #define TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(expected, actual) \
292 do { \
293 string diff; \
294 EXPECT_TRUE(EqualFunctionDefLibrary(expected, actual, &diff)) \
295 << diff << "\nActual: " << actual.DebugString(); \
296 } while (false)
297
298 REGISTER_OP("InputTest")
299 .Output("o: float")
__anon4e24c0570b02(::tensorflow::shape_inference::InferenceContext* c) 300 .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
301 c->set_output(0, c->UnknownShape());
302 return OkStatus();
303 });
304
305 REGISTER_OP("InputTestShaped")
306 .Output("o: float")
__anon4e24c0570c02(::tensorflow::shape_inference::InferenceContext* c) 307 .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
308 c->set_output(0, c->Vector(2));
309 return OkStatus();
310 });
311
312 REGISTER_OP("UnaryTest")
313 .Input("a: float")
314 .Output("o: float")
__anon4e24c0570d02(::tensorflow::shape_inference::InferenceContext* c) 315 .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
316 ::tensorflow::shape_inference::ShapeHandle o;
317 TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
318 c->set_output(0, o);
319 return OkStatus();
320 });
321 REGISTER_OP("BinaryTest")
322 .Input("a: float")
323 .Input("b: float")
324 .Output("o: float")
__anon4e24c0570e02(::tensorflow::shape_inference::InferenceContext* c) 325 .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
326 ::tensorflow::shape_inference::ShapeHandle o;
327 TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
328 c->set_output(0, o);
329 return OkStatus();
330 });
331 REGISTER_OP("BinaryTest2")
332 .Input("a: float")
333 .Input("b: float")
334 .Output("o: float")
335 .SetShapeFn(::tensorflow::shape_inference::UnknownShape);
336
337 REGISTER_OP("AddNLikeTest")
338 .Input("inputs: N * T")
339 .Output("sum: T")
340 .Attr("N: int >= 1")
341 .Attr("T: numbertype")
342 .SetIsCommutative()
343 .SetIsAggregate();
344
Sequencer(const GraphDefBuilder::Options & opts,const string & call_node_name)345 Node* Sequencer(const GraphDefBuilder::Options& opts,
346 const string& call_node_name) {
347 if (opts.HaveError()) return nullptr;
348 NodeBuilder node_builder(opts.GetNameForOp("NoOp"), "NoOp",
349 opts.op_registry());
350 return opts.WithAttr(kXlaHostTransferSequencerAttr, call_node_name)
351 .FinalizeBuilder(&node_builder);
352 }
353
Input(const GraphDefBuilder::Options & opts)354 Node* Input(const GraphDefBuilder::Options& opts) {
355 return ops::SourceOp("InputTest", opts);
356 }
357
InputShaped(const GraphDefBuilder::Options & opts)358 Node* InputShaped(const GraphDefBuilder::Options& opts) {
359 return ops::SourceOp("InputTestShaped", opts);
360 }
361
KnownShapeBase(DataType dtype,absl::Span<const int> shape,const GraphDefBuilder::Options & opts)362 Node* KnownShapeBase(DataType dtype, absl::Span<const int> shape,
363 const GraphDefBuilder::Options& opts) {
364 if (opts.HaveError()) return nullptr;
365 NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const",
366 opts.op_registry());
367 TensorProto value;
368 value.set_dtype(dtype);
369 for (int dim : shape) {
370 value.mutable_tensor_shape()->add_dim()->set_size(dim);
371 }
372 return opts.WithAttr("value", value)
373 .WithAttr("dtype", dtype)
374 .FinalizeBuilder(&node_builder);
375 }
376
KnownShape(absl::Span<const int> shape,const GraphDefBuilder::Options & opts)377 Node* KnownShape(absl::Span<const int> shape,
378 const GraphDefBuilder::Options& opts) {
379 return KnownShapeBase(DT_FLOAT, shape, opts);
380 }
381
KeyPlaceholderShape(const GraphDefBuilder::Options & opts)382 Node* KeyPlaceholderShape(const GraphDefBuilder::Options& opts) {
383 return KnownShapeBase(DT_STRING, {2}, opts);
384 }
385
KeyPlaceholder(const string & call_node,const GraphDefBuilder::Options & opts)386 Node* KeyPlaceholder(const string& call_node,
387 const GraphDefBuilder::Options& opts) {
388 if (opts.HaveError()) return nullptr;
389 NodeBuilder node_builder(absl::StrCat(call_node, "_key_placeholder"),
390 "Placeholder", opts.op_registry());
391 TensorShapeProto shape;
392 shape.add_dim()->set_size(2);
393 return opts.WithAttr("shape", shape)
394 .WithAttr("dtype", DT_STRING)
395 .WithAttr("_host_compute_call_node", call_node)
396 .FinalizeBuilder(&node_builder);
397 }
398
RecvAtHost(ops::NodeOut key_input,const string & cluster,const string & new_func_name,const string & oc_cluster,absl::Span<const DataType> dtypes,const GraphDefBuilder::Options & opts)399 Node* RecvAtHost(ops::NodeOut key_input, const string& cluster,
400 const string& new_func_name, const string& oc_cluster,
401 absl::Span<const DataType> dtypes,
402 const GraphDefBuilder::Options& opts) {
403 if (opts.HaveError()) return nullptr;
404 string key = absl::StrCat("host_compute_channel_", cluster, "_",
405 new_func_name, "_", oc_cluster);
406 string name = absl::StrCat("outside_compilation_", cluster, "_",
407 new_func_name, "_", oc_cluster, "_recv");
408 NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"),
409 "_XlaRecvAtHost", opts.op_registry());
410 node_builder.Input(std::move(key_input));
411 return opts.WithAttr("Toutputs", dtypes)
412 .WithAttr("key", key)
413 .WithAttr("device_ordinal", 0)
414 .WithAttr("_encapsulate", cluster)
415 .WithAttr("_outside", oc_cluster)
416 .FinalizeBuilder(&node_builder);
417 }
418
SendFromHost(ops::NodeOut key_input,const string & cluster,const string & new_func_name,const string & oc_cluster,const std::vector<ops::NodeOut> & inputs,const GraphDefBuilder::Options & opts)419 Node* SendFromHost(ops::NodeOut key_input, const string& cluster,
420 const string& new_func_name, const string& oc_cluster,
421 const std::vector<ops::NodeOut>& inputs,
422 const GraphDefBuilder::Options& opts) {
423 if (opts.HaveError()) return nullptr;
424 string key = absl::StrCat("host_compute_channel_", cluster, "_",
425 new_func_name, "_", oc_cluster);
426 string name = absl::StrCat("outside_compilation_", cluster, "_",
427 new_func_name, "_", oc_cluster, "_send");
428 NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"),
429 "_XlaSendFromHost", opts.op_registry());
430 node_builder.Input(inputs);
431 node_builder.Input(std::move(key_input));
432 std::vector<DataType> dtypes;
433 for (const auto& node : inputs) {
434 dtypes.push_back(node.dt);
435 }
436 return opts.WithAttr("Tinputs", dtypes)
437 .WithAttr("key", key)
438 .WithAttr("device_ordinal", 0)
439 .WithAttr("_encapsulate", cluster)
440 .WithAttr("_outside", oc_cluster)
441 .FinalizeBuilder(&node_builder);
442 }
443
Unary(ops::NodeOut a,const GraphDefBuilder::Options & opts)444 Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
445 return ops::UnaryOp("UnaryTest", std::move(a), opts);
446 }
447
Binary(ops::NodeOut a,ops::NodeOut b,const GraphDefBuilder::Options & opts)448 Node* Binary(ops::NodeOut a, ops::NodeOut b,
449 const GraphDefBuilder::Options& opts) {
450 return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts);
451 }
452
BinaryUnknownShape(ops::NodeOut a,ops::NodeOut b,const GraphDefBuilder::Options & opts)453 Node* BinaryUnknownShape(ops::NodeOut a, ops::NodeOut b,
454 const GraphDefBuilder::Options& opts) {
455 return ops::BinaryOp("BinaryTest2", std::move(a), std::move(b), opts);
456 }
457
AddNLike(const std::vector<ops::NodeOut> & inputs,const GraphDefBuilder::Options & opts)458 Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
459 const GraphDefBuilder::Options& opts) {
460 if (opts.HaveError()) return nullptr;
461 NodeBuilder node_builder(opts.GetNameForOp("AddN"), "AddNLikeTest",
462 opts.op_registry());
463 node_builder.Input(inputs);
464 return opts.FinalizeBuilder(&node_builder);
465 }
466
ArgOp(int index,DataType type,const GraphDefBuilder::Options & opts)467 Node* ArgOp(int index, DataType type, const GraphDefBuilder::Options& opts) {
468 return ops::SourceOp("_Arg",
469 opts.WithAttr("T", type).WithAttr("index", index));
470 }
471
RetOp(int index,ops::NodeOut a,const GraphDefBuilder::Options & opts)472 Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) {
473 if (opts.HaveError()) return nullptr;
474 NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval",
475 opts.op_registry());
476 node_builder.Input(std::move(a)).Attr("index", index);
477 return opts.FinalizeBuilder(&node_builder);
478 }
479
Encapsulate(GraphDef * graphdef,FunctionDefLibrary * library,const std::vector<string> & encapsulated_functions)480 Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
481 const std::vector<string>& encapsulated_functions) {
482 Status s;
483 // Convert the GraphDef to a Graph
484 std::unique_ptr<FunctionLibraryDefinition> lib_def(
485 new FunctionLibraryDefinition(OpRegistry::Global(), *library));
486 GraphConstructorOptions options;
487 options.allow_internal_ops = true;
488 std::unique_ptr<Graph> graph(new Graph(lib_def.get()));
489 s = ConvertGraphDefToGraph(options, *graphdef, graph.get());
490 if (!s.ok()) return s;
491
492 s = PerformStaticShapeInferenceBeforeEncapsulation(graph.get());
493 if (!s.ok()) return s;
494
495 // Create FunctionLibraryRuntime.
496 SessionOptions session_options;
497 std::vector<std::unique_ptr<Device>> devices;
498 TF_CHECK_OK(DeviceFactory::AddDevices(
499 session_options, "/job:localhost/replica:0/task:0", &devices));
500 OptimizerOptions opts;
501 auto device_mgr = std::make_unique<StaticDeviceMgr>(std::move(devices));
502 auto pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
503 device_mgr.get(), Env::Default(), /*config=*/nullptr,
504 TF_GRAPH_DEF_VERSION, lib_def.get(), opts,
505 /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
506 auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
507
508 std::unique_ptr<Graph> graph_out;
509 s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph,
510 /*rewrite_subgraph_fn=*/{},
511 /*reuse_existing_functions=*/false,
512 &graph_out, lib_def.get());
513 if (!s.ok()) return s;
514
515 std::unordered_map<string, XlaClusterInfo> clusters;
516 for (const auto& func : encapsulated_functions) {
517 Node* xla_computation_node;
518 for (Node* n : graph_out->nodes()) {
519 if (n->name() == func) {
520 xla_computation_node = n;
521 }
522 }
523 if (!xla_computation_node) {
524 return errors::Internal("Cannot find node ", func);
525 }
526 NameAttrList func_name_attrs;
527 func_name_attrs.set_name(func);
528 clusters.emplace(func,
529 XlaClusterInfo{func, func_name_attrs, xla_computation_node,
530 std::map<string, int>{}});
531 }
532 bool modified;
533 s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters,
534 graph_out.get(), flr, lib_def.get(), &modified);
535 if (!s.ok()) return s;
536
537 GraphDef graphdef_out;
538 graph_out->ToGraphDef(&graphdef_out);
539 graphdef->Swap(&graphdef_out);
540
541 *library = lib_def->ToProto();
542 // Remove "_xla_inferred_shapes" attr. They are added by
543 // `PerformStaticShapeInferenceBeforeEncapsulation`.
544 for (FunctionDef& fdef : *library->mutable_function()) {
545 for (NodeDef& node_def : *fdef.mutable_node_def()) {
546 node_def.mutable_attr()->erase("_xla_inferred_shapes");
547 }
548 }
549
550 return s;
551 }
552
Encapsulate(GraphDef * graphdef,FunctionDefLibrary * library)553 Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
554 std::vector<string> encapsulated_functions;
555 return Encapsulate(graphdef, library, encapsulated_functions);
556 }
557
558 // If there are no marked nodes, funcification should be a no-op.
TEST(EncapsulateSubgraphsTest,NoFunctions)559 TEST(EncapsulateSubgraphsTest, NoFunctions) {
560 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
561
562 Node* a = Input(builder.opts().WithName("A"));
563 Node* b = Input(builder.opts().WithName("B"));
564 Node* c = Unary(a, builder.opts().WithName("C"));
565 Binary(b, c, builder.opts().WithName("D"));
566
567 GraphDef graphdef_in;
568 FunctionDefLibrary library_in;
569 TF_EXPECT_OK(builder.ToGraphDef(&graphdef_in));
570 *library_in.add_function() = test::function::XTimesTwo();
571
572 GraphDef graphdef_out = graphdef_in;
573 FunctionDefLibrary library_out = library_in;
574 TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out));
575
576 // If there are no marked nodes, funcification should be a no-op.
577 TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out);
578 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out);
579 }
580
581 // Test with one function to transform.
TEST(EncapsulateSubgraphsTest,OneFunction)582 TEST(EncapsulateSubgraphsTest, OneFunction) {
583 FunctionDefLibrary library;
584 GraphDef graphdef;
585
586 {
587 *library.add_function() = test::function::XTimesTwo();
588
589 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
590 Node* a = Input(b1.opts().WithName("A"));
591 Node* b = Input(b1.opts().WithName("B"));
592 // Give nodes 'c' and 'd' names that collide after lowercasing.
593 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
594 Node* d = Binary(b, c,
595 b1.opts().WithName("c").WithControlInput(c).WithAttr(
596 "_encapsulate", "F1"));
597 Binary(a, d, b1.opts().WithName("E"));
598 TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
599 }
600
601 TF_EXPECT_OK(Encapsulate(&graphdef, &library));
602
603 FunctionDefLibrary library_expected;
604 GraphDef graphdef_expected;
605
606 *library_expected.add_function() = test::function::XTimesTwo();
607 *library_expected.add_function() = FunctionDefHelper::Create(
608 "F1", {"a_0_arg:float", "b_0_arg:float"}, {"c_0_retval:float"}, {},
609 {
610 {{"C"}, "UnaryTest", {"a_0_arg"}},
611 {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
612 },
613 {{"c_0_retval", "c:o:0"}});
614
615 {
616 std::unique_ptr<FunctionLibraryDefinition> lib_def(
617 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
618 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
619 Node* a = Input(b2.opts().WithName("A"));
620 Node* b = Input(b2.opts().WithName("B"));
621
622 NodeBuilder node_builder("F1", "F1", lib_def.get());
623 node_builder.Input(a).Input(b);
624 Node* call = b2.opts().FinalizeBuilder(&node_builder);
625
626 Binary(a, call, b2.opts().WithName("E"));
627 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
628 }
629
630 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
631 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
632 }
633
634 // Test with two functions to transform.
TEST(EncapsulateSubgraphsTest,TwoFunctions)635 TEST(EncapsulateSubgraphsTest, TwoFunctions) {
636 FunctionDefLibrary library;
637 GraphDef graphdef;
638
639 {
640 *library.add_function() = test::function::XTimesTwo();
641
642 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
643 Node* a = Input(b1.opts().WithName("A"));
644 Node* b = Input(b1.opts().WithName("B"));
645 Node* control = Input(b1.opts().WithName("Control"));
646 Node* c =
647 Unary(a, b1.opts().WithName("C").WithControlInput(control).WithAttr(
648 "_encapsulate", "F1"));
649 Node* d = Binary(b, c,
650 b1.opts().WithName("D").WithControlInput(control).WithAttr(
651 "_encapsulate", "F2"));
652 Binary(a, d, b1.opts().WithName("E"));
653 TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
654 }
655
656 TF_EXPECT_OK(Encapsulate(&graphdef, &library));
657
658 FunctionDefLibrary library_expected;
659 GraphDef graphdef_expected;
660
661 *library_expected.add_function() = test::function::XTimesTwo();
662 *library_expected.add_function() = FunctionDefHelper::Create(
663 "F1", {"a_0_arg:float"}, {"c_0_retval:float"}, {},
664 {
665 {{"C"}, "UnaryTest", {"a_0_arg"}},
666 },
667 {{"c_0_retval", "C:o:0"}});
668 *library_expected.add_function() = FunctionDefHelper::Create(
669 "F2", {"b_0_arg:float", "c_0_arg:float"}, {"d_0_retval:float"}, {},
670 {
671 {{"D"}, "BinaryTest", {"b_0_arg", "c_0_arg"}},
672 },
673 {{"d_0_retval", "D:o:0"}});
674
675 {
676 std::unique_ptr<FunctionLibraryDefinition> lib_def(
677 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
678 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
679 Node* a = Input(b2.opts().WithName("A"));
680 Node* b = Input(b2.opts().WithName("B"));
681 Node* control = Input(b2.opts().WithName("Control"));
682
683 NodeBuilder nb("F1", "F1", lib_def.get());
684 nb.Input(a).ControlInput(control);
685 Node* call1 = b2.opts().FinalizeBuilder(&nb);
686
687 NodeBuilder nb2("F2", "F2", lib_def.get());
688 nb2.Input(b).Input(call1).ControlInput(control);
689 Node* call2 = b2.opts().FinalizeBuilder(&nb2);
690
691 Binary(a, call2, b2.opts().WithName("E"));
692 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
693 }
694
695 // If there are no marked nodes, funcification should be a no-op.
696 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
697 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
698 }
699
700 // Returns a vector of node names in 'graph', sorted by name.
GraphNodes(const Graph & graph)701 std::vector<string> GraphNodes(const Graph& graph) {
702 std::vector<string> nodes;
703 for (const auto& node : graph.nodes()) {
704 if (!node->IsSource() && !node->IsSink()) {
705 nodes.push_back(node->name());
706 }
707 }
708 std::sort(nodes.begin(), nodes.end());
709 return nodes;
710 }
711
712 // Returns a sorted vector of (src, dst) edges in 'graph'.
GraphEdges(const Graph & graph)713 std::vector<std::pair<string, string>> GraphEdges(const Graph& graph) {
714 std::vector<std::pair<string, string>> edges;
715 for (const Edge* edge : graph.edges()) {
716 if (edge->src()->IsSource() || edge->dst()->IsSink()) continue;
717 edges.emplace_back(
718 absl::StrCat(edge->src()->name(), ":", edge->src_output()),
719 absl::StrCat(edge->dst()->name(), ":", edge->dst_input()));
720 }
721 std::sort(edges.begin(), edges.end());
722 return edges;
723 }
724
TEST(EncapsulateSubgraphsTest,InputDeduplication)725 TEST(EncapsulateSubgraphsTest, InputDeduplication) {
726 Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
727 "/job:localhost/replica:0/task:0/cpu:0");
728 auto x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT);
729 auto add1 = ops::Add(root.WithOpName("add1"), x, x);
730 add1.node()->AddAttr("_cluster", "cluster1");
731 auto add2 = ops::Add(root.WithOpName("add2"), add1, add1);
732 add2.node()->AddAttr("_cluster", "cluster2");
733 auto out = ops::Mul(root.WithOpName("mul"), add1, add2);
734
735 Graph graph_before_encapsulation(OpRegistry::Global());
736 TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation));
737
738 FunctionLibraryDefinition library(OpRegistry::Global(), {});
739 std::unique_ptr<Graph> graph;
740 TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
741 "_cluster", graph_before_encapsulation,
742 /*rewrite_subgraph_fn=*/{},
743 /*reuse_existing_functions=*/false, &graph, &library));
744
745 std::vector<string> expected_nodes = {"cluster1", "cluster2", "mul", "x"};
746 EXPECT_EQ(expected_nodes, GraphNodes(*graph));
747
748 std::vector<std::pair<string, string>> expected_edges = {
749 {"cluster1:0", "cluster2:0"},
750 {"cluster1:0", "mul:0"},
751 {"cluster2:0", "mul:1"},
752 {"x:0", "cluster1:0"}};
753 EXPECT_EQ(expected_edges, GraphEdges(*graph));
754 }
755
FindNodeByName(const Graph & graph,const string & name)756 const Node* FindNodeByName(const Graph& graph, const string& name) {
757 for (const Node* node : graph.nodes()) {
758 if (node->name() == name) return node;
759 }
760 return nullptr;
761 }
762
HasGuaranteeConstAttr(const Node & n)763 bool HasGuaranteeConstAttr(const Node& n) {
764 bool is_guaranteed_constant = false;
765 if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant",
766 &is_guaranteed_constant)
767 .ok()) {
768 return false;
769 }
770 return is_guaranteed_constant;
771 }
772
TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest,Simple)773 TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
774 Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
775 "/job:localhost/replica:0/task:0/cpu:0");
776 auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
777 auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
778 auto const_guarantee_x2 =
779 ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2);
780 auto const_guarantee_x1 =
781 ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1);
782 auto add1 =
783 ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_guarantee_x2);
784 add1.node()->AddAttr("_encapsulate", "encapsulate1");
785
786 Graph graph_before(OpRegistry::Global());
787 TF_ASSERT_OK(root.ToGraph(&graph_before));
788
789 std::unique_ptr<Graph> graph_after;
790 FunctionLibraryDefinition library(OpRegistry::Global(), {});
791 int guaranteed_consts = 0;
792 TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
793 "_encapsulate", graph_before,
794 /*rewrite_subgraph_fn=*/
795 [&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
796 std::unique_ptr<Graph>* graph_ptr,
797 std::vector<int>* input_permutation,
798 std::vector<int>* output_permutation,
799 NodeDef* call_def) {
800 Graph* graph = graph_ptr->get();
801 for (const Node* n : graph->nodes()) {
802 if (n->type_string() == "_Arg" &&
803 absl::StartsWith(n->name(), "const")) {
804 ++guaranteed_consts;
805 EXPECT_TRUE(HasGuaranteeConstAttr(*n));
806 } else {
807 EXPECT_FALSE(HasGuaranteeConstAttr(*n));
808 }
809 }
810 return OkStatus();
811 },
812 /*reuse_existing_functions=*/false, &graph_after, &library));
813 EXPECT_EQ(2, guaranteed_consts);
814 }
815
TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest,Add)816 TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
817 Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
818 "/job:localhost/replica:0/task:0/cpu:0");
819 auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
820 auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
821 auto const_guarantee_x1 =
822 ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1);
823 auto const_guarantee_x2 =
824 ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2);
825 auto const_guarantee_add1 = ops::Add(root.WithOpName("const_guarantee_add1"),
826 const_guarantee_x1, const_guarantee_x2);
827 auto add2 = ops::Add(root.WithOpName("add2"), const_guarantee_x1, x2);
828 auto mul1 = ops::Mul(root.WithOpName("mul1"), const_guarantee_add1, add2);
829 mul1.node()->AddAttr("_encapsulate", "encapsulate1");
830
831 Graph graph_before(OpRegistry::Global());
832 TF_ASSERT_OK(root.ToGraph(&graph_before));
833
834 std::unique_ptr<Graph> graph_after;
835 FunctionLibraryDefinition library(OpRegistry::Global(), {});
836 int guaranteed_consts = 0;
837 TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
838 "_encapsulate", graph_before,
839 /*rewrite_subgraph_fn=*/
840 [&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
841 std::unique_ptr<Graph>* graph_ptr,
842 std::vector<int>* input_permutation,
843 std::vector<int>* output_permutation,
844 NodeDef* call_def) {
845 Graph* graph = graph_ptr->get();
846 for (const Node* n : graph->nodes()) {
847 if (n->type_string() == "_Arg" &&
848 absl::StartsWith(n->name(), "const")) {
849 ++guaranteed_consts;
850 EXPECT_TRUE(HasGuaranteeConstAttr(*n));
851 } else {
852 EXPECT_FALSE(HasGuaranteeConstAttr(*n));
853 }
854 }
855 return OkStatus();
856 },
857 /*reuse_existing_functions=*/false, &graph_after, &library));
858 // Only 1 runtime const, which is const_guarantee_add1. Add2 has one const
859 // and another non-const, so overall non-const.
860 EXPECT_EQ(1, guaranteed_consts);
861 }
862
863 // Test with one function to transform and one outside_compilation cluster.
TEST(EncapsulateSubgraphsTest,OneFunctionOneOutside)864 TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
865 FunctionDefLibrary library;
866 GraphDef graphdef;
867
868 {
869 *library.add_function() = test::function::XTimesTwo();
870
871 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
872 Node* a = Input(b1.opts().WithName("A"));
873 Node* b = Input(b1.opts().WithName("B"));
874 // Give nodes 'c' and 'd' names that collide after lowercasing.
875 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
876 Node* d = Binary(b, c,
877 b1.opts().WithName("c").WithControlInput(c).WithAttr(
878 "_encapsulate", "F1"));
879 Node* e = Binary(c, d,
880 b1.opts()
881 .WithName("E")
882 .WithControlInputs({b, d})
883 .WithAttr("_encapsulate", "F1")
884 .WithAttr("_outside", "O1"));
885 Node* f = Binary(c, e,
886 b1.opts().WithName("F").WithControlInput(e).WithAttr(
887 "_encapsulate", "F1"));
888 Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
889 TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
890 }
891
892 std::vector<string> encapsulated_functions{"F1"};
893 TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
894
895 FunctionDefLibrary library_expected;
896 GraphDef graphdef_expected;
897
898 {
899 GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
900 Node* key_constant = KeyPlaceholder("F1", shape.opts());
901 Node* recv = RecvAtHost(
902 ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
903 shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
904 Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
905 shape.opts()
906 .WithName("E")
907 .WithAttr("_encapsulate", "F1")
908 .WithAttr("_outside", "O1"));
909 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
910 shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
911 TF_EXPECT_OK(
912 AddGraphDefToFunctionLibrary(shape, "F1_F1_O1", &library_expected));
913 }
914
915 NameAttrList shape_inference_graph;
916 shape_inference_graph.set_name(
917 "_outside_compilation_shape_inference_F1_F1_O1");
918 *library_expected.add_function() = test::function::XTimesTwo();
919 *library_expected.add_function() = FunctionDefHelper::Create(
920 "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
921 {
922 {{"C"}, "UnaryTest", {"a_0_arg"}},
923 {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
924 {{"F"},
925 "BinaryTest",
926 {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
927 {},
928 {"outside_compilation_O1_host_compute"}},
929 {{"outside_compilation_O1_host_compute"},
930 "XlaHostCompute",
931 {"C:o:0", "c:o:0"},
932 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
933 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
934 {"ancestors", absl::Span<const string>({})},
935 {"key", "host_compute_channel_F1_F1_O1"},
936 {"send_key", ""},
937 {"recv_key", ""},
938 {"shape_inference_graph", shape_inference_graph},
939 {"tpu_core", 0},
940 {"cost_estimate_ns", 1000000},
941 {"shapes", absl::Span<const DataType>({})},
942 {"_outside_compilation_subgraph", "O1"},
943 {"_xla_token_input_nodes",
944 absl::Span<const string>({"_xla_token_arg_node"})},
945 {"_xla_original_oc_node_name",
946 "outside_compilation_O1_host_compute"}},
947 {"c"}},
948 },
949 {{"f_0_retval_retval", "F:o:0"}});
950
951 {
952 std::unique_ptr<FunctionLibraryDefinition> lib_def(
953 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
954 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
955 Node* a = Input(b2.opts().WithName("A"));
956 Node* b = Input(b2.opts().WithName("B"));
957
958 Node* key_constant =
959 KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
960 Node* recv = RecvAtHost(
961 ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
962 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
963 Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
964 b2.opts()
965 .WithName("E")
966 .WithControlInputs({recv})
967 .WithAttr("_encapsulate", "F1")
968 .WithAttr("_outside", "O1"));
969 Node* send =
970 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
971 b2.opts().WithControlInput(e).WithAttr(
972 kXlaHasHostTransferAttrName, true));
973
974 Node* s = Sequencer(
975 b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
976 "F1");
977
978 NodeBuilder node_builder("F1", "F1", lib_def.get());
979 node_builder.Input(a).Input(b);
980 Node* call =
981 b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder);
982
983 Binary(a, call, b2.opts().WithName("G").WithControlInputs({call}));
984 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
985 }
986
987 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
988 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
989 }
990
991 // Test with one function to transform and two outside_compilation clusters.
TEST(EncapsulateSubgraphsTest,OneFunctionTwoOutside)992 TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
993 FunctionDefLibrary library;
994 GraphDef graphdef;
995
996 {
997 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
998 Node* a = Input(b1.opts().WithName("A"));
999 Node* b = Input(b1.opts().WithName("B"));
1000 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1001 Node* d =
1002 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1003 Node* e = Binary(c, d,
1004 b1.opts()
1005 .WithName("E")
1006 .WithControlInputs({b, d})
1007 .WithAttr("_encapsulate", "F1")
1008 .WithAttr("_outside", "O1"));
1009 Node* f = Binary(c, e,
1010 b1.opts().WithName("F").WithControlInput(e).WithAttr(
1011 "_encapsulate", "F1"));
1012 Node* g = Binary(e, f,
1013 b1.opts()
1014 .WithName("G")
1015 .WithControlInputs({e, f})
1016 .WithAttr("_encapsulate", "F1")
1017 .WithAttr("_outside", "O2"));
1018 Node* h = Binary(d, e,
1019 b1.opts()
1020 .WithName("H")
1021 .WithAttr("_encapsulate", "F1")
1022 .WithAttr("_outside", "O2"));
1023 Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F1"));
1024 Binary(g, i, b1.opts().WithName("J"));
1025 TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1026 }
1027
1028 std::vector<string> encapsulated_functions{"F1"};
1029 TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1030
1031 FunctionDefLibrary library_expected;
1032 GraphDef graphdef_expected;
1033
1034 {
1035 GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
1036 Node* key_constant = KeyPlaceholder("F1", shape1.opts());
1037 Node* recv = RecvAtHost(
1038 ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
1039 shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1040 Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
1041 shape1.opts()
1042 .WithName("E")
1043 .WithAttr("_encapsulate", "F1")
1044 .WithAttr("_outside", "O1"));
1045 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
1046 shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1047 TF_EXPECT_OK(
1048 AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
1049 }
1050
1051 {
1052 GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
1053 Node* key_constant = KeyPlaceholder("F1", shape2.opts());
1054 Node* recv1 = RecvAtHost(
1055 ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
1056 shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1057 Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
1058 shape2.opts()
1059 .WithName("E")
1060 .WithAttr("_encapsulate", "F1")
1061 .WithAttr("_outside", "O1"));
1062 Node* recv2 = RecvAtHost(
1063 ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT, DT_FLOAT},
1064 shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1065 Node* g = Binary(e, ops::NodeOut(recv2, 0),
1066 shape2.opts()
1067 .WithName("G")
1068 .WithAttr("_encapsulate", "F1")
1069 .WithAttr("_outside", "O2"));
1070 Node* h = Binary(ops::NodeOut(recv2, 1), e,
1071 shape2.opts()
1072 .WithName("H")
1073 .WithAttr("_encapsulate", "F1")
1074 .WithAttr("_outside", "O2"));
1075 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {g, h},
1076 shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1077 TF_EXPECT_OK(
1078 AddGraphDefToFunctionLibrary(shape2, "F1_F1_O2", &library_expected));
1079 }
1080
1081 NameAttrList shape_inference_graph1, shape_inference_graph2;
1082 shape_inference_graph1.set_name(
1083 "_outside_compilation_shape_inference_F1_F1_O1");
1084 shape_inference_graph2.set_name(
1085 "_outside_compilation_shape_inference_F1_F1_O2");
1086 *library_expected.add_function() = FunctionDefHelper::Create(
1087 "F1", {"a_0_arg:float", "b_0_arg:float"},
1088 {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {},
1089 {
1090 {{"C"}, "UnaryTest", {"a_0_arg"}},
1091 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
1092 {{"I"},
1093 "UnaryTest",
1094 {"outside_compilation_O2_host_compute:outputs:1"}},
1095 {{"F"},
1096 "BinaryTest",
1097 {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
1098 {},
1099 {"outside_compilation_O1_host_compute"}},
1100 {{"outside_compilation_O2_host_compute"},
1101 "XlaHostCompute",
1102 {"F:o:0", "D:o:0"},
1103 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
1104 {"Toutputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
1105 {"ancestors", absl::Span<const string>({})},
1106 {"key", "host_compute_channel_F1_F1_O2"},
1107 {"send_key", ""},
1108 {"recv_key", ""},
1109 {"shape_inference_graph", shape_inference_graph2},
1110 {"tpu_core", 0},
1111 {"cost_estimate_ns", 1000000},
1112 {"shapes", absl::Span<const DataType>({})},
1113 {"_outside_compilation_subgraph", "O2"},
1114 {"_xla_token_input_nodes",
1115 absl::Span<const string>({"_xla_token_arg_node",
1116 "outside_compilation_O1_host_compute"})},
1117 {"_xla_original_oc_node_name",
1118 "outside_compilation_O2_host_compute"}},
1119 {"F", "outside_compilation_O1_host_compute"}},
1120 {{"outside_compilation_O1_host_compute"},
1121 "XlaHostCompute",
1122 {"C:o:0", "D:o:0"},
1123 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
1124 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1125 {"ancestors", absl::Span<const string>({})},
1126 {"key", "host_compute_channel_F1_F1_O1"},
1127 {"send_key", ""},
1128 {"recv_key", ""},
1129 {"shape_inference_graph", shape_inference_graph1},
1130 {"tpu_core", 0},
1131 {"cost_estimate_ns", 1000000},
1132 {"shapes", absl::Span<const DataType>({})},
1133 {"_outside_compilation_subgraph", "O1"},
1134 {"_xla_token_input_nodes",
1135 absl::Span<const string>({"_xla_token_arg_node"})},
1136 {"_xla_original_oc_node_name",
1137 "outside_compilation_O1_host_compute"}},
1138 {"D"}},
1139 },
1140 {{"g_0_retval_retval", "outside_compilation_O2_host_compute:outputs:0"},
1141 {"i_0_retval_retval", "I:o:0"}});
1142
1143 {
1144 std::unique_ptr<FunctionLibraryDefinition> lib_def(
1145 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
1146 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
1147 Node* a = Input(b2.opts().WithName("A"));
1148 Node* b = Input(b2.opts().WithName("B"));
1149
1150 Node* key_constant =
1151 KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
1152 Node* recv1 = RecvAtHost(
1153 ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
1154 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1155 Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
1156 b2.opts()
1157 .WithName("E")
1158 .WithControlInputs({recv1})
1159 .WithAttr("_encapsulate", "F1")
1160 .WithAttr("_outside", "O1"));
1161 Node* send1 =
1162 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
1163 b2.opts().WithControlInput(e).WithAttr(
1164 kXlaHasHostTransferAttrName, true));
1165
1166 Node* recv2 = RecvAtHost(
1167 ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT, DT_FLOAT},
1168 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1169 Node* g = Binary(e, ops::NodeOut(recv2, 0),
1170 b2.opts()
1171 .WithName("G")
1172 .WithControlInputs({recv2, e})
1173 .WithAttr("_encapsulate", "F1")
1174 .WithAttr("_outside", "O2"));
1175 Node* h = Binary(ops::NodeOut(recv2, 1), e,
1176 b2.opts()
1177 .WithName("H")
1178 .WithAttr("_encapsulate", "F1")
1179 .WithAttr("_outside", "O2"));
1180 Node* send2 =
1181 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {g, h},
1182 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1183
1184 Node* s = Sequencer(b2.opts()
1185 .WithName("F1_sequencer")
1186 .WithControlInputs({recv1, send1, recv2, send2}),
1187 "F1");
1188
1189 NodeBuilder node_builder("F1", "F1", lib_def.get());
1190 node_builder.Input(a).Input(b);
1191 Node* call =
1192 b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder);
1193
1194 Binary(ops::NodeOut(call, 0), ops::NodeOut(call, 1),
1195 b2.opts().WithName("J"));
1196 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
1197 }
1198 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
1199 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
1200 }
1201
1202 // Test with two functions to transform, each with one outside_compilation
1203 // cluster.
TEST(EncapsulateSubgraphsTest,TwoFunctionsTwoOutside)1204 TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
1205 FunctionDefLibrary library;
1206 GraphDef graphdef;
1207
1208 {
1209 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
1210 Node* a = InputShaped(b1.opts().WithName("A"));
1211 Node* b = InputShaped(b1.opts().WithName("B"));
1212 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1213 Node* d =
1214 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1215 Node* e = Binary(c, d,
1216 b1.opts()
1217 .WithName("E")
1218 .WithControlInputs({b, d})
1219 .WithAttr("_encapsulate", "F1")
1220 .WithAttr("_outside", "O1"));
1221 Node* f = Binary(c, e,
1222 b1.opts().WithName("F").WithControlInput(e).WithAttr(
1223 "_encapsulate", "F1"));
1224 Node* g = Binary(e, f,
1225 b1.opts().WithName("G").WithControlInputs({e, f}).WithAttr(
1226 "_encapsulate", "F2"));
1227 Node* h = Binary(d, g,
1228 b1.opts()
1229 .WithName("H")
1230 .WithAttr("_encapsulate", "F2")
1231 .WithAttr("_outside", "O1"));
1232 Node* i =
1233 Binary(f, h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2"));
1234 Binary(g, i, b1.opts().WithName("J"));
1235 TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1236 }
1237
1238 std::vector<string> encapsulated_functions{"F1", "F2"};
1239 TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1240
1241 FunctionDefLibrary library_expected;
1242 GraphDef graphdef_expected;
1243
1244 TensorShapeProto shape_proto_expected;
1245 shape_proto_expected.add_dim()->set_size(2);
1246
1247 *library_expected.add_function() = FunctionDefHelper::Create(
1248 "F1", {"a_0_arg:float", "b_0_arg:float"},
1249 {"e_0_retval_retval:float", "f_0_retval_retval:float",
1250 "d_0_retval_retval:float"},
1251 {},
1252 {
1253 {{"C"}, "UnaryTest", {"a_0_arg"}},
1254 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
1255 {{"F"},
1256 "BinaryTest",
1257 {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
1258 {},
1259 {"outside_compilation_O1_host_compute"}},
1260 {{"outside_compilation_O1_host_compute"},
1261 "XlaHostCompute",
1262 {"C:o:0", "D:o:0"},
1263 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
1264 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1265 {"ancestors", absl::Span<const string>({})},
1266 {"key", "host_compute_channel_F1_F1_O1"},
1267 {"send_key", ""},
1268 {"recv_key", ""},
1269 {"shape_inference_graph", NameAttrList()},
1270 {"tpu_core", 0},
1271 {"cost_estimate_ns", 1000000},
1272 {"shapes",
1273 absl::Span<const TensorShapeProto>({shape_proto_expected})},
1274 {"_outside_compilation_subgraph", "O1"},
1275 {"_xla_token_input_nodes",
1276 absl::Span<const string>({"_xla_token_arg_node"})},
1277 {"_xla_original_oc_node_name",
1278 "outside_compilation_O1_host_compute"}},
1279 {"D"}},
1280 },
1281 {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
1282 {"d_0_retval_retval", "D:o:0"},
1283 {"f_0_retval_retval", "F:o:0"}});
1284
1285 *library_expected.add_function() = FunctionDefHelper::Create(
1286 "F2", {"e_0_arg:float", "f_0_arg:float", "d_0_arg:float"},
1287 {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {},
1288 {
1289 {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
1290 {{"I"},
1291 "BinaryTest",
1292 {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}},
1293 {{"outside_compilation_O1_host_compute"},
1294 "XlaHostCompute",
1295 {"d_0_arg", "G:o:0"},
1296 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
1297 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1298 {"ancestors", absl::Span<const string>({})},
1299 {"key", "host_compute_channel_F2_F2_O1"},
1300 {"send_key", ""},
1301 {"recv_key", ""},
1302 {"shape_inference_graph", NameAttrList()},
1303 {"tpu_core", 0},
1304 {"cost_estimate_ns", 1000000},
1305 {"shapes",
1306 absl::Span<const TensorShapeProto>({shape_proto_expected})},
1307 {"_outside_compilation_subgraph", "O1"},
1308 {"_xla_token_input_nodes",
1309 absl::Span<const string>({"_xla_token_arg_node"})},
1310 {"_xla_original_oc_node_name",
1311 "outside_compilation_O1_host_compute"}}},
1312 },
1313 {{"g_0_retval_retval", "G:o:0"}, {"i_0_retval_retval", "I:o:0"}});
1314
1315 {
1316 std::unique_ptr<FunctionLibraryDefinition> lib_def(
1317 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
1318 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
1319 Node* a = InputShaped(b2.opts().WithName("A"));
1320 Node* b = InputShaped(b2.opts().WithName("B"));
1321
1322 Node* key_constant1 =
1323 KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
1324 Node* recv1 = RecvAtHost(
1325 ops::NodeOut(key_constant1, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
1326 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1327 Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
1328 b2.opts()
1329 .WithName("E")
1330 .WithControlInputs({recv1})
1331 .WithAttr("_encapsulate", "F1")
1332 .WithAttr("_outside", "O1"));
1333 Node* send1 =
1334 SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "F1", "O1", {e},
1335 b2.opts().WithControlInput(e).WithAttr(
1336 kXlaHasHostTransferAttrName, true));
1337 Node* s1 = Sequencer(
1338 b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
1339 "F1");
1340
1341 NodeBuilder node_builder1("F1", "F1", lib_def.get());
1342 node_builder1.Input(a).Input(b);
1343 Node* call1 =
1344 b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1);
1345
1346 Node* key_constant2 =
1347 KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder"));
1348 Node* recv2 = RecvAtHost(
1349 ops::NodeOut(key_constant2, 0), "F2", "F2", "O1", {DT_FLOAT, DT_FLOAT},
1350 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1351 Node* h = Binary(recv2, ops::NodeOut(recv2, 1),
1352 b2.opts()
1353 .WithName("H")
1354 .WithAttr("_encapsulate", "F2")
1355 .WithAttr("_outside", "O1"));
1356 Node* send2 =
1357 SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "F2", "O1", {h},
1358 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1359
1360 Node* s2 = Sequencer(
1361 b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}),
1362 "F2");
1363 NodeBuilder node_builder2("F2", "F2", lib_def.get());
1364 node_builder2.Input(call1)
1365 .Input(ops::NodeOut(call1, 1))
1366 .Input(ops::NodeOut(call1, 2));
1367 Node* call2 = b2.opts()
1368 .WithControlInputs({s2, call1})
1369 .FinalizeBuilder(&node_builder2);
1370 Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J"));
1371 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
1372 }
1373
1374 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
1375 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
1376 }
1377
1378 // Test with two functions to transform, each with one outside_compilation
1379 // cluster, with the dependency between them purely from an outside_compilation
1380 // edge.
TEST(EncapsulateSubgraphsTest,TwoFunctionsTwoOutsideDependencyFromOutside)1381 TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
1382 FunctionDefLibrary library;
1383 GraphDef graphdef;
1384
1385 {
1386 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
1387 Node* a = InputShaped(b1.opts().WithName("A"));
1388 Node* b = InputShaped(b1.opts().WithName("B"));
1389 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1390 Node* d =
1391 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1392 Node* e = Binary(c, d,
1393 b1.opts()
1394 .WithName("E")
1395 .WithControlInputs({b, d})
1396 .WithAttr("_encapsulate", "F1")
1397 .WithAttr("_outside", "O1"));
1398 Node* f = Binary(c, e,
1399 b1.opts().WithName("F").WithControlInput(e).WithAttr(
1400 "_encapsulate", "F1"));
1401 Node* g =
1402 Binary(a, b, b1.opts().WithName("G").WithAttr("_encapsulate", "F2"));
1403 Node* h = Unary(g, b1.opts()
1404 .WithName("H")
1405 .WithAttr("_encapsulate", "F2")
1406 .WithAttr("_outside", "O1"));
1407 Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2"));
1408 Binary(f, i, b1.opts().WithName("J"));
1409 TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1410 }
1411
1412 std::vector<string> encapsulated_functions{"F1", "F2"};
1413 TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1414
1415 FunctionDefLibrary library_expected;
1416 GraphDef graphdef_expected;
1417 TensorShapeProto shape_proto_expected;
1418 shape_proto_expected.add_dim()->set_size(2);
1419
1420 *library_expected.add_function() = FunctionDefHelper::Create(
1421 "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
1422 {
1423 {{"C"}, "UnaryTest", {"a_0_arg"}},
1424 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
1425 {{"F"},
1426 "BinaryTest",
1427 {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
1428 {},
1429 {"outside_compilation_O1_host_compute"}},
1430 {{"outside_compilation_O1_host_compute"},
1431 "XlaHostCompute",
1432 {"C:o:0", "D:o:0"},
1433 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
1434 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1435 {"ancestors", absl::Span<const string>({})},
1436 {"key", "host_compute_channel_F1_F1_O1"},
1437 {"send_key", ""},
1438 {"recv_key", ""},
1439 {"shape_inference_graph", NameAttrList()},
1440 {"tpu_core", 0},
1441 {"cost_estimate_ns", 1000000},
1442 {"shapes",
1443 absl::Span<const TensorShapeProto>({shape_proto_expected})},
1444 {"_outside_compilation_subgraph", "O1"},
1445 {"_xla_token_input_nodes",
1446 absl::Span<const string>({"_xla_token_arg_node"})},
1447 {"_xla_original_oc_node_name",
1448 "outside_compilation_O1_host_compute"}},
1449 {"D"}},
1450 },
1451 {{"f_0_retval_retval", "F:o:0"}});
1452
1453 *library_expected.add_function() = FunctionDefHelper::Create(
1454 "F2", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval_retval:float"}, {},
1455 {
1456 {{"G"}, "BinaryTest", {"a_0_arg", "b_0_arg"}},
1457 {{"I"},
1458 "UnaryTest",
1459 {"outside_compilation_O1_host_compute:outputs:0"}},
1460 {{"outside_compilation_O1_host_compute"},
1461 "XlaHostCompute",
1462 {"G:o:0"},
1463 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
1464 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1465 {"ancestors", absl::Span<const string>({})},
1466 {"key", "host_compute_channel_F2_F2_O1"},
1467 {"send_key", ""},
1468 {"recv_key", ""},
1469 {"shape_inference_graph", NameAttrList()},
1470 {"tpu_core", 0},
1471 {"cost_estimate_ns", 1000000},
1472 {"shapes",
1473 absl::Span<const TensorShapeProto>({shape_proto_expected})},
1474 {"_outside_compilation_subgraph", "O1"},
1475 {"_xla_token_input_nodes",
1476 absl::Span<const string>({"_xla_token_arg_node"})},
1477 {"_xla_original_oc_node_name",
1478 "outside_compilation_O1_host_compute"}}},
1479 },
1480 {{"i_0_retval_retval", "I:o:0"}});
1481
1482 {
1483 std::unique_ptr<FunctionLibraryDefinition> lib_def(
1484 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
1485 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
1486 Node* a = InputShaped(b2.opts().WithName("A"));
1487 Node* b = InputShaped(b2.opts().WithName("B"));
1488
1489 Node* key_constant1 =
1490 KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
1491 Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "F1", "O1",
1492 {DT_FLOAT, DT_FLOAT}, b2.opts());
1493 Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
1494 b2.opts()
1495 .WithName("E")
1496 .WithControlInputs({recv1})
1497 .WithAttr("_encapsulate", "F1")
1498 .WithAttr("_outside", "O1"));
1499 Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "F1", "O1",
1500 {e}, b2.opts().WithControlInput(e));
1501 Node* s1 = Sequencer(
1502 b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
1503 "F1");
1504
1505 NodeBuilder node_builder1("F1", "F1", lib_def.get());
1506 node_builder1.Input(a).Input(b);
1507 Node* call1 =
1508 b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1);
1509
1510 Node* key_constant2 =
1511 KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder"));
1512 Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "F2", "O1",
1513 {DT_FLOAT}, b2.opts());
1514 Node* h = Unary(recv2, b2.opts()
1515 .WithName("H")
1516 .WithAttr("_encapsulate", "F2")
1517 .WithAttr("_outside", "O1"));
1518 Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "F2", "O1",
1519 {h}, b2.opts());
1520
1521 Node* s2 = Sequencer(
1522 b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}),
1523 "F2");
1524 NodeBuilder node_builder2("F2", "F2", lib_def.get());
1525 node_builder2.Input(a).Input(b);
1526 Node* call2 =
1527 b2.opts().WithControlInputs({s2}).FinalizeBuilder(&node_builder2);
1528 Binary(call1, call2, b2.opts().WithName("J"));
1529 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
1530 }
1531
1532 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
1533 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
1534 }
1535
1536 // Test with one outside_compilation cluster that has no inputs from the
1537 // compiled subgraph.
TEST(EncapsulateSubgraphsTest,OutsideCompilationNoInputs)1538 TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
1539 FunctionDefLibrary library;
1540 GraphDef graphdef;
1541
1542 {
1543 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
1544 Node* a = InputShaped(b1.opts().WithName("A"));
1545 Node* b = Input(b1.opts().WithName("B"));
1546 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1547 Node* d =
1548 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1549 Node* e = Unary(a, b1.opts()
1550 .WithName("E")
1551 .WithAttr("_encapsulate", "F1")
1552 .WithAttr("_outside", "O1"));
1553 Node* f =
1554 Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
1555 Unary(f, b1.opts().WithName("G"));
1556 TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1557 }
1558
1559 std::vector<string> encapsulated_functions{"F1"};
1560 TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1561
1562 FunctionDefLibrary library_expected;
1563 GraphDef graphdef_expected;
1564
1565 TensorShapeProto shape_proto_expected;
1566 shape_proto_expected.add_dim()->set_size(2);
1567
1568 *library_expected.add_function() = FunctionDefHelper::Create(
1569 "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
1570 {
1571 {{"C"}, "UnaryTest", {"a_0_arg"}},
1572 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
1573 {{"F"},
1574 "BinaryTest",
1575 {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
1576 {{"outside_compilation_O1_host_compute"},
1577 "XlaHostCompute",
1578 {"a_0_arg"},
1579 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
1580 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1581 {"ancestors", absl::Span<const string>({})},
1582 {"key", "host_compute_channel_F1_F1_O1"},
1583 {"send_key", ""},
1584 {"recv_key", ""},
1585 {"shape_inference_graph", NameAttrList()},
1586 {"tpu_core", 0},
1587 {"cost_estimate_ns", 1000000},
1588 {"shapes",
1589 absl::Span<const TensorShapeProto>({shape_proto_expected})},
1590 {"_outside_compilation_subgraph", "O1"},
1591 {"_xla_token_input_nodes",
1592 absl::Span<const string>({"_xla_token_arg_node"})},
1593 {"_xla_original_oc_node_name",
1594 "outside_compilation_O1_host_compute"}}},
1595 },
1596 {{"f_0_retval_retval", "F:o:0"}});
1597
1598 {
1599 std::unique_ptr<FunctionLibraryDefinition> lib_def(
1600 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
1601 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
1602 Node* a = InputShaped(b2.opts().WithName("A"));
1603 Node* b = Input(b2.opts().WithName("B"));
1604
1605 Node* key_constant =
1606 KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
1607 Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1608 {DT_FLOAT}, b2.opts());
1609 Node* e = Unary(recv1, b2.opts()
1610 .WithName("E")
1611 .WithAttr("_encapsulate", "F1")
1612 .WithAttr("_outside", "O1"));
1613 Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1614 {e}, b2.opts());
1615 Node* s1 = Sequencer(
1616 b2.opts().WithName("F1_sequencer").WithControlInputs({send1, recv1}),
1617 "F1");
1618 NodeBuilder node_builder1("F1", "F1", lib_def.get());
1619 node_builder1.Input(a).Input(b);
1620 Node* call1 =
1621 b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
1622
1623 Unary(call1, b2.opts().WithName("G"));
1624 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
1625 }
1626
1627 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
1628 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
1629 }
1630
1631 // Test with one outside_compilation cluster that has no data inputs but has a
1632 // control input from the compiled subgraph.
TEST(EncapsulateSubgraphsTest,OutsideCompilationControlInput)1633 TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
1634 FunctionDefLibrary library;
1635 GraphDef graphdef;
1636
1637 {
1638 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
1639 Node* a = InputShaped(b1.opts().WithName("A"));
1640 Node* b = Input(b1.opts().WithName("B"));
1641 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1642 Node* d =
1643 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1644 Node* e = Unary(a, b1.opts()
1645 .WithName("E")
1646 .WithControlInput(d)
1647 .WithAttr("_encapsulate", "F1")
1648 .WithAttr("_outside", "O1"));
1649 Node* f =
1650 Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
1651 Unary(f, b1.opts().WithName("G"));
1652 TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1653 }
1654
1655 std::vector<string> encapsulated_functions{"F1"};
1656 TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1657
1658 FunctionDefLibrary library_expected;
1659 GraphDef graphdef_expected;
1660
1661 TensorShapeProto shape_proto_expected;
1662 shape_proto_expected.add_dim()->set_size(2);
1663
1664 *library_expected.add_function() = FunctionDefHelper::Create(
1665 "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
1666 {
1667 {{"C"}, "UnaryTest", {"a_0_arg"}},
1668 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
1669 {{"F"},
1670 "BinaryTest",
1671 {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
1672 {{"outside_compilation_O1_host_compute"},
1673 "XlaHostCompute",
1674 {"a_0_arg"},
1675 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
1676 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1677 {"ancestors", absl::Span<const string>({})},
1678 {"key", "host_compute_channel_F1_F1_O1"},
1679 {"send_key", ""},
1680 {"recv_key", ""},
1681 {"shape_inference_graph", NameAttrList()},
1682 {"tpu_core", 0},
1683 {"cost_estimate_ns", 1000000},
1684 {"shapes",
1685 absl::Span<const TensorShapeProto>({shape_proto_expected})},
1686 {"_outside_compilation_subgraph", "O1"},
1687 {"_xla_token_input_nodes",
1688 absl::Span<const string>({"_xla_token_arg_node"})},
1689 {"_xla_original_oc_node_name",
1690 "outside_compilation_O1_host_compute"}},
1691 {"D"}},
1692 },
1693 {{"f_0_retval_retval", "F:o:0"}});
1694
1695 {
1696 std::unique_ptr<FunctionLibraryDefinition> lib_def(
1697 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
1698 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
1699 Node* a = InputShaped(b2.opts().WithName("A"));
1700 Node* b = Input(b2.opts().WithName("B"));
1701
1702 Node* key_constant =
1703 KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
1704 Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1705 {DT_FLOAT}, b2.opts());
1706 Node* e = Unary(recv1, b2.opts()
1707 .WithName("E")
1708 .WithControlInput(recv1)
1709 .WithAttr("_encapsulate", "F1")
1710 .WithAttr("_outside", "O1"));
1711 Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1712 {e}, b2.opts());
1713 Node* s1 = Sequencer(
1714 b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
1715 "F1");
1716 NodeBuilder node_builder1("F1", "F1", lib_def.get());
1717 node_builder1.Input(a).Input(b);
1718 Node* call1 =
1719 b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
1720
1721 Unary(call1, b2.opts().WithName("G"));
1722 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
1723 }
1724
1725 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
1726 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
1727 }
1728
1729 // Test with one outside_compilation cluster that has no outputs from the
1730 // compiled subgraph.
TEST(EncapsulateSubgraphsTest,OutsideCompilationNoOutputs)1731 TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
1732 FunctionDefLibrary library;
1733 GraphDef graphdef;
1734
1735 {
1736 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
1737 Node* a = Input(b1.opts().WithName("A"));
1738 Node* b = Input(b1.opts().WithName("B"));
1739 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1740 Node* d =
1741 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1742 Node* e = Unary(d, b1.opts()
1743 .WithName("E")
1744 .WithAttr("_encapsulate", "F1")
1745 .WithAttr("_outside", "O1"));
1746 Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
1747 Binary(e, f, b1.opts().WithName("G"));
1748 TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1749 }
1750
1751 std::vector<string> encapsulated_functions{"F1"};
1752 TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1753
1754 FunctionDefLibrary library_expected;
1755 GraphDef graphdef_expected;
1756
1757 {
1758 GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
1759 Node* key_constant = KeyPlaceholder("F1", shape1.opts());
1760 Node* recv1 =
1761 RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
1762 shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1763 Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts()
1764 .WithName("E")
1765 .WithAttr("_encapsulate", "F1")
1766 .WithAttr("_outside", "O1"));
1767 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
1768 shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1769 TF_EXPECT_OK(
1770 AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
1771 }
1772
1773 NameAttrList shape_inference_graph;
1774 shape_inference_graph.set_name(
1775 "_outside_compilation_shape_inference_F1_F1_O1");
1776 *library_expected.add_function() = FunctionDefHelper::Create(
1777 "F1", {"a_0_arg:float", "b_0_arg:float"},
1778 {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {},
1779 {
1780 {{"C"}, "UnaryTest", {"a_0_arg"}},
1781 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
1782 {{"F"}, "UnaryTest", {"D:o:0"}},
1783 {{"outside_compilation_O1_host_compute"},
1784 "XlaHostCompute",
1785 {"D:o:0"},
1786 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
1787 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1788 {"ancestors", absl::Span<const string>({})},
1789 {"key", "host_compute_channel_F1_F1_O1"},
1790 {"send_key", ""},
1791 {"recv_key", ""},
1792 {"shape_inference_graph", shape_inference_graph},
1793 {"tpu_core", 0},
1794 {"cost_estimate_ns", 1000000},
1795 {"shapes", absl::Span<const TensorShapeProto>({})},
1796 {"_outside_compilation_subgraph", "O1"},
1797 {"_xla_token_input_nodes",
1798 absl::Span<const string>({"_xla_token_arg_node"})},
1799 {"_xla_original_oc_node_name",
1800 "outside_compilation_O1_host_compute"}}},
1801 },
1802 {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
1803 {"f_0_retval_retval", "F:o:0"}});
1804
1805 {
1806 std::unique_ptr<FunctionLibraryDefinition> lib_def(
1807 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
1808 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
1809 Node* a = Input(b2.opts().WithName("A"));
1810 Node* b = Input(b2.opts().WithName("B"));
1811
1812 Node* key_constant =
1813 KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
1814 Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1815 {DT_FLOAT}, b2.opts());
1816 Node* e = Unary(recv1, b2.opts()
1817 .WithName("E")
1818 .WithAttr("_encapsulate", "F1")
1819 .WithAttr("_outside", "O1"));
1820 Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1821 {e}, b2.opts());
1822 Node* s1 = Sequencer(
1823 b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
1824 "F1");
1825 NodeBuilder node_builder1("F1", "F1", lib_def.get());
1826 node_builder1.Input(a).Input(b);
1827 Node* call1 =
1828 b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
1829
1830 Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G"));
1831 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
1832 }
1833
1834 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
1835 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
1836 }
1837
1838 // Test with one outside_compilation cluster that has no data outputs but has a
1839 // control output to the compiled subgraph.
TEST(EncapsulateSubgraphsTest,OutsideCompilationControlOutput)1840 TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
1841 FunctionDefLibrary library;
1842 GraphDef graphdef;
1843
1844 {
1845 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
1846 Node* a = Input(b1.opts().WithName("A"));
1847 Node* b = Input(b1.opts().WithName("B"));
1848 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1849 Node* d =
1850 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1851 Node* e = Unary(d, b1.opts()
1852 .WithName("E")
1853 .WithAttr("_encapsulate", "F1")
1854 .WithAttr("_outside", "O1"));
1855 Node* f = Unary(d, b1.opts().WithName("F").WithControlInput(e).WithAttr(
1856 "_encapsulate", "F1"));
1857 Binary(e, f, b1.opts().WithName("G"));
1858 TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1859 }
1860
1861 std::vector<string> encapsulated_functions{"F1"};
1862 TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1863
1864 FunctionDefLibrary library_expected;
1865 GraphDef graphdef_expected;
1866
1867 {
1868 GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
1869 Node* key_constant = KeyPlaceholder("F1", shape1.opts());
1870 Node* recv1 =
1871 RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
1872 shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1873 Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts()
1874 .WithName("E")
1875 .WithAttr("_encapsulate", "F1")
1876 .WithAttr("_outside", "O1"));
1877 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
1878 shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1879 TF_EXPECT_OK(
1880 AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
1881 }
1882
1883 NameAttrList shape_inference_graph;
1884 shape_inference_graph.set_name(
1885 "_outside_compilation_shape_inference_F1_F1_O1");
1886 *library_expected.add_function() = FunctionDefHelper::Create(
1887 "F1", {"a_0_arg:float", "b_0_arg:float"},
1888 {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {},
1889 {
1890 {{"C"}, "UnaryTest", {"a_0_arg"}},
1891 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
1892 {{"F"},
1893 "UnaryTest",
1894 {"D:o:0"},
1895 {},
1896 {"outside_compilation_O1_host_compute"}},
1897 {{"outside_compilation_O1_host_compute"},
1898 "XlaHostCompute",
1899 {"D:o:0"},
1900 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
1901 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
1902 {"ancestors", absl::Span<const string>({})},
1903 {"key", "host_compute_channel_F1_F1_O1"},
1904 {"send_key", ""},
1905 {"recv_key", ""},
1906 {"shape_inference_graph", shape_inference_graph},
1907 {"tpu_core", 0},
1908 {"cost_estimate_ns", 1000000},
1909 {"shapes", absl::Span<const TensorShapeProto>({})},
1910 {"_outside_compilation_subgraph", "O1"},
1911 {"_xla_token_input_nodes",
1912 absl::Span<const string>({"_xla_token_arg_node"})},
1913 {"_xla_original_oc_node_name",
1914 "outside_compilation_O1_host_compute"}}},
1915 },
1916 {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
1917 {"f_0_retval_retval", "F:o:0"}});
1918
1919 {
1920 std::unique_ptr<FunctionLibraryDefinition> lib_def(
1921 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
1922 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
1923 Node* a = Input(b2.opts().WithName("A"));
1924 Node* b = Input(b2.opts().WithName("B"));
1925
1926 Node* key_constant =
1927 KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
1928 Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1929 {DT_FLOAT}, b2.opts());
1930 Node* e = Unary(recv1, b2.opts()
1931 .WithName("E")
1932 .WithAttr("_encapsulate", "F1")
1933 .WithAttr("_outside", "O1"));
1934 Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
1935 {e}, b2.opts().WithControlInput(e));
1936 Node* s1 = Sequencer(
1937 b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
1938 "F1");
1939 NodeBuilder node_builder1("F1", "F1", lib_def.get());
1940 node_builder1.Input(a).Input(b);
1941 Node* call1 =
1942 b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
1943
1944 Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G"));
1945 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
1946 }
1947
1948 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
1949 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
1950 }
1951
1952 // Test with two outside_compilation clusters that interact outside the compiled
1953 // subgraph, where the ancestor has no HostCompute Op.
TEST(EncapsulateSubgraphsTest,OutsideCompilationClusterDependencyNoSrcCluster)1954 TEST(EncapsulateSubgraphsTest,
1955 OutsideCompilationClusterDependencyNoSrcCluster) {
1956 FunctionDefLibrary library;
1957 GraphDef graphdef;
1958
1959 {
1960 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
1961 Node* a = Input(b1.opts().WithName("A"));
1962 Node* b = Input(b1.opts().WithName("B"));
1963 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
1964 Node* d =
1965 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
1966 Node* e = Unary(a, b1.opts()
1967 .WithName("E")
1968 .WithAttr("_encapsulate", "F1")
1969 .WithAttr("_outside", "O1"));
1970 Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
1971 Node* g = Unary(f, b1.opts()
1972 .WithName("G")
1973 .WithAttr("_encapsulate", "F1")
1974 .WithAttr("_outside", "O2")
1975 .WithControlInput(e));
1976 Node* h = Unary(g, b1.opts().WithName("H").WithAttr("_encapsulate", "F1"));
1977 Binary(e, h, b1.opts().WithName("I"));
1978 TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
1979 }
1980
1981 std::vector<string> encapsulated_functions{"F1"};
1982 TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
1983
1984 FunctionDefLibrary library_expected;
1985 GraphDef graphdef_expected;
1986
1987 {
1988 GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
1989 Node* key_constant = KeyPlaceholder("F1", shape1.opts());
1990 Node* recv1 =
1991 RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
1992 shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1993 Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts()
1994 .WithName("E")
1995 .WithAttr("_encapsulate", "F1")
1996 .WithAttr("_outside", "O1"));
1997 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
1998 shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
1999 TF_EXPECT_OK(
2000 AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
2001 }
2002
2003 {
2004 GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
2005 Node* key_constant = KeyPlaceholder("F1", shape2.opts());
2006 Node* recv2 =
2007 RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT},
2008 shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2009 Node* g = Unary(ops::NodeOut(recv2, 0), shape2.opts()
2010 .WithName("G")
2011 .WithAttr("_encapsulate", "F1")
2012 .WithAttr("_outside", "O2"));
2013 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {g},
2014 shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2015 TF_EXPECT_OK(
2016 AddGraphDefToFunctionLibrary(shape2, "F1_F1_O2", &library_expected));
2017 }
2018
2019 NameAttrList shape_inference_graph1;
2020 shape_inference_graph1.set_name(
2021 "_outside_compilation_shape_inference_F1_F1_O1");
2022 NameAttrList shape_inference_graph2;
2023 shape_inference_graph2.set_name(
2024 "_outside_compilation_shape_inference_F1_F1_O2");
2025 *library_expected.add_function() = FunctionDefHelper::Create(
2026 "F1", {"a_0_arg:float", "b_0_arg:float"},
2027 {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {},
2028 {
2029 {{"C"}, "UnaryTest", {"a_0_arg"}},
2030 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
2031 {{"F"}, "UnaryTest", {"D:o:0"}},
2032 {{"H"},
2033 "UnaryTest",
2034 {"outside_compilation_O2_host_compute:outputs:0"}},
2035 {{"outside_compilation_O1_host_compute"},
2036 "XlaHostCompute",
2037 {"a_0_arg"},
2038 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2039 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
2040 {"ancestors", absl::Span<const string>({})},
2041 {"key", "host_compute_channel_F1_F1_O1"},
2042 {"send_key", ""},
2043 {"recv_key", ""},
2044 {"shape_inference_graph", shape_inference_graph1},
2045 {"tpu_core", 0},
2046 {"cost_estimate_ns", 1000000},
2047 {"shapes", absl::Span<const TensorShapeProto>({})},
2048 {"_outside_compilation_subgraph", "O1"},
2049 {"_xla_token_input_nodes",
2050 absl::Span<const string>({"_xla_token_arg_node"})},
2051 {"_xla_original_oc_node_name",
2052 "outside_compilation_O1_host_compute"}}},
2053 {{"outside_compilation_O2_host_compute"},
2054 "XlaHostCompute",
2055 {"F:o:0"},
2056 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2057 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
2058 {"ancestors", absl::Span<const string>({})},
2059 {"key", "host_compute_channel_F1_F1_O2"},
2060 {"send_key", ""},
2061 {"recv_key", ""},
2062 {"shape_inference_graph", shape_inference_graph2},
2063 {"tpu_core", 0},
2064 {"cost_estimate_ns", 1000000},
2065 {"shapes", absl::Span<const TensorShapeProto>({})},
2066 {"_outside_compilation_subgraph", "O2"},
2067 {"_xla_token_input_nodes",
2068 absl::Span<const string>({"_xla_token_arg_node",
2069 "outside_compilation_O1_host_compute"})},
2070 {"_xla_original_oc_node_name",
2071 "outside_compilation_O2_host_compute"}},
2072 {"outside_compilation_O1_host_compute"}},
2073 },
2074 {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
2075 {"h_0_retval_retval", "H:o:0"}});
2076
2077 {
2078 std::unique_ptr<FunctionLibraryDefinition> lib_def(
2079 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
2080 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
2081 Node* a = Input(b2.opts().WithName("A"));
2082 Node* b = Input(b2.opts().WithName("B"));
2083 Node* key_constant =
2084 KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
2085 Node* recv1 =
2086 RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
2087 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2088
2089 Node* e = Unary(recv1, b2.opts()
2090 .WithName("E")
2091 .WithAttr("_encapsulate", "F1")
2092 .WithAttr("_outside", "O1"));
2093 Node* send1 =
2094 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2095 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2096 Node* recv2 =
2097 RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT},
2098 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2099 Node* g = Unary(recv2, b2.opts()
2100 .WithName("G")
2101 .WithAttr("_encapsulate", "F1")
2102 .WithAttr("_outside", "O2")
2103 .WithControlInput(e));
2104 Node* send2 =
2105 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {g},
2106 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2107 Node* s1 = Sequencer(b2.opts()
2108 .WithName("F1_sequencer")
2109 .WithControlInputs({recv1, send1, recv2, send2}),
2110 "F1");
2111 NodeBuilder node_builder1("F1", "F1", lib_def.get());
2112 node_builder1.Input(a).Input(b).ControlInput(s1);
2113 Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
2114
2115 Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I"));
2116 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
2117 }
2118
2119 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
2120 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
2121 }
2122
2123 // Test with two outside_compilation clusters that interact outside the compiled
2124 // subgraph, where the successor has no HostCompute Op.
TEST(EncapsulateSubgraphsTest,OutsideCompilationClusterDependencyNoDstCluster)2125 TEST(EncapsulateSubgraphsTest,
2126 OutsideCompilationClusterDependencyNoDstCluster) {
2127 FunctionDefLibrary library;
2128 GraphDef graphdef;
2129
2130 {
2131 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
2132 Node* a = Input(b1.opts().WithName("A"));
2133 Node* b = Input(b1.opts().WithName("B"));
2134 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
2135 Node* d =
2136 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
2137 Node* e = Unary(d, b1.opts()
2138 .WithName("E")
2139 .WithAttr("_encapsulate", "F1")
2140 .WithAttr("_outside", "O1"));
2141 Node* f = Unary(e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
2142 /*Node* g =*/Unary(a, b1.opts()
2143 .WithName("G")
2144 .WithAttr("_encapsulate", "F1")
2145 .WithAttr("_outside", "O2")
2146 .WithControlInput(e));
2147 Node* h = Unary(f, b1.opts().WithName("H").WithAttr("_encapsulate", "F1"));
2148 Binary(e, h, b1.opts().WithName("I"));
2149 TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
2150 }
2151
2152 std::vector<string> encapsulated_functions{"F1"};
2153 TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
2154
2155 FunctionDefLibrary library_expected;
2156 GraphDef graphdef_expected;
2157
2158 {
2159 GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
2160 Node* key_constant = KeyPlaceholder("F1", shape1.opts());
2161 Node* recv2 =
2162 RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
2163 shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2164 Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
2165 .WithName("E")
2166 .WithAttr("_encapsulate", "F1")
2167 .WithAttr("_outside", "O1"));
2168 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2169 shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2170 TF_EXPECT_OK(
2171 AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
2172 }
2173
2174 NameAttrList shape_inference_graph;
2175 shape_inference_graph.set_name(
2176 "_outside_compilation_shape_inference_F1_F1_O1");
2177 *library_expected.add_function() = FunctionDefHelper::Create(
2178 "F1", {"a_0_arg:float", "b_0_arg:float"},
2179 {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {},
2180 {
2181 {{"C"}, "UnaryTest", {"a_0_arg"}},
2182 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
2183 {{"F"},
2184 "UnaryTest",
2185 {"outside_compilation_O1_host_compute:outputs:0"}},
2186 {{"H"}, "UnaryTest", {"F:o:0"}},
2187 {{"outside_compilation_O2_host_compute"},
2188 "XlaHostCompute",
2189 {"a_0_arg"},
2190 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2191 {"Toutputs", absl::Span<const DataType>({})},
2192 {"ancestors", absl::Span<const string>({})},
2193 {"key", "host_compute_channel_F1_F1_O2"},
2194 {"send_key", ""},
2195 {"recv_key", ""},
2196 {"shape_inference_graph", NameAttrList()},
2197 {"tpu_core", 0},
2198 {"cost_estimate_ns", 1000000},
2199 {"shapes", absl::Span<const TensorShapeProto>({})},
2200 {"_outside_compilation_subgraph", "O2"},
2201 {"_xla_token_input_nodes",
2202 absl::Span<const string>({"_xla_token_arg_node",
2203 "outside_compilation_O1_host_compute"})},
2204 {"_xla_original_oc_node_name",
2205 "outside_compilation_O2_host_compute"}},
2206 {"outside_compilation_O1_host_compute"}},
2207 {{"outside_compilation_O1_host_compute"},
2208 "XlaHostCompute",
2209 {"D:o:0"},
2210 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2211 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
2212 {"ancestors", absl::Span<const string>({})},
2213 {"key", "host_compute_channel_F1_F1_O1"},
2214 {"send_key", ""},
2215 {"recv_key", ""},
2216 {"shape_inference_graph", shape_inference_graph},
2217 {"tpu_core", 0},
2218 {"cost_estimate_ns", 1000000},
2219 {"shapes", absl::Span<const TensorShapeProto>({})},
2220 {"_outside_compilation_subgraph", "O1"},
2221 {"_xla_token_input_nodes",
2222 absl::Span<const string>({"_xla_token_arg_node"})},
2223 {"_xla_original_oc_node_name",
2224 "outside_compilation_O1_host_compute"}}},
2225 },
2226 {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
2227 {"h_0_retval_retval", "H:o:0"}});
2228
2229 {
2230 std::unique_ptr<FunctionLibraryDefinition> lib_def(
2231 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
2232 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
2233 Node* a = Input(b2.opts().WithName("A"));
2234 Node* b = Input(b2.opts().WithName("B"));
2235
2236 Node* key_constant =
2237 KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
2238 Node* recv1 =
2239 RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
2240 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2241 Node* e = Unary(recv1, b2.opts()
2242 .WithName("E")
2243 .WithAttr("_encapsulate", "F1")
2244 .WithAttr("_outside", "O1"));
2245 Node* send =
2246 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2247 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2248 Node* recv2 =
2249 RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT},
2250 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2251 /*Node* g =*/Unary(recv2, b2.opts()
2252 .WithName("G")
2253 .WithAttr("_encapsulate", "F1")
2254 .WithAttr("_outside", "O2")
2255 .WithControlInput(e));
2256 Node* s1 = Sequencer(b2.opts()
2257 .WithName("F1_sequencer")
2258 .WithControlInputs({recv1, recv2, send}),
2259 "F1");
2260 NodeBuilder node_builder1("F1", "F1", lib_def.get());
2261 node_builder1.Input(a).Input(b).ControlInput(s1);
2262 Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
2263
2264 Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I"));
2265 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
2266 }
2267
2268 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
2269 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
2270 }
2271
2272 // Test with two outside_compilation clusters that interact outside the compiled
2273 // subgraph.
TEST(EncapsulateSubgraphsTest,OutsideCompilationClusterDependency)2274 TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
2275 FunctionDefLibrary library;
2276 GraphDef graphdef;
2277
2278 {
2279 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
2280 Node* a = Input(b1.opts().WithName("A"));
2281 Node* b = Input(b1.opts().WithName("B"));
2282 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
2283 Node* d =
2284 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
2285 Node* e = Unary(d, b1.opts()
2286 .WithName("E")
2287 .WithAttr("_encapsulate", "F1")
2288 .WithAttr("_outside", "O1"));
2289 Node* f = Unary(e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
2290 Node* g = Unary(d, b1.opts()
2291 .WithName("G")
2292 .WithAttr("_encapsulate", "F1")
2293 .WithAttr("_outside", "O2")
2294 .WithControlInput(e));
2295 Node* h = Unary(f, b1.opts().WithName("H").WithAttr("_encapsulate", "F1"));
2296 /*Node* i =*/Binary(d, e,
2297 b1.opts()
2298 .WithName("I")
2299 .WithAttr("_encapsulate", "F1")
2300 .WithAttr("_outside", "O3")
2301 .WithControlInput(g));
2302 Binary(e, h, b1.opts().WithName("J"));
2303 TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
2304 }
2305
2306 std::vector<string> encapsulated_functions{"F1"};
2307 TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
2308
2309 FunctionDefLibrary library_expected;
2310 GraphDef graphdef_expected;
2311
2312 {
2313 GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
2314 Node* key_constant = KeyPlaceholder("F1", shape1.opts());
2315 Node* recv2 =
2316 RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
2317 shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2318 Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
2319 .WithName("E")
2320 .WithAttr("_encapsulate", "F1")
2321 .WithAttr("_outside", "O1"));
2322 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2323 shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2324 TF_EXPECT_OK(
2325 AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
2326 }
2327
2328 NameAttrList shape_inference_graph;
2329 shape_inference_graph.set_name(
2330 "_outside_compilation_shape_inference_F1_F1_O1");
2331 *library_expected.add_function() = FunctionDefHelper::Create(
2332 "F1", {"a_0_arg:float", "b_0_arg:float"},
2333 {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {},
2334 {{{"C"}, "UnaryTest", {"a_0_arg"}},
2335 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
2336 {{"F"}, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}},
2337 {{"H"}, "UnaryTest", {"F:o:0"}},
2338 {{"outside_compilation_O1_host_compute"},
2339 "XlaHostCompute",
2340 {"D:o:0"},
2341 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2342 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
2343 {"ancestors", absl::Span<const string>({})},
2344 {"key", "host_compute_channel_F1_F1_O1"},
2345 {"send_key", ""},
2346 {"recv_key", ""},
2347 {"shape_inference_graph", shape_inference_graph},
2348 {"tpu_core", 0},
2349 {"cost_estimate_ns", 1000000},
2350 {"shapes", absl::Span<const TensorShapeProto>({})},
2351 {"_outside_compilation_subgraph", "O1"},
2352 {"_xla_token_input_nodes",
2353 absl::Span<const string>({"_xla_token_arg_node"})},
2354 {"_xla_original_oc_node_name",
2355 "outside_compilation_O1_host_compute"}}},
2356 {{"outside_compilation_O2_host_compute"},
2357 "XlaHostCompute",
2358 {"D:o:0"},
2359 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2360 {"Toutputs", absl::Span<const DataType>({})},
2361 {"ancestors", absl::Span<const string>({})},
2362 {"key", "host_compute_channel_F1_F1_O2"},
2363 {"send_key", ""},
2364 {"recv_key", ""},
2365 {"shape_inference_graph", NameAttrList()},
2366 {"tpu_core", 0},
2367 {"cost_estimate_ns", 1000000},
2368 {"shapes", absl::Span<const TensorShapeProto>({})},
2369 {"_outside_compilation_subgraph", "O2"},
2370 {"_xla_token_input_nodes",
2371 absl::Span<const string>(
2372 {"_xla_token_arg_node", "outside_compilation_O1_host_compute"})},
2373 {"_xla_original_oc_node_name", "outside_compilation_O2_host_compute"}},
2374 {"outside_compilation_O1_host_compute"}},
2375 {{"outside_compilation_O3_host_compute"},
2376 "XlaHostCompute",
2377 {"D:o:0"},
2378 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2379 {"Toutputs", absl::Span<const DataType>({})},
2380 {"ancestors", absl::Span<const string>({})},
2381 {"key", "host_compute_channel_F1_F1_O3"},
2382 {"send_key", ""},
2383 {"recv_key", ""},
2384 {"shape_inference_graph", NameAttrList()},
2385 {"tpu_core", 0},
2386 {"cost_estimate_ns", 1000000},
2387 {"shapes", absl::Span<const TensorShapeProto>({})},
2388 {"_outside_compilation_subgraph", "O3"},
2389 {"_xla_token_input_nodes",
2390 absl::Span<const string>({"_xla_token_arg_node",
2391 "outside_compilation_O1_host_compute",
2392 "outside_compilation_O2_host_compute"})},
2393 {"_xla_original_oc_node_name", "outside_compilation_O3_host_compute"}},
2394 {"outside_compilation_O1_host_compute",
2395 "outside_compilation_O2_host_compute"}}},
2396 {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
2397 {"h_0_retval_retval", "H:o:0"}});
2398
2399 {
2400 std::unique_ptr<FunctionLibraryDefinition> lib_def(
2401 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
2402 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
2403 Node* a = Input(b2.opts().WithName("A"));
2404 Node* b = Input(b2.opts().WithName("B"));
2405
2406 Node* key_constant =
2407 KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
2408 Node* recv1 =
2409 RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
2410 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2411 Node* e = Unary(recv1, b2.opts()
2412 .WithName("E")
2413 .WithAttr("_encapsulate", "F1")
2414 .WithAttr("_outside", "O1"));
2415 Node* send =
2416 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2417 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2418 Node* recv2 =
2419 RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT},
2420 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2421 Node* g = Unary(recv2, b2.opts()
2422 .WithName("G")
2423 .WithAttr("_encapsulate", "F1")
2424 .WithAttr("_outside", "O2")
2425 .WithControlInput(e));
2426 Node* recv3 =
2427 RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O3", {DT_FLOAT},
2428 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2429 /*Node* i =*/Binary(recv3, e,
2430 b2.opts()
2431 .WithName("I")
2432 .WithAttr("_encapsulate", "F1")
2433 .WithAttr("_outside", "O3")
2434 .WithControlInput(g));
2435 Node* s1 = Sequencer(b2.opts()
2436 .WithName("F1_sequencer")
2437 .WithControlInputs({recv1, send, recv2, recv3}),
2438 "F1");
2439 NodeBuilder node_builder1("F1", "F1", lib_def.get());
2440 node_builder1.Input(a).Input(b).ControlInput(s1);
2441 Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
2442
2443 Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("J"));
2444 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
2445 }
2446
2447 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
2448 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
2449 }
2450
2451 // Test with one outside_compilation cluster that has no outputs from the
2452 // compiled subgraph.
TEST(EncapsulateSubgraphsTest,OutsideCompilationNoInputsOrOutputs)2453 TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
2454 FunctionDefLibrary library;
2455 GraphDef graphdef;
2456
2457 {
2458 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
2459 Node* a = Input(b1.opts().WithName("A"));
2460 Node* b = Input(b1.opts().WithName("B"));
2461 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
2462 Node* d =
2463 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
2464 Node* e = Unary(a, b1.opts()
2465 .WithName("E")
2466 .WithAttr("_encapsulate", "F1")
2467 .WithAttr("_outside", "O1"));
2468 Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
2469 Binary(e, f, b1.opts().WithName("G"));
2470 TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
2471 }
2472
2473 std::vector<string> encapsulated_functions{"F1"};
2474 TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
2475
2476 FunctionDefLibrary library_expected;
2477 GraphDef graphdef_expected;
2478
2479 {
2480 GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
2481 Node* key_constant = KeyPlaceholder("F1", shape1.opts());
2482 Node* recv2 =
2483 RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
2484 shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2485 Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
2486 .WithName("E")
2487 .WithAttr("_encapsulate", "F1")
2488 .WithAttr("_outside", "O1"));
2489 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2490 shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2491 TF_EXPECT_OK(
2492 AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
2493 }
2494
2495 NameAttrList shape_inference_graph;
2496 shape_inference_graph.set_name(
2497 "_outside_compilation_shape_inference_F1_F1_O1");
2498 *library_expected.add_function() = FunctionDefHelper::Create(
2499 "F1", {"a_0_arg:float", "b_0_arg:float"},
2500 {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {},
2501 {
2502 {{"C"}, "UnaryTest", {"a_0_arg"}},
2503 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
2504 {{"F"}, "UnaryTest", {"D:o:0"}},
2505 {{"outside_compilation_O1_host_compute"},
2506 "XlaHostCompute",
2507 {"a_0_arg"},
2508 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
2509 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
2510 {"ancestors", absl::Span<const string>({})},
2511 {"key", "host_compute_channel_F1_F1_O1"},
2512 {"send_key", ""},
2513 {"recv_key", ""},
2514 {"shape_inference_graph", shape_inference_graph},
2515 {"tpu_core", 0},
2516 {"cost_estimate_ns", 1000000},
2517 {"shapes", absl::Span<const TensorShapeProto>({})},
2518 {"_outside_compilation_subgraph", "O1"},
2519 {"_xla_token_input_nodes",
2520 absl::Span<const string>({"_xla_token_arg_node"})},
2521 {"_xla_original_oc_node_name",
2522 "outside_compilation_O1_host_compute"}}},
2523 },
2524 {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
2525 {"f_0_retval_retval", "F:o:0"}});
2526
2527 {
2528 std::unique_ptr<FunctionLibraryDefinition> lib_def(
2529 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
2530 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
2531 Node* a = Input(b2.opts().WithName("A"));
2532 Node* b = Input(b2.opts().WithName("B"));
2533
2534 Node* key_constant =
2535 KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
2536 Node* recv =
2537 RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
2538 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2539 Node* e = Unary(recv, b2.opts()
2540 .WithName("E")
2541 .WithAttr("_encapsulate", "F1")
2542 .WithAttr("_outside", "O1"));
2543 Node* send =
2544 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2545 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2546 Node* s = Sequencer(
2547 b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
2548 "F1");
2549 NodeBuilder node_builder1("F1", "F1", lib_def.get());
2550 node_builder1.Input(a).Input(b).ControlInput(s);
2551 Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
2552
2553 Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G"));
2554 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
2555 }
2556
2557 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
2558 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
2559 }
2560
2561 // Test for shape inference of outside compilation.
TEST(EncapsulateSubgraphsTest,OutsideCompilationShapeInference)2562 TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
2563 FunctionDefLibrary library;
2564 GraphDef graphdef;
2565
2566 {
2567 *library.add_function() = test::function::XTimesTwo();
2568
2569 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
2570 Node* a = InputShaped(b1.opts().WithName("A"));
2571 Node* b = Input(b1.opts().WithName("B"));
2572 // Give nodes 'c' and 'd' names that collide after lowercasing.
2573 Node* c = Unary(a, b1.opts().WithName("C"));
2574 Node* d = Unary(b, b1.opts().WithName("c").WithControlInput(c).WithAttr(
2575 "_encapsulate", "F1"));
2576 Node* e = BinaryUnknownShape(c, d,
2577 b1.opts()
2578 .WithName("E")
2579 .WithControlInputs({b, d})
2580 .WithAttr("_encapsulate", "F1")
2581 .WithAttr("_outside", "O1"));
2582 Node* f = Binary(c, e,
2583 b1.opts().WithName("F").WithControlInput(e).WithAttr(
2584 "_encapsulate", "F1"));
2585 Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
2586 TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
2587 }
2588
2589 std::vector<string> encapsulated_functions{"F1"};
2590 TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
2591
2592 FunctionDefLibrary library_expected;
2593 GraphDef graphdef_expected;
2594
2595 {
2596 GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
2597 Node* key_constant = KeyPlaceholder("F1", shape.opts());
2598 Node* recv = RecvAtHost(
2599 ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
2600 shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2601 Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1),
2602 shape.opts()
2603 .WithName("E")
2604 .WithAttr("_encapsulate", "F1")
2605 .WithAttr("_outside", "O1"));
2606 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2607 shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2608 TF_EXPECT_OK(
2609 AddGraphDefToFunctionLibrary(shape, "F1_F1_O1", &library_expected));
2610 }
2611
2612 NameAttrList shape_inference_graph;
2613 shape_inference_graph.set_name(
2614 "_outside_compilation_shape_inference_F1_F1_O1");
2615 *library_expected.add_function() = test::function::XTimesTwo();
2616 *library_expected.add_function() = FunctionDefHelper::Create(
2617 "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval_retval:float"}, {},
2618 {
2619 {{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}},
2620 {{"F"},
2621 "BinaryTest",
2622 {"c_0_arg", "outside_compilation_O1_host_compute:outputs:0"},
2623 {},
2624 {"outside_compilation_O1_host_compute"}},
2625 {{"outside_compilation_O1_host_compute"},
2626 "XlaHostCompute",
2627 {"c_0_arg", "c:o:0"},
2628 {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
2629 {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
2630 {"ancestors", absl::Span<const string>({})},
2631 {"key", "host_compute_channel_F1_F1_O1"},
2632 {"send_key", ""},
2633 {"recv_key", ""},
2634 {"shape_inference_graph", shape_inference_graph},
2635 {"tpu_core", 0},
2636 {"cost_estimate_ns", 1000000},
2637 {"shapes", absl::Span<const DataType>({})},
2638 {"_outside_compilation_subgraph", "O1"},
2639 {"_xla_token_input_nodes",
2640 absl::Span<const string>({"_xla_token_arg_node"})},
2641 {"_xla_original_oc_node_name",
2642 "outside_compilation_O1_host_compute"}},
2643 {"c"}},
2644 },
2645 {{"f_0_retval_retval", "F:o:0"}});
2646
2647 {
2648 std::unique_ptr<FunctionLibraryDefinition> lib_def(
2649 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
2650 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
2651 Node* a = InputShaped(b2.opts().WithName("A"));
2652 Node* b = Input(b2.opts().WithName("B"));
2653 Node* c = Unary(a, b2.opts().WithName("C"));
2654
2655 Node* key_constant =
2656 KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
2657 Node* recv = RecvAtHost(
2658 ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
2659 b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
2660 Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1),
2661 b2.opts()
2662 .WithName("E")
2663 .WithControlInputs({recv})
2664 .WithAttr("_encapsulate", "F1")
2665 .WithAttr("_outside", "O1"));
2666 Node* send =
2667 SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
2668 b2.opts().WithControlInput(e).WithAttr(
2669 kXlaHasHostTransferAttrName, true));
2670
2671 Node* s = Sequencer(
2672 b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
2673 "F1");
2674
2675 NodeBuilder node_builder("F1", "F1", lib_def.get());
2676 node_builder.Input(b).Input(c);
2677 Node* call =
2678 b2.opts().WithControlInputs({s, b, c}).FinalizeBuilder(&node_builder);
2679
2680 Binary(a, call, b2.opts().WithName("G").WithControlInputs({call}));
2681 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
2682 }
2683
2684 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
2685 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
2686 }
2687
CreateSubgraphTouchingRefVar(const Scope & s)2688 void CreateSubgraphTouchingRefVar(const Scope& s) {
2689 Output variable =
2690 ops::Variable(s.WithOpName("variable"), PartialTensorShape{}, DT_FLOAT);
2691 Output read = ops::Identity(s.WithOpName("read_ref_var"), variable);
2692 Output neg = ops::Negate(s.WithOpName("negate_ref"), read);
2693 Output add = ops::Add(s.WithOpName("add_ref"), neg, neg);
2694
2695 Output constant =
2696 ops::Const(s.WithOpName("constant_ref"), Input::Initializer(0.0));
2697 s.graph()->AddControlEdge(constant.node(), variable.node());
2698 }
2699
TEST(EncapsulateSubgraphsTest,RefVariablesMarked)2700 TEST(EncapsulateSubgraphsTest, RefVariablesMarked) {
2701 Scope root = Scope::NewRootScope().ExitOnError();
2702 CreateSubgraphTouchingRefVar(root);
2703
2704 auto graph = std::make_unique<Graph>(OpRegistry::Global());
2705 TF_ASSERT_OK(root.ToGraph(graph.get()));
2706
2707 GraphOptimizationPassWrapper wrapper;
2708 GraphOptimizationPassOptions options =
2709 wrapper.CreateGraphOptimizationPassOptions(&graph);
2710
2711 EncapsulateSubgraphsPass pass;
2712 TF_ASSERT_OK(pass.Run(options));
2713
2714 for (const Node* node : graph->nodes()) {
2715 bool has_ref_var;
2716 TF_ASSERT_OK(
2717 GetNodeAttr(node->attrs(), kXlaHasReferenceVarsAttr, &has_ref_var));
2718 EXPECT_TRUE(node->IsSink() || node->IsSource() || has_ref_var)
2719 << "All nodes apart from source and sink can access reference variable";
2720 }
2721 }
2722
CreateSubgraphNotTouchingRefVar(const Scope & s)2723 void CreateSubgraphNotTouchingRefVar(const Scope& s) {
2724 Output constant =
2725 ops::Const(s.WithOpName("constant_normal"), Input::Initializer(0.0));
2726 Output neg = ops::Negate(s.WithOpName("negate_normal"), constant);
2727 Output add = ops::Add(s.WithOpName("add_normal"), neg, neg);
2728 }
2729
TEST(EncapsulateSubgraphsTest,NoRefVarsNoAttr)2730 TEST(EncapsulateSubgraphsTest, NoRefVarsNoAttr) {
2731 Scope root = Scope::NewRootScope().ExitOnError();
2732 CreateSubgraphNotTouchingRefVar(root);
2733
2734 auto graph = std::make_unique<Graph>(OpRegistry::Global());
2735 TF_ASSERT_OK(root.ToGraph(graph.get()));
2736
2737 GraphOptimizationPassWrapper wrapper;
2738 GraphOptimizationPassOptions options =
2739 wrapper.CreateGraphOptimizationPassOptions(&graph);
2740
2741 EncapsulateSubgraphsPass pass;
2742 TF_ASSERT_OK(pass.Run(options));
2743
2744 for (const Node* node : graph->nodes()) {
2745 bool has_ref_var;
2746 TF_ASSERT_OK(
2747 GetNodeAttr(node->attrs(), kXlaHasReferenceVarsAttr, &has_ref_var));
2748 EXPECT_FALSE(has_ref_var) << "The graph does not have reference variables";
2749 }
2750 }
2751
2752 } // namespace
2753 } // namespace tensorflow
2754