xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/extract_outside_compilation_pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
22 #include "tensorflow/compiler/jit/encapsulate_util.h"
23 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
24 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/graph_to_functiondef.h"
30 #include "tensorflow/core/framework/node_def_builder.h"
31 #include "tensorflow/core/framework/node_def_util.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/graph/algorithm.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/gtl/cleanup.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/util/dump_graph.h"
38 #include "tensorflow/stream_executor/lib/statusor.h"
39 
40 namespace tensorflow {
41 
42 namespace {
43 
44 // Control return mapping function for outside compilation host graphs.
45 // All nodes with kXlaHasHostTransfer attribute are control outputs.
HostGraphControlRetMapping(const Node * n)46 std::optional<string> HostGraphControlRetMapping(const Node* n) {
47   if (HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
48     return n->name();
49   }
50   return std::nullopt;
51 }
52 
53 // Add a key placeholder node to the graph. The key placeholder node will be
54 // used as input for XlaRecvAtHost/XlaSendFromHost nodes.
AddHostComputeKeyPlaceholder(const string & xla_cluster_name,Graph * g)55 StatusOr<Node*> AddHostComputeKeyPlaceholder(const string& xla_cluster_name,
56                                              Graph* g) {
57   NodeDef key_def;
58   NodeDefBuilder builder(absl::StrCat(xla_cluster_name, "_key_placeholder"),
59                          "Placeholder");
60   builder.Attr("dtype", DT_STRING);
61   builder.Attr("shape", PartialTensorShape({2}));
62   builder.Attr("_host_compute_call_node", xla_cluster_name);
63   Status s = builder.Finalize(&key_def);
64   if (!s.ok()) return s;
65 
66   Node* n = g->AddNode(key_def, &s);
67   if (!s.ok()) return s;
68   return n;
69 }
70 
71 // Returns if the node is a XLA computation key placeholder.
IsKeyPlaceholderNode(const Node & n)72 bool IsKeyPlaceholderNode(const Node& n) {
73   return n.type_string() == "Placeholder" &&
74          absl::EndsWith(n.name(), "_key_placeholder");
75 }
76 
77 // Returns nodes with given type.
GatherNodesWithType(const Graph & g,const string & type)78 std::vector<Node*> GatherNodesWithType(const Graph& g, const string& type) {
79   std::vector<Node*> result;
80   for (Node* n : g.nodes()) {
81     if (n->type_string() == type) {
82       result.push_back(n);
83     }
84   }
85   return result;
86 }
87 
88 // Gets data types from `arg_nodes` and fills them into `recv_at_host_dtypes`.
GetArgDataTypes(const std::vector<Node * > & arg_nodes,std::vector<DataType> * recv_at_host_dtypes)89 Status GetArgDataTypes(const std::vector<Node*>& arg_nodes,
90                        std::vector<DataType>* recv_at_host_dtypes) {
91   recv_at_host_dtypes->resize(arg_nodes.size(), DT_INVALID);
92   for (auto* n : arg_nodes) {
93     int index;
94     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
95     DataType dtype;
96     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
97     (*recv_at_host_dtypes)[index] = dtype;
98   }
99   for (int i = 0, end = recv_at_host_dtypes->size(); i < end; i++) {
100     if ((*recv_at_host_dtypes)[i] == DT_INVALID) {
101       return errors::Internal("Cannot get datatype for input ", i);
102     }
103   }
104   return OkStatus();
105 }
106 
107 // Builds XlaRecvAtHost node.
BuildRecvAtHostNode(Graph * g,const string & oc_cluster_name,const std::vector<DataType> & recv_at_host_dtypes,Node * key_placeholder)108 StatusOr<Node*> BuildRecvAtHostNode(
109     Graph* g, const string& oc_cluster_name,
110     const std::vector<DataType>& recv_at_host_dtypes, Node* key_placeholder) {
111   NodeDefBuilder recv_at_host_builder(
112       absl::StrCat("outside_compilation_", oc_cluster_name, "_recv"),
113       "_XlaRecvAtHost");
114   NodeDef recv_at_host_def;
115   recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes);
116   // The correct device_ordinal will be inserted during replication in a
117   // subsequent rewrite.
118   AttrValue device_ordinal_value;
119   device_ordinal_value.set_placeholder("_device_ordinal");
120   recv_at_host_builder.Attr("device_ordinal", device_ordinal_value);
121   recv_at_host_builder.Attr(
122       "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
123   recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true);
124   recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
125   TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def));
126   TF_ASSIGN_OR_RETURN(Node * recv_at_host_node, g->AddNode(recv_at_host_def));
127   return recv_at_host_node;
128 }
129 
130 // Builds XlaRecvAtHost node, and replaces all _Arg nodes with it.
ReplaceArgNodesWithRecvAtHostNode(Graph * g,const string & oc_cluster_name,std::vector<DataType> * recv_at_host_dtypes,Node * key_placeholder)131 StatusOr<Node*> ReplaceArgNodesWithRecvAtHostNode(
132     Graph* g, const string& oc_cluster_name,
133     std::vector<DataType>* recv_at_host_dtypes, Node* key_placeholder) {
134   // TODO(b/77601805): use out nodes for source node, instead of traversing all
135   // nodes.
136   std::vector<Node*> arg_nodes = GatherNodesWithType(*g, "_Arg");
137   TF_RETURN_IF_ERROR(GetArgDataTypes(arg_nodes, recv_at_host_dtypes));
138   TF_ASSIGN_OR_RETURN(
139       Node * recv_at_host_node,
140       BuildRecvAtHostNode(g, oc_cluster_name, *recv_at_host_dtypes,
141                           key_placeholder));
142   for (auto* n : arg_nodes) {
143     int index;
144     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
145     // Record out edges and remove `n` before adding those edges to RecvAtHost.
146     // This is to avoid multiple producers.
147     std::vector<OutEdgeInfo> out_edge_info;
148     out_edge_info.reserve(n->out_edges().size());
149     for (auto edge : n->out_edges()) {
150       out_edge_info.push_back(
151           {edge->dst(), edge->src_output(), edge->dst_input()});
152     }
153     g->RemoveNode(n);
154     for (const OutEdgeInfo& edge : out_edge_info) {
155       if (edge.dst_input == Graph::kControlSlot) {
156         g->AddControlEdge(recv_at_host_node, edge.dst);
157       } else {
158         g->AddEdge(recv_at_host_node, index, edge.dst, edge.dst_input);
159       }
160     }
161 
162     // Rewrite dst nodes because their input changed.
163     for (int i = 0, end = out_edge_info.size(); i < end; i++) {
164       const OutEdgeInfo edge = out_edge_info[i];
165       if (edge.dst_input == Graph::kControlSlot) {
166         continue;
167       }
168 
169       Node* dst = edge.dst;
170       NodeDef new_def = dst->def();
171       *new_def.mutable_input(edge.dst_input) =
172           absl::StrCat(recv_at_host_node->name(), ":", index);
173       TF_ASSIGN_OR_RETURN(Node * dst_replace, ReplaceNode(g, dst, new_def));
174 
175       // Other edges might have `dst` as dst node as well. Update those edges
176       // with `dst_replace`.
177       for (int j = i + 1, end = out_edge_info.size(); j < end; j++) {
178         if (out_edge_info[j].dst == dst) {
179           out_edge_info[j].dst = dst_replace;
180         }
181       }
182     }
183   }
184   g->AddEdge(key_placeholder, 0, recv_at_host_node, 0);
185   return recv_at_host_node;
186 }
187 
188 // Gets data types from `ret_nodes` and fills them into `send_from_host_dtypes`.
GetRetDataTypes(const std::vector<Node * > & ret_nodes,std::vector<DataType> * send_from_host_dtypes)189 Status GetRetDataTypes(const std::vector<Node*>& ret_nodes,
190                        std::vector<DataType>* send_from_host_dtypes) {
191   send_from_host_dtypes->resize(ret_nodes.size(), DT_INVALID);
192   for (auto* n : ret_nodes) {
193     int index;
194     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
195     DataType dtype;
196     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
197     (*send_from_host_dtypes)[index] = dtype;
198   }
199   for (int i = 0, end = send_from_host_dtypes->size(); i < end; i++) {
200     if ((*send_from_host_dtypes)[i] == DT_INVALID) {
201       return errors::Internal("Cannot get datatype for output ", i);
202     }
203   }
204   return OkStatus();
205 }
206 
207 // Builds XlaSendFromHost node.
BuildSendFromHostNode(Graph * g,const string & oc_cluster_name,const std::vector<Node * > & ret_nodes,const std::vector<DataType> & send_from_host_dtypes,Node * key_placeholder)208 StatusOr<Node*> BuildSendFromHostNode(
209     Graph* g, const string& oc_cluster_name,
210     const std::vector<Node*>& ret_nodes,
211     const std::vector<DataType>& send_from_host_dtypes, Node* key_placeholder) {
212   NodeDefBuilder send_from_host_builder(
213       absl::StrCat("outside_compilation_", oc_cluster_name, "_send"),
214       "_XlaSendFromHost");
215   NodeDef send_from_host_def;
216   send_from_host_builder.Attr("Tinputs", send_from_host_dtypes);
217   // The correct device_ordinal will be inserted during replication in a
218   // subsequent rewrite.
219   AttrValue device_ordinal_value;
220   device_ordinal_value.set_placeholder("_device_ordinal");
221   send_from_host_builder.Attr("device_ordinal", device_ordinal_value);
222   send_from_host_builder.Attr(
223       "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
224   send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true);
225   std::vector<NodeDefBuilder::NodeOut> inputs(send_from_host_dtypes.size());
226   for (auto* n : ret_nodes) {
227     int index;
228     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
229     const int num_dtypes = send_from_host_dtypes.size();
230     if (index < 0 || index >= num_dtypes) {
231       return errors::Internal("Invalid _Retval index: ", index);
232     }
233     for (auto edge : n->in_edges()) {
234       inputs[index] =
235           NodeDefBuilder::NodeOut{edge->src()->name(), edge->src_output(),
236                                   edge->src()->output_type(edge->src_output())};
237     }
238   }
239   send_from_host_builder.Input(inputs);
240   send_from_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
241   TF_RETURN_IF_ERROR(send_from_host_builder.Finalize(&send_from_host_def));
242   TF_ASSIGN_OR_RETURN(Node * send_from_host_node,
243                       g->AddNode(send_from_host_def));
244   return send_from_host_node;
245 }
246 
247 // Builds XlaSendFromHost node, and replaces all _Retval nodes with it.
ReplaceRetNodesWithSendFromHostNode(Graph * g,const string & oc_cluster_name,std::vector<DataType> * send_from_host_dtypes,Node * key_placeholder)248 StatusOr<Node*> ReplaceRetNodesWithSendFromHostNode(
249     Graph* g, const string& oc_cluster_name,
250     std::vector<DataType>* send_from_host_dtypes, Node* key_placeholder) {
251   // TODO(b/77601805): use in nodes for sink node, instead of traversing all
252   // nodes.
253   std::vector<Node*> ret_nodes = GatherNodesWithType(*g, "_Retval");
254   TF_RETURN_IF_ERROR(GetRetDataTypes(ret_nodes, send_from_host_dtypes));
255   TF_ASSIGN_OR_RETURN(
256       Node * send_from_host_node,
257       BuildSendFromHostNode(g, oc_cluster_name, ret_nodes,
258                             *send_from_host_dtypes, key_placeholder));
259   for (auto* n : ret_nodes) {
260     int index;
261     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
262     for (auto edge : n->in_edges()) {
263       if (edge->src_output() == Graph::kControlSlot) {
264         g->AddControlEdge(edge->src(), send_from_host_node);
265       } else {
266         g->AddEdge(edge->src(), edge->src_output(), send_from_host_node, index);
267       }
268     }
269     g->RemoveNode(n);
270   }
271   g->AddEdge(key_placeholder, 0, send_from_host_node,
272              send_from_host_dtypes->size());
273   return send_from_host_node;
274 }
275 
276 // Returns input shapes (excluding key placeholder) for `send_from_host_node`
277 // if they are all fully defined; std::nullopt otherwise.
GetInferredInputShapes(int num_inputs,Node * send_from_host_node)278 std::optional<std::vector<PartialTensorShape>> GetInferredInputShapes(
279     int num_inputs, Node* send_from_host_node) {
280   std::vector<PartialTensorShape> results(num_inputs);
281   for (int i = 0; i < num_inputs; i++) {
282     const Edge* e;
283     if (!send_from_host_node->input_edge(i, &e).ok()) {
284       return std::nullopt;
285     }
286 
287     std::vector<PartialTensorShape> shapes;
288     if (!GetNodeAttr(e->src()->attrs(), kXlaInferredShapesAttrName, &shapes)
289              .ok()) {
290       return std::nullopt;
291     }
292 
293     const PartialTensorShape shape = shapes[e->src_output()];
294     if (!shape.IsFullyDefined()) {
295       return std::nullopt;
296     }
297 
298     results[e->dst_input()] = shape;
299   }
300   return results;
301 }
302 
host_compute_node_name(const string & original_oc_name)303 string host_compute_node_name(const string& original_oc_name) {
304   return absl::StrCat("outside_compilation_", original_oc_name,
305                       "_host_compute");
306 }
307 
308 // Builds XlaHostCompute NodeDef from the outside compilation call node.
BuildXlaHostComputeNodeDef(const Node * call_node,const std::map<string,int> & host_compute_core,const absl::flat_hash_map<string,std::vector<string>> & cluster_deps)309 StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
310     const Node* call_node, const std::map<string, int>& host_compute_core,
311     const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
312   string original_oc_name;
313   TF_RETURN_IF_ERROR(GetNodeAttr(
314       call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name));
315   NodeDefBuilder host_compute_builder(host_compute_node_name(original_oc_name),
316                                       "XlaHostCompute");
317   // In XlaCompiler, if XlaHostCompute node is in a function call node and that
318   // function is inlined, name of the XlaHostCompute node will be changed. So
319   // we cannot rely on node name; use an attribute instead.
320   host_compute_builder.Attr(kXlaOriginalOutsideCompilationNodeName,
321                             host_compute_builder.node_name());
322 
323   // Copy all attributes.
324   for (const auto& attr : call_node->attrs()) {
325     host_compute_builder.Attr(attr.first, attr.second);
326   }
327 
328   // Populate tpu_core assignment.
329   const auto iter = host_compute_core.find(original_oc_name);
330   if (iter != host_compute_core.end()) {
331     int core = iter->second;
332     host_compute_builder.Attr("tpu_core", core);
333   }
334 
335   // Set input tokens and other outside compilation clusters that current
336   // cluster depends in `kXlaTokenArgNodeName`. This is needed because when
337   // outside compilation subgraphs are encapsulated and moved to host graph,
338   // control/data edges between them will only be reflected in host graph.
339   // From XLA's perspective, two originally dependent clusters are no longer
340   // connected, which makes them look like they can be scheduled for execution
341   // in arbitrary order even though in fact they must be executed in order
342   // according to their host-side graph dependency. This can cause deadlock.
343   // Therefore, we hint XLA what the correct ordering of these clusters should
344   // be to avoid deadlocks.
345   std::vector<string> xla_token_input_nodes;
346   xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName);
347   auto cluster_deps_it = cluster_deps.find(original_oc_name);
348   if (cluster_deps_it != cluster_deps.end()) {
349     for (const auto& dep : cluster_deps_it->second) {
350       xla_token_input_nodes.emplace_back(host_compute_node_name(dep));
351     }
352   }
353   host_compute_builder.Attr(kXlaTokenInputNodesAttrName, xla_token_input_nodes);
354 
355   // Populate inputs.
356   std::vector<DataType> input_dtypes;
357   TF_RETURN_IF_ERROR(GetNodeAttr(call_node->attrs(), "Tinputs", &input_dtypes));
358   std::vector<NodeDefBuilder::NodeOut> inputs(input_dtypes.size());
359   for (auto e : call_node->in_edges()) {
360     if (e->IsControlEdge()) {
361       continue;
362     }
363 
364     const int input_dtypes_size = input_dtypes.size();
365     if (e->dst_input() < 0 || e->dst_input() >= input_dtypes_size) {
366       return errors::Internal("Invalid dst_input: ", e->dst_input());
367     }
368     inputs[e->dst_input()] = NodeDefBuilder::NodeOut{
369         e->src()->name(), e->src_output(), input_dtypes[e->dst_input()]};
370   }
371   host_compute_builder.Input(inputs);
372 
373   NodeDef new_def;
374   TF_RETURN_IF_ERROR(host_compute_builder.Finalize(&new_def));
375   return new_def;
376 }
377 
378 // Replace outside compilation function call node with XlaHostCompute node.
ReplaceOutsideCompilationCallNode(Graph * g,Node * call_node,const std::map<string,int> & host_compute_core,const absl::flat_hash_map<string,std::vector<string>> & cluster_deps)379 TF_ATTRIBUTE_NOINLINE StatusOr<Node*> ReplaceOutsideCompilationCallNode(
380     Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
381     const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
382   // Build XlaHostCompute NodeDef.
383   TF_ASSIGN_OR_RETURN(
384       NodeDef node_def,
385       BuildXlaHostComputeNodeDef(call_node, host_compute_core, cluster_deps));
386   TF_ASSIGN_OR_RETURN(Node * host_compute_node,
387                       ReplaceNode(g, call_node, node_def));
388   VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
389 
390   return host_compute_node;
391 }
392 
393 // Resets "_device_ordinal" attr to placeholder value for related nodes
394 // (XlaRecvAtHost nodes; XlaSendFromHost nodes; If/While/FuncCall nodes
395 // containing XlaRecvAtHost/XlaSendFromHost).
ResetDeviceOrdinalToPlaceholderValue(Graph * g)396 Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) {
397   AttrValue device_ordinal_value;
398   device_ordinal_value.set_placeholder("_device_ordinal");
399   for (Node* n : g->nodes()) {
400     if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
401       continue;
402     }
403 
404     if (n->type_string() == "_XlaRecvAtHost" ||
405         n->type_string() == "_XlaSendFromHost") {
406       n->ClearAttr("device_ordinal");
407       n->AddAttr("device_ordinal", device_ordinal_value);
408     } else if (n->IsIfNode()) {
409       for (const string& attr_name :
410            std::vector<string>{"then_branch", "else_branch"}) {
411         NameAttrList branch_func;
412         TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
413         (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
414         n->ClearAttr(attr_name);
415         n->AddAttr(attr_name, branch_func);
416       }
417     } else if (n->IsWhileNode()) {
418       for (const string& attr_name : std::vector<string>{"cond", "body"}) {
419         NameAttrList branch_func;
420         TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
421         (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
422         n->ClearAttr(attr_name);
423         n->AddAttr(attr_name, branch_func);
424       }
425     } else if (HasNodeAttr(n->def(), "_device_ordinal")) {
426       // Function call node containing outside compilation.
427       n->ClearAttr("_device_ordinal");
428       n->AddAttr("_device_ordinal", device_ordinal_value);
429     } else {
430       return errors::Internal("Unknown node marked with ",
431                               kXlaHasHostTransferAttrName, ": ",
432                               n->DebugString());
433     }
434   }
435   return OkStatus();
436 }
437 
438 // Cheap check to tell whether FunctionDef contains a lifted argument.
HasLiftedArgs(const FunctionDef & function_def)439 bool HasLiftedArgs(const FunctionDef& function_def) {
440   return absl::c_any_of(function_def.node_def(), [](const NodeDef& node_def) {
441     return (node_def.op() == "Placeholder" &&
442             node_def.attr().find(kXlaLiftedArgOutsideCompilationAttrName) !=
443                 node_def.attr().end());
444   });
445 }
446 
447 // Find lifted arguments in a function body and their corresponding outside
448 // compilation nodes.
449 StatusOr<std::vector<std::pair<Node*, Node*>>>
LiftedArgsAndOutsideCompilationNodesInFunctionBody(const FunctionBody & function_body,const std::unordered_map<string,Node * > & outside_compilation_attr_to_node)450 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
451     const FunctionBody& function_body,
452     const std::unordered_map<string, Node*>& outside_compilation_attr_to_node) {
453   std::vector<std::pair<Node*, Node*>>
454       lifted_arg_nodes_and_outside_compilation_nodes;
455   for (Node* n : function_body.graph->op_nodes()) {
456     string oc_cluster;
457     if (n->type_string() == "Placeholder" &&
458         GetNodeAttr(n->def(), kXlaLiftedArgOutsideCompilationAttrName,
459                     &oc_cluster)
460             .ok()) {
461       TF_RET_CHECK(outside_compilation_attr_to_node.find(oc_cluster) !=
462                    outside_compilation_attr_to_node.end());
463       lifted_arg_nodes_and_outside_compilation_nodes.emplace_back(
464           n, outside_compilation_attr_to_node.at(oc_cluster));
465     }
466   }
467   return lifted_arg_nodes_and_outside_compilation_nodes;
468 }
469 
470 // Append lifted args' types to functional control flow node's `type_attr_name`
471 // attribute.
UpdateTypesAttribute(const std::vector<std::pair<Node *,Node * >> & lifted_arg_nodes_and_outside_compilation_nodes,const string & type_attr_name,Node * n)472 StatusOr<std::vector<DataType>> UpdateTypesAttribute(
473     const std::vector<std::pair<Node*, Node*>>&
474         lifted_arg_nodes_and_outside_compilation_nodes,
475     const string& type_attr_name, Node* n) {
476   std::vector<DataType> data_types;
477   data_types.reserve(lifted_arg_nodes_and_outside_compilation_nodes.size());
478   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), type_attr_name, &data_types));
479   for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) {
480     Node* outside_compilation_node = pair.second;
481     DataType data_type;
482     TF_RET_CHECK(outside_compilation_node->IsIdentity() ||
483                  outside_compilation_node->type_string() == "Placeholder");
484     if (outside_compilation_node->IsIdentity()) {
485       TF_RETURN_IF_ERROR(
486           GetNodeAttr(outside_compilation_node->def(), "T", &data_type));
487     } else {
488       TF_RETURN_IF_ERROR(
489           GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type));
490     }
491     data_types.push_back(data_type);
492   }
493   n->ClearAttr(type_attr_name);
494   n->AddAttr(type_attr_name, data_types);
495 
496   return data_types;
497 }
498 
499 // Add edges from lifted outside compilation argument nodes to `n` in Graph `g`.
AddEdgesFromOutsideCompilationNodes(const int original_arg_count,const int arg_to_input_edge_offset,const std::vector<DataType> & data_types,const std::vector<Node * > & outside_compilation_nodes,Graph * g,Node * n)500 void AddEdgesFromOutsideCompilationNodes(
501     const int original_arg_count, const int arg_to_input_edge_offset,
502     const std::vector<DataType>& data_types,
503     const std::vector<Node*>& outside_compilation_nodes, Graph* g, Node* n) {
504   // Add edges from outside compilation nodes to While node.
505   for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
506     Node* outside_compilation_node =
507         outside_compilation_nodes[i - original_arg_count];
508     g->AddEdge(outside_compilation_node, 0, n, i + arg_to_input_edge_offset);
509   }
510 }
511 
512 // Construct _Arg that maps to lifted outside compilation argument node input.
AddOutsideCompilationInputArgToFunctionBody(const FunctionBody & function_body,const int arg_idx,const DataType & data_type)513 StatusOr<Node*> AddOutsideCompilationInputArgToFunctionBody(
514     const FunctionBody& function_body, const int arg_idx,
515     const DataType& data_type) {
516   NodeDefBuilder arg_builder(absl::StrCat("arg_", arg_idx), "_Arg");
517   arg_builder.Attr("T", data_type);
518   arg_builder.Attr("index", arg_idx);
519   NodeDef arg_def;
520   TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
521 
522   TF_ASSIGN_OR_RETURN(Node * arg_node, function_body.graph->AddNode(arg_def));
523   return arg_node;
524 }
525 
526 // Add _Retval node that matches newly added `arg_node` and connect `arg_node`
527 // to it.
AddMatchingRetvalNode(const FunctionBody & function_body,const int arg_idx,const DataType & data_type,Node * arg_node)528 Status AddMatchingRetvalNode(const FunctionBody& function_body,
529                              const int arg_idx, const DataType& data_type,
530                              Node* arg_node) {
531   NodeDefBuilder ret_builder(absl::StrCat("ret_", arg_idx), "_Retval");
532   ret_builder.Attr("T", data_type);
533   ret_builder.Attr("index", arg_idx);
534   ret_builder.Input(arg_node->name(), 0, data_type);
535   NodeDef ret_def;
536   TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
537   TF_ASSIGN_OR_RETURN(Node * ret_node, function_body.graph->AddNode(ret_def));
538   function_body.graph->AddEdge(arg_node, 0, ret_node, 0);
539 
540   return OkStatus();
541 }
542 
ReplaceLiftedArgNodePlaceholderWithArg(const FunctionBody & function_body,const int original_arg_count,const int arg_idx,const std::vector<Node * > & lifted_arg_nodes,Node * arg_node)543 void ReplaceLiftedArgNodePlaceholderWithArg(
544     const FunctionBody& function_body, const int original_arg_count,
545     const int arg_idx, const std::vector<Node*>& lifted_arg_nodes,
546     Node* arg_node) {
547   Node* lifted_arg_node = lifted_arg_nodes[arg_idx - original_arg_count];
548   // This might happen because lifted_arg_node only exists in one branch of an
549   // If node, and we are handling the other branch.
550   if (!lifted_arg_node) {
551     return;
552   }
553 
554   for (const Edge* e : lifted_arg_node->out_edges()) {
555     if (e->IsControlEdge()) {
556       function_body.graph->AddControlEdge(arg_node, e->dst());
557     } else {
558       function_body.graph->AddEdge(arg_node, 0, e->dst(), e->dst_input());
559     }
560   }
561   function_body.graph->RemoveNode(lifted_arg_node);
562 }
563 
564 // Adds function def to function definition library and update the function
565 // callsite operation `callsite_node` to invoke new function instead.
AddFunctionWithNewName(const std::string & new_name,const std::string & func_attr_name,const FunctionDef & function_def,NameAttrList * func_attr,Node * callsite_node,FunctionLibraryDefinition * fld)566 Status AddFunctionWithNewName(const std::string& new_name,
567                               const std::string& func_attr_name,
568                               const FunctionDef& function_def,
569                               NameAttrList* func_attr, Node* callsite_node,
570                               FunctionLibraryDefinition* fld) {
571   TF_RETURN_IF_ERROR(fld->AddFunctionDef(function_def));
572   func_attr->set_name(new_name);
573   callsite_node->ClearAttr(func_attr_name);
574   callsite_node->AddAttr(func_attr_name, *func_attr);
575   return OkStatus();
576 }
577 
578 // Reconnect outside compilation lifted arguments in a functional While node to
579 // its outside compilation tensor sources.
PostprocessLiftedArgsForWhile(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)580 Status PostprocessLiftedArgsForWhile(
581     const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
582     Graph* g, Node* n, FunctionLibraryDefinition* fld) {
583   TF_RET_CHECK(n->IsWhileNode());
584 
585   // Check if there is any lifted args in body function.
586   NameAttrList body_func;
587   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "body", &body_func));
588   const FunctionDef* body_function_def = fld->Find(body_func.name());
589   TF_RET_CHECK(body_function_def);
590 
591   if (!HasLiftedArgs(*body_function_def)) {
592     return OkStatus();
593   }
594 
595   // Gather all lifted args.
596   std::unique_ptr<FunctionBody> body_function_body;
597   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*body_function_def,
598                                              AttrSlice(&body_func.attr()), fld,
599                                              &body_function_body));
600 
601   int original_arg_count = body_function_body->arg_nodes.size();
602 
603   TF_ASSIGN_OR_RETURN(
604       auto lifted_arg_nodes_and_outside_compilation_nodes,
605       LiftedArgsAndOutsideCompilationNodesInFunctionBody(
606           *body_function_body, outside_compilation_attr_to_node));
607 
608   // Append lifted args' types to While node's T attribute.
609   TF_ASSIGN_OR_RETURN(
610       std::vector<DataType> data_types,
611       UpdateTypesAttribute(lifted_arg_nodes_and_outside_compilation_nodes, "T",
612                            n));
613 
614   // Add edges from outside compilation nodes to While node.
615   std::vector<Node*> outside_compilation_nodes;
616   outside_compilation_nodes.reserve(
617       lifted_arg_nodes_and_outside_compilation_nodes.size());
618   std::transform(
619       lifted_arg_nodes_and_outside_compilation_nodes.begin(),
620       lifted_arg_nodes_and_outside_compilation_nodes.end(),
621       std::back_inserter(outside_compilation_nodes),
622       [](const std::pair<Node*, Node*>& pair) { return pair.second; });
623   AddEdgesFromOutsideCompilationNodes(original_arg_count,
624                                       /*arg_to_input_edge_offset=*/0,
625                                       data_types, outside_compilation_nodes, g,
626                                       n);
627 
628   // In body_graph, create new _Arg/_Retval nodes, and replace lifted arg
629   // nodes with the new _Arg nodes.
630   std::vector<Node*> lifted_arg_nodes;
631   lifted_arg_nodes.reserve(
632       lifted_arg_nodes_and_outside_compilation_nodes.size());
633   std::transform(
634       lifted_arg_nodes_and_outside_compilation_nodes.begin(),
635       lifted_arg_nodes_and_outside_compilation_nodes.end(),
636       std::back_inserter(lifted_arg_nodes),
637       [](const std::pair<Node*, Node*>& pair) { return pair.first; });
638   for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
639     TF_ASSIGN_OR_RETURN(Node * arg_node,
640                         AddOutsideCompilationInputArgToFunctionBody(
641                             *body_function_body, i, data_types[i]));
642 
643     TF_RETURN_IF_ERROR(
644         AddMatchingRetvalNode(*body_function_body, i, data_types[i], arg_node));
645 
646     ReplaceLiftedArgNodePlaceholderWithArg(
647         *body_function_body, original_arg_count, i, lifted_arg_nodes, arg_node);
648   }
649 
650   const auto new_body_function_name =
651       fld->UniqueFunctionName(absl::StrCat(body_func.name(), "_lifted_arg_"));
652   FunctionDef rewritten_body_function_def;
653   TF_RETURN_IF_ERROR(GraphToFunctionDef(
654       *body_function_body->graph, new_body_function_name,
655       HostGraphControlRetMapping, &rewritten_body_function_def));
656   TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_body_function_name, "body",
657                                             rewritten_body_function_def,
658                                             &body_func, n, fld));
659 
660   // In cond_graph, just add new _Arg nodes.
661   NameAttrList cond_func;
662   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "cond", &cond_func));
663   const FunctionDef* cond_function_def = fld->Find(cond_func.name());
664   TF_RET_CHECK(cond_function_def);
665   std::unique_ptr<FunctionBody> cond_function_body;
666   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*cond_function_def,
667                                              AttrSlice(&cond_func.attr()), fld,
668                                              &cond_function_body));
669 
670   for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
671     StatusOr<Node*> arg_node_or = AddOutsideCompilationInputArgToFunctionBody(
672         *cond_function_body, i, data_types[i]);
673     TF_RETURN_IF_ERROR(arg_node_or.status());
674   }
675 
676   const auto new_cond_function_name =
677       fld->UniqueFunctionName(absl::StrCat(cond_func.name(), "_lifted_arg_"));
678   FunctionDef rewritten_cond_function_def;
679   TF_RETURN_IF_ERROR(GraphToFunctionDef(
680       *cond_function_body->graph, new_cond_function_name,
681       HostGraphControlRetMapping, &rewritten_cond_function_def));
682   TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_cond_function_name, "cond",
683                                             rewritten_cond_function_def,
684                                             &cond_func, n, fld));
685   return OkStatus();
686 }
687 
PostprocessLiftedArgsForIf(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)688 Status PostprocessLiftedArgsForIf(
689     const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
690     Graph* g, Node* n, FunctionLibraryDefinition* fld) {
691   TF_RET_CHECK(n->IsIfNode());
692 
693   NameAttrList then_branch_func;
694   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "then_branch", &then_branch_func));
695   const FunctionDef* then_branch_function_def =
696       fld->Find(then_branch_func.name());
697   TF_RET_CHECK(then_branch_function_def);
698 
699   NameAttrList else_branch_func;
700   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "else_branch", &else_branch_func));
701   const FunctionDef* else_branch_function_def =
702       fld->Find(else_branch_func.name());
703   TF_RET_CHECK(else_branch_function_def);
704 
705   // Nothing to do if neither branch contains any lifted arguments.
706   if (!HasLiftedArgs(*then_branch_function_def) &&
707       !HasLiftedArgs(*else_branch_function_def)) {
708     return OkStatus();
709   }
710 
711   std::unique_ptr<FunctionBody> then_branch_function_body;
712   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
713       *then_branch_function_def, AttrSlice(&then_branch_func.attr()), fld,
714       &then_branch_function_body));
715 
716   std::unique_ptr<FunctionBody> else_branch_function_body;
717   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
718       *else_branch_function_def, AttrSlice(&else_branch_func.attr()), fld,
719       &else_branch_function_body));
720 
721   // Then and else branches have same argument count and argument data types.
722   int original_arg_count = then_branch_function_body->arg_nodes.size();
723 
724   TF_ASSIGN_OR_RETURN(
725       auto then_branch_lifted_arg_nodes_and_outside_compilation_nodes,
726       LiftedArgsAndOutsideCompilationNodesInFunctionBody(
727           *then_branch_function_body, outside_compilation_attr_to_node));
728 
729   TF_ASSIGN_OR_RETURN(
730       auto else_branch_lifted_arg_nodes_and_outside_compilation_nodes,
731       LiftedArgsAndOutsideCompilationNodesInFunctionBody(
732           *else_branch_function_body, outside_compilation_attr_to_node));
733 
734   // Merge lifted args from then and else branches.
735   std::vector<Node*> outside_compilation_nodes;
736   std::vector<Node*> then_branch_lifted_arg_nodes;
737   outside_compilation_nodes.reserve(
738       then_branch_lifted_arg_nodes_and_outside_compilation_nodes.size());
739   then_branch_lifted_arg_nodes.reserve(
740       then_branch_lifted_arg_nodes_and_outside_compilation_nodes.size());
741   for (const auto& pair :
742        then_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
743     outside_compilation_nodes.push_back(pair.second);
744     then_branch_lifted_arg_nodes.push_back(pair.first);
745   }
746   for (const auto& pair :
747        else_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
748     if (std::find(outside_compilation_nodes.begin(),
749                   outside_compilation_nodes.end(),
750                   pair.second) == outside_compilation_nodes.end()) {
751       outside_compilation_nodes.push_back(pair.second);
752       // Then branch does not contain this lifted arg. Add an empty item to
753       // then_branch_lifted_arg_nodes.
754       then_branch_lifted_arg_nodes.push_back(nullptr);
755     }
756   }
757   // Reorder else_branch_lifted_arg_nodes_and_outside_compilation_nodes.
758   std::vector<Node*> else_branch_lifted_arg_nodes(
759       outside_compilation_nodes.size());
760   for (const auto& pair :
761        else_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
762     auto iter = std::find(outside_compilation_nodes.begin(),
763                           outside_compilation_nodes.end(), pair.second);
764     TF_RET_CHECK(iter != outside_compilation_nodes.end());
765     int index = iter - outside_compilation_nodes.begin();
766     else_branch_lifted_arg_nodes[index] = pair.first;
767   }
768 
769   // Append lifted args' types to If node's Tin attribute.
770   std::vector<DataType> data_types;
771   data_types.reserve(outside_compilation_nodes.size());
772   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &data_types));
773   for (Node* n : outside_compilation_nodes) {
774     data_types.push_back(n->output_type(0));
775   }
776   n->ClearAttr("Tin");
777   n->AddAttr("Tin", data_types);
778 
779   // Add edges from outside compilation nodes to If node. If node's input #0
780   // is predicate input, input #1 maps to _Arg #0 of branch functions, thus
781   // arg_to_input_edge_offset is set to 1.
782   AddEdgesFromOutsideCompilationNodes(original_arg_count,
783                                       /*arg_to_input_edge_offset=*/1,
784                                       data_types, outside_compilation_nodes, g,
785                                       n);
786 
787   for (int i = original_arg_count, end = data_types.size(); i < end; ++i) {
788     TF_ASSIGN_OR_RETURN(Node * then_branch_arg_node,
789                         AddOutsideCompilationInputArgToFunctionBody(
790                             *then_branch_function_body, i, data_types[i]));
791 
792     ReplaceLiftedArgNodePlaceholderWithArg(
793         *then_branch_function_body, original_arg_count, i,
794         then_branch_lifted_arg_nodes, then_branch_arg_node);
795 
796     TF_ASSIGN_OR_RETURN(Node * else_branch_arg_node,
797                         AddOutsideCompilationInputArgToFunctionBody(
798                             *else_branch_function_body, i, data_types[i]));
799 
800     ReplaceLiftedArgNodePlaceholderWithArg(
801         *else_branch_function_body, original_arg_count, i,
802         else_branch_lifted_arg_nodes, else_branch_arg_node);
803   }
804 
805   const auto new_then_function_name = fld->UniqueFunctionName(
806       absl::StrCat(then_branch_func.name(), "_lifted_arg_"));
807   FunctionDef rewritten_then_branch_function_def;
808   TF_RETURN_IF_ERROR(GraphToFunctionDef(
809       *then_branch_function_body->graph, new_then_function_name,
810       HostGraphControlRetMapping, &rewritten_then_branch_function_def));
811   TF_RETURN_IF_ERROR(AddFunctionWithNewName(
812       new_then_function_name, "then_branch", rewritten_then_branch_function_def,
813       &then_branch_func, n, fld));
814 
815   const auto new_else_function_name = fld->UniqueFunctionName(
816       absl::StrCat(else_branch_func.name(), "_lifted_arg_"));
817   FunctionDef rewritten_else_branch_function_def;
818   TF_RETURN_IF_ERROR(GraphToFunctionDef(
819       *else_branch_function_body->graph, new_else_function_name,
820       HostGraphControlRetMapping, &rewritten_else_branch_function_def));
821   TF_RETURN_IF_ERROR(AddFunctionWithNewName(
822       new_else_function_name, "else_branch", rewritten_else_branch_function_def,
823       &else_branch_func, n, fld));
824   return OkStatus();
825 }
826 
PostprocessLiftedArgsForCall(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)827 Status PostprocessLiftedArgsForCall(
828     const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
829     Graph* g, Node* n, FunctionLibraryDefinition* fld) {
830   const FunctionDef* fdef = fld->Find(n->type_string());
831   TF_RET_CHECK(fdef);
832 
833   // Nothing to do if the function does not contain any lifted arguments.
834   if (!HasLiftedArgs(*fdef)) {
835     return OkStatus();
836   }
837 
838   std::unique_ptr<FunctionBody> fbody;
839   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, n->attrs(), fld, &fbody));
840 
841   int original_arg_count = fbody->arg_nodes.size();
842 
843   TF_ASSIGN_OR_RETURN(auto lifted_arg_nodes_and_outside_compilation_nodes,
844                       LiftedArgsAndOutsideCompilationNodesInFunctionBody(
845                           *fbody, outside_compilation_attr_to_node));
846 
847   // Append lifted args' types to call node's input data types.
848   std::vector<DataType> data_types(n->input_types().begin(),
849                                    n->input_types().end());
850   for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) {
851     Node* outside_compilation_node = pair.second;
852     DataType data_type;
853     TF_RET_CHECK(outside_compilation_node->IsIdentity() ||
854                  outside_compilation_node->type_string() == "Placeholder");
855     if (outside_compilation_node->IsIdentity()) {
856       TF_RETURN_IF_ERROR(
857           GetNodeAttr(outside_compilation_node->def(), "T", &data_type));
858     } else {
859       TF_RETURN_IF_ERROR(
860           GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type));
861     }
862     data_types.push_back(data_type);
863   }
864 
865   std::vector<Node*> lifted_arg_nodes;
866   lifted_arg_nodes.reserve(
867       lifted_arg_nodes_and_outside_compilation_nodes.size());
868   std::transform(
869       lifted_arg_nodes_and_outside_compilation_nodes.begin(),
870       lifted_arg_nodes_and_outside_compilation_nodes.end(),
871       std::back_inserter(lifted_arg_nodes),
872       [](const std::pair<Node*, Node*>& pair) { return pair.first; });
873   for (int i = original_arg_count, end = data_types.size(); i < end; ++i) {
874     TF_ASSIGN_OR_RETURN(
875         Node * arg_node,
876         AddOutsideCompilationInputArgToFunctionBody(*fbody, i, data_types[i]));
877 
878     ReplaceLiftedArgNodePlaceholderWithArg(*fbody, original_arg_count, i,
879                                            lifted_arg_nodes, arg_node);
880   }
881 
882   FunctionDef rewritten_fdef;
883   TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, n->type_string(),
884                                         HostGraphControlRetMapping,
885                                         &rewritten_fdef));
886   const auto new_function_name =
887       fld->UniqueFunctionName(absl::StrCat(n->type_string(), "_lifted_arg_"));
888   rewritten_fdef.mutable_signature()->set_name(new_function_name);
889   TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
890 
891   // We need to recreate the node. Otherwise TF will not know n->num_inputs()
892   // has increased.
893   NodeDef node_def = n->def();
894 
895   // Function name is represented via the Op's type. Reset the op type to new
896   // function def name;
897   *node_def.mutable_op() = new_function_name;
898 
899   for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
900     Node* outside_compilation_node =
901         lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count]
902             .second;
903     node_def.add_input(absl::StrCat(outside_compilation_node->name(), ":", 0));
904   }
905   TF_ASSIGN_OR_RETURN(n, ReplaceNode(g, n, node_def));
906 
907   // Add edges from outside compilation nodes to call node.
908   std::vector<Node*> outside_compilation_nodes;
909   outside_compilation_nodes.reserve(
910       lifted_arg_nodes_and_outside_compilation_nodes.size());
911   std::transform(
912       lifted_arg_nodes_and_outside_compilation_nodes.begin(),
913       lifted_arg_nodes_and_outside_compilation_nodes.end(),
914       std::back_inserter(outside_compilation_nodes),
915       [](const std::pair<Node*, Node*>& pair) { return pair.second; });
916   AddEdgesFromOutsideCompilationNodes(original_arg_count,
917                                       /*arg_to_input_edge_offset=*/0,
918                                       data_types, outside_compilation_nodes, g,
919                                       n);
920 
921   return OkStatus();
922 }
923 
924 // Creates a mapping from outside compilation cluster name to lifted argument
925 // placeholder.
OutsideCompilationAttrToNode(const Graph & g)926 StatusOr<std::unordered_map<string, Node*>> OutsideCompilationAttrToNode(
927     const Graph& g) {
928   std::unordered_map<string, Node*> outside_compilation_attr_to_node;
929   for (Node* n : g.op_nodes()) {
930     bool is_lifted_arg;
931     string outside_compilation_attr;
932     if (TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) &&
933         TryGetNodeAttr(n->def(), "_xla_outside_compilation",
934                        &outside_compilation_attr)) {
935       TF_RET_CHECK(is_lifted_arg);
936       TF_RET_CHECK(n->IsIdentity() || n->type_string() == "Placeholder");
937       outside_compilation_attr_to_node[outside_compilation_attr] = n;
938     }
939   }
940 
941   return outside_compilation_attr_to_node;
942 }
943 
PostprocessLiftedArgs(Graph * g,FunctionLibraryDefinition * fld)944 Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) {
945   TF_ASSIGN_OR_RETURN(auto outside_compilation_attr_to_node,
946                       OutsideCompilationAttrToNode(*g));
947 
948   std::vector<Node*> call_nodes;
949   for (Node* n : g->op_nodes()) {
950     if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
951       continue;
952     }
953 
954     if (n->IsWhileNode()) {
955       TF_RETURN_IF_ERROR(PostprocessLiftedArgsForWhile(
956           outside_compilation_attr_to_node, g, n, fld));
957     }
958 
959     if (n->IsIfNode()) {
960       TF_RETURN_IF_ERROR(PostprocessLiftedArgsForIf(
961           outside_compilation_attr_to_node, g, n, fld));
962     }
963 
964     // Outside compilation host side function call will always be direct
965     // function call nodes.
966     // Function call nodes need to be handled separately because we rewrite
967     // nodes in `PostprocessLiftedArgsForCall`.
968     if (fld->Contains(n->type_string())) {
969       call_nodes.push_back(n);
970     }
971   }
972 
973   for (Node* n : call_nodes) {
974     TF_RETURN_IF_ERROR(PostprocessLiftedArgsForCall(
975         outside_compilation_attr_to_node, g, n, fld));
976   }
977 
978   return OkStatus();
979 }
980 
981 // For an XLA computation, builds host side graph given all outside compilation
982 // graphs inside it. The host side graph contains:
983 // 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and
984 //    XlaSendFromHost to this sequencer node, so all outside compilation nodes
985 //    will be executed *before* this sequencer).
986 // 2) a "key placeholder" node. Later in ExpandHostGraphIntoMainGraph(), we will
987 //    replace this node with compilation result node.
988 // 3) all outside compilation graphs.
ConstructHostGraph(const string & xla_cluster_name,const string & outside_compilation_attr_name,const std::vector<string> & outside_compilation_host_graphs,FunctionLibraryDefinition * fld,std::unique_ptr<Graph> * host_graph)989 Status ConstructHostGraph(
990     const string& xla_cluster_name, const string& outside_compilation_attr_name,
991     const std::vector<string>& outside_compilation_host_graphs,
992     FunctionLibraryDefinition* fld, std::unique_ptr<Graph>* host_graph) {
993   host_graph->reset(new Graph(fld));
994 
995   // Create sequencer node in host graph.
996   NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"),
997                                    "NoOp");
998   sequencer_builder.Attr("_xla_host_transfer_sequencer", xla_cluster_name);
999   NodeDef sequencer_def;
1000   TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def));
1001   TF_ASSIGN_OR_RETURN(Node * sequencer, (*host_graph)->AddNode(sequencer_def));
1002 
1003   // Create key placeholder in host graph.
1004   TF_ASSIGN_OR_RETURN(
1005       Node * key_placeholder,
1006       AddHostComputeKeyPlaceholder(xla_cluster_name, host_graph->get()));
1007 
1008   // For each outside compilation graph, copy them to host graph with the
1009   // following changes:
1010   // a) Use key_placeholder in host graph instead of its own.
1011   // b) Add control edge from host transfer nodes (XlaRecvAtHost,
1012   //    XlaSendFromHost, If/While nodes containing
1013   //    XlaRecvAtHost/XlaSendFromHost) to sequencer node.
1014   // c) Clear node_def.device(), so device placer won't get confused.
1015   for (const string& host_func : outside_compilation_host_graphs) {
1016     VLOG(4) << "Expanding host graph " << host_func;
1017     // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder
1018     // value after we expanded all host graphs. We cannot just use placeholder
1019     // value here because FunctionDef instantiation does not allow placeholder
1020     // value for attributes.
1021     AttrValue device_ordinal_attr;
1022     device_ordinal_attr.set_i(0);
1023     protobuf::Map<string, AttrValue> attrs;
1024     attrs["_device_ordinal"] = device_ordinal_attr;
1025     std::unique_ptr<FunctionBody> host_fbody;
1026     const FunctionDef* host_fdef = fld->Find(host_func);
1027     TF_RET_CHECK(host_fdef);
1028     TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_fdef, AttrSlice(&attrs),
1029                                                fld, &host_fbody));
1030 
1031     // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
1032     // reachable from sink node so all nodes will be copied.
1033     // TODO(b/77601805): consolidate copy graph functions.
1034     FixupSourceAndSinkEdges(host_fbody->graph);
1035 
1036     std::map<const Node*, Node*> node_map;
1037     node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node();
1038     node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node();
1039     Status s;
1040     ReverseDFS(
1041         *host_fbody->graph, /*enter=*/nullptr,
1042         [&](const Node* n) {
1043           if (!s.ok()) {
1044             return;
1045           }
1046 
1047           Node* copy;
1048           if (node_map.find(n) != node_map.end()) {
1049             // Already copied this node.
1050             copy = node_map.at(n);
1051           } else if (IsKeyPlaceholderNode(*n)) {
1052             // Change a).
1053             copy = key_placeholder;
1054             node_map[n] = copy;
1055           } else {
1056             // Copy the node.
1057             NodeDef copy_def = n->def();
1058             // Change c).
1059             copy_def.clear_device();
1060             copy = (*host_graph)->AddNode(copy_def, &s);
1061             if (!s.ok()) {
1062               return;
1063             }
1064             node_map[n] = copy;
1065           }
1066 
1067           // Only handle input edges. Output edges will be added later as
1068           // its output nodes' input edges.
1069           for (auto e : n->in_edges()) {
1070             if (node_map.find(e->src()) == node_map.end()) {
1071               s = errors::Internal("Cannot find node image for ",
1072                                    e->src()->DebugString());
1073               return;
1074             }
1075             (*host_graph)
1076                 ->AddEdge(node_map[e->src()], e->src_output(), copy,
1077                           e->dst_input());
1078           }
1079 
1080           // Change b).
1081           if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) {
1082             (*host_graph)->AddControlEdge(copy, sequencer);
1083           }
1084         },
1085         NodeComparatorID());
1086 
1087     if (!s.ok()) {
1088       return s;
1089     }
1090   }
1091   // Reset "_device_ordinal" to placeholder value.
1092   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(host_graph->get()));
1093 
1094   // sequencer and key_placeholder might be dead nodes. Prune them if necessary.
1095   // - sequencer should be pruned iff it has no input control edges from
1096   //   RecvAtHost/SendFromHost. If it has input control edge, we connect it to
1097   //   sink node so it won't be pruned.
1098   // - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost.
1099   //   We don't need to do anything special.
1100   if (!sequencer->in_edges().empty()) {
1101     (*host_graph)->AddControlEdge(sequencer, (*host_graph)->sink_node());
1102   }
1103   PruneForReverseReachability(
1104       host_graph->get(),
1105       std::unordered_set<const Node*>{(*host_graph)->sink_node()});
1106 
1107   // Postprocess edges between different outside compilations.
1108   TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations(
1109       host_graph->get(), outside_compilation_attr_name));
1110 
1111   // Postprocess lifted arg nodes.
1112   TF_RETURN_IF_ERROR(PostprocessLiftedArgs(host_graph->get(), fld));
1113 
1114   if (VLOG_IS_ON(4)) {
1115     DumpGraphToFile(absl::StrCat("extract_outside_compilation_host_graph_for_",
1116                                  xla_cluster_name),
1117                     **host_graph, fld);
1118   }
1119 
1120   return OkStatus();
1121 }
1122 
1123 // Expand XLA computation's outside compilation host side graph into main graph.
1124 // Add a control edge between sequencer node and the XLA computation node.
ExpandHostGraphIntoMainGraph(Graph * main_graph,FunctionLibraryDefinition * fld,const string & host_graph_func_name,Node * xla_computation_node,Node * pivot_node)1125 Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
1126                                     FunctionLibraryDefinition* fld,
1127                                     const string& host_graph_func_name,
1128                                     Node* xla_computation_node,
1129                                     Node* pivot_node) {
1130   // Temporarily use "0" as "_device_ordinal". It will be rewritten with the
1131   // correct value in a later pass. We cannot just use placeholder value here
1132   // because FunctionDef instantiation does not allow placeholder value for
1133   // attributes.
1134   AttrValue device_ordinal_attr;
1135   device_ordinal_attr.set_i(0);
1136   protobuf::Map<string, AttrValue> attrs;
1137   attrs["_device_ordinal"] = device_ordinal_attr;
1138   std::unique_ptr<FunctionBody> fbody;
1139   const FunctionDef* host_graph_func = fld->Find(host_graph_func_name);
1140   TF_RET_CHECK(host_graph_func);
1141   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_graph_func,
1142                                              AttrSlice(&attrs), fld, &fbody));
1143   Graph* host_graph = fbody->graph;
1144 
1145   // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
1146   // reachable from sink node so all nodes will be copied.
1147   // TODO(b/77601805): consolidate copy graph functions.
1148   FixupSourceAndSinkEdges(host_graph);
1149 
1150   // Copy all nodes.
1151   std::map<const Node*, Node*> node_map;
1152   if (pivot_node) {
1153     node_map[host_graph->source_node()] = pivot_node;
1154   } else {
1155     node_map[host_graph->source_node()] = main_graph->source_node();
1156   }
1157   node_map[host_graph->sink_node()] = main_graph->sink_node();
1158   Status s = OkStatus();
1159   auto copy_node_fn = [&](const Node* n) {
1160     if (!s.ok()) {
1161       return;
1162     }
1163 
1164     Node* copy;
1165     if (node_map.find(n) != node_map.end()) {
1166       // Already copied this node.
1167       copy = node_map.at(n);
1168     } else {
1169       // Copy the node.
1170       NodeDef copy_def = n->def();
1171       copy = main_graph->AddNode(copy_def, &s);
1172       if (!s.ok()) {
1173         return;
1174       }
1175       node_map[n] = copy;
1176     }
1177 
1178     // Only handle input edges. Output edges will be added later as its output
1179     // nodes' input edges.
1180     for (auto e : n->in_edges()) {
1181       if (node_map.find(e->src()) == node_map.end()) {
1182         s = errors::Internal("Cannot find node image for ",
1183                              e->src()->DebugString());
1184         return;
1185       }
1186       main_graph->AddEdge(node_map[e->src()], e->src_output(), copy,
1187                           e->dst_input());
1188     }
1189 
1190     // Add control edge from sequencer to XLA computation node.
1191     if (copy->type_string() == "NoOp" &&
1192         HasNodeAttr(copy->def(), "_xla_host_transfer_sequencer")) {
1193       main_graph->AddControlEdge(copy, xla_computation_node);
1194     }
1195   };
1196   ReverseDFS(*host_graph, /*enter=*/nullptr, copy_node_fn, NodeComparatorID());
1197   return s;
1198 }
1199 
1200 // Rewrites shape inference graph for outside compilation:
1201 // 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from
1202 //    `host_graph`. Because we might still have outside compilation to outside
1203 //    compilation placeholder nodes in shape inference graph, which will prevent
1204 //    us from inferring XlaSendFromHost shape. But in `host_graph`, we already
1205 //    removed those placeholder nodes.
1206 // 2) Remove control edges.
1207 // 3) Prune nodes that are not useful for shape inference.
RewriteShapeInferenceGraph(const string & shape_inference_graph_name,Graph * host_graph,Node * pivot_node,FunctionLibraryDefinition * fld)1208 Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
1209                                   Graph* host_graph, Node* pivot_node,
1210                                   FunctionLibraryDefinition* fld) {
1211   // Use "0" as "_device_ordinal". It does not matter for shape inference.
1212   AttrValue device_ordinal_attr;
1213   device_ordinal_attr.set_i(0);
1214   protobuf::Map<string, AttrValue> attrs;
1215   attrs["_device_ordinal"] = device_ordinal_attr;
1216   std::unique_ptr<FunctionBody> fbody;
1217   const FunctionDef* shape_inference_graph =
1218       fld->Find(shape_inference_graph_name);
1219   TF_RET_CHECK(shape_inference_graph);
1220   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*shape_inference_graph,
1221                                              AttrSlice(&attrs), fld, &fbody));
1222   Graph* g = fbody->graph;
1223 
1224   // Find SendFromHost node.
1225   Node* send_from_host = nullptr;
1226   for (Node* n : g->nodes()) {
1227     if (n->type_string() == "_XlaSendFromHost") {
1228       send_from_host = n;
1229       break;
1230     }
1231   }
1232   if (!send_from_host) {
1233     return errors::Internal("Shape inference graph ",
1234                             shape_inference_graph_name,
1235                             " does not have _XlaSendFromHost node.");
1236   }
1237 
1238   // See if the SendFromHost node exists in `host_graph`.
1239   Node* send_node_in_host_graph = nullptr;
1240   for (Node* n : host_graph->nodes()) {
1241     if (n->name() == send_from_host->name()) {
1242       send_node_in_host_graph = n;
1243       break;
1244     }
1245   }
1246   if (send_node_in_host_graph) {
1247     // This is an "top-level" outside compilation. Clear the graph, and copy
1248     // SendFromHost and all its predecessors from `host_graph`.
1249     std::vector<Node*> nodes;
1250     nodes.reserve(g->num_op_nodes());
1251     for (Node* n : g->op_nodes()) {
1252       nodes.push_back(n);
1253     }
1254     for (Node* n : nodes) {
1255       g->RemoveNode(n);
1256     }
1257     Node* start_node = pivot_node ? pivot_node : host_graph->source_node();
1258     // Reverse DFS from send_from_host_main_graph, and stop at start_node.
1259     struct Visit {
1260       Node* n;
1261       bool is_exiting;
1262     };
1263     std::vector<Visit> stack{{send_node_in_host_graph, false}};
1264     std::map<Node*, Node*> node_map;
1265     node_map[host_graph->source_node()] = g->source_node();
1266     while (!stack.empty()) {
1267       Visit& curr = stack.back();
1268       if (curr.is_exiting) {
1269         if (node_map.find(curr.n) == node_map.end()) {
1270           Node* copy = g->CopyNode(curr.n);
1271           if (curr.n != start_node) {
1272             for (const Edge* e : curr.n->in_edges()) {
1273               auto node_iter = node_map.find(e->src());
1274               if (node_iter == node_map.end()) {
1275                 return errors::Internal("Cannot find node image for ",
1276                                         e->src()->DebugString());
1277               }
1278               g->AddEdge(node_iter->second, e->src_output(), copy,
1279                          e->dst_input());
1280             }
1281           }
1282           node_map[curr.n] = copy;
1283         }
1284         stack.pop_back();
1285       } else {
1286         curr.is_exiting = true;
1287         if (curr.n != start_node) {
1288           for (const Edge* e : curr.n->in_edges()) {
1289             if (node_map.find(e->src()) != node_map.end()) {
1290               continue;
1291             }
1292             stack.push_back({e->src(), false});
1293           }
1294         }
1295       }
1296     }
1297 
1298     send_from_host = node_map[send_node_in_host_graph];
1299   } else {
1300     // This is an outside compilation generated for If/While/gradient/etc.
1301     // It will be enough for shape inference. Leave `g` unchanged.
1302   }
1303 
1304   // Control edges are not useful for shape inference. Remove them.
1305   for (auto e : g->edges()) {
1306     if (e->IsControlEdge()) {
1307       g->RemoveEdge(e);
1308     }
1309   }
1310 
1311   // Nodes that are not reverse reachable from SendFromHost are not useful for
1312   // shape inference. Prune them.
1313   PruneForReverseReachability(g,
1314                               std::unordered_set<const Node*>{send_from_host});
1315 
1316   if (VLOG_IS_ON(4)) {
1317     DumpGraphToFile(shape_inference_graph_name, *g, fld);
1318   }
1319 
1320   // Replace original shape inference graph.
1321   FunctionDef fdef_replace;
1322   TF_RETURN_IF_ERROR(
1323       GraphToFunctionDef(*g, shape_inference_graph_name, &fdef_replace));
1324   TF_RETURN_IF_ERROR(
1325       fld->ReplaceFunction(shape_inference_graph_name, fdef_replace));
1326 
1327   return OkStatus();
1328 }
1329 
SetMaximalSharding(NodeDefBuilder & node_builder)1330 void SetMaximalSharding(NodeDefBuilder& node_builder) {
1331   xla::OpSharding sharding;
1332   sharding.set_type(xla::OpSharding::MAXIMAL);
1333   sharding.add_tile_assignment_dimensions(1);
1334   sharding.add_tile_assignment_devices(0);
1335   node_builder.Attr("_XlaSharding", sharding.SerializeAsString());
1336 }
1337 
1338 // Builds XlaSendToHost node which sends cond predicate to host.
BuildSendIfPredNode(const string & name,const string & host_transfer_key,Node * pred_node,Graph * g)1339 TF_ATTRIBUTE_NOINLINE StatusOr<Node*> BuildSendIfPredNode(
1340     const string& name, const string& host_transfer_key, Node* pred_node,
1341     Graph* g) {
1342   NodeDefBuilder send_pred_builder(name, "XlaSendToHost");
1343   send_pred_builder.Attr("Tinput", DT_BOOL);
1344   send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0"));
1345   send_pred_builder.Attr(kXlaTokenInputNodesAttrName,
1346                          std::vector<string>{kXlaTokenArgNodeName});
1347   send_pred_builder.Attr(kXlaOriginalOutsideCompilationNodeName, name);
1348   SetMaximalSharding(send_pred_builder);
1349   send_pred_builder.Input(pred_node->name(), 0, DT_BOOL);
1350   NodeDef send_pred_def;
1351   TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def));
1352   TF_ASSIGN_OR_RETURN(Node * send_pred_node, g->AddNode(send_pred_def));
1353   g->AddEdge(pred_node, 0, send_pred_node, 0);
1354   return send_pred_node;
1355 }
1356 
1357 // Replaces key placeholder node with an _Arg node.
ReplaceKeyPlaceholderWithArgNode(const string & xla_cluster_name,const string & func_name,FunctionLibraryDefinition * fld)1358 Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name,
1359                                         const string& func_name,
1360                                         FunctionLibraryDefinition* fld) {
1361   // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder
1362   // value after rewriting.
1363   AttrValue device_ordinal_attr;
1364   device_ordinal_attr.set_i(0);
1365   protobuf::Map<string, AttrValue> attrs;
1366   attrs["_device_ordinal"] = device_ordinal_attr;
1367   std::unique_ptr<FunctionBody> fbody;
1368   const FunctionDef* func = fld->Find(func_name);
1369   TF_RETURN_IF_ERROR(
1370       FunctionDefToBodyHelper(*func, AttrSlice(&attrs), fld, &fbody));
1371   Graph* g = fbody->graph;
1372 
1373   // Find or create the key placeholder node.
1374   Node* key_placeholder = nullptr;
1375   for (Node* n : g->nodes()) {
1376     if (IsKeyPlaceholderNode(*n)) {
1377       key_placeholder = n;
1378       break;
1379     }
1380   }
1381   if (!key_placeholder) {
1382     TF_ASSIGN_OR_RETURN(key_placeholder,
1383                         AddHostComputeKeyPlaceholder(xla_cluster_name, g));
1384   }
1385 
1386   // Build the _Arg node, and replace key placeholder node with it.
1387   NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp);
1388   arg_builder.Attr("T", DT_STRING);
1389   arg_builder.Attr("index", 0);
1390   NodeDef arg_def;
1391   TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
1392   TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status());
1393 
1394   // Reset "_device_ordinal" to placeholder value.
1395   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g));
1396 
1397   FunctionDef replace_fdef;
1398   TF_RETURN_IF_ERROR(GraphToFunctionDef(
1399       *g, func_name, HostGraphControlRetMapping, &replace_fdef));
1400   TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef));
1401   return OkStatus();
1402 }
1403 
1404 // Builds host side graph for If node.
BuildHostGraphForIfNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const string & if_node_name,const string & host_transfer_key,const string & host_graph_func_name,FunctionLibraryDefinition * fld,const string & then_branch_host_func_name,const string & else_branch_host_func_name)1405 TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode(
1406     const string& xla_cluster_attr_name,
1407     const string& outside_compilation_attr_name, const string& xla_cluster_name,
1408     const string& if_node_name, const string& host_transfer_key,
1409     const string& host_graph_func_name, FunctionLibraryDefinition* fld,
1410     const string& then_branch_host_func_name,
1411     const string& else_branch_host_func_name) {
1412   Graph host_graph(fld);
1413   string outside_compilation_name = absl::StrCat("oc_if_", if_node_name);
1414   AttrValue device_ordinal_value;
1415   device_ordinal_value.set_placeholder("_device_ordinal");
1416 
1417   // Step 1: add key placeholder node.
1418   TF_ASSIGN_OR_RETURN(
1419       Node * key_placeholder,
1420       AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1421 
1422   // Step 2: build XlaRecvAtHost node to recv predicate.
1423   NodeDefBuilder recv_pred_builder(
1424       absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost");
1425   recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL});
1426   recv_pred_builder.Attr("key", host_transfer_key);
1427   recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
1428   recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1429   recv_pred_builder.Attr(outside_compilation_attr_name,
1430                          outside_compilation_name);
1431   recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true);
1432   recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING);
1433   NodeDef recv_pred_def;
1434   TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
1435   TF_ASSIGN_OR_RETURN(Node * recv_pred_node, host_graph.AddNode(recv_pred_def));
1436   host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0);
1437 
1438   // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key
1439   // placeholder with an _Arg node.
1440   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1441       xla_cluster_name, then_branch_host_func_name, fld));
1442   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1443       xla_cluster_name, else_branch_host_func_name, fld));
1444 
1445   // Step 4: build If node to choose between `{then, else}_branch_host_graph`.
1446   NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If");
1447   if_builder.Attr("Tcond", DT_BOOL);
1448   if_builder.Attr("Tin", std::vector<DataType>{DT_STRING});
1449   if_builder.Attr("Tout", std::vector<DataType>{});
1450   NameAttrList host_then_branch, host_else_branch;
1451   host_then_branch.set_name(then_branch_host_func_name);
1452   (*host_then_branch.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1453   host_else_branch.set_name(else_branch_host_func_name);
1454   (*host_else_branch.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1455   if_builder.Attr("then_branch", host_then_branch);
1456   if_builder.Attr("else_branch", host_else_branch);
1457   if_builder.Attr(kXlaHasHostTransferAttrName, true);
1458   if_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1459   if_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
1460   if_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
1461   std::vector<NodeDefBuilder::NodeOut> if_inputs{
1462       {key_placeholder->name(), 0, DT_STRING}};
1463   if_builder.Input(if_inputs);
1464   NodeDef if_def;
1465   TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def));
1466   TF_ASSIGN_OR_RETURN(Node * if_node, host_graph.AddNode(if_def));
1467   host_graph.AddEdge(recv_pred_node, 0, if_node, 0);
1468   host_graph.AddEdge(key_placeholder, 0, if_node, 1);
1469 
1470   // Convert `host_graph` to function.
1471   FunctionDef oc_host_graph_fdef;
1472   TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1473                                         &oc_host_graph_fdef));
1474   if (fld->Find(host_graph_func_name)) {
1475     TF_RETURN_IF_ERROR(
1476         fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1477   } else {
1478     TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1479   }
1480 
1481   return OkStatus();
1482 }
1483 
1484 // Rewrites loop cond to add a node which sends loop cond to host.
AddSendLoopPredToLoopCond(const string & cond_xla_func_name,const string & host_transfer_key,NameAttrList * loop_cond_func,FunctionLibraryDefinition * fld,Node * while_node)1485 TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
1486     const string& cond_xla_func_name, const string& host_transfer_key,
1487     NameAttrList* loop_cond_func, FunctionLibraryDefinition* fld,
1488     Node* while_node) {
1489   // Instantiate the loop cond function.
1490   std::unique_ptr<FunctionBody> fbody;
1491   const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func->name());
1492   TF_RET_CHECK(loop_cond_fdef);
1493   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1494       *loop_cond_fdef, AttrSlice(&loop_cond_func->attr()), fld, &fbody));
1495   Graph* g = fbody->graph;
1496 
1497   // Find the _Retval node and the loop cond node.
1498   Node* ret_node = nullptr;
1499   for (Node* n : g->nodes()) {
1500     if (n->type_string() == "_Retval") {
1501       if (ret_node) {
1502         return errors::Internal("Multiple return node for loop cond function ",
1503                                 loop_cond_func->name(), ": ",
1504                                 ret_node->DebugString(), " and ",
1505                                 n->DebugString());
1506       } else {
1507         ret_node = n;
1508       }
1509     }
1510   }
1511   if (!ret_node) {
1512     return errors::Internal("No _Retval node for loop cond function ",
1513                             loop_cond_func->name());
1514   }
1515   Node* loop_cond;
1516   TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond));
1517 
1518   // Build the XlaSendToHost node.
1519   NodeDefBuilder send_loop_cond_builder(
1520       absl::StrCat("send_oc_while_cond_", while_node->name()), "XlaSendToHost");
1521   send_loop_cond_builder.Attr("Tinput", DT_BOOL);
1522   send_loop_cond_builder.Attr("key",
1523                               absl::StrCat(host_transfer_key, "_dtoh_0"));
1524   send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName,
1525                               std::vector<string>{kXlaTokenArgNodeName});
1526   send_loop_cond_builder.Attr(kXlaOriginalOutsideCompilationNodeName,
1527                               send_loop_cond_builder.node_name());
1528   SetMaximalSharding(send_loop_cond_builder);
1529   send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL);
1530   NodeDef send_loop_cond_def;
1531   TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def));
1532   TF_ASSIGN_OR_RETURN(Node * send_loop_cond_node,
1533                       g->AddNode(send_loop_cond_def));
1534   g->AddEdge(loop_cond, 0, send_loop_cond_node, 0);
1535 
1536   // Replace original function if loop_cond_func already has been re-written
1537   // for outside compilation.
1538   FunctionDef replace_fdef;
1539   if (loop_cond_func->name() == cond_xla_func_name) {
1540     TF_RETURN_IF_ERROR(
1541         GraphToFunctionDef(*g, loop_cond_func->name(), &replace_fdef));
1542     TF_RETURN_IF_ERROR(
1543         fld->ReplaceFunction(loop_cond_func->name(), replace_fdef));
1544   } else {
1545     // If original while cond function has not been modified, add a new function
1546     // with send loop predicated added and update the while node callsite
1547     // operation.
1548     const auto new_name = fld->UniqueFunctionName(
1549         absl::StrCat(loop_cond_func->name(), "_send_pred_added_"));
1550     TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, new_name, &replace_fdef));
1551     TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
1552     loop_cond_func->set_name(new_name);
1553     while_node->ClearAttr("cond");
1554     while_node->AddAttr("cond", *loop_cond_func);
1555   }
1556 
1557   return OkStatus();
1558 }
1559 
1560 // Rewrites while loop cond function for host.
RewriteHostWhileLoopCond(const string & cond_host_func_name,const string & while_node_name,const string & host_transfer_key,const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & outside_compilation_name,FunctionLibraryDefinition * fld)1561 Status RewriteHostWhileLoopCond(
1562     const string& cond_host_func_name, const string& while_node_name,
1563     const string& host_transfer_key, const string& xla_cluster_attr_name,
1564     const string& xla_cluster_name, const string& outside_compilation_attr_name,
1565     const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
1566   // Replace key placeholder node with _Arg node.
1567   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1568       xla_cluster_name, cond_host_func_name, fld));
1569 
1570   // Instantiate cond function.
1571   AttrValue device_ordinal_temp_value;
1572   device_ordinal_temp_value.set_i(0);
1573   protobuf::Map<string, AttrValue> attrs;
1574   attrs["_device_ordinal"] = device_ordinal_temp_value;
1575   std::unique_ptr<FunctionBody> cond_fbody;
1576   const FunctionDef* cond_host_func = fld->Find(cond_host_func_name);
1577   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*cond_host_func, AttrSlice(&attrs),
1578                                              fld, &cond_fbody));
1579   Graph* cond_graph = cond_fbody->graph;
1580   Node* key_arg = nullptr;
1581   for (Node* n : cond_graph->nodes()) {
1582     if (n->type_string() == "_Arg") {
1583       key_arg = n;
1584     }
1585   }
1586   if (!key_arg) {
1587     return errors::Internal(
1588         "No _Arg node found for host compute key in function ",
1589         cond_host_func_name);
1590   }
1591 
1592   // Add an XlaRecvAtHost node to use as cond function return value.
1593   NodeDefBuilder recv_pred_builder(
1594       absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost");
1595   recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL});
1596   recv_pred_builder.Attr("key", host_transfer_key);
1597   AttrValue device_ordinal_value;
1598   device_ordinal_value.set_placeholder("_device_ordinal");
1599   recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
1600   recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1601   recv_pred_builder.Attr(outside_compilation_attr_name,
1602                          outside_compilation_name);
1603   recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true);
1604   recv_pred_builder.Input(key_arg->name(), 0, DT_STRING);
1605   NodeDef recv_pred_def;
1606   TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
1607   TF_ASSIGN_OR_RETURN(Node * recv_pred_node,
1608                       cond_graph->AddNode(recv_pred_def));
1609   cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0);
1610   NodeDefBuilder ret_builder(
1611       absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval");
1612   ret_builder.Attr("T", DT_BOOL);
1613   ret_builder.Attr("index", 0);
1614   ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
1615   NodeDef ret_def;
1616   TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1617   TF_ASSIGN_OR_RETURN(Node * ret_node, cond_graph->AddNode(ret_def));
1618   cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0);
1619 
1620   // Reset device_ordinal to placeholder value.
1621   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph));
1622 
1623   // Replace original function.
1624   FunctionDef cond_replace_fdef;
1625   TF_RETURN_IF_ERROR(GraphToFunctionDef(*cond_graph, cond_host_func_name,
1626                                         HostGraphControlRetMapping,
1627                                         &cond_replace_fdef));
1628   TF_RETURN_IF_ERROR(
1629       fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef));
1630 
1631   return OkStatus();
1632 }
1633 
1634 // Rewrites while loop body function for host.
RewriteHostWhileLoopBody(const string & body_host_func_name,const string & while_node_name,const string & host_transfer_key,const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & outside_compilation_name,FunctionLibraryDefinition * fld)1635 Status RewriteHostWhileLoopBody(
1636     const string& body_host_func_name, const string& while_node_name,
1637     const string& host_transfer_key, const string& xla_cluster_attr_name,
1638     const string& xla_cluster_name, const string& outside_compilation_attr_name,
1639     const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
1640   // Replace key placeholder node with _Arg node.
1641   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1642       xla_cluster_name, body_host_func_name, fld));
1643 
1644   // Instantiate body function.
1645   AttrValue device_ordinal_temp_value;
1646   device_ordinal_temp_value.set_i(0);
1647   protobuf::Map<string, AttrValue> attrs;
1648   attrs["_device_ordinal"] = device_ordinal_temp_value;
1649   std::unique_ptr<FunctionBody> body_fbody;
1650   const FunctionDef* body_host_func = fld->Find(body_host_func_name);
1651   TF_RET_CHECK(body_host_func);
1652   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*body_host_func, AttrSlice(&attrs),
1653                                              fld, &body_fbody));
1654   Graph* body_graph = body_fbody->graph;
1655   Node* key_arg = nullptr;
1656   for (Node* n : body_graph->nodes()) {
1657     if (n->type_string() == "_Arg") {
1658       key_arg = n;
1659     }
1660   }
1661   if (!key_arg) {
1662     return errors::Internal(
1663         "No _Arg node found for host compute key in function ",
1664         body_host_func_name);
1665   }
1666 
1667   // Add a _Retval node to loop body.
1668   NodeDefBuilder ret_builder(
1669       absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval");
1670   ret_builder.Attr("T", DT_STRING);
1671   ret_builder.Attr("index", 0);
1672   ret_builder.Input(key_arg->name(), 0, DT_STRING);
1673   NodeDef ret_def;
1674   TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1675   TF_ASSIGN_OR_RETURN(Node * ret_node, body_graph->AddNode(ret_def));
1676   body_graph->AddEdge(key_arg, 0, ret_node, 0);
1677 
1678   // Reset device_ordinal to placeholder value.
1679   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph));
1680 
1681   // Replace original function.
1682   FunctionDef body_replace_fdef;
1683   TF_RETURN_IF_ERROR(GraphToFunctionDef(*body_graph, body_host_func_name,
1684                                         HostGraphControlRetMapping,
1685                                         &body_replace_fdef));
1686   TF_RETURN_IF_ERROR(
1687       fld->ReplaceFunction(body_host_func_name, body_replace_fdef));
1688 
1689   return OkStatus();
1690 }
1691 
1692 // Builds host side graph for while node.
BuildHostGraphForWhileNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const string & while_node_name,const string & host_transfer_key,const string & host_graph_func_name,FunctionLibraryDefinition * fld,const string & cond_host_func_name,const string & body_host_func_name)1693 TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForWhileNode(
1694     const string& xla_cluster_attr_name,
1695     const string& outside_compilation_attr_name, const string& xla_cluster_name,
1696     const string& while_node_name, const string& host_transfer_key,
1697     const string& host_graph_func_name, FunctionLibraryDefinition* fld,
1698     const string& cond_host_func_name, const string& body_host_func_name) {
1699   Graph host_graph(fld);
1700   string outside_compilation_name = absl::StrCat("oc_while_", while_node_name);
1701 
1702   // Step 1: add key placeholder node.
1703   TF_ASSIGN_OR_RETURN(
1704       Node * key_placeholder,
1705       AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1706 
1707   // Step 2: rewrite cond function.
1708   TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond(
1709       cond_host_func_name, while_node_name, host_transfer_key,
1710       xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1711       outside_compilation_name, fld));
1712 
1713   // Step 3: rewrite body function.
1714   TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody(
1715       body_host_func_name, while_node_name, host_transfer_key,
1716       xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1717       outside_compilation_name, fld));
1718 
1719   // Step 4: build While node.
1720   NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name),
1721                                "While");
1722   while_builder.Attr("T", std::vector<DataType>{DT_STRING});
1723   NameAttrList func;
1724   AttrValue device_ordinal_value;
1725   device_ordinal_value.set_placeholder("_device_ordinal");
1726   (*func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1727   func.set_name(cond_host_func_name);
1728   while_builder.Attr("cond", func);
1729   func.set_name(body_host_func_name);
1730   while_builder.Attr("body", func);
1731   while_builder.Attr(kXlaHasHostTransferAttrName, true);
1732   while_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1733   while_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
1734   // Make sure loop body of i-th iteration happens before loop cond of (i+1)-th
1735   // iteration.
1736   while_builder.Attr("parallel_iterations", 1);
1737   std::vector<NodeDefBuilder::NodeOut> while_inputs{
1738       {key_placeholder->name(), 0, DT_STRING}};
1739   while_builder.Input(while_inputs);
1740   NodeDef while_def;
1741   TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def));
1742   TF_ASSIGN_OR_RETURN(Node * while_node, host_graph.AddNode(while_def));
1743   host_graph.AddEdge(key_placeholder, 0, while_node, 0);
1744 
1745   // Convert `host_graph` to function.
1746   FunctionDef oc_host_graph_fdef;
1747   TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1748                                         &oc_host_graph_fdef));
1749   if (fld->Find(host_graph_func_name)) {
1750     TF_RETURN_IF_ERROR(
1751         fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1752   } else {
1753     TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1754   }
1755 
1756   return OkStatus();
1757 }
1758 
1759 // Builds host graph for func call nodes.
BuildHostGraphForFuncCallNode(const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & func_call_node_name,const string & func_call_host_func_name,const string & host_graph_func_name,FunctionLibraryDefinition * fld)1760 Status BuildHostGraphForFuncCallNode(
1761     const string& xla_cluster_attr_name, const string& xla_cluster_name,
1762     const string& outside_compilation_attr_name,
1763     const string& func_call_node_name, const string& func_call_host_func_name,
1764     const string& host_graph_func_name, FunctionLibraryDefinition* fld) {
1765   Graph host_graph(fld);
1766   AttrValue device_ordinal_value;
1767   device_ordinal_value.set_placeholder("_device_ordinal");
1768 
1769   // Step 1: add key placeholder node.
1770   TF_ASSIGN_OR_RETURN(
1771       Node * key_placeholder,
1772       AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1773 
1774   // Step 2: rewrite `host_func_name`, replace key placeholder with an _Arg
1775   // node.
1776   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1777       xla_cluster_name, func_call_host_func_name, fld));
1778 
1779   // Step 3: build a function call node with `host_func_name`, with
1780   // `key_placeholder` as input.
1781   NodeDefBuilder call_builder(absl::StrCat("oc_call_", func_call_node_name),
1782                               func_call_host_func_name, fld);
1783   call_builder.Input(key_placeholder->name(), 0, DT_STRING);
1784   call_builder.Attr("_device_ordinal", device_ordinal_value);
1785   call_builder.Attr(kXlaHasHostTransferAttrName, true);
1786   call_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1787   call_builder.Attr(outside_compilation_attr_name, call_builder.node_name());
1788   NodeDef call_def;
1789   TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def));
1790   TF_ASSIGN_OR_RETURN(Node * call_node, host_graph.AddNode(call_def));
1791   host_graph.AddEdge(key_placeholder, 0, call_node, 0);
1792 
1793   // Convert `host_graph` to function.
1794   FunctionDef oc_host_graph_fdef;
1795   TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1796                                         HostGraphControlRetMapping,
1797                                         &oc_host_graph_fdef));
1798   if (fld->Find(host_graph_func_name)) {
1799     TF_RETURN_IF_ERROR(
1800         fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1801   } else {
1802     TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1803   }
1804 
1805   return OkStatus();
1806 }
1807 
ExtractOutsideCompilationForFuncCallNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)1808 TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode(
1809     const string& xla_cluster_attr_name,
1810     const string& outside_compilation_attr_name, const string& xla_cluster_name,
1811     const std::map<string, int>& host_compute_core, Graph* g, Node* n,
1812     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
1813     std::vector<string>* host_graphs,
1814     std::vector<string>* shape_inference_graphs,
1815     bool* has_outside_compilation) {
1816   bool func_has_outside_compilation = false;
1817   NameAttrList func;
1818   if (fld->Contains(n->type_string())) {
1819     func.set_name(n->type_string());
1820     typedef protobuf::Map<string, AttrValue> AttrMap;
1821     *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end());
1822   } else if (n->IsPartitionedCall()) {
1823     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func));
1824   } else {
1825     TF_RET_CHECK(n->type_string() == FunctionLibraryDefinition::kGradientOp);
1826     func.set_name(FunctionLibraryDefinition::kGradientOp);
1827     *func.mutable_attr() = n->def().attr();
1828   }
1829   string canonical_func_name;
1830   if (func.name() == FunctionLibraryDefinition::kGradientOp) {
1831     NameAttrList forward_func;
1832     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &forward_func));
1833     canonical_func_name = absl::StrCat("gradient_", forward_func.name());
1834   } else {
1835     canonical_func_name = func.name();
1836   }
1837   string new_func_name = absl::StrCat(canonical_func_name, "_oc");
1838   string host_func_name =
1839       absl::StrCat("oc_func_call_host_", canonical_func_name);
1840   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1841       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1842       func, new_func_name, host_func_name, host_compute_core, flr, fld,
1843       shape_inference_graphs, &func_has_outside_compilation));
1844 
1845   // If the function call does not have outside compilation, nothing to do.
1846   if (!func_has_outside_compilation) {
1847     return OkStatus();
1848   }
1849 
1850   *has_outside_compilation = true;
1851 
1852   // Change `n` to call the new function directly.
1853   auto replace_builder =
1854       std::make_unique<NodeDefBuilder>(n->name(), new_func_name, fld);
1855   std::vector<NodeDefBuilder::NodeOut> inputs(n->num_inputs());
1856   for (const Edge* e : n->in_edges()) {
1857     if (e->IsControlEdge()) {
1858       continue;
1859     }
1860 
1861     const bool input_size_check =
1862         e->dst_input() < static_cast<int>(inputs.size());
1863     TF_RET_CHECK(e->dst_input() >= 0 && input_size_check);
1864     inputs[e->dst_input()] =
1865         NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
1866                                 e->src()->output_type(e->src_output())};
1867   }
1868   for (const auto& input : inputs) {
1869     replace_builder->Input(input);
1870   }
1871   for (const auto& attr : n->attrs()) {
1872     replace_builder->Attr(attr.first, attr.second);
1873   }
1874   auto replace_def = std::make_unique<NodeDef>();
1875   TF_RETURN_IF_ERROR(replace_builder->Finalize(replace_def.get()));
1876   TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, *replace_def));
1877   replace->AddAttr(kXlaTokenInputNodesAttrName,
1878                    std::vector<string>{kXlaTokenArgNodeName});
1879   replace->AddAttr(kXlaOriginalOutsideCompilationNodeName, replace->name());
1880 
1881   // Build host side graph for the function call.
1882   string oc_host_graph_name =
1883       absl::StrCat("oc_func_host_graph_", replace->name());
1884   TF_RETURN_IF_ERROR(BuildHostGraphForFuncCallNode(
1885       xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1886       replace->name(), host_func_name, oc_host_graph_name, fld));
1887 
1888   // Record the host graph.
1889   host_graphs->push_back(oc_host_graph_name);
1890 
1891   return OkStatus();
1892 }
1893 
ExtractOutsideCompilationForIfNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)1894 Status ExtractOutsideCompilationForIfNode(
1895     const string& xla_cluster_attr_name,
1896     const string& outside_compilation_attr_name, const string& xla_cluster_name,
1897     const std::map<string, int>& host_compute_core, Graph* g, Node* n,
1898     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
1899     std::vector<string>* host_graphs,
1900     std::vector<string>* shape_inference_graphs,
1901     bool* has_outside_compilation) {
1902   // Instantiate "then_branch" and "else_branch".
1903   NameAttrList then_branch, else_branch;
1904   TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch));
1905   TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch));
1906 
1907   // Extract outside compilation for then_branch and else_branch.
1908   bool then_branch_has_outside_compilation = false;
1909   bool else_branch_has_outside_compilation = false;
1910   string then_branch_host_func_name =
1911              absl::StrCat("oc_then_branch_host_if_", then_branch.name()),
1912          else_branch_host_func_name =
1913              absl::StrCat("oc_else_branch_host_if_", else_branch.name());
1914   string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"),
1915          else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc");
1916   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1917       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1918       then_branch, then_branch_xla_func_name, then_branch_host_func_name,
1919       host_compute_core, flr, fld, shape_inference_graphs,
1920       &then_branch_has_outside_compilation));
1921   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1922       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1923       else_branch, else_branch_xla_func_name, else_branch_host_func_name,
1924       host_compute_core, flr, fld, shape_inference_graphs,
1925       &else_branch_has_outside_compilation));
1926 
1927   // If then/else branch do not have outside compilation, nothing to do.
1928   if (!then_branch_has_outside_compilation &&
1929       !else_branch_has_outside_compilation) {
1930     return OkStatus();
1931   }
1932 
1933   *has_outside_compilation = true;
1934 
1935   // Change If node to call the new functions.
1936   if (then_branch_has_outside_compilation) {
1937     then_branch.set_name(then_branch_xla_func_name);
1938     n->ClearAttr("then_branch");
1939     n->AddAttr("then_branch", then_branch);
1940   }
1941   if (else_branch_has_outside_compilation) {
1942     else_branch.set_name(else_branch_xla_func_name);
1943     n->ClearAttr("else_branch");
1944     n->AddAttr("else_branch", else_branch);
1945   }
1946   n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name());
1947 
1948   string host_transfer_key = absl::StrCat("oc_if_pred_", n->name());
1949 
1950   // XLA computation: add a SendToHost node to send cond predicate.
1951   Node* pred_node;
1952   TF_RETURN_IF_ERROR(n->input_node(0, &pred_node));
1953   TF_ASSIGN_OR_RETURN(
1954       Node * send_pred_node,
1955       BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()),
1956                           host_transfer_key, pred_node, g));
1957   n->AddAttr(kXlaTokenInputNodesAttrName,
1958              std::vector<string>{send_pred_node->name()});
1959 
1960   // Add a control edge from `send_pred_node` to If node, so XlaCompiler will
1961   // visit If node after `send_pred_node`, thus the token output for
1962   // `send_pred_node` has been generated.
1963   g->AddControlEdge(send_pred_node, n);
1964 
1965   // Build host side graph for the "If" node.
1966   // If then/else branch does not have outside compilation, we won't build host
1967   // graph for the branch. But here we need a host graph for both branches, so
1968   // we need to create a no-op host graph.
1969   if (!then_branch_has_outside_compilation) {
1970     std::unique_ptr<Graph> then_branch_host_graph(new Graph(fld));
1971     std::vector<string> then_branch_host_graphs;
1972     TF_RETURN_IF_ERROR(ConstructHostGraph(
1973         xla_cluster_name, outside_compilation_attr_name,
1974         then_branch_host_graphs, fld, &then_branch_host_graph));
1975     FunctionDef then_branch_host_fdef;
1976     TF_RETURN_IF_ERROR(GraphToFunctionDef(*then_branch_host_graph,
1977                                           then_branch_host_func_name,
1978                                           &then_branch_host_fdef));
1979     if (fld->Find(then_branch_host_func_name)) {
1980       TF_RETURN_IF_ERROR(fld->ReplaceFunction(then_branch_host_func_name,
1981                                               then_branch_host_fdef));
1982     } else {
1983       TF_RETURN_IF_ERROR(fld->AddFunctionDef(then_branch_host_fdef));
1984     }
1985   }
1986   if (!else_branch_has_outside_compilation) {
1987     std::unique_ptr<Graph> else_branch_host_graph(new Graph(fld));
1988     std::vector<string> else_branch_host_graphs;
1989     TF_RETURN_IF_ERROR(ConstructHostGraph(
1990         xla_cluster_name, outside_compilation_attr_name,
1991         else_branch_host_graphs, fld, &else_branch_host_graph));
1992     FunctionDef else_branch_host_fdef;
1993     TF_RETURN_IF_ERROR(GraphToFunctionDef(*else_branch_host_graph,
1994                                           else_branch_host_func_name,
1995                                           &else_branch_host_fdef));
1996     if (fld->Find(else_branch_host_func_name)) {
1997       TF_RETURN_IF_ERROR(fld->ReplaceFunction(else_branch_host_func_name,
1998                                               else_branch_host_fdef));
1999     } else {
2000       TF_RETURN_IF_ERROR(fld->AddFunctionDef(else_branch_host_fdef));
2001     }
2002   }
2003   string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name());
2004   TF_RETURN_IF_ERROR(BuildHostGraphForIfNode(
2005       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2006       n->name(), host_transfer_key, oc_host_graph_name, fld,
2007       then_branch_host_func_name, else_branch_host_func_name));
2008   host_graphs->push_back(oc_host_graph_name);
2009 
2010   return OkStatus();
2011 }
2012 
ExtractOutsideCompilationForWhileNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2013 Status ExtractOutsideCompilationForWhileNode(
2014     const string& xla_cluster_attr_name,
2015     const string& outside_compilation_attr_name, const string& xla_cluster_name,
2016     const std::map<string, int>& host_compute_core, Graph* g, Node* n,
2017     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
2018     std::vector<string>* host_graphs,
2019     std::vector<string>* shape_inference_graphs,
2020     bool* has_outside_compilation) {
2021   // Instantiate "cond" and "body".
2022   NameAttrList cond, body;
2023   TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond));
2024   TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body));
2025 
2026   // Extract outside compilation for cond and body.
2027   bool cond_has_outside_compilation = false;
2028   bool body_has_outside_compilation = false;
2029   string cond_host_func_name = absl::StrCat("oc_cond_host_while_", cond.name()),
2030          body_host_func_name = absl::StrCat("oc_body_host_while_", body.name());
2031   string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"),
2032          body_xla_func_name = absl::StrCat(body.name(), "_oc");
2033   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2034       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2035       cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr,
2036       fld, shape_inference_graphs, &cond_has_outside_compilation));
2037   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2038       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2039       body, body_xla_func_name, body_host_func_name, host_compute_core, flr,
2040       fld, shape_inference_graphs, &body_has_outside_compilation));
2041 
2042   // If cond/body do not have outside compilation, nothing to do.
2043   if (!cond_has_outside_compilation && !body_has_outside_compilation) {
2044     return OkStatus();
2045   }
2046 
2047   *has_outside_compilation = true;
2048 
2049   // Change While node to call the new functions.
2050   if (cond_has_outside_compilation) {
2051     cond.set_name(cond_xla_func_name);
2052     n->ClearAttr("cond");
2053     n->AddAttr("cond", cond);
2054   }
2055   if (body_has_outside_compilation) {
2056     body.set_name(body_xla_func_name);
2057     n->ClearAttr("body");
2058     n->AddAttr("body", body);
2059   }
2060   n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name());
2061 
2062   string host_transfer_key = absl::StrCat("oc_while_pred_", n->name());
2063 
2064   // XLA computation: rewrite cond function to add a SendToHost node to send
2065   // loop predicate.
2066   TF_RETURN_IF_ERROR(AddSendLoopPredToLoopCond(
2067       cond_xla_func_name, host_transfer_key, &cond, fld, n));
2068   n->AddAttr(kXlaTokenInputNodesAttrName,
2069              std::vector<string>{kXlaTokenArgNodeName});
2070 
2071   // Build host side graph for the "While" node.
2072   if (!cond_has_outside_compilation) {
2073     std::unique_ptr<Graph> cond_host_graph(new Graph(fld));
2074     std::vector<string> host_graphs;
2075     TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
2076                                           outside_compilation_attr_name,
2077                                           host_graphs, fld, &cond_host_graph));
2078     FunctionDef cond_host_fdef;
2079     TF_RETURN_IF_ERROR(GraphToFunctionDef(*cond_host_graph, cond_host_func_name,
2080                                           &cond_host_fdef));
2081     if (fld->Find(cond_host_func_name)) {
2082       TF_RETURN_IF_ERROR(
2083           fld->ReplaceFunction(cond_host_func_name, cond_host_fdef));
2084     } else {
2085       TF_RETURN_IF_ERROR(fld->AddFunctionDef(cond_host_fdef));
2086     }
2087   }
2088   if (!body_has_outside_compilation) {
2089     std::unique_ptr<Graph> body_host_graph(new Graph(fld));
2090     std::vector<string> host_graphs;
2091     TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
2092                                           outside_compilation_attr_name,
2093                                           host_graphs, fld, &body_host_graph));
2094     FunctionDef body_host_fdef;
2095     TF_RETURN_IF_ERROR(GraphToFunctionDef(*body_host_graph, body_host_func_name,
2096                                           &body_host_fdef));
2097     if (fld->Find(body_host_func_name)) {
2098       TF_RETURN_IF_ERROR(
2099           fld->ReplaceFunction(body_host_func_name, body_host_fdef));
2100     } else {
2101       TF_RETURN_IF_ERROR(fld->AddFunctionDef(body_host_fdef));
2102     }
2103   }
2104   string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name());
2105   TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode(
2106       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2107       n->name(), host_transfer_key, oc_host_graph_name, fld,
2108       cond_host_func_name, body_host_func_name));
2109   host_graphs->push_back(oc_host_graph_name);
2110 
2111   return OkStatus();
2112 }
2113 
ExtractOutsideCompilationForNodesWithAssociatedFunctions(Graph * g,const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2114 Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
2115     Graph* g, const string& xla_cluster_attr_name,
2116     const string& outside_compilation_attr_name, const string& xla_cluster_name,
2117     const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
2118     FunctionLibraryDefinition* fld, std::vector<string>* host_graphs,
2119     std::vector<string>* shape_inference_graphs,
2120     bool* has_outside_compilation) {
2121   std::vector<Node*> if_nodes, while_nodes, func_call_nodes;
2122   for (Node* n : g->nodes()) {
2123     if (n->IsIfNode()) {
2124       if_nodes.push_back(n);
2125     } else if (n->IsWhileNode()) {
2126       while_nodes.push_back(n);
2127     } else if (IsFunctionCall(*fld, *n)) {
2128       func_call_nodes.push_back(n);
2129     }
2130   }
2131 
2132   for (Node* n : func_call_nodes) {
2133     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFuncCallNode(
2134         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2135         host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2136         has_outside_compilation));
2137   }
2138 
2139   for (Node* n : if_nodes) {
2140     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForIfNode(
2141         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2142         host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2143         has_outside_compilation));
2144   }
2145 
2146   for (Node* n : while_nodes) {
2147     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForWhileNode(
2148         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2149         host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2150         has_outside_compilation));
2151   }
2152 
2153   return OkStatus();
2154 }
2155 
CopyOutsideCompilationConstNodes(Graph * g,const string & outside_compilation_attr_name)2156 Status CopyOutsideCompilationConstNodes(
2157     Graph* g, const string& outside_compilation_attr_name) {
2158   for (Node* n : g->op_nodes()) {
2159     if (!n->IsConstant() ||
2160         !HasNodeAttr(n->def(), outside_compilation_attr_name)) {
2161       continue;
2162     }
2163 
2164     std::vector<const Edge*> out_edges(n->out_edges().begin(),
2165                                        n->out_edges().end());
2166     bool has_non_oc_output = false;
2167     for (const Edge* e : out_edges) {
2168       if (!e->IsControlEdge() &&
2169           !HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
2170         has_non_oc_output = true;
2171         break;
2172       }
2173     }
2174     if (!has_non_oc_output) {
2175       continue;
2176     }
2177 
2178     NodeDef copy_def = n->def();
2179     copy_def.set_name(g->NewName(n->name()));
2180     copy_def.mutable_attr()->erase(outside_compilation_attr_name);
2181     TF_ASSIGN_OR_RETURN(Node * copy_node, g->AddNode(copy_def));
2182     for (const Edge* e : n->in_edges()) {
2183       if (e->IsControlEdge()) {
2184         g->AddControlEdge(e->src(), copy_node);
2185       }
2186     }
2187     for (const Edge* e : out_edges) {
2188       if (!e->IsControlEdge() &&
2189           !HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
2190         Node* dst = e->dst();
2191         int dst_input = e->dst_input();
2192         g->RemoveEdge(e);
2193         g->AddEdge(copy_node, 0, dst, dst_input);
2194       }
2195     }
2196   }
2197 
2198   return OkStatus();
2199 }
2200 
2201 }  // namespace
2202 
operator ()(const std::vector<OutputTensor> & arg_source_tensors,std::unique_ptr<Graph> * graph,std::vector<int> * input_permutation,std::vector<int> * output_permutation,NodeDef * node_def)2203 Status RewriteOutsideCompilationSubgraphFn::operator()(
2204     const std::vector<OutputTensor>& arg_source_tensors,
2205     std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
2206     std::vector<int>* output_permutation, NodeDef* node_def) {
2207   string old_name = node_def->op();
2208   string new_name =
2209       absl::StrCat(xla_cluster_name_, "_", new_function_name_, "_", old_name);
2210   node_def->set_op(new_name);
2211   node_def->set_name(new_name);
2212 
2213   // Later we will run PruneForReverseReachability(), so make sure all original
2214   // nodes are reachable from sink node and won't be removed.
2215   FixupSourceAndSinkEdges(graph->get());
2216 
2217   // Step 1: create a key placeholder node.
2218   TF_ASSIGN_OR_RETURN(
2219       Node * key_placeholder,
2220       AddHostComputeKeyPlaceholder(xla_cluster_name_, graph->get()));
2221 
2222   // Step 2: build RecvAtHost node, and replace all _Arg nodes with it.
2223   std::vector<DataType> recv_at_host_dtypes;
2224   TF_ASSIGN_OR_RETURN(
2225       Node * recv_at_host_node,
2226       ReplaceArgNodesWithRecvAtHostNode(graph->get(), new_name,
2227                                         &recv_at_host_dtypes, key_placeholder));
2228 
2229   // Step 3: build SendFromHost node, and replace all _Retval nodes with it.
2230   std::vector<DataType> send_from_host_dtypes;
2231   TF_ASSIGN_OR_RETURN(
2232       Node * send_from_host_node,
2233       ReplaceRetNodesWithSendFromHostNode(
2234           graph->get(), new_name, &send_from_host_dtypes, key_placeholder));
2235 
2236   // Step 4: add XLA cluster and outside compilation attr.
2237   for (Node* n : (*graph)->nodes()) {
2238     if (IsKeyPlaceholderNode(*n)) {
2239       continue;
2240     }
2241 
2242     n->AddAttr(xla_cluster_attr_name_, xla_cluster_name_);
2243     n->AddAttr(outside_compilation_attr_name_, old_name);
2244   }
2245 
2246   // Check whether we have all input shapes for XlaSendFromHost. If we do, we
2247   // will set `shapes` attr for the call node; otherwise we will save the
2248   // shape inference graph and set `shape_inference_graph` for the call node.
2249   std::optional<std::vector<PartialTensorShape>> shapes =
2250       GetInferredInputShapes(send_from_host_dtypes.size(), send_from_host_node);
2251   for (Node* n : (*graph)->nodes()) {
2252     n->ClearAttr(kXlaInferredShapesAttrName);
2253   }
2254 
2255   // Step 5: add control edges for originally XLA <-> outside compilation
2256   // control edges.
2257   for (Node* n : (*graph)->nodes()) {
2258     if (HasNodeAttr(n->def(), kXlaConnectedToXlaComputationAttrName)) {
2259       (*graph)->AddControlEdge(n, send_from_host_node);
2260       n->ClearAttr(kXlaConnectedToXlaComputationAttrName);
2261     }
2262     if (HasNodeAttr(n->def(), kXlaConnectedFromXlaComputationAttrName)) {
2263       (*graph)->AddControlEdge(recv_at_host_node, n);
2264       n->ClearAttr(kXlaConnectedFromXlaComputationAttrName);
2265     }
2266   }
2267 
2268   // Step 6: RecvAtHost/SendFromHost/key_placeholder might be dead nodes. Prune
2269   // them if necessary.
2270   // - RecvAtHost should be pruned iff it has no output data/control edges. If
2271   //   it has any output edge, it will be reverse reachable from sink node. We
2272   //   don't need to do anything special.
2273   // - SendFromHost should be pruned iff it has no input data/control edges. If
2274   //   it has input edges other than key_placeholder, we connect it to sink
2275   //   node so it won't be pruned.
2276   // - key_placeholder should be pruned iff RecvAtHost/SendFromHost are pruned.
2277   //   We don't need to do anything special.
2278   if (send_from_host_node->in_edges().size() > 1) {
2279     (*graph)->AddControlEdge(send_from_host_node, (*graph)->sink_node());
2280   }
2281   PruneForReverseReachability(
2282       graph->get(), std::unordered_set<const Node*>{(*graph)->sink_node()});
2283 
2284   // Step 7: add necessary attributes to function call node, so we can replace
2285   // it with HostCompute node later.
2286   AddNodeAttr("_outside_compilation_subgraph", old_name, node_def);
2287   if (shapes) {
2288     NameAttrList shape_inference_graph;
2289     AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
2290     AddNodeAttr("shapes", *shapes, node_def);
2291   } else {
2292     string shape_inference_func_name =
2293         absl::StrCat("_outside_compilation_shape_inference_", new_name);
2294     NameAttrList shape_inference_graph;
2295     shape_inference_graph.set_name(shape_inference_func_name);
2296     AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
2297     AddNodeAttr("shapes", std::vector<TensorShapeProto>{}, node_def);
2298   }
2299   AddNodeAttr("ancestors", std::vector<string>{}, node_def);
2300   AddNodeAttr("Tinputs", recv_at_host_dtypes, node_def);
2301   AddNodeAttr("Toutputs", send_from_host_dtypes, node_def);
2302   AddNodeAttr("key", absl::StrCat("host_compute_channel_", new_name), node_def);
2303 
2304   return OkStatus();
2305 }
2306 
ExtractOutsideCompilationForFunction(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const NameAttrList & func_name_attrs,const string & new_func_name,const string & host_graph_func_name,const std::map<string,int> & host_compute_core,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2307 Status ExtractOutsideCompilationForFunction(
2308     const string& xla_cluster_attr_name,
2309     const string& outside_compilation_attr_name, const string& xla_cluster_name,
2310     const NameAttrList& func_name_attrs, const string& new_func_name,
2311     const string& host_graph_func_name,
2312     const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
2313     FunctionLibraryDefinition* fld, std::vector<string>* shape_inference_graphs,
2314     bool* has_outside_compilation) {
2315   // Convert the function to graph.
2316   const string& func_name = func_name_attrs.name();
2317   FunctionLibraryRuntime::Handle handle;
2318   TF_RETURN_IF_ERROR(
2319       flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle));
2320   Status ret_status = OkStatus();
2321   auto cleanup_handle = gtl::MakeCleanup([&]() {
2322     auto s = flr->ReleaseHandle(handle);
2323     if (!s.ok()) {
2324       ret_status.Update(s);
2325     }
2326   });
2327   const FunctionBody* fbody = flr->GetFunctionBody(handle);
2328 
2329   // Check if we have outside compilation nodes.
2330   *has_outside_compilation = false;
2331   for (Node* n : fbody->graph->nodes()) {
2332     if (HasNodeAttr(n->def(), outside_compilation_attr_name)) {
2333       *has_outside_compilation = true;
2334       break;
2335     }
2336   }
2337   // We cannot early return here, because we might have outside compilation in
2338   // If/While function body.
2339 
2340   if (VLOG_IS_ON(4)) {
2341     DumpGraphToFile(
2342         absl::StrCat("extract_outside_compilation_for_func_before_", func_name),
2343         *fbody->graph, fld);
2344   }
2345 
2346   std::unique_ptr<Graph> graph_out;
2347   std::vector<string> outside_compilation_host_graphs;
2348   std::vector<string> shape_inference_graphs_to_rewrite;
2349   if (*has_outside_compilation) {
2350     // Copy outside compilation Const nodes with non outside compilation users.
2351     TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes(
2352         fbody->graph, outside_compilation_attr_name));
2353 
2354     // Find dependencies between outside compilation clusters.
2355     TF_ASSIGN_OR_RETURN(auto cluster_deps,
2356                         OutsideCompilationClusterDependencies(
2357                             fbody->graph, outside_compilation_attr_name));
2358 
2359     // Preprocess edges between different outside compilations. They will be
2360     // restored in `ConstructHostGraph()`.
2361     TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
2362         fbody->graph, outside_compilation_attr_name));
2363 
2364     // Encapsulate outside_compilation cluster into function call node.
2365     auto rewrite_fn = std::make_unique<RewriteOutsideCompilationSubgraphFn>(
2366         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2367         new_func_name);
2368     TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
2369         outside_compilation_attr_name, *fbody->graph, *rewrite_fn,
2370         /*reuse_existing_functions=*/true, &graph_out, fld));
2371 
2372     // Replace outside_compilation function nodes with HostCompute ops.
2373     std::vector<Node*> outside_compilation_nodes;
2374     for (Node* n : graph_out->nodes()) {
2375       if (HasNodeAttr(n->def(), "_outside_compilation_subgraph")) {
2376         outside_compilation_nodes.push_back(n);
2377         outside_compilation_host_graphs.push_back(n->name());
2378 
2379         // If we could not infer shapes for XlaSendFromHost inputs statically,
2380         // we will set the "shape_inference_graph" attribute. In that case, copy
2381         // outside compilation subgraph as shape inference graph in `fld`.
2382         auto shape_inference_graph = std::make_unique<NameAttrList>();
2383         TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph",
2384                                        shape_inference_graph.get()));
2385         if (!shape_inference_graph->name().empty()) {
2386           shape_inference_graphs->push_back(shape_inference_graph->name());
2387           shape_inference_graphs_to_rewrite.push_back(
2388               shape_inference_graph->name());
2389 
2390           const FunctionDef* xla_fdef = fld->Find(n->name());
2391           if (!xla_fdef) {
2392             return errors::Internal("Cannot find XLA function ", n->name());
2393           }
2394           auto shape_inference_fdef = std::make_unique<FunctionDef>(*xla_fdef);
2395           shape_inference_fdef->mutable_signature()->set_name(
2396               shape_inference_graph->name());
2397           if (fld->Find(shape_inference_graph->name())) {
2398             TF_RETURN_IF_ERROR(fld->ReplaceFunction(
2399                 shape_inference_graph->name(), *shape_inference_fdef));
2400           } else {
2401             TF_RETURN_IF_ERROR(fld->AddFunctionDef(*shape_inference_fdef));
2402           }
2403         }
2404       }
2405     }
2406     std::map<string, Node*> host_compute_nodes;
2407     for (Node* n : outside_compilation_nodes) {
2408       auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
2409           graph_out.get(), n, host_compute_core, *cluster_deps);
2410       TF_RETURN_IF_ERROR(host_compute_node_or.status());
2411       Node* host_compute_node = host_compute_node_or.ValueOrDie();
2412       host_compute_nodes[host_compute_node->name()] = host_compute_node;
2413     }
2414     // For XlaHostCompute nodes with dependencies, add control edges between
2415     // them so XlaCompiler can handle them in correct order.
2416     for (const auto& iter : host_compute_nodes) {
2417       Node* host_compute_node = iter.second;
2418       std::vector<string> token_input_node_names;
2419       TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
2420                                      kXlaTokenInputNodesAttrName,
2421                                      &token_input_node_names));
2422       for (const string& node_name : token_input_node_names) {
2423         if (node_name == kXlaTokenArgNodeName) {
2424           continue;
2425         }
2426 
2427         auto iter = host_compute_nodes.find(node_name);
2428         TF_RET_CHECK(iter != host_compute_nodes.end());
2429         graph_out->AddControlEdge(iter->second, host_compute_node);
2430       }
2431     }
2432   }
2433 
2434   // Handle nodes with associated functions.
2435   Graph* g = (*has_outside_compilation) ? graph_out.get() : fbody->graph;
2436   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions(
2437       g, xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2438       host_compute_core, flr, fld, &outside_compilation_host_graphs,
2439       shape_inference_graphs, has_outside_compilation));
2440 
2441   if (*has_outside_compilation) {
2442     // Construct host graph.
2443     std::unique_ptr<Graph> host_graph;
2444     TF_RETURN_IF_ERROR(
2445         ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
2446                            outside_compilation_host_graphs, fld, &host_graph));
2447     auto host_graph_fdef = std::make_unique<FunctionDef>();
2448     TF_RETURN_IF_ERROR(GraphToFunctionDef(*host_graph, host_graph_func_name,
2449                                           HostGraphControlRetMapping,
2450                                           host_graph_fdef.get()));
2451     if (fld->Find(host_graph_func_name)) {
2452       TF_RETURN_IF_ERROR(
2453           fld->ReplaceFunction(host_graph_func_name, *host_graph_fdef));
2454     } else {
2455       TF_RETURN_IF_ERROR(fld->AddFunctionDef(*host_graph_fdef));
2456     }
2457 
2458     // Shape inference graphs might contain Placeholder nodes for outside
2459     // compilation to outside compilation edges. Rewrite shape inference graphs
2460     // to remove such nodes.
2461     for (const string& shape_inference_graph :
2462          shape_inference_graphs_to_rewrite) {
2463       TF_RETURN_IF_ERROR(
2464           RewriteShapeInferenceGraph(shape_inference_graph, host_graph.get(),
2465                                      /*pivot_node=*/nullptr, fld));
2466     }
2467 
2468     // Remove the outside compilation graphs from function library.
2469     for (const string& func : outside_compilation_host_graphs) {
2470       TF_RETURN_IF_ERROR(fld->RemoveFunction(func));
2471     }
2472 
2473     // Replace original function.
2474     auto updated_fdef = std::make_unique<FunctionDef>();
2475     TF_RETURN_IF_ERROR(
2476         GraphToFunctionDef(*g, new_func_name, updated_fdef.get()));
2477     updated_fdef->mutable_signature()->set_is_stateful(true);
2478     const FunctionDef* original_fdef = fld->Find(func_name);
2479     if (original_fdef) {
2480       for (const auto& attr : original_fdef->attr()) {
2481         (*updated_fdef->mutable_attr())[attr.first] = attr.second;
2482       }
2483     }
2484     if (fld->Find(new_func_name)) {
2485       TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, *updated_fdef));
2486     } else {
2487       TF_RETURN_IF_ERROR(fld->AddFunctionDef(*updated_fdef));
2488     }
2489     if (VLOG_IS_ON(4)) {
2490       DumpGraphToFile(
2491           absl::StrCat("extract_outside_compilation_for_func_after_",
2492                        func_name),
2493           *g, fld);
2494     }
2495   }
2496 
2497   return ret_status;
2498 }
2499 
ExtractOutsideCompilation(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const std::unordered_map<string,XlaClusterInfo> & clusters,Graph * g,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,bool * modified)2500 Status ExtractOutsideCompilation(
2501     const string& xla_cluster_attr_name,
2502     const string& outside_compilation_attr_name,
2503     const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
2504     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
2505     bool* modified) {
2506   if (VLOG_IS_ON(4)) {
2507     DumpGraphToFile("extract_outside_compilation_before", *g, fld);
2508   }
2509 
2510   *modified = false;
2511   auto node_name_index = g->BuildNodeNameIndex();
2512   for (auto& iter : clusters) {
2513     string xla_cluster_name = iter.first;
2514     Node* n = iter.second.node;
2515     auto const& func_name_attrs = iter.second.func_name_attrs;
2516     auto const& host_compute_core = iter.second.host_compute_core;
2517 
2518     std::vector<string> shape_inference_graphs;
2519     bool has_outside_compilation;
2520     string host_graph_func_name =
2521         absl::StrCat("oc_host_graph_", xla_cluster_name);
2522     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2523         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2524         func_name_attrs, func_name_attrs.name(), host_graph_func_name,
2525         host_compute_core, flr, fld, &shape_inference_graphs,
2526         &has_outside_compilation));
2527     *modified |= has_outside_compilation;
2528 
2529     if (has_outside_compilation) {
2530       string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
2531       Node* pivot_node = node_name_index[pivot_name];
2532       TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(
2533           g, fld, host_graph_func_name, n, pivot_node));
2534 
2535       TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
2536 
2537       for (const auto& shape_inference_graph_name : shape_inference_graphs) {
2538         TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(
2539             shape_inference_graph_name, g, pivot_node, fld));
2540       }
2541     }
2542   }
2543 
2544   if (VLOG_IS_ON(4)) {
2545     DumpGraphToFile("extract_outside_compilation_after", *g, fld);
2546   }
2547   return OkStatus();
2548 }
2549 
2550 }  // namespace tensorflow
2551