1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
22 #include "tensorflow/compiler/jit/encapsulate_util.h"
23 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
24 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/graph_to_functiondef.h"
30 #include "tensorflow/core/framework/node_def_builder.h"
31 #include "tensorflow/core/framework/node_def_util.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/graph/algorithm.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/gtl/cleanup.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/util/dump_graph.h"
38 #include "tensorflow/stream_executor/lib/statusor.h"
39
40 namespace tensorflow {
41
42 namespace {
43
44 // Control return mapping function for outside compilation host graphs.
45 // All nodes with kXlaHasHostTransfer attribute are control outputs.
HostGraphControlRetMapping(const Node * n)46 std::optional<string> HostGraphControlRetMapping(const Node* n) {
47 if (HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
48 return n->name();
49 }
50 return std::nullopt;
51 }
52
53 // Add a key placeholder node to the graph. The key placeholder node will be
54 // used as input for XlaRecvAtHost/XlaSendFromHost nodes.
AddHostComputeKeyPlaceholder(const string & xla_cluster_name,Graph * g)55 StatusOr<Node*> AddHostComputeKeyPlaceholder(const string& xla_cluster_name,
56 Graph* g) {
57 NodeDef key_def;
58 NodeDefBuilder builder(absl::StrCat(xla_cluster_name, "_key_placeholder"),
59 "Placeholder");
60 builder.Attr("dtype", DT_STRING);
61 builder.Attr("shape", PartialTensorShape({2}));
62 builder.Attr("_host_compute_call_node", xla_cluster_name);
63 Status s = builder.Finalize(&key_def);
64 if (!s.ok()) return s;
65
66 Node* n = g->AddNode(key_def, &s);
67 if (!s.ok()) return s;
68 return n;
69 }
70
71 // Returns if the node is a XLA computation key placeholder.
IsKeyPlaceholderNode(const Node & n)72 bool IsKeyPlaceholderNode(const Node& n) {
73 return n.type_string() == "Placeholder" &&
74 absl::EndsWith(n.name(), "_key_placeholder");
75 }
76
77 // Returns nodes with given type.
GatherNodesWithType(const Graph & g,const string & type)78 std::vector<Node*> GatherNodesWithType(const Graph& g, const string& type) {
79 std::vector<Node*> result;
80 for (Node* n : g.nodes()) {
81 if (n->type_string() == type) {
82 result.push_back(n);
83 }
84 }
85 return result;
86 }
87
88 // Gets data types from `arg_nodes` and fills them into `recv_at_host_dtypes`.
GetArgDataTypes(const std::vector<Node * > & arg_nodes,std::vector<DataType> * recv_at_host_dtypes)89 Status GetArgDataTypes(const std::vector<Node*>& arg_nodes,
90 std::vector<DataType>* recv_at_host_dtypes) {
91 recv_at_host_dtypes->resize(arg_nodes.size(), DT_INVALID);
92 for (auto* n : arg_nodes) {
93 int index;
94 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
95 DataType dtype;
96 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
97 (*recv_at_host_dtypes)[index] = dtype;
98 }
99 for (int i = 0, end = recv_at_host_dtypes->size(); i < end; i++) {
100 if ((*recv_at_host_dtypes)[i] == DT_INVALID) {
101 return errors::Internal("Cannot get datatype for input ", i);
102 }
103 }
104 return OkStatus();
105 }
106
107 // Builds XlaRecvAtHost node.
BuildRecvAtHostNode(Graph * g,const string & oc_cluster_name,const std::vector<DataType> & recv_at_host_dtypes,Node * key_placeholder)108 StatusOr<Node*> BuildRecvAtHostNode(
109 Graph* g, const string& oc_cluster_name,
110 const std::vector<DataType>& recv_at_host_dtypes, Node* key_placeholder) {
111 NodeDefBuilder recv_at_host_builder(
112 absl::StrCat("outside_compilation_", oc_cluster_name, "_recv"),
113 "_XlaRecvAtHost");
114 NodeDef recv_at_host_def;
115 recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes);
116 // The correct device_ordinal will be inserted during replication in a
117 // subsequent rewrite.
118 AttrValue device_ordinal_value;
119 device_ordinal_value.set_placeholder("_device_ordinal");
120 recv_at_host_builder.Attr("device_ordinal", device_ordinal_value);
121 recv_at_host_builder.Attr(
122 "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
123 recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true);
124 recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
125 TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def));
126 TF_ASSIGN_OR_RETURN(Node * recv_at_host_node, g->AddNode(recv_at_host_def));
127 return recv_at_host_node;
128 }
129
130 // Builds XlaRecvAtHost node, and replaces all _Arg nodes with it.
ReplaceArgNodesWithRecvAtHostNode(Graph * g,const string & oc_cluster_name,std::vector<DataType> * recv_at_host_dtypes,Node * key_placeholder)131 StatusOr<Node*> ReplaceArgNodesWithRecvAtHostNode(
132 Graph* g, const string& oc_cluster_name,
133 std::vector<DataType>* recv_at_host_dtypes, Node* key_placeholder) {
134 // TODO(b/77601805): use out nodes for source node, instead of traversing all
135 // nodes.
136 std::vector<Node*> arg_nodes = GatherNodesWithType(*g, "_Arg");
137 TF_RETURN_IF_ERROR(GetArgDataTypes(arg_nodes, recv_at_host_dtypes));
138 TF_ASSIGN_OR_RETURN(
139 Node * recv_at_host_node,
140 BuildRecvAtHostNode(g, oc_cluster_name, *recv_at_host_dtypes,
141 key_placeholder));
142 for (auto* n : arg_nodes) {
143 int index;
144 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
145 // Record out edges and remove `n` before adding those edges to RecvAtHost.
146 // This is to avoid multiple producers.
147 std::vector<OutEdgeInfo> out_edge_info;
148 out_edge_info.reserve(n->out_edges().size());
149 for (auto edge : n->out_edges()) {
150 out_edge_info.push_back(
151 {edge->dst(), edge->src_output(), edge->dst_input()});
152 }
153 g->RemoveNode(n);
154 for (const OutEdgeInfo& edge : out_edge_info) {
155 if (edge.dst_input == Graph::kControlSlot) {
156 g->AddControlEdge(recv_at_host_node, edge.dst);
157 } else {
158 g->AddEdge(recv_at_host_node, index, edge.dst, edge.dst_input);
159 }
160 }
161
162 // Rewrite dst nodes because their input changed.
163 for (int i = 0, end = out_edge_info.size(); i < end; i++) {
164 const OutEdgeInfo edge = out_edge_info[i];
165 if (edge.dst_input == Graph::kControlSlot) {
166 continue;
167 }
168
169 Node* dst = edge.dst;
170 NodeDef new_def = dst->def();
171 *new_def.mutable_input(edge.dst_input) =
172 absl::StrCat(recv_at_host_node->name(), ":", index);
173 TF_ASSIGN_OR_RETURN(Node * dst_replace, ReplaceNode(g, dst, new_def));
174
175 // Other edges might have `dst` as dst node as well. Update those edges
176 // with `dst_replace`.
177 for (int j = i + 1, end = out_edge_info.size(); j < end; j++) {
178 if (out_edge_info[j].dst == dst) {
179 out_edge_info[j].dst = dst_replace;
180 }
181 }
182 }
183 }
184 g->AddEdge(key_placeholder, 0, recv_at_host_node, 0);
185 return recv_at_host_node;
186 }
187
188 // Gets data types from `ret_nodes` and fills them into `send_from_host_dtypes`.
GetRetDataTypes(const std::vector<Node * > & ret_nodes,std::vector<DataType> * send_from_host_dtypes)189 Status GetRetDataTypes(const std::vector<Node*>& ret_nodes,
190 std::vector<DataType>* send_from_host_dtypes) {
191 send_from_host_dtypes->resize(ret_nodes.size(), DT_INVALID);
192 for (auto* n : ret_nodes) {
193 int index;
194 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
195 DataType dtype;
196 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
197 (*send_from_host_dtypes)[index] = dtype;
198 }
199 for (int i = 0, end = send_from_host_dtypes->size(); i < end; i++) {
200 if ((*send_from_host_dtypes)[i] == DT_INVALID) {
201 return errors::Internal("Cannot get datatype for output ", i);
202 }
203 }
204 return OkStatus();
205 }
206
207 // Builds XlaSendFromHost node.
BuildSendFromHostNode(Graph * g,const string & oc_cluster_name,const std::vector<Node * > & ret_nodes,const std::vector<DataType> & send_from_host_dtypes,Node * key_placeholder)208 StatusOr<Node*> BuildSendFromHostNode(
209 Graph* g, const string& oc_cluster_name,
210 const std::vector<Node*>& ret_nodes,
211 const std::vector<DataType>& send_from_host_dtypes, Node* key_placeholder) {
212 NodeDefBuilder send_from_host_builder(
213 absl::StrCat("outside_compilation_", oc_cluster_name, "_send"),
214 "_XlaSendFromHost");
215 NodeDef send_from_host_def;
216 send_from_host_builder.Attr("Tinputs", send_from_host_dtypes);
217 // The correct device_ordinal will be inserted during replication in a
218 // subsequent rewrite.
219 AttrValue device_ordinal_value;
220 device_ordinal_value.set_placeholder("_device_ordinal");
221 send_from_host_builder.Attr("device_ordinal", device_ordinal_value);
222 send_from_host_builder.Attr(
223 "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
224 send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true);
225 std::vector<NodeDefBuilder::NodeOut> inputs(send_from_host_dtypes.size());
226 for (auto* n : ret_nodes) {
227 int index;
228 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
229 const int num_dtypes = send_from_host_dtypes.size();
230 if (index < 0 || index >= num_dtypes) {
231 return errors::Internal("Invalid _Retval index: ", index);
232 }
233 for (auto edge : n->in_edges()) {
234 inputs[index] =
235 NodeDefBuilder::NodeOut{edge->src()->name(), edge->src_output(),
236 edge->src()->output_type(edge->src_output())};
237 }
238 }
239 send_from_host_builder.Input(inputs);
240 send_from_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
241 TF_RETURN_IF_ERROR(send_from_host_builder.Finalize(&send_from_host_def));
242 TF_ASSIGN_OR_RETURN(Node * send_from_host_node,
243 g->AddNode(send_from_host_def));
244 return send_from_host_node;
245 }
246
247 // Builds XlaSendFromHost node, and replaces all _Retval nodes with it.
ReplaceRetNodesWithSendFromHostNode(Graph * g,const string & oc_cluster_name,std::vector<DataType> * send_from_host_dtypes,Node * key_placeholder)248 StatusOr<Node*> ReplaceRetNodesWithSendFromHostNode(
249 Graph* g, const string& oc_cluster_name,
250 std::vector<DataType>* send_from_host_dtypes, Node* key_placeholder) {
251 // TODO(b/77601805): use in nodes for sink node, instead of traversing all
252 // nodes.
253 std::vector<Node*> ret_nodes = GatherNodesWithType(*g, "_Retval");
254 TF_RETURN_IF_ERROR(GetRetDataTypes(ret_nodes, send_from_host_dtypes));
255 TF_ASSIGN_OR_RETURN(
256 Node * send_from_host_node,
257 BuildSendFromHostNode(g, oc_cluster_name, ret_nodes,
258 *send_from_host_dtypes, key_placeholder));
259 for (auto* n : ret_nodes) {
260 int index;
261 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
262 for (auto edge : n->in_edges()) {
263 if (edge->src_output() == Graph::kControlSlot) {
264 g->AddControlEdge(edge->src(), send_from_host_node);
265 } else {
266 g->AddEdge(edge->src(), edge->src_output(), send_from_host_node, index);
267 }
268 }
269 g->RemoveNode(n);
270 }
271 g->AddEdge(key_placeholder, 0, send_from_host_node,
272 send_from_host_dtypes->size());
273 return send_from_host_node;
274 }
275
276 // Returns input shapes (excluding key placeholder) for `send_from_host_node`
277 // if they are all fully defined; std::nullopt otherwise.
GetInferredInputShapes(int num_inputs,Node * send_from_host_node)278 std::optional<std::vector<PartialTensorShape>> GetInferredInputShapes(
279 int num_inputs, Node* send_from_host_node) {
280 std::vector<PartialTensorShape> results(num_inputs);
281 for (int i = 0; i < num_inputs; i++) {
282 const Edge* e;
283 if (!send_from_host_node->input_edge(i, &e).ok()) {
284 return std::nullopt;
285 }
286
287 std::vector<PartialTensorShape> shapes;
288 if (!GetNodeAttr(e->src()->attrs(), kXlaInferredShapesAttrName, &shapes)
289 .ok()) {
290 return std::nullopt;
291 }
292
293 const PartialTensorShape shape = shapes[e->src_output()];
294 if (!shape.IsFullyDefined()) {
295 return std::nullopt;
296 }
297
298 results[e->dst_input()] = shape;
299 }
300 return results;
301 }
302
host_compute_node_name(const string & original_oc_name)303 string host_compute_node_name(const string& original_oc_name) {
304 return absl::StrCat("outside_compilation_", original_oc_name,
305 "_host_compute");
306 }
307
308 // Builds XlaHostCompute NodeDef from the outside compilation call node.
BuildXlaHostComputeNodeDef(const Node * call_node,const std::map<string,int> & host_compute_core,const absl::flat_hash_map<string,std::vector<string>> & cluster_deps)309 StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
310 const Node* call_node, const std::map<string, int>& host_compute_core,
311 const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
312 string original_oc_name;
313 TF_RETURN_IF_ERROR(GetNodeAttr(
314 call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name));
315 NodeDefBuilder host_compute_builder(host_compute_node_name(original_oc_name),
316 "XlaHostCompute");
317 // In XlaCompiler, if XlaHostCompute node is in a function call node and that
318 // function is inlined, name of the XlaHostCompute node will be changed. So
319 // we cannot rely on node name; use an attribute instead.
320 host_compute_builder.Attr(kXlaOriginalOutsideCompilationNodeName,
321 host_compute_builder.node_name());
322
323 // Copy all attributes.
324 for (const auto& attr : call_node->attrs()) {
325 host_compute_builder.Attr(attr.first, attr.second);
326 }
327
328 // Populate tpu_core assignment.
329 const auto iter = host_compute_core.find(original_oc_name);
330 if (iter != host_compute_core.end()) {
331 int core = iter->second;
332 host_compute_builder.Attr("tpu_core", core);
333 }
334
335 // Set input tokens and other outside compilation clusters that current
336 // cluster depends in `kXlaTokenArgNodeName`. This is needed because when
337 // outside compilation subgraphs are encapsulated and moved to host graph,
338 // control/data edges between them will only be reflected in host graph.
339 // From XLA's perspective, two originally dependent clusters are no longer
340 // connected, which makes them look like they can be scheduled for execution
341 // in arbitrary order even though in fact they must be executed in order
342 // according to their host-side graph dependency. This can cause deadlock.
343 // Therefore, we hint XLA what the correct ordering of these clusters should
344 // be to avoid deadlocks.
345 std::vector<string> xla_token_input_nodes;
346 xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName);
347 auto cluster_deps_it = cluster_deps.find(original_oc_name);
348 if (cluster_deps_it != cluster_deps.end()) {
349 for (const auto& dep : cluster_deps_it->second) {
350 xla_token_input_nodes.emplace_back(host_compute_node_name(dep));
351 }
352 }
353 host_compute_builder.Attr(kXlaTokenInputNodesAttrName, xla_token_input_nodes);
354
355 // Populate inputs.
356 std::vector<DataType> input_dtypes;
357 TF_RETURN_IF_ERROR(GetNodeAttr(call_node->attrs(), "Tinputs", &input_dtypes));
358 std::vector<NodeDefBuilder::NodeOut> inputs(input_dtypes.size());
359 for (auto e : call_node->in_edges()) {
360 if (e->IsControlEdge()) {
361 continue;
362 }
363
364 const int input_dtypes_size = input_dtypes.size();
365 if (e->dst_input() < 0 || e->dst_input() >= input_dtypes_size) {
366 return errors::Internal("Invalid dst_input: ", e->dst_input());
367 }
368 inputs[e->dst_input()] = NodeDefBuilder::NodeOut{
369 e->src()->name(), e->src_output(), input_dtypes[e->dst_input()]};
370 }
371 host_compute_builder.Input(inputs);
372
373 NodeDef new_def;
374 TF_RETURN_IF_ERROR(host_compute_builder.Finalize(&new_def));
375 return new_def;
376 }
377
378 // Replace outside compilation function call node with XlaHostCompute node.
ReplaceOutsideCompilationCallNode(Graph * g,Node * call_node,const std::map<string,int> & host_compute_core,const absl::flat_hash_map<string,std::vector<string>> & cluster_deps)379 TF_ATTRIBUTE_NOINLINE StatusOr<Node*> ReplaceOutsideCompilationCallNode(
380 Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
381 const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
382 // Build XlaHostCompute NodeDef.
383 TF_ASSIGN_OR_RETURN(
384 NodeDef node_def,
385 BuildXlaHostComputeNodeDef(call_node, host_compute_core, cluster_deps));
386 TF_ASSIGN_OR_RETURN(Node * host_compute_node,
387 ReplaceNode(g, call_node, node_def));
388 VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
389
390 return host_compute_node;
391 }
392
393 // Resets "_device_ordinal" attr to placeholder value for related nodes
394 // (XlaRecvAtHost nodes; XlaSendFromHost nodes; If/While/FuncCall nodes
395 // containing XlaRecvAtHost/XlaSendFromHost).
ResetDeviceOrdinalToPlaceholderValue(Graph * g)396 Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) {
397 AttrValue device_ordinal_value;
398 device_ordinal_value.set_placeholder("_device_ordinal");
399 for (Node* n : g->nodes()) {
400 if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
401 continue;
402 }
403
404 if (n->type_string() == "_XlaRecvAtHost" ||
405 n->type_string() == "_XlaSendFromHost") {
406 n->ClearAttr("device_ordinal");
407 n->AddAttr("device_ordinal", device_ordinal_value);
408 } else if (n->IsIfNode()) {
409 for (const string& attr_name :
410 std::vector<string>{"then_branch", "else_branch"}) {
411 NameAttrList branch_func;
412 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
413 (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
414 n->ClearAttr(attr_name);
415 n->AddAttr(attr_name, branch_func);
416 }
417 } else if (n->IsWhileNode()) {
418 for (const string& attr_name : std::vector<string>{"cond", "body"}) {
419 NameAttrList branch_func;
420 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
421 (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
422 n->ClearAttr(attr_name);
423 n->AddAttr(attr_name, branch_func);
424 }
425 } else if (HasNodeAttr(n->def(), "_device_ordinal")) {
426 // Function call node containing outside compilation.
427 n->ClearAttr("_device_ordinal");
428 n->AddAttr("_device_ordinal", device_ordinal_value);
429 } else {
430 return errors::Internal("Unknown node marked with ",
431 kXlaHasHostTransferAttrName, ": ",
432 n->DebugString());
433 }
434 }
435 return OkStatus();
436 }
437
438 // Cheap check to tell whether FunctionDef contains a lifted argument.
HasLiftedArgs(const FunctionDef & function_def)439 bool HasLiftedArgs(const FunctionDef& function_def) {
440 return absl::c_any_of(function_def.node_def(), [](const NodeDef& node_def) {
441 return (node_def.op() == "Placeholder" &&
442 node_def.attr().find(kXlaLiftedArgOutsideCompilationAttrName) !=
443 node_def.attr().end());
444 });
445 }
446
447 // Find lifted arguments in a function body and their corresponding outside
448 // compilation nodes.
449 StatusOr<std::vector<std::pair<Node*, Node*>>>
LiftedArgsAndOutsideCompilationNodesInFunctionBody(const FunctionBody & function_body,const std::unordered_map<string,Node * > & outside_compilation_attr_to_node)450 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
451 const FunctionBody& function_body,
452 const std::unordered_map<string, Node*>& outside_compilation_attr_to_node) {
453 std::vector<std::pair<Node*, Node*>>
454 lifted_arg_nodes_and_outside_compilation_nodes;
455 for (Node* n : function_body.graph->op_nodes()) {
456 string oc_cluster;
457 if (n->type_string() == "Placeholder" &&
458 GetNodeAttr(n->def(), kXlaLiftedArgOutsideCompilationAttrName,
459 &oc_cluster)
460 .ok()) {
461 TF_RET_CHECK(outside_compilation_attr_to_node.find(oc_cluster) !=
462 outside_compilation_attr_to_node.end());
463 lifted_arg_nodes_and_outside_compilation_nodes.emplace_back(
464 n, outside_compilation_attr_to_node.at(oc_cluster));
465 }
466 }
467 return lifted_arg_nodes_and_outside_compilation_nodes;
468 }
469
470 // Append lifted args' types to functional control flow node's `type_attr_name`
471 // attribute.
UpdateTypesAttribute(const std::vector<std::pair<Node *,Node * >> & lifted_arg_nodes_and_outside_compilation_nodes,const string & type_attr_name,Node * n)472 StatusOr<std::vector<DataType>> UpdateTypesAttribute(
473 const std::vector<std::pair<Node*, Node*>>&
474 lifted_arg_nodes_and_outside_compilation_nodes,
475 const string& type_attr_name, Node* n) {
476 std::vector<DataType> data_types;
477 data_types.reserve(lifted_arg_nodes_and_outside_compilation_nodes.size());
478 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), type_attr_name, &data_types));
479 for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) {
480 Node* outside_compilation_node = pair.second;
481 DataType data_type;
482 TF_RET_CHECK(outside_compilation_node->IsIdentity() ||
483 outside_compilation_node->type_string() == "Placeholder");
484 if (outside_compilation_node->IsIdentity()) {
485 TF_RETURN_IF_ERROR(
486 GetNodeAttr(outside_compilation_node->def(), "T", &data_type));
487 } else {
488 TF_RETURN_IF_ERROR(
489 GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type));
490 }
491 data_types.push_back(data_type);
492 }
493 n->ClearAttr(type_attr_name);
494 n->AddAttr(type_attr_name, data_types);
495
496 return data_types;
497 }
498
499 // Add edges from lifted outside compilation argument nodes to `n` in Graph `g`.
AddEdgesFromOutsideCompilationNodes(const int original_arg_count,const int arg_to_input_edge_offset,const std::vector<DataType> & data_types,const std::vector<Node * > & outside_compilation_nodes,Graph * g,Node * n)500 void AddEdgesFromOutsideCompilationNodes(
501 const int original_arg_count, const int arg_to_input_edge_offset,
502 const std::vector<DataType>& data_types,
503 const std::vector<Node*>& outside_compilation_nodes, Graph* g, Node* n) {
504 // Add edges from outside compilation nodes to While node.
505 for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
506 Node* outside_compilation_node =
507 outside_compilation_nodes[i - original_arg_count];
508 g->AddEdge(outside_compilation_node, 0, n, i + arg_to_input_edge_offset);
509 }
510 }
511
512 // Construct _Arg that maps to lifted outside compilation argument node input.
AddOutsideCompilationInputArgToFunctionBody(const FunctionBody & function_body,const int arg_idx,const DataType & data_type)513 StatusOr<Node*> AddOutsideCompilationInputArgToFunctionBody(
514 const FunctionBody& function_body, const int arg_idx,
515 const DataType& data_type) {
516 NodeDefBuilder arg_builder(absl::StrCat("arg_", arg_idx), "_Arg");
517 arg_builder.Attr("T", data_type);
518 arg_builder.Attr("index", arg_idx);
519 NodeDef arg_def;
520 TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
521
522 TF_ASSIGN_OR_RETURN(Node * arg_node, function_body.graph->AddNode(arg_def));
523 return arg_node;
524 }
525
526 // Add _Retval node that matches newly added `arg_node` and connect `arg_node`
527 // to it.
AddMatchingRetvalNode(const FunctionBody & function_body,const int arg_idx,const DataType & data_type,Node * arg_node)528 Status AddMatchingRetvalNode(const FunctionBody& function_body,
529 const int arg_idx, const DataType& data_type,
530 Node* arg_node) {
531 NodeDefBuilder ret_builder(absl::StrCat("ret_", arg_idx), "_Retval");
532 ret_builder.Attr("T", data_type);
533 ret_builder.Attr("index", arg_idx);
534 ret_builder.Input(arg_node->name(), 0, data_type);
535 NodeDef ret_def;
536 TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
537 TF_ASSIGN_OR_RETURN(Node * ret_node, function_body.graph->AddNode(ret_def));
538 function_body.graph->AddEdge(arg_node, 0, ret_node, 0);
539
540 return OkStatus();
541 }
542
ReplaceLiftedArgNodePlaceholderWithArg(const FunctionBody & function_body,const int original_arg_count,const int arg_idx,const std::vector<Node * > & lifted_arg_nodes,Node * arg_node)543 void ReplaceLiftedArgNodePlaceholderWithArg(
544 const FunctionBody& function_body, const int original_arg_count,
545 const int arg_idx, const std::vector<Node*>& lifted_arg_nodes,
546 Node* arg_node) {
547 Node* lifted_arg_node = lifted_arg_nodes[arg_idx - original_arg_count];
548 // This might happen because lifted_arg_node only exists in one branch of an
549 // If node, and we are handling the other branch.
550 if (!lifted_arg_node) {
551 return;
552 }
553
554 for (const Edge* e : lifted_arg_node->out_edges()) {
555 if (e->IsControlEdge()) {
556 function_body.graph->AddControlEdge(arg_node, e->dst());
557 } else {
558 function_body.graph->AddEdge(arg_node, 0, e->dst(), e->dst_input());
559 }
560 }
561 function_body.graph->RemoveNode(lifted_arg_node);
562 }
563
564 // Adds function def to function definition library and update the function
565 // callsite operation `callsite_node` to invoke new function instead.
AddFunctionWithNewName(const std::string & new_name,const std::string & func_attr_name,const FunctionDef & function_def,NameAttrList * func_attr,Node * callsite_node,FunctionLibraryDefinition * fld)566 Status AddFunctionWithNewName(const std::string& new_name,
567 const std::string& func_attr_name,
568 const FunctionDef& function_def,
569 NameAttrList* func_attr, Node* callsite_node,
570 FunctionLibraryDefinition* fld) {
571 TF_RETURN_IF_ERROR(fld->AddFunctionDef(function_def));
572 func_attr->set_name(new_name);
573 callsite_node->ClearAttr(func_attr_name);
574 callsite_node->AddAttr(func_attr_name, *func_attr);
575 return OkStatus();
576 }
577
578 // Reconnect outside compilation lifted arguments in a functional While node to
579 // its outside compilation tensor sources.
PostprocessLiftedArgsForWhile(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)580 Status PostprocessLiftedArgsForWhile(
581 const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
582 Graph* g, Node* n, FunctionLibraryDefinition* fld) {
583 TF_RET_CHECK(n->IsWhileNode());
584
585 // Check if there is any lifted args in body function.
586 NameAttrList body_func;
587 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "body", &body_func));
588 const FunctionDef* body_function_def = fld->Find(body_func.name());
589 TF_RET_CHECK(body_function_def);
590
591 if (!HasLiftedArgs(*body_function_def)) {
592 return OkStatus();
593 }
594
595 // Gather all lifted args.
596 std::unique_ptr<FunctionBody> body_function_body;
597 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*body_function_def,
598 AttrSlice(&body_func.attr()), fld,
599 &body_function_body));
600
601 int original_arg_count = body_function_body->arg_nodes.size();
602
603 TF_ASSIGN_OR_RETURN(
604 auto lifted_arg_nodes_and_outside_compilation_nodes,
605 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
606 *body_function_body, outside_compilation_attr_to_node));
607
608 // Append lifted args' types to While node's T attribute.
609 TF_ASSIGN_OR_RETURN(
610 std::vector<DataType> data_types,
611 UpdateTypesAttribute(lifted_arg_nodes_and_outside_compilation_nodes, "T",
612 n));
613
614 // Add edges from outside compilation nodes to While node.
615 std::vector<Node*> outside_compilation_nodes;
616 outside_compilation_nodes.reserve(
617 lifted_arg_nodes_and_outside_compilation_nodes.size());
618 std::transform(
619 lifted_arg_nodes_and_outside_compilation_nodes.begin(),
620 lifted_arg_nodes_and_outside_compilation_nodes.end(),
621 std::back_inserter(outside_compilation_nodes),
622 [](const std::pair<Node*, Node*>& pair) { return pair.second; });
623 AddEdgesFromOutsideCompilationNodes(original_arg_count,
624 /*arg_to_input_edge_offset=*/0,
625 data_types, outside_compilation_nodes, g,
626 n);
627
628 // In body_graph, create new _Arg/_Retval nodes, and replace lifted arg
629 // nodes with the new _Arg nodes.
630 std::vector<Node*> lifted_arg_nodes;
631 lifted_arg_nodes.reserve(
632 lifted_arg_nodes_and_outside_compilation_nodes.size());
633 std::transform(
634 lifted_arg_nodes_and_outside_compilation_nodes.begin(),
635 lifted_arg_nodes_and_outside_compilation_nodes.end(),
636 std::back_inserter(lifted_arg_nodes),
637 [](const std::pair<Node*, Node*>& pair) { return pair.first; });
638 for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
639 TF_ASSIGN_OR_RETURN(Node * arg_node,
640 AddOutsideCompilationInputArgToFunctionBody(
641 *body_function_body, i, data_types[i]));
642
643 TF_RETURN_IF_ERROR(
644 AddMatchingRetvalNode(*body_function_body, i, data_types[i], arg_node));
645
646 ReplaceLiftedArgNodePlaceholderWithArg(
647 *body_function_body, original_arg_count, i, lifted_arg_nodes, arg_node);
648 }
649
650 const auto new_body_function_name =
651 fld->UniqueFunctionName(absl::StrCat(body_func.name(), "_lifted_arg_"));
652 FunctionDef rewritten_body_function_def;
653 TF_RETURN_IF_ERROR(GraphToFunctionDef(
654 *body_function_body->graph, new_body_function_name,
655 HostGraphControlRetMapping, &rewritten_body_function_def));
656 TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_body_function_name, "body",
657 rewritten_body_function_def,
658 &body_func, n, fld));
659
660 // In cond_graph, just add new _Arg nodes.
661 NameAttrList cond_func;
662 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "cond", &cond_func));
663 const FunctionDef* cond_function_def = fld->Find(cond_func.name());
664 TF_RET_CHECK(cond_function_def);
665 std::unique_ptr<FunctionBody> cond_function_body;
666 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*cond_function_def,
667 AttrSlice(&cond_func.attr()), fld,
668 &cond_function_body));
669
670 for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
671 StatusOr<Node*> arg_node_or = AddOutsideCompilationInputArgToFunctionBody(
672 *cond_function_body, i, data_types[i]);
673 TF_RETURN_IF_ERROR(arg_node_or.status());
674 }
675
676 const auto new_cond_function_name =
677 fld->UniqueFunctionName(absl::StrCat(cond_func.name(), "_lifted_arg_"));
678 FunctionDef rewritten_cond_function_def;
679 TF_RETURN_IF_ERROR(GraphToFunctionDef(
680 *cond_function_body->graph, new_cond_function_name,
681 HostGraphControlRetMapping, &rewritten_cond_function_def));
682 TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_cond_function_name, "cond",
683 rewritten_cond_function_def,
684 &cond_func, n, fld));
685 return OkStatus();
686 }
687
PostprocessLiftedArgsForIf(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)688 Status PostprocessLiftedArgsForIf(
689 const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
690 Graph* g, Node* n, FunctionLibraryDefinition* fld) {
691 TF_RET_CHECK(n->IsIfNode());
692
693 NameAttrList then_branch_func;
694 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "then_branch", &then_branch_func));
695 const FunctionDef* then_branch_function_def =
696 fld->Find(then_branch_func.name());
697 TF_RET_CHECK(then_branch_function_def);
698
699 NameAttrList else_branch_func;
700 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "else_branch", &else_branch_func));
701 const FunctionDef* else_branch_function_def =
702 fld->Find(else_branch_func.name());
703 TF_RET_CHECK(else_branch_function_def);
704
705 // Nothing to do if neither branch contains any lifted arguments.
706 if (!HasLiftedArgs(*then_branch_function_def) &&
707 !HasLiftedArgs(*else_branch_function_def)) {
708 return OkStatus();
709 }
710
711 std::unique_ptr<FunctionBody> then_branch_function_body;
712 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
713 *then_branch_function_def, AttrSlice(&then_branch_func.attr()), fld,
714 &then_branch_function_body));
715
716 std::unique_ptr<FunctionBody> else_branch_function_body;
717 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
718 *else_branch_function_def, AttrSlice(&else_branch_func.attr()), fld,
719 &else_branch_function_body));
720
721 // Then and else branches have same argument count and argument data types.
722 int original_arg_count = then_branch_function_body->arg_nodes.size();
723
724 TF_ASSIGN_OR_RETURN(
725 auto then_branch_lifted_arg_nodes_and_outside_compilation_nodes,
726 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
727 *then_branch_function_body, outside_compilation_attr_to_node));
728
729 TF_ASSIGN_OR_RETURN(
730 auto else_branch_lifted_arg_nodes_and_outside_compilation_nodes,
731 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
732 *else_branch_function_body, outside_compilation_attr_to_node));
733
734 // Merge lifted args from then and else branches.
735 std::vector<Node*> outside_compilation_nodes;
736 std::vector<Node*> then_branch_lifted_arg_nodes;
737 outside_compilation_nodes.reserve(
738 then_branch_lifted_arg_nodes_and_outside_compilation_nodes.size());
739 then_branch_lifted_arg_nodes.reserve(
740 then_branch_lifted_arg_nodes_and_outside_compilation_nodes.size());
741 for (const auto& pair :
742 then_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
743 outside_compilation_nodes.push_back(pair.second);
744 then_branch_lifted_arg_nodes.push_back(pair.first);
745 }
746 for (const auto& pair :
747 else_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
748 if (std::find(outside_compilation_nodes.begin(),
749 outside_compilation_nodes.end(),
750 pair.second) == outside_compilation_nodes.end()) {
751 outside_compilation_nodes.push_back(pair.second);
752 // Then branch does not contain this lifted arg. Add an empty item to
753 // then_branch_lifted_arg_nodes.
754 then_branch_lifted_arg_nodes.push_back(nullptr);
755 }
756 }
757 // Reorder else_branch_lifted_arg_nodes_and_outside_compilation_nodes.
758 std::vector<Node*> else_branch_lifted_arg_nodes(
759 outside_compilation_nodes.size());
760 for (const auto& pair :
761 else_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
762 auto iter = std::find(outside_compilation_nodes.begin(),
763 outside_compilation_nodes.end(), pair.second);
764 TF_RET_CHECK(iter != outside_compilation_nodes.end());
765 int index = iter - outside_compilation_nodes.begin();
766 else_branch_lifted_arg_nodes[index] = pair.first;
767 }
768
769 // Append lifted args' types to If node's Tin attribute.
770 std::vector<DataType> data_types;
771 data_types.reserve(outside_compilation_nodes.size());
772 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &data_types));
773 for (Node* n : outside_compilation_nodes) {
774 data_types.push_back(n->output_type(0));
775 }
776 n->ClearAttr("Tin");
777 n->AddAttr("Tin", data_types);
778
779 // Add edges from outside compilation nodes to If node. If node's input #0
780 // is predicate input, input #1 maps to _Arg #0 of branch functions, thus
781 // arg_to_input_edge_offset is set to 1.
782 AddEdgesFromOutsideCompilationNodes(original_arg_count,
783 /*arg_to_input_edge_offset=*/1,
784 data_types, outside_compilation_nodes, g,
785 n);
786
787 for (int i = original_arg_count, end = data_types.size(); i < end; ++i) {
788 TF_ASSIGN_OR_RETURN(Node * then_branch_arg_node,
789 AddOutsideCompilationInputArgToFunctionBody(
790 *then_branch_function_body, i, data_types[i]));
791
792 ReplaceLiftedArgNodePlaceholderWithArg(
793 *then_branch_function_body, original_arg_count, i,
794 then_branch_lifted_arg_nodes, then_branch_arg_node);
795
796 TF_ASSIGN_OR_RETURN(Node * else_branch_arg_node,
797 AddOutsideCompilationInputArgToFunctionBody(
798 *else_branch_function_body, i, data_types[i]));
799
800 ReplaceLiftedArgNodePlaceholderWithArg(
801 *else_branch_function_body, original_arg_count, i,
802 else_branch_lifted_arg_nodes, else_branch_arg_node);
803 }
804
805 const auto new_then_function_name = fld->UniqueFunctionName(
806 absl::StrCat(then_branch_func.name(), "_lifted_arg_"));
807 FunctionDef rewritten_then_branch_function_def;
808 TF_RETURN_IF_ERROR(GraphToFunctionDef(
809 *then_branch_function_body->graph, new_then_function_name,
810 HostGraphControlRetMapping, &rewritten_then_branch_function_def));
811 TF_RETURN_IF_ERROR(AddFunctionWithNewName(
812 new_then_function_name, "then_branch", rewritten_then_branch_function_def,
813 &then_branch_func, n, fld));
814
815 const auto new_else_function_name = fld->UniqueFunctionName(
816 absl::StrCat(else_branch_func.name(), "_lifted_arg_"));
817 FunctionDef rewritten_else_branch_function_def;
818 TF_RETURN_IF_ERROR(GraphToFunctionDef(
819 *else_branch_function_body->graph, new_else_function_name,
820 HostGraphControlRetMapping, &rewritten_else_branch_function_def));
821 TF_RETURN_IF_ERROR(AddFunctionWithNewName(
822 new_else_function_name, "else_branch", rewritten_else_branch_function_def,
823 &else_branch_func, n, fld));
824 return OkStatus();
825 }
826
PostprocessLiftedArgsForCall(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)827 Status PostprocessLiftedArgsForCall(
828 const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
829 Graph* g, Node* n, FunctionLibraryDefinition* fld) {
830 const FunctionDef* fdef = fld->Find(n->type_string());
831 TF_RET_CHECK(fdef);
832
833 // Nothing to do if the function does not contain any lifted arguments.
834 if (!HasLiftedArgs(*fdef)) {
835 return OkStatus();
836 }
837
838 std::unique_ptr<FunctionBody> fbody;
839 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, n->attrs(), fld, &fbody));
840
841 int original_arg_count = fbody->arg_nodes.size();
842
843 TF_ASSIGN_OR_RETURN(auto lifted_arg_nodes_and_outside_compilation_nodes,
844 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
845 *fbody, outside_compilation_attr_to_node));
846
847 // Append lifted args' types to call node's input data types.
848 std::vector<DataType> data_types(n->input_types().begin(),
849 n->input_types().end());
850 for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) {
851 Node* outside_compilation_node = pair.second;
852 DataType data_type;
853 TF_RET_CHECK(outside_compilation_node->IsIdentity() ||
854 outside_compilation_node->type_string() == "Placeholder");
855 if (outside_compilation_node->IsIdentity()) {
856 TF_RETURN_IF_ERROR(
857 GetNodeAttr(outside_compilation_node->def(), "T", &data_type));
858 } else {
859 TF_RETURN_IF_ERROR(
860 GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type));
861 }
862 data_types.push_back(data_type);
863 }
864
865 std::vector<Node*> lifted_arg_nodes;
866 lifted_arg_nodes.reserve(
867 lifted_arg_nodes_and_outside_compilation_nodes.size());
868 std::transform(
869 lifted_arg_nodes_and_outside_compilation_nodes.begin(),
870 lifted_arg_nodes_and_outside_compilation_nodes.end(),
871 std::back_inserter(lifted_arg_nodes),
872 [](const std::pair<Node*, Node*>& pair) { return pair.first; });
873 for (int i = original_arg_count, end = data_types.size(); i < end; ++i) {
874 TF_ASSIGN_OR_RETURN(
875 Node * arg_node,
876 AddOutsideCompilationInputArgToFunctionBody(*fbody, i, data_types[i]));
877
878 ReplaceLiftedArgNodePlaceholderWithArg(*fbody, original_arg_count, i,
879 lifted_arg_nodes, arg_node);
880 }
881
882 FunctionDef rewritten_fdef;
883 TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, n->type_string(),
884 HostGraphControlRetMapping,
885 &rewritten_fdef));
886 const auto new_function_name =
887 fld->UniqueFunctionName(absl::StrCat(n->type_string(), "_lifted_arg_"));
888 rewritten_fdef.mutable_signature()->set_name(new_function_name);
889 TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
890
891 // We need to recreate the node. Otherwise TF will not know n->num_inputs()
892 // has increased.
893 NodeDef node_def = n->def();
894
895 // Function name is represented via the Op's type. Reset the op type to new
896 // function def name;
897 *node_def.mutable_op() = new_function_name;
898
899 for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
900 Node* outside_compilation_node =
901 lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count]
902 .second;
903 node_def.add_input(absl::StrCat(outside_compilation_node->name(), ":", 0));
904 }
905 TF_ASSIGN_OR_RETURN(n, ReplaceNode(g, n, node_def));
906
907 // Add edges from outside compilation nodes to call node.
908 std::vector<Node*> outside_compilation_nodes;
909 outside_compilation_nodes.reserve(
910 lifted_arg_nodes_and_outside_compilation_nodes.size());
911 std::transform(
912 lifted_arg_nodes_and_outside_compilation_nodes.begin(),
913 lifted_arg_nodes_and_outside_compilation_nodes.end(),
914 std::back_inserter(outside_compilation_nodes),
915 [](const std::pair<Node*, Node*>& pair) { return pair.second; });
916 AddEdgesFromOutsideCompilationNodes(original_arg_count,
917 /*arg_to_input_edge_offset=*/0,
918 data_types, outside_compilation_nodes, g,
919 n);
920
921 return OkStatus();
922 }
923
924 // Creates a mapping from outside compilation cluster name to lifted argument
925 // placeholder.
OutsideCompilationAttrToNode(const Graph & g)926 StatusOr<std::unordered_map<string, Node*>> OutsideCompilationAttrToNode(
927 const Graph& g) {
928 std::unordered_map<string, Node*> outside_compilation_attr_to_node;
929 for (Node* n : g.op_nodes()) {
930 bool is_lifted_arg;
931 string outside_compilation_attr;
932 if (TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) &&
933 TryGetNodeAttr(n->def(), "_xla_outside_compilation",
934 &outside_compilation_attr)) {
935 TF_RET_CHECK(is_lifted_arg);
936 TF_RET_CHECK(n->IsIdentity() || n->type_string() == "Placeholder");
937 outside_compilation_attr_to_node[outside_compilation_attr] = n;
938 }
939 }
940
941 return outside_compilation_attr_to_node;
942 }
943
PostprocessLiftedArgs(Graph * g,FunctionLibraryDefinition * fld)944 Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) {
945 TF_ASSIGN_OR_RETURN(auto outside_compilation_attr_to_node,
946 OutsideCompilationAttrToNode(*g));
947
948 std::vector<Node*> call_nodes;
949 for (Node* n : g->op_nodes()) {
950 if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
951 continue;
952 }
953
954 if (n->IsWhileNode()) {
955 TF_RETURN_IF_ERROR(PostprocessLiftedArgsForWhile(
956 outside_compilation_attr_to_node, g, n, fld));
957 }
958
959 if (n->IsIfNode()) {
960 TF_RETURN_IF_ERROR(PostprocessLiftedArgsForIf(
961 outside_compilation_attr_to_node, g, n, fld));
962 }
963
964 // Outside compilation host side function call will always be direct
965 // function call nodes.
966 // Function call nodes need to be handled separately because we rewrite
967 // nodes in `PostprocessLiftedArgsForCall`.
968 if (fld->Contains(n->type_string())) {
969 call_nodes.push_back(n);
970 }
971 }
972
973 for (Node* n : call_nodes) {
974 TF_RETURN_IF_ERROR(PostprocessLiftedArgsForCall(
975 outside_compilation_attr_to_node, g, n, fld));
976 }
977
978 return OkStatus();
979 }
980
981 // For an XLA computation, builds host side graph given all outside compilation
982 // graphs inside it. The host side graph contains:
983 // 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and
984 // XlaSendFromHost to this sequencer node, so all outside compilation nodes
985 // will be executed *before* this sequencer).
986 // 2) a "key placeholder" node. Later in ExpandHostGraphIntoMainGraph(), we will
987 // replace this node with compilation result node.
988 // 3) all outside compilation graphs.
ConstructHostGraph(const string & xla_cluster_name,const string & outside_compilation_attr_name,const std::vector<string> & outside_compilation_host_graphs,FunctionLibraryDefinition * fld,std::unique_ptr<Graph> * host_graph)989 Status ConstructHostGraph(
990 const string& xla_cluster_name, const string& outside_compilation_attr_name,
991 const std::vector<string>& outside_compilation_host_graphs,
992 FunctionLibraryDefinition* fld, std::unique_ptr<Graph>* host_graph) {
993 host_graph->reset(new Graph(fld));
994
995 // Create sequencer node in host graph.
996 NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"),
997 "NoOp");
998 sequencer_builder.Attr("_xla_host_transfer_sequencer", xla_cluster_name);
999 NodeDef sequencer_def;
1000 TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def));
1001 TF_ASSIGN_OR_RETURN(Node * sequencer, (*host_graph)->AddNode(sequencer_def));
1002
1003 // Create key placeholder in host graph.
1004 TF_ASSIGN_OR_RETURN(
1005 Node * key_placeholder,
1006 AddHostComputeKeyPlaceholder(xla_cluster_name, host_graph->get()));
1007
1008 // For each outside compilation graph, copy them to host graph with the
1009 // following changes:
1010 // a) Use key_placeholder in host graph instead of its own.
1011 // b) Add control edge from host transfer nodes (XlaRecvAtHost,
1012 // XlaSendFromHost, If/While nodes containing
1013 // XlaRecvAtHost/XlaSendFromHost) to sequencer node.
1014 // c) Clear node_def.device(), so device placer won't get confused.
1015 for (const string& host_func : outside_compilation_host_graphs) {
1016 VLOG(4) << "Expanding host graph " << host_func;
1017 // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder
1018 // value after we expanded all host graphs. We cannot just use placeholder
1019 // value here because FunctionDef instantiation does not allow placeholder
1020 // value for attributes.
1021 AttrValue device_ordinal_attr;
1022 device_ordinal_attr.set_i(0);
1023 protobuf::Map<string, AttrValue> attrs;
1024 attrs["_device_ordinal"] = device_ordinal_attr;
1025 std::unique_ptr<FunctionBody> host_fbody;
1026 const FunctionDef* host_fdef = fld->Find(host_func);
1027 TF_RET_CHECK(host_fdef);
1028 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_fdef, AttrSlice(&attrs),
1029 fld, &host_fbody));
1030
1031 // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
1032 // reachable from sink node so all nodes will be copied.
1033 // TODO(b/77601805): consolidate copy graph functions.
1034 FixupSourceAndSinkEdges(host_fbody->graph);
1035
1036 std::map<const Node*, Node*> node_map;
1037 node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node();
1038 node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node();
1039 Status s;
1040 ReverseDFS(
1041 *host_fbody->graph, /*enter=*/nullptr,
1042 [&](const Node* n) {
1043 if (!s.ok()) {
1044 return;
1045 }
1046
1047 Node* copy;
1048 if (node_map.find(n) != node_map.end()) {
1049 // Already copied this node.
1050 copy = node_map.at(n);
1051 } else if (IsKeyPlaceholderNode(*n)) {
1052 // Change a).
1053 copy = key_placeholder;
1054 node_map[n] = copy;
1055 } else {
1056 // Copy the node.
1057 NodeDef copy_def = n->def();
1058 // Change c).
1059 copy_def.clear_device();
1060 copy = (*host_graph)->AddNode(copy_def, &s);
1061 if (!s.ok()) {
1062 return;
1063 }
1064 node_map[n] = copy;
1065 }
1066
1067 // Only handle input edges. Output edges will be added later as
1068 // its output nodes' input edges.
1069 for (auto e : n->in_edges()) {
1070 if (node_map.find(e->src()) == node_map.end()) {
1071 s = errors::Internal("Cannot find node image for ",
1072 e->src()->DebugString());
1073 return;
1074 }
1075 (*host_graph)
1076 ->AddEdge(node_map[e->src()], e->src_output(), copy,
1077 e->dst_input());
1078 }
1079
1080 // Change b).
1081 if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) {
1082 (*host_graph)->AddControlEdge(copy, sequencer);
1083 }
1084 },
1085 NodeComparatorID());
1086
1087 if (!s.ok()) {
1088 return s;
1089 }
1090 }
1091 // Reset "_device_ordinal" to placeholder value.
1092 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(host_graph->get()));
1093
1094 // sequencer and key_placeholder might be dead nodes. Prune them if necessary.
1095 // - sequencer should be pruned iff it has no input control edges from
1096 // RecvAtHost/SendFromHost. If it has input control edge, we connect it to
1097 // sink node so it won't be pruned.
1098 // - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost.
1099 // We don't need to do anything special.
1100 if (!sequencer->in_edges().empty()) {
1101 (*host_graph)->AddControlEdge(sequencer, (*host_graph)->sink_node());
1102 }
1103 PruneForReverseReachability(
1104 host_graph->get(),
1105 std::unordered_set<const Node*>{(*host_graph)->sink_node()});
1106
1107 // Postprocess edges between different outside compilations.
1108 TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations(
1109 host_graph->get(), outside_compilation_attr_name));
1110
1111 // Postprocess lifted arg nodes.
1112 TF_RETURN_IF_ERROR(PostprocessLiftedArgs(host_graph->get(), fld));
1113
1114 if (VLOG_IS_ON(4)) {
1115 DumpGraphToFile(absl::StrCat("extract_outside_compilation_host_graph_for_",
1116 xla_cluster_name),
1117 **host_graph, fld);
1118 }
1119
1120 return OkStatus();
1121 }
1122
1123 // Expand XLA computation's outside compilation host side graph into main graph.
1124 // Add a control edge between sequencer node and the XLA computation node.
ExpandHostGraphIntoMainGraph(Graph * main_graph,FunctionLibraryDefinition * fld,const string & host_graph_func_name,Node * xla_computation_node,Node * pivot_node)1125 Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
1126 FunctionLibraryDefinition* fld,
1127 const string& host_graph_func_name,
1128 Node* xla_computation_node,
1129 Node* pivot_node) {
1130 // Temporarily use "0" as "_device_ordinal". It will be rewritten with the
1131 // correct value in a later pass. We cannot just use placeholder value here
1132 // because FunctionDef instantiation does not allow placeholder value for
1133 // attributes.
1134 AttrValue device_ordinal_attr;
1135 device_ordinal_attr.set_i(0);
1136 protobuf::Map<string, AttrValue> attrs;
1137 attrs["_device_ordinal"] = device_ordinal_attr;
1138 std::unique_ptr<FunctionBody> fbody;
1139 const FunctionDef* host_graph_func = fld->Find(host_graph_func_name);
1140 TF_RET_CHECK(host_graph_func);
1141 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_graph_func,
1142 AttrSlice(&attrs), fld, &fbody));
1143 Graph* host_graph = fbody->graph;
1144
1145 // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
1146 // reachable from sink node so all nodes will be copied.
1147 // TODO(b/77601805): consolidate copy graph functions.
1148 FixupSourceAndSinkEdges(host_graph);
1149
1150 // Copy all nodes.
1151 std::map<const Node*, Node*> node_map;
1152 if (pivot_node) {
1153 node_map[host_graph->source_node()] = pivot_node;
1154 } else {
1155 node_map[host_graph->source_node()] = main_graph->source_node();
1156 }
1157 node_map[host_graph->sink_node()] = main_graph->sink_node();
1158 Status s = OkStatus();
1159 auto copy_node_fn = [&](const Node* n) {
1160 if (!s.ok()) {
1161 return;
1162 }
1163
1164 Node* copy;
1165 if (node_map.find(n) != node_map.end()) {
1166 // Already copied this node.
1167 copy = node_map.at(n);
1168 } else {
1169 // Copy the node.
1170 NodeDef copy_def = n->def();
1171 copy = main_graph->AddNode(copy_def, &s);
1172 if (!s.ok()) {
1173 return;
1174 }
1175 node_map[n] = copy;
1176 }
1177
1178 // Only handle input edges. Output edges will be added later as its output
1179 // nodes' input edges.
1180 for (auto e : n->in_edges()) {
1181 if (node_map.find(e->src()) == node_map.end()) {
1182 s = errors::Internal("Cannot find node image for ",
1183 e->src()->DebugString());
1184 return;
1185 }
1186 main_graph->AddEdge(node_map[e->src()], e->src_output(), copy,
1187 e->dst_input());
1188 }
1189
1190 // Add control edge from sequencer to XLA computation node.
1191 if (copy->type_string() == "NoOp" &&
1192 HasNodeAttr(copy->def(), "_xla_host_transfer_sequencer")) {
1193 main_graph->AddControlEdge(copy, xla_computation_node);
1194 }
1195 };
1196 ReverseDFS(*host_graph, /*enter=*/nullptr, copy_node_fn, NodeComparatorID());
1197 return s;
1198 }
1199
1200 // Rewrites shape inference graph for outside compilation:
1201 // 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from
1202 // `host_graph`. Because we might still have outside compilation to outside
1203 // compilation placeholder nodes in shape inference graph, which will prevent
1204 // us from inferring XlaSendFromHost shape. But in `host_graph`, we already
1205 // removed those placeholder nodes.
1206 // 2) Remove control edges.
1207 // 3) Prune nodes that are not useful for shape inference.
RewriteShapeInferenceGraph(const string & shape_inference_graph_name,Graph * host_graph,Node * pivot_node,FunctionLibraryDefinition * fld)1208 Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
1209 Graph* host_graph, Node* pivot_node,
1210 FunctionLibraryDefinition* fld) {
1211 // Use "0" as "_device_ordinal". It does not matter for shape inference.
1212 AttrValue device_ordinal_attr;
1213 device_ordinal_attr.set_i(0);
1214 protobuf::Map<string, AttrValue> attrs;
1215 attrs["_device_ordinal"] = device_ordinal_attr;
1216 std::unique_ptr<FunctionBody> fbody;
1217 const FunctionDef* shape_inference_graph =
1218 fld->Find(shape_inference_graph_name);
1219 TF_RET_CHECK(shape_inference_graph);
1220 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*shape_inference_graph,
1221 AttrSlice(&attrs), fld, &fbody));
1222 Graph* g = fbody->graph;
1223
1224 // Find SendFromHost node.
1225 Node* send_from_host = nullptr;
1226 for (Node* n : g->nodes()) {
1227 if (n->type_string() == "_XlaSendFromHost") {
1228 send_from_host = n;
1229 break;
1230 }
1231 }
1232 if (!send_from_host) {
1233 return errors::Internal("Shape inference graph ",
1234 shape_inference_graph_name,
1235 " does not have _XlaSendFromHost node.");
1236 }
1237
1238 // See if the SendFromHost node exists in `host_graph`.
1239 Node* send_node_in_host_graph = nullptr;
1240 for (Node* n : host_graph->nodes()) {
1241 if (n->name() == send_from_host->name()) {
1242 send_node_in_host_graph = n;
1243 break;
1244 }
1245 }
1246 if (send_node_in_host_graph) {
1247 // This is an "top-level" outside compilation. Clear the graph, and copy
1248 // SendFromHost and all its predecessors from `host_graph`.
1249 std::vector<Node*> nodes;
1250 nodes.reserve(g->num_op_nodes());
1251 for (Node* n : g->op_nodes()) {
1252 nodes.push_back(n);
1253 }
1254 for (Node* n : nodes) {
1255 g->RemoveNode(n);
1256 }
1257 Node* start_node = pivot_node ? pivot_node : host_graph->source_node();
1258 // Reverse DFS from send_from_host_main_graph, and stop at start_node.
1259 struct Visit {
1260 Node* n;
1261 bool is_exiting;
1262 };
1263 std::vector<Visit> stack{{send_node_in_host_graph, false}};
1264 std::map<Node*, Node*> node_map;
1265 node_map[host_graph->source_node()] = g->source_node();
1266 while (!stack.empty()) {
1267 Visit& curr = stack.back();
1268 if (curr.is_exiting) {
1269 if (node_map.find(curr.n) == node_map.end()) {
1270 Node* copy = g->CopyNode(curr.n);
1271 if (curr.n != start_node) {
1272 for (const Edge* e : curr.n->in_edges()) {
1273 auto node_iter = node_map.find(e->src());
1274 if (node_iter == node_map.end()) {
1275 return errors::Internal("Cannot find node image for ",
1276 e->src()->DebugString());
1277 }
1278 g->AddEdge(node_iter->second, e->src_output(), copy,
1279 e->dst_input());
1280 }
1281 }
1282 node_map[curr.n] = copy;
1283 }
1284 stack.pop_back();
1285 } else {
1286 curr.is_exiting = true;
1287 if (curr.n != start_node) {
1288 for (const Edge* e : curr.n->in_edges()) {
1289 if (node_map.find(e->src()) != node_map.end()) {
1290 continue;
1291 }
1292 stack.push_back({e->src(), false});
1293 }
1294 }
1295 }
1296 }
1297
1298 send_from_host = node_map[send_node_in_host_graph];
1299 } else {
1300 // This is an outside compilation generated for If/While/gradient/etc.
1301 // It will be enough for shape inference. Leave `g` unchanged.
1302 }
1303
1304 // Control edges are not useful for shape inference. Remove them.
1305 for (auto e : g->edges()) {
1306 if (e->IsControlEdge()) {
1307 g->RemoveEdge(e);
1308 }
1309 }
1310
1311 // Nodes that are not reverse reachable from SendFromHost are not useful for
1312 // shape inference. Prune them.
1313 PruneForReverseReachability(g,
1314 std::unordered_set<const Node*>{send_from_host});
1315
1316 if (VLOG_IS_ON(4)) {
1317 DumpGraphToFile(shape_inference_graph_name, *g, fld);
1318 }
1319
1320 // Replace original shape inference graph.
1321 FunctionDef fdef_replace;
1322 TF_RETURN_IF_ERROR(
1323 GraphToFunctionDef(*g, shape_inference_graph_name, &fdef_replace));
1324 TF_RETURN_IF_ERROR(
1325 fld->ReplaceFunction(shape_inference_graph_name, fdef_replace));
1326
1327 return OkStatus();
1328 }
1329
SetMaximalSharding(NodeDefBuilder & node_builder)1330 void SetMaximalSharding(NodeDefBuilder& node_builder) {
1331 xla::OpSharding sharding;
1332 sharding.set_type(xla::OpSharding::MAXIMAL);
1333 sharding.add_tile_assignment_dimensions(1);
1334 sharding.add_tile_assignment_devices(0);
1335 node_builder.Attr("_XlaSharding", sharding.SerializeAsString());
1336 }
1337
1338 // Builds XlaSendToHost node which sends cond predicate to host.
BuildSendIfPredNode(const string & name,const string & host_transfer_key,Node * pred_node,Graph * g)1339 TF_ATTRIBUTE_NOINLINE StatusOr<Node*> BuildSendIfPredNode(
1340 const string& name, const string& host_transfer_key, Node* pred_node,
1341 Graph* g) {
1342 NodeDefBuilder send_pred_builder(name, "XlaSendToHost");
1343 send_pred_builder.Attr("Tinput", DT_BOOL);
1344 send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0"));
1345 send_pred_builder.Attr(kXlaTokenInputNodesAttrName,
1346 std::vector<string>{kXlaTokenArgNodeName});
1347 send_pred_builder.Attr(kXlaOriginalOutsideCompilationNodeName, name);
1348 SetMaximalSharding(send_pred_builder);
1349 send_pred_builder.Input(pred_node->name(), 0, DT_BOOL);
1350 NodeDef send_pred_def;
1351 TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def));
1352 TF_ASSIGN_OR_RETURN(Node * send_pred_node, g->AddNode(send_pred_def));
1353 g->AddEdge(pred_node, 0, send_pred_node, 0);
1354 return send_pred_node;
1355 }
1356
1357 // Replaces key placeholder node with an _Arg node.
ReplaceKeyPlaceholderWithArgNode(const string & xla_cluster_name,const string & func_name,FunctionLibraryDefinition * fld)1358 Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name,
1359 const string& func_name,
1360 FunctionLibraryDefinition* fld) {
1361 // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder
1362 // value after rewriting.
1363 AttrValue device_ordinal_attr;
1364 device_ordinal_attr.set_i(0);
1365 protobuf::Map<string, AttrValue> attrs;
1366 attrs["_device_ordinal"] = device_ordinal_attr;
1367 std::unique_ptr<FunctionBody> fbody;
1368 const FunctionDef* func = fld->Find(func_name);
1369 TF_RETURN_IF_ERROR(
1370 FunctionDefToBodyHelper(*func, AttrSlice(&attrs), fld, &fbody));
1371 Graph* g = fbody->graph;
1372
1373 // Find or create the key placeholder node.
1374 Node* key_placeholder = nullptr;
1375 for (Node* n : g->nodes()) {
1376 if (IsKeyPlaceholderNode(*n)) {
1377 key_placeholder = n;
1378 break;
1379 }
1380 }
1381 if (!key_placeholder) {
1382 TF_ASSIGN_OR_RETURN(key_placeholder,
1383 AddHostComputeKeyPlaceholder(xla_cluster_name, g));
1384 }
1385
1386 // Build the _Arg node, and replace key placeholder node with it.
1387 NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp);
1388 arg_builder.Attr("T", DT_STRING);
1389 arg_builder.Attr("index", 0);
1390 NodeDef arg_def;
1391 TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
1392 TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status());
1393
1394 // Reset "_device_ordinal" to placeholder value.
1395 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g));
1396
1397 FunctionDef replace_fdef;
1398 TF_RETURN_IF_ERROR(GraphToFunctionDef(
1399 *g, func_name, HostGraphControlRetMapping, &replace_fdef));
1400 TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef));
1401 return OkStatus();
1402 }
1403
1404 // Builds host side graph for If node.
BuildHostGraphForIfNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const string & if_node_name,const string & host_transfer_key,const string & host_graph_func_name,FunctionLibraryDefinition * fld,const string & then_branch_host_func_name,const string & else_branch_host_func_name)1405 TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode(
1406 const string& xla_cluster_attr_name,
1407 const string& outside_compilation_attr_name, const string& xla_cluster_name,
1408 const string& if_node_name, const string& host_transfer_key,
1409 const string& host_graph_func_name, FunctionLibraryDefinition* fld,
1410 const string& then_branch_host_func_name,
1411 const string& else_branch_host_func_name) {
1412 Graph host_graph(fld);
1413 string outside_compilation_name = absl::StrCat("oc_if_", if_node_name);
1414 AttrValue device_ordinal_value;
1415 device_ordinal_value.set_placeholder("_device_ordinal");
1416
1417 // Step 1: add key placeholder node.
1418 TF_ASSIGN_OR_RETURN(
1419 Node * key_placeholder,
1420 AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1421
1422 // Step 2: build XlaRecvAtHost node to recv predicate.
1423 NodeDefBuilder recv_pred_builder(
1424 absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost");
1425 recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL});
1426 recv_pred_builder.Attr("key", host_transfer_key);
1427 recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
1428 recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1429 recv_pred_builder.Attr(outside_compilation_attr_name,
1430 outside_compilation_name);
1431 recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true);
1432 recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING);
1433 NodeDef recv_pred_def;
1434 TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
1435 TF_ASSIGN_OR_RETURN(Node * recv_pred_node, host_graph.AddNode(recv_pred_def));
1436 host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0);
1437
1438 // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key
1439 // placeholder with an _Arg node.
1440 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1441 xla_cluster_name, then_branch_host_func_name, fld));
1442 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1443 xla_cluster_name, else_branch_host_func_name, fld));
1444
1445 // Step 4: build If node to choose between `{then, else}_branch_host_graph`.
1446 NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If");
1447 if_builder.Attr("Tcond", DT_BOOL);
1448 if_builder.Attr("Tin", std::vector<DataType>{DT_STRING});
1449 if_builder.Attr("Tout", std::vector<DataType>{});
1450 NameAttrList host_then_branch, host_else_branch;
1451 host_then_branch.set_name(then_branch_host_func_name);
1452 (*host_then_branch.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1453 host_else_branch.set_name(else_branch_host_func_name);
1454 (*host_else_branch.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1455 if_builder.Attr("then_branch", host_then_branch);
1456 if_builder.Attr("else_branch", host_else_branch);
1457 if_builder.Attr(kXlaHasHostTransferAttrName, true);
1458 if_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1459 if_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
1460 if_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
1461 std::vector<NodeDefBuilder::NodeOut> if_inputs{
1462 {key_placeholder->name(), 0, DT_STRING}};
1463 if_builder.Input(if_inputs);
1464 NodeDef if_def;
1465 TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def));
1466 TF_ASSIGN_OR_RETURN(Node * if_node, host_graph.AddNode(if_def));
1467 host_graph.AddEdge(recv_pred_node, 0, if_node, 0);
1468 host_graph.AddEdge(key_placeholder, 0, if_node, 1);
1469
1470 // Convert `host_graph` to function.
1471 FunctionDef oc_host_graph_fdef;
1472 TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1473 &oc_host_graph_fdef));
1474 if (fld->Find(host_graph_func_name)) {
1475 TF_RETURN_IF_ERROR(
1476 fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1477 } else {
1478 TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1479 }
1480
1481 return OkStatus();
1482 }
1483
1484 // Rewrites loop cond to add a node which sends loop cond to host.
AddSendLoopPredToLoopCond(const string & cond_xla_func_name,const string & host_transfer_key,NameAttrList * loop_cond_func,FunctionLibraryDefinition * fld,Node * while_node)1485 TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
1486 const string& cond_xla_func_name, const string& host_transfer_key,
1487 NameAttrList* loop_cond_func, FunctionLibraryDefinition* fld,
1488 Node* while_node) {
1489 // Instantiate the loop cond function.
1490 std::unique_ptr<FunctionBody> fbody;
1491 const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func->name());
1492 TF_RET_CHECK(loop_cond_fdef);
1493 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1494 *loop_cond_fdef, AttrSlice(&loop_cond_func->attr()), fld, &fbody));
1495 Graph* g = fbody->graph;
1496
1497 // Find the _Retval node and the loop cond node.
1498 Node* ret_node = nullptr;
1499 for (Node* n : g->nodes()) {
1500 if (n->type_string() == "_Retval") {
1501 if (ret_node) {
1502 return errors::Internal("Multiple return node for loop cond function ",
1503 loop_cond_func->name(), ": ",
1504 ret_node->DebugString(), " and ",
1505 n->DebugString());
1506 } else {
1507 ret_node = n;
1508 }
1509 }
1510 }
1511 if (!ret_node) {
1512 return errors::Internal("No _Retval node for loop cond function ",
1513 loop_cond_func->name());
1514 }
1515 Node* loop_cond;
1516 TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond));
1517
1518 // Build the XlaSendToHost node.
1519 NodeDefBuilder send_loop_cond_builder(
1520 absl::StrCat("send_oc_while_cond_", while_node->name()), "XlaSendToHost");
1521 send_loop_cond_builder.Attr("Tinput", DT_BOOL);
1522 send_loop_cond_builder.Attr("key",
1523 absl::StrCat(host_transfer_key, "_dtoh_0"));
1524 send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName,
1525 std::vector<string>{kXlaTokenArgNodeName});
1526 send_loop_cond_builder.Attr(kXlaOriginalOutsideCompilationNodeName,
1527 send_loop_cond_builder.node_name());
1528 SetMaximalSharding(send_loop_cond_builder);
1529 send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL);
1530 NodeDef send_loop_cond_def;
1531 TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def));
1532 TF_ASSIGN_OR_RETURN(Node * send_loop_cond_node,
1533 g->AddNode(send_loop_cond_def));
1534 g->AddEdge(loop_cond, 0, send_loop_cond_node, 0);
1535
1536 // Replace original function if loop_cond_func already has been re-written
1537 // for outside compilation.
1538 FunctionDef replace_fdef;
1539 if (loop_cond_func->name() == cond_xla_func_name) {
1540 TF_RETURN_IF_ERROR(
1541 GraphToFunctionDef(*g, loop_cond_func->name(), &replace_fdef));
1542 TF_RETURN_IF_ERROR(
1543 fld->ReplaceFunction(loop_cond_func->name(), replace_fdef));
1544 } else {
1545 // If original while cond function has not been modified, add a new function
1546 // with send loop predicated added and update the while node callsite
1547 // operation.
1548 const auto new_name = fld->UniqueFunctionName(
1549 absl::StrCat(loop_cond_func->name(), "_send_pred_added_"));
1550 TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, new_name, &replace_fdef));
1551 TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
1552 loop_cond_func->set_name(new_name);
1553 while_node->ClearAttr("cond");
1554 while_node->AddAttr("cond", *loop_cond_func);
1555 }
1556
1557 return OkStatus();
1558 }
1559
1560 // Rewrites while loop cond function for host.
RewriteHostWhileLoopCond(const string & cond_host_func_name,const string & while_node_name,const string & host_transfer_key,const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & outside_compilation_name,FunctionLibraryDefinition * fld)1561 Status RewriteHostWhileLoopCond(
1562 const string& cond_host_func_name, const string& while_node_name,
1563 const string& host_transfer_key, const string& xla_cluster_attr_name,
1564 const string& xla_cluster_name, const string& outside_compilation_attr_name,
1565 const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
1566 // Replace key placeholder node with _Arg node.
1567 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1568 xla_cluster_name, cond_host_func_name, fld));
1569
1570 // Instantiate cond function.
1571 AttrValue device_ordinal_temp_value;
1572 device_ordinal_temp_value.set_i(0);
1573 protobuf::Map<string, AttrValue> attrs;
1574 attrs["_device_ordinal"] = device_ordinal_temp_value;
1575 std::unique_ptr<FunctionBody> cond_fbody;
1576 const FunctionDef* cond_host_func = fld->Find(cond_host_func_name);
1577 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*cond_host_func, AttrSlice(&attrs),
1578 fld, &cond_fbody));
1579 Graph* cond_graph = cond_fbody->graph;
1580 Node* key_arg = nullptr;
1581 for (Node* n : cond_graph->nodes()) {
1582 if (n->type_string() == "_Arg") {
1583 key_arg = n;
1584 }
1585 }
1586 if (!key_arg) {
1587 return errors::Internal(
1588 "No _Arg node found for host compute key in function ",
1589 cond_host_func_name);
1590 }
1591
1592 // Add an XlaRecvAtHost node to use as cond function return value.
1593 NodeDefBuilder recv_pred_builder(
1594 absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost");
1595 recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL});
1596 recv_pred_builder.Attr("key", host_transfer_key);
1597 AttrValue device_ordinal_value;
1598 device_ordinal_value.set_placeholder("_device_ordinal");
1599 recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
1600 recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1601 recv_pred_builder.Attr(outside_compilation_attr_name,
1602 outside_compilation_name);
1603 recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true);
1604 recv_pred_builder.Input(key_arg->name(), 0, DT_STRING);
1605 NodeDef recv_pred_def;
1606 TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
1607 TF_ASSIGN_OR_RETURN(Node * recv_pred_node,
1608 cond_graph->AddNode(recv_pred_def));
1609 cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0);
1610 NodeDefBuilder ret_builder(
1611 absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval");
1612 ret_builder.Attr("T", DT_BOOL);
1613 ret_builder.Attr("index", 0);
1614 ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
1615 NodeDef ret_def;
1616 TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1617 TF_ASSIGN_OR_RETURN(Node * ret_node, cond_graph->AddNode(ret_def));
1618 cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0);
1619
1620 // Reset device_ordinal to placeholder value.
1621 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph));
1622
1623 // Replace original function.
1624 FunctionDef cond_replace_fdef;
1625 TF_RETURN_IF_ERROR(GraphToFunctionDef(*cond_graph, cond_host_func_name,
1626 HostGraphControlRetMapping,
1627 &cond_replace_fdef));
1628 TF_RETURN_IF_ERROR(
1629 fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef));
1630
1631 return OkStatus();
1632 }
1633
1634 // Rewrites while loop body function for host.
RewriteHostWhileLoopBody(const string & body_host_func_name,const string & while_node_name,const string & host_transfer_key,const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & outside_compilation_name,FunctionLibraryDefinition * fld)1635 Status RewriteHostWhileLoopBody(
1636 const string& body_host_func_name, const string& while_node_name,
1637 const string& host_transfer_key, const string& xla_cluster_attr_name,
1638 const string& xla_cluster_name, const string& outside_compilation_attr_name,
1639 const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
1640 // Replace key placeholder node with _Arg node.
1641 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1642 xla_cluster_name, body_host_func_name, fld));
1643
1644 // Instantiate body function.
1645 AttrValue device_ordinal_temp_value;
1646 device_ordinal_temp_value.set_i(0);
1647 protobuf::Map<string, AttrValue> attrs;
1648 attrs["_device_ordinal"] = device_ordinal_temp_value;
1649 std::unique_ptr<FunctionBody> body_fbody;
1650 const FunctionDef* body_host_func = fld->Find(body_host_func_name);
1651 TF_RET_CHECK(body_host_func);
1652 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*body_host_func, AttrSlice(&attrs),
1653 fld, &body_fbody));
1654 Graph* body_graph = body_fbody->graph;
1655 Node* key_arg = nullptr;
1656 for (Node* n : body_graph->nodes()) {
1657 if (n->type_string() == "_Arg") {
1658 key_arg = n;
1659 }
1660 }
1661 if (!key_arg) {
1662 return errors::Internal(
1663 "No _Arg node found for host compute key in function ",
1664 body_host_func_name);
1665 }
1666
1667 // Add a _Retval node to loop body.
1668 NodeDefBuilder ret_builder(
1669 absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval");
1670 ret_builder.Attr("T", DT_STRING);
1671 ret_builder.Attr("index", 0);
1672 ret_builder.Input(key_arg->name(), 0, DT_STRING);
1673 NodeDef ret_def;
1674 TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1675 TF_ASSIGN_OR_RETURN(Node * ret_node, body_graph->AddNode(ret_def));
1676 body_graph->AddEdge(key_arg, 0, ret_node, 0);
1677
1678 // Reset device_ordinal to placeholder value.
1679 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph));
1680
1681 // Replace original function.
1682 FunctionDef body_replace_fdef;
1683 TF_RETURN_IF_ERROR(GraphToFunctionDef(*body_graph, body_host_func_name,
1684 HostGraphControlRetMapping,
1685 &body_replace_fdef));
1686 TF_RETURN_IF_ERROR(
1687 fld->ReplaceFunction(body_host_func_name, body_replace_fdef));
1688
1689 return OkStatus();
1690 }
1691
1692 // Builds host side graph for while node.
BuildHostGraphForWhileNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const string & while_node_name,const string & host_transfer_key,const string & host_graph_func_name,FunctionLibraryDefinition * fld,const string & cond_host_func_name,const string & body_host_func_name)1693 TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForWhileNode(
1694 const string& xla_cluster_attr_name,
1695 const string& outside_compilation_attr_name, const string& xla_cluster_name,
1696 const string& while_node_name, const string& host_transfer_key,
1697 const string& host_graph_func_name, FunctionLibraryDefinition* fld,
1698 const string& cond_host_func_name, const string& body_host_func_name) {
1699 Graph host_graph(fld);
1700 string outside_compilation_name = absl::StrCat("oc_while_", while_node_name);
1701
1702 // Step 1: add key placeholder node.
1703 TF_ASSIGN_OR_RETURN(
1704 Node * key_placeholder,
1705 AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1706
1707 // Step 2: rewrite cond function.
1708 TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond(
1709 cond_host_func_name, while_node_name, host_transfer_key,
1710 xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1711 outside_compilation_name, fld));
1712
1713 // Step 3: rewrite body function.
1714 TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody(
1715 body_host_func_name, while_node_name, host_transfer_key,
1716 xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1717 outside_compilation_name, fld));
1718
1719 // Step 4: build While node.
1720 NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name),
1721 "While");
1722 while_builder.Attr("T", std::vector<DataType>{DT_STRING});
1723 NameAttrList func;
1724 AttrValue device_ordinal_value;
1725 device_ordinal_value.set_placeholder("_device_ordinal");
1726 (*func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1727 func.set_name(cond_host_func_name);
1728 while_builder.Attr("cond", func);
1729 func.set_name(body_host_func_name);
1730 while_builder.Attr("body", func);
1731 while_builder.Attr(kXlaHasHostTransferAttrName, true);
1732 while_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1733 while_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
1734 // Make sure loop body of i-th iteration happens before loop cond of (i+1)-th
1735 // iteration.
1736 while_builder.Attr("parallel_iterations", 1);
1737 std::vector<NodeDefBuilder::NodeOut> while_inputs{
1738 {key_placeholder->name(), 0, DT_STRING}};
1739 while_builder.Input(while_inputs);
1740 NodeDef while_def;
1741 TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def));
1742 TF_ASSIGN_OR_RETURN(Node * while_node, host_graph.AddNode(while_def));
1743 host_graph.AddEdge(key_placeholder, 0, while_node, 0);
1744
1745 // Convert `host_graph` to function.
1746 FunctionDef oc_host_graph_fdef;
1747 TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1748 &oc_host_graph_fdef));
1749 if (fld->Find(host_graph_func_name)) {
1750 TF_RETURN_IF_ERROR(
1751 fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1752 } else {
1753 TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1754 }
1755
1756 return OkStatus();
1757 }
1758
1759 // Builds host graph for func call nodes.
BuildHostGraphForFuncCallNode(const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & func_call_node_name,const string & func_call_host_func_name,const string & host_graph_func_name,FunctionLibraryDefinition * fld)1760 Status BuildHostGraphForFuncCallNode(
1761 const string& xla_cluster_attr_name, const string& xla_cluster_name,
1762 const string& outside_compilation_attr_name,
1763 const string& func_call_node_name, const string& func_call_host_func_name,
1764 const string& host_graph_func_name, FunctionLibraryDefinition* fld) {
1765 Graph host_graph(fld);
1766 AttrValue device_ordinal_value;
1767 device_ordinal_value.set_placeholder("_device_ordinal");
1768
1769 // Step 1: add key placeholder node.
1770 TF_ASSIGN_OR_RETURN(
1771 Node * key_placeholder,
1772 AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1773
1774 // Step 2: rewrite `host_func_name`, replace key placeholder with an _Arg
1775 // node.
1776 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1777 xla_cluster_name, func_call_host_func_name, fld));
1778
1779 // Step 3: build a function call node with `host_func_name`, with
1780 // `key_placeholder` as input.
1781 NodeDefBuilder call_builder(absl::StrCat("oc_call_", func_call_node_name),
1782 func_call_host_func_name, fld);
1783 call_builder.Input(key_placeholder->name(), 0, DT_STRING);
1784 call_builder.Attr("_device_ordinal", device_ordinal_value);
1785 call_builder.Attr(kXlaHasHostTransferAttrName, true);
1786 call_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1787 call_builder.Attr(outside_compilation_attr_name, call_builder.node_name());
1788 NodeDef call_def;
1789 TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def));
1790 TF_ASSIGN_OR_RETURN(Node * call_node, host_graph.AddNode(call_def));
1791 host_graph.AddEdge(key_placeholder, 0, call_node, 0);
1792
1793 // Convert `host_graph` to function.
1794 FunctionDef oc_host_graph_fdef;
1795 TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1796 HostGraphControlRetMapping,
1797 &oc_host_graph_fdef));
1798 if (fld->Find(host_graph_func_name)) {
1799 TF_RETURN_IF_ERROR(
1800 fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1801 } else {
1802 TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1803 }
1804
1805 return OkStatus();
1806 }
1807
ExtractOutsideCompilationForFuncCallNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)1808 TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode(
1809 const string& xla_cluster_attr_name,
1810 const string& outside_compilation_attr_name, const string& xla_cluster_name,
1811 const std::map<string, int>& host_compute_core, Graph* g, Node* n,
1812 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
1813 std::vector<string>* host_graphs,
1814 std::vector<string>* shape_inference_graphs,
1815 bool* has_outside_compilation) {
1816 bool func_has_outside_compilation = false;
1817 NameAttrList func;
1818 if (fld->Contains(n->type_string())) {
1819 func.set_name(n->type_string());
1820 typedef protobuf::Map<string, AttrValue> AttrMap;
1821 *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end());
1822 } else if (n->IsPartitionedCall()) {
1823 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func));
1824 } else {
1825 TF_RET_CHECK(n->type_string() == FunctionLibraryDefinition::kGradientOp);
1826 func.set_name(FunctionLibraryDefinition::kGradientOp);
1827 *func.mutable_attr() = n->def().attr();
1828 }
1829 string canonical_func_name;
1830 if (func.name() == FunctionLibraryDefinition::kGradientOp) {
1831 NameAttrList forward_func;
1832 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &forward_func));
1833 canonical_func_name = absl::StrCat("gradient_", forward_func.name());
1834 } else {
1835 canonical_func_name = func.name();
1836 }
1837 string new_func_name = absl::StrCat(canonical_func_name, "_oc");
1838 string host_func_name =
1839 absl::StrCat("oc_func_call_host_", canonical_func_name);
1840 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1841 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1842 func, new_func_name, host_func_name, host_compute_core, flr, fld,
1843 shape_inference_graphs, &func_has_outside_compilation));
1844
1845 // If the function call does not have outside compilation, nothing to do.
1846 if (!func_has_outside_compilation) {
1847 return OkStatus();
1848 }
1849
1850 *has_outside_compilation = true;
1851
1852 // Change `n` to call the new function directly.
1853 auto replace_builder =
1854 std::make_unique<NodeDefBuilder>(n->name(), new_func_name, fld);
1855 std::vector<NodeDefBuilder::NodeOut> inputs(n->num_inputs());
1856 for (const Edge* e : n->in_edges()) {
1857 if (e->IsControlEdge()) {
1858 continue;
1859 }
1860
1861 const bool input_size_check =
1862 e->dst_input() < static_cast<int>(inputs.size());
1863 TF_RET_CHECK(e->dst_input() >= 0 && input_size_check);
1864 inputs[e->dst_input()] =
1865 NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
1866 e->src()->output_type(e->src_output())};
1867 }
1868 for (const auto& input : inputs) {
1869 replace_builder->Input(input);
1870 }
1871 for (const auto& attr : n->attrs()) {
1872 replace_builder->Attr(attr.first, attr.second);
1873 }
1874 auto replace_def = std::make_unique<NodeDef>();
1875 TF_RETURN_IF_ERROR(replace_builder->Finalize(replace_def.get()));
1876 TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, *replace_def));
1877 replace->AddAttr(kXlaTokenInputNodesAttrName,
1878 std::vector<string>{kXlaTokenArgNodeName});
1879 replace->AddAttr(kXlaOriginalOutsideCompilationNodeName, replace->name());
1880
1881 // Build host side graph for the function call.
1882 string oc_host_graph_name =
1883 absl::StrCat("oc_func_host_graph_", replace->name());
1884 TF_RETURN_IF_ERROR(BuildHostGraphForFuncCallNode(
1885 xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1886 replace->name(), host_func_name, oc_host_graph_name, fld));
1887
1888 // Record the host graph.
1889 host_graphs->push_back(oc_host_graph_name);
1890
1891 return OkStatus();
1892 }
1893
ExtractOutsideCompilationForIfNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)1894 Status ExtractOutsideCompilationForIfNode(
1895 const string& xla_cluster_attr_name,
1896 const string& outside_compilation_attr_name, const string& xla_cluster_name,
1897 const std::map<string, int>& host_compute_core, Graph* g, Node* n,
1898 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
1899 std::vector<string>* host_graphs,
1900 std::vector<string>* shape_inference_graphs,
1901 bool* has_outside_compilation) {
1902 // Instantiate "then_branch" and "else_branch".
1903 NameAttrList then_branch, else_branch;
1904 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch));
1905 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch));
1906
1907 // Extract outside compilation for then_branch and else_branch.
1908 bool then_branch_has_outside_compilation = false;
1909 bool else_branch_has_outside_compilation = false;
1910 string then_branch_host_func_name =
1911 absl::StrCat("oc_then_branch_host_if_", then_branch.name()),
1912 else_branch_host_func_name =
1913 absl::StrCat("oc_else_branch_host_if_", else_branch.name());
1914 string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"),
1915 else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc");
1916 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1917 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1918 then_branch, then_branch_xla_func_name, then_branch_host_func_name,
1919 host_compute_core, flr, fld, shape_inference_graphs,
1920 &then_branch_has_outside_compilation));
1921 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1922 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1923 else_branch, else_branch_xla_func_name, else_branch_host_func_name,
1924 host_compute_core, flr, fld, shape_inference_graphs,
1925 &else_branch_has_outside_compilation));
1926
1927 // If then/else branch do not have outside compilation, nothing to do.
1928 if (!then_branch_has_outside_compilation &&
1929 !else_branch_has_outside_compilation) {
1930 return OkStatus();
1931 }
1932
1933 *has_outside_compilation = true;
1934
1935 // Change If node to call the new functions.
1936 if (then_branch_has_outside_compilation) {
1937 then_branch.set_name(then_branch_xla_func_name);
1938 n->ClearAttr("then_branch");
1939 n->AddAttr("then_branch", then_branch);
1940 }
1941 if (else_branch_has_outside_compilation) {
1942 else_branch.set_name(else_branch_xla_func_name);
1943 n->ClearAttr("else_branch");
1944 n->AddAttr("else_branch", else_branch);
1945 }
1946 n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name());
1947
1948 string host_transfer_key = absl::StrCat("oc_if_pred_", n->name());
1949
1950 // XLA computation: add a SendToHost node to send cond predicate.
1951 Node* pred_node;
1952 TF_RETURN_IF_ERROR(n->input_node(0, &pred_node));
1953 TF_ASSIGN_OR_RETURN(
1954 Node * send_pred_node,
1955 BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()),
1956 host_transfer_key, pred_node, g));
1957 n->AddAttr(kXlaTokenInputNodesAttrName,
1958 std::vector<string>{send_pred_node->name()});
1959
1960 // Add a control edge from `send_pred_node` to If node, so XlaCompiler will
1961 // visit If node after `send_pred_node`, thus the token output for
1962 // `send_pred_node` has been generated.
1963 g->AddControlEdge(send_pred_node, n);
1964
1965 // Build host side graph for the "If" node.
1966 // If then/else branch does not have outside compilation, we won't build host
1967 // graph for the branch. But here we need a host graph for both branches, so
1968 // we need to create a no-op host graph.
1969 if (!then_branch_has_outside_compilation) {
1970 std::unique_ptr<Graph> then_branch_host_graph(new Graph(fld));
1971 std::vector<string> then_branch_host_graphs;
1972 TF_RETURN_IF_ERROR(ConstructHostGraph(
1973 xla_cluster_name, outside_compilation_attr_name,
1974 then_branch_host_graphs, fld, &then_branch_host_graph));
1975 FunctionDef then_branch_host_fdef;
1976 TF_RETURN_IF_ERROR(GraphToFunctionDef(*then_branch_host_graph,
1977 then_branch_host_func_name,
1978 &then_branch_host_fdef));
1979 if (fld->Find(then_branch_host_func_name)) {
1980 TF_RETURN_IF_ERROR(fld->ReplaceFunction(then_branch_host_func_name,
1981 then_branch_host_fdef));
1982 } else {
1983 TF_RETURN_IF_ERROR(fld->AddFunctionDef(then_branch_host_fdef));
1984 }
1985 }
1986 if (!else_branch_has_outside_compilation) {
1987 std::unique_ptr<Graph> else_branch_host_graph(new Graph(fld));
1988 std::vector<string> else_branch_host_graphs;
1989 TF_RETURN_IF_ERROR(ConstructHostGraph(
1990 xla_cluster_name, outside_compilation_attr_name,
1991 else_branch_host_graphs, fld, &else_branch_host_graph));
1992 FunctionDef else_branch_host_fdef;
1993 TF_RETURN_IF_ERROR(GraphToFunctionDef(*else_branch_host_graph,
1994 else_branch_host_func_name,
1995 &else_branch_host_fdef));
1996 if (fld->Find(else_branch_host_func_name)) {
1997 TF_RETURN_IF_ERROR(fld->ReplaceFunction(else_branch_host_func_name,
1998 else_branch_host_fdef));
1999 } else {
2000 TF_RETURN_IF_ERROR(fld->AddFunctionDef(else_branch_host_fdef));
2001 }
2002 }
2003 string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name());
2004 TF_RETURN_IF_ERROR(BuildHostGraphForIfNode(
2005 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2006 n->name(), host_transfer_key, oc_host_graph_name, fld,
2007 then_branch_host_func_name, else_branch_host_func_name));
2008 host_graphs->push_back(oc_host_graph_name);
2009
2010 return OkStatus();
2011 }
2012
ExtractOutsideCompilationForWhileNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2013 Status ExtractOutsideCompilationForWhileNode(
2014 const string& xla_cluster_attr_name,
2015 const string& outside_compilation_attr_name, const string& xla_cluster_name,
2016 const std::map<string, int>& host_compute_core, Graph* g, Node* n,
2017 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
2018 std::vector<string>* host_graphs,
2019 std::vector<string>* shape_inference_graphs,
2020 bool* has_outside_compilation) {
2021 // Instantiate "cond" and "body".
2022 NameAttrList cond, body;
2023 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond));
2024 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body));
2025
2026 // Extract outside compilation for cond and body.
2027 bool cond_has_outside_compilation = false;
2028 bool body_has_outside_compilation = false;
2029 string cond_host_func_name = absl::StrCat("oc_cond_host_while_", cond.name()),
2030 body_host_func_name = absl::StrCat("oc_body_host_while_", body.name());
2031 string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"),
2032 body_xla_func_name = absl::StrCat(body.name(), "_oc");
2033 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2034 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2035 cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr,
2036 fld, shape_inference_graphs, &cond_has_outside_compilation));
2037 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2038 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2039 body, body_xla_func_name, body_host_func_name, host_compute_core, flr,
2040 fld, shape_inference_graphs, &body_has_outside_compilation));
2041
2042 // If cond/body do not have outside compilation, nothing to do.
2043 if (!cond_has_outside_compilation && !body_has_outside_compilation) {
2044 return OkStatus();
2045 }
2046
2047 *has_outside_compilation = true;
2048
2049 // Change While node to call the new functions.
2050 if (cond_has_outside_compilation) {
2051 cond.set_name(cond_xla_func_name);
2052 n->ClearAttr("cond");
2053 n->AddAttr("cond", cond);
2054 }
2055 if (body_has_outside_compilation) {
2056 body.set_name(body_xla_func_name);
2057 n->ClearAttr("body");
2058 n->AddAttr("body", body);
2059 }
2060 n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name());
2061
2062 string host_transfer_key = absl::StrCat("oc_while_pred_", n->name());
2063
2064 // XLA computation: rewrite cond function to add a SendToHost node to send
2065 // loop predicate.
2066 TF_RETURN_IF_ERROR(AddSendLoopPredToLoopCond(
2067 cond_xla_func_name, host_transfer_key, &cond, fld, n));
2068 n->AddAttr(kXlaTokenInputNodesAttrName,
2069 std::vector<string>{kXlaTokenArgNodeName});
2070
2071 // Build host side graph for the "While" node.
2072 if (!cond_has_outside_compilation) {
2073 std::unique_ptr<Graph> cond_host_graph(new Graph(fld));
2074 std::vector<string> host_graphs;
2075 TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
2076 outside_compilation_attr_name,
2077 host_graphs, fld, &cond_host_graph));
2078 FunctionDef cond_host_fdef;
2079 TF_RETURN_IF_ERROR(GraphToFunctionDef(*cond_host_graph, cond_host_func_name,
2080 &cond_host_fdef));
2081 if (fld->Find(cond_host_func_name)) {
2082 TF_RETURN_IF_ERROR(
2083 fld->ReplaceFunction(cond_host_func_name, cond_host_fdef));
2084 } else {
2085 TF_RETURN_IF_ERROR(fld->AddFunctionDef(cond_host_fdef));
2086 }
2087 }
2088 if (!body_has_outside_compilation) {
2089 std::unique_ptr<Graph> body_host_graph(new Graph(fld));
2090 std::vector<string> host_graphs;
2091 TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
2092 outside_compilation_attr_name,
2093 host_graphs, fld, &body_host_graph));
2094 FunctionDef body_host_fdef;
2095 TF_RETURN_IF_ERROR(GraphToFunctionDef(*body_host_graph, body_host_func_name,
2096 &body_host_fdef));
2097 if (fld->Find(body_host_func_name)) {
2098 TF_RETURN_IF_ERROR(
2099 fld->ReplaceFunction(body_host_func_name, body_host_fdef));
2100 } else {
2101 TF_RETURN_IF_ERROR(fld->AddFunctionDef(body_host_fdef));
2102 }
2103 }
2104 string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name());
2105 TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode(
2106 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2107 n->name(), host_transfer_key, oc_host_graph_name, fld,
2108 cond_host_func_name, body_host_func_name));
2109 host_graphs->push_back(oc_host_graph_name);
2110
2111 return OkStatus();
2112 }
2113
ExtractOutsideCompilationForNodesWithAssociatedFunctions(Graph * g,const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2114 Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
2115 Graph* g, const string& xla_cluster_attr_name,
2116 const string& outside_compilation_attr_name, const string& xla_cluster_name,
2117 const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
2118 FunctionLibraryDefinition* fld, std::vector<string>* host_graphs,
2119 std::vector<string>* shape_inference_graphs,
2120 bool* has_outside_compilation) {
2121 std::vector<Node*> if_nodes, while_nodes, func_call_nodes;
2122 for (Node* n : g->nodes()) {
2123 if (n->IsIfNode()) {
2124 if_nodes.push_back(n);
2125 } else if (n->IsWhileNode()) {
2126 while_nodes.push_back(n);
2127 } else if (IsFunctionCall(*fld, *n)) {
2128 func_call_nodes.push_back(n);
2129 }
2130 }
2131
2132 for (Node* n : func_call_nodes) {
2133 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFuncCallNode(
2134 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2135 host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2136 has_outside_compilation));
2137 }
2138
2139 for (Node* n : if_nodes) {
2140 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForIfNode(
2141 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2142 host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2143 has_outside_compilation));
2144 }
2145
2146 for (Node* n : while_nodes) {
2147 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForWhileNode(
2148 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2149 host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2150 has_outside_compilation));
2151 }
2152
2153 return OkStatus();
2154 }
2155
CopyOutsideCompilationConstNodes(Graph * g,const string & outside_compilation_attr_name)2156 Status CopyOutsideCompilationConstNodes(
2157 Graph* g, const string& outside_compilation_attr_name) {
2158 for (Node* n : g->op_nodes()) {
2159 if (!n->IsConstant() ||
2160 !HasNodeAttr(n->def(), outside_compilation_attr_name)) {
2161 continue;
2162 }
2163
2164 std::vector<const Edge*> out_edges(n->out_edges().begin(),
2165 n->out_edges().end());
2166 bool has_non_oc_output = false;
2167 for (const Edge* e : out_edges) {
2168 if (!e->IsControlEdge() &&
2169 !HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
2170 has_non_oc_output = true;
2171 break;
2172 }
2173 }
2174 if (!has_non_oc_output) {
2175 continue;
2176 }
2177
2178 NodeDef copy_def = n->def();
2179 copy_def.set_name(g->NewName(n->name()));
2180 copy_def.mutable_attr()->erase(outside_compilation_attr_name);
2181 TF_ASSIGN_OR_RETURN(Node * copy_node, g->AddNode(copy_def));
2182 for (const Edge* e : n->in_edges()) {
2183 if (e->IsControlEdge()) {
2184 g->AddControlEdge(e->src(), copy_node);
2185 }
2186 }
2187 for (const Edge* e : out_edges) {
2188 if (!e->IsControlEdge() &&
2189 !HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
2190 Node* dst = e->dst();
2191 int dst_input = e->dst_input();
2192 g->RemoveEdge(e);
2193 g->AddEdge(copy_node, 0, dst, dst_input);
2194 }
2195 }
2196 }
2197
2198 return OkStatus();
2199 }
2200
2201 } // namespace
2202
operator ()(const std::vector<OutputTensor> & arg_source_tensors,std::unique_ptr<Graph> * graph,std::vector<int> * input_permutation,std::vector<int> * output_permutation,NodeDef * node_def)2203 Status RewriteOutsideCompilationSubgraphFn::operator()(
2204 const std::vector<OutputTensor>& arg_source_tensors,
2205 std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
2206 std::vector<int>* output_permutation, NodeDef* node_def) {
2207 string old_name = node_def->op();
2208 string new_name =
2209 absl::StrCat(xla_cluster_name_, "_", new_function_name_, "_", old_name);
2210 node_def->set_op(new_name);
2211 node_def->set_name(new_name);
2212
2213 // Later we will run PruneForReverseReachability(), so make sure all original
2214 // nodes are reachable from sink node and won't be removed.
2215 FixupSourceAndSinkEdges(graph->get());
2216
2217 // Step 1: create a key placeholder node.
2218 TF_ASSIGN_OR_RETURN(
2219 Node * key_placeholder,
2220 AddHostComputeKeyPlaceholder(xla_cluster_name_, graph->get()));
2221
2222 // Step 2: build RecvAtHost node, and replace all _Arg nodes with it.
2223 std::vector<DataType> recv_at_host_dtypes;
2224 TF_ASSIGN_OR_RETURN(
2225 Node * recv_at_host_node,
2226 ReplaceArgNodesWithRecvAtHostNode(graph->get(), new_name,
2227 &recv_at_host_dtypes, key_placeholder));
2228
2229 // Step 3: build SendFromHost node, and replace all _Retval nodes with it.
2230 std::vector<DataType> send_from_host_dtypes;
2231 TF_ASSIGN_OR_RETURN(
2232 Node * send_from_host_node,
2233 ReplaceRetNodesWithSendFromHostNode(
2234 graph->get(), new_name, &send_from_host_dtypes, key_placeholder));
2235
2236 // Step 4: add XLA cluster and outside compilation attr.
2237 for (Node* n : (*graph)->nodes()) {
2238 if (IsKeyPlaceholderNode(*n)) {
2239 continue;
2240 }
2241
2242 n->AddAttr(xla_cluster_attr_name_, xla_cluster_name_);
2243 n->AddAttr(outside_compilation_attr_name_, old_name);
2244 }
2245
2246 // Check whether we have all input shapes for XlaSendFromHost. If we do, we
2247 // will set `shapes` attr for the call node; otherwise we will save the
2248 // shape inference graph and set `shape_inference_graph` for the call node.
2249 std::optional<std::vector<PartialTensorShape>> shapes =
2250 GetInferredInputShapes(send_from_host_dtypes.size(), send_from_host_node);
2251 for (Node* n : (*graph)->nodes()) {
2252 n->ClearAttr(kXlaInferredShapesAttrName);
2253 }
2254
2255 // Step 5: add control edges for originally XLA <-> outside compilation
2256 // control edges.
2257 for (Node* n : (*graph)->nodes()) {
2258 if (HasNodeAttr(n->def(), kXlaConnectedToXlaComputationAttrName)) {
2259 (*graph)->AddControlEdge(n, send_from_host_node);
2260 n->ClearAttr(kXlaConnectedToXlaComputationAttrName);
2261 }
2262 if (HasNodeAttr(n->def(), kXlaConnectedFromXlaComputationAttrName)) {
2263 (*graph)->AddControlEdge(recv_at_host_node, n);
2264 n->ClearAttr(kXlaConnectedFromXlaComputationAttrName);
2265 }
2266 }
2267
2268 // Step 6: RecvAtHost/SendFromHost/key_placeholder might be dead nodes. Prune
2269 // them if necessary.
2270 // - RecvAtHost should be pruned iff it has no output data/control edges. If
2271 // it has any output edge, it will be reverse reachable from sink node. We
2272 // don't need to do anything special.
2273 // - SendFromHost should be pruned iff it has no input data/control edges. If
2274 // it has input edges other than key_placeholder, we connect it to sink
2275 // node so it won't be pruned.
2276 // - key_placeholder should be pruned iff RecvAtHost/SendFromHost are pruned.
2277 // We don't need to do anything special.
2278 if (send_from_host_node->in_edges().size() > 1) {
2279 (*graph)->AddControlEdge(send_from_host_node, (*graph)->sink_node());
2280 }
2281 PruneForReverseReachability(
2282 graph->get(), std::unordered_set<const Node*>{(*graph)->sink_node()});
2283
2284 // Step 7: add necessary attributes to function call node, so we can replace
2285 // it with HostCompute node later.
2286 AddNodeAttr("_outside_compilation_subgraph", old_name, node_def);
2287 if (shapes) {
2288 NameAttrList shape_inference_graph;
2289 AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
2290 AddNodeAttr("shapes", *shapes, node_def);
2291 } else {
2292 string shape_inference_func_name =
2293 absl::StrCat("_outside_compilation_shape_inference_", new_name);
2294 NameAttrList shape_inference_graph;
2295 shape_inference_graph.set_name(shape_inference_func_name);
2296 AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
2297 AddNodeAttr("shapes", std::vector<TensorShapeProto>{}, node_def);
2298 }
2299 AddNodeAttr("ancestors", std::vector<string>{}, node_def);
2300 AddNodeAttr("Tinputs", recv_at_host_dtypes, node_def);
2301 AddNodeAttr("Toutputs", send_from_host_dtypes, node_def);
2302 AddNodeAttr("key", absl::StrCat("host_compute_channel_", new_name), node_def);
2303
2304 return OkStatus();
2305 }
2306
ExtractOutsideCompilationForFunction(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const NameAttrList & func_name_attrs,const string & new_func_name,const string & host_graph_func_name,const std::map<string,int> & host_compute_core,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2307 Status ExtractOutsideCompilationForFunction(
2308 const string& xla_cluster_attr_name,
2309 const string& outside_compilation_attr_name, const string& xla_cluster_name,
2310 const NameAttrList& func_name_attrs, const string& new_func_name,
2311 const string& host_graph_func_name,
2312 const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
2313 FunctionLibraryDefinition* fld, std::vector<string>* shape_inference_graphs,
2314 bool* has_outside_compilation) {
2315 // Convert the function to graph.
2316 const string& func_name = func_name_attrs.name();
2317 FunctionLibraryRuntime::Handle handle;
2318 TF_RETURN_IF_ERROR(
2319 flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle));
2320 Status ret_status = OkStatus();
2321 auto cleanup_handle = gtl::MakeCleanup([&]() {
2322 auto s = flr->ReleaseHandle(handle);
2323 if (!s.ok()) {
2324 ret_status.Update(s);
2325 }
2326 });
2327 const FunctionBody* fbody = flr->GetFunctionBody(handle);
2328
2329 // Check if we have outside compilation nodes.
2330 *has_outside_compilation = false;
2331 for (Node* n : fbody->graph->nodes()) {
2332 if (HasNodeAttr(n->def(), outside_compilation_attr_name)) {
2333 *has_outside_compilation = true;
2334 break;
2335 }
2336 }
2337 // We cannot early return here, because we might have outside compilation in
2338 // If/While function body.
2339
2340 if (VLOG_IS_ON(4)) {
2341 DumpGraphToFile(
2342 absl::StrCat("extract_outside_compilation_for_func_before_", func_name),
2343 *fbody->graph, fld);
2344 }
2345
2346 std::unique_ptr<Graph> graph_out;
2347 std::vector<string> outside_compilation_host_graphs;
2348 std::vector<string> shape_inference_graphs_to_rewrite;
2349 if (*has_outside_compilation) {
2350 // Copy outside compilation Const nodes with non outside compilation users.
2351 TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes(
2352 fbody->graph, outside_compilation_attr_name));
2353
2354 // Find dependencies between outside compilation clusters.
2355 TF_ASSIGN_OR_RETURN(auto cluster_deps,
2356 OutsideCompilationClusterDependencies(
2357 fbody->graph, outside_compilation_attr_name));
2358
2359 // Preprocess edges between different outside compilations. They will be
2360 // restored in `ConstructHostGraph()`.
2361 TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
2362 fbody->graph, outside_compilation_attr_name));
2363
2364 // Encapsulate outside_compilation cluster into function call node.
2365 auto rewrite_fn = std::make_unique<RewriteOutsideCompilationSubgraphFn>(
2366 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2367 new_func_name);
2368 TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
2369 outside_compilation_attr_name, *fbody->graph, *rewrite_fn,
2370 /*reuse_existing_functions=*/true, &graph_out, fld));
2371
2372 // Replace outside_compilation function nodes with HostCompute ops.
2373 std::vector<Node*> outside_compilation_nodes;
2374 for (Node* n : graph_out->nodes()) {
2375 if (HasNodeAttr(n->def(), "_outside_compilation_subgraph")) {
2376 outside_compilation_nodes.push_back(n);
2377 outside_compilation_host_graphs.push_back(n->name());
2378
2379 // If we could not infer shapes for XlaSendFromHost inputs statically,
2380 // we will set the "shape_inference_graph" attribute. In that case, copy
2381 // outside compilation subgraph as shape inference graph in `fld`.
2382 auto shape_inference_graph = std::make_unique<NameAttrList>();
2383 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph",
2384 shape_inference_graph.get()));
2385 if (!shape_inference_graph->name().empty()) {
2386 shape_inference_graphs->push_back(shape_inference_graph->name());
2387 shape_inference_graphs_to_rewrite.push_back(
2388 shape_inference_graph->name());
2389
2390 const FunctionDef* xla_fdef = fld->Find(n->name());
2391 if (!xla_fdef) {
2392 return errors::Internal("Cannot find XLA function ", n->name());
2393 }
2394 auto shape_inference_fdef = std::make_unique<FunctionDef>(*xla_fdef);
2395 shape_inference_fdef->mutable_signature()->set_name(
2396 shape_inference_graph->name());
2397 if (fld->Find(shape_inference_graph->name())) {
2398 TF_RETURN_IF_ERROR(fld->ReplaceFunction(
2399 shape_inference_graph->name(), *shape_inference_fdef));
2400 } else {
2401 TF_RETURN_IF_ERROR(fld->AddFunctionDef(*shape_inference_fdef));
2402 }
2403 }
2404 }
2405 }
2406 std::map<string, Node*> host_compute_nodes;
2407 for (Node* n : outside_compilation_nodes) {
2408 auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
2409 graph_out.get(), n, host_compute_core, *cluster_deps);
2410 TF_RETURN_IF_ERROR(host_compute_node_or.status());
2411 Node* host_compute_node = host_compute_node_or.ValueOrDie();
2412 host_compute_nodes[host_compute_node->name()] = host_compute_node;
2413 }
2414 // For XlaHostCompute nodes with dependencies, add control edges between
2415 // them so XlaCompiler can handle them in correct order.
2416 for (const auto& iter : host_compute_nodes) {
2417 Node* host_compute_node = iter.second;
2418 std::vector<string> token_input_node_names;
2419 TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
2420 kXlaTokenInputNodesAttrName,
2421 &token_input_node_names));
2422 for (const string& node_name : token_input_node_names) {
2423 if (node_name == kXlaTokenArgNodeName) {
2424 continue;
2425 }
2426
2427 auto iter = host_compute_nodes.find(node_name);
2428 TF_RET_CHECK(iter != host_compute_nodes.end());
2429 graph_out->AddControlEdge(iter->second, host_compute_node);
2430 }
2431 }
2432 }
2433
2434 // Handle nodes with associated functions.
2435 Graph* g = (*has_outside_compilation) ? graph_out.get() : fbody->graph;
2436 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions(
2437 g, xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2438 host_compute_core, flr, fld, &outside_compilation_host_graphs,
2439 shape_inference_graphs, has_outside_compilation));
2440
2441 if (*has_outside_compilation) {
2442 // Construct host graph.
2443 std::unique_ptr<Graph> host_graph;
2444 TF_RETURN_IF_ERROR(
2445 ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
2446 outside_compilation_host_graphs, fld, &host_graph));
2447 auto host_graph_fdef = std::make_unique<FunctionDef>();
2448 TF_RETURN_IF_ERROR(GraphToFunctionDef(*host_graph, host_graph_func_name,
2449 HostGraphControlRetMapping,
2450 host_graph_fdef.get()));
2451 if (fld->Find(host_graph_func_name)) {
2452 TF_RETURN_IF_ERROR(
2453 fld->ReplaceFunction(host_graph_func_name, *host_graph_fdef));
2454 } else {
2455 TF_RETURN_IF_ERROR(fld->AddFunctionDef(*host_graph_fdef));
2456 }
2457
2458 // Shape inference graphs might contain Placeholder nodes for outside
2459 // compilation to outside compilation edges. Rewrite shape inference graphs
2460 // to remove such nodes.
2461 for (const string& shape_inference_graph :
2462 shape_inference_graphs_to_rewrite) {
2463 TF_RETURN_IF_ERROR(
2464 RewriteShapeInferenceGraph(shape_inference_graph, host_graph.get(),
2465 /*pivot_node=*/nullptr, fld));
2466 }
2467
2468 // Remove the outside compilation graphs from function library.
2469 for (const string& func : outside_compilation_host_graphs) {
2470 TF_RETURN_IF_ERROR(fld->RemoveFunction(func));
2471 }
2472
2473 // Replace original function.
2474 auto updated_fdef = std::make_unique<FunctionDef>();
2475 TF_RETURN_IF_ERROR(
2476 GraphToFunctionDef(*g, new_func_name, updated_fdef.get()));
2477 updated_fdef->mutable_signature()->set_is_stateful(true);
2478 const FunctionDef* original_fdef = fld->Find(func_name);
2479 if (original_fdef) {
2480 for (const auto& attr : original_fdef->attr()) {
2481 (*updated_fdef->mutable_attr())[attr.first] = attr.second;
2482 }
2483 }
2484 if (fld->Find(new_func_name)) {
2485 TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, *updated_fdef));
2486 } else {
2487 TF_RETURN_IF_ERROR(fld->AddFunctionDef(*updated_fdef));
2488 }
2489 if (VLOG_IS_ON(4)) {
2490 DumpGraphToFile(
2491 absl::StrCat("extract_outside_compilation_for_func_after_",
2492 func_name),
2493 *g, fld);
2494 }
2495 }
2496
2497 return ret_status;
2498 }
2499
ExtractOutsideCompilation(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const std::unordered_map<string,XlaClusterInfo> & clusters,Graph * g,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,bool * modified)2500 Status ExtractOutsideCompilation(
2501 const string& xla_cluster_attr_name,
2502 const string& outside_compilation_attr_name,
2503 const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
2504 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
2505 bool* modified) {
2506 if (VLOG_IS_ON(4)) {
2507 DumpGraphToFile("extract_outside_compilation_before", *g, fld);
2508 }
2509
2510 *modified = false;
2511 auto node_name_index = g->BuildNodeNameIndex();
2512 for (auto& iter : clusters) {
2513 string xla_cluster_name = iter.first;
2514 Node* n = iter.second.node;
2515 auto const& func_name_attrs = iter.second.func_name_attrs;
2516 auto const& host_compute_core = iter.second.host_compute_core;
2517
2518 std::vector<string> shape_inference_graphs;
2519 bool has_outside_compilation;
2520 string host_graph_func_name =
2521 absl::StrCat("oc_host_graph_", xla_cluster_name);
2522 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2523 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2524 func_name_attrs, func_name_attrs.name(), host_graph_func_name,
2525 host_compute_core, flr, fld, &shape_inference_graphs,
2526 &has_outside_compilation));
2527 *modified |= has_outside_compilation;
2528
2529 if (has_outside_compilation) {
2530 string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
2531 Node* pivot_node = node_name_index[pivot_name];
2532 TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(
2533 g, fld, host_graph_func_name, n, pivot_node));
2534
2535 TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
2536
2537 for (const auto& shape_inference_graph_name : shape_inference_graphs) {
2538 TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(
2539 shape_inference_graph_name, g, pivot_node, fld));
2540 }
2541 }
2542 }
2543
2544 if (VLOG_IS_ON(4)) {
2545 DumpGraphToFile("extract_outside_compilation_after", *g, fld);
2546 }
2547 return OkStatus();
2548 }
2549
2550 } // namespace tensorflow
2551