1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h"
17
18 #include <queue>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/container/node_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
26 #include "tensorflow/compiler/jit/encapsulate_util.h"
27 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
28 #include "tensorflow/compiler/jit/xla_cluster_util.h"
29 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
30 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/framework/function.h"
34 #include "tensorflow/core/framework/graph_to_functiondef.h"
35 #include "tensorflow/core/framework/node_def.pb.h"
36 #include "tensorflow/core/framework/node_def_builder.h"
37 #include "tensorflow/core/framework/node_def_util.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/gtl/cleanup.h"
41 #include "tensorflow/core/lib/gtl/flatset.h"
42 #include "tensorflow/core/lib/hash/hash.h"
43 #include "tensorflow/core/lib/strings/proto_serialization.h"
44 #include "tensorflow/core/lib/strings/str_util.h"
45 #include "tensorflow/core/public/session_options.h"
46 #include "tensorflow/core/public/version.h"
47 #include "tensorflow/core/tpu/tpu_compile_interface.h"
48 #include "tensorflow/core/tpu/tpu_defs.h"
49 #include "tensorflow/core/util/dump_graph.h"
50
51 namespace tensorflow {
52
53 namespace {
54
55 const char* const kTPUReplicatedInput = "TPUReplicatedInput";
56 const char* const kTPUReplicatedOutput = "TPUReplicatedOutput";
57 const char* const kPivotForClusterAttr = "_pivot_for_cluster";
58 const char* const kTPUPartitionedInput = "TPUPartitionedInput";
59
60 // Finds the `index` of an _Arg or _Retval node.
GetIndexAttr(const Node & n,int num_args,int * index)61 Status GetIndexAttr(const Node& n, int num_args, int* index) {
62 TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index));
63 if (*index < 0 || *index >= num_args) {
64 return errors::InvalidArgument("Invalid ", n.type_string(), " number ",
65 *index);
66 }
67 return OkStatus();
68 }
69
70 // Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts
71 // the arguments into the order expected by TPUReplicate computations:
72 // 1) replicated arguments
73 // 2) non-replicated (broadcast) arguments
74 // 3) resource variable arguments
75 // See the documentation of EncapsulateSubgraphsInFunctions for the meaning
76 // of the arguments.
RewriteSubgraph(const std::vector<OutputTensor> & arg_source_tensors,std::unique_ptr<Graph> * graph_ptr,std::vector<int> * input_permutation,std::vector<int> * output_permutation,NodeDef * call_def)77 Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
78 std::unique_ptr<Graph>* graph_ptr,
79 std::vector<int>* input_permutation,
80 std::vector<int>* output_permutation,
81 NodeDef* call_def) {
82 // Replicated inputs have TPUReplicatedInput nodes as predecessors in the
83 // input graph.
84 auto is_replicated_input = [&](const Node& n, bool* is_packed = nullptr) {
85 CHECK_EQ("_Arg", n.type_string());
86 int index;
87 TF_CHECK_OK(GetIndexAttr(n, arg_source_tensors.size(), &index));
88 bool ret =
89 arg_source_tensors.at(index).node->type_string() == kTPUReplicatedInput;
90 if (is_packed) {
91 if (!ret || !GetNodeAttr(arg_source_tensors.at(index).node->attrs(),
92 "is_packed", is_packed)
93 .ok()) {
94 *is_packed = false;
95 }
96 }
97 return ret;
98 };
99
100 auto is_guaranteed_constant = [&](const Node& n) {
101 bool guaranteed_constant = false;
102 if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant)
103 .ok()) {
104 return false;
105 }
106 // Replicated input nodes can be marked as guaranteed constants if they are
107 // const.
108 return guaranteed_constant && !is_replicated_input(n);
109 };
110
111 Graph* graph = graph_ptr->get();
112 Node* metadata_node = nullptr;
113 const int num_args = input_permutation->size();
114 const int num_retvals = output_permutation->size();
115
116 std::vector<Node*> args;
117 std::vector<Node*> retvals;
118 args.reserve(num_args);
119 retvals.reserve(num_retvals);
120 for (Node* n : graph->nodes()) {
121 if (n->type_string() == "_Arg") {
122 args.push_back(n);
123 } else if (n->type_string() == "_Retval") {
124 retvals.push_back(n);
125 } else if (n->type_string() == "TPUReplicateMetadata") {
126 metadata_node = n;
127 } else if (!str_util::StrContains(n->requested_device(),
128 DEVICE_TPU_REPLICATED_CORE)) {
129 // If an operator isn't assigned to a TPU core device, assign it to
130 // TPU_REPLICATED_CORE without a specific core ID. For some operators,
131 // such as variable reads/writes, the operator may be assigned to non-TPU
132 // devices due to colocation.
133 n->set_assigned_device_name(
134 strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE));
135 }
136 }
137
138 // Read the metadata node and remove it from the graph.
139 if (metadata_node == nullptr) {
140 return errors::InvalidArgument("Missing TPUReplicateMetadata node");
141 }
142
143 for (const auto& attr : metadata_node->attrs()) {
144 if (attr.first == "computation_shape") {
145 // Convert the deprecated computation_shape attribute into a
146 // num_cores_per_replica value. If a computation_shape is present, it
147 // overrides num_cores_per_replica.
148 std::vector<int> shape;
149 TF_RETURN_IF_ERROR(
150 GetNodeAttr(metadata_node->attrs(), "computation_shape", &shape));
151 if (!shape.empty()) {
152 int64_t num_cores_per_replica = 1LL;
153 for (int dim : shape) {
154 num_cores_per_replica *= dim;
155 }
156 call_def->mutable_attr()->erase("num_cores_per_replica");
157 AddNodeAttr("num_cores_per_replica", num_cores_per_replica, call_def);
158 }
159 } else {
160 call_def->mutable_attr()->insert(attr);
161 }
162 }
163 MergeDebugInfo(NodeDebugInfo(metadata_node->def()), call_def);
164 graph->RemoveNode(metadata_node);
165
166 if (std::find(args.begin(), args.end(), nullptr) != args.end()) {
167 return errors::InvalidArgument("Missing or non-consecutive arguments");
168 }
169
170 // Reorders the arguments.
171 std::sort(args.begin(), args.end(), [&](Node* a, Node* b) {
172 // Non-constants appear before constants
173 bool a_is_guaranteed_constant = is_guaranteed_constant(*a);
174 bool b_is_guaranteed_constant = is_guaranteed_constant(*b);
175 // Non-packed values appear before packed values.
176 bool a_is_packed;
177 bool b_is_packed;
178 // Replicated values appear before non-replicated values.
179 bool a_not_replicated = !is_replicated_input(*a, &a_is_packed);
180 bool b_not_replicated = !is_replicated_input(*b, &b_is_packed);
181 // Non-resources appear before resources
182 bool a_is_resource = (a->output_type(0) == DT_RESOURCE);
183 bool b_is_resource = (b->output_type(0) == DT_RESOURCE);
184 // Uses the name as a tiebreaker so the output is deterministic.
185 StringPiece a_name(a->name());
186 StringPiece b_name(b->name());
187 return std::tie(a_is_guaranteed_constant, a_not_replicated, a_is_packed,
188 a_is_resource, a_name) <
189 std::tie(b_is_guaranteed_constant, b_not_replicated, b_is_packed,
190 b_is_resource, b_name);
191 });
192 // Sorts the retvals by name so the order is deterministic.
193 std::sort(retvals.begin(), retvals.end(),
194 [](Node* a, Node* b) { return a->name() < b->name(); });
195
196 // Computes the permutation to produce the correct argument order, and update
197 // the argument indices.
198 int variable_start_index = num_args;
199 int guaranteed_const_start_index = num_args;
200 for (int i = 0; i < num_args; ++i) {
201 int index;
202 TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index));
203 if (args[i]->output_type(0) == DT_RESOURCE &&
204 !is_replicated_input(*args[i]) && variable_start_index == num_args) {
205 variable_start_index = i;
206 } else if (is_guaranteed_constant(*args[i]) &&
207 guaranteed_const_start_index == num_args) {
208 guaranteed_const_start_index = i;
209 }
210 (*input_permutation)[index] = i;
211 args[i]->AddAttr("index", i);
212 }
213 VLOG(4) << "variable_start_index: " << variable_start_index
214 << " guaranteed_const_start_index: " << guaranteed_const_start_index;
215
216 // Computes the permutation to produce the correct retval order, and update
217 // the argument indices.
218 for (int i = 0; i < num_retvals; ++i) {
219 int index;
220 TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index));
221 (*output_permutation)[index] = i;
222 retvals[i]->AddAttr("index", i);
223 }
224
225 AddNodeAttr(kTPUReplicateAttr, call_def->name(), call_def);
226 AddNodeAttr("_variable_start_index", variable_start_index, call_def);
227 AddNodeAttr("_guaranteed_const_start_index", guaranteed_const_start_index,
228 call_def);
229
230 // Uniquify the function name by fingerprinting the function.
231 // Nondeterminism in serialization would not lead to incorrect results, but
232 // may cause spurious cache misses. DeterministicSerialization is a
233 // best-effort deterministic serialization.
234 TF_ASSIGN_OR_RETURN(string serialized, SerializeGraphDeterministic(*graph));
235 uint64 fingerprint =
236 TpuCompileInterface::Get()->FingerprintString(serialized);
237 LOG(INFO) << "Subgraph fingerprint:" << fingerprint;
238 call_def->set_op(strings::StrCat(call_def->op(), "_", fingerprint));
239 return OkStatus();
240 }
241
EdgeType(const Edge * edge)242 DataType EdgeType(const Edge* edge) {
243 return edge->dst()->input_type(edge->dst_input());
244 }
245
246 // Adds the control inputs of `node` to `*deps`.
AddControlInputs(const Node & node,gtl::FlatSet<Node * > * deps)247 void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
248 for (const Edge* edge : node.in_edges()) {
249 if (edge->IsControlEdge()) {
250 deps->insert(edge->src());
251 }
252 }
253 }
254
255 // Adds the control outputs of `node` to `*deps`.
AddControlOutputs(const Node & node,gtl::FlatSet<Node * > * deps)256 void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) {
257 for (const Edge* edge : node.out_edges()) {
258 if (edge->IsControlEdge()) {
259 deps->insert(edge->dst());
260 }
261 }
262 }
263
264 // We add Identity nodes for _Arg/_Retval in XLA computation. Remove those
265 // Identity nodes to simplify furthur processing.
RemoveIdentityNodesForArgRetval(Graph * g)266 Status RemoveIdentityNodesForArgRetval(Graph* g) {
267 // Collect Identity nodes for _Arg/_Retval.
268 std::vector<Node*> identity_nodes;
269 for (Node* n : g->nodes()) {
270 if (n->type_string() == "Identity" &&
271 (HasNodeAttr(n->def(), "_tpu_input_identity") ||
272 HasNodeAttr(n->def(), "_tpu_output_identity"))) {
273 identity_nodes.push_back(n);
274 }
275 }
276
277 // Remove those Identity nodes.
278 for (Node* n : identity_nodes) {
279 const Edge* input_edge;
280 TF_RETURN_IF_ERROR(n->input_edge(0, &input_edge));
281
282 std::vector<const Edge*> output_edges;
283 for (const Edge* e : n->out_edges()) {
284 output_edges.push_back(e);
285 }
286 for (const Edge* e : output_edges) {
287 if (e->IsControlEdge()) {
288 Node* dst = e->dst();
289 g->RemoveEdge(e);
290 g->AddControlEdge(input_edge->src(), dst);
291 } else {
292 Node* dst = e->dst();
293 int dst_input = e->dst_input();
294 g->RemoveEdge(e);
295 g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
296 }
297 }
298 g->RemoveNode(n);
299 }
300
301 return OkStatus();
302 }
303
304 // Updates the TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR when
305 // 'additional_per_replicate_inputs' are added to the inputs of `xla_node`.
UpdateMirroredVariableIndices(int additional_per_replica_inputs,Node * xla_node)306 Status UpdateMirroredVariableIndices(int additional_per_replica_inputs,
307 Node* xla_node) {
308 std::vector<int> mirrored_variable_indices;
309 if (xla_node->attrs().Find(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR) !=
310 nullptr) {
311 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(),
312 TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
313 &mirrored_variable_indices));
314 }
315
316 if (!mirrored_variable_indices.empty()) {
317 for (int i = 0; i < mirrored_variable_indices.size(); ++i)
318 mirrored_variable_indices[i] += additional_per_replica_inputs;
319 xla_node->ClearAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR);
320 xla_node->AddAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
321 mirrored_variable_indices);
322 }
323 return OkStatus();
324 }
325
326 // Move outside compilation nodes at the beginning of XLA computation to host.
327 // For XLA computation graph, we will add new _Arg nodes to replace those
328 // outside compilation nodes.
329 // For host graph, we will move those outside compilation nodes to host,
330 // replicate them, and use them as XLA node's input.
MoveHeadOutsideCompilationToHost(const string & outside_compilation_attr_name,const string & xla_func_name,const std::string & cluster_name,Graph * g,Graph * xla_graph,Node * xla_node,Node * pivot_node)331 Status MoveHeadOutsideCompilationToHost(
332 const string& outside_compilation_attr_name, const string& xla_func_name,
333 const std::string& cluster_name, Graph* g, Graph* xla_graph, Node* xla_node,
334 Node* pivot_node) {
335 // Find outside compilation nodes that only have _Arg or other outside
336 // compilation nodes as input. These nodes will be moved to host graph.
337 std::vector<Node*> oc_nodes_at_head;
338 const string kOnlyArgOrOcInputAttrName = "_xla_only_arg_or_oc_input";
339 ReverseDFS(
340 *xla_graph, /*enter=*/nullptr,
341 [&](Node* n) {
342 bool has_non_arg_or_oc_input = false;
343 for (const Edge* e : n->in_edges()) {
344 if (e->src() == xla_graph->source_node()) {
345 continue;
346 }
347 if (!e->src()->IsArg() &&
348 (!HasNodeAttr(e->src()->def(), outside_compilation_attr_name) ||
349 !HasNodeAttr(e->src()->def(), kOnlyArgOrOcInputAttrName))) {
350 has_non_arg_or_oc_input = true;
351 break;
352 }
353 }
354 if (HasNodeAttr(n->def(), outside_compilation_attr_name) &&
355 !has_non_arg_or_oc_input &&
356 !HasNodeAttr(n->def(), kXlaIsPlaceholderForArg)) {
357 n->AddAttr(kOnlyArgOrOcInputAttrName, true);
358 oc_nodes_at_head.push_back(n);
359 }
360 },
361 NodeComparatorName());
362 std::vector<Node*> const_nodes_to_remove;
363 for (Node* n : oc_nodes_at_head) {
364 // If a Const node is in "oc_nodes_at_head" but some of its successors are
365 // not, copy this Const node and use the copied node for those successors.
366 if (n->type_string() != "Const") {
367 continue;
368 }
369
370 std::vector<const Edge*> edges_to_replace;
371 for (const Edge* e : n->out_edges()) {
372 if (!e->IsControlEdge() &&
373 HasNodeAttr(e->dst()->def(), outside_compilation_attr_name) &&
374 !HasNodeAttr(e->dst()->def(), kOnlyArgOrOcInputAttrName)) {
375 edges_to_replace.push_back(e);
376 }
377 }
378 if (edges_to_replace.empty()) {
379 continue;
380 }
381
382 Node* const_copy = xla_graph->CopyNode(n);
383 for (const Edge* e : edges_to_replace) {
384 Node* dst = e->dst();
385 int dst_input = e->dst_input();
386 xla_graph->RemoveEdge(e);
387 xla_graph->AddEdge(const_copy, 0, dst, dst_input);
388 }
389 // Make sure the copied node can be traced from source node.
390 xla_graph->AddControlEdge(xla_graph->source_node(), const_copy);
391
392 // If this Const node has no data output any more, remove it later.
393 bool has_output_edge = false;
394 for (const Edge* e : n->out_edges()) {
395 if (!e->IsControlEdge()) {
396 has_output_edge = true;
397 break;
398 }
399 }
400 if (!has_output_edge) {
401 const_nodes_to_remove.push_back(n);
402 }
403 }
404 for (Node* n : const_nodes_to_remove) {
405 xla_graph->RemoveNode(n);
406 oc_nodes_at_head.erase(
407 std::remove(oc_nodes_at_head.begin(), oc_nodes_at_head.end(), n),
408 oc_nodes_at_head.end());
409 }
410 if (VLOG_IS_ON(5)) {
411 for (Node* n : oc_nodes_at_head) {
412 VLOG(5) << "oc_nodes_at_head: " << n->DebugString();
413 }
414 }
415
416 // Copy all nodes in `oc_nodes_at_head` to host graph, and also replicate
417 // them.
418
419 // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
420 // will become very expensive in this case because it is doing a linear
421 // search inside. Create an input_edges vector ahead to make the lookups
422 // faster.
423 std::vector<const Edge*> input_edges;
424 TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
425
426 std::vector<DataType> input_types;
427 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tinputs", &input_types));
428 int num_distributed_vars;
429 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
430 &num_distributed_vars));
431 int num_replicas;
432 TF_RETURN_IF_ERROR(
433 GetNodeAttr(xla_node->attrs(), "num_replicas", &num_replicas));
434 int old_num_per_replica_inputs =
435 (input_types.size() - num_distributed_vars) / num_replicas;
436 VLOG(5) << "old_num_per_replica_inputs: " << old_num_per_replica_inputs;
437 std::map<Node*, std::vector<Node*>> node_images;
438 for (Node* n : oc_nodes_at_head) {
439 for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
440 NodeDef copy_def = n->def();
441 copy_def.set_name(absl::StrCat(n->name(), "_head_oc/R", replica_id));
442 copy_def.clear_device();
443
444 TF_ASSIGN_OR_RETURN(Node * copy_node, g->AddNode(copy_def));
445
446 copy_node->AddAttr(kXlaReplicaIdAttrName, replica_id);
447 copy_node->AddAttr(kTPUReplicateAttr, cluster_name);
448
449 for (const Edge* e : n->in_edges()) {
450 if (e->src() == xla_graph->source_node()) {
451 continue;
452 }
453 // Either e->src() is _Arg node, or it's in `node_images`.
454 if (e->src()->IsArg()) {
455 int index;
456 TF_RETURN_IF_ERROR(GetNodeAttr(e->src()->attrs(), "index", &index));
457 const int new_index =
458 (index < old_num_per_replica_inputs)
459 ? (old_num_per_replica_inputs * replica_id + index)
460 : (old_num_per_replica_inputs * num_replicas +
461 (index - old_num_per_replica_inputs));
462 const Edge* original_edge = input_edges.at(new_index);
463 g->AddEdge(original_edge->src(), original_edge->src_output(),
464 copy_node, e->dst_input());
465 } else {
466 g->AddEdge(node_images[e->src()][replica_id], e->src_output(),
467 copy_node, e->dst_input());
468 }
469 }
470
471 // Add control edge between `copy_node` and `xla_node`, so these outside
472 // compilation nodes will be executed before XLA computation happens.
473 g->AddControlEdge(copy_node, xla_node);
474
475 // Add control edge between `pivot_node` and `copy_node`, so `copy_node`
476 // belongs to same while loop as `xla_node`.
477 if (pivot_node) {
478 g->AddControlEdge(pivot_node, copy_node);
479 }
480
481 node_images[n].push_back(copy_node);
482 }
483 }
484
485 // Record output edges from `oc_nodes_at_head`. We will create an _Arg node
486 // for each of these edges. An obvious optimization here is to deduplicate
487 // these edges by <src, src_output>. But that optimization will complicate
488 // the code, and in practice we usually do not have output edges with the
489 // same <src, src_output>.
490 std::vector<const Edge*> oc_output_edges;
491 std::vector<DataType> new_arg_types;
492 for (Node* n : oc_nodes_at_head) {
493 for (const Edge* e : n->out_edges()) {
494 if (!e->IsControlEdge() &&
495 node_images.find(e->dst()) == node_images.end()) {
496 VLOG(5) << "oc_output_edges: " << e->DebugString();
497 oc_output_edges.push_back(e);
498 new_arg_types.push_back(e->src()->output_type(e->src_output()));
499 }
500 }
501 }
502 int new_num_per_replica_inputs =
503 old_num_per_replica_inputs + oc_output_edges.size();
504 VLOG(5) << "new_num_per_replica_inputs: " << new_num_per_replica_inputs;
505
506 // Process input edges for XLA node.
507 int num_variables;
508 TF_RETURN_IF_ERROR(
509 GetNodeAttr(xla_node->attrs(), "NumVariables", &num_variables));
510 std::vector<DataType> broadcast_input_types, guaranteed_constant_types;
511 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tbroadcast_inputs",
512 &broadcast_input_types));
513 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tguaranteed_constants",
514 &guaranteed_constant_types));
515 int num_other_inputs = num_distributed_vars + num_variables +
516 broadcast_input_types.size() +
517 guaranteed_constant_types.size();
518 VLOG(5) << "num_other_inputs: " << num_other_inputs;
519
520 // Update `Tinputs` attribute for `xla_node`.
521 std::vector<DataType> new_input_types;
522 // Order of new_input_types: old per-replica inputs -> new per-replica inputs
523 // -> distributed variables
524 new_input_types.reserve(num_replicas * new_num_per_replica_inputs +
525 num_distributed_vars);
526 for (int replica_id = 0; replica_id < num_replicas; ++replica_id) {
527 for (int i = 0; i < old_num_per_replica_inputs; ++i) {
528 new_input_types.push_back(input_types[i]);
529 }
530 for (int i = old_num_per_replica_inputs; i < new_num_per_replica_inputs;
531 ++i) {
532 new_input_types.push_back(new_arg_types[i - old_num_per_replica_inputs]);
533 }
534 }
535 const int num_new_per_replica_input_types = new_input_types.size();
536 for (int i = input_types.size() - num_distributed_vars;
537 i < input_types.size(); i++) {
538 new_input_types.push_back(input_types[i]);
539 }
540 xla_node->ClearAttr("Tinputs");
541 xla_node->AddAttr("Tinputs", new_input_types);
542
543 TF_RETURN_IF_ERROR(UpdateMirroredVariableIndices(
544 /*additional_per_replica_inputs=*/oc_output_edges.size(), xla_node));
545
546 int new_variable_start_index =
547 num_new_per_replica_input_types / num_replicas + num_distributed_vars +
548 broadcast_input_types.size();
549 if (xla_node->attrs().Find("_variable_start_index") != nullptr) {
550 xla_node->ClearAttr("_variable_start_index");
551 xla_node->AddAttr("_variable_start_index", new_variable_start_index);
552 }
553 int new_guaranteed_const_start_index =
554 new_variable_start_index + num_variables;
555 if (xla_node->attrs().Find("_guaranteed_const_start_index") != nullptr) {
556 xla_node->ClearAttr("_guaranteed_const_start_index");
557 xla_node->AddAttr("_guaranteed_const_start_index",
558 new_guaranteed_const_start_index);
559 }
560
561 // Move non per-replica input edges.
562 std::vector<const Edge*> new_input_edges(
563 num_replicas * new_num_per_replica_inputs + num_other_inputs);
564 int end_input_index =
565 num_replicas * new_num_per_replica_inputs + num_other_inputs - 1;
566 int start_input_index = end_input_index + 1 - num_other_inputs;
567 for (int input_index = end_input_index; input_index >= start_input_index;
568 input_index--) {
569 const Edge* e =
570 input_edges.at(input_index - num_replicas * new_arg_types.size());
571 Node* src = e->src();
572 int src_output = e->src_output();
573 g->RemoveEdge(e);
574 const Edge* new_input_edge =
575 g->AddEdge(src, src_output, xla_node, input_index);
576 new_input_edges[input_index] = new_input_edge;
577 }
578
579 // Re-order old per-replica inputs edges, and add new per-replica input edges.
580 std::vector<std::pair<Node*, int>> per_replica_inputs;
581 std::vector<const Edge*> old_per_replica_edges;
582 for (int i = 0; i < old_num_per_replica_inputs * num_replicas; i++) {
583 const Edge* e = input_edges.at(i);
584 per_replica_inputs.push_back(std::make_pair(e->src(), e->src_output()));
585 old_per_replica_edges.push_back(e);
586 }
587 for (const Edge* e : old_per_replica_edges) {
588 g->RemoveEdge(e);
589 }
590 for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
591 for (int input_index = 0; input_index < old_num_per_replica_inputs;
592 input_index++) {
593 Node* src = per_replica_inputs[replica_id * old_num_per_replica_inputs +
594 input_index]
595 .first;
596 int src_output =
597 per_replica_inputs[replica_id * old_num_per_replica_inputs +
598 input_index]
599 .second;
600 const Edge* new_input_edge =
601 g->AddEdge(src, src_output, xla_node,
602 replica_id * new_num_per_replica_inputs + input_index);
603 new_input_edges[input_index] = new_input_edge;
604 }
605 for (int input_index = old_num_per_replica_inputs;
606 input_index < new_num_per_replica_inputs; input_index++) {
607 Node* original_src =
608 oc_output_edges[input_index - old_num_per_replica_inputs]->src();
609 int original_src_output =
610 oc_output_edges[input_index - old_num_per_replica_inputs]
611 ->src_output();
612 Node* src = node_images[original_src][replica_id];
613 const Edge* new_input_edge =
614 g->AddEdge(src, original_src_output, xla_node,
615 replica_id * new_num_per_replica_inputs + input_index);
616 new_input_edges[input_index] = new_input_edge;
617 }
618 }
619
620 // Adjust original _Arg nodes in `xla_graph`.
621 for (Node* n : xla_graph->nodes()) {
622 if (n->IsArg()) {
623 int index;
624 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
625 if (index >= old_num_per_replica_inputs) {
626 index += new_arg_types.size();
627 n->ClearAttr("index");
628 n->AddAttr("index", index);
629 }
630 }
631 }
632
633 // Create new _Arg nodes in `xla_graph`.
634 for (int i = old_num_per_replica_inputs; i < new_num_per_replica_inputs;
635 i++) {
636 NodeDefBuilder arg_builder(absl::StrCat("arg_", i),
637 FunctionLibraryDefinition::kArgOp);
638 arg_builder.Attr("T", new_arg_types[i - old_num_per_replica_inputs]);
639 arg_builder.Attr("index", i);
640 NodeDef arg_def;
641 TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
642 TF_ASSIGN_OR_RETURN(Node * arg_node, xla_graph->AddNode(arg_def));
643 const Edge* original_edge = oc_output_edges[i - old_num_per_replica_inputs];
644 Node* dst = original_edge->dst();
645 int dst_input = original_edge->dst_input();
646 xla_graph->RemoveEdge(original_edge);
647 xla_graph->AddEdge(arg_node, 0, dst, dst_input);
648 }
649
650 // For lifted arg nodes:
651 // 1. Add a Placeholder node in `xla_graph`. When we build host side graph
652 // in ExtractOutsideCompilationPass, we will use this new Placeholder node
653 // instead of lifted arg node here.
654 // 2. Add an IdentityN node in `g` to indicate its inputs. We will reconnect
655 // this IdentityN node and this lifted arg node's usage nodes in
656 // DistributedTPURewritePass.
657 for (Node* n : oc_nodes_at_head) {
658 bool is_lifted_arg;
659 string outside_compilation_attr;
660 if (!TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) ||
661 !TryGetNodeAttr(n->def(), kOutsideCompilationAttr,
662 &outside_compilation_attr)) {
663 continue;
664 }
665
666 TF_RET_CHECK(n->IsIdentity());
667 NodeDefBuilder ph_builder(absl::StrCat("placeholder_", n->name()),
668 "Placeholder");
669 DataType dtype;
670 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
671 ph_builder.Attr("dtype", dtype);
672 ph_builder.Attr(kXlaIsLiftedArgAttrName, true);
673 ph_builder.Attr(kOutsideCompilationAttr, outside_compilation_attr);
674 NodeDef ph_def;
675 TF_RETURN_IF_ERROR(ph_builder.Finalize(&ph_def));
676 Status s;
677 xla_graph->AddNode(ph_def, &s);
678 TF_RETURN_IF_ERROR(s);
679
680 Node* input_node;
681 TF_RETURN_IF_ERROR(n->input_node(0, &input_node));
682 TF_RET_CHECK(input_node->type_string() == "_Arg");
683 int index;
684 TF_RETURN_IF_ERROR(GetNodeAttr(input_node->def(), "index", &index));
685 // TODO(b/74023706): for now we only support resource input (e.g. summary
686 // writer), which is non-replicated input. Support replicated input as
687 // well.
688 TF_RET_CHECK(index >= new_num_per_replica_inputs + num_distributed_vars);
689 const Edge* input_edge =
690 new_input_edges.at(num_replicas * new_num_per_replica_inputs + index -
691 new_num_per_replica_inputs);
692 NodeDefBuilder id_builder(absl::StrCat("lifted_arg_input_", index),
693 "IdentityN");
694 DataType input_dtype =
695 input_edge->src()->output_type(input_edge->src_output());
696 id_builder.Attr("T", std::vector<DataType>(num_replicas, input_dtype));
697 std::vector<NodeDefBuilder::NodeOut> inputs(
698 num_replicas,
699 NodeDefBuilder::NodeOut{input_edge->src()->name(),
700 input_edge->src_output(), input_dtype});
701 id_builder.Attr(kXlaOutsideCompilationInputsAttrName,
702 outside_compilation_attr);
703 id_builder.Input(inputs);
704 NodeDef id_def;
705 TF_RETURN_IF_ERROR(id_builder.Finalize(&id_def));
706 TF_ASSIGN_OR_RETURN(Node * id_node, g->AddNode(id_def));
707 for (int i = 0; i < num_replicas; i++) {
708 g->AddEdge(input_edge->src(), input_edge->src_output(), id_node, i);
709 }
710 }
711
712 // Remove `oc_nodes_at_head`.
713 for (Node* n : oc_nodes_at_head) {
714 xla_graph->RemoveNode(n);
715 }
716
717 VLOG(4) << "MoveHeadOutsideCompilationToHost host graph: "
718 << DumpGraphToFile(absl::StrCat("move_head_oc_host_", xla_func_name),
719 *g);
720 VLOG(4) << "MoveHeadOutsideCompilationToHost XLA graph: "
721 << DumpGraphToFile(absl::StrCat("move_head_oc_xla_", xla_func_name),
722 *xla_graph);
723
724 return OkStatus();
725 }
726
727 // If there are any unused _Arg nodes in `xla_graph`, remove them from
728 // `xla_graph` and remove corresponding input edge in host graph `g`.
RemoveUnusedXlaInput(const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)729 Status RemoveUnusedXlaInput(const string& xla_func_name, Graph* g,
730 Graph* xla_graph, Node* xla_node) {
731 // Find unused _Arg nodes, and remove them.
732 std::vector<DataType> input_types;
733 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tinputs", &input_types));
734 std::vector<int> mirrored_variable_indices;
735 if (xla_node->attrs().Find(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR) !=
736 nullptr) {
737 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(),
738 TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
739 &mirrored_variable_indices));
740 }
741 std::vector<DataType> broadcast_input_types;
742 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tbroadcast_inputs",
743 &broadcast_input_types));
744 std::vector<DataType> guaranteed_constant_types;
745 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tguaranteed_constants",
746 &guaranteed_constant_types));
747 int num_variables;
748 TF_RETURN_IF_ERROR(
749 GetNodeAttr(xla_node->def(), "NumVariables", &num_variables));
750 int num_replicas;
751 TF_RETURN_IF_ERROR(
752 GetNodeAttr(xla_node->def(), "num_replicas", &num_replicas));
753 int num_distributed_vars;
754 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
755 &num_distributed_vars));
756 int num_per_replica_inputs =
757 (input_types.size() - num_distributed_vars) / num_replicas;
758 std::set<int> arg_indices_to_remove;
759 std::vector<Node*> arg_nodes_to_update, nodes_to_remove;
760 int num_args = 0, num_removed_per_replica_inputs = 0,
761 num_removed_distributed_vars = 0;
762 for (Node* n : xla_graph->nodes()) {
763 if (!n->IsArg()) {
764 continue;
765 }
766
767 bool has_output = false;
768 for (const Edge* e : n->out_edges()) {
769 if (e->dst() != xla_graph->sink_node()) {
770 has_output = true;
771 break;
772 }
773 }
774
775 num_args++;
776 int index;
777 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
778 if (has_output) {
779 arg_nodes_to_update.push_back(n);
780 continue;
781 }
782
783 arg_indices_to_remove.insert(index);
784 if (index < num_per_replica_inputs) {
785 num_removed_per_replica_inputs++;
786 } else if (index < num_per_replica_inputs + num_distributed_vars) {
787 num_removed_distributed_vars++;
788 }
789 nodes_to_remove.push_back(n);
790 }
791 for (Node* n : nodes_to_remove) {
792 xla_graph->RemoveNode(n);
793 }
794
795 // Update `index` for other _Arg nodes.
796 std::map<int, int> arg_index_mapping;
797 int new_arg_index = 0;
798 for (int i = 0; i < num_args; i++) {
799 if (arg_indices_to_remove.find(i) != arg_indices_to_remove.end()) {
800 continue;
801 } else {
802 arg_index_mapping[i] = new_arg_index;
803 new_arg_index++;
804 }
805 }
806 for (Node* n : arg_nodes_to_update) {
807 int index;
808 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
809 n->ClearAttr("index");
810 n->AddAttr("index", arg_index_mapping[index]);
811 }
812
813 // Re-order replicated index edges for `xla_node`.
814
815 // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
816 // will become very expensive in this case because it is doing a linear search
817 // inside. Create a input_edges vector ahead to make the lookups faster.
818 std::vector<const Edge*> input_edges;
819 TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
820
821 const int num_new_per_replica_inputs =
822 num_per_replica_inputs - num_removed_per_replica_inputs;
823 for (int i = 0; i < num_replicas; i++) {
824 for (int j = 0; j < num_per_replica_inputs; j++) {
825 auto iter = arg_index_mapping.find(j);
826 if (iter != arg_index_mapping.end()) {
827 const Edge* e = input_edges.at(i * num_per_replica_inputs + j);
828 Node* src = e->src();
829 int src_output = e->src_output();
830 int dst_input = i * num_new_per_replica_inputs + iter->second;
831
832 g->RemoveEdge(e);
833 g->AddEdge(src, src_output, xla_node, dst_input);
834 } else {
835 const Edge* e = input_edges.at(i * num_per_replica_inputs + j);
836 g->RemoveEdge(e);
837 }
838 }
839 }
840
841 // Move other data input edges.
842 for (int i = num_replicas * num_per_replica_inputs;
843 i < xla_node->num_inputs(); i++) {
844 int arg_index =
845 num_per_replica_inputs + i - num_replicas * num_per_replica_inputs;
846 auto iter = arg_index_mapping.find(arg_index);
847 if (iter != arg_index_mapping.end()) {
848 const Edge* e = input_edges.at(i);
849 Node* src = e->src();
850 int src_output = e->src_output();
851 int dst_input = num_replicas * num_new_per_replica_inputs + iter->second -
852 num_new_per_replica_inputs;
853
854 g->RemoveEdge(e);
855 g->AddEdge(src, src_output, xla_node, dst_input);
856 } else {
857 const Edge* e = input_edges.at(i);
858 g->RemoveEdge(e);
859 }
860 }
861
862 // Update attributes for `xla_node`.
863 std::vector<DataType> new_input_types;
864 for (int i = 0; i < num_replicas; i++) {
865 for (int j = 0; j < num_per_replica_inputs; j++) {
866 auto iter = arg_index_mapping.find(j);
867 if (iter != arg_index_mapping.end()) {
868 new_input_types.push_back(input_types[iter->first]);
869 }
870 }
871 }
872 for (int i = 0; i < num_distributed_vars; ++i) {
873 auto iter = arg_index_mapping.find(i + num_per_replica_inputs);
874 if (iter != arg_index_mapping.end()) {
875 new_input_types.push_back(
876 input_types[iter->first - num_per_replica_inputs +
877 num_per_replica_inputs * num_replicas]);
878 }
879 }
880 xla_node->ClearAttr("Tinputs");
881 xla_node->AddAttr("Tinputs", new_input_types);
882
883 const int num_new_distributed_vars =
884 num_distributed_vars - num_removed_distributed_vars;
885 xla_node->ClearAttr("num_distributed_variables");
886 xla_node->AddAttr("num_distributed_variables", num_new_distributed_vars);
887
888 if (!mirrored_variable_indices.empty()) {
889 std::vector<int> new_mirrored_variable_indices;
890 absl::flat_hash_set<int> old_mirrored_variable_indices_set;
891 for (int index : mirrored_variable_indices) {
892 old_mirrored_variable_indices_set.insert(index);
893 }
894 for (int i = 0; i < num_per_replica_inputs + num_distributed_vars; i++) {
895 auto iter = arg_index_mapping.find(i);
896 if (iter != arg_index_mapping.end() &&
897 old_mirrored_variable_indices_set.contains(iter->first)) {
898 new_mirrored_variable_indices.push_back(iter->second);
899 }
900 }
901 xla_node->ClearAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR);
902 xla_node->AddAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
903 new_mirrored_variable_indices);
904 }
905
906 int num_replicated_inputs = num_per_replica_inputs + num_distributed_vars;
907 std::vector<DataType> new_broadcast_input_types;
908 for (int i = 0; i < broadcast_input_types.size(); i++) {
909 int arg_index = num_replicated_inputs + i;
910 if (arg_index_mapping.find(arg_index) != arg_index_mapping.end()) {
911 new_broadcast_input_types.push_back(broadcast_input_types[i]);
912 }
913 }
914 xla_node->ClearAttr("Tbroadcast_inputs");
915 xla_node->AddAttr("Tbroadcast_inputs", new_broadcast_input_types);
916 int new_num_variables = 0;
917 for (int i = 0; i < num_variables; i++) {
918 int arg_index = num_replicated_inputs + broadcast_input_types.size() + i;
919 if (arg_index_mapping.find(arg_index) != arg_index_mapping.end()) {
920 new_num_variables++;
921 }
922 }
923 xla_node->ClearAttr("NumVariables");
924 xla_node->AddAttr("NumVariables", new_num_variables);
925 std::vector<DataType> new_guaranteed_constant_types;
926 for (int i = 0; i < guaranteed_constant_types.size(); i++) {
927 int arg_index = num_replicated_inputs + broadcast_input_types.size() +
928 num_variables + i;
929 if (arg_index_mapping.find(arg_index) != arg_index_mapping.end()) {
930 new_guaranteed_constant_types.push_back(guaranteed_constant_types[i]);
931 }
932 }
933 xla_node->ClearAttr("Tguaranteed_constants");
934 xla_node->AddAttr("Tguaranteed_constants", new_guaranteed_constant_types);
935
936 int new_variable_start_index = num_new_per_replica_inputs +
937 num_new_distributed_vars +
938 new_broadcast_input_types.size();
939 if (xla_node->attrs().Find("_variable_start_index") != nullptr) {
940 xla_node->ClearAttr("_variable_start_index");
941 xla_node->AddAttr("_variable_start_index", new_variable_start_index);
942 }
943 int new_guaranteed_const_start_index =
944 new_variable_start_index + new_num_variables;
945 if (xla_node->attrs().Find("_guaranteed_const_start_index") != nullptr) {
946 xla_node->ClearAttr("_guaranteed_const_start_index");
947 xla_node->AddAttr("_guaranteed_const_start_index",
948 new_guaranteed_const_start_index);
949 }
950
951 VLOG(4) << "RemoveUnusedXlaInput host graph: "
952 << DumpGraphToFile(
953 absl::StrCat("remove_unused_input_host_", xla_func_name), *g);
954 VLOG(4) << "RemoveUnusedXlaInput XLA graph: "
955 << DumpGraphToFile(
956 absl::StrCat("remove_unused_input_xla_", xla_func_name),
957 *xla_graph);
958
959 return OkStatus();
960 }
961
962 // Move outside compilation nodes at the end of XLA computation to host.
963 // For XLA computation graph, we will add new _Retval nodes to replace those
964 // outside compilation nodes.
965 // For host graph, we will move those outside compilation nodes to host,
966 // replicate them, and use them as XLA node's output.
MoveTailOutsideCompilationToHost(const string & outside_compilation_attr_name,const string & xla_func_name,const std::string & cluster_name,Graph * g,Graph * xla_graph,Node * xla_node,Node * pivot_node)967 Status MoveTailOutsideCompilationToHost(
968 const string& outside_compilation_attr_name, const string& xla_func_name,
969 const std::string& cluster_name, Graph* g, Graph* xla_graph, Node* xla_node,
970 Node* pivot_node) {
971 // Find outside compilation nodes that only have _Retval or other outside
972 // compilation nodes as output. These nodes will be moved to host graph.
973 std::vector<Node*> oc_nodes_at_tail;
974 const string kOnlyRetOrOcOutputAttrName = "_xla_only_ret_or_oc_output";
975 DFS(
976 *xla_graph, /*enter=*/nullptr,
977 [&](Node* n) {
978 bool has_non_ret_or_oc_output = false;
979 for (const Edge* e : n->out_edges()) {
980 if (e->dst() == xla_graph->sink_node()) {
981 continue;
982 }
983 if (!e->dst()->IsRetval() &&
984 (!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name) ||
985 !HasNodeAttr(e->dst()->def(), kOnlyRetOrOcOutputAttrName))) {
986 has_non_ret_or_oc_output = true;
987 break;
988 }
989 }
990 if (HasNodeAttr(n->def(), outside_compilation_attr_name) &&
991 !has_non_ret_or_oc_output) {
992 n->AddAttr(kOnlyRetOrOcOutputAttrName, true);
993 oc_nodes_at_tail.push_back(n);
994 }
995 },
996 NodeComparatorName());
997 if (VLOG_IS_ON(5)) {
998 for (Node* n : oc_nodes_at_tail) {
999 VLOG(5) << "oc_nodes_at_tail: " << n->DebugString();
1000 }
1001 }
1002
1003 // Record input edges from `oc_nodes_at_tail`. We will create an _Retval node
1004 // for each of these edges. An obvious optimization here is to deduplicate
1005 // these edges by <src, src_output>. But that optimization will complicate
1006 // the code, and in practice we usually do not have input edges with the
1007 // same <src, src_output>.
1008 std::vector<const Edge*> oc_input_edges;
1009 std::vector<DataType> new_ret_types;
1010 for (Node* n : oc_nodes_at_tail) {
1011 for (const Edge* e : n->in_edges()) {
1012 if (!e->IsControlEdge() &&
1013 !HasNodeAttr(e->src()->def(), kOnlyRetOrOcOutputAttrName)) {
1014 VLOG(5) << "oc_input_edges: " << e->DebugString();
1015 oc_input_edges.push_back(e);
1016 new_ret_types.push_back(e->src()->output_type(e->src_output()));
1017 }
1018 }
1019 }
1020 std::vector<DataType> output_types;
1021 TF_RETURN_IF_ERROR(
1022 GetNodeAttr(xla_node->attrs(), "output_types", &output_types));
1023 int num_replicas;
1024 TF_RETURN_IF_ERROR(
1025 GetNodeAttr(xla_node->attrs(), "num_replicas", &num_replicas));
1026 int old_num_replicated_outputs = output_types.size() / num_replicas;
1027 int new_num_replicated_outputs =
1028 old_num_replicated_outputs + oc_input_edges.size();
1029 VLOG(5) << "old_num_replicated_outputs: " << old_num_replicated_outputs;
1030 VLOG(5) << "new_num_replicated_outputs: " << new_num_replicated_outputs;
1031
1032 // Update `output_types` attribute for `xla_node`.
1033 std::vector<DataType> new_output_types;
1034 for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1035 for (int i = 0; i < old_num_replicated_outputs; i++) {
1036 new_output_types.push_back(output_types[i]);
1037 }
1038 for (int i = old_num_replicated_outputs; i < new_num_replicated_outputs;
1039 i++) {
1040 new_output_types.push_back(new_ret_types[i - old_num_replicated_outputs]);
1041 }
1042 }
1043 xla_node->ClearAttr("output_types");
1044 xla_node->AddAttr("output_types", new_output_types);
1045
1046 // Re-order old replicated output edges. Since a node could potentially
1047 // connect to multiple nodes, build a vector<vector<pair>> mapping of
1048 // output index to input nodes/index.
1049 // The outer vector represents the output index, the inner vector
1050 // represents the destination node and input index pair with the possibility
1051 // of multiple node/index pairs.
1052 std::vector<std::vector<std::pair<Node*, int>>> replicated_outputs(
1053 old_num_replicated_outputs * num_replicas);
1054 std::vector<const Edge*> old_replicated_edges;
1055 for (const Edge* e : xla_node->out_edges()) {
1056 if (e->src_output() >= 0 &&
1057 e->src_output() < old_num_replicated_outputs * num_replicas) {
1058 replicated_outputs[e->src_output()].push_back(
1059 std::make_pair(e->dst(), e->dst_input()));
1060 old_replicated_edges.push_back(e);
1061 }
1062 }
1063 for (const Edge* e : old_replicated_edges) {
1064 g->RemoveEdge(e);
1065 }
1066 for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1067 for (int output_index = 0; output_index < old_num_replicated_outputs;
1068 output_index++) {
1069 for (const auto& node_input_pair :
1070 replicated_outputs[replica_id * old_num_replicated_outputs +
1071 output_index]) {
1072 Node* dst = node_input_pair.first;
1073 int dst_input = node_input_pair.second;
1074 g->AddEdge(xla_node,
1075 replica_id * new_num_replicated_outputs + output_index, dst,
1076 dst_input);
1077 }
1078 }
1079 }
1080
1081 // Copy all nodes in `oc_nodes_at_tail` to host graph, and also replicate
1082 // them.
1083 std::map<Node*, std::vector<Node*>> node_images;
1084 for (Node* n : oc_nodes_at_tail) {
1085 for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1086 NodeDef copy_def = n->def();
1087 copy_def.set_name(absl::StrCat(n->name(), "_tail_oc/R", replica_id));
1088 copy_def.clear_device();
1089
1090 TF_ASSIGN_OR_RETURN(Node * copy_node, g->AddNode(copy_def));
1091
1092 copy_node->AddAttr(kXlaReplicaIdAttrName, replica_id);
1093 copy_node->AddAttr(kTPUReplicateAttr, cluster_name);
1094
1095 for (const Edge* e : n->out_edges()) {
1096 if (e->dst() == xla_graph->sink_node()) {
1097 continue;
1098 }
1099 // Either e->dst() is _Retval, or it's in `node_images`.
1100 if (e->dst()->IsRetval()) {
1101 int index;
1102 TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->attrs(), "index", &index));
1103 for (const auto& output :
1104 replicated_outputs[replica_id * old_num_replicated_outputs +
1105 index]) {
1106 // Remove original input edge, if existent.
1107 const Edge* original_edge;
1108 Status s = output.first->input_edge(output.second, &original_edge);
1109 if (s.ok()) {
1110 g->RemoveEdge(original_edge);
1111 }
1112 g->AddEdge(copy_node, e->src_output(), output.first, output.second);
1113 }
1114 } else {
1115 g->AddEdge(copy_node, e->src_output(),
1116 node_images[e->dst()][replica_id], e->dst_input());
1117 }
1118 }
1119
1120 // Add attribute "_xla_tail_outside_compilation" to `copy_node`, and add a
1121 // control edge between `xla_node` and `copy_node`. As a result, in later
1122 // rewriting pass, a control edge will be added between `copy_node` and
1123 // "control_after" node for the XLA computation, so `copy_node` will be
1124 // executed before XLA computation's final results.
1125 copy_node->AddAttr("_xla_tail_outside_compilation", true);
1126 g->AddControlEdge(xla_node, copy_node);
1127
1128 // Add control edge between `pivot_node` and `copy_node`, so `copy_node`
1129 // belongs to same while loop as `xla_node`.
1130 if (pivot_node) {
1131 g->AddControlEdge(pivot_node, copy_node);
1132 }
1133
1134 node_images[n].push_back(copy_node);
1135 }
1136 }
1137
1138 // Connect new output values of `xla_node` to dst nodes of `oc_input_edges`.
1139 for (int i = 0; i < new_ret_types.size(); i++) {
1140 const Edge* original_edge = oc_input_edges[i];
1141 for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1142 int src_output = replica_id * new_num_replicated_outputs +
1143 old_num_replicated_outputs + i;
1144 Node* dst = node_images[original_edge->dst()][replica_id];
1145 g->AddEdge(xla_node, src_output, dst, original_edge->dst_input());
1146 }
1147 }
1148
1149 // Create new _Retval nodes in `xla_graph`.
1150 for (int i = old_num_replicated_outputs; i < new_num_replicated_outputs;
1151 i++) {
1152 NodeDefBuilder ret_builder(absl::StrCat("ret_", i),
1153 FunctionLibraryDefinition::kRetOp);
1154 ret_builder.Attr("T", new_ret_types[i - old_num_replicated_outputs]);
1155 ret_builder.Attr("index", i);
1156 const Edge* original_edge = oc_input_edges[i - old_num_replicated_outputs];
1157 Node* src = original_edge->src();
1158 int src_output = original_edge->src_output();
1159 ret_builder.Input(src->name(), src_output, src->output_type(src_output));
1160 NodeDef ret_def;
1161 TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1162 TF_ASSIGN_OR_RETURN(Node * ret_node, xla_graph->AddNode(ret_def));
1163 xla_graph->RemoveEdge(original_edge);
1164 xla_graph->AddEdge(src, src_output, ret_node, 0);
1165 }
1166
1167 // Remove `oc_nodes_at_tail`.
1168 for (Node* n : oc_nodes_at_tail) {
1169 xla_graph->RemoveNode(n);
1170 }
1171
1172 // We cannot leave _Retval with no input. Add a placeholder input, which will
1173 // be removed later with unused _Retval.
1174 std::vector<Node*> unused_rets;
1175 for (Node* n : xla_graph->nodes()) {
1176 if (n->IsRetval() && n->in_edges().empty()) {
1177 unused_rets.push_back(n);
1178 }
1179 }
1180 for (Node* n : unused_rets) {
1181 NodeDefBuilder builder(absl::StrCat("placeholder_", n->name()),
1182 "Placeholder");
1183 DataType dtype;
1184 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
1185 builder.Attr("dtype", dtype);
1186 builder.Attr(kXlaIsPlaceholderForTailOcAttrName, true);
1187 NodeDef def;
1188 TF_RETURN_IF_ERROR(builder.Finalize(&def));
1189 TF_ASSIGN_OR_RETURN(Node * placeholder, xla_graph->AddNode(def));
1190 xla_graph->AddEdge(placeholder, 0, n, 0);
1191 }
1192
1193 VLOG(4) << "MoveTailOutsideCompilationToHost host graph: "
1194 << DumpGraphToFile(absl::StrCat("move_tail_oc_host_", xla_func_name),
1195 *g);
1196 VLOG(4) << "MoveTaildOutsideCompilationToHost XLA graph: "
1197 << DumpGraphToFile(absl::StrCat("move_tail_oc_xla_", xla_func_name),
1198 *xla_graph);
1199
1200 return OkStatus();
1201 }
1202
ReplaceArgUsedByOutsideCompilationWithPlaceholder(const string & outside_compilation_attr_name,const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)1203 Status ReplaceArgUsedByOutsideCompilationWithPlaceholder(
1204 const string& outside_compilation_attr_name, const string& xla_func_name,
1205 Graph* g, Graph* xla_graph, Node* xla_node) {
1206 std::vector<DataType> input_types;
1207 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tinputs", &input_types));
1208 int num_distributed_vars;
1209 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
1210 &num_distributed_vars));
1211 int num_replicas;
1212 TF_RETURN_IF_ERROR(
1213 GetNodeAttr(xla_node->attrs(), "num_replicas", &num_replicas));
1214 int num_per_replica_inputs =
1215 (input_types.size() - num_distributed_vars) / num_replicas;
1216
1217 for (Node* n : xla_graph->op_nodes()) {
1218 if (!n->IsArg()) {
1219 continue;
1220 }
1221
1222 DataType dtype;
1223 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
1224 // TODO(b/74023706): enable moving normal data tensors.
1225 if (dtype != DT_RESOURCE) {
1226 continue;
1227 }
1228
1229 std::vector<const Edge*> oc_out_edges;
1230 for (const Edge* e : n->out_edges()) {
1231 if (e->IsControlEdge() ||
1232 !HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1233 continue;
1234 }
1235
1236 oc_out_edges.push_back(e);
1237 }
1238 if (oc_out_edges.empty()) {
1239 continue;
1240 }
1241
1242 // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
1243 // will become very expensive in this case because it is doing a linear
1244 // search inside. Create an input_edges vector ahead to make the lookups
1245 // faster.
1246 std::vector<const Edge*> input_edges;
1247 TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
1248
1249 // Build an IdentityN node to record inputs for this _Arg node.
1250 int index;
1251 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
1252 string oc_identifier = absl::StrCat("oc_only_arg_", index);
1253 NodeDefBuilder id_builder(absl::StrCat(oc_identifier, "_inputs"),
1254 "IdentityN");
1255 std::vector<DataType> dtypes(num_replicas, dtype);
1256 id_builder.Attr("T", dtypes);
1257 id_builder.Attr(kXlaOutsideCompilationInputsAttrName, oc_identifier);
1258 std::vector<NodeDefBuilder::NodeOut> inputs(num_replicas);
1259 if (index >= num_per_replica_inputs) {
1260 const Edge* e = input_edges.at(num_replicas * num_per_replica_inputs +
1261 (index - num_per_replica_inputs));
1262 for (int i = 0; i < num_replicas; i++) {
1263 inputs[i] =
1264 NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
1265 e->src()->output_type(e->src_output())};
1266 }
1267 } else {
1268 for (int i = 0; i < num_replicas; i++) {
1269 const Edge* e = input_edges.at(i * num_per_replica_inputs + index);
1270 inputs[i] =
1271 NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
1272 e->src()->output_type(e->src_output())};
1273 }
1274 }
1275 id_builder.Input(inputs);
1276 NodeDef id_def;
1277 TF_RETURN_IF_ERROR(id_builder.Finalize(&id_def));
1278 TF_ASSIGN_OR_RETURN(Node * id_node, g->AddNode(id_def));
1279 if (index >= num_per_replica_inputs) {
1280 const Edge* e = input_edges.at(num_replicas * num_per_replica_inputs +
1281 (index - num_per_replica_inputs));
1282 for (int i = 0; i < num_replicas; i++) {
1283 g->AddEdge(e->src(), e->src_output(), id_node, i);
1284 }
1285 } else {
1286 for (int i = 0; i < num_replicas; i++) {
1287 const Edge* e = input_edges.at(i * num_per_replica_inputs + index);
1288 g->AddEdge(e->src(), e->src_output(), id_node, i);
1289 }
1290 }
1291
1292 for (const Edge* e : oc_out_edges) {
1293 // 'e' will use a new Placeholder node as input.
1294 NodeDefBuilder ph_builder(xla_graph->NewName("ph_for_arg_in_oc_"),
1295 "Placeholder");
1296 ph_builder.Attr("dtype", dtype);
1297
1298 string outside_compilation_attr;
1299 TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->def(), kOutsideCompilationAttr,
1300 &outside_compilation_attr));
1301 ph_builder.Attr(kOutsideCompilationAttr, outside_compilation_attr);
1302 ph_builder.Attr(kXlaOutsideCompilationInputsAttrName, oc_identifier);
1303 ph_builder.Attr(kXlaIsPlaceholderForArg, true);
1304 NodeDef ph_def;
1305 TF_RETURN_IF_ERROR(ph_builder.Finalize(&ph_def));
1306 TF_ASSIGN_OR_RETURN(Node * ph_node, xla_graph->AddNode(ph_def));
1307 Node* dst = e->dst();
1308 int dst_input = e->dst_input();
1309 xla_graph->RemoveEdge(e);
1310 xla_graph->AddEdge(ph_node, 0, dst, dst_input);
1311 xla_graph->AddControlEdge(xla_graph->source_node(), ph_node);
1312 }
1313 }
1314 VLOG(4) << "ReplaceOutsideCompilationOnlyArgWithPlaceholder host graph: "
1315 << DumpGraphToFile(
1316 absl::StrCat("replace_oc_only_arg_host_", xla_func_name), *g);
1317 VLOG(4) << "ReplaceOutsideCompilationOnlyArgWithPlaceholder XLA graph: "
1318 << DumpGraphToFile(
1319 absl::StrCat("replace_oc_only_arg_xla_", xla_func_name),
1320 *xla_graph);
1321 return OkStatus();
1322 }
1323
1324 // If there are any unused _Retval nodes in `xla_graph` (whose input is a
1325 // Placeholder node), remove them from `xla_graph` and remove corresponding
1326 // output edge in host graph `g`.
RemoveUnusedXlaOutput(const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)1327 Status RemoveUnusedXlaOutput(const string& xla_func_name, Graph* g,
1328 Graph* xla_graph, Node* xla_node) {
1329 // Find unused _Retval nodes, and remove them.
1330 std::vector<DataType> output_types;
1331 TF_RETURN_IF_ERROR(
1332 GetNodeAttr(xla_node->def(), "output_types", &output_types));
1333 int num_replicas;
1334 TF_RETURN_IF_ERROR(
1335 GetNodeAttr(xla_node->def(), "num_replicas", &num_replicas));
1336 int num_replicated_outputs = output_types.size() / num_replicas;
1337 std::set<int> ret_indices_to_remove;
1338 std::vector<Node*> ret_nodes_to_update, nodes_to_remove;
1339 int num_rets = 0;
1340 for (Node* n : xla_graph->nodes()) {
1341 if (!n->IsRetval()) {
1342 continue;
1343 }
1344
1345 num_rets++;
1346
1347 const Edge* e;
1348 TF_RETURN_IF_ERROR(n->input_edge(0, &e));
1349 if (e->src()->type_string() != "Placeholder" ||
1350 !HasNodeAttr(e->src()->def(), kXlaIsPlaceholderForTailOcAttrName)) {
1351 ret_nodes_to_update.push_back(n);
1352 continue;
1353 }
1354
1355 int index;
1356 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
1357 ret_indices_to_remove.insert(index);
1358 nodes_to_remove.push_back(e->src());
1359 nodes_to_remove.push_back(n);
1360 }
1361 for (Node* n : nodes_to_remove) {
1362 xla_graph->RemoveNode(n);
1363 }
1364
1365 // Update `index` for other _Arg nodes.
1366 std::map<int, int> ret_index_mapping;
1367 int new_ret_index = 0;
1368 for (int i = 0; i < num_rets; i++) {
1369 if (ret_indices_to_remove.find(i) != ret_indices_to_remove.end()) {
1370 continue;
1371 } else {
1372 ret_index_mapping[i] = new_ret_index;
1373 new_ret_index++;
1374 }
1375 }
1376 for (Node* n : ret_nodes_to_update) {
1377 int index;
1378 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
1379 n->ClearAttr("index");
1380 n->AddAttr("index", ret_index_mapping[index]);
1381 }
1382
1383 // Update `output_types` attribute for `xla_node`.
1384 std::vector<DataType> new_output_types;
1385 for (int i = 0; i < num_replicas; i++) {
1386 for (const auto& e : ret_index_mapping) {
1387 new_output_types.push_back(output_types[e.first]);
1388 }
1389 }
1390
1391 xla_node->ClearAttr("output_types");
1392 xla_node->AddAttr("output_types", new_output_types);
1393
1394 // Re-order replicated output edges for `xla_node`.
1395 std::vector<std::vector<const Edge*>> output_edges(num_replicas *
1396 num_replicated_outputs);
1397 for (const Edge* e : xla_node->out_edges()) {
1398 if (e->src_output() >= 0 &&
1399 e->src_output() < num_replicas * num_replicated_outputs) {
1400 output_edges[e->src_output()].push_back(e);
1401 }
1402 }
1403 for (int i = 0; i < num_replicas; i++) {
1404 for (int j = 0; j < num_replicated_outputs; j++) {
1405 auto iter = ret_index_mapping.find(j);
1406 if (iter != ret_index_mapping.end()) {
1407 for (const Edge* e : output_edges[i * num_replicated_outputs + j]) {
1408 Node* dst = e->dst();
1409 int dst_input = e->dst_input();
1410 int src_output =
1411 i * (num_replicated_outputs - ret_indices_to_remove.size()) +
1412 iter->second;
1413 g->RemoveEdge(e);
1414 g->AddEdge(xla_node, src_output, dst, dst_input);
1415 }
1416 } else {
1417 TF_RET_CHECK(output_edges[i * num_replicated_outputs + j].empty())
1418 << "Output edge not removed: "
1419 << output_edges[i * num_replicated_outputs + j][0]->DebugString();
1420 }
1421 }
1422 }
1423
1424 VLOG(4) << "RemoveUnusedXlaOutput host graph: "
1425 << DumpGraphToFile(
1426 absl::StrCat("remove_unused_output_host_", xla_func_name), *g);
1427 VLOG(4) << "RemoveUnusedXlaOutput XLA graph: "
1428 << DumpGraphToFile(
1429 absl::StrCat("remove_unused_output_xla_", xla_func_name),
1430 *xla_graph);
1431
1432 return OkStatus();
1433 }
1434
1435 // For data edges between _Arg and _Retval in `xla_graph`, remove them and
1436 // change input/output edges in `g` (host graph). For now, we only consider
1437 // replicated inputs.
RemoveEdgesBetweenArgAndRetval(const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)1438 Status RemoveEdgesBetweenArgAndRetval(const string& xla_func_name, Graph* g,
1439 Graph* xla_graph, Node* xla_node) {
1440 // Collect data edges between _Arg and _Retval.
1441 int num_replicas;
1442 TF_RETURN_IF_ERROR(
1443 GetNodeAttr(xla_node->def(), "num_replicas", &num_replicas));
1444 std::vector<DataType> input_types;
1445 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tinputs", &input_types));
1446 int num_distributed_vars;
1447 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
1448 &num_distributed_vars));
1449 int old_num_per_replica_inputs =
1450 (input_types.size() - num_distributed_vars) / num_replicas;
1451 std::vector<DataType> output_types;
1452 TF_RETURN_IF_ERROR(
1453 GetNodeAttr(xla_node->def(), "output_types", &output_types));
1454 int old_num_outputs = output_types.size() / num_replicas;
1455 std::vector<const Edge*> edges;
1456 for (const Edge* e : xla_graph->edges()) {
1457 if (!e->IsControlEdge() && e->src()->IsArg() && e->dst()->IsRetval()) {
1458 edges.push_back(e);
1459 }
1460 }
1461
1462 // In host graph `g`, remove output edge from `xla_node` and connect input &
1463 // output directly.
1464 std::vector<std::vector<const Edge*>> xla_node_out_edges(
1465 xla_node->num_outputs());
1466 for (const Edge* e : xla_node->out_edges()) {
1467 if (!e->IsControlEdge()) {
1468 xla_node_out_edges[e->src_output()].push_back(e);
1469 }
1470 }
1471
1472 // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
1473 // will become very expensive in this case because it is doing a linear
1474 // search inside. Create an input_edges vector ahead to make the lookups
1475 // faster.
1476 std::vector<const Edge*> input_edges;
1477 TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
1478 for (const Edge* e : edges) {
1479 int arg_index;
1480 TF_RETURN_IF_ERROR(GetNodeAttr(e->src()->def(), "index", &arg_index));
1481 int ret_index;
1482 TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->def(), "index", &ret_index));
1483
1484 for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1485 int input_index;
1486 if (arg_index < old_num_per_replica_inputs) {
1487 input_index = replica_id * old_num_per_replica_inputs + arg_index;
1488 } else {
1489 input_index = num_replicas * old_num_per_replica_inputs +
1490 (arg_index - old_num_per_replica_inputs);
1491 }
1492 const Edge* input_edge = input_edges.at(input_index);
1493
1494 int output_index = replica_id * old_num_outputs + ret_index;
1495 for (const Edge* output_edge : xla_node_out_edges[output_index]) {
1496 Node* dst = output_edge->dst();
1497 int dst_input = output_edge->dst_input();
1498
1499 g->RemoveEdge(output_edge);
1500 g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
1501 }
1502 }
1503 }
1504
1505 // Remove edges from `xla_graph`. Add a Placeholder node for the _Retval node,
1506 // which will be removed by `RemoveUnusedXlaOutput()` later.
1507 for (const Edge* e : edges) {
1508 NodeDefBuilder placeholder_builder(
1509 absl::StrCat("placeholder_", e->dst()->name()), "Placeholder");
1510 placeholder_builder.Attr("dtype", e->src()->output_type(e->src_output()));
1511 placeholder_builder.Attr(kXlaIsPlaceholderForTailOcAttrName, true);
1512 NodeDef placeholder_def;
1513 TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
1514 TF_ASSIGN_OR_RETURN(Node * placeholder_node,
1515 xla_graph->AddNode(placeholder_def));
1516
1517 Node* dst = e->dst();
1518 int dst_input = e->dst_input();
1519 xla_graph->RemoveEdge(e);
1520 xla_graph->AddEdge(placeholder_node, 0, dst, dst_input);
1521 }
1522
1523 VLOG(4) << "RemoveUnusedArgRetvalPair host graph: "
1524 << DumpGraphToFile(
1525 absl::StrCat("remove_unused_arg_ret_host_", xla_func_name),
1526 *g);
1527 VLOG(4) << "RemoveUnusedArgRetvalPair XLA graph: "
1528 << DumpGraphToFile(
1529 absl::StrCat("remove_unused_arg_ret_xla_", xla_func_name),
1530 *xla_graph);
1531
1532 return OkStatus();
1533 }
1534
1535 // Remove any TPUReplicatedInput nodes with no output edges. Those nodes are
1536 // usually TPUMirroredVariable handles which are not used by any computations.
RemoveUnusedTPUReplicatedInputs(Graph * graph)1537 void RemoveUnusedTPUReplicatedInputs(Graph* graph) {
1538 for (Node* n : graph->nodes()) {
1539 if (n->type_string() == kTPUReplicatedInput) {
1540 bool has_output = false;
1541 for (const Edge* e : n->out_edges()) {
1542 if (!e->dst()->IsSink()) {
1543 has_output = true;
1544 break;
1545 }
1546 }
1547 if (!has_output) {
1548 // Remove any TPUPartitionedInput node from the src nodes of the
1549 // to-be-removed TPUReplicatedInput node
1550 std::vector<Node*> to_be_removed_src_nodes;
1551 for (const auto& e_in : n->in_edges()) {
1552 if (!e_in->IsControlEdge() &&
1553 e_in->src()->type_string() == kTPUPartitionedInput)
1554 to_be_removed_src_nodes.push_back(e_in->src());
1555 }
1556 graph->RemoveNode(n);
1557 for (Node* node : to_be_removed_src_nodes) {
1558 graph->RemoveNode(node);
1559 }
1560 }
1561 }
1562 }
1563 }
1564
1565 // We might have duplicated cluster names in the graph, e.g. when a tf.function
1566 // containing tpu_strategy.run() is called multiple times with
1567 // the same inputs. Find clusters with duplicated names and rename them.
RenameClustersWithDuplicatedNames(Graph * g)1568 Status RenameClustersWithDuplicatedNames(Graph* g) {
1569 // Find all TPU clusters by finding all TPUReplicateMetadata nodes.
1570 std::unordered_map<string, std::vector<Node*>> cluster_name_to_metadata_nodes;
1571 std::unordered_set<string> cluster_names;
1572 for (Node* n : g->nodes()) {
1573 if (n->type_string() != "TPUReplicateMetadata") {
1574 continue;
1575 }
1576 string cluster_name;
1577 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kTPUReplicateAttr, &cluster_name));
1578 cluster_name_to_metadata_nodes[cluster_name].push_back(n);
1579 cluster_names.insert(cluster_name);
1580 }
1581 // Look for clusters with duplicated name.
1582 for (const auto& iter : cluster_name_to_metadata_nodes) {
1583 if (iter.second.size() == 1) {
1584 continue;
1585 }
1586
1587 // Rename clusters.
1588 for (int i = 1; i < iter.second.size(); i++) {
1589 // Find an available cluster name.
1590 string new_cluster_name;
1591 int cluster_name_suffix = 1;
1592 while (true) {
1593 new_cluster_name = absl::StrCat(iter.first, "_", cluster_name_suffix);
1594 if (cluster_names.find(new_cluster_name) == cluster_names.end()) {
1595 break;
1596 }
1597 cluster_name_suffix++;
1598 }
1599 cluster_names.insert(new_cluster_name);
1600
1601 // Change _tpu_replicate attribute for all nodes in this cluster.
1602 // Start with outputs of TPUReplicateMetadata and follow output edges.
1603 std::queue<Node*> queue;
1604 queue.push(iter.second.at(i));
1605 std::unordered_set<Node*> visited;
1606 while (!queue.empty()) {
1607 Node* n = queue.front();
1608 queue.pop();
1609
1610 visited.insert(n);
1611
1612 n->ClearAttr(kTPUReplicateAttr);
1613 n->AddAttr(kTPUReplicateAttr, new_cluster_name);
1614
1615 string cluster_name;
1616 for (const Edge* e : n->out_edges()) {
1617 if (GetNodeAttr(e->dst()->def(), kTPUReplicateAttr, &cluster_name)
1618 .ok() &&
1619 cluster_name == iter.first &&
1620 visited.find(e->dst()) == visited.end()) {
1621 queue.push(e->dst());
1622 }
1623 }
1624 }
1625 // Change "_tpu_compilation_status" attr for TPUCompilationResult node.
1626 for (const Edge* e : iter.second.at(i)->out_edges()) {
1627 if (e->dst()->type_string() == "TPUCompilationResult") {
1628 e->dst()->ClearAttr("_tpu_compilation_status");
1629 e->dst()->AddAttr("_tpu_compilation_status", new_cluster_name);
1630 }
1631 }
1632 }
1633 }
1634 return OkStatus();
1635 }
1636
1637 // Instantiate a function that is associated with a functional control flow
1638 // node. The function name is found by looking up `function_name_attr` of given
1639 // node.
InstantiateAssociatedFunction(const Node & n,absl::string_view function_name_attr,FunctionLibraryDefinition * fld)1640 xla::StatusOr<std::unique_ptr<FunctionBody>> InstantiateAssociatedFunction(
1641 const Node& n, absl::string_view function_name_attr,
1642 FunctionLibraryDefinition* fld) {
1643 std::unique_ptr<FunctionBody> fbody;
1644 NameAttrList func_attr_list;
1645 TF_RETURN_IF_ERROR(GetNodeAttr(n.def(), function_name_attr, &func_attr_list));
1646 const FunctionDef* fdef = fld->Find(func_attr_list.name());
1647 if (fdef == nullptr) {
1648 return errors::Internal("Cannot find ", function_name_attr, " function",
1649 "for node ", n.DebugString());
1650 }
1651 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1652 *fdef, AttrSlice(&func_attr_list.attr()), fld, &fbody));
1653 return fbody;
1654 }
1655
1656 // Find inputs of If node that are only used for outside compilation if used at
1657 // all in both if/else branches
FindArgsToLiftForIfNode(const Node & if_node,FunctionLibraryDefinition * fld)1658 xla::StatusOr<absl::flat_hash_set<int>> FindArgsToLiftForIfNode(
1659 const Node& if_node, FunctionLibraryDefinition* fld) {
1660 absl::flat_hash_set<int> args_to_lift_indices;
1661 std::vector<DataType> dtypes;
1662 TF_RETURN_IF_ERROR(GetNodeAttr(if_node.def(), "Tin", &dtypes));
1663
1664 int num_args = dtypes.size();
1665
1666 for (int i = 0; i < num_args; i++) {
1667 // TODO(b/74023706): enable non resource inputs as well.
1668 if (dtypes[i] == DT_RESOURCE) {
1669 args_to_lift_indices.insert(i);
1670 }
1671 }
1672
1673 TF_ASSIGN_OR_RETURN(
1674 std::unique_ptr<FunctionBody> then_branch_fbody,
1675 InstantiateAssociatedFunction(if_node, "then_branch", fld));
1676
1677 TF_ASSIGN_OR_RETURN(
1678 std::unique_ptr<FunctionBody> else_branch_fbody,
1679 InstantiateAssociatedFunction(if_node, "else_branch", fld));
1680
1681 for (int i = 0; i < num_args; ++i) {
1682 bool used = false;
1683
1684 const Node* then_arg_node = then_branch_fbody->arg_nodes[i];
1685 for (const Edge* e : then_arg_node->out_edges()) {
1686 used = true;
1687 if (e->IsControlEdge() ||
1688 HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr))
1689 continue;
1690
1691 args_to_lift_indices.erase(i);
1692 break;
1693 }
1694
1695 const Node* else_arg_node = else_branch_fbody->arg_nodes[i];
1696 for (const Edge* e : else_arg_node->out_edges()) {
1697 used = true;
1698 if (e->IsControlEdge() ||
1699 HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr))
1700 continue;
1701
1702 args_to_lift_indices.erase(i);
1703 break;
1704 }
1705
1706 // Do not lift arguments that are not used at all. Otherwise, this unused
1707 // arg would be outside compiled, its output tensor will be forced to
1708 // transfer to host needlessly.
1709 if (!used) args_to_lift_indices.erase(i);
1710 }
1711
1712 return args_to_lift_indices;
1713 }
1714
1715 // Find inputs of While node that are:
1716 // 1. not used in cond func,
1717 // 2. only used for outside compilation in body func,
1718 // 3. loop invariant.
1719 // These inputs can be lifted out of the while loop.
FindArgsToLiftForWhileNode(Node * while_node,FunctionLibraryDefinition * fld)1720 xla::StatusOr<absl::flat_hash_set<int>> FindArgsToLiftForWhileNode(
1721 Node* while_node, FunctionLibraryDefinition* fld) {
1722 // DT_RESOURCE inputs are candidates.
1723 absl::flat_hash_set<int> result;
1724 std::vector<DataType> dtypes;
1725 TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "T", &dtypes));
1726 for (int i = 0; i < dtypes.size(); i++) {
1727 // TODO(b/74023706): enable non resource inputs as well.
1728 if (dtypes[i] == DT_RESOURCE) {
1729 result.insert(i);
1730 }
1731 }
1732
1733 // Remove inputs that are used in cond func.
1734 NameAttrList cond_func;
1735 TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "cond", &cond_func));
1736 const FunctionDef* cond_fdef = fld->Find(cond_func.name());
1737 if (cond_fdef == nullptr) {
1738 return errors::Internal("Cannot find cond function ", cond_func.name(),
1739 " for while node ", while_node->DebugString());
1740 }
1741 std::unique_ptr<FunctionBody> cond_fbody;
1742 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1743 *cond_fdef, AttrSlice(&cond_func.attr()), fld, &cond_fbody));
1744 for (int i = 0; i < cond_fbody->arg_nodes.size(); i++) {
1745 const Node* arg_node = cond_fbody->arg_nodes[i];
1746 for (const Edge* e : arg_node->out_edges()) {
1747 if (!e->IsControlEdge()) {
1748 result.erase(i);
1749 }
1750 }
1751 }
1752
1753 // Remove inputs that are not loop invariant.
1754 NameAttrList body_func;
1755 TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_func));
1756 const FunctionDef* body_fdef = fld->Find(body_func.name());
1757 if (body_fdef == nullptr) {
1758 return errors::Internal("Cannot find body function ", body_func.name(),
1759 " for while node ", while_node->DebugString());
1760 }
1761 std::unique_ptr<FunctionBody> body_fbody;
1762 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1763 *body_fdef, AttrSlice(&body_func.attr()), fld, &body_fbody));
1764 for (int i = 0; i < body_fbody->ret_nodes.size(); i++) {
1765 const Node* node = body_fbody->ret_nodes[i];
1766 do {
1767 TF_RETURN_IF_ERROR(node->input_node(0, &node));
1768 } while (node->IsIdentity());
1769 if (node != body_fbody->arg_nodes[i]) {
1770 result.erase(i);
1771 }
1772 }
1773
1774 // Remove inputs that only have one output edge (loop invariant, but not used
1775 // in outside compilation).
1776 for (int i = 0; i < body_fbody->arg_nodes.size(); i++) {
1777 const Node* arg_node = body_fbody->arg_nodes[i];
1778 int data_edge_count = std::count_if(
1779 arg_node->out_edges().begin(), arg_node->out_edges().end(),
1780 [](const Edge* e) { return !e->IsControlEdge(); });
1781 if (data_edge_count == 1) {
1782 result.erase(i);
1783 }
1784 }
1785
1786 // Remove inputs that have non-outside-compilation usage.
1787 for (int i = 0; i < body_fbody->arg_nodes.size(); i++) {
1788 const Node* arg_node = body_fbody->arg_nodes[i];
1789 for (const Edge* e : arg_node->out_edges()) {
1790 if (!e->dst()->IsRetval() &&
1791 !HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1792 result.erase(i);
1793 break;
1794 }
1795 }
1796 }
1797
1798 return result;
1799 }
1800
1801 // Find inputs of function call node that are only used for outside compilation.
1802 // These inputs can be lifted out of the function call node.
FindArgsToLiftForCallNode(Node * call_node,const FunctionBody & fbody)1803 xla::StatusOr<absl::flat_hash_set<int>> FindArgsToLiftForCallNode(
1804 Node* call_node, const FunctionBody& fbody) {
1805 // DT_RESOURCE inputs are candidates.
1806 absl::flat_hash_set<int> result;
1807 std::vector<DataType> dtypes(call_node->input_types().begin(),
1808 call_node->input_types().end());
1809 for (int i = 0; i < dtypes.size(); i++) {
1810 // TODO(b/74023706): enable for non resource inputs as well.
1811 if (dtypes[i] == DT_RESOURCE) {
1812 result.insert(i);
1813 }
1814 }
1815
1816 // Remove inputs that have non-outside-compilation usage, or not used at all.
1817 for (int i = 0; i < fbody.arg_nodes.size(); i++) {
1818 const Node* arg_node = fbody.arg_nodes[i];
1819 if (arg_node->out_edges().empty()) {
1820 result.erase(i);
1821 continue;
1822 }
1823
1824 for (const Edge* e : arg_node->out_edges()) {
1825 if (!HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1826 result.erase(i);
1827 break;
1828 }
1829 }
1830 }
1831 return result;
1832 }
1833
1834 Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr,
1835 FunctionLibraryDefinition* fld,
1836 int* lifted_arg_count, bool* rewritten);
1837
LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(const FunctionBody & fbody,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,int * lifted_arg_count,absl::optional<string> new_func_name,bool * rewritten)1838 Status LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
1839 const FunctionBody& fbody, FunctionLibraryRuntime* flr,
1840 FunctionLibraryDefinition* fld, int* lifted_arg_count,
1841 absl::optional<string> new_func_name, bool* rewritten) {
1842 *rewritten = false;
1843 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgs(
1844 fbody.graph, flr, fld, lifted_arg_count, rewritten));
1845
1846 if (*rewritten) {
1847 FunctionDef rewritten_fdef;
1848 TF_RETURN_IF_ERROR(GraphToFunctionDef(
1849 *(fbody.graph), fbody.fdef.signature().name(), &rewritten_fdef));
1850 if (new_func_name) {
1851 rewritten_fdef.mutable_signature()->set_name(*new_func_name);
1852 TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
1853 } else {
1854 TF_RETURN_IF_ERROR(
1855 fld->ReplaceFunction(fbody.fdef.signature().name(), rewritten_fdef));
1856 }
1857 }
1858
1859 return OkStatus();
1860 }
1861
MakeIdentityNodesForArgsToLift(const absl::flat_hash_set<int> & args_to_lift,const int arg_to_input_edge_offset,Graph * g,Node * n,absl::flat_hash_map<int,string> * lifted_arg_index_to_oc_cluster_name,int * lifted_arg_count)1862 Status MakeIdentityNodesForArgsToLift(
1863 const absl::flat_hash_set<int>& args_to_lift,
1864 const int arg_to_input_edge_offset, Graph* g, Node* n,
1865 absl::flat_hash_map<int, string>* lifted_arg_index_to_oc_cluster_name,
1866 int* lifted_arg_count) {
1867 int num_input = n->num_inputs();
1868 for (int arg_index = 0; arg_index < num_input; ++arg_index) {
1869 if (!args_to_lift.contains(arg_index)) continue;
1870
1871 int input_edge_index = arg_index + arg_to_input_edge_offset;
1872 const Edge* arg_edge;
1873 TF_RETURN_IF_ERROR(n->input_edge(input_edge_index, &arg_edge));
1874
1875 string node_name =
1876 g->NewName(absl::StrCat("lifted_arg", *lifted_arg_count));
1877 (*lifted_arg_count)++;
1878 (*lifted_arg_index_to_oc_cluster_name)[arg_index] = node_name;
1879 NodeDefBuilder id_builder(node_name, "Identity");
1880 id_builder.Attr("T", n->input_type(input_edge_index));
1881 id_builder.Attr(kOutsideCompilationAttr, id_builder.node_name());
1882 id_builder.Attr(kXlaIsLiftedArgAttrName, true);
1883 id_builder.Input(arg_edge->src()->name(), arg_edge->src_output(),
1884 n->input_type(input_edge_index));
1885 NodeDef id_def;
1886 TF_RETURN_IF_ERROR(id_builder.Finalize(&id_def));
1887 TF_ASSIGN_OR_RETURN(Node * id_node, g->AddNode(id_def));
1888 g->AddEdge(arg_edge->src(), arg_edge->src_output(), id_node, 0);
1889 g->AddControlEdge(id_node, n);
1890 }
1891
1892 return OkStatus();
1893 }
1894
1895 // Replaces all usages of lifted args with placeholder nodes. Afterwards,
1896 // removing these args should be safe since they no longer have users.
RemoveArgsToLiftFromFunctionBody(const absl::flat_hash_set<int> & args_to_lift,const std::vector<DataType> & arg_dtypes,const absl::flat_hash_map<int,string> & lifted_arg_index_to_oc_cluster_name,const absl::flat_hash_map<int,int> & index_mapping,const FunctionBody * fbody)1897 Status RemoveArgsToLiftFromFunctionBody(
1898 const absl::flat_hash_set<int>& args_to_lift,
1899 const std::vector<DataType>& arg_dtypes,
1900 const absl::flat_hash_map<int, string>& lifted_arg_index_to_oc_cluster_name,
1901 const absl::flat_hash_map<int, int>& index_mapping,
1902 const FunctionBody* fbody) {
1903 for (int i = 0; i < fbody->arg_nodes.size(); ++i) {
1904 Node* arg_node = fbody->arg_nodes[i];
1905
1906 if (!args_to_lift.contains(i)) {
1907 int new_index = index_mapping.at(i);
1908 arg_node->ClearAttr("index");
1909 arg_node->AddAttr("index", new_index);
1910 arg_node->ClearAttr("T");
1911 arg_node->AddAttr("T", arg_dtypes[i]);
1912 continue;
1913 }
1914
1915 std::vector<const Edge*> out_edges_to_oc;
1916 for (const Edge* e : arg_node->out_edges()) {
1917 if (HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1918 out_edges_to_oc.push_back(e);
1919 }
1920 }
1921
1922 for (const Edge* e : out_edges_to_oc) {
1923 string outside_compilation_cluster;
1924 TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->def(), kOutsideCompilationAttr,
1925 &outside_compilation_cluster));
1926 NodeDefBuilder ph_builder(fbody->graph->NewName("lifted_arg"),
1927 "Placeholder");
1928 ph_builder.Attr("dtype", arg_dtypes[i]);
1929 ph_builder.Attr(kOutsideCompilationAttr, outside_compilation_cluster);
1930 TF_RET_CHECK(lifted_arg_index_to_oc_cluster_name.contains(i));
1931 ph_builder.Attr(kXlaLiftedArgOutsideCompilationAttrName,
1932 lifted_arg_index_to_oc_cluster_name.at(i));
1933
1934 NodeDef ph_def;
1935 TF_RETURN_IF_ERROR(ph_builder.Finalize(&ph_def));
1936
1937 TF_ASSIGN_OR_RETURN(Node * ph_node, fbody->graph->AddNode(ph_def));
1938
1939 Node* dst = e->dst();
1940 int dst_input = e->dst_input();
1941 fbody->graph->RemoveEdge(e);
1942 fbody->graph->AddEdge(ph_node, 0, dst, dst_input);
1943 }
1944
1945 fbody->graph->RemoveNode(arg_node);
1946 }
1947
1948 return OkStatus();
1949 }
1950
CleanUpInEdges(const absl::flat_hash_map<int,int> & index_mapping,const int arg_to_input_edge_offset,Graph * g,Node * n)1951 Status CleanUpInEdges(const absl::flat_hash_map<int, int>& index_mapping,
1952 const int arg_to_input_edge_offset, Graph* g, Node* n) {
1953 int num_inputs = n->num_inputs();
1954 for (int i = 0; i < num_inputs; ++i) {
1955 if (i < arg_to_input_edge_offset) continue;
1956
1957 int arg_idx = i - arg_to_input_edge_offset;
1958 const Edge* e;
1959 TF_RETURN_IF_ERROR(n->input_edge(i, &e));
1960
1961 // If an edge maps to a lifted argument, simply remove that edge from graph.
1962 if (!index_mapping.contains(arg_idx)) {
1963 g->RemoveEdge(e);
1964 continue;
1965 }
1966
1967 // If an edge maps to same input port, nothing to do.
1968 if (index_mapping.at(arg_idx) == arg_idx) continue;
1969
1970 g->AddEdge(e->src(), e->src_output(), n,
1971 index_mapping.at(arg_idx) + arg_to_input_edge_offset);
1972 g->RemoveEdge(e);
1973 }
1974
1975 return OkStatus();
1976 }
1977
UpdateTypeAttribute(const absl::flat_hash_map<int,int> & index_mapping,const string & type_attr_name,const std::vector<DataType> & dtypes,Node * n)1978 Status UpdateTypeAttribute(const absl::flat_hash_map<int, int>& index_mapping,
1979 const string& type_attr_name,
1980 const std::vector<DataType>& dtypes, Node* n) {
1981 std::vector<DataType> new_dtypes;
1982 new_dtypes.reserve(index_mapping.size());
1983 for (int i = 0; i < dtypes.size(); ++i) {
1984 if (index_mapping.contains(i)) {
1985 new_dtypes.emplace_back(dtypes[i]);
1986 }
1987 }
1988
1989 n->ClearAttr(type_attr_name);
1990 n->AddAttr(type_attr_name, new_dtypes);
1991
1992 return OkStatus();
1993 }
1994
1995 // While V2 always creates Identity node for each While node output, which is
1996 // not necessary for XLA computation. Remove those Identity nodes.
RemoveOutputIdentityNodesForWhileV2(Graph * g,Node * while_node)1997 void RemoveOutputIdentityNodesForWhileV2(Graph* g, Node* while_node) {
1998 std::vector<const Edge*> edges_to_identity_node;
1999 for (const Edge* e : while_node->out_edges()) {
2000 if (!e->IsControlEdge() && e->dst()->IsIdentity()) {
2001 edges_to_identity_node.push_back(e);
2002 }
2003 }
2004 for (const Edge* e : edges_to_identity_node) {
2005 Node* identity = e->dst();
2006 std::vector<const Edge*> out_edges(identity->out_edges().begin(),
2007 identity->out_edges().end());
2008 for (const Edge* out_edge : out_edges) {
2009 if (out_edge->IsControlEdge()) {
2010 g->AddControlEdge(while_node, out_edge->dst());
2011 } else {
2012 Node* dst = out_edge->dst();
2013 int dst_input = out_edge->dst_input();
2014 g->RemoveEdge(out_edge);
2015 g->AddEdge(while_node, e->src_output(), dst, dst_input);
2016 }
2017 }
2018 g->RemoveNode(identity);
2019 }
2020 }
2021
2022 // If corresponding While node output is used, change it to use While node input
2023 // instead.
ReplaceOutputEdgesWithInputEdgeSourceForWhile(const absl::flat_hash_set<int> & args_to_lift,Graph * g,Node * while_node)2024 Status ReplaceOutputEdgesWithInputEdgeSourceForWhile(
2025 const absl::flat_hash_set<int>& args_to_lift, Graph* g, Node* while_node) {
2026 std::vector<const Edge*> edges_to_replace;
2027 for (const Edge* e : while_node->out_edges()) {
2028 if (args_to_lift.contains(e->src_output())) {
2029 edges_to_replace.push_back(e);
2030 }
2031 }
2032 for (const Edge* e : edges_to_replace) {
2033 const Edge* input_edge;
2034 TF_RETURN_IF_ERROR(while_node->input_edge(e->src_output(), &input_edge));
2035 Node* dst = e->dst();
2036 int dst_input = e->dst_input();
2037 g->RemoveEdge(e);
2038 g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
2039 }
2040
2041 return OkStatus();
2042 }
2043
2044 // Calculates mapping from argument index before lifting to index afterwards.
ArgIndexMapping(const int num_args,const absl::flat_hash_set<int> & args_to_lift)2045 absl::flat_hash_map<int, int> ArgIndexMapping(
2046 const int num_args, const absl::flat_hash_set<int>& args_to_lift) {
2047 absl::flat_hash_map<int, int> index_mapping;
2048 int new_index = 0;
2049 for (int i = 0; i < num_args; i++) {
2050 if (!args_to_lift.contains(i)) {
2051 index_mapping[i] = new_index;
2052 ++new_index;
2053 }
2054 }
2055
2056 return index_mapping;
2057 }
2058
2059 // Remove outputs of While node body function that maps to lifted arguments.
CleanUpRetvalsForWhileBody(const absl::flat_hash_map<int,int> & index_mapping,const std::vector<DataType> & dtypes,FunctionBody * fbody)2060 void CleanUpRetvalsForWhileBody(
2061 const absl::flat_hash_map<int, int>& index_mapping,
2062 const std::vector<DataType>& dtypes, FunctionBody* fbody) {
2063 for (int i = 0; i < fbody->ret_nodes.size(); i++) {
2064 Node* ret_node = fbody->ret_nodes[i];
2065 if (index_mapping.contains(i)) {
2066 int new_index = index_mapping.at(i);
2067 ret_node->ClearAttr("index");
2068 ret_node->AddAttr("index", new_index);
2069 ret_node->ClearAttr("T");
2070 ret_node->AddAttr("T", dtypes[i]);
2071 } else {
2072 fbody->graph->RemoveNode(ret_node);
2073 }
2074 }
2075 }
2076
LiftOutsideCompilationOnlyArgsFromWhileNode(Graph * g,Node * while_node,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2077 Status LiftOutsideCompilationOnlyArgsFromWhileNode(
2078 Graph* g, Node* while_node, FunctionLibraryDefinition* fld,
2079 int* lifted_arg_count, bool* rewritten) {
2080 *rewritten = false;
2081
2082 TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> args_to_lift,
2083 FindArgsToLiftForWhileNode(while_node, fld));
2084 if (args_to_lift.empty()) return OkStatus();
2085
2086 RemoveOutputIdentityNodesForWhileV2(g, while_node);
2087
2088 TF_RETURN_IF_ERROR(ReplaceOutputEdgesWithInputEdgeSourceForWhile(
2089 args_to_lift, g, while_node));
2090
2091 std::vector<DataType> dtypes;
2092 TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "T", &dtypes));
2093
2094 absl::flat_hash_map<int, int> index_mapping =
2095 ArgIndexMapping(dtypes.size(), args_to_lift);
2096
2097 // For each lifted arg, add an outside compilation Identity node to send
2098 // it to host.
2099 absl::flat_hash_map<int, string> lifted_arg_index_to_oc_cluster_name;
2100 TF_RETURN_IF_ERROR(MakeIdentityNodesForArgsToLift(
2101 args_to_lift, /*arg_to_input_edge_offset=*/0, g, while_node,
2102 &lifted_arg_index_to_oc_cluster_name, lifted_arg_count));
2103
2104 // For cond func, remove _Arg nodes.
2105 TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionBody> cond_fbody,
2106 InstantiateAssociatedFunction(*while_node, "cond", fld));
2107 TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2108 args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2109 cond_fbody.get()));
2110
2111 FunctionDef rewritten_cond_fdef;
2112 TF_RETURN_IF_ERROR(GraphToFunctionDef(*(cond_fbody->graph),
2113 cond_fbody->fdef.signature().name(),
2114 &rewritten_cond_fdef));
2115 TF_RETURN_IF_ERROR(fld->ReplaceFunction(cond_fbody->fdef.signature().name(),
2116 rewritten_cond_fdef));
2117
2118 // For body func, remove _Retval nodes, and replace _Arg nodes with
2119 // Placeholder nodes.
2120 TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionBody> body_fbody,
2121 InstantiateAssociatedFunction(*while_node, "body", fld));
2122
2123 TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2124 args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2125 body_fbody.get()));
2126
2127 CleanUpRetvalsForWhileBody(index_mapping, dtypes, body_fbody.get());
2128
2129 FunctionDef rewritten_body_fdef;
2130 TF_RETURN_IF_ERROR(GraphToFunctionDef(*(body_fbody->graph),
2131 body_fbody->fdef.signature().name(),
2132 &rewritten_body_fdef));
2133 TF_RETURN_IF_ERROR(fld->ReplaceFunction(body_fbody->fdef.signature().name(),
2134 rewritten_body_fdef));
2135
2136 // Remove edges from lifted args to While node, and change "T" attr of the
2137 // While node.
2138 TF_RETURN_IF_ERROR(CleanUpInEdges(
2139 index_mapping, /*arg_to_input_edge_offset=*/0, g, while_node));
2140
2141 TF_RETURN_IF_ERROR(
2142 UpdateTypeAttribute(index_mapping, "T", dtypes, while_node));
2143
2144 *rewritten = true;
2145
2146 return OkStatus();
2147 }
2148
LiftOutsideCompilationOnlyArgsFromIfNode(Graph * g,Node * if_node,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2149 Status LiftOutsideCompilationOnlyArgsFromIfNode(Graph* g, Node* if_node,
2150 FunctionLibraryDefinition* fld,
2151 int* lifted_arg_count,
2152 bool* rewritten) {
2153 *rewritten = false;
2154 TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> args_to_lift,
2155 FindArgsToLiftForIfNode(*if_node, fld));
2156 if (args_to_lift.empty()) return OkStatus();
2157
2158 std::vector<DataType> dtypes;
2159 TF_RETURN_IF_ERROR(GetNodeAttr(if_node->def(), "Tin", &dtypes));
2160
2161 absl::flat_hash_map<int, int> index_mapping;
2162 int new_index = 0;
2163 for (int i = 0; i < dtypes.size(); i++) {
2164 if (!args_to_lift.contains(i)) {
2165 index_mapping[i] = new_index;
2166 ++new_index;
2167 }
2168 }
2169
2170 // For each lifted arg, add an outside compilation Identity node to send
2171 // it to host.
2172 absl::flat_hash_map<int, string> lifted_arg_index_to_oc_cluster_name;
2173 TF_RETURN_IF_ERROR(MakeIdentityNodesForArgsToLift(
2174 args_to_lift, /*arg_to_input_edge_offset=*/1, g, if_node,
2175 &lifted_arg_index_to_oc_cluster_name, lifted_arg_count));
2176
2177 TF_ASSIGN_OR_RETURN(
2178 std::unique_ptr<FunctionBody> then_branch_fbody,
2179 InstantiateAssociatedFunction(*if_node, "then_branch", fld));
2180
2181 TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2182 args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2183 then_branch_fbody.get()));
2184
2185 FunctionDef rewritten_then_branch_fdef;
2186 TF_RETURN_IF_ERROR(GraphToFunctionDef(
2187 *(then_branch_fbody->graph), then_branch_fbody->fdef.signature().name(),
2188 &rewritten_then_branch_fdef));
2189 TF_RETURN_IF_ERROR(fld->ReplaceFunction(
2190 then_branch_fbody->fdef.signature().name(), rewritten_then_branch_fdef));
2191
2192 TF_ASSIGN_OR_RETURN(
2193 std::unique_ptr<FunctionBody> else_branch_fbody,
2194 InstantiateAssociatedFunction(*if_node, "else_branch", fld));
2195
2196 TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2197 args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2198 else_branch_fbody.get()));
2199
2200 FunctionDef rewritten_else_branch_fdef;
2201 TF_RETURN_IF_ERROR(GraphToFunctionDef(
2202 *(else_branch_fbody->graph), else_branch_fbody->fdef.signature().name(),
2203 &rewritten_else_branch_fdef));
2204 TF_RETURN_IF_ERROR(fld->ReplaceFunction(
2205 else_branch_fbody->fdef.signature().name(), rewritten_else_branch_fdef));
2206
2207 // Remove edges from lifted args to If node, and change "Tin" attr of the
2208 // If node.
2209 TF_RETURN_IF_ERROR(CleanUpInEdges(
2210 index_mapping, /*arg_to_input_edge_offset=*/1, g, if_node));
2211 TF_RETURN_IF_ERROR(
2212 UpdateTypeAttribute(index_mapping, "Tin", dtypes, if_node));
2213
2214 *rewritten = true;
2215
2216 return OkStatus();
2217 }
2218
LiftOutsideCompilationOnlyArgsFromCallNode(Graph * g,Node * call_node,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2219 Status LiftOutsideCompilationOnlyArgsFromCallNode(
2220 Graph* g, Node* call_node, FunctionLibraryRuntime* flr,
2221 FunctionLibraryDefinition* fld, int* lifted_arg_count, bool* rewritten) {
2222 *rewritten = false;
2223
2224 // Instantiate the function.
2225 NameAttrList func;
2226 if (fld->Contains(call_node->type_string())) {
2227 func.set_name(call_node->type_string());
2228 *func.mutable_attr() = call_node->def().attr();
2229 } else if (call_node->IsPartitionedCall()) {
2230 TF_RETURN_IF_ERROR(GetNodeAttr(call_node->def(), "f", &func));
2231 } else {
2232 TF_RET_CHECK(call_node->type_string() ==
2233 FunctionLibraryDefinition::kGradientOp);
2234 func.set_name(FunctionLibraryDefinition::kGradientOp);
2235 *func.mutable_attr() = call_node->def().attr();
2236 }
2237 FunctionLibraryRuntime::Handle handle;
2238 TF_RETURN_IF_ERROR(
2239 flr->Instantiate(func.name(), AttrSlice(&func.attr()), &handle));
2240 auto cleanup_handle = gtl::MakeCleanup(
2241 [&flr, &handle]() { flr->ReleaseHandle(handle).IgnoreError(); });
2242 const FunctionBody* fbody = flr->GetFunctionBody(handle);
2243
2244 // Find _Arg nodes to lift.
2245 TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> args_to_lift,
2246 FindArgsToLiftForCallNode(call_node, *fbody));
2247 if (args_to_lift.empty()) return OkStatus();
2248
2249 std::vector<DataType> dtypes;
2250 dtypes = std::vector<DataType>(call_node->input_types().begin(),
2251 call_node->input_types().end());
2252
2253 absl::flat_hash_map<int, int> index_mapping =
2254 ArgIndexMapping(dtypes.size(), args_to_lift);
2255
2256 // For each lifted arg, add an outside compilation Identity node to send
2257 // it to host.
2258 absl::flat_hash_map<int, string> lifted_arg_index_to_oc_cluster_name;
2259 TF_RETURN_IF_ERROR(MakeIdentityNodesForArgsToLift(
2260 args_to_lift, /*arg_to_input_edge_offset=*/0, g, call_node,
2261 &lifted_arg_index_to_oc_cluster_name, lifted_arg_count));
2262
2263 // Remove _Arg nodes.
2264 TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2265 args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2266 fbody));
2267
2268 // Store rewritten function as a new function, because the original function
2269 // might be defined by user and we should not modify it.
2270 FunctionDef rewritten_fdef;
2271 TF_RETURN_IF_ERROR(GraphToFunctionDef(
2272 *(fbody->graph), fbody->fdef.signature().name(), &rewritten_fdef));
2273 string new_func_name =
2274 fld->UniqueFunctionName(fbody->fdef.signature().name());
2275 rewritten_fdef.mutable_signature()->set_name(new_func_name);
2276 TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
2277
2278 // Remove edges from lifted args to call node.
2279 TF_RETURN_IF_ERROR(CleanUpInEdges(
2280 index_mapping, /*arg_to_input_edge_offset=*/0, g, call_node));
2281
2282 // Rewrite the call node to use the rewritten function.
2283 NodeDef node_def;
2284 node_def.set_name(g->NewName(call_node->name()));
2285 node_def.set_op(new_func_name);
2286 if (call_node->IsPartitionedCall()) {
2287 NameAttrList f;
2288 TF_RETURN_IF_ERROR(GetNodeAttr(call_node->def(), "f", &f));
2289 *node_def.mutable_attr() = f.attr();
2290 } else if (fld->Contains(call_node->type_string())) {
2291 *node_def.mutable_attr() = call_node->def().attr();
2292 } else {
2293 TF_RET_CHECK(call_node->type_string() ==
2294 FunctionLibraryDefinition::kGradientOp);
2295 *node_def.mutable_attr() = call_node->def().attr();
2296 node_def.mutable_attr()->erase(FunctionLibraryDefinition::kFuncAttr);
2297 }
2298 TF_ASSIGN_OR_RETURN(call_node, ReplaceNode(g, call_node, node_def));
2299
2300 *rewritten = true;
2301
2302 return OkStatus();
2303 }
2304
2305 // Lifts outside compilation only _Arg nodes out of If/While/function nodes.
LiftOutsideCompilationOnlyArgs(Graph * g,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2306 Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr,
2307 FunctionLibraryDefinition* fld,
2308 int* lifted_arg_count, bool* rewritten) {
2309 *rewritten = false;
2310
2311 // Handle deeper functional nodes first.
2312 std::vector<Node*> while_nodes, if_nodes, call_nodes;
2313 for (Node* n : g->op_nodes()) {
2314 if (HasNodeAttr(n->def(), kOutsideCompilationAttr)) {
2315 continue;
2316 }
2317
2318 if (n->IsWhileNode()) {
2319 TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionBody> body_fbody,
2320 InstantiateAssociatedFunction(*n, "body", fld));
2321 bool func_rewritten = false;
2322 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2323 *body_fbody, flr, fld, lifted_arg_count,
2324 /*new_func_name=*/absl::nullopt, &func_rewritten));
2325 *rewritten = *rewritten || func_rewritten;
2326
2327 while_nodes.push_back(n);
2328 } else if (n->IsIfNode()) {
2329 TF_ASSIGN_OR_RETURN(
2330 std::unique_ptr<FunctionBody> then_branch_fbody,
2331 InstantiateAssociatedFunction(*n, "then_branch", fld));
2332 bool func_rewritten = false;
2333 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2334 *then_branch_fbody, flr, fld, lifted_arg_count,
2335 /*new_func_name=*/absl::nullopt, &func_rewritten));
2336 *rewritten |= func_rewritten;
2337
2338 TF_ASSIGN_OR_RETURN(
2339 std::unique_ptr<FunctionBody> else_branch_fbody,
2340 InstantiateAssociatedFunction(*n, "else_branch", fld));
2341 func_rewritten = false;
2342 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2343 *else_branch_fbody, flr, fld, lifted_arg_count,
2344 /*new_func_name=*/absl::nullopt, &func_rewritten));
2345 *rewritten |= func_rewritten;
2346
2347 if_nodes.push_back(n);
2348 } else if (IsFunctionCall(*fld, *n)) {
2349 // Function call nodes need to be rewritten, so handle them later.
2350 call_nodes.push_back(n);
2351 }
2352 }
2353
2354 std::vector<Node*> rewritten_call_nodes;
2355 for (Node* call_node : call_nodes) {
2356 if (call_node->IsPartitionedCall()) {
2357 std::unique_ptr<FunctionBody> function_fbody;
2358 TF_ASSIGN_OR_RETURN(function_fbody,
2359 InstantiateAssociatedFunction(*call_node, "f", fld));
2360 bool func_rewritten = false;
2361 string new_func_name =
2362 fld->UniqueFunctionName(function_fbody->fdef.signature().name());
2363 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2364 *function_fbody, flr, fld, lifted_arg_count, new_func_name,
2365 &func_rewritten));
2366 if (func_rewritten) {
2367 NameAttrList f;
2368 TF_RETURN_IF_ERROR(GetNodeAttr(call_node->def(), "f", &f));
2369 f.set_name(new_func_name);
2370 call_node->ClearAttr("f");
2371 call_node->AddAttr("f", f);
2372 }
2373
2374 *rewritten |= func_rewritten;
2375 rewritten_call_nodes.push_back(call_node);
2376 } else if (fld->Contains(call_node->type_string())) {
2377 std::unique_ptr<FunctionBody> function_fbody;
2378 const FunctionDef* fdef = fld->Find(call_node->type_string());
2379 TF_RET_CHECK(fdef);
2380 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, call_node->attrs(), fld,
2381 &function_fbody));
2382 bool func_rewritten = false;
2383 string new_func_name =
2384 fld->UniqueFunctionName(function_fbody->fdef.signature().name());
2385 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2386 *function_fbody, flr, fld, lifted_arg_count, new_func_name,
2387 &func_rewritten));
2388 if (func_rewritten) {
2389 NodeDef node_def;
2390 node_def.set_name(g->NewName(call_node->name()));
2391 node_def.set_op(new_func_name);
2392 *node_def.mutable_attr() = call_node->def().attr();
2393 TF_ASSIGN_OR_RETURN(call_node, ReplaceNode(g, call_node, node_def));
2394 }
2395
2396 *rewritten |= func_rewritten;
2397 rewritten_call_nodes.push_back(call_node);
2398 } else {
2399 TF_RET_CHECK(call_node->type_string() ==
2400 FunctionLibraryDefinition::kGradientOp);
2401 FunctionLibraryRuntime::Handle handle;
2402 TF_RETURN_IF_ERROR(flr->Instantiate(call_node->type_string(),
2403 call_node->attrs(), &handle));
2404 auto cleanup_handle = gtl::MakeCleanup(
2405 [&flr, &handle]() { flr->ReleaseHandle(handle).IgnoreError(); });
2406 bool func_rewritten = false;
2407 string new_func_name = fld->UniqueFunctionName(
2408 absl::StrCat(call_node->name(), "_lift_args"));
2409 const FunctionBody* function_fbody = flr->GetFunctionBody(handle);
2410 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2411 *function_fbody, flr, fld, lifted_arg_count, new_func_name,
2412 &func_rewritten));
2413 if (func_rewritten) {
2414 NodeDef node_def;
2415 node_def.set_name(g->NewName(call_node->name()));
2416 node_def.set_op(new_func_name);
2417 *node_def.mutable_attr() = call_node->def().attr();
2418 node_def.mutable_attr()->erase(FunctionLibraryDefinition::kFuncAttr);
2419 TF_ASSIGN_OR_RETURN(call_node, ReplaceNode(g, call_node, node_def));
2420 }
2421
2422 *rewritten |= func_rewritten;
2423 rewritten_call_nodes.push_back(call_node);
2424 }
2425 }
2426
2427 for (Node* n : while_nodes) {
2428 bool node_rewritten = false;
2429 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsFromWhileNode(
2430 g, n, fld, lifted_arg_count, &node_rewritten));
2431 *rewritten = *rewritten || node_rewritten;
2432 }
2433
2434 for (Node* n : if_nodes) {
2435 bool node_rewritten = false;
2436 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsFromIfNode(
2437 g, n, fld, lifted_arg_count, &node_rewritten));
2438 *rewritten = *rewritten || node_rewritten;
2439 }
2440
2441 for (Node* n : rewritten_call_nodes) {
2442 bool node_rewritten = false;
2443 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsFromCallNode(
2444 g, n, flr, fld, lifted_arg_count, &node_rewritten));
2445 *rewritten = *rewritten || node_rewritten;
2446 }
2447
2448 if (*rewritten) {
2449 VLOG(4) << DumpGraphToFile("after_lifting_args", *g, fld);
2450 }
2451
2452 return OkStatus();
2453 }
2454
2455 } // namespace
2456
Encapsulate(std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def)2457 /*static*/ Status EncapsulateTPUComputationsPass::Encapsulate(
2458 std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
2459 // Check for undeclared outputs before Encapsulation, so we can give a better
2460 // error message.
2461 // TODO(phawkins): merge this with the encapsulation code to avoid the extra
2462 // O(n) pass over the edges.
2463 for (const Edge* e : (*graph)->edges()) {
2464 if (!e->IsControlEdge() &&
2465 e->src()->attrs().Find(kTPUReplicateAttr) != nullptr &&
2466 e->src()->attrs().Find(kOutsideCompilationAttr) == nullptr &&
2467 e->dst()->attrs().Find(kTPUReplicateAttr) == nullptr &&
2468 e->dst()->type_string() != kTPUReplicatedOutput) {
2469 return errors::InvalidArgument(
2470 "Undeclared output of TPU computation. A common cause of this error "
2471 "is variable initializers that depend on the TPU computation. Edge: ",
2472 FormatNodeForError(*e->src()), ":", e->src_output(), " -> ",
2473 FormatNodeForError(*e->dst()), ":", e->dst_input());
2474 }
2475 }
2476
2477 RemoveUnusedTPUReplicatedInputs(graph->get());
2478
2479 TF_RETURN_IF_ERROR(RenameClustersWithDuplicatedNames(graph->get()));
2480
2481 TF_RETURN_IF_ERROR(
2482 PerformStaticShapeInferenceBeforeEncapsulation(graph->get()));
2483
2484 auto output = absl::make_unique<Graph>((*graph)->op_registry());
2485 TF_RETURN_WITH_CONTEXT_IF_ERROR(
2486 EncapsulateSubgraphsInFunctions(
2487 kTPUReplicateAttr, **graph, RewriteSubgraph,
2488 /*reuse_existing_functions=*/true, &output, flib_def),
2489 "EncapsulateTPUComputationsPass failed");
2490 graph->swap(output);
2491
2492 return OkStatus();
2493 }
2494
BuildTPUReplicateOps(Graph * graph)2495 /*static*/ Status EncapsulateTPUComputationsPass::BuildTPUReplicateOps(
2496 Graph* graph) {
2497 // Finds all of the replicate function calls, to avoid mutating the graph
2498 // while iterating.
2499 std::vector<Node*> replicate_nodes;
2500 std::vector<Node*> guarantee_const_nodes;
2501 for (Node* n : graph->nodes()) {
2502 string name;
2503 if (TryGetNodeAttr(n->attrs(), kTPUReplicateAttr, &name) &&
2504 !TryGetNodeAttr(n->attrs(), kOutsideCompilationAttr, &name)) {
2505 replicate_nodes.push_back(n);
2506 } else if (n->type_string() == "GuaranteeConst") {
2507 guarantee_const_nodes.push_back(n);
2508 }
2509 }
2510
2511 // Replace any GuaranteeConst nodes with Identity nodes. These nodes have now
2512 // served their purpose and have no runtime effect, except increasing
2513 // inference latency due to executor overhead. Subsequent rewrites will remove
2514 // the Identity nodes.
2515 for (Node* n : guarantee_const_nodes) {
2516 std::vector<std::pair<Node*, int>> predecessors;
2517 for (const Edge* e : n->in_edges()) {
2518 predecessors.emplace_back(e->src(), e->src_output());
2519 }
2520 std::vector<std::pair<Node*, int>> successors;
2521 for (const Edge* e : n->out_edges()) {
2522 successors.emplace_back(e->dst(), e->dst_input());
2523 }
2524 NodeDef ndef;
2525 ndef.set_name(n->name());
2526 ndef.set_op("Identity");
2527 ndef.set_device(n->requested_device());
2528 MergeDebugInfo(NodeDebugInfo(n->def()), &ndef);
2529 AddNodeAttr("T", n->output_type(0), &ndef);
2530
2531 graph->RemoveNode(n);
2532 TF_ASSIGN_OR_RETURN(Node * id_node, graph->AddNode(ndef));
2533
2534 for (const auto& pred : predecessors) {
2535 if (pred.second < 0) {
2536 graph->AddControlEdge(pred.first, id_node);
2537 } else {
2538 graph->AddEdge(pred.first, pred.second, id_node, 0);
2539 }
2540 }
2541 for (const auto& succ : successors) {
2542 if (succ.second < 0) {
2543 graph->AddControlEdge(id_node, succ.first);
2544 } else {
2545 graph->AddEdge(id_node, 0, succ.first, succ.second);
2546 }
2547 }
2548 }
2549
2550 // Replaces each replicate function call together with its neighboring
2551 // TPUReplicatedInput/TPUReplicatedOutput nodes with a TPUReplicate node.
2552 for (Node* replicate : replicate_nodes) {
2553 int num_replicas;
2554 TF_RETURN_IF_ERROR(
2555 GetNodeAttr(replicate->attrs(), "num_replicas", &num_replicas));
2556 int variable_start_index;
2557 TF_RETURN_IF_ERROR(GetNodeAttr(replicate->attrs(), "_variable_start_index",
2558 &variable_start_index));
2559 int guaranteed_const_start_index;
2560 TF_RETURN_IF_ERROR(GetNodeAttr(replicate->attrs(),
2561 "_guaranteed_const_start_index",
2562 &guaranteed_const_start_index));
2563
2564 if (HasNodeAttr(replicate->def(), "use_tpu")) {
2565 bool use_tpu;
2566 TF_RETURN_IF_ERROR(GetNodeAttr(replicate->attrs(), "use_tpu", &use_tpu));
2567 if (!use_tpu) {
2568 LOG(WARNING) << "use_tpu=false attr on a TPUReplicate node is ignored.";
2569 }
2570 }
2571
2572 std::vector<const Edge*> in_edges;
2573 TF_RETURN_IF_ERROR(replicate->input_edges(&in_edges));
2574
2575 // Counts the number of replicated, non-replicated, and variable inputs.
2576 int pos = 0;
2577 std::vector<int> mirrored_variable_indices;
2578 int distributed_var_start_index = 0;
2579 while (pos < in_edges.size() &&
2580 in_edges[pos]->src()->type_string() == kTPUReplicatedInput) {
2581 // Checks that each TPUReplicatedInput node has the correct number of
2582 // replicas.
2583 int input_num_replicas;
2584 TF_RETURN_IF_ERROR(
2585 GetNodeAttr(in_edges[pos]->src()->attrs(), "N", &input_num_replicas));
2586
2587 bool is_mirrored_variable;
2588 CHECK(GetNodeAttr(in_edges[pos]->src()->attrs(), "is_mirrored_variable",
2589 &is_mirrored_variable)
2590 .ok());
2591 if (is_mirrored_variable) {
2592 mirrored_variable_indices.push_back(pos);
2593 }
2594
2595 bool is_packed = false;
2596 GetNodeAttr(in_edges[pos]->src()->attrs(), "is_packed", &is_packed)
2597 .IgnoreError();
2598
2599 bool is_distributed_variable =
2600 is_packed && (in_edges[pos]->src()->output_type(
2601 in_edges[pos]->src_output()) == DT_RESOURCE);
2602
2603 if (!is_distributed_variable && input_num_replicas != num_replicas) {
2604 return errors::InvalidArgument(
2605 "Mismatched number of replicas. Computation has ", num_replicas,
2606 " replicas, input '", FormatNodeForError(*in_edges[pos]->src()),
2607 "' has ", input_num_replicas, " replicas.");
2608 }
2609
2610 if (!is_distributed_variable) {
2611 if (distributed_var_start_index < pos) {
2612 return errors::InvalidArgument(
2613 "Expect a distributed resource after index ",
2614 distributed_var_start_index,
2615 ", but got a replicated resource at index ", pos);
2616 } else {
2617 ++distributed_var_start_index;
2618 }
2619 }
2620 ++pos;
2621 }
2622 const int num_replicated_inputs = distributed_var_start_index;
2623 const int num_distributed_vars = pos - num_replicated_inputs;
2624
2625 const int num_variables =
2626 std::max(0, guaranteed_const_start_index - variable_start_index);
2627
2628 const int num_guaranteed_constants =
2629 in_edges.size() - guaranteed_const_start_index;
2630 TF_RET_CHECK(num_guaranteed_constants >= 0);
2631
2632 VLOG(1) << "Replicate node '" << replicate->name() << "'"
2633 << " input edges: " << in_edges.size()
2634 << " num_replicated_inputs: " << num_replicated_inputs
2635 << " num_distributed_vars: " << num_distributed_vars
2636 << " num_variables: " << num_variables
2637 << " num_guaranteed_constants: " << num_guaranteed_constants
2638 << " num_mirrored_variables: " << mirrored_variable_indices.size();
2639
2640 const int num_broadcast_inputs =
2641 in_edges.size() - (num_replicated_inputs + num_distributed_vars +
2642 num_variables + num_guaranteed_constants);
2643 TF_RET_CHECK(num_broadcast_inputs >= 0);
2644
2645 const int num_inputs = num_replicated_inputs * num_replicas +
2646 num_distributed_vars + num_broadcast_inputs +
2647 num_guaranteed_constants + num_variables;
2648
2649 std::vector<Node*> nodes_to_remove = {replicate};
2650
2651 // Data and control inputs to the new TPUReplicate node.
2652 std::vector<std::pair<Node*, int>> data_inputs(num_inputs);
2653 gtl::FlatSet<Node*> control_inputs;
2654
2655 AddControlInputs(*replicate, &control_inputs);
2656
2657 // Replicated inputs. Adds the inputs from the TPUReplicatedInput inputs,
2658 // in replica-major order. See the comments in
2659 // distributed_tpu_rewrite_pass.h for a description of the argument order.
2660 DataTypeVector replicated_input_types(num_replicated_inputs * num_replicas +
2661 num_distributed_vars);
2662
2663 // Inputs with is_distributed_variable = false.
2664 for (int i = 0; i < num_replicated_inputs; ++i) {
2665 std::vector<const Edge*> replica_in_edges;
2666 TF_RETURN_IF_ERROR(in_edges[i]->src()->input_edges(&replica_in_edges));
2667 for (int replica = 0; replica < num_replicas; ++replica) {
2668 int pos = replica * num_replicated_inputs + i;
2669 const Edge* edge = replica_in_edges[replica];
2670 data_inputs[pos] = {edge->src(), edge->src_output()};
2671 replicated_input_types[pos] = EdgeType(edge);
2672 }
2673 AddControlInputs(*in_edges[i]->src(), &control_inputs);
2674 nodes_to_remove.push_back(in_edges[i]->src());
2675 }
2676
2677 // Inputs with is_distributed_variable = true.
2678 for (int i = 0; i < num_distributed_vars; ++i) {
2679 int pos = num_replicas * num_replicated_inputs + i;
2680 std::vector<const Edge*> replica_in_edges;
2681 TF_RETURN_IF_ERROR(
2682 in_edges[num_replicated_inputs + i]->src()->input_edges(
2683 &replica_in_edges));
2684 TF_RET_CHECK(replica_in_edges.size() == 1);
2685 const Edge* edge = replica_in_edges[0];
2686 data_inputs[pos] = {edge->src(), edge->src_output()};
2687 replicated_input_types[pos] = EdgeType(edge);
2688 AddControlInputs(*in_edges[num_replicated_inputs + i]->src(),
2689 &control_inputs);
2690 nodes_to_remove.push_back(in_edges[num_replicated_inputs + i]->src());
2691 }
2692
2693 // Appends the broadcast inputs.
2694 DataTypeVector broadcast_input_types(num_broadcast_inputs);
2695 for (int i = 0; i < num_broadcast_inputs; ++i) {
2696 int pos = num_replicas * num_replicated_inputs + num_distributed_vars + i;
2697 const Edge* edge =
2698 in_edges[num_replicated_inputs + num_distributed_vars + i];
2699 data_inputs[pos] = {edge->src(), edge->src_output()};
2700 broadcast_input_types[i] = EdgeType(edge);
2701 }
2702
2703 // Appends the variable inputs.
2704 for (int i = 0; i < num_variables; ++i) {
2705 int pos = num_replicas * num_replicated_inputs + num_distributed_vars +
2706 num_broadcast_inputs + i;
2707 const Edge* edge = in_edges[num_replicated_inputs + num_distributed_vars +
2708 num_broadcast_inputs + i];
2709 data_inputs[pos] = {edge->src(), edge->src_output()};
2710 }
2711
2712 DataTypeVector guaranteed_constant_types(num_guaranteed_constants);
2713 for (int i = 0; i < num_guaranteed_constants; ++i) {
2714 int pos = num_replicas * num_replicated_inputs + num_distributed_vars +
2715 num_broadcast_inputs + num_variables + i;
2716 const Edge* edge = in_edges[num_replicated_inputs + num_distributed_vars +
2717 num_broadcast_inputs + num_variables + i];
2718 data_inputs[pos] = {edge->src(), edge->src_output()};
2719 guaranteed_constant_types[i] = EdgeType(edge);
2720 }
2721
2722 // Outputs. All outputs from a replicated computation are replicated.
2723 const int num_outputs = replicate->output_types().size();
2724 gtl::FlatSet<Node*> control_outputs;
2725 std::vector<Node*> replicated_outputs(num_outputs);
2726 for (const Edge* e : replicate->out_edges()) {
2727 if (e->IsControlEdge()) {
2728 control_outputs.insert(e->dst());
2729 } else {
2730 TF_RET_CHECK(e->src_output() < num_outputs);
2731 TF_RET_CHECK(e->dst()->type_string() == kTPUReplicatedOutput)
2732 << e->DebugString();
2733 TF_RET_CHECK(e->dst()->output_types().size() == num_replicas);
2734 replicated_outputs[e->src_output()] = e->dst();
2735 nodes_to_remove.push_back(e->dst());
2736
2737 AddControlOutputs(*e->dst(), &control_outputs);
2738 }
2739 }
2740
2741 // Flattens the edges outgoing from the TPUReplicatedOutput nodes in
2742 // replica-major order.
2743 std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_replicas *
2744 num_outputs);
2745 DataTypeVector output_types(num_replicas * num_outputs);
2746 for (int i = 0; i < num_outputs; ++i) {
2747 std::vector<std::vector<const Edge*>> replica_out_edges(num_replicas);
2748 TF_RET_CHECK(replicated_outputs[i] != nullptr);
2749 for (const Edge* e : replicated_outputs[i]->out_edges()) {
2750 TF_RET_CHECK(!e->IsControlEdge());
2751 replica_out_edges[e->src_output()].push_back(e);
2752 }
2753
2754 for (int replica = 0; replica < num_replicas; ++replica) {
2755 const int pos = replica * num_outputs + i;
2756 for (const Edge* edge : replica_out_edges[replica]) {
2757 data_outputs[pos].push_back({edge->dst(), edge->dst_input()});
2758 }
2759 output_types[pos] = replicated_outputs[i]->input_type(0);
2760 }
2761 }
2762
2763 // TODO(b/79092708): Consolidate and cleanup to avoid TPU specialization.
2764 NodeDef def;
2765 def.set_name(replicate->name());
2766 def.set_op("_TPUReplicate");
2767 MergeDebugInfo(NodeDebugInfo(replicate->def()), &def);
2768 NameAttrList computation;
2769 computation.set_name(replicate->type_string());
2770 AddNodeAttr("computation", computation, &def);
2771 for (const auto& attr : replicate->attrs()) {
2772 def.mutable_attr()->insert(attr);
2773 }
2774 AddNodeAttr("Tinputs", replicated_input_types, &def);
2775 AddNodeAttr("Tbroadcast_inputs", broadcast_input_types, &def);
2776 AddNodeAttr("NumVariables", num_variables, &def);
2777 AddNodeAttr("Tguaranteed_constants", guaranteed_constant_types, &def);
2778 AddNodeAttr("output_types", output_types, &def);
2779 AddNodeAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
2780 mirrored_variable_indices, &def);
2781 AddNodeAttr("num_distributed_variables", num_distributed_vars, &def);
2782
2783 for (Node* node : nodes_to_remove) {
2784 VLOG(2) << "Deleting node " << node->DebugString();
2785 // Ensure that we do not attempt to add control edges to nodes that are
2786 // deleted.
2787 control_inputs.erase(node);
2788 control_outputs.erase(node);
2789 graph->RemoveNode(node);
2790 }
2791
2792 TF_ASSIGN_OR_RETURN(Node * tpu_replicate, graph->AddNode(def));
2793 for (int i = 0; i < data_inputs.size(); ++i) {
2794 graph->AddEdge(data_inputs[i].first, data_inputs[i].second, tpu_replicate,
2795 i);
2796 }
2797 for (Node* n : control_inputs) {
2798 graph->AddControlEdge(n, tpu_replicate);
2799 }
2800 for (int i = 0; i < data_outputs.size(); ++i) {
2801 for (const auto& successor : data_outputs[i]) {
2802 graph->AddEdge(tpu_replicate, i, successor.first, successor.second);
2803 }
2804 }
2805 for (Node* n : control_outputs) {
2806 graph->AddControlEdge(tpu_replicate, n);
2807 }
2808 }
2809 return OkStatus();
2810 }
2811
Run(const GraphOptimizationPassOptions & options)2812 Status EncapsulateTPUComputationsPass::Run(
2813 const GraphOptimizationPassOptions& options) {
2814 VLOG(1) << "EncapsulateTPUComputations(): "
2815 << DumpGraphToFile("encapsulate_tpu_computations_before",
2816 **options.graph, options.flib_def);
2817
2818 TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def));
2819 VLOG(1) << "EncapsulateTPUComputations() half-way: "
2820 << DumpGraphToFile("encapsulate_tpu_computations_halfway",
2821 **options.graph, options.flib_def);
2822
2823 TF_RETURN_IF_ERROR(BuildTPUReplicateOps(options.graph->get()));
2824 VLOG(1) << "EncapsulateTPUComputations() finished: "
2825 << DumpGraphToFile("encapsulate_tpu_computations_after",
2826 **options.graph, options.flib_def);
2827 return OkStatus();
2828 }
2829
ProcessHeadTailOutsideCompilation(const string & outside_compilation_attr_name,int * lifted_arg_count,std::unordered_map<string,XlaClusterInfo> * clusters,Graph * g,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld)2830 Status ExtractOutsideCompilationPass::ProcessHeadTailOutsideCompilation(
2831 const string& outside_compilation_attr_name, int* lifted_arg_count,
2832 std::unordered_map<string, XlaClusterInfo>* clusters, Graph* g,
2833 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) {
2834 // Gather a list of pivots by cluster so we can easily look them up.
2835 absl::node_hash_map<string, Node*> pivots;
2836 string cluster_name;
2837 for (Node* node : g->nodes()) {
2838 if (TryGetNodeAttr(node->attrs(), kPivotForClusterAttr, &cluster_name)) {
2839 pivots[cluster_name] = node;
2840 }
2841 }
2842 for (auto& iter : *clusters) {
2843 // Find pivot node for this XLA cluster.
2844 Node* pivot_node = pivots[iter.first];
2845
2846 // Instantiate XLA computation function.
2847 string xla_func_name = iter.second.func_name_attrs.name();
2848 std::unique_ptr<FunctionBody> xla_fbody;
2849 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
2850 *fld->Find(xla_func_name),
2851 AttrSlice(&iter.second.func_name_attrs.attr()), fld, &xla_fbody));
2852 Graph* xla_graph = xla_fbody->graph;
2853
2854 // Make sure all nodes can be traced from sink node.
2855 FixupSourceAndSinkEdges(xla_graph);
2856
2857 // We create Identity nodes for all _Arg/_Retval nodes in XLA computation.
2858 // Remove those Identity nodes to simplify furthur processing.
2859 TF_RETURN_IF_ERROR(RemoveIdentityNodesForArgRetval(xla_graph));
2860
2861 bool rewritten;
2862 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgs(
2863 xla_graph, flr, fld, lifted_arg_count, &rewritten));
2864
2865 // Move head outside compilation to host.
2866 TF_RETURN_IF_ERROR(MoveHeadOutsideCompilationToHost(
2867 outside_compilation_attr_name, iter.second.func_name_attrs.name(),
2868 iter.second.cluster_name, g, xla_graph, iter.second.node, pivot_node));
2869
2870 // Move tail outside compilation to host.
2871 TF_RETURN_IF_ERROR(MoveTailOutsideCompilationToHost(
2872 outside_compilation_attr_name, iter.second.func_name_attrs.name(),
2873 iter.second.cluster_name, g, xla_graph, iter.second.node, pivot_node));
2874
2875 // Replace outside compilation only _Arg nodes with Placeholder nodes.
2876 TF_RETURN_IF_ERROR(ReplaceArgUsedByOutsideCompilationWithPlaceholder(
2877 outside_compilation_attr_name, xla_func_name, g, xla_graph,
2878 iter.second.node));
2879
2880 // There might be direct data edges between _Arg node and _Retval node in
2881 // `xla_graph`. Remove those edges to avoid back-and-forth data transfer
2882 // between host and XLA.
2883 TF_RETURN_IF_ERROR(RemoveEdgesBetweenArgAndRetval(
2884 iter.second.func_name_attrs.name(), g, xla_graph, iter.second.node));
2885
2886 // After `MoveHeadOutsideCompilationToHost`, there might be unused XLA
2887 // inputs. Remove them.
2888 TF_RETURN_IF_ERROR(RemoveUnusedXlaInput(iter.second.func_name_attrs.name(),
2889 g, xla_graph, iter.second.node));
2890
2891 // After `MoveTailOutsideCompilationToHost`, there might be unused XLA
2892 // outputs. Remove them.
2893 TF_RETURN_IF_ERROR(RemoveUnusedXlaOutput(iter.second.func_name_attrs.name(),
2894 g, xla_graph, iter.second.node));
2895
2896 // Replace original function.
2897 FunctionDef replace_fdef;
2898 TF_RETURN_IF_ERROR(
2899 GraphToFunctionDef(*xla_graph, xla_func_name, &replace_fdef));
2900 TF_RETURN_IF_ERROR(fld->ReplaceFunction(xla_func_name, replace_fdef));
2901
2902 FixupSourceAndSinkEdges(g);
2903 }
2904
2905 return OkStatus();
2906 }
2907
Run(const GraphOptimizationPassOptions & options)2908 Status ExtractOutsideCompilationPass::Run(
2909 const GraphOptimizationPassOptions& options) {
2910 const auto* config =
2911 (options.session_options ? &options.session_options->config : nullptr);
2912 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
2913 new ProcessFunctionLibraryRuntime(
2914 /*device_mgr=*/nullptr, options.session_options->env,
2915 /*config=*/config, TF_GRAPH_DEF_VERSION, options.flib_def,
2916 config ? config->graph_options().optimizer_options()
2917 : OptimizerOptions()));
2918 FunctionLibraryRuntime* flr =
2919 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
2920
2921 // Find XLA compile ops and their corresponding FunctionDefs.
2922 static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
2923 new std::map<string, string>{
2924 {"_TPUReplicate", "computation"},
2925 };
2926 std::unordered_map<string, XlaClusterInfo> clusters;
2927 int lifted_arg_count = 0;
2928 for (Node* n : (*options.graph)->nodes()) {
2929 auto iter = kNodeTypeToFunctionAttrMapping->find(n->type_string());
2930 if (iter == kNodeTypeToFunctionAttrMapping->end()) {
2931 continue;
2932 }
2933
2934 string xla_cluster_name = n->name();
2935
2936 string func_attr = iter->second;
2937 NameAttrList func;
2938 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
2939
2940 std::vector<string> core_list;
2941 TF_RETURN_IF_ERROR(
2942 GetNodeAttr(n->attrs(), "host_compute_core", &core_list));
2943 std::map<string, int> host_compute_core;
2944 TF_RETURN_IF_ERROR(ParseHostComputeCoreList(core_list, &host_compute_core));
2945
2946 clusters.emplace(xla_cluster_name, XlaClusterInfo{xla_cluster_name, func, n,
2947 host_compute_core});
2948 }
2949 TF_RETURN_IF_ERROR(ProcessHeadTailOutsideCompilation(
2950 kOutsideCompilationAttr, &lifted_arg_count, &clusters,
2951 options.graph->get(), flr, options.flib_def));
2952 bool modified;
2953 TF_RETURN_IF_ERROR(ExtractOutsideCompilation(
2954 kTPUReplicateAttr, kOutsideCompilationAttr, clusters,
2955 options.graph->get(), flr, options.flib_def, &modified));
2956 if (modified) {
2957 TF_RETURN_IF_ERROR(
2958 PruneUnreachableFunctionsFromGraph(**options.graph, options.flib_def));
2959 }
2960
2961 return OkStatus();
2962 }
2963
2964 } // namespace tensorflow
2965