1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h"
17 
18 #include <queue>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/container/node_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
26 #include "tensorflow/compiler/jit/encapsulate_util.h"
27 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
28 #include "tensorflow/compiler/jit/xla_cluster_util.h"
29 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
30 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/framework/function.h"
34 #include "tensorflow/core/framework/graph_to_functiondef.h"
35 #include "tensorflow/core/framework/node_def.pb.h"
36 #include "tensorflow/core/framework/node_def_builder.h"
37 #include "tensorflow/core/framework/node_def_util.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/gtl/cleanup.h"
41 #include "tensorflow/core/lib/gtl/flatset.h"
42 #include "tensorflow/core/lib/hash/hash.h"
43 #include "tensorflow/core/lib/strings/proto_serialization.h"
44 #include "tensorflow/core/lib/strings/str_util.h"
45 #include "tensorflow/core/public/session_options.h"
46 #include "tensorflow/core/public/version.h"
47 #include "tensorflow/core/tpu/tpu_compile_interface.h"
48 #include "tensorflow/core/tpu/tpu_defs.h"
49 #include "tensorflow/core/util/dump_graph.h"
50 
51 namespace tensorflow {
52 
53 namespace {
54 
55 const char* const kTPUReplicatedInput = "TPUReplicatedInput";
56 const char* const kTPUReplicatedOutput = "TPUReplicatedOutput";
57 const char* const kPivotForClusterAttr = "_pivot_for_cluster";
58 const char* const kTPUPartitionedInput = "TPUPartitionedInput";
59 
60 // Finds the `index` of an _Arg or _Retval node.
GetIndexAttr(const Node & n,int num_args,int * index)61 Status GetIndexAttr(const Node& n, int num_args, int* index) {
62   TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index));
63   if (*index < 0 || *index >= num_args) {
64     return errors::InvalidArgument("Invalid ", n.type_string(), " number ",
65                                    *index);
66   }
67   return OkStatus();
68 }
69 
70 // Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts
71 // the arguments into the order expected by TPUReplicate computations:
72 // 1) replicated arguments
73 // 2) non-replicated (broadcast) arguments
74 // 3) resource variable arguments
75 // See the documentation of EncapsulateSubgraphsInFunctions for the meaning
76 // of the arguments.
RewriteSubgraph(const std::vector<OutputTensor> & arg_source_tensors,std::unique_ptr<Graph> * graph_ptr,std::vector<int> * input_permutation,std::vector<int> * output_permutation,NodeDef * call_def)77 Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
78                        std::unique_ptr<Graph>* graph_ptr,
79                        std::vector<int>* input_permutation,
80                        std::vector<int>* output_permutation,
81                        NodeDef* call_def) {
82   // Replicated inputs have TPUReplicatedInput nodes as predecessors in the
83   // input graph.
84   auto is_replicated_input = [&](const Node& n, bool* is_packed = nullptr) {
85     CHECK_EQ("_Arg", n.type_string());
86     int index;
87     TF_CHECK_OK(GetIndexAttr(n, arg_source_tensors.size(), &index));
88     bool ret =
89         arg_source_tensors.at(index).node->type_string() == kTPUReplicatedInput;
90     if (is_packed) {
91       if (!ret || !GetNodeAttr(arg_source_tensors.at(index).node->attrs(),
92                                "is_packed", is_packed)
93                        .ok()) {
94         *is_packed = false;
95       }
96     }
97     return ret;
98   };
99 
100   auto is_guaranteed_constant = [&](const Node& n) {
101     bool guaranteed_constant = false;
102     if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant)
103              .ok()) {
104       return false;
105     }
106     // Replicated input nodes can be marked as guaranteed constants if they are
107     // const.
108     return guaranteed_constant && !is_replicated_input(n);
109   };
110 
111   Graph* graph = graph_ptr->get();
112   Node* metadata_node = nullptr;
113   const int num_args = input_permutation->size();
114   const int num_retvals = output_permutation->size();
115 
116   std::vector<Node*> args;
117   std::vector<Node*> retvals;
118   args.reserve(num_args);
119   retvals.reserve(num_retvals);
120   for (Node* n : graph->nodes()) {
121     if (n->type_string() == "_Arg") {
122       args.push_back(n);
123     } else if (n->type_string() == "_Retval") {
124       retvals.push_back(n);
125     } else if (n->type_string() == "TPUReplicateMetadata") {
126       metadata_node = n;
127     } else if (!str_util::StrContains(n->requested_device(),
128                                       DEVICE_TPU_REPLICATED_CORE)) {
129       // If an operator isn't assigned to a TPU core device, assign it to
130       // TPU_REPLICATED_CORE without a specific core ID. For some operators,
131       // such as variable reads/writes, the operator may be assigned to non-TPU
132       // devices due to colocation.
133       n->set_assigned_device_name(
134           strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE));
135     }
136   }
137 
138   // Read the metadata node and remove it from the graph.
139   if (metadata_node == nullptr) {
140     return errors::InvalidArgument("Missing TPUReplicateMetadata node");
141   }
142 
143   for (const auto& attr : metadata_node->attrs()) {
144     if (attr.first == "computation_shape") {
145       // Convert the deprecated computation_shape attribute into a
146       // num_cores_per_replica value. If a computation_shape is present, it
147       // overrides num_cores_per_replica.
148       std::vector<int> shape;
149       TF_RETURN_IF_ERROR(
150           GetNodeAttr(metadata_node->attrs(), "computation_shape", &shape));
151       if (!shape.empty()) {
152         int64_t num_cores_per_replica = 1LL;
153         for (int dim : shape) {
154           num_cores_per_replica *= dim;
155         }
156         call_def->mutable_attr()->erase("num_cores_per_replica");
157         AddNodeAttr("num_cores_per_replica", num_cores_per_replica, call_def);
158       }
159     } else {
160       call_def->mutable_attr()->insert(attr);
161     }
162   }
163   MergeDebugInfo(NodeDebugInfo(metadata_node->def()), call_def);
164   graph->RemoveNode(metadata_node);
165 
166   if (std::find(args.begin(), args.end(), nullptr) != args.end()) {
167     return errors::InvalidArgument("Missing or non-consecutive arguments");
168   }
169 
170   // Reorders the arguments.
171   std::sort(args.begin(), args.end(), [&](Node* a, Node* b) {
172     // Non-constants appear before constants
173     bool a_is_guaranteed_constant = is_guaranteed_constant(*a);
174     bool b_is_guaranteed_constant = is_guaranteed_constant(*b);
175     // Non-packed values appear before packed values.
176     bool a_is_packed;
177     bool b_is_packed;
178     // Replicated values appear before non-replicated values.
179     bool a_not_replicated = !is_replicated_input(*a, &a_is_packed);
180     bool b_not_replicated = !is_replicated_input(*b, &b_is_packed);
181     // Non-resources appear before resources
182     bool a_is_resource = (a->output_type(0) == DT_RESOURCE);
183     bool b_is_resource = (b->output_type(0) == DT_RESOURCE);
184     // Uses the name as a tiebreaker so the output is deterministic.
185     StringPiece a_name(a->name());
186     StringPiece b_name(b->name());
187     return std::tie(a_is_guaranteed_constant, a_not_replicated, a_is_packed,
188                     a_is_resource, a_name) <
189            std::tie(b_is_guaranteed_constant, b_not_replicated, b_is_packed,
190                     b_is_resource, b_name);
191   });
192   // Sorts the retvals by name so the order is deterministic.
193   std::sort(retvals.begin(), retvals.end(),
194             [](Node* a, Node* b) { return a->name() < b->name(); });
195 
196   // Computes the permutation to produce the correct argument order, and update
197   // the argument indices.
198   int variable_start_index = num_args;
199   int guaranteed_const_start_index = num_args;
200   for (int i = 0; i < num_args; ++i) {
201     int index;
202     TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index));
203     if (args[i]->output_type(0) == DT_RESOURCE &&
204         !is_replicated_input(*args[i]) && variable_start_index == num_args) {
205       variable_start_index = i;
206     } else if (is_guaranteed_constant(*args[i]) &&
207                guaranteed_const_start_index == num_args) {
208       guaranteed_const_start_index = i;
209     }
210     (*input_permutation)[index] = i;
211     args[i]->AddAttr("index", i);
212   }
213   VLOG(4) << "variable_start_index: " << variable_start_index
214           << " guaranteed_const_start_index: " << guaranteed_const_start_index;
215 
216   // Computes the permutation to produce the correct retval order, and update
217   // the argument indices.
218   for (int i = 0; i < num_retvals; ++i) {
219     int index;
220     TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index));
221     (*output_permutation)[index] = i;
222     retvals[i]->AddAttr("index", i);
223   }
224 
225   AddNodeAttr(kTPUReplicateAttr, call_def->name(), call_def);
226   AddNodeAttr("_variable_start_index", variable_start_index, call_def);
227   AddNodeAttr("_guaranteed_const_start_index", guaranteed_const_start_index,
228               call_def);
229 
230   // Uniquify the function name by fingerprinting the function.
231   // Nondeterminism in serialization would not lead to incorrect results, but
232   // may cause spurious cache misses. DeterministicSerialization is a
233   // best-effort deterministic serialization.
234   TF_ASSIGN_OR_RETURN(string serialized, SerializeGraphDeterministic(*graph));
235   uint64 fingerprint =
236       TpuCompileInterface::Get()->FingerprintString(serialized);
237   LOG(INFO) << "Subgraph fingerprint:" << fingerprint;
238   call_def->set_op(strings::StrCat(call_def->op(), "_", fingerprint));
239   return OkStatus();
240 }
241 
EdgeType(const Edge * edge)242 DataType EdgeType(const Edge* edge) {
243   return edge->dst()->input_type(edge->dst_input());
244 }
245 
246 // Adds the control inputs of `node` to `*deps`.
AddControlInputs(const Node & node,gtl::FlatSet<Node * > * deps)247 void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
248   for (const Edge* edge : node.in_edges()) {
249     if (edge->IsControlEdge()) {
250       deps->insert(edge->src());
251     }
252   }
253 }
254 
255 // Adds the control outputs of `node` to `*deps`.
AddControlOutputs(const Node & node,gtl::FlatSet<Node * > * deps)256 void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) {
257   for (const Edge* edge : node.out_edges()) {
258     if (edge->IsControlEdge()) {
259       deps->insert(edge->dst());
260     }
261   }
262 }
263 
264 // We add Identity nodes for _Arg/_Retval in XLA computation. Remove those
265 // Identity nodes to simplify furthur processing.
RemoveIdentityNodesForArgRetval(Graph * g)266 Status RemoveIdentityNodesForArgRetval(Graph* g) {
267   // Collect Identity nodes for _Arg/_Retval.
268   std::vector<Node*> identity_nodes;
269   for (Node* n : g->nodes()) {
270     if (n->type_string() == "Identity" &&
271         (HasNodeAttr(n->def(), "_tpu_input_identity") ||
272          HasNodeAttr(n->def(), "_tpu_output_identity"))) {
273       identity_nodes.push_back(n);
274     }
275   }
276 
277   // Remove those Identity nodes.
278   for (Node* n : identity_nodes) {
279     const Edge* input_edge;
280     TF_RETURN_IF_ERROR(n->input_edge(0, &input_edge));
281 
282     std::vector<const Edge*> output_edges;
283     for (const Edge* e : n->out_edges()) {
284       output_edges.push_back(e);
285     }
286     for (const Edge* e : output_edges) {
287       if (e->IsControlEdge()) {
288         Node* dst = e->dst();
289         g->RemoveEdge(e);
290         g->AddControlEdge(input_edge->src(), dst);
291       } else {
292         Node* dst = e->dst();
293         int dst_input = e->dst_input();
294         g->RemoveEdge(e);
295         g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
296       }
297     }
298     g->RemoveNode(n);
299   }
300 
301   return OkStatus();
302 }
303 
304 // Updates the TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR when
305 // 'additional_per_replicate_inputs' are added to the inputs of `xla_node`.
UpdateMirroredVariableIndices(int additional_per_replica_inputs,Node * xla_node)306 Status UpdateMirroredVariableIndices(int additional_per_replica_inputs,
307                                      Node* xla_node) {
308   std::vector<int> mirrored_variable_indices;
309   if (xla_node->attrs().Find(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR) !=
310       nullptr) {
311     TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(),
312                                    TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
313                                    &mirrored_variable_indices));
314   }
315 
316   if (!mirrored_variable_indices.empty()) {
317     for (int i = 0; i < mirrored_variable_indices.size(); ++i)
318       mirrored_variable_indices[i] += additional_per_replica_inputs;
319     xla_node->ClearAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR);
320     xla_node->AddAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
321                       mirrored_variable_indices);
322   }
323   return OkStatus();
324 }
325 
326 // Move outside compilation nodes at the beginning of XLA computation to host.
327 // For XLA computation graph, we will add new _Arg nodes to replace those
328 // outside compilation nodes.
329 // For host graph, we will move those outside compilation nodes to host,
330 // replicate them, and use them as XLA node's input.
MoveHeadOutsideCompilationToHost(const string & outside_compilation_attr_name,const string & xla_func_name,const std::string & cluster_name,Graph * g,Graph * xla_graph,Node * xla_node,Node * pivot_node)331 Status MoveHeadOutsideCompilationToHost(
332     const string& outside_compilation_attr_name, const string& xla_func_name,
333     const std::string& cluster_name, Graph* g, Graph* xla_graph, Node* xla_node,
334     Node* pivot_node) {
335   // Find outside compilation nodes that only have _Arg or other outside
336   // compilation nodes as input. These nodes will be moved to host graph.
337   std::vector<Node*> oc_nodes_at_head;
338   const string kOnlyArgOrOcInputAttrName = "_xla_only_arg_or_oc_input";
339   ReverseDFS(
340       *xla_graph, /*enter=*/nullptr,
341       [&](Node* n) {
342         bool has_non_arg_or_oc_input = false;
343         for (const Edge* e : n->in_edges()) {
344           if (e->src() == xla_graph->source_node()) {
345             continue;
346           }
347           if (!e->src()->IsArg() &&
348               (!HasNodeAttr(e->src()->def(), outside_compilation_attr_name) ||
349                !HasNodeAttr(e->src()->def(), kOnlyArgOrOcInputAttrName))) {
350             has_non_arg_or_oc_input = true;
351             break;
352           }
353         }
354         if (HasNodeAttr(n->def(), outside_compilation_attr_name) &&
355             !has_non_arg_or_oc_input &&
356             !HasNodeAttr(n->def(), kXlaIsPlaceholderForArg)) {
357           n->AddAttr(kOnlyArgOrOcInputAttrName, true);
358           oc_nodes_at_head.push_back(n);
359         }
360       },
361       NodeComparatorName());
362   std::vector<Node*> const_nodes_to_remove;
363   for (Node* n : oc_nodes_at_head) {
364     // If a Const node is in "oc_nodes_at_head" but some of its successors are
365     // not, copy this Const node and use the copied node for those successors.
366     if (n->type_string() != "Const") {
367       continue;
368     }
369 
370     std::vector<const Edge*> edges_to_replace;
371     for (const Edge* e : n->out_edges()) {
372       if (!e->IsControlEdge() &&
373           HasNodeAttr(e->dst()->def(), outside_compilation_attr_name) &&
374           !HasNodeAttr(e->dst()->def(), kOnlyArgOrOcInputAttrName)) {
375         edges_to_replace.push_back(e);
376       }
377     }
378     if (edges_to_replace.empty()) {
379       continue;
380     }
381 
382     Node* const_copy = xla_graph->CopyNode(n);
383     for (const Edge* e : edges_to_replace) {
384       Node* dst = e->dst();
385       int dst_input = e->dst_input();
386       xla_graph->RemoveEdge(e);
387       xla_graph->AddEdge(const_copy, 0, dst, dst_input);
388     }
389     // Make sure the copied node can be traced from source node.
390     xla_graph->AddControlEdge(xla_graph->source_node(), const_copy);
391 
392     // If this Const node has no data output any more, remove it later.
393     bool has_output_edge = false;
394     for (const Edge* e : n->out_edges()) {
395       if (!e->IsControlEdge()) {
396         has_output_edge = true;
397         break;
398       }
399     }
400     if (!has_output_edge) {
401       const_nodes_to_remove.push_back(n);
402     }
403   }
404   for (Node* n : const_nodes_to_remove) {
405     xla_graph->RemoveNode(n);
406     oc_nodes_at_head.erase(
407         std::remove(oc_nodes_at_head.begin(), oc_nodes_at_head.end(), n),
408         oc_nodes_at_head.end());
409   }
410   if (VLOG_IS_ON(5)) {
411     for (Node* n : oc_nodes_at_head) {
412       VLOG(5) << "oc_nodes_at_head: " << n->DebugString();
413     }
414   }
415 
416   // Copy all nodes in `oc_nodes_at_head` to host graph, and also replicate
417   // them.
418 
419   // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
420   // will become very expensive in this case because it is doing a linear
421   // search inside. Create an input_edges vector ahead to make the lookups
422   // faster.
423   std::vector<const Edge*> input_edges;
424   TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
425 
426   std::vector<DataType> input_types;
427   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tinputs", &input_types));
428   int num_distributed_vars;
429   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
430                                  &num_distributed_vars));
431   int num_replicas;
432   TF_RETURN_IF_ERROR(
433       GetNodeAttr(xla_node->attrs(), "num_replicas", &num_replicas));
434   int old_num_per_replica_inputs =
435       (input_types.size() - num_distributed_vars) / num_replicas;
436   VLOG(5) << "old_num_per_replica_inputs: " << old_num_per_replica_inputs;
437   std::map<Node*, std::vector<Node*>> node_images;
438   for (Node* n : oc_nodes_at_head) {
439     for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
440       NodeDef copy_def = n->def();
441       copy_def.set_name(absl::StrCat(n->name(), "_head_oc/R", replica_id));
442       copy_def.clear_device();
443 
444       TF_ASSIGN_OR_RETURN(Node * copy_node, g->AddNode(copy_def));
445 
446       copy_node->AddAttr(kXlaReplicaIdAttrName, replica_id);
447       copy_node->AddAttr(kTPUReplicateAttr, cluster_name);
448 
449       for (const Edge* e : n->in_edges()) {
450         if (e->src() == xla_graph->source_node()) {
451           continue;
452         }
453         // Either e->src() is _Arg node, or it's in `node_images`.
454         if (e->src()->IsArg()) {
455           int index;
456           TF_RETURN_IF_ERROR(GetNodeAttr(e->src()->attrs(), "index", &index));
457           const int new_index =
458               (index < old_num_per_replica_inputs)
459                   ? (old_num_per_replica_inputs * replica_id + index)
460                   : (old_num_per_replica_inputs * num_replicas +
461                      (index - old_num_per_replica_inputs));
462           const Edge* original_edge = input_edges.at(new_index);
463           g->AddEdge(original_edge->src(), original_edge->src_output(),
464                      copy_node, e->dst_input());
465         } else {
466           g->AddEdge(node_images[e->src()][replica_id], e->src_output(),
467                      copy_node, e->dst_input());
468         }
469       }
470 
471       // Add control edge between `copy_node` and `xla_node`, so these outside
472       // compilation nodes will be executed before XLA computation happens.
473       g->AddControlEdge(copy_node, xla_node);
474 
475       // Add control edge between `pivot_node` and `copy_node`, so `copy_node`
476       // belongs to same while loop as `xla_node`.
477       if (pivot_node) {
478         g->AddControlEdge(pivot_node, copy_node);
479       }
480 
481       node_images[n].push_back(copy_node);
482     }
483   }
484 
485   // Record output edges from `oc_nodes_at_head`. We will create an _Arg node
486   // for each of these edges. An obvious optimization here is to deduplicate
487   // these edges by <src, src_output>. But that optimization will complicate
488   // the code, and in practice we usually do not have output edges with the
489   // same <src, src_output>.
490   std::vector<const Edge*> oc_output_edges;
491   std::vector<DataType> new_arg_types;
492   for (Node* n : oc_nodes_at_head) {
493     for (const Edge* e : n->out_edges()) {
494       if (!e->IsControlEdge() &&
495           node_images.find(e->dst()) == node_images.end()) {
496         VLOG(5) << "oc_output_edges: " << e->DebugString();
497         oc_output_edges.push_back(e);
498         new_arg_types.push_back(e->src()->output_type(e->src_output()));
499       }
500     }
501   }
502   int new_num_per_replica_inputs =
503       old_num_per_replica_inputs + oc_output_edges.size();
504   VLOG(5) << "new_num_per_replica_inputs: " << new_num_per_replica_inputs;
505 
506   // Process input edges for XLA node.
507   int num_variables;
508   TF_RETURN_IF_ERROR(
509       GetNodeAttr(xla_node->attrs(), "NumVariables", &num_variables));
510   std::vector<DataType> broadcast_input_types, guaranteed_constant_types;
511   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tbroadcast_inputs",
512                                  &broadcast_input_types));
513   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tguaranteed_constants",
514                                  &guaranteed_constant_types));
515   int num_other_inputs = num_distributed_vars + num_variables +
516                          broadcast_input_types.size() +
517                          guaranteed_constant_types.size();
518   VLOG(5) << "num_other_inputs: " << num_other_inputs;
519 
520   // Update `Tinputs` attribute for `xla_node`.
521   std::vector<DataType> new_input_types;
522   // Order of new_input_types: old per-replica inputs -> new per-replica inputs
523   // -> distributed variables
524   new_input_types.reserve(num_replicas * new_num_per_replica_inputs +
525                           num_distributed_vars);
526   for (int replica_id = 0; replica_id < num_replicas; ++replica_id) {
527     for (int i = 0; i < old_num_per_replica_inputs; ++i) {
528       new_input_types.push_back(input_types[i]);
529     }
530     for (int i = old_num_per_replica_inputs; i < new_num_per_replica_inputs;
531          ++i) {
532       new_input_types.push_back(new_arg_types[i - old_num_per_replica_inputs]);
533     }
534   }
535   const int num_new_per_replica_input_types = new_input_types.size();
536   for (int i = input_types.size() - num_distributed_vars;
537        i < input_types.size(); i++) {
538     new_input_types.push_back(input_types[i]);
539   }
540   xla_node->ClearAttr("Tinputs");
541   xla_node->AddAttr("Tinputs", new_input_types);
542 
543   TF_RETURN_IF_ERROR(UpdateMirroredVariableIndices(
544       /*additional_per_replica_inputs=*/oc_output_edges.size(), xla_node));
545 
546   int new_variable_start_index =
547       num_new_per_replica_input_types / num_replicas + num_distributed_vars +
548       broadcast_input_types.size();
549   if (xla_node->attrs().Find("_variable_start_index") != nullptr) {
550     xla_node->ClearAttr("_variable_start_index");
551     xla_node->AddAttr("_variable_start_index", new_variable_start_index);
552   }
553   int new_guaranteed_const_start_index =
554       new_variable_start_index + num_variables;
555   if (xla_node->attrs().Find("_guaranteed_const_start_index") != nullptr) {
556     xla_node->ClearAttr("_guaranteed_const_start_index");
557     xla_node->AddAttr("_guaranteed_const_start_index",
558                       new_guaranteed_const_start_index);
559   }
560 
561   // Move non per-replica input edges.
562   std::vector<const Edge*> new_input_edges(
563       num_replicas * new_num_per_replica_inputs + num_other_inputs);
564   int end_input_index =
565       num_replicas * new_num_per_replica_inputs + num_other_inputs - 1;
566   int start_input_index = end_input_index + 1 - num_other_inputs;
567   for (int input_index = end_input_index; input_index >= start_input_index;
568        input_index--) {
569     const Edge* e =
570         input_edges.at(input_index - num_replicas * new_arg_types.size());
571     Node* src = e->src();
572     int src_output = e->src_output();
573     g->RemoveEdge(e);
574     const Edge* new_input_edge =
575         g->AddEdge(src, src_output, xla_node, input_index);
576     new_input_edges[input_index] = new_input_edge;
577   }
578 
579   // Re-order old per-replica inputs edges, and add new per-replica input edges.
580   std::vector<std::pair<Node*, int>> per_replica_inputs;
581   std::vector<const Edge*> old_per_replica_edges;
582   for (int i = 0; i < old_num_per_replica_inputs * num_replicas; i++) {
583     const Edge* e = input_edges.at(i);
584     per_replica_inputs.push_back(std::make_pair(e->src(), e->src_output()));
585     old_per_replica_edges.push_back(e);
586   }
587   for (const Edge* e : old_per_replica_edges) {
588     g->RemoveEdge(e);
589   }
590   for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
591     for (int input_index = 0; input_index < old_num_per_replica_inputs;
592          input_index++) {
593       Node* src = per_replica_inputs[replica_id * old_num_per_replica_inputs +
594                                      input_index]
595                       .first;
596       int src_output =
597           per_replica_inputs[replica_id * old_num_per_replica_inputs +
598                              input_index]
599               .second;
600       const Edge* new_input_edge =
601           g->AddEdge(src, src_output, xla_node,
602                      replica_id * new_num_per_replica_inputs + input_index);
603       new_input_edges[input_index] = new_input_edge;
604     }
605     for (int input_index = old_num_per_replica_inputs;
606          input_index < new_num_per_replica_inputs; input_index++) {
607       Node* original_src =
608           oc_output_edges[input_index - old_num_per_replica_inputs]->src();
609       int original_src_output =
610           oc_output_edges[input_index - old_num_per_replica_inputs]
611               ->src_output();
612       Node* src = node_images[original_src][replica_id];
613       const Edge* new_input_edge =
614           g->AddEdge(src, original_src_output, xla_node,
615                      replica_id * new_num_per_replica_inputs + input_index);
616       new_input_edges[input_index] = new_input_edge;
617     }
618   }
619 
620   // Adjust original _Arg nodes in `xla_graph`.
621   for (Node* n : xla_graph->nodes()) {
622     if (n->IsArg()) {
623       int index;
624       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
625       if (index >= old_num_per_replica_inputs) {
626         index += new_arg_types.size();
627         n->ClearAttr("index");
628         n->AddAttr("index", index);
629       }
630     }
631   }
632 
633   // Create new _Arg nodes in `xla_graph`.
634   for (int i = old_num_per_replica_inputs; i < new_num_per_replica_inputs;
635        i++) {
636     NodeDefBuilder arg_builder(absl::StrCat("arg_", i),
637                                FunctionLibraryDefinition::kArgOp);
638     arg_builder.Attr("T", new_arg_types[i - old_num_per_replica_inputs]);
639     arg_builder.Attr("index", i);
640     NodeDef arg_def;
641     TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
642     TF_ASSIGN_OR_RETURN(Node * arg_node, xla_graph->AddNode(arg_def));
643     const Edge* original_edge = oc_output_edges[i - old_num_per_replica_inputs];
644     Node* dst = original_edge->dst();
645     int dst_input = original_edge->dst_input();
646     xla_graph->RemoveEdge(original_edge);
647     xla_graph->AddEdge(arg_node, 0, dst, dst_input);
648   }
649 
650   // For lifted arg nodes:
651   // 1. Add a Placeholder node in `xla_graph`. When we build host side graph
652   //    in ExtractOutsideCompilationPass, we will use this new Placeholder node
653   //    instead of lifted arg node here.
654   // 2. Add an IdentityN node in `g` to indicate its inputs. We will reconnect
655   //    this IdentityN node and this lifted arg node's usage nodes in
656   //    DistributedTPURewritePass.
657   for (Node* n : oc_nodes_at_head) {
658     bool is_lifted_arg;
659     string outside_compilation_attr;
660     if (!TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) ||
661         !TryGetNodeAttr(n->def(), kOutsideCompilationAttr,
662                         &outside_compilation_attr)) {
663       continue;
664     }
665 
666     TF_RET_CHECK(n->IsIdentity());
667     NodeDefBuilder ph_builder(absl::StrCat("placeholder_", n->name()),
668                               "Placeholder");
669     DataType dtype;
670     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
671     ph_builder.Attr("dtype", dtype);
672     ph_builder.Attr(kXlaIsLiftedArgAttrName, true);
673     ph_builder.Attr(kOutsideCompilationAttr, outside_compilation_attr);
674     NodeDef ph_def;
675     TF_RETURN_IF_ERROR(ph_builder.Finalize(&ph_def));
676     Status s;
677     xla_graph->AddNode(ph_def, &s);
678     TF_RETURN_IF_ERROR(s);
679 
680     Node* input_node;
681     TF_RETURN_IF_ERROR(n->input_node(0, &input_node));
682     TF_RET_CHECK(input_node->type_string() == "_Arg");
683     int index;
684     TF_RETURN_IF_ERROR(GetNodeAttr(input_node->def(), "index", &index));
685     // TODO(b/74023706): for now we only support resource input (e.g. summary
686     // writer), which is non-replicated input. Support replicated input as
687     // well.
688     TF_RET_CHECK(index >= new_num_per_replica_inputs + num_distributed_vars);
689     const Edge* input_edge =
690         new_input_edges.at(num_replicas * new_num_per_replica_inputs + index -
691                            new_num_per_replica_inputs);
692     NodeDefBuilder id_builder(absl::StrCat("lifted_arg_input_", index),
693                               "IdentityN");
694     DataType input_dtype =
695         input_edge->src()->output_type(input_edge->src_output());
696     id_builder.Attr("T", std::vector<DataType>(num_replicas, input_dtype));
697     std::vector<NodeDefBuilder::NodeOut> inputs(
698         num_replicas,
699         NodeDefBuilder::NodeOut{input_edge->src()->name(),
700                                 input_edge->src_output(), input_dtype});
701     id_builder.Attr(kXlaOutsideCompilationInputsAttrName,
702                     outside_compilation_attr);
703     id_builder.Input(inputs);
704     NodeDef id_def;
705     TF_RETURN_IF_ERROR(id_builder.Finalize(&id_def));
706     TF_ASSIGN_OR_RETURN(Node * id_node, g->AddNode(id_def));
707     for (int i = 0; i < num_replicas; i++) {
708       g->AddEdge(input_edge->src(), input_edge->src_output(), id_node, i);
709     }
710   }
711 
712   // Remove `oc_nodes_at_head`.
713   for (Node* n : oc_nodes_at_head) {
714     xla_graph->RemoveNode(n);
715   }
716 
717   VLOG(4) << "MoveHeadOutsideCompilationToHost host graph: "
718           << DumpGraphToFile(absl::StrCat("move_head_oc_host_", xla_func_name),
719                              *g);
720   VLOG(4) << "MoveHeadOutsideCompilationToHost XLA graph: "
721           << DumpGraphToFile(absl::StrCat("move_head_oc_xla_", xla_func_name),
722                              *xla_graph);
723 
724   return OkStatus();
725 }
726 
727 // If there are any unused _Arg nodes in `xla_graph`, remove them from
728 // `xla_graph` and remove corresponding input edge in host graph `g`.
RemoveUnusedXlaInput(const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)729 Status RemoveUnusedXlaInput(const string& xla_func_name, Graph* g,
730                             Graph* xla_graph, Node* xla_node) {
731   // Find unused _Arg nodes, and remove them.
732   std::vector<DataType> input_types;
733   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tinputs", &input_types));
734   std::vector<int> mirrored_variable_indices;
735   if (xla_node->attrs().Find(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR) !=
736       nullptr) {
737     TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(),
738                                    TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
739                                    &mirrored_variable_indices));
740   }
741   std::vector<DataType> broadcast_input_types;
742   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tbroadcast_inputs",
743                                  &broadcast_input_types));
744   std::vector<DataType> guaranteed_constant_types;
745   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tguaranteed_constants",
746                                  &guaranteed_constant_types));
747   int num_variables;
748   TF_RETURN_IF_ERROR(
749       GetNodeAttr(xla_node->def(), "NumVariables", &num_variables));
750   int num_replicas;
751   TF_RETURN_IF_ERROR(
752       GetNodeAttr(xla_node->def(), "num_replicas", &num_replicas));
753   int num_distributed_vars;
754   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
755                                  &num_distributed_vars));
756   int num_per_replica_inputs =
757       (input_types.size() - num_distributed_vars) / num_replicas;
758   std::set<int> arg_indices_to_remove;
759   std::vector<Node*> arg_nodes_to_update, nodes_to_remove;
760   int num_args = 0, num_removed_per_replica_inputs = 0,
761       num_removed_distributed_vars = 0;
762   for (Node* n : xla_graph->nodes()) {
763     if (!n->IsArg()) {
764       continue;
765     }
766 
767     bool has_output = false;
768     for (const Edge* e : n->out_edges()) {
769       if (e->dst() != xla_graph->sink_node()) {
770         has_output = true;
771         break;
772       }
773     }
774 
775     num_args++;
776     int index;
777     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
778     if (has_output) {
779       arg_nodes_to_update.push_back(n);
780       continue;
781     }
782 
783     arg_indices_to_remove.insert(index);
784     if (index < num_per_replica_inputs) {
785       num_removed_per_replica_inputs++;
786     } else if (index < num_per_replica_inputs + num_distributed_vars) {
787       num_removed_distributed_vars++;
788     }
789     nodes_to_remove.push_back(n);
790   }
791   for (Node* n : nodes_to_remove) {
792     xla_graph->RemoveNode(n);
793   }
794 
795   // Update `index` for other _Arg nodes.
796   std::map<int, int> arg_index_mapping;
797   int new_arg_index = 0;
798   for (int i = 0; i < num_args; i++) {
799     if (arg_indices_to_remove.find(i) != arg_indices_to_remove.end()) {
800       continue;
801     } else {
802       arg_index_mapping[i] = new_arg_index;
803       new_arg_index++;
804     }
805   }
806   for (Node* n : arg_nodes_to_update) {
807     int index;
808     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
809     n->ClearAttr("index");
810     n->AddAttr("index", arg_index_mapping[index]);
811   }
812 
813   // Re-order replicated index edges for `xla_node`.
814 
815   // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
816   // will become very expensive in this case because it is doing a linear search
817   // inside. Create a input_edges vector ahead to make the lookups faster.
818   std::vector<const Edge*> input_edges;
819   TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
820 
821   const int num_new_per_replica_inputs =
822       num_per_replica_inputs - num_removed_per_replica_inputs;
823   for (int i = 0; i < num_replicas; i++) {
824     for (int j = 0; j < num_per_replica_inputs; j++) {
825       auto iter = arg_index_mapping.find(j);
826       if (iter != arg_index_mapping.end()) {
827         const Edge* e = input_edges.at(i * num_per_replica_inputs + j);
828         Node* src = e->src();
829         int src_output = e->src_output();
830         int dst_input = i * num_new_per_replica_inputs + iter->second;
831 
832         g->RemoveEdge(e);
833         g->AddEdge(src, src_output, xla_node, dst_input);
834       } else {
835         const Edge* e = input_edges.at(i * num_per_replica_inputs + j);
836         g->RemoveEdge(e);
837       }
838     }
839   }
840 
841   // Move other data input edges.
842   for (int i = num_replicas * num_per_replica_inputs;
843        i < xla_node->num_inputs(); i++) {
844     int arg_index =
845         num_per_replica_inputs + i - num_replicas * num_per_replica_inputs;
846     auto iter = arg_index_mapping.find(arg_index);
847     if (iter != arg_index_mapping.end()) {
848       const Edge* e = input_edges.at(i);
849       Node* src = e->src();
850       int src_output = e->src_output();
851       int dst_input = num_replicas * num_new_per_replica_inputs + iter->second -
852                       num_new_per_replica_inputs;
853 
854       g->RemoveEdge(e);
855       g->AddEdge(src, src_output, xla_node, dst_input);
856     } else {
857       const Edge* e = input_edges.at(i);
858       g->RemoveEdge(e);
859     }
860   }
861 
862   // Update attributes for `xla_node`.
863   std::vector<DataType> new_input_types;
864   for (int i = 0; i < num_replicas; i++) {
865     for (int j = 0; j < num_per_replica_inputs; j++) {
866       auto iter = arg_index_mapping.find(j);
867       if (iter != arg_index_mapping.end()) {
868         new_input_types.push_back(input_types[iter->first]);
869       }
870     }
871   }
872   for (int i = 0; i < num_distributed_vars; ++i) {
873     auto iter = arg_index_mapping.find(i + num_per_replica_inputs);
874     if (iter != arg_index_mapping.end()) {
875       new_input_types.push_back(
876           input_types[iter->first - num_per_replica_inputs +
877                       num_per_replica_inputs * num_replicas]);
878     }
879   }
880   xla_node->ClearAttr("Tinputs");
881   xla_node->AddAttr("Tinputs", new_input_types);
882 
883   const int num_new_distributed_vars =
884       num_distributed_vars - num_removed_distributed_vars;
885   xla_node->ClearAttr("num_distributed_variables");
886   xla_node->AddAttr("num_distributed_variables", num_new_distributed_vars);
887 
888   if (!mirrored_variable_indices.empty()) {
889     std::vector<int> new_mirrored_variable_indices;
890     absl::flat_hash_set<int> old_mirrored_variable_indices_set;
891     for (int index : mirrored_variable_indices) {
892       old_mirrored_variable_indices_set.insert(index);
893     }
894     for (int i = 0; i < num_per_replica_inputs + num_distributed_vars; i++) {
895       auto iter = arg_index_mapping.find(i);
896       if (iter != arg_index_mapping.end() &&
897           old_mirrored_variable_indices_set.contains(iter->first)) {
898         new_mirrored_variable_indices.push_back(iter->second);
899       }
900     }
901     xla_node->ClearAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR);
902     xla_node->AddAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
903                       new_mirrored_variable_indices);
904   }
905 
906   int num_replicated_inputs = num_per_replica_inputs + num_distributed_vars;
907   std::vector<DataType> new_broadcast_input_types;
908   for (int i = 0; i < broadcast_input_types.size(); i++) {
909     int arg_index = num_replicated_inputs + i;
910     if (arg_index_mapping.find(arg_index) != arg_index_mapping.end()) {
911       new_broadcast_input_types.push_back(broadcast_input_types[i]);
912     }
913   }
914   xla_node->ClearAttr("Tbroadcast_inputs");
915   xla_node->AddAttr("Tbroadcast_inputs", new_broadcast_input_types);
916   int new_num_variables = 0;
917   for (int i = 0; i < num_variables; i++) {
918     int arg_index = num_replicated_inputs + broadcast_input_types.size() + i;
919     if (arg_index_mapping.find(arg_index) != arg_index_mapping.end()) {
920       new_num_variables++;
921     }
922   }
923   xla_node->ClearAttr("NumVariables");
924   xla_node->AddAttr("NumVariables", new_num_variables);
925   std::vector<DataType> new_guaranteed_constant_types;
926   for (int i = 0; i < guaranteed_constant_types.size(); i++) {
927     int arg_index = num_replicated_inputs + broadcast_input_types.size() +
928                     num_variables + i;
929     if (arg_index_mapping.find(arg_index) != arg_index_mapping.end()) {
930       new_guaranteed_constant_types.push_back(guaranteed_constant_types[i]);
931     }
932   }
933   xla_node->ClearAttr("Tguaranteed_constants");
934   xla_node->AddAttr("Tguaranteed_constants", new_guaranteed_constant_types);
935 
936   int new_variable_start_index = num_new_per_replica_inputs +
937                                  num_new_distributed_vars +
938                                  new_broadcast_input_types.size();
939   if (xla_node->attrs().Find("_variable_start_index") != nullptr) {
940     xla_node->ClearAttr("_variable_start_index");
941     xla_node->AddAttr("_variable_start_index", new_variable_start_index);
942   }
943   int new_guaranteed_const_start_index =
944       new_variable_start_index + new_num_variables;
945   if (xla_node->attrs().Find("_guaranteed_const_start_index") != nullptr) {
946     xla_node->ClearAttr("_guaranteed_const_start_index");
947     xla_node->AddAttr("_guaranteed_const_start_index",
948                       new_guaranteed_const_start_index);
949   }
950 
951   VLOG(4) << "RemoveUnusedXlaInput host graph: "
952           << DumpGraphToFile(
953                  absl::StrCat("remove_unused_input_host_", xla_func_name), *g);
954   VLOG(4) << "RemoveUnusedXlaInput XLA graph: "
955           << DumpGraphToFile(
956                  absl::StrCat("remove_unused_input_xla_", xla_func_name),
957                  *xla_graph);
958 
959   return OkStatus();
960 }
961 
962 // Move outside compilation nodes at the end of XLA computation to host.
963 // For XLA computation graph, we will add new _Retval nodes to replace those
964 // outside compilation nodes.
965 // For host graph, we will move those outside compilation nodes to host,
966 // replicate them, and use them as XLA node's output.
MoveTailOutsideCompilationToHost(const string & outside_compilation_attr_name,const string & xla_func_name,const std::string & cluster_name,Graph * g,Graph * xla_graph,Node * xla_node,Node * pivot_node)967 Status MoveTailOutsideCompilationToHost(
968     const string& outside_compilation_attr_name, const string& xla_func_name,
969     const std::string& cluster_name, Graph* g, Graph* xla_graph, Node* xla_node,
970     Node* pivot_node) {
971   // Find outside compilation nodes that only have _Retval or other outside
972   // compilation nodes as output. These nodes will be moved to host graph.
973   std::vector<Node*> oc_nodes_at_tail;
974   const string kOnlyRetOrOcOutputAttrName = "_xla_only_ret_or_oc_output";
975   DFS(
976       *xla_graph, /*enter=*/nullptr,
977       [&](Node* n) {
978         bool has_non_ret_or_oc_output = false;
979         for (const Edge* e : n->out_edges()) {
980           if (e->dst() == xla_graph->sink_node()) {
981             continue;
982           }
983           if (!e->dst()->IsRetval() &&
984               (!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name) ||
985                !HasNodeAttr(e->dst()->def(), kOnlyRetOrOcOutputAttrName))) {
986             has_non_ret_or_oc_output = true;
987             break;
988           }
989         }
990         if (HasNodeAttr(n->def(), outside_compilation_attr_name) &&
991             !has_non_ret_or_oc_output) {
992           n->AddAttr(kOnlyRetOrOcOutputAttrName, true);
993           oc_nodes_at_tail.push_back(n);
994         }
995       },
996       NodeComparatorName());
997   if (VLOG_IS_ON(5)) {
998     for (Node* n : oc_nodes_at_tail) {
999       VLOG(5) << "oc_nodes_at_tail: " << n->DebugString();
1000     }
1001   }
1002 
1003   // Record input edges from `oc_nodes_at_tail`. We will create an _Retval node
1004   // for each of these edges. An obvious optimization here is to deduplicate
1005   // these edges by <src, src_output>. But that optimization will complicate
1006   // the code, and in practice we usually do not have input edges with the
1007   // same <src, src_output>.
1008   std::vector<const Edge*> oc_input_edges;
1009   std::vector<DataType> new_ret_types;
1010   for (Node* n : oc_nodes_at_tail) {
1011     for (const Edge* e : n->in_edges()) {
1012       if (!e->IsControlEdge() &&
1013           !HasNodeAttr(e->src()->def(), kOnlyRetOrOcOutputAttrName)) {
1014         VLOG(5) << "oc_input_edges: " << e->DebugString();
1015         oc_input_edges.push_back(e);
1016         new_ret_types.push_back(e->src()->output_type(e->src_output()));
1017       }
1018     }
1019   }
1020   std::vector<DataType> output_types;
1021   TF_RETURN_IF_ERROR(
1022       GetNodeAttr(xla_node->attrs(), "output_types", &output_types));
1023   int num_replicas;
1024   TF_RETURN_IF_ERROR(
1025       GetNodeAttr(xla_node->attrs(), "num_replicas", &num_replicas));
1026   int old_num_replicated_outputs = output_types.size() / num_replicas;
1027   int new_num_replicated_outputs =
1028       old_num_replicated_outputs + oc_input_edges.size();
1029   VLOG(5) << "old_num_replicated_outputs: " << old_num_replicated_outputs;
1030   VLOG(5) << "new_num_replicated_outputs: " << new_num_replicated_outputs;
1031 
1032   // Update `output_types` attribute for `xla_node`.
1033   std::vector<DataType> new_output_types;
1034   for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1035     for (int i = 0; i < old_num_replicated_outputs; i++) {
1036       new_output_types.push_back(output_types[i]);
1037     }
1038     for (int i = old_num_replicated_outputs; i < new_num_replicated_outputs;
1039          i++) {
1040       new_output_types.push_back(new_ret_types[i - old_num_replicated_outputs]);
1041     }
1042   }
1043   xla_node->ClearAttr("output_types");
1044   xla_node->AddAttr("output_types", new_output_types);
1045 
1046   // Re-order old replicated output edges. Since a node could potentially
1047   // connect to multiple nodes, build a vector<vector<pair>> mapping of
1048   // output index to input nodes/index.
1049   // The outer vector represents the output index, the inner vector
1050   // represents the destination node and input index pair with the possibility
1051   // of multiple node/index pairs.
1052   std::vector<std::vector<std::pair<Node*, int>>> replicated_outputs(
1053       old_num_replicated_outputs * num_replicas);
1054   std::vector<const Edge*> old_replicated_edges;
1055   for (const Edge* e : xla_node->out_edges()) {
1056     if (e->src_output() >= 0 &&
1057         e->src_output() < old_num_replicated_outputs * num_replicas) {
1058       replicated_outputs[e->src_output()].push_back(
1059           std::make_pair(e->dst(), e->dst_input()));
1060       old_replicated_edges.push_back(e);
1061     }
1062   }
1063   for (const Edge* e : old_replicated_edges) {
1064     g->RemoveEdge(e);
1065   }
1066   for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1067     for (int output_index = 0; output_index < old_num_replicated_outputs;
1068          output_index++) {
1069       for (const auto& node_input_pair :
1070            replicated_outputs[replica_id * old_num_replicated_outputs +
1071                               output_index]) {
1072         Node* dst = node_input_pair.first;
1073         int dst_input = node_input_pair.second;
1074         g->AddEdge(xla_node,
1075                    replica_id * new_num_replicated_outputs + output_index, dst,
1076                    dst_input);
1077       }
1078     }
1079   }
1080 
1081   // Copy all nodes in `oc_nodes_at_tail` to host graph, and also replicate
1082   // them.
1083   std::map<Node*, std::vector<Node*>> node_images;
1084   for (Node* n : oc_nodes_at_tail) {
1085     for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1086       NodeDef copy_def = n->def();
1087       copy_def.set_name(absl::StrCat(n->name(), "_tail_oc/R", replica_id));
1088       copy_def.clear_device();
1089 
1090       TF_ASSIGN_OR_RETURN(Node * copy_node, g->AddNode(copy_def));
1091 
1092       copy_node->AddAttr(kXlaReplicaIdAttrName, replica_id);
1093       copy_node->AddAttr(kTPUReplicateAttr, cluster_name);
1094 
1095       for (const Edge* e : n->out_edges()) {
1096         if (e->dst() == xla_graph->sink_node()) {
1097           continue;
1098         }
1099         // Either e->dst() is _Retval, or it's in `node_images`.
1100         if (e->dst()->IsRetval()) {
1101           int index;
1102           TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->attrs(), "index", &index));
1103           for (const auto& output :
1104                replicated_outputs[replica_id * old_num_replicated_outputs +
1105                                   index]) {
1106             // Remove original input edge, if existent.
1107             const Edge* original_edge;
1108             Status s = output.first->input_edge(output.second, &original_edge);
1109             if (s.ok()) {
1110               g->RemoveEdge(original_edge);
1111             }
1112             g->AddEdge(copy_node, e->src_output(), output.first, output.second);
1113           }
1114         } else {
1115           g->AddEdge(copy_node, e->src_output(),
1116                      node_images[e->dst()][replica_id], e->dst_input());
1117         }
1118       }
1119 
1120       // Add attribute "_xla_tail_outside_compilation" to `copy_node`, and add a
1121       // control edge between `xla_node` and `copy_node`. As a result, in later
1122       // rewriting pass, a control edge will be added between `copy_node` and
1123       // "control_after" node for the XLA computation, so `copy_node` will be
1124       // executed before XLA computation's final results.
1125       copy_node->AddAttr("_xla_tail_outside_compilation", true);
1126       g->AddControlEdge(xla_node, copy_node);
1127 
1128       // Add control edge between `pivot_node` and `copy_node`, so `copy_node`
1129       // belongs to same while loop as `xla_node`.
1130       if (pivot_node) {
1131         g->AddControlEdge(pivot_node, copy_node);
1132       }
1133 
1134       node_images[n].push_back(copy_node);
1135     }
1136   }
1137 
1138   // Connect new output values of `xla_node` to dst nodes of `oc_input_edges`.
1139   for (int i = 0; i < new_ret_types.size(); i++) {
1140     const Edge* original_edge = oc_input_edges[i];
1141     for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1142       int src_output = replica_id * new_num_replicated_outputs +
1143                        old_num_replicated_outputs + i;
1144       Node* dst = node_images[original_edge->dst()][replica_id];
1145       g->AddEdge(xla_node, src_output, dst, original_edge->dst_input());
1146     }
1147   }
1148 
1149   // Create new _Retval nodes in `xla_graph`.
1150   for (int i = old_num_replicated_outputs; i < new_num_replicated_outputs;
1151        i++) {
1152     NodeDefBuilder ret_builder(absl::StrCat("ret_", i),
1153                                FunctionLibraryDefinition::kRetOp);
1154     ret_builder.Attr("T", new_ret_types[i - old_num_replicated_outputs]);
1155     ret_builder.Attr("index", i);
1156     const Edge* original_edge = oc_input_edges[i - old_num_replicated_outputs];
1157     Node* src = original_edge->src();
1158     int src_output = original_edge->src_output();
1159     ret_builder.Input(src->name(), src_output, src->output_type(src_output));
1160     NodeDef ret_def;
1161     TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1162     TF_ASSIGN_OR_RETURN(Node * ret_node, xla_graph->AddNode(ret_def));
1163     xla_graph->RemoveEdge(original_edge);
1164     xla_graph->AddEdge(src, src_output, ret_node, 0);
1165   }
1166 
1167   // Remove `oc_nodes_at_tail`.
1168   for (Node* n : oc_nodes_at_tail) {
1169     xla_graph->RemoveNode(n);
1170   }
1171 
1172   // We cannot leave _Retval with no input. Add a placeholder input, which will
1173   // be removed later with unused _Retval.
1174   std::vector<Node*> unused_rets;
1175   for (Node* n : xla_graph->nodes()) {
1176     if (n->IsRetval() && n->in_edges().empty()) {
1177       unused_rets.push_back(n);
1178     }
1179   }
1180   for (Node* n : unused_rets) {
1181     NodeDefBuilder builder(absl::StrCat("placeholder_", n->name()),
1182                            "Placeholder");
1183     DataType dtype;
1184     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
1185     builder.Attr("dtype", dtype);
1186     builder.Attr(kXlaIsPlaceholderForTailOcAttrName, true);
1187     NodeDef def;
1188     TF_RETURN_IF_ERROR(builder.Finalize(&def));
1189     TF_ASSIGN_OR_RETURN(Node * placeholder, xla_graph->AddNode(def));
1190     xla_graph->AddEdge(placeholder, 0, n, 0);
1191   }
1192 
1193   VLOG(4) << "MoveTailOutsideCompilationToHost host graph: "
1194           << DumpGraphToFile(absl::StrCat("move_tail_oc_host_", xla_func_name),
1195                              *g);
1196   VLOG(4) << "MoveTaildOutsideCompilationToHost XLA graph: "
1197           << DumpGraphToFile(absl::StrCat("move_tail_oc_xla_", xla_func_name),
1198                              *xla_graph);
1199 
1200   return OkStatus();
1201 }
1202 
ReplaceArgUsedByOutsideCompilationWithPlaceholder(const string & outside_compilation_attr_name,const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)1203 Status ReplaceArgUsedByOutsideCompilationWithPlaceholder(
1204     const string& outside_compilation_attr_name, const string& xla_func_name,
1205     Graph* g, Graph* xla_graph, Node* xla_node) {
1206   std::vector<DataType> input_types;
1207   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tinputs", &input_types));
1208   int num_distributed_vars;
1209   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
1210                                  &num_distributed_vars));
1211   int num_replicas;
1212   TF_RETURN_IF_ERROR(
1213       GetNodeAttr(xla_node->attrs(), "num_replicas", &num_replicas));
1214   int num_per_replica_inputs =
1215       (input_types.size() - num_distributed_vars) / num_replicas;
1216 
1217   for (Node* n : xla_graph->op_nodes()) {
1218     if (!n->IsArg()) {
1219       continue;
1220     }
1221 
1222     DataType dtype;
1223     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
1224     // TODO(b/74023706): enable moving normal data tensors.
1225     if (dtype != DT_RESOURCE) {
1226       continue;
1227     }
1228 
1229     std::vector<const Edge*> oc_out_edges;
1230     for (const Edge* e : n->out_edges()) {
1231       if (e->IsControlEdge() ||
1232           !HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1233         continue;
1234       }
1235 
1236       oc_out_edges.push_back(e);
1237     }
1238     if (oc_out_edges.empty()) {
1239       continue;
1240     }
1241 
1242     // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
1243     // will become very expensive in this case because it is doing a linear
1244     // search inside. Create an input_edges vector ahead to make the lookups
1245     // faster.
1246     std::vector<const Edge*> input_edges;
1247     TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
1248 
1249     // Build an IdentityN node to record inputs for this _Arg node.
1250     int index;
1251     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
1252     string oc_identifier = absl::StrCat("oc_only_arg_", index);
1253     NodeDefBuilder id_builder(absl::StrCat(oc_identifier, "_inputs"),
1254                               "IdentityN");
1255     std::vector<DataType> dtypes(num_replicas, dtype);
1256     id_builder.Attr("T", dtypes);
1257     id_builder.Attr(kXlaOutsideCompilationInputsAttrName, oc_identifier);
1258     std::vector<NodeDefBuilder::NodeOut> inputs(num_replicas);
1259     if (index >= num_per_replica_inputs) {
1260       const Edge* e = input_edges.at(num_replicas * num_per_replica_inputs +
1261                                      (index - num_per_replica_inputs));
1262       for (int i = 0; i < num_replicas; i++) {
1263         inputs[i] =
1264             NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
1265                                     e->src()->output_type(e->src_output())};
1266       }
1267     } else {
1268       for (int i = 0; i < num_replicas; i++) {
1269         const Edge* e = input_edges.at(i * num_per_replica_inputs + index);
1270         inputs[i] =
1271             NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
1272                                     e->src()->output_type(e->src_output())};
1273       }
1274     }
1275     id_builder.Input(inputs);
1276     NodeDef id_def;
1277     TF_RETURN_IF_ERROR(id_builder.Finalize(&id_def));
1278     TF_ASSIGN_OR_RETURN(Node * id_node, g->AddNode(id_def));
1279     if (index >= num_per_replica_inputs) {
1280       const Edge* e = input_edges.at(num_replicas * num_per_replica_inputs +
1281                                      (index - num_per_replica_inputs));
1282       for (int i = 0; i < num_replicas; i++) {
1283         g->AddEdge(e->src(), e->src_output(), id_node, i);
1284       }
1285     } else {
1286       for (int i = 0; i < num_replicas; i++) {
1287         const Edge* e = input_edges.at(i * num_per_replica_inputs + index);
1288         g->AddEdge(e->src(), e->src_output(), id_node, i);
1289       }
1290     }
1291 
1292     for (const Edge* e : oc_out_edges) {
1293       // 'e' will use a new Placeholder node as input.
1294       NodeDefBuilder ph_builder(xla_graph->NewName("ph_for_arg_in_oc_"),
1295                                 "Placeholder");
1296       ph_builder.Attr("dtype", dtype);
1297 
1298       string outside_compilation_attr;
1299       TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->def(), kOutsideCompilationAttr,
1300                                      &outside_compilation_attr));
1301       ph_builder.Attr(kOutsideCompilationAttr, outside_compilation_attr);
1302       ph_builder.Attr(kXlaOutsideCompilationInputsAttrName, oc_identifier);
1303       ph_builder.Attr(kXlaIsPlaceholderForArg, true);
1304       NodeDef ph_def;
1305       TF_RETURN_IF_ERROR(ph_builder.Finalize(&ph_def));
1306       TF_ASSIGN_OR_RETURN(Node * ph_node, xla_graph->AddNode(ph_def));
1307       Node* dst = e->dst();
1308       int dst_input = e->dst_input();
1309       xla_graph->RemoveEdge(e);
1310       xla_graph->AddEdge(ph_node, 0, dst, dst_input);
1311       xla_graph->AddControlEdge(xla_graph->source_node(), ph_node);
1312     }
1313   }
1314   VLOG(4) << "ReplaceOutsideCompilationOnlyArgWithPlaceholder host graph: "
1315           << DumpGraphToFile(
1316                  absl::StrCat("replace_oc_only_arg_host_", xla_func_name), *g);
1317   VLOG(4) << "ReplaceOutsideCompilationOnlyArgWithPlaceholder XLA graph: "
1318           << DumpGraphToFile(
1319                  absl::StrCat("replace_oc_only_arg_xla_", xla_func_name),
1320                  *xla_graph);
1321   return OkStatus();
1322 }
1323 
1324 // If there are any unused _Retval nodes in `xla_graph` (whose input is a
1325 // Placeholder node), remove them from `xla_graph` and remove corresponding
1326 // output edge in host graph `g`.
RemoveUnusedXlaOutput(const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)1327 Status RemoveUnusedXlaOutput(const string& xla_func_name, Graph* g,
1328                              Graph* xla_graph, Node* xla_node) {
1329   // Find unused _Retval nodes, and remove them.
1330   std::vector<DataType> output_types;
1331   TF_RETURN_IF_ERROR(
1332       GetNodeAttr(xla_node->def(), "output_types", &output_types));
1333   int num_replicas;
1334   TF_RETURN_IF_ERROR(
1335       GetNodeAttr(xla_node->def(), "num_replicas", &num_replicas));
1336   int num_replicated_outputs = output_types.size() / num_replicas;
1337   std::set<int> ret_indices_to_remove;
1338   std::vector<Node*> ret_nodes_to_update, nodes_to_remove;
1339   int num_rets = 0;
1340   for (Node* n : xla_graph->nodes()) {
1341     if (!n->IsRetval()) {
1342       continue;
1343     }
1344 
1345     num_rets++;
1346 
1347     const Edge* e;
1348     TF_RETURN_IF_ERROR(n->input_edge(0, &e));
1349     if (e->src()->type_string() != "Placeholder" ||
1350         !HasNodeAttr(e->src()->def(), kXlaIsPlaceholderForTailOcAttrName)) {
1351       ret_nodes_to_update.push_back(n);
1352       continue;
1353     }
1354 
1355     int index;
1356     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
1357     ret_indices_to_remove.insert(index);
1358     nodes_to_remove.push_back(e->src());
1359     nodes_to_remove.push_back(n);
1360   }
1361   for (Node* n : nodes_to_remove) {
1362     xla_graph->RemoveNode(n);
1363   }
1364 
1365   // Update `index` for other _Arg nodes.
1366   std::map<int, int> ret_index_mapping;
1367   int new_ret_index = 0;
1368   for (int i = 0; i < num_rets; i++) {
1369     if (ret_indices_to_remove.find(i) != ret_indices_to_remove.end()) {
1370       continue;
1371     } else {
1372       ret_index_mapping[i] = new_ret_index;
1373       new_ret_index++;
1374     }
1375   }
1376   for (Node* n : ret_nodes_to_update) {
1377     int index;
1378     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
1379     n->ClearAttr("index");
1380     n->AddAttr("index", ret_index_mapping[index]);
1381   }
1382 
1383   // Update `output_types` attribute for `xla_node`.
1384   std::vector<DataType> new_output_types;
1385   for (int i = 0; i < num_replicas; i++) {
1386     for (const auto& e : ret_index_mapping) {
1387       new_output_types.push_back(output_types[e.first]);
1388     }
1389   }
1390 
1391   xla_node->ClearAttr("output_types");
1392   xla_node->AddAttr("output_types", new_output_types);
1393 
1394   // Re-order replicated output edges for `xla_node`.
1395   std::vector<std::vector<const Edge*>> output_edges(num_replicas *
1396                                                      num_replicated_outputs);
1397   for (const Edge* e : xla_node->out_edges()) {
1398     if (e->src_output() >= 0 &&
1399         e->src_output() < num_replicas * num_replicated_outputs) {
1400       output_edges[e->src_output()].push_back(e);
1401     }
1402   }
1403   for (int i = 0; i < num_replicas; i++) {
1404     for (int j = 0; j < num_replicated_outputs; j++) {
1405       auto iter = ret_index_mapping.find(j);
1406       if (iter != ret_index_mapping.end()) {
1407         for (const Edge* e : output_edges[i * num_replicated_outputs + j]) {
1408           Node* dst = e->dst();
1409           int dst_input = e->dst_input();
1410           int src_output =
1411               i * (num_replicated_outputs - ret_indices_to_remove.size()) +
1412               iter->second;
1413           g->RemoveEdge(e);
1414           g->AddEdge(xla_node, src_output, dst, dst_input);
1415         }
1416       } else {
1417         TF_RET_CHECK(output_edges[i * num_replicated_outputs + j].empty())
1418             << "Output edge not removed: "
1419             << output_edges[i * num_replicated_outputs + j][0]->DebugString();
1420       }
1421     }
1422   }
1423 
1424   VLOG(4) << "RemoveUnusedXlaOutput host graph: "
1425           << DumpGraphToFile(
1426                  absl::StrCat("remove_unused_output_host_", xla_func_name), *g);
1427   VLOG(4) << "RemoveUnusedXlaOutput XLA graph: "
1428           << DumpGraphToFile(
1429                  absl::StrCat("remove_unused_output_xla_", xla_func_name),
1430                  *xla_graph);
1431 
1432   return OkStatus();
1433 }
1434 
1435 // For data edges between _Arg and _Retval in `xla_graph`, remove them and
1436 // change input/output edges in `g` (host graph). For now, we only consider
1437 // replicated inputs.
RemoveEdgesBetweenArgAndRetval(const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)1438 Status RemoveEdgesBetweenArgAndRetval(const string& xla_func_name, Graph* g,
1439                                       Graph* xla_graph, Node* xla_node) {
1440   // Collect data edges between _Arg and _Retval.
1441   int num_replicas;
1442   TF_RETURN_IF_ERROR(
1443       GetNodeAttr(xla_node->def(), "num_replicas", &num_replicas));
1444   std::vector<DataType> input_types;
1445   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tinputs", &input_types));
1446   int num_distributed_vars;
1447   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
1448                                  &num_distributed_vars));
1449   int old_num_per_replica_inputs =
1450       (input_types.size() - num_distributed_vars) / num_replicas;
1451   std::vector<DataType> output_types;
1452   TF_RETURN_IF_ERROR(
1453       GetNodeAttr(xla_node->def(), "output_types", &output_types));
1454   int old_num_outputs = output_types.size() / num_replicas;
1455   std::vector<const Edge*> edges;
1456   for (const Edge* e : xla_graph->edges()) {
1457     if (!e->IsControlEdge() && e->src()->IsArg() && e->dst()->IsRetval()) {
1458       edges.push_back(e);
1459     }
1460   }
1461 
1462   // In host graph `g`, remove output edge from `xla_node` and connect input &
1463   // output directly.
1464   std::vector<std::vector<const Edge*>> xla_node_out_edges(
1465       xla_node->num_outputs());
1466   for (const Edge* e : xla_node->out_edges()) {
1467     if (!e->IsControlEdge()) {
1468       xla_node_out_edges[e->src_output()].push_back(e);
1469     }
1470   }
1471 
1472   // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
1473   // will become very expensive in this case because it is doing a linear
1474   // search inside. Create an input_edges vector ahead to make the lookups
1475   // faster.
1476   std::vector<const Edge*> input_edges;
1477   TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
1478   for (const Edge* e : edges) {
1479     int arg_index;
1480     TF_RETURN_IF_ERROR(GetNodeAttr(e->src()->def(), "index", &arg_index));
1481     int ret_index;
1482     TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->def(), "index", &ret_index));
1483 
1484     for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1485       int input_index;
1486       if (arg_index < old_num_per_replica_inputs) {
1487         input_index = replica_id * old_num_per_replica_inputs + arg_index;
1488       } else {
1489         input_index = num_replicas * old_num_per_replica_inputs +
1490                       (arg_index - old_num_per_replica_inputs);
1491       }
1492       const Edge* input_edge = input_edges.at(input_index);
1493 
1494       int output_index = replica_id * old_num_outputs + ret_index;
1495       for (const Edge* output_edge : xla_node_out_edges[output_index]) {
1496         Node* dst = output_edge->dst();
1497         int dst_input = output_edge->dst_input();
1498 
1499         g->RemoveEdge(output_edge);
1500         g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
1501       }
1502     }
1503   }
1504 
1505   // Remove edges from `xla_graph`. Add a Placeholder node for the _Retval node,
1506   // which will be removed by `RemoveUnusedXlaOutput()` later.
1507   for (const Edge* e : edges) {
1508     NodeDefBuilder placeholder_builder(
1509         absl::StrCat("placeholder_", e->dst()->name()), "Placeholder");
1510     placeholder_builder.Attr("dtype", e->src()->output_type(e->src_output()));
1511     placeholder_builder.Attr(kXlaIsPlaceholderForTailOcAttrName, true);
1512     NodeDef placeholder_def;
1513     TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
1514     TF_ASSIGN_OR_RETURN(Node * placeholder_node,
1515                         xla_graph->AddNode(placeholder_def));
1516 
1517     Node* dst = e->dst();
1518     int dst_input = e->dst_input();
1519     xla_graph->RemoveEdge(e);
1520     xla_graph->AddEdge(placeholder_node, 0, dst, dst_input);
1521   }
1522 
1523   VLOG(4) << "RemoveUnusedArgRetvalPair host graph: "
1524           << DumpGraphToFile(
1525                  absl::StrCat("remove_unused_arg_ret_host_", xla_func_name),
1526                  *g);
1527   VLOG(4) << "RemoveUnusedArgRetvalPair XLA graph: "
1528           << DumpGraphToFile(
1529                  absl::StrCat("remove_unused_arg_ret_xla_", xla_func_name),
1530                  *xla_graph);
1531 
1532   return OkStatus();
1533 }
1534 
1535 // Remove any TPUReplicatedInput nodes with no output edges. Those nodes are
1536 // usually TPUMirroredVariable handles which are not used by any computations.
RemoveUnusedTPUReplicatedInputs(Graph * graph)1537 void RemoveUnusedTPUReplicatedInputs(Graph* graph) {
1538   for (Node* n : graph->nodes()) {
1539     if (n->type_string() == kTPUReplicatedInput) {
1540       bool has_output = false;
1541       for (const Edge* e : n->out_edges()) {
1542         if (!e->dst()->IsSink()) {
1543           has_output = true;
1544           break;
1545         }
1546       }
1547       if (!has_output) {
1548         // Remove any TPUPartitionedInput node from the src nodes of the
1549         // to-be-removed TPUReplicatedInput node
1550         std::vector<Node*> to_be_removed_src_nodes;
1551         for (const auto& e_in : n->in_edges()) {
1552           if (!e_in->IsControlEdge() &&
1553               e_in->src()->type_string() == kTPUPartitionedInput)
1554             to_be_removed_src_nodes.push_back(e_in->src());
1555         }
1556         graph->RemoveNode(n);
1557         for (Node* node : to_be_removed_src_nodes) {
1558           graph->RemoveNode(node);
1559         }
1560       }
1561     }
1562   }
1563 }
1564 
1565 // We might have duplicated cluster names in the graph, e.g. when a tf.function
1566 // containing tpu_strategy.run() is called multiple times with
1567 // the same inputs. Find clusters with duplicated names and rename them.
RenameClustersWithDuplicatedNames(Graph * g)1568 Status RenameClustersWithDuplicatedNames(Graph* g) {
1569   // Find all TPU clusters by finding all TPUReplicateMetadata nodes.
1570   std::unordered_map<string, std::vector<Node*>> cluster_name_to_metadata_nodes;
1571   std::unordered_set<string> cluster_names;
1572   for (Node* n : g->nodes()) {
1573     if (n->type_string() != "TPUReplicateMetadata") {
1574       continue;
1575     }
1576     string cluster_name;
1577     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kTPUReplicateAttr, &cluster_name));
1578     cluster_name_to_metadata_nodes[cluster_name].push_back(n);
1579     cluster_names.insert(cluster_name);
1580   }
1581   // Look for clusters with duplicated name.
1582   for (const auto& iter : cluster_name_to_metadata_nodes) {
1583     if (iter.second.size() == 1) {
1584       continue;
1585     }
1586 
1587     // Rename clusters.
1588     for (int i = 1; i < iter.second.size(); i++) {
1589       // Find an available cluster name.
1590       string new_cluster_name;
1591       int cluster_name_suffix = 1;
1592       while (true) {
1593         new_cluster_name = absl::StrCat(iter.first, "_", cluster_name_suffix);
1594         if (cluster_names.find(new_cluster_name) == cluster_names.end()) {
1595           break;
1596         }
1597         cluster_name_suffix++;
1598       }
1599       cluster_names.insert(new_cluster_name);
1600 
1601       // Change _tpu_replicate attribute for all nodes in this cluster.
1602       // Start with outputs of TPUReplicateMetadata and follow output edges.
1603       std::queue<Node*> queue;
1604       queue.push(iter.second.at(i));
1605       std::unordered_set<Node*> visited;
1606       while (!queue.empty()) {
1607         Node* n = queue.front();
1608         queue.pop();
1609 
1610         visited.insert(n);
1611 
1612         n->ClearAttr(kTPUReplicateAttr);
1613         n->AddAttr(kTPUReplicateAttr, new_cluster_name);
1614 
1615         string cluster_name;
1616         for (const Edge* e : n->out_edges()) {
1617           if (GetNodeAttr(e->dst()->def(), kTPUReplicateAttr, &cluster_name)
1618                   .ok() &&
1619               cluster_name == iter.first &&
1620               visited.find(e->dst()) == visited.end()) {
1621             queue.push(e->dst());
1622           }
1623         }
1624       }
1625       // Change "_tpu_compilation_status" attr for TPUCompilationResult node.
1626       for (const Edge* e : iter.second.at(i)->out_edges()) {
1627         if (e->dst()->type_string() == "TPUCompilationResult") {
1628           e->dst()->ClearAttr("_tpu_compilation_status");
1629           e->dst()->AddAttr("_tpu_compilation_status", new_cluster_name);
1630         }
1631       }
1632     }
1633   }
1634   return OkStatus();
1635 }
1636 
1637 // Instantiate a function that is associated with a functional control flow
1638 // node. The function name is found by looking up `function_name_attr` of given
1639 // node.
InstantiateAssociatedFunction(const Node & n,absl::string_view function_name_attr,FunctionLibraryDefinition * fld)1640 xla::StatusOr<std::unique_ptr<FunctionBody>> InstantiateAssociatedFunction(
1641     const Node& n, absl::string_view function_name_attr,
1642     FunctionLibraryDefinition* fld) {
1643   std::unique_ptr<FunctionBody> fbody;
1644   NameAttrList func_attr_list;
1645   TF_RETURN_IF_ERROR(GetNodeAttr(n.def(), function_name_attr, &func_attr_list));
1646   const FunctionDef* fdef = fld->Find(func_attr_list.name());
1647   if (fdef == nullptr) {
1648     return errors::Internal("Cannot find ", function_name_attr, " function",
1649                             "for node ", n.DebugString());
1650   }
1651   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1652       *fdef, AttrSlice(&func_attr_list.attr()), fld, &fbody));
1653   return fbody;
1654 }
1655 
1656 // Find inputs of If node that are only used for outside compilation if used at
1657 // all in both if/else branches
FindArgsToLiftForIfNode(const Node & if_node,FunctionLibraryDefinition * fld)1658 xla::StatusOr<absl::flat_hash_set<int>> FindArgsToLiftForIfNode(
1659     const Node& if_node, FunctionLibraryDefinition* fld) {
1660   absl::flat_hash_set<int> args_to_lift_indices;
1661   std::vector<DataType> dtypes;
1662   TF_RETURN_IF_ERROR(GetNodeAttr(if_node.def(), "Tin", &dtypes));
1663 
1664   int num_args = dtypes.size();
1665 
1666   for (int i = 0; i < num_args; i++) {
1667     // TODO(b/74023706): enable non resource inputs as well.
1668     if (dtypes[i] == DT_RESOURCE) {
1669       args_to_lift_indices.insert(i);
1670     }
1671   }
1672 
1673   TF_ASSIGN_OR_RETURN(
1674       std::unique_ptr<FunctionBody> then_branch_fbody,
1675       InstantiateAssociatedFunction(if_node, "then_branch", fld));
1676 
1677   TF_ASSIGN_OR_RETURN(
1678       std::unique_ptr<FunctionBody> else_branch_fbody,
1679       InstantiateAssociatedFunction(if_node, "else_branch", fld));
1680 
1681   for (int i = 0; i < num_args; ++i) {
1682     bool used = false;
1683 
1684     const Node* then_arg_node = then_branch_fbody->arg_nodes[i];
1685     for (const Edge* e : then_arg_node->out_edges()) {
1686       used = true;
1687       if (e->IsControlEdge() ||
1688           HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr))
1689         continue;
1690 
1691       args_to_lift_indices.erase(i);
1692       break;
1693     }
1694 
1695     const Node* else_arg_node = else_branch_fbody->arg_nodes[i];
1696     for (const Edge* e : else_arg_node->out_edges()) {
1697       used = true;
1698       if (e->IsControlEdge() ||
1699           HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr))
1700         continue;
1701 
1702       args_to_lift_indices.erase(i);
1703       break;
1704     }
1705 
1706     // Do not lift arguments that are not used at all. Otherwise, this unused
1707     // arg would be outside compiled, its output tensor will be forced to
1708     // transfer to host needlessly.
1709     if (!used) args_to_lift_indices.erase(i);
1710   }
1711 
1712   return args_to_lift_indices;
1713 }
1714 
1715 // Find inputs of While node that are:
1716 // 1. not used in cond func,
1717 // 2. only used for outside compilation in body func,
1718 // 3. loop invariant.
1719 // These inputs can be lifted out of the while loop.
FindArgsToLiftForWhileNode(Node * while_node,FunctionLibraryDefinition * fld)1720 xla::StatusOr<absl::flat_hash_set<int>> FindArgsToLiftForWhileNode(
1721     Node* while_node, FunctionLibraryDefinition* fld) {
1722   // DT_RESOURCE inputs are candidates.
1723   absl::flat_hash_set<int> result;
1724   std::vector<DataType> dtypes;
1725   TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "T", &dtypes));
1726   for (int i = 0; i < dtypes.size(); i++) {
1727     // TODO(b/74023706): enable non resource inputs as well.
1728     if (dtypes[i] == DT_RESOURCE) {
1729       result.insert(i);
1730     }
1731   }
1732 
1733   // Remove inputs that are used in cond func.
1734   NameAttrList cond_func;
1735   TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "cond", &cond_func));
1736   const FunctionDef* cond_fdef = fld->Find(cond_func.name());
1737   if (cond_fdef == nullptr) {
1738     return errors::Internal("Cannot find cond function ", cond_func.name(),
1739                             " for while node ", while_node->DebugString());
1740   }
1741   std::unique_ptr<FunctionBody> cond_fbody;
1742   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1743       *cond_fdef, AttrSlice(&cond_func.attr()), fld, &cond_fbody));
1744   for (int i = 0; i < cond_fbody->arg_nodes.size(); i++) {
1745     const Node* arg_node = cond_fbody->arg_nodes[i];
1746     for (const Edge* e : arg_node->out_edges()) {
1747       if (!e->IsControlEdge()) {
1748         result.erase(i);
1749       }
1750     }
1751   }
1752 
1753   // Remove inputs that are not loop invariant.
1754   NameAttrList body_func;
1755   TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_func));
1756   const FunctionDef* body_fdef = fld->Find(body_func.name());
1757   if (body_fdef == nullptr) {
1758     return errors::Internal("Cannot find body function ", body_func.name(),
1759                             " for while node ", while_node->DebugString());
1760   }
1761   std::unique_ptr<FunctionBody> body_fbody;
1762   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1763       *body_fdef, AttrSlice(&body_func.attr()), fld, &body_fbody));
1764   for (int i = 0; i < body_fbody->ret_nodes.size(); i++) {
1765     const Node* node = body_fbody->ret_nodes[i];
1766     do {
1767       TF_RETURN_IF_ERROR(node->input_node(0, &node));
1768     } while (node->IsIdentity());
1769     if (node != body_fbody->arg_nodes[i]) {
1770       result.erase(i);
1771     }
1772   }
1773 
1774   // Remove inputs that only have one output edge (loop invariant, but not used
1775   // in outside compilation).
1776   for (int i = 0; i < body_fbody->arg_nodes.size(); i++) {
1777     const Node* arg_node = body_fbody->arg_nodes[i];
1778     int data_edge_count = std::count_if(
1779         arg_node->out_edges().begin(), arg_node->out_edges().end(),
1780         [](const Edge* e) { return !e->IsControlEdge(); });
1781     if (data_edge_count == 1) {
1782       result.erase(i);
1783     }
1784   }
1785 
1786   // Remove inputs that have non-outside-compilation usage.
1787   for (int i = 0; i < body_fbody->arg_nodes.size(); i++) {
1788     const Node* arg_node = body_fbody->arg_nodes[i];
1789     for (const Edge* e : arg_node->out_edges()) {
1790       if (!e->dst()->IsRetval() &&
1791           !HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1792         result.erase(i);
1793         break;
1794       }
1795     }
1796   }
1797 
1798   return result;
1799 }
1800 
1801 // Find inputs of function call node that are only used for outside compilation.
1802 // These inputs can be lifted out of the function call node.
FindArgsToLiftForCallNode(Node * call_node,const FunctionBody & fbody)1803 xla::StatusOr<absl::flat_hash_set<int>> FindArgsToLiftForCallNode(
1804     Node* call_node, const FunctionBody& fbody) {
1805   // DT_RESOURCE inputs are candidates.
1806   absl::flat_hash_set<int> result;
1807   std::vector<DataType> dtypes(call_node->input_types().begin(),
1808                                call_node->input_types().end());
1809   for (int i = 0; i < dtypes.size(); i++) {
1810     // TODO(b/74023706): enable for non resource inputs as well.
1811     if (dtypes[i] == DT_RESOURCE) {
1812       result.insert(i);
1813     }
1814   }
1815 
1816   // Remove inputs that have non-outside-compilation usage, or not used at all.
1817   for (int i = 0; i < fbody.arg_nodes.size(); i++) {
1818     const Node* arg_node = fbody.arg_nodes[i];
1819     if (arg_node->out_edges().empty()) {
1820       result.erase(i);
1821       continue;
1822     }
1823 
1824     for (const Edge* e : arg_node->out_edges()) {
1825       if (!HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1826         result.erase(i);
1827         break;
1828       }
1829     }
1830   }
1831   return result;
1832 }
1833 
1834 Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr,
1835                                       FunctionLibraryDefinition* fld,
1836                                       int* lifted_arg_count, bool* rewritten);
1837 
LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(const FunctionBody & fbody,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,int * lifted_arg_count,absl::optional<string> new_func_name,bool * rewritten)1838 Status LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
1839     const FunctionBody& fbody, FunctionLibraryRuntime* flr,
1840     FunctionLibraryDefinition* fld, int* lifted_arg_count,
1841     absl::optional<string> new_func_name, bool* rewritten) {
1842   *rewritten = false;
1843   TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgs(
1844       fbody.graph, flr, fld, lifted_arg_count, rewritten));
1845 
1846   if (*rewritten) {
1847     FunctionDef rewritten_fdef;
1848     TF_RETURN_IF_ERROR(GraphToFunctionDef(
1849         *(fbody.graph), fbody.fdef.signature().name(), &rewritten_fdef));
1850     if (new_func_name) {
1851       rewritten_fdef.mutable_signature()->set_name(*new_func_name);
1852       TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
1853     } else {
1854       TF_RETURN_IF_ERROR(
1855           fld->ReplaceFunction(fbody.fdef.signature().name(), rewritten_fdef));
1856     }
1857   }
1858 
1859   return OkStatus();
1860 }
1861 
MakeIdentityNodesForArgsToLift(const absl::flat_hash_set<int> & args_to_lift,const int arg_to_input_edge_offset,Graph * g,Node * n,absl::flat_hash_map<int,string> * lifted_arg_index_to_oc_cluster_name,int * lifted_arg_count)1862 Status MakeIdentityNodesForArgsToLift(
1863     const absl::flat_hash_set<int>& args_to_lift,
1864     const int arg_to_input_edge_offset, Graph* g, Node* n,
1865     absl::flat_hash_map<int, string>* lifted_arg_index_to_oc_cluster_name,
1866     int* lifted_arg_count) {
1867   int num_input = n->num_inputs();
1868   for (int arg_index = 0; arg_index < num_input; ++arg_index) {
1869     if (!args_to_lift.contains(arg_index)) continue;
1870 
1871     int input_edge_index = arg_index + arg_to_input_edge_offset;
1872     const Edge* arg_edge;
1873     TF_RETURN_IF_ERROR(n->input_edge(input_edge_index, &arg_edge));
1874 
1875     string node_name =
1876         g->NewName(absl::StrCat("lifted_arg", *lifted_arg_count));
1877     (*lifted_arg_count)++;
1878     (*lifted_arg_index_to_oc_cluster_name)[arg_index] = node_name;
1879     NodeDefBuilder id_builder(node_name, "Identity");
1880     id_builder.Attr("T", n->input_type(input_edge_index));
1881     id_builder.Attr(kOutsideCompilationAttr, id_builder.node_name());
1882     id_builder.Attr(kXlaIsLiftedArgAttrName, true);
1883     id_builder.Input(arg_edge->src()->name(), arg_edge->src_output(),
1884                      n->input_type(input_edge_index));
1885     NodeDef id_def;
1886     TF_RETURN_IF_ERROR(id_builder.Finalize(&id_def));
1887     TF_ASSIGN_OR_RETURN(Node * id_node, g->AddNode(id_def));
1888     g->AddEdge(arg_edge->src(), arg_edge->src_output(), id_node, 0);
1889     g->AddControlEdge(id_node, n);
1890   }
1891 
1892   return OkStatus();
1893 }
1894 
1895 // Replaces all usages of lifted args with placeholder nodes. Afterwards,
1896 // removing these args should be safe since they no longer have users.
RemoveArgsToLiftFromFunctionBody(const absl::flat_hash_set<int> & args_to_lift,const std::vector<DataType> & arg_dtypes,const absl::flat_hash_map<int,string> & lifted_arg_index_to_oc_cluster_name,const absl::flat_hash_map<int,int> & index_mapping,const FunctionBody * fbody)1897 Status RemoveArgsToLiftFromFunctionBody(
1898     const absl::flat_hash_set<int>& args_to_lift,
1899     const std::vector<DataType>& arg_dtypes,
1900     const absl::flat_hash_map<int, string>& lifted_arg_index_to_oc_cluster_name,
1901     const absl::flat_hash_map<int, int>& index_mapping,
1902     const FunctionBody* fbody) {
1903   for (int i = 0; i < fbody->arg_nodes.size(); ++i) {
1904     Node* arg_node = fbody->arg_nodes[i];
1905 
1906     if (!args_to_lift.contains(i)) {
1907       int new_index = index_mapping.at(i);
1908       arg_node->ClearAttr("index");
1909       arg_node->AddAttr("index", new_index);
1910       arg_node->ClearAttr("T");
1911       arg_node->AddAttr("T", arg_dtypes[i]);
1912       continue;
1913     }
1914 
1915     std::vector<const Edge*> out_edges_to_oc;
1916     for (const Edge* e : arg_node->out_edges()) {
1917       if (HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1918         out_edges_to_oc.push_back(e);
1919       }
1920     }
1921 
1922     for (const Edge* e : out_edges_to_oc) {
1923       string outside_compilation_cluster;
1924       TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->def(), kOutsideCompilationAttr,
1925                                      &outside_compilation_cluster));
1926       NodeDefBuilder ph_builder(fbody->graph->NewName("lifted_arg"),
1927                                 "Placeholder");
1928       ph_builder.Attr("dtype", arg_dtypes[i]);
1929       ph_builder.Attr(kOutsideCompilationAttr, outside_compilation_cluster);
1930       TF_RET_CHECK(lifted_arg_index_to_oc_cluster_name.contains(i));
1931       ph_builder.Attr(kXlaLiftedArgOutsideCompilationAttrName,
1932                       lifted_arg_index_to_oc_cluster_name.at(i));
1933 
1934       NodeDef ph_def;
1935       TF_RETURN_IF_ERROR(ph_builder.Finalize(&ph_def));
1936 
1937       TF_ASSIGN_OR_RETURN(Node * ph_node, fbody->graph->AddNode(ph_def));
1938 
1939       Node* dst = e->dst();
1940       int dst_input = e->dst_input();
1941       fbody->graph->RemoveEdge(e);
1942       fbody->graph->AddEdge(ph_node, 0, dst, dst_input);
1943     }
1944 
1945     fbody->graph->RemoveNode(arg_node);
1946   }
1947 
1948   return OkStatus();
1949 }
1950 
CleanUpInEdges(const absl::flat_hash_map<int,int> & index_mapping,const int arg_to_input_edge_offset,Graph * g,Node * n)1951 Status CleanUpInEdges(const absl::flat_hash_map<int, int>& index_mapping,
1952                       const int arg_to_input_edge_offset, Graph* g, Node* n) {
1953   int num_inputs = n->num_inputs();
1954   for (int i = 0; i < num_inputs; ++i) {
1955     if (i < arg_to_input_edge_offset) continue;
1956 
1957     int arg_idx = i - arg_to_input_edge_offset;
1958     const Edge* e;
1959     TF_RETURN_IF_ERROR(n->input_edge(i, &e));
1960 
1961     // If an edge maps to a lifted argument, simply remove that edge from graph.
1962     if (!index_mapping.contains(arg_idx)) {
1963       g->RemoveEdge(e);
1964       continue;
1965     }
1966 
1967     // If an edge maps to same input port, nothing to do.
1968     if (index_mapping.at(arg_idx) == arg_idx) continue;
1969 
1970     g->AddEdge(e->src(), e->src_output(), n,
1971                index_mapping.at(arg_idx) + arg_to_input_edge_offset);
1972     g->RemoveEdge(e);
1973   }
1974 
1975   return OkStatus();
1976 }
1977 
UpdateTypeAttribute(const absl::flat_hash_map<int,int> & index_mapping,const string & type_attr_name,const std::vector<DataType> & dtypes,Node * n)1978 Status UpdateTypeAttribute(const absl::flat_hash_map<int, int>& index_mapping,
1979                            const string& type_attr_name,
1980                            const std::vector<DataType>& dtypes, Node* n) {
1981   std::vector<DataType> new_dtypes;
1982   new_dtypes.reserve(index_mapping.size());
1983   for (int i = 0; i < dtypes.size(); ++i) {
1984     if (index_mapping.contains(i)) {
1985       new_dtypes.emplace_back(dtypes[i]);
1986     }
1987   }
1988 
1989   n->ClearAttr(type_attr_name);
1990   n->AddAttr(type_attr_name, new_dtypes);
1991 
1992   return OkStatus();
1993 }
1994 
1995 // While V2 always creates Identity node for each While node output, which is
1996 // not necessary for XLA computation. Remove those Identity nodes.
RemoveOutputIdentityNodesForWhileV2(Graph * g,Node * while_node)1997 void RemoveOutputIdentityNodesForWhileV2(Graph* g, Node* while_node) {
1998   std::vector<const Edge*> edges_to_identity_node;
1999   for (const Edge* e : while_node->out_edges()) {
2000     if (!e->IsControlEdge() && e->dst()->IsIdentity()) {
2001       edges_to_identity_node.push_back(e);
2002     }
2003   }
2004   for (const Edge* e : edges_to_identity_node) {
2005     Node* identity = e->dst();
2006     std::vector<const Edge*> out_edges(identity->out_edges().begin(),
2007                                        identity->out_edges().end());
2008     for (const Edge* out_edge : out_edges) {
2009       if (out_edge->IsControlEdge()) {
2010         g->AddControlEdge(while_node, out_edge->dst());
2011       } else {
2012         Node* dst = out_edge->dst();
2013         int dst_input = out_edge->dst_input();
2014         g->RemoveEdge(out_edge);
2015         g->AddEdge(while_node, e->src_output(), dst, dst_input);
2016       }
2017     }
2018     g->RemoveNode(identity);
2019   }
2020 }
2021 
2022 // If corresponding While node output is used, change it to use While node input
2023 // instead.
ReplaceOutputEdgesWithInputEdgeSourceForWhile(const absl::flat_hash_set<int> & args_to_lift,Graph * g,Node * while_node)2024 Status ReplaceOutputEdgesWithInputEdgeSourceForWhile(
2025     const absl::flat_hash_set<int>& args_to_lift, Graph* g, Node* while_node) {
2026   std::vector<const Edge*> edges_to_replace;
2027   for (const Edge* e : while_node->out_edges()) {
2028     if (args_to_lift.contains(e->src_output())) {
2029       edges_to_replace.push_back(e);
2030     }
2031   }
2032   for (const Edge* e : edges_to_replace) {
2033     const Edge* input_edge;
2034     TF_RETURN_IF_ERROR(while_node->input_edge(e->src_output(), &input_edge));
2035     Node* dst = e->dst();
2036     int dst_input = e->dst_input();
2037     g->RemoveEdge(e);
2038     g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
2039   }
2040 
2041   return OkStatus();
2042 }
2043 
2044 // Calculates mapping from argument index before lifting to index afterwards.
ArgIndexMapping(const int num_args,const absl::flat_hash_set<int> & args_to_lift)2045 absl::flat_hash_map<int, int> ArgIndexMapping(
2046     const int num_args, const absl::flat_hash_set<int>& args_to_lift) {
2047   absl::flat_hash_map<int, int> index_mapping;
2048   int new_index = 0;
2049   for (int i = 0; i < num_args; i++) {
2050     if (!args_to_lift.contains(i)) {
2051       index_mapping[i] = new_index;
2052       ++new_index;
2053     }
2054   }
2055 
2056   return index_mapping;
2057 }
2058 
2059 // Remove outputs of While node body function that maps to lifted arguments.
CleanUpRetvalsForWhileBody(const absl::flat_hash_map<int,int> & index_mapping,const std::vector<DataType> & dtypes,FunctionBody * fbody)2060 void CleanUpRetvalsForWhileBody(
2061     const absl::flat_hash_map<int, int>& index_mapping,
2062     const std::vector<DataType>& dtypes, FunctionBody* fbody) {
2063   for (int i = 0; i < fbody->ret_nodes.size(); i++) {
2064     Node* ret_node = fbody->ret_nodes[i];
2065     if (index_mapping.contains(i)) {
2066       int new_index = index_mapping.at(i);
2067       ret_node->ClearAttr("index");
2068       ret_node->AddAttr("index", new_index);
2069       ret_node->ClearAttr("T");
2070       ret_node->AddAttr("T", dtypes[i]);
2071     } else {
2072       fbody->graph->RemoveNode(ret_node);
2073     }
2074   }
2075 }
2076 
LiftOutsideCompilationOnlyArgsFromWhileNode(Graph * g,Node * while_node,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2077 Status LiftOutsideCompilationOnlyArgsFromWhileNode(
2078     Graph* g, Node* while_node, FunctionLibraryDefinition* fld,
2079     int* lifted_arg_count, bool* rewritten) {
2080   *rewritten = false;
2081 
2082   TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> args_to_lift,
2083                       FindArgsToLiftForWhileNode(while_node, fld));
2084   if (args_to_lift.empty()) return OkStatus();
2085 
2086   RemoveOutputIdentityNodesForWhileV2(g, while_node);
2087 
2088   TF_RETURN_IF_ERROR(ReplaceOutputEdgesWithInputEdgeSourceForWhile(
2089       args_to_lift, g, while_node));
2090 
2091   std::vector<DataType> dtypes;
2092   TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "T", &dtypes));
2093 
2094   absl::flat_hash_map<int, int> index_mapping =
2095       ArgIndexMapping(dtypes.size(), args_to_lift);
2096 
2097   // For each lifted arg, add an outside compilation Identity node to send
2098   // it to host.
2099   absl::flat_hash_map<int, string> lifted_arg_index_to_oc_cluster_name;
2100   TF_RETURN_IF_ERROR(MakeIdentityNodesForArgsToLift(
2101       args_to_lift, /*arg_to_input_edge_offset=*/0, g, while_node,
2102       &lifted_arg_index_to_oc_cluster_name, lifted_arg_count));
2103 
2104   // For cond func, remove _Arg nodes.
2105   TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionBody> cond_fbody,
2106                       InstantiateAssociatedFunction(*while_node, "cond", fld));
2107   TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2108       args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2109       cond_fbody.get()));
2110 
2111   FunctionDef rewritten_cond_fdef;
2112   TF_RETURN_IF_ERROR(GraphToFunctionDef(*(cond_fbody->graph),
2113                                         cond_fbody->fdef.signature().name(),
2114                                         &rewritten_cond_fdef));
2115   TF_RETURN_IF_ERROR(fld->ReplaceFunction(cond_fbody->fdef.signature().name(),
2116                                           rewritten_cond_fdef));
2117 
2118   // For body func, remove _Retval nodes, and replace _Arg nodes with
2119   // Placeholder nodes.
2120   TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionBody> body_fbody,
2121                       InstantiateAssociatedFunction(*while_node, "body", fld));
2122 
2123   TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2124       args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2125       body_fbody.get()));
2126 
2127   CleanUpRetvalsForWhileBody(index_mapping, dtypes, body_fbody.get());
2128 
2129   FunctionDef rewritten_body_fdef;
2130   TF_RETURN_IF_ERROR(GraphToFunctionDef(*(body_fbody->graph),
2131                                         body_fbody->fdef.signature().name(),
2132                                         &rewritten_body_fdef));
2133   TF_RETURN_IF_ERROR(fld->ReplaceFunction(body_fbody->fdef.signature().name(),
2134                                           rewritten_body_fdef));
2135 
2136   // Remove edges from lifted args to While node, and change "T" attr of the
2137   // While node.
2138   TF_RETURN_IF_ERROR(CleanUpInEdges(
2139       index_mapping, /*arg_to_input_edge_offset=*/0, g, while_node));
2140 
2141   TF_RETURN_IF_ERROR(
2142       UpdateTypeAttribute(index_mapping, "T", dtypes, while_node));
2143 
2144   *rewritten = true;
2145 
2146   return OkStatus();
2147 }
2148 
LiftOutsideCompilationOnlyArgsFromIfNode(Graph * g,Node * if_node,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2149 Status LiftOutsideCompilationOnlyArgsFromIfNode(Graph* g, Node* if_node,
2150                                                 FunctionLibraryDefinition* fld,
2151                                                 int* lifted_arg_count,
2152                                                 bool* rewritten) {
2153   *rewritten = false;
2154   TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> args_to_lift,
2155                       FindArgsToLiftForIfNode(*if_node, fld));
2156   if (args_to_lift.empty()) return OkStatus();
2157 
2158   std::vector<DataType> dtypes;
2159   TF_RETURN_IF_ERROR(GetNodeAttr(if_node->def(), "Tin", &dtypes));
2160 
2161   absl::flat_hash_map<int, int> index_mapping;
2162   int new_index = 0;
2163   for (int i = 0; i < dtypes.size(); i++) {
2164     if (!args_to_lift.contains(i)) {
2165       index_mapping[i] = new_index;
2166       ++new_index;
2167     }
2168   }
2169 
2170   // For each lifted arg, add an outside compilation Identity node to send
2171   // it to host.
2172   absl::flat_hash_map<int, string> lifted_arg_index_to_oc_cluster_name;
2173   TF_RETURN_IF_ERROR(MakeIdentityNodesForArgsToLift(
2174       args_to_lift, /*arg_to_input_edge_offset=*/1, g, if_node,
2175       &lifted_arg_index_to_oc_cluster_name, lifted_arg_count));
2176 
2177   TF_ASSIGN_OR_RETURN(
2178       std::unique_ptr<FunctionBody> then_branch_fbody,
2179       InstantiateAssociatedFunction(*if_node, "then_branch", fld));
2180 
2181   TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2182       args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2183       then_branch_fbody.get()));
2184 
2185   FunctionDef rewritten_then_branch_fdef;
2186   TF_RETURN_IF_ERROR(GraphToFunctionDef(
2187       *(then_branch_fbody->graph), then_branch_fbody->fdef.signature().name(),
2188       &rewritten_then_branch_fdef));
2189   TF_RETURN_IF_ERROR(fld->ReplaceFunction(
2190       then_branch_fbody->fdef.signature().name(), rewritten_then_branch_fdef));
2191 
2192   TF_ASSIGN_OR_RETURN(
2193       std::unique_ptr<FunctionBody> else_branch_fbody,
2194       InstantiateAssociatedFunction(*if_node, "else_branch", fld));
2195 
2196   TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2197       args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2198       else_branch_fbody.get()));
2199 
2200   FunctionDef rewritten_else_branch_fdef;
2201   TF_RETURN_IF_ERROR(GraphToFunctionDef(
2202       *(else_branch_fbody->graph), else_branch_fbody->fdef.signature().name(),
2203       &rewritten_else_branch_fdef));
2204   TF_RETURN_IF_ERROR(fld->ReplaceFunction(
2205       else_branch_fbody->fdef.signature().name(), rewritten_else_branch_fdef));
2206 
2207   // Remove edges from lifted args to If node, and change "Tin" attr of the
2208   // If node.
2209   TF_RETURN_IF_ERROR(CleanUpInEdges(
2210       index_mapping, /*arg_to_input_edge_offset=*/1, g, if_node));
2211   TF_RETURN_IF_ERROR(
2212       UpdateTypeAttribute(index_mapping, "Tin", dtypes, if_node));
2213 
2214   *rewritten = true;
2215 
2216   return OkStatus();
2217 }
2218 
LiftOutsideCompilationOnlyArgsFromCallNode(Graph * g,Node * call_node,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2219 Status LiftOutsideCompilationOnlyArgsFromCallNode(
2220     Graph* g, Node* call_node, FunctionLibraryRuntime* flr,
2221     FunctionLibraryDefinition* fld, int* lifted_arg_count, bool* rewritten) {
2222   *rewritten = false;
2223 
2224   // Instantiate the function.
2225   NameAttrList func;
2226   if (fld->Contains(call_node->type_string())) {
2227     func.set_name(call_node->type_string());
2228     *func.mutable_attr() = call_node->def().attr();
2229   } else if (call_node->IsPartitionedCall()) {
2230     TF_RETURN_IF_ERROR(GetNodeAttr(call_node->def(), "f", &func));
2231   } else {
2232     TF_RET_CHECK(call_node->type_string() ==
2233                  FunctionLibraryDefinition::kGradientOp);
2234     func.set_name(FunctionLibraryDefinition::kGradientOp);
2235     *func.mutable_attr() = call_node->def().attr();
2236   }
2237   FunctionLibraryRuntime::Handle handle;
2238   TF_RETURN_IF_ERROR(
2239       flr->Instantiate(func.name(), AttrSlice(&func.attr()), &handle));
2240   auto cleanup_handle = gtl::MakeCleanup(
2241       [&flr, &handle]() { flr->ReleaseHandle(handle).IgnoreError(); });
2242   const FunctionBody* fbody = flr->GetFunctionBody(handle);
2243 
2244   // Find _Arg nodes to lift.
2245   TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> args_to_lift,
2246                       FindArgsToLiftForCallNode(call_node, *fbody));
2247   if (args_to_lift.empty()) return OkStatus();
2248 
2249   std::vector<DataType> dtypes;
2250   dtypes = std::vector<DataType>(call_node->input_types().begin(),
2251                                  call_node->input_types().end());
2252 
2253   absl::flat_hash_map<int, int> index_mapping =
2254       ArgIndexMapping(dtypes.size(), args_to_lift);
2255 
2256   // For each lifted arg, add an outside compilation Identity node to send
2257   // it to host.
2258   absl::flat_hash_map<int, string> lifted_arg_index_to_oc_cluster_name;
2259   TF_RETURN_IF_ERROR(MakeIdentityNodesForArgsToLift(
2260       args_to_lift, /*arg_to_input_edge_offset=*/0, g, call_node,
2261       &lifted_arg_index_to_oc_cluster_name, lifted_arg_count));
2262 
2263   // Remove _Arg nodes.
2264   TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2265       args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2266       fbody));
2267 
2268   // Store rewritten function as a new function, because the original function
2269   // might be defined by user and we should not modify it.
2270   FunctionDef rewritten_fdef;
2271   TF_RETURN_IF_ERROR(GraphToFunctionDef(
2272       *(fbody->graph), fbody->fdef.signature().name(), &rewritten_fdef));
2273   string new_func_name =
2274       fld->UniqueFunctionName(fbody->fdef.signature().name());
2275   rewritten_fdef.mutable_signature()->set_name(new_func_name);
2276   TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
2277 
2278   // Remove edges from lifted args to call node.
2279   TF_RETURN_IF_ERROR(CleanUpInEdges(
2280       index_mapping, /*arg_to_input_edge_offset=*/0, g, call_node));
2281 
2282   // Rewrite the call node to use the rewritten function.
2283   NodeDef node_def;
2284   node_def.set_name(g->NewName(call_node->name()));
2285   node_def.set_op(new_func_name);
2286   if (call_node->IsPartitionedCall()) {
2287     NameAttrList f;
2288     TF_RETURN_IF_ERROR(GetNodeAttr(call_node->def(), "f", &f));
2289     *node_def.mutable_attr() = f.attr();
2290   } else if (fld->Contains(call_node->type_string())) {
2291     *node_def.mutable_attr() = call_node->def().attr();
2292   } else {
2293     TF_RET_CHECK(call_node->type_string() ==
2294                  FunctionLibraryDefinition::kGradientOp);
2295     *node_def.mutable_attr() = call_node->def().attr();
2296     node_def.mutable_attr()->erase(FunctionLibraryDefinition::kFuncAttr);
2297   }
2298   TF_ASSIGN_OR_RETURN(call_node, ReplaceNode(g, call_node, node_def));
2299 
2300   *rewritten = true;
2301 
2302   return OkStatus();
2303 }
2304 
2305 // Lifts outside compilation only _Arg nodes out of If/While/function nodes.
LiftOutsideCompilationOnlyArgs(Graph * g,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2306 Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr,
2307                                       FunctionLibraryDefinition* fld,
2308                                       int* lifted_arg_count, bool* rewritten) {
2309   *rewritten = false;
2310 
2311   // Handle deeper functional nodes first.
2312   std::vector<Node*> while_nodes, if_nodes, call_nodes;
2313   for (Node* n : g->op_nodes()) {
2314     if (HasNodeAttr(n->def(), kOutsideCompilationAttr)) {
2315       continue;
2316     }
2317 
2318     if (n->IsWhileNode()) {
2319       TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionBody> body_fbody,
2320                           InstantiateAssociatedFunction(*n, "body", fld));
2321       bool func_rewritten = false;
2322       TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2323           *body_fbody, flr, fld, lifted_arg_count,
2324           /*new_func_name=*/absl::nullopt, &func_rewritten));
2325       *rewritten = *rewritten || func_rewritten;
2326 
2327       while_nodes.push_back(n);
2328     } else if (n->IsIfNode()) {
2329       TF_ASSIGN_OR_RETURN(
2330           std::unique_ptr<FunctionBody> then_branch_fbody,
2331           InstantiateAssociatedFunction(*n, "then_branch", fld));
2332       bool func_rewritten = false;
2333       TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2334           *then_branch_fbody, flr, fld, lifted_arg_count,
2335           /*new_func_name=*/absl::nullopt, &func_rewritten));
2336       *rewritten |= func_rewritten;
2337 
2338       TF_ASSIGN_OR_RETURN(
2339           std::unique_ptr<FunctionBody> else_branch_fbody,
2340           InstantiateAssociatedFunction(*n, "else_branch", fld));
2341       func_rewritten = false;
2342       TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2343           *else_branch_fbody, flr, fld, lifted_arg_count,
2344           /*new_func_name=*/absl::nullopt, &func_rewritten));
2345       *rewritten |= func_rewritten;
2346 
2347       if_nodes.push_back(n);
2348     } else if (IsFunctionCall(*fld, *n)) {
2349       // Function call nodes need to be rewritten, so handle them later.
2350       call_nodes.push_back(n);
2351     }
2352   }
2353 
2354   std::vector<Node*> rewritten_call_nodes;
2355   for (Node* call_node : call_nodes) {
2356     if (call_node->IsPartitionedCall()) {
2357       std::unique_ptr<FunctionBody> function_fbody;
2358       TF_ASSIGN_OR_RETURN(function_fbody,
2359                           InstantiateAssociatedFunction(*call_node, "f", fld));
2360       bool func_rewritten = false;
2361       string new_func_name =
2362           fld->UniqueFunctionName(function_fbody->fdef.signature().name());
2363       TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2364           *function_fbody, flr, fld, lifted_arg_count, new_func_name,
2365           &func_rewritten));
2366       if (func_rewritten) {
2367         NameAttrList f;
2368         TF_RETURN_IF_ERROR(GetNodeAttr(call_node->def(), "f", &f));
2369         f.set_name(new_func_name);
2370         call_node->ClearAttr("f");
2371         call_node->AddAttr("f", f);
2372       }
2373 
2374       *rewritten |= func_rewritten;
2375       rewritten_call_nodes.push_back(call_node);
2376     } else if (fld->Contains(call_node->type_string())) {
2377       std::unique_ptr<FunctionBody> function_fbody;
2378       const FunctionDef* fdef = fld->Find(call_node->type_string());
2379       TF_RET_CHECK(fdef);
2380       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, call_node->attrs(), fld,
2381                                                  &function_fbody));
2382       bool func_rewritten = false;
2383       string new_func_name =
2384           fld->UniqueFunctionName(function_fbody->fdef.signature().name());
2385       TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2386           *function_fbody, flr, fld, lifted_arg_count, new_func_name,
2387           &func_rewritten));
2388       if (func_rewritten) {
2389         NodeDef node_def;
2390         node_def.set_name(g->NewName(call_node->name()));
2391         node_def.set_op(new_func_name);
2392         *node_def.mutable_attr() = call_node->def().attr();
2393         TF_ASSIGN_OR_RETURN(call_node, ReplaceNode(g, call_node, node_def));
2394       }
2395 
2396       *rewritten |= func_rewritten;
2397       rewritten_call_nodes.push_back(call_node);
2398     } else {
2399       TF_RET_CHECK(call_node->type_string() ==
2400                    FunctionLibraryDefinition::kGradientOp);
2401       FunctionLibraryRuntime::Handle handle;
2402       TF_RETURN_IF_ERROR(flr->Instantiate(call_node->type_string(),
2403                                           call_node->attrs(), &handle));
2404       auto cleanup_handle = gtl::MakeCleanup(
2405           [&flr, &handle]() { flr->ReleaseHandle(handle).IgnoreError(); });
2406       bool func_rewritten = false;
2407       string new_func_name = fld->UniqueFunctionName(
2408           absl::StrCat(call_node->name(), "_lift_args"));
2409       const FunctionBody* function_fbody = flr->GetFunctionBody(handle);
2410       TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2411           *function_fbody, flr, fld, lifted_arg_count, new_func_name,
2412           &func_rewritten));
2413       if (func_rewritten) {
2414         NodeDef node_def;
2415         node_def.set_name(g->NewName(call_node->name()));
2416         node_def.set_op(new_func_name);
2417         *node_def.mutable_attr() = call_node->def().attr();
2418         node_def.mutable_attr()->erase(FunctionLibraryDefinition::kFuncAttr);
2419         TF_ASSIGN_OR_RETURN(call_node, ReplaceNode(g, call_node, node_def));
2420       }
2421 
2422       *rewritten |= func_rewritten;
2423       rewritten_call_nodes.push_back(call_node);
2424     }
2425   }
2426 
2427   for (Node* n : while_nodes) {
2428     bool node_rewritten = false;
2429     TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsFromWhileNode(
2430         g, n, fld, lifted_arg_count, &node_rewritten));
2431     *rewritten = *rewritten || node_rewritten;
2432   }
2433 
2434   for (Node* n : if_nodes) {
2435     bool node_rewritten = false;
2436     TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsFromIfNode(
2437         g, n, fld, lifted_arg_count, &node_rewritten));
2438     *rewritten = *rewritten || node_rewritten;
2439   }
2440 
2441   for (Node* n : rewritten_call_nodes) {
2442     bool node_rewritten = false;
2443     TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsFromCallNode(
2444         g, n, flr, fld, lifted_arg_count, &node_rewritten));
2445     *rewritten = *rewritten || node_rewritten;
2446   }
2447 
2448   if (*rewritten) {
2449     VLOG(4) << DumpGraphToFile("after_lifting_args", *g, fld);
2450   }
2451 
2452   return OkStatus();
2453 }
2454 
2455 }  // namespace
2456 
Encapsulate(std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def)2457 /*static*/ Status EncapsulateTPUComputationsPass::Encapsulate(
2458     std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
2459   // Check for undeclared outputs before Encapsulation, so we can give a better
2460   // error message.
2461   // TODO(phawkins): merge this with the encapsulation code to avoid the extra
2462   // O(n) pass over the edges.
2463   for (const Edge* e : (*graph)->edges()) {
2464     if (!e->IsControlEdge() &&
2465         e->src()->attrs().Find(kTPUReplicateAttr) != nullptr &&
2466         e->src()->attrs().Find(kOutsideCompilationAttr) == nullptr &&
2467         e->dst()->attrs().Find(kTPUReplicateAttr) == nullptr &&
2468         e->dst()->type_string() != kTPUReplicatedOutput) {
2469       return errors::InvalidArgument(
2470           "Undeclared output of TPU computation. A common cause of this error "
2471           "is variable initializers that depend on the TPU computation. Edge: ",
2472           FormatNodeForError(*e->src()), ":", e->src_output(), " -> ",
2473           FormatNodeForError(*e->dst()), ":", e->dst_input());
2474     }
2475   }
2476 
2477   RemoveUnusedTPUReplicatedInputs(graph->get());
2478 
2479   TF_RETURN_IF_ERROR(RenameClustersWithDuplicatedNames(graph->get()));
2480 
2481   TF_RETURN_IF_ERROR(
2482       PerformStaticShapeInferenceBeforeEncapsulation(graph->get()));
2483 
2484   auto output = absl::make_unique<Graph>((*graph)->op_registry());
2485   TF_RETURN_WITH_CONTEXT_IF_ERROR(
2486       EncapsulateSubgraphsInFunctions(
2487           kTPUReplicateAttr, **graph, RewriteSubgraph,
2488           /*reuse_existing_functions=*/true, &output, flib_def),
2489       "EncapsulateTPUComputationsPass failed");
2490   graph->swap(output);
2491 
2492   return OkStatus();
2493 }
2494 
BuildTPUReplicateOps(Graph * graph)2495 /*static*/ Status EncapsulateTPUComputationsPass::BuildTPUReplicateOps(
2496     Graph* graph) {
2497   // Finds all of the replicate function calls, to avoid mutating the graph
2498   // while iterating.
2499   std::vector<Node*> replicate_nodes;
2500   std::vector<Node*> guarantee_const_nodes;
2501   for (Node* n : graph->nodes()) {
2502     string name;
2503     if (TryGetNodeAttr(n->attrs(), kTPUReplicateAttr, &name) &&
2504         !TryGetNodeAttr(n->attrs(), kOutsideCompilationAttr, &name)) {
2505       replicate_nodes.push_back(n);
2506     } else if (n->type_string() == "GuaranteeConst") {
2507       guarantee_const_nodes.push_back(n);
2508     }
2509   }
2510 
2511   // Replace any GuaranteeConst nodes with Identity nodes. These nodes have now
2512   // served their purpose and have no runtime effect, except increasing
2513   // inference latency due to executor overhead. Subsequent rewrites will remove
2514   // the Identity nodes.
2515   for (Node* n : guarantee_const_nodes) {
2516     std::vector<std::pair<Node*, int>> predecessors;
2517     for (const Edge* e : n->in_edges()) {
2518       predecessors.emplace_back(e->src(), e->src_output());
2519     }
2520     std::vector<std::pair<Node*, int>> successors;
2521     for (const Edge* e : n->out_edges()) {
2522       successors.emplace_back(e->dst(), e->dst_input());
2523     }
2524     NodeDef ndef;
2525     ndef.set_name(n->name());
2526     ndef.set_op("Identity");
2527     ndef.set_device(n->requested_device());
2528     MergeDebugInfo(NodeDebugInfo(n->def()), &ndef);
2529     AddNodeAttr("T", n->output_type(0), &ndef);
2530 
2531     graph->RemoveNode(n);
2532     TF_ASSIGN_OR_RETURN(Node * id_node, graph->AddNode(ndef));
2533 
2534     for (const auto& pred : predecessors) {
2535       if (pred.second < 0) {
2536         graph->AddControlEdge(pred.first, id_node);
2537       } else {
2538         graph->AddEdge(pred.first, pred.second, id_node, 0);
2539       }
2540     }
2541     for (const auto& succ : successors) {
2542       if (succ.second < 0) {
2543         graph->AddControlEdge(id_node, succ.first);
2544       } else {
2545         graph->AddEdge(id_node, 0, succ.first, succ.second);
2546       }
2547     }
2548   }
2549 
2550   // Replaces each replicate function call together with its neighboring
2551   // TPUReplicatedInput/TPUReplicatedOutput nodes with a TPUReplicate node.
2552   for (Node* replicate : replicate_nodes) {
2553     int num_replicas;
2554     TF_RETURN_IF_ERROR(
2555         GetNodeAttr(replicate->attrs(), "num_replicas", &num_replicas));
2556     int variable_start_index;
2557     TF_RETURN_IF_ERROR(GetNodeAttr(replicate->attrs(), "_variable_start_index",
2558                                    &variable_start_index));
2559     int guaranteed_const_start_index;
2560     TF_RETURN_IF_ERROR(GetNodeAttr(replicate->attrs(),
2561                                    "_guaranteed_const_start_index",
2562                                    &guaranteed_const_start_index));
2563 
2564     if (HasNodeAttr(replicate->def(), "use_tpu")) {
2565       bool use_tpu;
2566       TF_RETURN_IF_ERROR(GetNodeAttr(replicate->attrs(), "use_tpu", &use_tpu));
2567       if (!use_tpu) {
2568         LOG(WARNING) << "use_tpu=false attr on a TPUReplicate node is ignored.";
2569       }
2570     }
2571 
2572     std::vector<const Edge*> in_edges;
2573     TF_RETURN_IF_ERROR(replicate->input_edges(&in_edges));
2574 
2575     // Counts the number of replicated, non-replicated, and variable inputs.
2576     int pos = 0;
2577     std::vector<int> mirrored_variable_indices;
2578     int distributed_var_start_index = 0;
2579     while (pos < in_edges.size() &&
2580            in_edges[pos]->src()->type_string() == kTPUReplicatedInput) {
2581       // Checks that each TPUReplicatedInput node has the correct number of
2582       // replicas.
2583       int input_num_replicas;
2584       TF_RETURN_IF_ERROR(
2585           GetNodeAttr(in_edges[pos]->src()->attrs(), "N", &input_num_replicas));
2586 
2587       bool is_mirrored_variable;
2588       CHECK(GetNodeAttr(in_edges[pos]->src()->attrs(), "is_mirrored_variable",
2589                         &is_mirrored_variable)
2590                 .ok());
2591       if (is_mirrored_variable) {
2592         mirrored_variable_indices.push_back(pos);
2593       }
2594 
2595       bool is_packed = false;
2596       GetNodeAttr(in_edges[pos]->src()->attrs(), "is_packed", &is_packed)
2597           .IgnoreError();
2598 
2599       bool is_distributed_variable =
2600           is_packed && (in_edges[pos]->src()->output_type(
2601                             in_edges[pos]->src_output()) == DT_RESOURCE);
2602 
2603       if (!is_distributed_variable && input_num_replicas != num_replicas) {
2604         return errors::InvalidArgument(
2605             "Mismatched number of replicas. Computation has ", num_replicas,
2606             " replicas, input '", FormatNodeForError(*in_edges[pos]->src()),
2607             "' has ", input_num_replicas, " replicas.");
2608       }
2609 
2610       if (!is_distributed_variable) {
2611         if (distributed_var_start_index < pos) {
2612           return errors::InvalidArgument(
2613               "Expect a distributed resource after index ",
2614               distributed_var_start_index,
2615               ", but got a replicated resource at index ", pos);
2616         } else {
2617           ++distributed_var_start_index;
2618         }
2619       }
2620       ++pos;
2621     }
2622     const int num_replicated_inputs = distributed_var_start_index;
2623     const int num_distributed_vars = pos - num_replicated_inputs;
2624 
2625     const int num_variables =
2626         std::max(0, guaranteed_const_start_index - variable_start_index);
2627 
2628     const int num_guaranteed_constants =
2629         in_edges.size() - guaranteed_const_start_index;
2630     TF_RET_CHECK(num_guaranteed_constants >= 0);
2631 
2632     VLOG(1) << "Replicate node '" << replicate->name() << "'"
2633             << " input edges: " << in_edges.size()
2634             << " num_replicated_inputs: " << num_replicated_inputs
2635             << " num_distributed_vars: " << num_distributed_vars
2636             << " num_variables: " << num_variables
2637             << " num_guaranteed_constants: " << num_guaranteed_constants
2638             << " num_mirrored_variables: " << mirrored_variable_indices.size();
2639 
2640     const int num_broadcast_inputs =
2641         in_edges.size() - (num_replicated_inputs + num_distributed_vars +
2642                            num_variables + num_guaranteed_constants);
2643     TF_RET_CHECK(num_broadcast_inputs >= 0);
2644 
2645     const int num_inputs = num_replicated_inputs * num_replicas +
2646                            num_distributed_vars + num_broadcast_inputs +
2647                            num_guaranteed_constants + num_variables;
2648 
2649     std::vector<Node*> nodes_to_remove = {replicate};
2650 
2651     // Data and control inputs to the new TPUReplicate node.
2652     std::vector<std::pair<Node*, int>> data_inputs(num_inputs);
2653     gtl::FlatSet<Node*> control_inputs;
2654 
2655     AddControlInputs(*replicate, &control_inputs);
2656 
2657     // Replicated inputs. Adds the inputs from the TPUReplicatedInput inputs,
2658     // in replica-major order. See the comments in
2659     // distributed_tpu_rewrite_pass.h for a description of the argument order.
2660     DataTypeVector replicated_input_types(num_replicated_inputs * num_replicas +
2661                                           num_distributed_vars);
2662 
2663     // Inputs with is_distributed_variable = false.
2664     for (int i = 0; i < num_replicated_inputs; ++i) {
2665       std::vector<const Edge*> replica_in_edges;
2666       TF_RETURN_IF_ERROR(in_edges[i]->src()->input_edges(&replica_in_edges));
2667       for (int replica = 0; replica < num_replicas; ++replica) {
2668         int pos = replica * num_replicated_inputs + i;
2669         const Edge* edge = replica_in_edges[replica];
2670         data_inputs[pos] = {edge->src(), edge->src_output()};
2671         replicated_input_types[pos] = EdgeType(edge);
2672       }
2673       AddControlInputs(*in_edges[i]->src(), &control_inputs);
2674       nodes_to_remove.push_back(in_edges[i]->src());
2675     }
2676 
2677     // Inputs with is_distributed_variable = true.
2678     for (int i = 0; i < num_distributed_vars; ++i) {
2679       int pos = num_replicas * num_replicated_inputs + i;
2680       std::vector<const Edge*> replica_in_edges;
2681       TF_RETURN_IF_ERROR(
2682           in_edges[num_replicated_inputs + i]->src()->input_edges(
2683               &replica_in_edges));
2684       TF_RET_CHECK(replica_in_edges.size() == 1);
2685       const Edge* edge = replica_in_edges[0];
2686       data_inputs[pos] = {edge->src(), edge->src_output()};
2687       replicated_input_types[pos] = EdgeType(edge);
2688       AddControlInputs(*in_edges[num_replicated_inputs + i]->src(),
2689                        &control_inputs);
2690       nodes_to_remove.push_back(in_edges[num_replicated_inputs + i]->src());
2691     }
2692 
2693     // Appends the broadcast inputs.
2694     DataTypeVector broadcast_input_types(num_broadcast_inputs);
2695     for (int i = 0; i < num_broadcast_inputs; ++i) {
2696       int pos = num_replicas * num_replicated_inputs + num_distributed_vars + i;
2697       const Edge* edge =
2698           in_edges[num_replicated_inputs + num_distributed_vars + i];
2699       data_inputs[pos] = {edge->src(), edge->src_output()};
2700       broadcast_input_types[i] = EdgeType(edge);
2701     }
2702 
2703     // Appends the variable inputs.
2704     for (int i = 0; i < num_variables; ++i) {
2705       int pos = num_replicas * num_replicated_inputs + num_distributed_vars +
2706                 num_broadcast_inputs + i;
2707       const Edge* edge = in_edges[num_replicated_inputs + num_distributed_vars +
2708                                   num_broadcast_inputs + i];
2709       data_inputs[pos] = {edge->src(), edge->src_output()};
2710     }
2711 
2712     DataTypeVector guaranteed_constant_types(num_guaranteed_constants);
2713     for (int i = 0; i < num_guaranteed_constants; ++i) {
2714       int pos = num_replicas * num_replicated_inputs + num_distributed_vars +
2715                 num_broadcast_inputs + num_variables + i;
2716       const Edge* edge = in_edges[num_replicated_inputs + num_distributed_vars +
2717                                   num_broadcast_inputs + num_variables + i];
2718       data_inputs[pos] = {edge->src(), edge->src_output()};
2719       guaranteed_constant_types[i] = EdgeType(edge);
2720     }
2721 
2722     // Outputs. All outputs from a replicated computation are replicated.
2723     const int num_outputs = replicate->output_types().size();
2724     gtl::FlatSet<Node*> control_outputs;
2725     std::vector<Node*> replicated_outputs(num_outputs);
2726     for (const Edge* e : replicate->out_edges()) {
2727       if (e->IsControlEdge()) {
2728         control_outputs.insert(e->dst());
2729       } else {
2730         TF_RET_CHECK(e->src_output() < num_outputs);
2731         TF_RET_CHECK(e->dst()->type_string() == kTPUReplicatedOutput)
2732             << e->DebugString();
2733         TF_RET_CHECK(e->dst()->output_types().size() == num_replicas);
2734         replicated_outputs[e->src_output()] = e->dst();
2735         nodes_to_remove.push_back(e->dst());
2736 
2737         AddControlOutputs(*e->dst(), &control_outputs);
2738       }
2739     }
2740 
2741     // Flattens the edges outgoing from the TPUReplicatedOutput nodes in
2742     // replica-major order.
2743     std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_replicas *
2744                                                                  num_outputs);
2745     DataTypeVector output_types(num_replicas * num_outputs);
2746     for (int i = 0; i < num_outputs; ++i) {
2747       std::vector<std::vector<const Edge*>> replica_out_edges(num_replicas);
2748       TF_RET_CHECK(replicated_outputs[i] != nullptr);
2749       for (const Edge* e : replicated_outputs[i]->out_edges()) {
2750         TF_RET_CHECK(!e->IsControlEdge());
2751         replica_out_edges[e->src_output()].push_back(e);
2752       }
2753 
2754       for (int replica = 0; replica < num_replicas; ++replica) {
2755         const int pos = replica * num_outputs + i;
2756         for (const Edge* edge : replica_out_edges[replica]) {
2757           data_outputs[pos].push_back({edge->dst(), edge->dst_input()});
2758         }
2759         output_types[pos] = replicated_outputs[i]->input_type(0);
2760       }
2761     }
2762 
2763     // TODO(b/79092708): Consolidate and cleanup to avoid TPU specialization.
2764     NodeDef def;
2765     def.set_name(replicate->name());
2766     def.set_op("_TPUReplicate");
2767     MergeDebugInfo(NodeDebugInfo(replicate->def()), &def);
2768     NameAttrList computation;
2769     computation.set_name(replicate->type_string());
2770     AddNodeAttr("computation", computation, &def);
2771     for (const auto& attr : replicate->attrs()) {
2772       def.mutable_attr()->insert(attr);
2773     }
2774     AddNodeAttr("Tinputs", replicated_input_types, &def);
2775     AddNodeAttr("Tbroadcast_inputs", broadcast_input_types, &def);
2776     AddNodeAttr("NumVariables", num_variables, &def);
2777     AddNodeAttr("Tguaranteed_constants", guaranteed_constant_types, &def);
2778     AddNodeAttr("output_types", output_types, &def);
2779     AddNodeAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
2780                 mirrored_variable_indices, &def);
2781     AddNodeAttr("num_distributed_variables", num_distributed_vars, &def);
2782 
2783     for (Node* node : nodes_to_remove) {
2784       VLOG(2) << "Deleting node " << node->DebugString();
2785       // Ensure that we do not attempt to add control edges to nodes that are
2786       // deleted.
2787       control_inputs.erase(node);
2788       control_outputs.erase(node);
2789       graph->RemoveNode(node);
2790     }
2791 
2792     TF_ASSIGN_OR_RETURN(Node * tpu_replicate, graph->AddNode(def));
2793     for (int i = 0; i < data_inputs.size(); ++i) {
2794       graph->AddEdge(data_inputs[i].first, data_inputs[i].second, tpu_replicate,
2795                      i);
2796     }
2797     for (Node* n : control_inputs) {
2798       graph->AddControlEdge(n, tpu_replicate);
2799     }
2800     for (int i = 0; i < data_outputs.size(); ++i) {
2801       for (const auto& successor : data_outputs[i]) {
2802         graph->AddEdge(tpu_replicate, i, successor.first, successor.second);
2803       }
2804     }
2805     for (Node* n : control_outputs) {
2806       graph->AddControlEdge(tpu_replicate, n);
2807     }
2808   }
2809   return OkStatus();
2810 }
2811 
Run(const GraphOptimizationPassOptions & options)2812 Status EncapsulateTPUComputationsPass::Run(
2813     const GraphOptimizationPassOptions& options) {
2814   VLOG(1) << "EncapsulateTPUComputations(): "
2815           << DumpGraphToFile("encapsulate_tpu_computations_before",
2816                              **options.graph, options.flib_def);
2817 
2818   TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def));
2819   VLOG(1) << "EncapsulateTPUComputations() half-way: "
2820           << DumpGraphToFile("encapsulate_tpu_computations_halfway",
2821                              **options.graph, options.flib_def);
2822 
2823   TF_RETURN_IF_ERROR(BuildTPUReplicateOps(options.graph->get()));
2824   VLOG(1) << "EncapsulateTPUComputations() finished: "
2825           << DumpGraphToFile("encapsulate_tpu_computations_after",
2826                              **options.graph, options.flib_def);
2827   return OkStatus();
2828 }
2829 
ProcessHeadTailOutsideCompilation(const string & outside_compilation_attr_name,int * lifted_arg_count,std::unordered_map<string,XlaClusterInfo> * clusters,Graph * g,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld)2830 Status ExtractOutsideCompilationPass::ProcessHeadTailOutsideCompilation(
2831     const string& outside_compilation_attr_name, int* lifted_arg_count,
2832     std::unordered_map<string, XlaClusterInfo>* clusters, Graph* g,
2833     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) {
2834   // Gather a list of pivots by cluster so we can easily look them up.
2835   absl::node_hash_map<string, Node*> pivots;
2836   string cluster_name;
2837   for (Node* node : g->nodes()) {
2838     if (TryGetNodeAttr(node->attrs(), kPivotForClusterAttr, &cluster_name)) {
2839       pivots[cluster_name] = node;
2840     }
2841   }
2842   for (auto& iter : *clusters) {
2843     // Find pivot node for this XLA cluster.
2844     Node* pivot_node = pivots[iter.first];
2845 
2846     // Instantiate XLA computation function.
2847     string xla_func_name = iter.second.func_name_attrs.name();
2848     std::unique_ptr<FunctionBody> xla_fbody;
2849     TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
2850         *fld->Find(xla_func_name),
2851         AttrSlice(&iter.second.func_name_attrs.attr()), fld, &xla_fbody));
2852     Graph* xla_graph = xla_fbody->graph;
2853 
2854     // Make sure all nodes can be traced from sink node.
2855     FixupSourceAndSinkEdges(xla_graph);
2856 
2857     // We create Identity nodes for all _Arg/_Retval nodes in XLA computation.
2858     // Remove those Identity nodes to simplify furthur processing.
2859     TF_RETURN_IF_ERROR(RemoveIdentityNodesForArgRetval(xla_graph));
2860 
2861     bool rewritten;
2862     TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgs(
2863         xla_graph, flr, fld, lifted_arg_count, &rewritten));
2864 
2865     // Move head outside compilation to host.
2866     TF_RETURN_IF_ERROR(MoveHeadOutsideCompilationToHost(
2867         outside_compilation_attr_name, iter.second.func_name_attrs.name(),
2868         iter.second.cluster_name, g, xla_graph, iter.second.node, pivot_node));
2869 
2870     // Move tail outside compilation to host.
2871     TF_RETURN_IF_ERROR(MoveTailOutsideCompilationToHost(
2872         outside_compilation_attr_name, iter.second.func_name_attrs.name(),
2873         iter.second.cluster_name, g, xla_graph, iter.second.node, pivot_node));
2874 
2875     // Replace outside compilation only _Arg nodes with Placeholder nodes.
2876     TF_RETURN_IF_ERROR(ReplaceArgUsedByOutsideCompilationWithPlaceholder(
2877         outside_compilation_attr_name, xla_func_name, g, xla_graph,
2878         iter.second.node));
2879 
2880     // There might be direct data edges between _Arg node and _Retval node in
2881     // `xla_graph`. Remove those edges to avoid back-and-forth data transfer
2882     // between host and XLA.
2883     TF_RETURN_IF_ERROR(RemoveEdgesBetweenArgAndRetval(
2884         iter.second.func_name_attrs.name(), g, xla_graph, iter.second.node));
2885 
2886     // After `MoveHeadOutsideCompilationToHost`, there might be unused XLA
2887     // inputs. Remove them.
2888     TF_RETURN_IF_ERROR(RemoveUnusedXlaInput(iter.second.func_name_attrs.name(),
2889                                             g, xla_graph, iter.second.node));
2890 
2891     // After `MoveTailOutsideCompilationToHost`, there might be unused XLA
2892     // outputs. Remove them.
2893     TF_RETURN_IF_ERROR(RemoveUnusedXlaOutput(iter.second.func_name_attrs.name(),
2894                                              g, xla_graph, iter.second.node));
2895 
2896     // Replace original function.
2897     FunctionDef replace_fdef;
2898     TF_RETURN_IF_ERROR(
2899         GraphToFunctionDef(*xla_graph, xla_func_name, &replace_fdef));
2900     TF_RETURN_IF_ERROR(fld->ReplaceFunction(xla_func_name, replace_fdef));
2901 
2902     FixupSourceAndSinkEdges(g);
2903   }
2904 
2905   return OkStatus();
2906 }
2907 
Run(const GraphOptimizationPassOptions & options)2908 Status ExtractOutsideCompilationPass::Run(
2909     const GraphOptimizationPassOptions& options) {
2910   const auto* config =
2911       (options.session_options ? &options.session_options->config : nullptr);
2912   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
2913       new ProcessFunctionLibraryRuntime(
2914           /*device_mgr=*/nullptr, options.session_options->env,
2915           /*config=*/config, TF_GRAPH_DEF_VERSION, options.flib_def,
2916           config ? config->graph_options().optimizer_options()
2917                  : OptimizerOptions()));
2918   FunctionLibraryRuntime* flr =
2919       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
2920 
2921   // Find XLA compile ops and their corresponding FunctionDefs.
2922   static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
2923       new std::map<string, string>{
2924           {"_TPUReplicate", "computation"},
2925       };
2926   std::unordered_map<string, XlaClusterInfo> clusters;
2927   int lifted_arg_count = 0;
2928   for (Node* n : (*options.graph)->nodes()) {
2929     auto iter = kNodeTypeToFunctionAttrMapping->find(n->type_string());
2930     if (iter == kNodeTypeToFunctionAttrMapping->end()) {
2931       continue;
2932     }
2933 
2934     string xla_cluster_name = n->name();
2935 
2936     string func_attr = iter->second;
2937     NameAttrList func;
2938     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
2939 
2940     std::vector<string> core_list;
2941     TF_RETURN_IF_ERROR(
2942         GetNodeAttr(n->attrs(), "host_compute_core", &core_list));
2943     std::map<string, int> host_compute_core;
2944     TF_RETURN_IF_ERROR(ParseHostComputeCoreList(core_list, &host_compute_core));
2945 
2946     clusters.emplace(xla_cluster_name, XlaClusterInfo{xla_cluster_name, func, n,
2947                                                       host_compute_core});
2948   }
2949   TF_RETURN_IF_ERROR(ProcessHeadTailOutsideCompilation(
2950       kOutsideCompilationAttr, &lifted_arg_count, &clusters,
2951       options.graph->get(), flr, options.flib_def));
2952   bool modified;
2953   TF_RETURN_IF_ERROR(ExtractOutsideCompilation(
2954       kTPUReplicateAttr, kOutsideCompilationAttr, clusters,
2955       options.graph->get(), flr, options.flib_def, &modified));
2956   if (modified) {
2957     TF_RETURN_IF_ERROR(
2958         PruneUnreachableFunctionsFromGraph(**options.graph, options.flib_def));
2959   }
2960 
2961   return OkStatus();
2962 }
2963 
2964 }  // namespace tensorflow
2965