xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Compilation for distributed TPU (TPU_REPLICATED_CORE devices).
17 
18 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h"
19 
20 #include <algorithm>
21 #include <queue>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/btree_map.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/strings/escaping.h"
29 #include "tensorflow/compiler/jit/encapsulate_util.h"
30 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
31 #include "tensorflow/compiler/tf2xla/sharding_util.h"
32 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
33 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
34 #include "tensorflow/compiler/xla/array3d.h"
35 #include "tensorflow/compiler/xla/array4d.h"
36 #include "tensorflow/compiler/xla/client/sharding_builder.h"
37 #include "tensorflow/compiler/xla/service/computation_placer.h"
38 #include "tensorflow/compiler/xla/xla.pb.h"
39 #include "tensorflow/core/common_runtime/device_propagation.h"
40 #include "tensorflow/core/common_runtime/function.h"
41 #include "tensorflow/core/common_runtime/graph_constructor.h"
42 #include "tensorflow/core/common_runtime/lower_function_call_op.h"
43 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
44 #include "tensorflow/core/common_runtime/lower_if_op.h"
45 #include "tensorflow/core/common_runtime/lower_while_op.h"
46 #include "tensorflow/core/common_runtime/optimization_registry.h"
47 #include "tensorflow/core/framework/function.h"
48 #include "tensorflow/core/framework/graph_to_functiondef.h"
49 #include "tensorflow/core/framework/node_def_builder.h"
50 #include "tensorflow/core/framework/node_def_util.h"
51 #include "tensorflow/core/framework/partial_tensor_shape.h"
52 #include "tensorflow/core/framework/tensor.pb.h"
53 #include "tensorflow/core/framework/types.pb.h"
54 #include "tensorflow/core/framework/versions.pb.h"
55 #include "tensorflow/core/graph/algorithm.h"
56 #include "tensorflow/core/graph/graph.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/lib/gtl/cleanup.h"
60 #include "tensorflow/core/lib/strings/proto_serialization.h"
61 #include "tensorflow/core/lib/strings/str_util.h"
62 #include "tensorflow/core/platform/error_payloads.h"
63 #include "tensorflow/core/platform/fingerprint.h"
64 #include "tensorflow/core/protobuf/core_platform_payloads.pb.h"
65 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
66 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
67 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
68 #include "tensorflow/core/public/session_options.h"
69 #include "tensorflow/core/tpu/graph_rewrite/cond_builder.h"
70 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
71 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
72 #include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
73 #include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
74 #include "tensorflow/core/tpu/tpu_compile_interface.h"
75 #include "tensorflow/core/tpu/tpu_defs.h"
76 #include "tensorflow/core/tpu/tpu_fingerprint_utils.h"
77 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
78 #include "tensorflow/core/util/device_name_utils.h"
79 #include "tensorflow/core/util/dump_graph.h"
80 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
81 
82 namespace tensorflow {
83 
84 namespace {
85 
86 // Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
87 // topology.
88 constexpr int kTPUTopologyRank = 4;
89 
90 // An upper bound on how many cores may be present in the topology.
91 static constexpr int kTPUMaxTopologySize = 4096;
92 
93 // Attribute containing the serialized xla::OpSharding to be passed to the
94 // corresponding XLA HLO operation, which represents how a shape is distributed
95 // across logical cores, e.g., replication, single-device, or partitioning.
96 const char kShardingAttribute[] = "_XlaSharding";
97 
98 const char kTPUPartitionedInput[] = "TPUPartitionedInput";
99 const char kTPUPartitionedOutput[] = "TPUPartitionedOutput";
100 
101 const char kVarHandleOp[] = "VarHandleOp";
102 
103 static const char* const kTPUCompilationResultAttr = "_tpu_compilation_status";
104 static const char* const kPostDeviceRewriteAttr = "_post_device_rewrite";
105 
106 using NodeAndId = std::pair<const Node*, int>;
107 
108 struct NodeAndPort {
NodeAndPorttensorflow::__anon71bd90b30111::NodeAndPort109   explicit NodeAndPort(Node* node, int port) : node(node), port(port) {}
110 
111   Node* node;
112   // Port of the node, e.g. this can be the `src_output` index of an Edge.
113   int port;
114 };
115 
116 class IntrusiveHeapLink {
117  public:
118   using size_type = size_t;
119   static constexpr size_type kNotMember = -1;
120 
121   IntrusiveHeapLink() = default;
122 
123   // Only IntrusiveHeap and LinkAccess objects should make these objects.
IntrusiveHeapLink(size_type pos)124   explicit IntrusiveHeapLink(size_type pos) : pos_{pos} {}
125 
126   // Only IntrusiveHeap and LinkAccess should get the value.
get() const127   size_type get() const { return pos_; }
128 
129  private:
130   size_type pos_{kNotMember};
131 };
132 
133 template <typename T, IntrusiveHeapLink T::*M>
134 struct IntrusiveHeapDataMemberLinkAccess {
Gettensorflow::__anon71bd90b30111::IntrusiveHeapDataMemberLinkAccess135   IntrusiveHeapLink Get(const T* elem) const { return elem->*M; }
Settensorflow::__anon71bd90b30111::IntrusiveHeapDataMemberLinkAccess136   void Set(T* elem, IntrusiveHeapLink link) const { elem->*M = link; }
137 };
138 
139 template <typename T>
140 struct DefaultIntrusiveHeapLinkAccess {
Gettensorflow::__anon71bd90b30111::DefaultIntrusiveHeapLinkAccess141   IntrusiveHeapLink Get(const T* elem) const { return elem->heap; }
Settensorflow::__anon71bd90b30111::DefaultIntrusiveHeapLinkAccess142   void Set(T* elem, IntrusiveHeapLink link) const { elem->heap = link; }
143 };
144 
145 template <typename T, typename PtrCompare,
146           typename LinkAccess = DefaultIntrusiveHeapLinkAccess<T>,
147           typename Alloc = std::allocator<T*>>
148 class IntrusiveHeap {
149  public:
150   typedef typename IntrusiveHeapLink::size_type size_type;
151   typedef T value_type;
152   typedef T* pointer;
153   typedef const T* const_pointer;
154   typedef PtrCompare pointer_compare_type;
155   typedef LinkAccess link_access_type;
156   typedef Alloc allocator_type;
157 
IntrusiveHeap(const pointer_compare_type & comp=pointer_compare_type (),const link_access_type & link_access=link_access_type (),const allocator_type & alloc=allocator_type ())158   explicit IntrusiveHeap(
159       const pointer_compare_type& comp = pointer_compare_type(),
160       const link_access_type& link_access = link_access_type(),
161       const allocator_type& alloc = allocator_type())
162       : rep_(comp, link_access, alloc) {}
163 
size() const164   size_type size() const { return heap().size(); }
165 
empty() const166   bool empty() const { return heap().empty(); }
167 
168   // Return the top element, but don't remove it.
top() const169   pointer top() const {
170     DCHECK(!empty());
171     return heap()[0];
172   }
173 
174   // Remove the top() pointer from the heap and return it.
Pop()175   pointer Pop() {
176     pointer t = top();
177     Remove(t);
178     return t;
179   }
180 
181   // Insert 't' into the heap.
Push(pointer t)182   void Push(pointer t) {
183     SetPositionOf(t, heap().size());
184     heap().push_back(t);
185     FixHeapUp(t);
186   }
187 
188   // Adjust the heap to accommodate changes in '*t'.
Adjust(pointer t)189   void Adjust(pointer t) {
190     DCHECK(Contains(t));
191     size_type h = GetPositionOf(t);
192     if (h != 0 && compare()(t, heap()[(h - 1) >> 1])) {
193       FixHeapUp(t);
194     } else {
195       FixHeapDown(t);
196     }
197   }
198 
199   // Remove the specified pointer from the heap.
Remove(pointer t)200   void Remove(pointer t) {
201     DCHECK(Contains(t));
202     size_type h = GetPositionOf(t);
203     SetPositionOf(t, IntrusiveHeapLink::kNotMember);
204     if (h == heap().size() - 1) {
205       // Fast path for removing from back of heap.
206       heap().pop_back();
207       return;
208     }
209     // Move the element from the back of the heap to overwrite 't'.
210     pointer& elem = heap()[h];
211     elem = heap().back();
212     SetPositionOf(elem, h);  // Element has moved, so update its link.
213     heap().pop_back();
214     Adjust(elem);  // Restore the heap invariant.
215   }
216 
Clear()217   void Clear() { heap().clear(); }
218 
Contains(const_pointer t) const219   bool Contains(const_pointer t) const {
220     size_type h = GetPositionOf(t);
221     return (h != IntrusiveHeapLink::kNotMember) && (h < size()) &&
222            heap()[h] == t;
223   }
224 
reserve(size_type n)225   void reserve(size_type n) { heap().reserve(n); }
226 
capacity() const227   size_type capacity() const { return heap().capacity(); }
228 
get_allocator() const229   allocator_type get_allocator() const { return rep_.heap_.get_allocator(); }
230 
231  private:
232   typedef std::vector<pointer, allocator_type> heap_type;
233 
234   // Empty base class optimization for pointer_compare and link_access.
235   // The heap_ data member retains a copy of the allocator, so it is not
236   // stored explicitly.
237   struct Rep : pointer_compare_type, link_access_type {
Reptensorflow::__anon71bd90b30111::IntrusiveHeap::Rep238     explicit Rep(const pointer_compare_type& cmp,
239                  const link_access_type& link_access,
240                  const allocator_type& alloc)
241         : pointer_compare_type(cmp),
242           link_access_type(link_access),
243           heap_(alloc) {}
244     heap_type heap_;  // NOLINT
245   };
246 
compare() const247   const pointer_compare_type& compare() const { return rep_; }
248 
link_access() const249   const link_access_type& link_access() const { return rep_; }
250 
heap() const251   const heap_type& heap() const { return rep_.heap_; }
heap()252   heap_type& heap() { return rep_.heap_; }
253 
GetPositionOf(const_pointer t) const254   size_type GetPositionOf(const_pointer t) const {
255     return link_access().Get(t).get();
256   }
257 
SetPositionOf(pointer t,size_type pos) const258   void SetPositionOf(pointer t, size_type pos) const {
259     return link_access().Set(t, IntrusiveHeapLink(pos));
260   }
261 
FixHeapUp(pointer t)262   void FixHeapUp(pointer t) {
263     size_type h = GetPositionOf(t);
264     while (h != 0) {
265       size_type parent = (h - 1) >> 1;
266       if (compare()(heap()[parent], t)) {
267         break;
268       }
269       heap()[h] = heap()[parent];
270       SetPositionOf(heap()[h], h);
271       h = parent;
272     }
273     heap()[h] = t;
274     SetPositionOf(t, h);
275   }
276 
FixHeapDown(pointer t)277   void FixHeapDown(pointer t) {
278     size_type h = GetPositionOf(t);
279     for (;;) {
280       size_type kid = (h << 1) + 1;
281       if (kid >= heap().size()) {
282         break;
283       }
284       if (kid + 1 < heap().size() && compare()(heap()[kid + 1], heap()[kid])) {
285         ++kid;
286       }
287       if (compare()(t, heap()[kid])) {
288         break;
289       }
290       heap()[h] = heap()[kid];
291       SetPositionOf(heap()[h], h);
292       h = kid;
293     }
294 
295     heap()[h] = t;
296     SetPositionOf(t, h);
297   }
298 
299   Rep rep_;
300 };
301 
CoreDeviceLabel(int core)302 string CoreDeviceLabel(int core) {
303   return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core);
304 }
305 
306 // Creates a unique node name with a particular prefix.
UniqueNodeName(const StringPiece prefix,Graph * graph)307 string UniqueNodeName(const StringPiece prefix, Graph* graph) {
308   return graph->NewName(strings::StrCat(prefix, "/_", internal::GetNodeId()));
309 }
310 
SetNodeDeviceForTPUCommunication(DeviceNameUtils::ParsedName device,const string & target_device_type,Node * node)311 Status SetNodeDeviceForTPUCommunication(DeviceNameUtils::ParsedName device,
312                                         const string& target_device_type,
313                                         Node* node) {
314   TF_RET_CHECK(device.has_type && device.type == DEVICE_TPU_NODE);
315   TF_RET_CHECK(device.has_id);
316   TF_RET_CHECK(HasNodeAttr(node->def(), kXlaHasHostTransferAttrName));
317 
318   // Store the device instance as an attr on the Node.
319   TF_RETURN_IF_ERROR(SetDeviceOrdinalAttributeForNode(node, device.id));
320 
321   // Place the execute Op on the TPU_SYSTEM device so it can access the cache of
322   // compiled protos in the resource manager.
323   device.type = target_device_type;
324   device.id = 0;
325 
326   node->set_assigned_device_name(DeviceNameUtils::ParsedNameToString(device));
327   return OkStatus();
328 }
329 
330 // Iterate over the nodes in the original graph and find all the TPUReplicate
331 // nodes, and all the nodes that are part of outside_compilation clusters.
FindTaggedNodes(Graph * graph,std::vector<Node * > * replicate_nodes,std::map<string,DistributedTPURewritePass::OutsideCompilationNodeMap> * outside_compilation_nodes,std::map<string,std::vector<Node * >> * head_tail_outside_compilation_nodes)332 Status FindTaggedNodes(
333     Graph* graph, std::vector<Node*>* replicate_nodes,
334     std::map<string, DistributedTPURewritePass::OutsideCompilationNodeMap>*
335         outside_compilation_nodes,
336     std::map<string, std::vector<Node*>>* head_tail_outside_compilation_nodes) {
337   for (Node* node : graph->op_nodes()) {
338     if (node->type_string() == "_TPUReplicate") {
339       replicate_nodes->push_back(node);
340       const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
341       if (cluster_attr == nullptr) {
342         return errors::Internal("TPUReplicate node ", node->name(), " has no ",
343                                 kTPUReplicateAttr, " attr.");
344       } else {
345         const string& cluster = cluster_attr->s();
346         if (cluster.empty()) {
347           return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
348                                   node->name(), " has no string value.");
349         }
350         if (outside_compilation_nodes->find(cluster) !=
351             outside_compilation_nodes->end()) {
352           return errors::Internal(
353               "TPUReplicate node ", node->name(), " has ", kTPUReplicateAttr,
354               " attr value '", cluster,
355               "' which is a duplicate of another TPUReplicate node in the "
356               "graph.");
357         }
358         (*outside_compilation_nodes)[cluster] =
359             DistributedTPURewritePass::OutsideCompilationNodeMap();
360         (*head_tail_outside_compilation_nodes)[cluster] = std::vector<Node*>();
361       }
362     }
363   }
364   for (Node* node : graph->op_nodes()) {
365     if (node->type_string() != "_TPUReplicate") {
366       const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
367       const AttrValue* outside_compilation_attr =
368           node->attrs().Find(kOutsideCompilationAttr);
369       if (cluster_attr == nullptr) {
370         if (outside_compilation_attr != nullptr) {
371           return errors::Internal("Node ", node->name(), " has ",
372                                   kOutsideCompilationAttr, " attr but no ",
373                                   kTPUReplicateAttr, " attr.");
374         }
375       } else {
376         const string& cluster = cluster_attr->s();
377         if (cluster.empty()) {
378           return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
379                                   node->name(), " has no string value.");
380         }
381         const auto iter = outside_compilation_nodes->find(cluster);
382         if (iter == outside_compilation_nodes->end()) {
383           return errors::Internal(
384               "Attr ", kTPUReplicateAttr, " on node ", node->name(),
385               " does not correspond to a TPUReplicate node.");
386         }
387         if (outside_compilation_attr == nullptr) {
388           return errors::Internal("Node ", node->name(), " has ",
389                                   kTPUReplicateAttr, " attr but no ",
390                                   kOutsideCompilationAttr, " attr.");
391         }
392         const string& oc_cluster = outside_compilation_attr->s();
393         if (oc_cluster.empty()) {
394           return errors::Internal("Attr ", kOutsideCompilationAttr, " on node ",
395                                   node->name(), " has no string value.");
396         }
397 
398         // Outside compilation cluster at head and tail of TPU computation has
399         // already been moved to host and is already replicated. As so, do not
400         // replicate outside compilation nodes with replica id attribute.
401         int replica_id;
402         if (TryGetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id)) {
403           const AttrValue* head_attr =
404               node->attrs().Find("_xla_only_arg_or_oc_input");
405           const AttrValue* tail_attr =
406               node->attrs().Find("_xla_only_ret_or_oc_output");
407           if (((head_attr != nullptr) && (head_attr->b())) ||
408               ((tail_attr != nullptr) && (tail_attr->b()))) {
409             // This is safe as this has the same keys as
410             // outside_compilation_nodes which we already know has this key.
411             (*head_tail_outside_compilation_nodes)[cluster].push_back(node);
412           }
413           continue;
414         }
415         iter->second[oc_cluster].push_back(node);
416       }
417     }
418   }
419   return OkStatus();
420 }
421 
422 // Helper class to spread TPU computation arguments and return values
423 // across cores.
424 // If all shapes are fully defined, balance by their size.
425 // If some of them are not fully defined, the undefined shapes size will
426 // be estimated with the average size of the fully defined ones.
427 // If none are defined, fall back to round-robin.
428 class TensorDevicePlacer {
429  public:
430   // Creates a TensorDevicePlacer object to distribute arguments or
431   // return values to a set of num_devices devices, where the types and
432   // the inferred shapes of the inputs (arguments or return values) are
433   // passed in types and shapes.
TensorDevicePlacer(int64_t num_devices,const DataTypeVector & types,const std::vector<InferredShape> & shapes)434   TensorDevicePlacer(int64_t num_devices, const DataTypeVector& types,
435                      const std::vector<InferredShape>& shapes)
436       : index_nodes_(num_devices), sizes_(types.size()) {
437     int64_t total_size = 0;
438     int64_t num_defined = 0;
439     for (int64_t i = 0; i < types.size(); ++i) {
440       sizes_[i] = GetInferredShapeSize(shapes[i], types[i]);
441       if (sizes_[i] >= 0) {
442         total_size += sizes_[i];
443         ++num_defined;
444       }
445     }
446     // If a shape is undefined, select a size for it which is the average
447     // of the defined shapes. If no shapes are defined, assign 1 so that we
448     // get round-robin behavior.
449     int64_t undefined_shape_size =
450         (num_defined > 0) ? total_size / num_defined : 1;
451     for (int64_t i = 0; i < sizes_.size(); ++i) {
452       if (sizes_[i] < 0) {
453         sizes_[i] = undefined_shape_size;
454       }
455     }
456 
457     for (int64_t i = 0; i < num_devices; ++i) {
458       heap_.Push(&index_nodes_[i]);
459     }
460   }
461 
462   // Reports that the argument/return-value at index has been assigned
463   // by the user to a given device.
ReportDeviceAssigned(int64_t device,int64_t index)464   void ReportDeviceAssigned(int64_t device, int64_t index) {
465     if (device >= index_nodes_.size()) {
466       LOG(FATAL) << "Sharding assignment is out of bounds. "  // Crash OK
467                     "Check that the number of nodes is properly set.";
468     }
469     DeviceNode* node = &index_nodes_.at(device);
470     node->size += sizes_.at(index);
471     heap_.Adjust(node);
472   }
473 
474   // Retrieves the device at which the argument/return-value at index
475   // should be assigned to.
RetrieveAssignment(int64_t index)476   int64_t RetrieveAssignment(int64_t index) {
477     DeviceNode* node = heap_.top();
478     int64_t device = node - index_nodes_.data();
479     node->size += sizes_.at(index);
480     heap_.Adjust(node);
481     return device;
482   }
483 
484  private:
485   struct DeviceNode {
486     struct Compare {
487       // Compare functor to implement a min heap using the ::gtl::IntrusiveHeap
488       // infrastructure.
operator ()tensorflow::__anon71bd90b30111::TensorDevicePlacer::DeviceNode::Compare489       bool operator()(const DeviceNode* lhs, const DeviceNode* rhs) const {
490         return lhs->size < rhs->size;
491       }
492     };
493 
494     IntrusiveHeapLink heap;
495     int64_t size = 0;
496   };
497 
GetInferredShapeSize(const InferredShape & ishape,DataType dtype)498   static int64_t GetInferredShapeSize(const InferredShape& ishape,
499                                       DataType dtype) {
500     return ishape.shape.IsFullyDefined()
501                ? ishape.shape.num_elements() * DataTypeSize(dtype)
502                : -1;
503   }
504 
505   std::vector<DeviceNode> index_nodes_;
506   IntrusiveHeap<DeviceNode, typename DeviceNode::Compare> heap_;
507   std::vector<int64_t> sizes_;
508 };
509 
ValidateCoreNumber(int64_t core,int64_t num_cores_per_replica)510 Status ValidateCoreNumber(int64_t core, int64_t num_cores_per_replica) {
511   if (core < 0 || core >= num_cores_per_replica) {
512     return tensorflow::errors::InvalidArgument("Invalid core ID: ", core,
513                                                ". The valid core IDs are [0..",
514                                                num_cores_per_replica, ")");
515   }
516   return OkStatus();
517 }
518 
FindHostComputeKeyPlaceholderNodes(const Graph * graph,const std::vector<Node * > & replicate_nodes,std::unordered_map<string,Node * > * host_compute_key_placeholder_map)519 Status FindHostComputeKeyPlaceholderNodes(
520     const Graph* graph, const std::vector<Node*>& replicate_nodes,
521     std::unordered_map<string, Node*>* host_compute_key_placeholder_map) {
522   host_compute_key_placeholder_map->clear();
523   for (const auto node : replicate_nodes) {
524     (*host_compute_key_placeholder_map)[node->name()] = nullptr;
525   }
526 
527   for (Node* node : graph->op_nodes()) {
528     if (node->type_string() == "Placeholder" &&
529         str_util::EndsWith(node->name(), "_key_placeholder")) {
530       const AttrValue* call_node_attr =
531           node->attrs().Find("_host_compute_call_node");
532       if (call_node_attr != nullptr) {
533         auto iter = host_compute_key_placeholder_map->find(call_node_attr->s());
534         if (iter == host_compute_key_placeholder_map->end()) {
535           return errors::InvalidArgument(
536               "Node ", node->name(), " has _host_compute_call_node attribute '",
537               call_node_attr->s(), "' that doesn't correspond to a call node");
538         }
539         if (iter->second != nullptr) {
540           return errors::InvalidArgument(
541               "Key placeholder node ", iter->second->name(), " for call node ",
542               call_node_attr->s(), " previously found as ",
543               iter->second->name());
544         }
545         iter->second = node;
546       }
547     }
548   }
549 
550   return OkStatus();
551 }
552 
ReplaceCompilationResultNodeWithIdentity(Graph * graph,Node ** node)553 Status ReplaceCompilationResultNodeWithIdentity(Graph* graph, Node** node) {
554   Node* old_node = *node;
555   // We want to replace the node with an identity node with the same name.
556   const string& node_name = old_node->name();
557 
558   // Create identity node.
559   TF_ASSIGN_OR_RETURN(
560       Node * id_node,
561       BuildIdentityNode(graph, node_name, DT_STRING,
562                         /*input=*/nullptr, /*requested_device=*/""));
563 
564   // No incoming edges are copied as a new one will be added from compile node
565   // to id_node.
566 
567   // Copy outgoing edges to the id node.
568   std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
569                                      old_node->out_edges().end());
570   for (const Edge* edge : out_edges) {
571     Node* dst = edge->dst();
572     int src_output = edge->src_output();
573     int dst_input = edge->dst_input();
574 
575     if (src_output == Graph::kControlSlot) {
576       graph->AddControlEdge(id_node, dst);
577     } else {
578       graph->AddEdge(id_node, src_output, dst, dst_input);
579     }
580     graph->RemoveEdge(edge);
581   }
582   graph->RemoveNode(old_node);
583 
584   *node = id_node;
585   return OkStatus();
586 }
587 
GetStepMarkerLocation(const Node & replicate_node,xla::DebugOptions::StepMarkerLocation * location)588 Status GetStepMarkerLocation(const Node& replicate_node,
589                              xla::DebugOptions::StepMarkerLocation* location) {
590   string step_marker_location_attr;
591   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "step_marker_location",
592                                  &step_marker_location_attr));
593   if (step_marker_location_attr.empty()) {
594     *location = xla::DebugOptions::STEP_MARK_AT_ENTRY;
595   } else {
596     if (!xla::DebugOptions::StepMarkerLocation_Parse(step_marker_location_attr,
597                                                      location)) {
598       return errors::InvalidArgument("Malformed step_marker_location: ",
599                                      step_marker_location_attr);
600     }
601   }
602   return OkStatus();
603 }
604 
605 // Extracts a map of dimension and number of splits for tiled input from xla
606 // sharding attribute.
GetDimensionIndicesAndNumSplitsFromSharding(const xla::OpSharding & sharding,std::map<int,int> * split_dimension_map)607 Status GetDimensionIndicesAndNumSplitsFromSharding(
608     const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) {
609   int64_t tensor_tile_rank = sharding.tile_assignment_dimensions_size();
610   if (sharding.replicate_on_last_tile_dim()) {
611     tensor_tile_rank--;
612   }
613   for (int dim_index = 0; dim_index < tensor_tile_rank; dim_index++) {
614     if (sharding.tile_assignment_dimensions(dim_index) > 1) {
615       split_dimension_map->emplace(
616           dim_index, sharding.tile_assignment_dimensions(dim_index));
617     }
618   }
619 
620   if (split_dimension_map->empty()) {
621     return errors::InvalidArgument("Arg has unnecessary tiled sharding: ",
622                                    sharding.DebugString());
623   }
624   return OkStatus();
625 }
626 
627 // Updates contents of the function with `function_name` in function library
628 // definition `flib_def` to `new_graph`. This is required when graph
629 // transformation happens inside a function call body.
UpdateFunctionLibDefinition(const Graph & new_graph,const std::string & function_name,FunctionLibraryDefinition * flib_def)630 Status UpdateFunctionLibDefinition(const Graph& new_graph,
631                                    const std::string& function_name,
632                                    FunctionLibraryDefinition* flib_def) {
633   FunctionDef graph_fdef;
634   TF_RETURN_IF_ERROR(GraphToFunctionDef(new_graph, function_name, &graph_fdef));
635   TF_RETURN_IF_ERROR(flib_def->ReplaceFunction(function_name, graph_fdef));
636   return OkStatus();
637 }
638 
639 struct NodeOut {
640   Node* node;
641   int index;
642 };
643 
644 struct ShardedInputIndex {
645   int replica_id;
646   int argument_index;
647 
operator <tensorflow::__anon71bd90b30111::ShardedInputIndex648   bool operator<(const ShardedInputIndex& rhs) const {
649     return std::tie(replica_id, argument_index) <
650            std::tie(rhs.replica_id, rhs.argument_index);
651   }
652 };
653 
654 struct ShardedPerHostInputIndex {
655   string host_device;
656   int argument_index;
operator <tensorflow::__anon71bd90b30111::ShardedPerHostInputIndex657   bool operator<(const ShardedPerHostInputIndex& rhs) const {
658     return std::tie(host_device, argument_index) <
659            std::tie(rhs.host_device, rhs.argument_index);
660   }
operator ==tensorflow::__anon71bd90b30111::ShardedPerHostInputIndex661   bool operator==(const ShardedPerHostInputIndex& rhs) const {
662     return (argument_index == rhs.argument_index) &&
663            (host_device == rhs.host_device);
664   }
665 };
666 
667 struct ShardedInputInfo {
668   // Split node that would be connected to tiled input Node.
669   Node* split_node;
670   // List of splits nodes and output index of the split node from which sharded
671   // input will be connected to the TPUExecute node. The inputs are ordered by
672   // logical core ids.
673   std::vector<NodeOut> sharded_inputs;
674 };
675 
676 // Adds pad node after split node to graph for uneven sharding tiled inputs.
677 // |graph| owns the returned Node* instance.
CreatePadNode(const int padding,const int num_dims,const int split_dim,DataType dtype,Node * control_predecessor,Node * split_node,const int split_index,Graph * graph)678 xla::StatusOr<Node*> CreatePadNode(const int padding, const int num_dims,
679                                    const int split_dim, DataType dtype,
680                                    Node* control_predecessor, Node* split_node,
681                                    const int split_index, Graph* graph) {
682   // Add paddings node.
683   Status s;
684   NodeDef paddings_def;
685   paddings_def.set_name(
686       graph->NewName(absl::StrCat(split_node->name(), "/paddings")));
687   paddings_def.set_op("Const");
688   AddNodeAttr("dtype", DT_INT32, &paddings_def);
689   paddings_def.set_device(split_node->assigned_device_name());
690   TensorProto sizes_tensor_proto;
691   sizes_tensor_proto.set_dtype(DT_INT32);
692   for (int i = 0; i < num_dims; ++i) {
693     sizes_tensor_proto.add_int_val(0);
694     if (i == split_dim) {
695       sizes_tensor_proto.add_int_val(padding);
696     } else {
697       sizes_tensor_proto.add_int_val(0);
698     }
699   }
700   TensorShape sizes_shape({num_dims, 2});
701   sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
702   AddNodeAttr("value", sizes_tensor_proto, &paddings_def);
703   TF_ASSIGN_OR_RETURN(Node * paddings_node, graph->AddNode(paddings_def));
704 
705   // Add Pad node.
706   NodeDef pad_def;
707   pad_def.set_name(graph->NewName(
708       absl::StrCat(split_node->name(), "/pad_shard_", split_index)));
709   pad_def.set_op("Pad");
710   pad_def.set_device(split_node->assigned_device_name());
711   AddNodeAttr("T", dtype, &pad_def);
712   AddNodeAttr("Tpaddings", DT_INT32, &pad_def);
713   pad_def.add_input(absl::StrCat(split_node->name(), ":", split_index));
714   pad_def.add_input(absl::StrCat(paddings_node->name(), ":0"));
715   TF_ASSIGN_OR_RETURN(Node * pad_node, graph->AddNode(pad_def));
716   pad_node->set_assigned_device_name(split_node->assigned_device_name());
717   // Add edges for pad node.
718   graph->AddEdge(split_node, split_index, pad_node, 0);
719   graph->AddEdge(paddings_node, 0, pad_node, 1);
720   graph->AddControlEdge(control_predecessor, pad_node);
721   return pad_node;
722 }
723 
724 // Adds split node and split dimension node to graph for sharding tiled inputs.
725 // |graph| owns the returned Node* instance.
CreateSplitNode(const int num_splits,const int dim,const int num_dims,const int64_t padding,const int orig_src_output,DataType dtype,absl::string_view name_prefix,Node * control_predecessor,Node * orig_src,Graph * graph)726 xla::StatusOr<Node*> CreateSplitNode(const int num_splits, const int dim,
727                                      const int num_dims, const int64_t padding,
728                                      const int orig_src_output, DataType dtype,
729                                      absl::string_view name_prefix,
730                                      Node* control_predecessor, Node* orig_src,
731                                      Graph* graph) {
732   const std::string input_assigned_device = orig_src->assigned_device_name();
733   Node* to_split_node = orig_src;
734   int to_split_index = orig_src_output;
735   if (padding > 0) {
736     TF_ASSIGN_OR_RETURN(
737         Node * pad_node,
738         CreatePadNode(padding, num_dims, dim, dtype, control_predecessor,
739                       orig_src, orig_src_output, graph));
740     to_split_node = pad_node;
741     to_split_index = 0;
742   }
743 
744   // Add a split dimension node.
745   NodeDef split_dim_def;
746   split_dim_def.set_name(
747       graph->NewName(absl::StrCat(name_prefix, "/split_dim")));
748   split_dim_def.set_op("Const");
749   split_dim_def.set_device(input_assigned_device);
750   AddNodeAttr("dtype", DT_INT32, &split_dim_def);
751   TensorProto tensor_proto;
752   tensor_proto.set_dtype(DT_INT32);
753   tensor_proto.add_int_val(dim);
754   TensorShape shape({});
755   shape.AsProto(tensor_proto.mutable_tensor_shape());
756   AddNodeAttr("value", tensor_proto, &split_dim_def);
757   TF_ASSIGN_OR_RETURN(Node * split_dim_node, graph->AddNode(split_dim_def));
758   // Add a split node.
759   NodeDef split_def;
760   split_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/split")));
761   split_def.set_op("Split");
762   split_def.set_device(input_assigned_device);
763   AddNodeAttr("num_split", num_splits, &split_def);
764   AddNodeAttr("T", dtype, &split_def);
765   split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
766   split_def.add_input(absl::StrCat(to_split_node->name(), ":", to_split_index));
767   TF_ASSIGN_OR_RETURN(Node * split_node, graph->AddNode(split_def));
768 
769   split_node->set_assigned_device_name(input_assigned_device);
770 
771   // If colocate the newly created split op to source node of input to TPU
772   // computation.
773   split_node->AddAttr(kColocationAttrName,
774                       std::vector<string>{absl::StrCat(kColocationGroupPrefix,
775                                                        orig_src->name())});
776 
777   graph->AddEdge(split_dim_node, 0, split_node, 0);
778   graph->AddEdge(to_split_node, to_split_index, split_node, 1);
779 
780   // Add a control dependency from `control_predecessor` to newly created
781   // constant node. This ensures that newly added split/split dim
782   // nodes are placed inside correct while loop frames when TPUExecute
783   // node is inside a host training loop.
784   graph->AddControlEdge(control_predecessor, split_dim_node);
785   return split_node;
786 }
787 
GetPadding(const int split_dim,const int num_splits,const PartialTensorShape & partial_tensor_shape)788 int64_t GetPadding(const int split_dim, const int num_splits,
789                    const PartialTensorShape& partial_tensor_shape) {
790   // If dim dimension is not defined, no uneven sharding support.
791   if (partial_tensor_shape.dim_size(split_dim) <= 0) {
792     return 0;
793   }
794   int64_t per_split_size = tensorflow::MathUtil::CeilOfRatio<int64_t>(
795       partial_tensor_shape.dim_size(split_dim), num_splits);
796   int64_t total_padding =
797       per_split_size * num_splits - partial_tensor_shape.dim_size(split_dim);
798   return total_padding;
799 }
800 
801 // Creates a set of splits nodes that shards tiled input node in graph.
CreateOrGetSplitNodesForInputSharding(const xla::OpSharding & sharding,int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,int replica_id,int orig_src_output,Node * orig_src,Node * control_predecessor,Graph * graph,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)802 xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
803     const xla::OpSharding& sharding, int orig_arg_num, DataType dtype,
804     const PartialTensorShape& partial_tensor_shape, int replica_id,
805     int orig_src_output, Node* orig_src, Node* control_predecessor,
806     Graph* graph,
807     std::map<ShardedInputIndex, ShardedInputInfo>*
808         arg_index_to_sharded_input_map) {
809   ShardedInputIndex input_index{replica_id, orig_arg_num};
810   auto iter = arg_index_to_sharded_input_map->find(input_index);
811   if (iter != arg_index_to_sharded_input_map->end()) {
812     return iter->second;
813   }
814   // Maps input dimension and number of splits with which the
815   // dimension sharded.
816   std::map<int, int> split_dimension_map;
817   TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
818       sharding, &split_dimension_map));
819   TF_RET_CHECK(!split_dimension_map.empty())
820       << "Unnecessary sharding attribute found.";
821 
822   // For v1 while loop, nodes inside the loop body must either
823   //  1) Have data edges from while loop input node.
824   //  or
825   //  2) Have direct control dependency from while loop input control
826   //     node.
827   //
828   // As so, if we are adding Split node inside, while loop body,
829   // we must manually add a control dependency to a node inside
830   // a while loop (i.e. `control_predecessor`) to constant nodes
831   // without data in-edges to make sure that added split nodes
832   // have correct frame name. Else, placer will complain when
833   // `BuildControlFlow()` is invoked.
834 
835   auto sharding_it = split_dimension_map.begin();
836   std::queue<Node*> split_nodes_for_dimension;
837   absl::flat_hash_map<Node*, int> node_to_split_dim;
838   int split_dimension = sharding_it->first;
839   int num_split = sharding_it->second;
840 
841   // Creates a tree of split nodes for sharding tiled inputs. Splits nodes
842   // are created such that input data is sharded in row major order.
843   // Split nodes at ith depth from the original input node represent nodes
844   // that split the input data at ith dimension.
845   TF_ASSIGN_OR_RETURN(
846       Node * root_split_node,
847       CreateSplitNode(
848           num_split, split_dimension, partial_tensor_shape.dims(),
849           GetPadding(split_dimension, num_split, partial_tensor_shape),
850           orig_src_output, dtype,
851           absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
852                        split_dimension),
853           control_predecessor, orig_src, graph));
854   sharding_it++;
855 
856   split_nodes_for_dimension.emplace(root_split_node);
857   node_to_split_dim[root_split_node] = split_dimension;
858 
859   while (sharding_it != split_dimension_map.end()) {
860     split_dimension = sharding_it->first;
861     num_split = sharding_it->second;
862     int num_split_nodes_in_dimension = split_nodes_for_dimension.size();
863     for (int i = 0; i < num_split_nodes_in_dimension; ++i) {
864       Node* input_split_node = split_nodes_for_dimension.front();
865       split_nodes_for_dimension.pop();
866       for (int src_output_index = 0;
867            src_output_index < input_split_node->num_outputs();
868            ++src_output_index) {
869         TF_ASSIGN_OR_RETURN(
870             Node * split_node,
871             CreateSplitNode(
872                 num_split, split_dimension, partial_tensor_shape.dims(),
873                 GetPadding(split_dimension, num_split, partial_tensor_shape),
874                 src_output_index, dtype,
875                 absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
876                              split_dimension),
877                 control_predecessor, input_split_node, graph));
878         split_nodes_for_dimension.emplace(split_node);
879         node_to_split_dim[split_node] = split_dimension;
880       }
881     }
882     sharding_it++;
883   }
884 
885   // `split_nodes_for_dimension` now includes final split nodes
886   // from which sharded data will be fed into TPUExcute nodes -- sorted by
887   // row major order.
888   std::vector<NodeOut> sharded_inputs_list(
889       sharding.tile_assignment_devices_size());
890   int64_t next_core_tile_index = 0;
891   while (!split_nodes_for_dimension.empty()) {
892     Node* split_node = split_nodes_for_dimension.front();
893     split_nodes_for_dimension.pop();
894     int num_splits;
895     TF_RETURN_IF_ERROR(
896         GetNodeAttr(split_node->def(), "num_split", &num_splits));
897     for (int out_index = 0; out_index < num_splits; ++out_index) {
898       int64_t repeat_count =
899           sharding.replicate_on_last_tile_dim()
900               ? *sharding.tile_assignment_dimensions().rbegin()
901               : 1;
902       for (int64_t i = 0; i < repeat_count; ++i) {
903         int64_t next_core =
904             sharding.tile_assignment_devices(next_core_tile_index++);
905         sharded_inputs_list[next_core] = NodeOut{split_node, out_index};
906       }
907     }
908   }
909 
910   ShardedInputInfo sharded_input_info{root_split_node,
911                                       std::move(sharded_inputs_list)};
912   (*arg_index_to_sharded_input_map)[input_index] = sharded_input_info;
913   return sharded_input_info;
914 }
915 
916 // Creates a xla split node to shard an input, and adds that new node to a
917 // Graph.
CreateXlaSplitOp(absl::string_view node_name,const bool is_resource,const NodeOut & input,const PartialTensorShape & partial_tensor_shape,const std::vector<Node * > & control_inputs,const std::vector<Node * > & control_outputs,const DataType dtype,const int num_shards,const xla::OpSharding & sharding,Graph * graph)918 StatusOr<Node*> CreateXlaSplitOp(absl::string_view node_name,
919                                  const bool is_resource, const NodeOut& input,
920                                  const PartialTensorShape& partial_tensor_shape,
921                                  const std::vector<Node*>& control_inputs,
922                                  const std::vector<Node*>& control_outputs,
923                                  const DataType dtype, const int num_shards,
924                                  const xla::OpSharding& sharding,
925                                  Graph* graph) {
926   const std::string& input_assigned_device = input.node->assigned_device_name();
927   NodeDef xla_split_def;
928   xla_split_def.set_name(graph->NewName(node_name));
929   xla_split_def.set_op(is_resource ? "ReadVariableXlaSplitND" : "XlaSplitND");
930   xla_split_def.set_device(input_assigned_device);
931   AddNodeAttr("T", dtype, &xla_split_def);
932   AddNodeAttr("N", num_shards, &xla_split_def);
933   const std::vector<int64_t> num_splits(
934       sharding.tile_assignment_dimensions().begin(),
935       sharding.replicate_on_last_tile_dim()
936           ? std::prev(sharding.tile_assignment_dimensions().end())
937           : sharding.tile_assignment_dimensions().end());
938   AddNodeAttr("num_splits", num_splits, &xla_split_def);
939   const int rank = sharding.replicate_on_last_tile_dim()
940                        ? sharding.tile_assignment_dimensions_size() - 1
941                        : sharding.tile_assignment_dimensions_size();
942   std::vector<int32> paddings;
943   paddings.reserve(rank);
944   for (int dim = 0; dim < rank; ++dim) {
945     paddings.push_back(GetPadding(dim, sharding.tile_assignment_dimensions(dim),
946                                   partial_tensor_shape));
947   }
948   AddNodeAttr("paddings", paddings, &xla_split_def);
949 
950   if (!is_resource) {
951     AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &xla_split_def);
952     AddNodeAttr(kColocationAttrName,
953                 std::vector<string>{
954                     absl::StrCat(kColocationGroupPrefix, input.node->name())},
955                 &xla_split_def);
956   }
957 
958   TF_ASSIGN_OR_RETURN(Node * xla_split, graph->AddNode(xla_split_def));
959   if (is_resource) {
960     xla_split->set_requested_device(input.node->requested_device());
961   }
962   xla_split->set_assigned_device_name(input_assigned_device);
963   graph->AddEdge(input.node, input.index, xla_split, 0);
964   for (Node* control_input : control_inputs) {
965     graph->AddControlEdge(control_input, xla_split);
966   }
967   for (Node* control_output : control_outputs) {
968     graph->AddControlEdge(xla_split, control_output);
969   }
970   return xla_split;
971 }
972 
973 // Creates a sharded tensor list for all input shards of an input with sharding.
ShardInputWithXlaSplitOp(absl::string_view node_name,const bool is_resource,const NodeOut & input,const PartialTensorShape & partial_tensor_shape,const std::vector<Node * > & control_inputs,const std::vector<Node * > & control_outputs,const DataType dtype,const xla::OpSharding & sharding,Graph * graph)974 xla::StatusOr<std::vector<NodeOut>> ShardInputWithXlaSplitOp(
975     absl::string_view node_name, const bool is_resource, const NodeOut& input,
976     const PartialTensorShape& partial_tensor_shape,
977     const std::vector<Node*>& control_inputs,
978     const std::vector<Node*>& control_outputs, const DataType dtype,
979     const xla::OpSharding& sharding, Graph* graph) {
980   const int repeat = sharding.replicate_on_last_tile_dim()
981                          ? *sharding.tile_assignment_dimensions().rbegin()
982                          : 1;
983   const int num_shards = sharding.tile_assignment_devices_size() / repeat;
984 
985   TF_ASSIGN_OR_RETURN(
986       Node * xla_split,
987       CreateXlaSplitOp(node_name, is_resource, input, partial_tensor_shape,
988                        control_inputs, control_outputs, dtype, num_shards,
989                        sharding, graph));
990 
991   std::vector<NodeOut> sharded_inputs_list(
992       sharding.tile_assignment_devices_size());
993 
994   for (int i = 0; i < num_shards; ++i) {
995     for (int j = 0; j < repeat; ++j) {
996       const int index = i * repeat + j;
997       const int core = sharding.tile_assignment_devices(index);
998       sharded_inputs_list[core] = {xla_split, i};
999     }
1000   }
1001 
1002   return sharded_inputs_list;
1003 }
1004 
1005 // Creates an XlaSplitND op to shard a per-replica arg.
CreateOrGetXlaSplitNodeForShardedPerReplicaArg(const xla::OpSharding & sharding,const int replica_id,const int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,Node * orig_src,const int orig_src_output,Graph * graph,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)1006 xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForShardedPerReplicaArg(
1007     const xla::OpSharding& sharding, const int replica_id,
1008     const int orig_arg_num, DataType dtype,
1009     const PartialTensorShape& partial_tensor_shape, Node* orig_src,
1010     const int orig_src_output, Graph* graph,
1011     std::map<ShardedInputIndex, ShardedInputInfo>*
1012         arg_index_to_sharded_input_map) {
1013   ShardedInputIndex input_index{replica_id, orig_arg_num};
1014   auto iter = arg_index_to_sharded_input_map->find(input_index);
1015   if (iter != arg_index_to_sharded_input_map->end()) {
1016     return iter->second;
1017   }
1018 
1019   TF_ASSIGN_OR_RETURN(
1020       std::vector<NodeOut> sharded_inputs_list,
1021       ShardInputWithXlaSplitOp(
1022           absl::StrCat(orig_src->name(), "/replica_", replica_id, "_split"),
1023           /*is_resource=*/false, /*input=*/{orig_src, orig_src_output},
1024           partial_tensor_shape, /*control_inputs=*/{}, /*control_outputs=*/{},
1025           dtype, sharding, graph));
1026 
1027   ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
1028   (*arg_index_to_sharded_input_map)[input_index] = sharded_input_info;
1029   return sharded_input_info;
1030 }
1031 
1032 // Creates an XlaSplitND op to shard a distributed arg.
CreateOrGetXlaSplitNodeForDistributedArg(const xla::OpSharding & sharding,const int num_replicas,const int replica_id,const int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,Node * orig_src,const int orig_src_output,Graph * graph,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)1033 xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForDistributedArg(
1034     const xla::OpSharding& sharding, const int num_replicas,
1035     const int replica_id, const int orig_arg_num, DataType dtype,
1036     const PartialTensorShape& partial_tensor_shape, Node* orig_src,
1037     const int orig_src_output, Graph* graph,
1038     std::map<ShardedInputIndex, ShardedInputInfo>*
1039         arg_index_to_sharded_input_map) {
1040   ShardedInputIndex input_index{replica_id, orig_arg_num};
1041   auto iter = arg_index_to_sharded_input_map->find(input_index);
1042   if (iter != arg_index_to_sharded_input_map->end()) {
1043     return iter->second;
1044   }
1045 
1046   TF_ASSIGN_OR_RETURN(
1047       std::vector<NodeOut> sharded_inputs_list,
1048       ShardInputWithXlaSplitOp(
1049           absl::StrCat(orig_src->name(), "/distributed_split"),
1050           /*is_resource=*/false, /*input=*/{orig_src, orig_src_output},
1051           partial_tensor_shape, /*control_inputs=*/{}, /*control_outputs=*/{},
1052           dtype, sharding, graph));
1053 
1054   ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
1055   for (int replica = 0; replica < num_replicas; ++replica) {
1056     (*arg_index_to_sharded_input_map)[{replica, orig_arg_num}] =
1057         sharded_input_info;
1058   }
1059   return sharded_input_info;
1060 }
1061 
1062 // Creates an ReadVariableXlaSplitND op to shard a variable arg.
CreateOrGetXlaSplitNodeForVariableArg(const xla::OpSharding & sharding,const int num_replicas,const int replica_id,const int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,Node * orig_src,const int orig_src_output,Graph * graph,std::vector<Node * > * to_be_removed_nodes,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)1063 xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForVariableArg(
1064     const xla::OpSharding& sharding, const int num_replicas,
1065     const int replica_id, const int orig_arg_num, DataType dtype,
1066     const PartialTensorShape& partial_tensor_shape, Node* orig_src,
1067     const int orig_src_output, Graph* graph,
1068     std::vector<Node*>* to_be_removed_nodes,
1069     std::map<ShardedInputIndex, ShardedInputInfo>*
1070         arg_index_to_sharded_input_map) {
1071   ShardedInputIndex input_index{replica_id, orig_arg_num};
1072   auto iter = arg_index_to_sharded_input_map->find(input_index);
1073   if (iter != arg_index_to_sharded_input_map->end()) {
1074     return iter->second;
1075   }
1076 
1077   DCHECK_EQ(orig_src->type_string(), "ReadVariableOp");
1078   std::vector<Node*> control_outputs;
1079   std::vector<const Edge*> edges_to_remove;
1080   for (const Edge* edge : orig_src->out_edges()) {
1081     if (edge->IsControlEdge()) {
1082       control_outputs.push_back(edge->dst());
1083     }
1084     edges_to_remove.push_back(edge);
1085   }
1086 
1087   to_be_removed_nodes->push_back(orig_src);
1088 
1089   const Edge* resource = nullptr;
1090   TF_RETURN_IF_ERROR(orig_src->input_edge(0, &resource));
1091 
1092   std::vector<Node*> control_inputs;
1093   for (const Edge* edge : orig_src->in_edges()) {
1094     if (edge->IsControlEdge()) {
1095       control_inputs.push_back(edge->src());
1096     }
1097   }
1098 
1099   TF_ASSIGN_OR_RETURN(
1100       std::vector<NodeOut> sharded_inputs_list,
1101       ShardInputWithXlaSplitOp(
1102           absl::StrCat(resource->src()->name(), "/read_variable_split"),
1103           /*is_resource=*/true,
1104           /*input=*/{resource->src(), resource->src_output()},
1105           partial_tensor_shape, control_inputs, control_outputs, dtype,
1106           sharding, graph));
1107 
1108   for (const Edge* edge : edges_to_remove) {
1109     graph->RemoveControlEdge(edge);
1110   }
1111 
1112   DCHECK(orig_src->out_edges().empty());
1113 
1114   ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
1115   for (int replica = 0; replica < num_replicas; ++replica) {
1116     ShardedInputIndex idx{replica, orig_arg_num};
1117     // Refrain from overwriting, if dummy inputs were already placed instead.
1118     arg_index_to_sharded_input_map->insert({idx, sharded_input_info});
1119   }
1120   return sharded_input_info;
1121 }
1122 
1123 // Creates a concat node to be used for aggregating sharded retvals across
1124 // logical cores.
CreateConcatNode(int dim,int num_splits,DataType dtype,absl::string_view name_prefix,const std::vector<NodeOut> & inputs,Graph * graph,absl::string_view device)1125 xla::StatusOr<Node*> CreateConcatNode(int dim, int num_splits, DataType dtype,
1126                                       absl::string_view name_prefix,
1127                                       const std::vector<NodeOut>& inputs,
1128                                       Graph* graph, absl::string_view device) {
1129   // Add a Concat dim node.
1130   NodeDef concat_dim_def;
1131   concat_dim_def.set_name(
1132       graph->NewName(absl::StrCat(name_prefix, "/concat_dim")));
1133   concat_dim_def.set_op("Const");
1134   AddNodeAttr("dtype", DT_INT32, &concat_dim_def);
1135   concat_dim_def.set_device(std::string(device));
1136   TensorProto tensor_proto;
1137   tensor_proto.set_dtype(DT_INT32);
1138   tensor_proto.add_int_val(dim);
1139   TensorShape shape({});
1140   shape.AsProto(tensor_proto.mutable_tensor_shape());
1141   AddNodeAttr("value", tensor_proto, &concat_dim_def);
1142   TF_ASSIGN_OR_RETURN(Node * concat_dim_node, graph->AddNode(concat_dim_def));
1143 
1144   // Add a Concat node.
1145   NodeDef concat_def;
1146   concat_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/concat")));
1147   concat_def.set_op("Concat");
1148   AddNodeAttr("N", num_splits, &concat_def);
1149   AddNodeAttr("T", dtype, &concat_def);
1150   concat_def.add_input(absl::StrCat(concat_dim_node->name(), ":0"));
1151   concat_def.set_device(std::string(device));
1152   for (const auto& i : inputs) {
1153     concat_def.add_input(absl::StrCat(i.node->name(), ":", i.index));
1154   }
1155   TF_ASSIGN_OR_RETURN(Node * concat_node, graph->AddNode(concat_def));
1156 
1157   graph->AddEdge(concat_dim_node, 0, concat_node, 0);
1158 
1159   // 0th input to concat node is a concat dim node. So we start from 1st input
1160   // and add all input edges.
1161   int dst_input = 1;
1162   for (const auto& i : inputs) {
1163     graph->AddEdge(i.node, i.index, concat_node, dst_input);
1164     ++dst_input;
1165   }
1166   return concat_node;
1167 }
1168 
1169 // Adds slice node after concat node to graph for uneven sharding tiled inputs.
CreateSliceNode(DataType dtype,const PartialTensorShape & shape,Node * concat_node,const int concat_out_index,Graph * graph,absl::string_view device)1170 xla::StatusOr<Node*> CreateSliceNode(DataType dtype,
1171                                      const PartialTensorShape& shape,
1172                                      Node* concat_node,
1173                                      const int concat_out_index, Graph* graph,
1174                                      absl::string_view device) {
1175   Status s;
1176   // Add begin node for concat.
1177   NodeDef begin_def;
1178   begin_def.set_name(
1179       graph->NewName(absl::StrCat(concat_node->name(), "/slice_begin")));
1180   begin_def.set_op("Const");
1181   AddNodeAttr("dtype", DT_INT32, &begin_def);
1182   begin_def.set_device(std::string(device));
1183   TensorProto begin_tensor_proto;
1184   begin_tensor_proto.set_dtype(DT_INT32);
1185   for (int i = 0; i < shape.dims(); ++i) {
1186     begin_tensor_proto.add_int_val(0);
1187   }
1188   TensorShape begin_shape({shape.dims()});
1189   begin_shape.AsProto(begin_tensor_proto.mutable_tensor_shape());
1190   AddNodeAttr("value", begin_tensor_proto, &begin_def);
1191   TF_ASSIGN_OR_RETURN(Node * begin_node, graph->AddNode(begin_def));
1192 
1193   // Add size node.
1194   NodeDef size_def;
1195   size_def.set_name(
1196       graph->NewName(absl::StrCat(concat_node->name(), "/slice_size")));
1197   size_def.set_op("Const");
1198   AddNodeAttr("dtype", DT_INT32, &size_def);
1199   size_def.set_device(std::string(device));
1200   TensorProto sizes_tensor_proto;
1201   sizes_tensor_proto.set_dtype(DT_INT32);
1202   for (int i = 0; i < shape.dims(); ++i) {
1203     sizes_tensor_proto.add_int_val(shape.dim_size(i));
1204   }
1205   TensorShape sizes_shape({shape.dims()});
1206   sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
1207   AddNodeAttr("value", sizes_tensor_proto, &size_def);
1208   TF_ASSIGN_OR_RETURN(Node * size_node, graph->AddNode(size_def));
1209 
1210   // Add Slice node.
1211   NodeDef slice_def;
1212   slice_def.set_name(
1213       graph->NewName(absl::StrCat(concat_node->name(), "/slice")));
1214   slice_def.set_op("Slice");
1215   slice_def.set_device(std::string(device));
1216   AddNodeAttr("T", dtype, &slice_def);
1217   AddNodeAttr("Index", DT_INT32, &slice_def);
1218   slice_def.add_input(absl::StrCat(concat_node->name(), ":", concat_out_index));
1219   slice_def.add_input(absl::StrCat(begin_node->name(), ":0"));
1220   slice_def.add_input(absl::StrCat(size_node->name(), ":0"));
1221   TF_ASSIGN_OR_RETURN(Node * slice_node, graph->AddNode(slice_def));
1222   // Add edges for slice node.
1223   graph->AddEdge(concat_node, concat_out_index, slice_node, 0);
1224   graph->AddEdge(begin_node, 0, slice_node, 1);
1225   graph->AddEdge(size_node, 0, slice_node, 2);
1226   return slice_node;
1227 }
1228 
1229 // Creates a set of Concat nodes that aggregates sharded outputs from TPUExecute
1230 // nodes into a single output. Sharded outputs are concatenated along row major
1231 // order. That is, tiled output along 0th dimension will be concatenated last.
CreateConcatNodesForRetval(const xla::OpSharding & sharding,DataType dtype,const PartialTensorShape & inferred_shape,int replica_id,const std::vector<NodeOut> & orig_inputs,Graph * graph,absl::string_view device)1232 xla::StatusOr<Node*> CreateConcatNodesForRetval(
1233     const xla::OpSharding& sharding, DataType dtype,
1234     const PartialTensorShape& inferred_shape, int replica_id,
1235     const std::vector<NodeOut>& orig_inputs, Graph* graph,
1236     absl::string_view device) {
1237   std::map<int, int> split_dimension_map;
1238   TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
1239       sharding, &split_dimension_map));
1240   std::vector<NodeOut> inputs_to_sharded_retval = orig_inputs;
1241   bool has_paddings = false;
1242 
1243   for (auto it = split_dimension_map.rbegin(); it != split_dimension_map.rend();
1244        it++) {
1245     auto dim = it->first;
1246     auto num_splits = it->second;
1247 
1248     int num_concat_nodes = inputs_to_sharded_retval.size() / num_splits;
1249     int input_index_to_concat_node = 0;
1250 
1251     std::vector<NodeOut> new_concat_nodes;
1252     for (int i = 0; i < num_concat_nodes; ++i) {
1253       auto concat_input_it =
1254           inputs_to_sharded_retval.begin() + input_index_to_concat_node;
1255       std::vector<NodeOut> inputs(concat_input_it,
1256                                   concat_input_it + num_splits);
1257       input_index_to_concat_node += num_splits;
1258 
1259       TF_ASSIGN_OR_RETURN(
1260           Node * concat_node,
1261           CreateConcatNode(
1262               dim, num_splits, dtype,
1263               absl::StrCat("sharded_output/replica_", replica_id, "_dim_", dim),
1264               inputs, graph, device));
1265       int64_t paddings = GetPadding(dim, num_splits, inferred_shape);
1266       has_paddings |= paddings > 0;
1267       new_concat_nodes.emplace_back(NodeOut{concat_node, 0});
1268     }
1269     inputs_to_sharded_retval = new_concat_nodes;
1270   }
1271 
1272   TF_RET_CHECK(inputs_to_sharded_retval.size() == 1);
1273   if (has_paddings) {
1274     TF_ASSIGN_OR_RETURN(Node * slice_node,
1275                         CreateSliceNode(dtype, inferred_shape,
1276                                         inputs_to_sharded_retval.at(0).node,
1277                                         /*concat_out_index*/ 0, graph, device));
1278     return slice_node;
1279   }
1280   return inputs_to_sharded_retval.at(0).node;
1281 }
1282 
CreateXlaConcatNode(const xla::OpSharding & sharding,const int replica_id,DataType dtype,const PartialTensorShape & partial_tensor_shape,const std::vector<NodeOut> & orig_inputs,absl::string_view device,Graph * graph)1283 xla::StatusOr<Node*> CreateXlaConcatNode(
1284     const xla::OpSharding& sharding, const int replica_id, DataType dtype,
1285     const PartialTensorShape& partial_tensor_shape,
1286     const std::vector<NodeOut>& orig_inputs, absl::string_view device,
1287     Graph* graph) {
1288   NodeDef xla_concat_def;
1289   xla_concat_def.set_name(graph->NewName(
1290       absl::StrCat("sharded_output/replica_", replica_id, "_concat")));
1291   xla_concat_def.set_op("XlaConcatND");
1292   xla_concat_def.set_device(std::string(device));
1293   AddNodeAttr("T", dtype, &xla_concat_def);
1294   AddNodeAttr("N", static_cast<int64_t>(orig_inputs.size()), &xla_concat_def);
1295   const std::vector<int64_t> num_concats(
1296       sharding.tile_assignment_dimensions().begin(),
1297       sharding.replicate_on_last_tile_dim()
1298           ? std::prev(sharding.tile_assignment_dimensions().end())
1299           : sharding.tile_assignment_dimensions().end());
1300   AddNodeAttr("num_concats", num_concats, &xla_concat_def);
1301   const int rank = sharding.replicate_on_last_tile_dim()
1302                        ? sharding.tile_assignment_dimensions_size() - 1
1303                        : sharding.tile_assignment_dimensions_size();
1304   std::vector<int32> paddings;
1305   paddings.reserve(rank);
1306   for (int dim = 0; dim < rank; ++dim) {
1307     paddings.push_back(GetPadding(dim, sharding.tile_assignment_dimensions(dim),
1308                                   partial_tensor_shape));
1309   }
1310   AddNodeAttr("paddings", paddings, &xla_concat_def);
1311 
1312   TF_ASSIGN_OR_RETURN(Node * xla_concat, graph->AddNode(xla_concat_def));
1313   for (int i = 0, e = orig_inputs.size(); i < e; ++i) {
1314     const NodeOut& input = orig_inputs[i];
1315     graph->AddEdge(input.node, input.index, xla_concat, i);
1316   }
1317   return xla_concat;
1318 }
1319 
1320 // Set the padding ops the same devices as the original inputs. If the original
1321 // inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand
1322 // mode will be triggered, so we don't need to copy the data back to the host
1323 // to do the padding.
SetPaddingNodesDevices(Graph * graph)1324 Status SetPaddingNodesDevices(Graph* graph) {
1325   for (Node* n : graph->op_nodes()) {
1326     bool tpu_padding_attr;
1327     if (n->type_string() == "Pad" &&
1328         GetNodeAttr(n->attrs(), kPostDeviceRewriteAttr, &tpu_padding_attr)
1329             .ok()) {
1330       Node* unpadded_input;
1331       TF_RETURN_IF_ERROR(n->input_node(0, &unpadded_input));
1332 
1333       const string& requested_device = unpadded_input->requested_device();
1334       const string& assigned_device = unpadded_input->assigned_device_name();
1335       if (!requested_device.empty() || !assigned_device.empty()) {
1336         // The output nodes of the original unpadded inputs include the padded
1337         // inputs and real shapes of inputs, we assign those to the same device
1338         // as the original inputs.
1339         for (Node* out : unpadded_input->out_nodes()) {
1340           if (GetNodeAttr(out->attrs(), kPostDeviceRewriteAttr,
1341                           &tpu_padding_attr)
1342                   .ok()) {
1343             out->set_requested_device(requested_device);
1344             out->set_assigned_device_name(assigned_device);
1345           }
1346         }
1347         // There might be a tf.shape node added before TPUCompileOp, we need to
1348         // set its device as well.
1349         for (Node* out : n->out_nodes()) {
1350           if (n->type_string() == "Shape") {
1351             out->set_requested_device(requested_device);
1352             out->set_assigned_device_name(assigned_device);
1353           }
1354         }
1355       }
1356     }
1357   }
1358   return OkStatus();
1359 }
1360 
AssignedOrRequestedDevice(const Node * node)1361 const string& AssignedOrRequestedDevice(const Node* node) {
1362   if (!node->assigned_device_name().empty()) {
1363     return node->assigned_device_name();
1364   }
1365   return node->requested_device();
1366 }
1367 
IsTpuDevice(StringPiece device_string)1368 bool IsTpuDevice(StringPiece device_string) {
1369   DeviceNameUtils::ParsedName device;
1370   return DeviceNameUtils::ParseFullName(device_string, &device) &&
1371          device.type == DEVICE_TPU_NODE;
1372 }
1373 
CanAcceptTPUDevicePropagation(const Node & node)1374 bool CanAcceptTPUDevicePropagation(const Node& node) {
1375   // A set of device ops can be placed on TPU. There is no strict rule of
1376   // thumb to decide which ops should be in the list, but empirically they are
1377   // mostly dummy ops like Identity-like ops or control flow related ops.
1378   // However one can add also add other ops like Pad to allow data stay on TPU.
1379   static const auto place_on_tpu_ops = new absl::flat_hash_set<std::string>(
1380       {"Identity", "IdentityN", "Enter", "Exit", "Switch", "Merge",
1381        "NextIteration", "Shape", "_Retval"});
1382   return place_on_tpu_ops->contains(node.type_string());
1383 }
1384 
CreateOpMetadataFromNode(const Node & node)1385 xla::OpMetadata CreateOpMetadataFromNode(const Node& node) {
1386   xla::OpMetadata metadata;
1387   metadata.set_op_type(node.type_string());
1388   metadata.set_op_name(node.name());
1389   return metadata;
1390 }
1391 
1392 // Helper struct holding node (nullable) and associated sharding.
1393 struct NodeAndSharding {
NodeAndShardingtensorflow::__anon71bd90b30111::NodeAndSharding1394   explicit NodeAndSharding(const Node* node, const xla::OpSharding& sharding)
1395       : node(node), sharding(sharding) {}
1396 
1397   const Node* node;
1398   xla::OpSharding sharding;
1399 };
1400 
1401 // Validate sharding configuration derived from XlaSharding attribute.
1402 // Infer the core id from the OpSharding, if necessary.
ParseAndValidateSharding(const NodeAndSharding & node_and_sharding,const int num_cores_per_replica,int64_t * inferred_core_id,absl::optional<NodeAndSharding> * result)1403 Status ParseAndValidateSharding(const NodeAndSharding& node_and_sharding,
1404                                 const int num_cores_per_replica,
1405                                 int64_t* inferred_core_id,
1406                                 absl::optional<NodeAndSharding>* result) {
1407   if (node_and_sharding.sharding.type() == xla::OpSharding::MAXIMAL) {
1408     int64_t core_annotation =
1409         node_and_sharding.sharding.tile_assignment_devices(0);
1410     TF_RETURN_IF_ERROR(
1411         ValidateCoreNumber(core_annotation, num_cores_per_replica));
1412     if (*inferred_core_id == -1 || *inferred_core_id > core_annotation) {
1413       *inferred_core_id = core_annotation;
1414       result->emplace(node_and_sharding);
1415     }
1416   } else {
1417     if (node_and_sharding.sharding.type() == xla::OpSharding::OTHER) {
1418       for (int64_t core :
1419            node_and_sharding.sharding.tile_assignment_devices()) {
1420         TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
1421       }
1422     }
1423 
1424     if (!result->has_value()) {
1425       *result = node_and_sharding;
1426     } else {
1427       std::string result_value_serialized;
1428       xla::OpSharding result_value = result->value().sharding;
1429       result_value.clear_metadata();
1430       SerializeToStringDeterministic(result_value, &result_value_serialized);
1431 
1432       std::string sharding_serialized;
1433       xla::OpSharding sharding = node_and_sharding.sharding;
1434       sharding.clear_metadata();
1435       SerializeToStringDeterministic(sharding, &sharding_serialized);
1436 
1437       // TODO(lyandy): Choose the more granular sharding instead of always
1438       // assigning to core 0 (maximal).
1439       if (result_value_serialized != sharding_serialized) {
1440         // We see different shardings, assign to core 0.
1441         auto core_zero_sharding = xla::sharding_builder::AssignDevice(0);
1442         DCHECK_NE(node_and_sharding.node, nullptr);
1443         *core_zero_sharding.add_metadata() =
1444             CreateOpMetadataFromNode(*node_and_sharding.node);
1445         result->emplace(
1446             NodeAndSharding(node_and_sharding.node, core_zero_sharding));
1447       }
1448     }
1449   }
1450   return OkStatus();
1451 }
1452 
1453 // As XlaSharding node may be followed by Cast op or an Identity op,
1454 // recursively walk the graph and aggregate nodes connectd to
1455 // |input_node| or Cast/Identity op following the |input_node|.
FindNodesMaybeContainingShardingInfo(const Node & input_node,std::vector<const Node * > * nodes)1456 void FindNodesMaybeContainingShardingInfo(const Node& input_node,
1457                                           std::vector<const Node*>* nodes) {
1458   if (input_node.IsIdentity() || input_node.type_string() == "Cast") {
1459     for (const Node* connected_node : input_node.out_nodes())
1460       FindNodesMaybeContainingShardingInfo(*connected_node, nodes);
1461   }
1462   nodes->emplace_back(&input_node);
1463 }
1464 
1465 // Parse sharding configuration from |node| or it's adjacent nodes.
1466 // XlaSharding configuration may be derived from
1467 //   a) Connected Identity op node.
1468 //   b) Connected Cast op node.
1469 xla::StatusOr<absl::optional<NodeAndSharding>>
ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,const Node & node)1470 ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,
1471                                    const Node& node) {
1472   // If |node| has `device` attribute or is a XlaSharding op,
1473   // return the parsed OpSharding.
1474   TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
1475                       ParseShardingFromDevice(node, num_cores_per_replica,
1476                                               /*add_metadata=*/true));
1477   if (sharding.has_value()) {
1478     return absl::optional<NodeAndSharding>(NodeAndSharding(&node, *sharding));
1479   }
1480 
1481   // XlaShardingOp may be followed by an identity or followed by identity
1482   // and a Cast op.
1483   std::vector<const Node*> potential_nodes_with_input_sharding;
1484   FindNodesMaybeContainingShardingInfo(node,
1485                                        &potential_nodes_with_input_sharding);
1486   for (const Node* maybe_node_with_sharding_info :
1487        potential_nodes_with_input_sharding) {
1488     if (maybe_node_with_sharding_info->type_string() != "XlaSharding") continue;
1489 
1490     TF_ASSIGN_OR_RETURN(
1491         absl::optional<xla::OpSharding> sharding_config,
1492         ParseShardingFromDevice(*maybe_node_with_sharding_info,
1493                                 num_cores_per_replica, /*add_metadata=*/true));
1494     if (sharding_config.has_value()) {
1495       return absl::optional<NodeAndSharding>(
1496           NodeAndSharding(maybe_node_with_sharding_info, *sharding_config));
1497     }
1498   }
1499   return absl::optional<NodeAndSharding>();
1500 }
1501 
1502 // Walk the graph from an argument node to find OpSharding configuration
1503 // from its neighbor nodes. Sharding configuration may be inferred from
1504 //  1) Parsing XlaSharding attribute from neighboring node.
1505 //  2) If argument node is a resource, then by parsing adjacent nodes
1506 //     of the connected ReadVariable op.
ParseAndValidateShardingFromNeighbors(const int num_cores_per_replica,const std::string & arg_node_name,const Node & neighbor_node,int64_t * inferred_core_id,bool * is_fast_mem,absl::optional<NodeAndSharding> * result)1507 Status ParseAndValidateShardingFromNeighbors(
1508     const int num_cores_per_replica, const std::string& arg_node_name,
1509     const Node& neighbor_node, int64_t* inferred_core_id, bool* is_fast_mem,
1510     absl::optional<NodeAndSharding>* result) {
1511   if (neighbor_node.attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
1512     *is_fast_mem = true;
1513     VLOG(2) << "place " << neighbor_node.name() << " on fast memory because "
1514             << arg_node_name << " has " << TPU_FAST_MEM_ATTR << " attribute";
1515   }
1516 
1517   // XlaSharding information may be encoded on node directly connected to the
1518   // argument node.
1519   TF_ASSIGN_OR_RETURN(
1520       absl::optional<NodeAndSharding> node_and_sharding,
1521       ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node));
1522   if (node_and_sharding.has_value()) {
1523     TF_RETURN_IF_ERROR(ParseAndValidateSharding(
1524         *node_and_sharding, num_cores_per_replica, inferred_core_id, result));
1525     return OkStatus();
1526   }
1527 
1528   // When we use variable in TPU computation, we always have a
1529   // XlaSharding op followed by a ReadVariableOp. As so, correctly parse
1530   // the users of ReadVariableOp for potential sharding configuration.
1531   if (neighbor_node.type_string() == "ReadVariableOp") {
1532     for (const Edge* e : neighbor_node.out_edges()) {
1533       if (e->IsControlEdge()) continue;
1534 
1535       if (e->dst()->attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
1536         *is_fast_mem = true;
1537         VLOG(2) << "place " << arg_node_name << " on fast memory because "
1538                 << e->dst()->name() << TPU_FAST_MEM_ATTR << " attribute";
1539       }
1540 
1541       TF_ASSIGN_OR_RETURN(
1542           absl::optional<NodeAndSharding> node_and_sharding,
1543           ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst()));
1544       if (node_and_sharding.has_value()) {
1545         TF_RETURN_IF_ERROR(ParseAndValidateSharding(*node_and_sharding,
1546                                                     num_cores_per_replica,
1547                                                     inferred_core_id, result));
1548         return OkStatus();
1549       }
1550     }
1551   }
1552   return OkStatus();
1553 }
1554 
1555 }  // namespace
1556 
1557 // Inputs:
1558 //   replication_spec_string: the device to which the TPUReplicate node was
1559 //     assigned.
1560 //   device_set: the set of TF devices.
1561 // Outputs:
1562 //   tpu_compilation_device: the name of the TPU compilation device.
1563 //   num_tpus_per_task: the number of TPUs in each task. Verifies that all tasks
1564 //     have the same number of TPU devices.
1565 //   tpu_devices: the TPU devices, indexed by [task][device].
GetTPUDeviceNames(const string & replication_spec_string,const DeviceSet & device_set,string * tpu_compilation_device,int * num_tpus_per_task,std::vector<std::vector<Device * >> * tpu_devices)1566 static Status GetTPUDeviceNames(
1567     const string& replication_spec_string, const DeviceSet& device_set,
1568     string* tpu_compilation_device, int* num_tpus_per_task,
1569     std::vector<std::vector<Device*>>* tpu_devices) {
1570   // TODO(b/110910013) GetSystemDevice parses the spec and returns the name of
1571   // the tpu_system device, which we replace by the cpu device. We do this
1572   // replacement because we want to place the TPUCompileOp (and the compile
1573   // assert op) explicitly on cpu devices on the same job as the tpu_system
1574   // device.
1575   DeviceNameUtils::ParsedName replication_spec;
1576   Device* replication_device;
1577   TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetSystemDevice(
1578       replication_spec_string, device_set, &replication_spec,
1579       &replication_device));
1580   *tpu_compilation_device =
1581       str_util::StringReplace(replication_device->name(), DEVICE_TPU_SYSTEM,
1582                               DEVICE_CPU, /*replace_all=*/true);
1583 
1584   // Finds the set of TPU devices attached to the tasks in the job.
1585   TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetTPUDevices(
1586       replication_spec, device_set, num_tpus_per_task, tpu_devices));
1587 
1588   return OkStatus();
1589 }
1590 
1591 // Parses the topology attribute of TPUReplicate, and populates *topology with
1592 // a physical mesh coordinate to (task, device) mapping.
ParseTopologyAttr(const string & topology_attr,const tpu::TpuTopologyExternal & tpu_topology,int num_tasks,int num_tpus_per_task,xla::Array4D<std::pair<int,int>> * topology)1593 static Status ParseTopologyAttr(const string& topology_attr,
1594                                 const tpu::TpuTopologyExternal& tpu_topology,
1595                                 int num_tasks, int num_tpus_per_task,
1596                                 xla::Array4D<std::pair<int, int>>* topology) {
1597   static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
1598   tpu::TopologyProto proto;
1599   proto.ParseFromString(topology_attr);
1600   if (proto.mesh_shape_size() != kTPUTopologyRank) {
1601     return errors::InvalidArgument("TPU topology must be rank ",
1602                                    kTPUTopologyRank);
1603   }
1604   if (proto.num_tasks() != num_tasks) {
1605     return errors::InvalidArgument("Mismatched number of TPU tasks (",
1606                                    proto.num_tasks(), " != ", num_tasks, ")");
1607   }
1608   if (proto.num_tpu_devices_per_task() != num_tpus_per_task) {
1609     return errors::InvalidArgument("Mismatched number of TPUs per task (",
1610                                    proto.num_tpu_devices_per_task(),
1611                                    " != ", num_tpus_per_task, ").");
1612   }
1613   if (proto.device_coordinates_size() !=
1614       num_tasks * num_tpus_per_task * kTPUTopologyRank) {
1615     return errors::InvalidArgument(
1616         "device coordinates should be ", num_tasks, "x", num_tpus_per_task, "x",
1617         kTPUTopologyRank, "; got ", proto.device_coordinates_size());
1618   }
1619 
1620   int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
1621   *topology = xla::Array4D<std::pair<int, int>>(
1622       tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
1623       tpu_topology.chip_bounds().z, devices_per_chip, {-1, -1});
1624   int pos = 0;
1625   for (int task = 0; task < num_tasks; ++task) {
1626     for (int device = 0; device < num_tpus_per_task; ++device) {
1627       int32_t x = proto.device_coordinates(pos++);
1628       int32_t y = proto.device_coordinates(pos++);
1629       int32_t z = proto.device_coordinates(pos++);
1630       int32_t core = proto.device_coordinates(pos++);
1631 
1632       if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
1633           core >= devices_per_chip) {
1634         return errors::InvalidArgument(
1635             "Mesh coordinates (", x, ",", y, ",", z, ",", core,
1636             ") are not valid for the current TPU topology");
1637       }
1638       if ((*topology)(x, y, z, core).first != -1) {
1639         return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
1640                                        ",", z, ",", core, ") in TPU topology");
1641       }
1642       (*topology)(x, y, z, core) = {task, device};
1643     }
1644   }
1645   return OkStatus();
1646 }
1647 
1648 // Parses the value of the device_assignment attribute to TPUReplicate.
1649 // Populates *device_assignment; *device_assignment must be a 2D array with
1650 // shape (num_replicas, num_cores_per_replica).
ParseDeviceAssignmentAttr(absl::Span<const int> device_assignment_attr,const tpu::TpuTopologyExternal & tpu_topology,int num_replicas,int num_cores_per_replica,xla::Array2D<tpu::TpuCoreLocationExternal> * device_assignment)1651 static Status ParseDeviceAssignmentAttr(
1652     absl::Span<const int> device_assignment_attr,
1653     const tpu::TpuTopologyExternal& tpu_topology, int num_replicas,
1654     int num_cores_per_replica,
1655     xla::Array2D<tpu::TpuCoreLocationExternal>* device_assignment) {
1656   static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
1657 
1658   const int64_t device_assignment_attr_size =
1659       num_replicas * num_cores_per_replica * kTPUTopologyRank;
1660   if (device_assignment_attr.size() != device_assignment_attr_size) {
1661     return errors::InvalidArgument(
1662         "Length of device_assignment attribute must be equal to num_replicas (",
1663         num_replicas, ") * num_cores_per_replica (", num_cores_per_replica,
1664         ") * ", kTPUTopologyRank, " got ", device_assignment_attr.size());
1665   }
1666   for (int core : device_assignment_attr) {
1667     if (core < 0 || core >= kTPUMaxTopologySize) {
1668       return errors::InvalidArgument(
1669           "Invalid core number in device assignment: ", core);
1670     }
1671   }
1672 
1673   *device_assignment = xla::Array2D<tpu::TpuCoreLocationExternal>(
1674       num_replicas, num_cores_per_replica);
1675   int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
1676   xla::Array4D<int> replica_assignment(
1677       tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
1678       tpu_topology.chip_bounds().z, devices_per_chip, -1);
1679   int pos = 0;
1680   for (int replica = 0; replica < num_replicas; ++replica) {
1681     for (int logical_core = 0; logical_core < num_cores_per_replica;
1682          ++logical_core) {
1683       int32_t x = device_assignment_attr[pos++];
1684       int32_t y = device_assignment_attr[pos++];
1685       int32_t z = device_assignment_attr[pos++];
1686       int32_t core = device_assignment_attr[pos++];
1687 
1688       if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
1689           core >= devices_per_chip) {
1690         return errors::InvalidArgument(
1691             "Mesh coordinates (", x, ",", y, ",", core,
1692             ") are not valid for the current TPU topology");
1693       }
1694       tpu::TpuCoreLocationExternal core_location =
1695           tpu_topology.Core(kTensorCore, x, y, z, core);
1696 
1697       if (replica_assignment(x, y, z, core) != -1) {
1698         return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
1699                                        ",", z, ",", core,
1700                                        ") in TPU device assignment");
1701       }
1702       replica_assignment(x, y, z, core) = replica;
1703       (*device_assignment)(replica, logical_core) = core_location;
1704     }
1705   }
1706   return OkStatus();
1707 }
1708 
1709 // Builds TensorFlow device assignments for the special case of a single core
1710 // computation that is replicated to every core in the mesh.
1711 // LINT.IfChange
BuildFullMeshDeviceAssignment(int num_replicas,const std::vector<std::vector<Device * >> & tpu_devices,int num_tasks,int num_tpus_per_task,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock)1712 static Status BuildFullMeshDeviceAssignment(
1713     int num_replicas, const std::vector<std::vector<Device*>>& tpu_devices,
1714     int num_tasks, int num_tpus_per_task,
1715     std::vector<std::vector<string>>* tf_device_assignment,
1716     std::vector<int>* devices_to_lock) {
1717   // Assign TensorFlow devices to replicas arbitrarily.
1718   for (int i = 0; i < num_replicas; ++i) {
1719     int task = i / num_tpus_per_task;
1720     int device = i % num_tpus_per_task;
1721     TF_RET_CHECK(task >= 0 && task < num_tasks);
1722     TF_RET_CHECK(device >= 0 && device < num_tpus_per_task);
1723 
1724     // We don't actually know which TF device corresponds to which physical
1725     // device, but it doesn't matter—they're all identical.
1726     (*tf_device_assignment)[i] = {tpu_devices[task][device]->name()};
1727     devices_to_lock->push_back(i);
1728   }
1729   return OkStatus();
1730 }
1731 // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
1732 
1733 // Builds TensorFlow device assignments for a replicated computation and convert
1734 // device_assignment into xla_device_assignment.
BuildGeneralDeviceAssignment(int num_replicas,int num_cores_per_replica,const std::vector<std::vector<Device * >> & tpu_devices,const xla::Array2D<tpu::TpuCoreLocationExternal> & device_assignment,const xla::Array4D<std::pair<int,int>> & topology,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment)1735 static Status BuildGeneralDeviceAssignment(
1736     int num_replicas, int num_cores_per_replica,
1737     const std::vector<std::vector<Device*>>& tpu_devices,
1738     const xla::Array2D<tpu::TpuCoreLocationExternal>& device_assignment,
1739     const xla::Array4D<std::pair<int, int>>& topology,
1740     std::vector<std::vector<string>>* tf_device_assignment,
1741     std::vector<int>* devices_to_lock,
1742     std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
1743   // Assign TensorFlow devices to each computation's replicas according to
1744   // device_assignment and 'topology'.
1745   *xla_device_assignment = absl::make_unique<xla::DeviceAssignment>(
1746       num_replicas, num_cores_per_replica);
1747   for (int replica = 0; replica < num_replicas; ++replica) {
1748     for (int computation = 0; computation < num_cores_per_replica;
1749          ++computation) {
1750       const tpu::TpuCoreLocationExternal& core_location =
1751           device_assignment(replica, computation);
1752 
1753       int task;
1754       int device;
1755       std::tie(task, device) =
1756           topology(core_location.chip_coordinates().x,
1757                    core_location.chip_coordinates().y,
1758                    core_location.chip_coordinates().z, core_location.index());
1759 
1760       CHECK_LT(computation, num_cores_per_replica);
1761       (**xla_device_assignment)(replica, computation) = core_location.Id();
1762 
1763       // The communication pattern between replicas will be determined later by
1764       // BuildAllReduceRing.
1765       TF_RET_CHECK(task >= 0 && task < tpu_devices.size());
1766       TF_RET_CHECK(device >= 0 && device < tpu_devices[task].size());
1767       (*tf_device_assignment)[replica].push_back(
1768           tpu_devices[task][device]->name());
1769       devices_to_lock->push_back((task * tpu_devices[task].size()) + device);
1770     }
1771   }
1772   return OkStatus();
1773 }
1774 
BuildDeviceAssignment(const tpu::TpuTopologyExternal & tpu_topology,int num_tpus_per_task,const std::vector<std::vector<Device * >> & tpu_devices,int num_replicas,int num_cores_per_replica,const string & topology_attr,absl::Span<const int> device_assignment_attr,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment)1775 /*static*/ Status DistributedTPURewritePass::BuildDeviceAssignment(
1776     const tpu::TpuTopologyExternal& tpu_topology, int num_tpus_per_task,
1777     const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas,
1778     int num_cores_per_replica, const string& topology_attr,
1779     absl::Span<const int> device_assignment_attr,
1780     std::vector<std::vector<string>>* tf_device_assignment,
1781     std::vector<int>* devices_to_lock,
1782     std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
1783   const int num_tasks = tpu_devices.size();
1784   const int num_tpu_devices = num_tasks * num_tpus_per_task;
1785   VLOG(2) << "num_tasks=" << num_tasks
1786           << " num_tpus_per_task=" << num_tpus_per_task;
1787 
1788   // Checks num_replicas is sane first to avoid integer overflow.
1789   if (num_replicas > num_tpu_devices) {
1790     return errors::InvalidArgument("Requested num_replicas=", num_replicas,
1791                                    " but there are only ", num_tpu_devices,
1792                                    " cores in the TPU topology.");
1793   }
1794   if (num_replicas * num_cores_per_replica > num_tpu_devices) {
1795     return errors::InvalidArgument(
1796         "Requested num_replicas=", num_replicas, " with ",
1797         num_cores_per_replica, " cores per replica, but there are only ",
1798         num_tpu_devices, " cores in the TPU topology");
1799   }
1800 
1801   tf_device_assignment->clear();
1802   tf_device_assignment->resize(num_replicas);
1803 
1804   devices_to_lock->clear();
1805   devices_to_lock->reserve(num_replicas * num_cores_per_replica);
1806 
1807   // Special case: we allow the user to omit the topology and device assignment
1808   // information in two cases:
1809   // * there is only one replica and one core per replica. In this case, we
1810   //   don't need to know topology information because we don't communicate with
1811   //   other cores.
1812   // * the number of replicas is equal to the number of cores in the slice. In
1813   //   this case, all cores are running the same program so we don't need to
1814   //   know which is which.
1815   if (topology_attr.empty()) {
1816     // LINT.IfChange
1817     if (num_replicas != 1 && num_replicas != num_tpu_devices) {
1818       return errors::InvalidArgument(
1819           "TPUReplicate asked to create ", num_replicas,
1820           " replicas, but the number of cores in the TPU topology is ",
1821           num_tpu_devices,
1822           " and no TPU device assignment was supplied. "
1823           "A TPU device assignment is required if the number of replicas is "
1824           "not 1 or the number of cores in the topology (",
1825           num_tpu_devices, ")");
1826     }
1827 
1828     if (num_cores_per_replica != 1) {
1829       return errors::InvalidArgument(
1830           "A TPU topology must be provided if num_cores_per_replica != 1");
1831     }
1832 
1833     if (!device_assignment_attr.empty()) {
1834       return errors::InvalidArgument(
1835           "A TPU topology must be provided if device_assignment_attr is "
1836           "non-empty");
1837     }
1838 
1839     // If there is only one replica, assign the Tensorflow computation to task 0
1840     // device 0, and leave the XLA device assignment empty. We don't know which
1841     // core this is in the TPU topology, but it doesn't matter—we don't need to
1842     // communicate with any other cores.
1843     if (num_replicas == 1) {
1844       (*tf_device_assignment)[0] = {tpu_devices[0][0]->name()};
1845       devices_to_lock->push_back(0);
1846       return OkStatus();
1847     }
1848 
1849     // Otherwise, num_replicas is equal to the number of cores, and we build a
1850     // device assignment that covers the entire mesh. We do not need to know
1851     // the topology to do so because all cores are identical.
1852     return BuildFullMeshDeviceAssignment(num_replicas, tpu_devices, num_tasks,
1853                                          num_tpus_per_task,
1854                                          tf_device_assignment, devices_to_lock);
1855     // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
1856   }
1857 
1858   // Array that maps mesh coordinates to {TF task, TF TPU device #} pairs.
1859   xla::Array4D<std::pair<int, int>> topology;
1860   TF_RETURN_IF_ERROR(ParseTopologyAttr(topology_attr, tpu_topology, num_tasks,
1861                                        num_tpus_per_task, &topology));
1862 
1863   // Array that maps logical (replica, core) pairs to physical mesh coordinates.
1864   xla::Array2D<tpu::TpuCoreLocationExternal> device_assignment;
1865   TF_RETURN_IF_ERROR(ParseDeviceAssignmentAttr(
1866       device_assignment_attr, tpu_topology, num_replicas, num_cores_per_replica,
1867       &device_assignment));
1868 
1869   return BuildGeneralDeviceAssignment(
1870       num_replicas, num_cores_per_replica, tpu_devices, device_assignment,
1871       topology, tf_device_assignment, devices_to_lock, xla_device_assignment);
1872 }
1873 
GetComputationForTPUReplicateOp(const NameAttrList & function,FunctionLibraryRuntime * flr,Graph * computation,DataTypeVector * arg_types,DataTypeVector * retval_types)1874 Status DistributedTPURewritePass::GetComputationForTPUReplicateOp(
1875     const NameAttrList& function, FunctionLibraryRuntime* flr,
1876     Graph* computation, DataTypeVector* arg_types,
1877     DataTypeVector* retval_types) {
1878   FunctionLibraryRuntime::Handle handle;
1879 
1880   TF_RETURN_IF_ERROR(
1881       flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
1882 
1883   const FunctionBody* fbody = flr->GetFunctionBody(handle);
1884 
1885   CopyGraph(*fbody->graph, computation);
1886   *arg_types = fbody->arg_types;
1887   *retval_types = fbody->ret_types;
1888   return OkStatus();
1889 }
1890 
1891 // Grab the InferredShape corresponding to an edge input.
GetEdgeShape(const GraphShapeInfo & shape_info,const Edge & edge,const InferredShape ** info)1892 static Status GetEdgeShape(const GraphShapeInfo& shape_info, const Edge& edge,
1893                            const InferredShape** info) {
1894   auto it = shape_info.find(edge.src()->name());
1895   if (it == shape_info.end()) {
1896     return errors::InvalidArgument(
1897         "Input to replicated TPU computation is missing InferredShape: ",
1898         edge.src()->name());
1899   }
1900   TF_RET_CHECK(it->second.size() > edge.src_output());
1901   *info = &it->second[edge.src_output()];
1902   return OkStatus();
1903 }
1904 
GetArgAndRetvalShapes(const GraphShapeInfo & shape_info,const Node & node,const ParameterInfo & params_info,std::vector<InferredShape> * arg_shapes,std::vector<InferredShape> * retval_shapes)1905 Status DistributedTPURewritePass::GetArgAndRetvalShapes(
1906     const GraphShapeInfo& shape_info, const Node& node,
1907     const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes,
1908     std::vector<InferredShape>* retval_shapes) {
1909   std::vector<const Edge*> input_edges;
1910   TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
1911 
1912   // If any replica's arg shape is unknown, we will mark the computation's arg
1913   // shape as being unknown. If the shapes differ the TpuExecute Op will raise a
1914   // runtime error.
1915   std::vector<bool> any_replica_shape_unknown(
1916       params_info.NumInputsToEachReplica());
1917   arg_shapes->clear();
1918   arg_shapes->resize(params_info.NumInputsToEachReplica());
1919   TF_RET_CHECK(input_edges.size() == params_info.NumInputsFromHost());
1920   // Determines the shapes of the per-replica arguments and checks that all
1921   // replicas have identical shapes.
1922   int64_t edge_pos = 0;
1923   auto check_shape = [&](int input_index) -> Status {
1924     const InferredShape* info;
1925     TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1926     ++edge_pos;
1927 
1928     if ((info->handle_type == DT_INVALID && !info->shape.IsFullyDefined()) ||
1929         (info->handle_type != DT_INVALID &&
1930          !info->handle_shape.IsFullyDefined())) {
1931       any_replica_shape_unknown[input_index] = true;
1932     }
1933     xla::StatusOr<InferredShape> status =
1934         MergeInferredShapes((*arg_shapes)[input_index], *info);
1935     if (!status.ok()) {
1936       return errors::InvalidArgument(
1937           "Mismatched shapes for input ", input_index, ": ",
1938           (*arg_shapes)[input_index].shape.DebugString(), " vs. ",
1939           info->shape.DebugString());
1940     }
1941     (*arg_shapes)[input_index] = status.ValueOrDie();
1942     return OkStatus();
1943   };
1944 
1945   for (int64_t i = 0; i < params_info.NumReplicas(); ++i) {
1946     for (int64_t j = 0; j < params_info.NumPerReplicaArgs(); ++j) {
1947       TF_RETURN_IF_ERROR(check_shape(j));
1948     }
1949   }
1950 
1951   for (int64_t i = 0; i < params_info.NumDistributedArgs(); ++i) {
1952     TF_RETURN_IF_ERROR(check_shape(params_info.NumPerReplicaArgs() + i));
1953   }
1954 
1955   for (int64_t i = 0;
1956        i < params_info.NumPerReplicaArgs() + params_info.NumDistributedArgs();
1957        ++i) {
1958     if (any_replica_shape_unknown[i]) {
1959       (*arg_shapes)[i].shape = PartialTensorShape();
1960       (*arg_shapes)[i].handle_shape = PartialTensorShape();
1961     }
1962   }
1963 
1964   // Determines the shape of the broadcast arguments.
1965   for (int64_t i = 0; i < params_info.NumBroadcastArgs(); ++i) {
1966     TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
1967     const InferredShape* info;
1968     TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1969     (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
1970                   params_info.NumDistributedArgs()]
1971         .shape = info->shape;
1972     ++edge_pos;
1973   }
1974 
1975   // Determines the handle shape and handle type of the resource variable
1976   // arguments.
1977   for (int64_t i = 0; i < params_info.NumVariables(); ++i) {
1978     TF_RET_CHECK(node.input_type(edge_pos) == DT_RESOURCE);
1979     const InferredShape* info;
1980     TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1981     InferredShape& arg_shape =
1982         (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
1983                       params_info.NumDistributedArgs() +
1984                       params_info.NumBroadcastArgs()];
1985     arg_shape.shape = TensorShape();  // Variables are always scalars.
1986     arg_shape.handle_shape = info->handle_shape;
1987     arg_shape.handle_type = info->handle_type;
1988     TF_RET_CHECK(arg_shape.handle_type != DT_INVALID)
1989         << " input edge: " << input_edges[edge_pos]->DebugString();
1990     ++edge_pos;
1991   }
1992 
1993   // Determines the shape of the guaranteed constants.
1994   // TODO(vinuraja): Can be removed because they are not required for any
1995   // calculations. Leaving them here for symmetry with other structures like
1996   // arg_types, arg_sharding, etc.
1997   for (int64_t i = 0; i < params_info.NumGuaranteedConstants(); ++i) {
1998     TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
1999     const InferredShape* info;
2000     TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
2001     (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
2002                   params_info.NumDistributedArgs() +
2003                   params_info.NumBroadcastArgs() + params_info.NumVariables()]
2004         .shape = info->shape;
2005     ++edge_pos;
2006   }
2007 
2008   // Extract the return value shapes.
2009   auto it = shape_info.find(node.name());
2010   retval_shapes->clear();
2011   if (it != shape_info.end()) {
2012     TF_RET_CHECK(it->second.size() >= node.num_outputs());
2013     retval_shapes->resize(node.num_outputs());
2014     for (int i = 0; i < node.num_outputs(); ++i) {
2015       (*retval_shapes)[i].shape = it->second[i].shape;
2016     }
2017   } else if (node.num_outputs() > 0) {
2018     return errors::InvalidArgument(
2019         "Replicated TPU computation is missing InferredShape: ",
2020         FormatNodeForError(node));
2021   }
2022   return OkStatus();
2023 }
2024 
2025 // Verifies that all nodes have legal sharding.
ValidateCoreNumbers(const Graph & graph,int num_cores_per_replica)2026 static Status ValidateCoreNumbers(const Graph& graph,
2027                                   int num_cores_per_replica) {
2028   for (Node* n : graph.nodes()) {
2029     TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
2030                         ParseShardingFromDevice(*n, num_cores_per_replica,
2031                                                 /*add_metadata=*/true));
2032   }
2033   return OkStatus();
2034 }
2035 
InferXlaShardingFromNeighbors(const Node & n,int num_cores_per_replica,FunctionLibraryRuntime * flr,CachedFunctionHandles * cached_function_handles,absl::optional<NodeAndSharding> * output_node_and_sharding,bool * is_fast_mem)2036 static Status InferXlaShardingFromNeighbors(
2037     const Node& n, int num_cores_per_replica, FunctionLibraryRuntime* flr,
2038     CachedFunctionHandles* cached_function_handles,
2039     absl::optional<NodeAndSharding>* output_node_and_sharding,
2040     bool* is_fast_mem) {
2041   int64_t core = -1;
2042   absl::optional<NodeAndSharding> result;
2043   // We assume the variable has been allocated on fast memory if any consuming
2044   // op has TPU_FAST_MEM_ATTR attribute. This is a protocol between runtime and
2045   // compiler.
2046   *is_fast_mem = false;
2047   for (const Edge* edge : n.out_edges()) {
2048     if (edge->IsControlEdge()) continue;
2049 
2050     TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
2051         num_cores_per_replica, n.name(), *edge->dst(), &core, is_fast_mem,
2052         &result));
2053 
2054     if (!flr) continue;
2055 
2056     // The nodes deciding this arg's device assignment might be in
2057     // FunctionDef. Instantiate FunctionDefs associated with this node
2058     // and check nodes using this arg.
2059     std::function<Status(const Edge* call_edge)> parse_sharding_from_function =
2060         [&](const Edge* call_edge) {
2061           auto associated_functions = GetAssociatedFunctions(
2062               *call_edge->dst(), flr->GetFunctionLibraryDefinition());
2063           for (auto& associated_function : associated_functions) {
2064             FunctionLibraryRuntime::Handle handle;
2065             TF_RETURN_IF_ERROR(cached_function_handles->GetOrInstantiate(
2066                 associated_function.func_name(),
2067                 AttrSlice(&associated_function.attrs()), &handle));
2068             const FunctionBody* body = flr->GetFunctionBody(handle);
2069             Graph* g = body->graph;
2070 
2071             for (Node* body_node : g->nodes()) {
2072               if (!body_node->IsArg()) continue;
2073 
2074               int index;
2075               TF_RETURN_IF_ERROR(
2076                   GetNodeAttr(body_node->attrs(), "index", &index));
2077               if (index != call_edge->dst_input()) continue;
2078 
2079               for (const Edge* out_edge : body_node->out_edges()) {
2080                 if (out_edge->IsControlEdge()) continue;
2081 
2082                 TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
2083                     num_cores_per_replica, n.name(), *out_edge->dst(), &core,
2084                     is_fast_mem, &result));
2085 
2086                 TF_RETURN_IF_ERROR(parse_sharding_from_function(out_edge));
2087               }
2088             }
2089           }
2090           return OkStatus();
2091         };
2092     TF_RETURN_IF_ERROR(parse_sharding_from_function(edge));
2093   }
2094   *output_node_and_sharding = result;
2095   return OkStatus();
2096 }
2097 
UseSpmdForXlaPartitioning(const Node * replicate_node)2098 bool UseSpmdForXlaPartitioning(const Node* replicate_node) {
2099   bool spmd_attr;
2100   if (!replicate_node ||
2101       !TryGetNodeAttr(replicate_node->attrs(), "use_spmd_for_xla_partitioning",
2102                       &spmd_attr)) {
2103     spmd_attr = false;
2104   }
2105   return spmd_attr;
2106 }
2107 
FormatNodeAndShardingMsg(const absl::optional<NodeAndSharding> & node_and_sharding)2108 std::string FormatNodeAndShardingMsg(
2109     const absl::optional<NodeAndSharding>& node_and_sharding) {
2110   DCHECK(node_and_sharding.has_value());
2111 
2112   xla::OpSharding sharding_no_metadata = node_and_sharding->sharding;
2113   sharding_no_metadata.clear_metadata();
2114   std::string escaped_sharding_str =
2115       absl::CEscape(sharding_no_metadata.SerializeAsString());
2116   if (node_and_sharding->node == nullptr) {
2117     return absl::StrCat(" via default sharding '", escaped_sharding_str, "'");
2118   }
2119 
2120   return absl::StrCat(" via node ", node_and_sharding->node->DebugString(),
2121                       " sharding '", escaped_sharding_str, "'");
2122 }
2123 
AssignArgsAndRetvalsToCores(int num_cores_per_replica,const ParameterInfo & params_info,const DataTypeVector & arg_types,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & retval_types,const std::vector<InferredShape> & retval_shapes,const Graph & graph,const Node * replicate_node,FunctionLibraryRuntime * flr,bool allow_parameter_replication_for_spmd,std::vector<xla::OpSharding> * arg_sharding,std::vector<bool> * arg_fast_mem,std::vector<xla::OpSharding> * retval_sharding,std::vector<std::string> * arg_names)2124 Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
2125     int num_cores_per_replica, const ParameterInfo& params_info,
2126     const DataTypeVector& arg_types,
2127     const std::vector<InferredShape>& arg_shapes,
2128     const DataTypeVector& retval_types,
2129     const std::vector<InferredShape>& retval_shapes, const Graph& graph,
2130     const Node* replicate_node, FunctionLibraryRuntime* flr,
2131     bool allow_parameter_replication_for_spmd,
2132     std::vector<xla::OpSharding>* arg_sharding, std::vector<bool>* arg_fast_mem,
2133     std::vector<xla::OpSharding>* retval_sharding,
2134     std::vector<std::string>* arg_names) {
2135   // Builds vectors of the argument and return nodes.
2136   std::vector<Node*> args(arg_types.size());
2137   std::vector<Node*> retvals(retval_types.size());
2138   absl::flat_hash_map<int, Node*> partitioned_output_nodes;
2139   for (Node* node : graph.op_nodes()) {
2140     if (node->IsArg()) {
2141       int index;
2142       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
2143       TF_RET_CHECK(index >= 0 && index < args.size());
2144       args[index] = node;
2145     } else if (node->IsRetval()) {
2146       int index;
2147       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
2148       TF_RET_CHECK(index >= 0 && index < retvals.size());
2149       retvals[index] = node;
2150     }
2151   }
2152   for (const Edge* edge : replicate_node->out_edges()) {
2153     int num_partitioned_outputs = 0;
2154     for (const Edge* out_edge : edge->dst()->out_edges()) {
2155       if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
2156         partitioned_output_nodes[edge->src_output()] = out_edge->dst();
2157         num_partitioned_outputs++;
2158       }
2159     }
2160     if (num_partitioned_outputs > 1) {
2161       return errors::InvalidArgument(
2162           "More than one TPUPartitionedOutput per replciated output.");
2163     }
2164   }
2165 
2166   // Verifies there are no missing arguments/return values.
2167   for (int i = 0; i < args.size(); ++i) {
2168     if (args[i] == nullptr) {
2169       return errors::Internal("Missing function argument: ", i);
2170     }
2171   }
2172   for (int i = 0; i < retvals.size(); ++i) {
2173     if (retvals[i] == nullptr) {
2174       return errors::Internal("Missing function return value: ", i);
2175     }
2176   }
2177 
2178   // Assigns a core to each _Arg. Chooses the lowest-numbered core that
2179   // consumes the argument. We choose the lowest-numbered core so the
2180   // assignment is deterministic.
2181   TensorDevicePlacer args_device_selector(num_cores_per_replica, arg_types,
2182                                           arg_shapes);
2183   arg_sharding->resize(args.size());
2184   arg_names->resize(args.size());
2185   arg_fast_mem->resize(args.size());
2186   CachedFunctionHandles cached_function_handles(flr);
2187   const bool use_spmd = (UseSpmdForXlaPartitioning(replicate_node) ||
2188                          replicate_inputs_outputs_by_default_for_xla_spmd_) &&
2189                         allow_parameter_replication_for_spmd &&
2190                         num_cores_per_replica > 1;
2191 
2192   // Offset _TPUReplicate non per replica argument indices by
2193   // (num_replicas - 1) * num_per_replica_args as _TPUReplicate nodes are
2194   // constructed with all per replica args across all replicas while the
2195   // encapsulated function only has 1 replica's per replica args. Per replica
2196   // args are ordered by replica first, so the index here does not require an
2197   // offset and the first replica's input nodes is sufficient for determining
2198   // argument sharding.
2199   const int index_offset =
2200       (params_info.NumReplicas() - 1) * params_info.NumPerReplicaArgs();
2201   for (int i = 0; i < args.size(); ++i) {
2202     const Node* n = args[i];
2203     absl::optional<int64_t> assigned_core;
2204     absl::optional<NodeAndSharding> node_and_sharding;
2205     bool is_fast_mem;
2206     TF_RETURN_IF_ERROR(InferXlaShardingFromNeighbors(
2207         *n, num_cores_per_replica, flr, &cached_function_handles,
2208         &node_and_sharding, &is_fast_mem));
2209 
2210     const bool is_per_replica_arg = params_info.IsPerReplicaArg(i);
2211     if (is_per_replica_arg || params_info.IsDistributedArg(i)) {
2212       Node* input_node;
2213       TF_RETURN_IF_ERROR(replicate_node->input_node(
2214           i + (is_per_replica_arg ? 0 : index_offset), &input_node));
2215       if (input_node->type_string() == kTPUPartitionedInput) {
2216         TF_ASSIGN_OR_RETURN(
2217             absl::optional<xla::OpSharding> parsed_sharding,
2218             GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
2219         if (!parsed_sharding.has_value())
2220           return errors::InvalidArgument("Missing _XlaSharding attr from: ",
2221                                          input_node->DebugString());
2222         node_and_sharding = NodeAndSharding(input_node, *parsed_sharding);
2223         VLOG(1) << "Arg " << i << " parsed sharding information from "
2224                 << input_node->DebugString() << " : "
2225                 << parsed_sharding->DebugString();
2226       }
2227     }
2228 
2229     if (params_info.IsVariableArg(i)) {
2230       Node* input_node;
2231       TF_RETURN_IF_ERROR(
2232           replicate_node->input_node(i + index_offset, &input_node));
2233       if (input_node->type_string() == kVarHandleOp) {
2234         TF_ASSIGN_OR_RETURN(
2235             absl::optional<xla::OpSharding> parsed_sharding,
2236             GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
2237         if (parsed_sharding.has_value()) {
2238           node_and_sharding = NodeAndSharding(input_node, *parsed_sharding);
2239           VLOG(1) << "Arg " << i << " parsed sharding information from "
2240                   << input_node->DebugString() << " : "
2241                   << parsed_sharding->DebugString();
2242         }
2243       }
2244     }
2245 
2246     if (node_and_sharding.has_value() && enable_automatic_model_parallelism_) {
2247       return tensorflow::errors::InvalidArgument(
2248           "Specifying manual sharding is not allowed when automatic "
2249           "model parallelism is enabled.",
2250           node_and_sharding->sharding.DebugString());
2251     }
2252 
2253     if (!node_and_sharding.has_value()) {
2254       if (use_spmd &&
2255           (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
2256            ((params_info.IsPerReplicaArg(i) ||
2257              params_info.IsDistributedArg(i)) &&
2258             arg_types[i] != DT_RESOURCE) ||
2259            params_info.IsConstantArg(i))) {
2260         // Use replication for host variables or non-variable per-replica
2261         // inputs.
2262         node_and_sharding = NodeAndSharding(/*node=*/nullptr,
2263                                             xla::sharding_builder::Replicate());
2264       } else {
2265         // TODO(dlibenzi): Distributing variables to cores other than 0 makes
2266         // learning/brain/research/babelfish/trainer:trainer_tpu_test fail.
2267         // For now distribute only per replica arguments, unless
2268         // tf_jf_distribute_vars is set, to allow debugging the issue.
2269         if (((params_info.IsPerReplicaArg(i) ||
2270               params_info.IsDistributedArg(i)) &&
2271              arg_types[i] != DT_RESOURCE) ||
2272             (distribute_vars_ && params_info.IsVariableArg(i))) {
2273           assigned_core = args_device_selector.RetrieveAssignment(i);
2274         } else {
2275           assigned_core = 0;
2276         }
2277         node_and_sharding = NodeAndSharding(
2278             /*node=*/nullptr,
2279             xla::sharding_builder::AssignDevice(*assigned_core));
2280       }
2281       *node_and_sharding->sharding.add_metadata() =
2282           CreateOpMetadataFromNode(*replicate_node);
2283     } else if (node_and_sharding->sharding.type() == xla::OpSharding::MAXIMAL) {
2284       if (use_spmd) {
2285         node_and_sharding->sharding = xla::sharding_builder::Replicate();
2286       } else {
2287         assigned_core = node_and_sharding->sharding.tile_assignment_devices(0);
2288       }
2289     } else if (node_and_sharding->sharding.type() !=
2290                    xla::OpSharding::REPLICATED &&
2291                node_and_sharding->sharding.type() != xla::OpSharding::OTHER) {
2292       return tensorflow::errors::InvalidArgument(
2293           "Unsupported argument sharding (for arg ", n->DebugString(),
2294           "): ", node_and_sharding->sharding.DebugString());
2295     }
2296     if (assigned_core.has_value()) {
2297       args_device_selector.ReportDeviceAssigned(*assigned_core, i);
2298       VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2299               << ") to core " << *assigned_core
2300               << FormatNodeAndShardingMsg(node_and_sharding);
2301       args[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
2302     } else if (node_and_sharding->sharding.type() == xla::OpSharding::OTHER) {
2303       for (int64_t core :
2304            node_and_sharding->sharding.tile_assignment_devices()) {
2305         TF_RET_CHECK(core >= 0 && core < num_cores_per_replica)
2306             << "core " << core << " should be between [0, "
2307             << num_cores_per_replica << "). sharding is "
2308             << node_and_sharding->sharding.DebugString();
2309         args_device_selector.ReportDeviceAssigned(core, i);
2310       }
2311       VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2312               << ") with tiled sharding to cores "
2313               << absl::StrJoin(
2314                      node_and_sharding->sharding.tile_assignment_devices(), ",")
2315               << " " << FormatNodeAndShardingMsg(node_and_sharding);
2316     } else {
2317       DCHECK_EQ(node_and_sharding->sharding.type(),
2318                 xla::OpSharding::REPLICATED);
2319       for (int64_t core = 0; core < num_cores_per_replica; ++core) {
2320         args_device_selector.ReportDeviceAssigned(core, i);
2321       }
2322       VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2323               << ") to all cores"
2324               << FormatNodeAndShardingMsg(node_and_sharding);
2325     }
2326     (*arg_sharding)[i] = node_and_sharding->sharding;
2327     (*arg_fast_mem)[i] = is_fast_mem;
2328     (*arg_names)[i] = n->name();
2329     if (is_fast_mem) {
2330       VLOG(3) << "Add " << TPU_FAST_MEM_ATTR << " attribute to "
2331               << args[i]->name();
2332     }
2333     args[i]->AddAttr(kShardingAttribute,
2334                      node_and_sharding->sharding.SerializeAsString());
2335   }
2336   TF_RETURN_IF_ERROR(cached_function_handles.ReleaseAllHandles());
2337 
2338   // Assigns each _Retval node to the core that produces its value.
2339   TensorDevicePlacer retvals_device_selector(num_cores_per_replica,
2340                                              retval_types, retval_shapes);
2341   retval_sharding->resize(retvals.size());
2342   for (int i = 0; i < retvals.size(); ++i) {
2343     const Edge* edge;
2344     TF_RETURN_IF_ERROR(retvals[i]->input_edge(0, &edge));
2345 
2346     TF_ASSIGN_OR_RETURN(
2347         absl::optional<xla::OpSharding> edge_sharding,
2348         ParseShardingFromEdgeSource(*edge, num_cores_per_replica,
2349                                     /*add_metadata=*/true));
2350 
2351     absl::optional<NodeAndSharding> node_and_sharding;
2352     if (edge_sharding.has_value()) {
2353       node_and_sharding.emplace(NodeAndSharding(edge->src(), *edge_sharding));
2354     }
2355 
2356     if (partitioned_output_nodes.contains(i)) {
2357       Node* output_node = partitioned_output_nodes[i];
2358       TF_ASSIGN_OR_RETURN(
2359           absl::optional<xla::OpSharding> parsed_sharding,
2360           GetShardingFromNodeDef(output_node->def(), /*add_metadata=*/true));
2361       if (parsed_sharding.has_value()) {
2362         node_and_sharding = NodeAndSharding(output_node, *parsed_sharding);
2363         VLOG(1) << "Retval " << i << " parsed sharding information from "
2364                 << output_node->DebugString() << " : "
2365                 << parsed_sharding->DebugString();
2366       }
2367     }
2368     absl::optional<int64_t> assigned_core;
2369     if (node_and_sharding.has_value()) {
2370       if (enable_automatic_model_parallelism_) {
2371         return tensorflow::errors::InvalidArgument(
2372             "Specifying manual sharding is not allowed when automatic "
2373             "model parallelism is enabled.",
2374             node_and_sharding->sharding.DebugString());
2375       }
2376 
2377       if (node_and_sharding->sharding.type() == xla::OpSharding::MAXIMAL) {
2378         if (use_spmd) {
2379           node_and_sharding->sharding = xla::sharding_builder::Replicate();
2380         } else {
2381           assigned_core =
2382               node_and_sharding->sharding.tile_assignment_devices(0);
2383           TF_RETURN_IF_ERROR(
2384               ValidateCoreNumber(*assigned_core, num_cores_per_replica));
2385         }
2386       } else if (node_and_sharding->sharding.type() !=
2387                      xla::OpSharding::REPLICATED &&
2388                  node_and_sharding->sharding.type() != xla::OpSharding::OTHER) {
2389         return tensorflow::errors::InvalidArgument(
2390             "Unsupported argument sharding for retval ",
2391             retvals[i]->DebugString(), " edge=", edge->DebugString(), ": ",
2392             node_and_sharding->sharding.DebugString());
2393       }
2394     } else {
2395       if (use_spmd) {
2396         node_and_sharding = NodeAndSharding(/*node=*/nullptr,
2397                                             xla::sharding_builder::Replicate());
2398       } else {
2399         if (distribute_vars_) {
2400           assigned_core = retvals_device_selector.RetrieveAssignment(i);
2401         } else {
2402           assigned_core = 0;
2403         }
2404         node_and_sharding = NodeAndSharding(
2405             /*node=*/nullptr,
2406             xla::sharding_builder::AssignDevice(*assigned_core));
2407       }
2408       *node_and_sharding->sharding.add_metadata() =
2409           CreateOpMetadataFromNode(*replicate_node);
2410     }
2411     if (assigned_core.has_value() && !use_spmd) {
2412       retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
2413       retvals_device_selector.ReportDeviceAssigned(*assigned_core, i);
2414       VLOG(3) << "Assigning return value " << i << " ("
2415               << retvals[i]->DebugString() << ") to core " << *assigned_core
2416               << FormatNodeAndShardingMsg(node_and_sharding);
2417     } else if (node_and_sharding->sharding.type() == xla::OpSharding::OTHER) {
2418       for (int64_t core :
2419            node_and_sharding->sharding.tile_assignment_devices()) {
2420         TF_RET_CHECK(core >= 0 && core < num_cores_per_replica)
2421             << "core " << core << " should be between [0, "
2422             << num_cores_per_replica << "). sharding is "
2423             << node_and_sharding->sharding.DebugString();
2424         retvals_device_selector.ReportDeviceAssigned(core, i);
2425       }
2426       VLOG(3) << "Assigning return value " << i << " ("
2427               << retvals[i]->DebugString() << ") with tiled sharding to cores "
2428               << absl::StrJoin(
2429                      node_and_sharding->sharding.tile_assignment_devices(), ",")
2430               << " " << FormatNodeAndShardingMsg(node_and_sharding);
2431     } else {
2432       if (use_spmd) {
2433         node_and_sharding->sharding = xla::sharding_builder::Replicate();
2434       }
2435       for (int64_t core = 0; core < num_cores_per_replica; ++core) {
2436         retvals_device_selector.ReportDeviceAssigned(core, i);
2437       }
2438       VLOG(3) << "Assigning return value " << i << " ("
2439               << retvals[i]->DebugString() << ") to all cores"
2440               << FormatNodeAndShardingMsg(node_and_sharding);
2441     }
2442     retvals[i]->AddAttr(kShardingAttribute,
2443                         node_and_sharding->sharding.SerializeAsString());
2444     (*retval_sharding)[i] = node_and_sharding->sharding;
2445   }
2446   if (use_spmd &&
2447       (absl::c_any_of(*arg_sharding,
2448                       [](const xla::OpSharding& s) {
2449                         return s.type() == xla::OpSharding::MAXIMAL;
2450                       }) ||
2451        absl::c_any_of(*retval_sharding, [](const xla::OpSharding& s) {
2452          return s.type() == xla::OpSharding::MAXIMAL;
2453        }))) {
2454     return tensorflow::errors::InvalidArgument(
2455         "XLA SPMD only supports cases where all inputs/outputs "
2456         "exist on every partition (sharded or replicated).");
2457   }
2458   return OkStatus();
2459 }
2460 
2461 // Builds Shape nodes that compute the shapes of arguments whose shapes are not
2462 // statically known.
BuildDynamicShapeNodes(const Node & replicate_node,const std::vector<InferredShape> & arg_shapes,const ParameterInfo & params_info,const std::vector<Node * > & variable_reads,Graph * graph,std::vector<Node * > * dynamic_shape_nodes)2463 /* static */ Status DistributedTPURewritePass::BuildDynamicShapeNodes(
2464     const Node& replicate_node, const std::vector<InferredShape>& arg_shapes,
2465     const ParameterInfo& params_info, const std::vector<Node*>& variable_reads,
2466     Graph* graph, std::vector<Node*>* dynamic_shape_nodes) {
2467   dynamic_shape_nodes->clear();
2468 
2469   std::vector<const Edge*> replicate_input_edges;
2470   TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
2471 
2472   // The compiler determines the shape of each constant by inspecting the value
2473   // of its corresponding host-memory tensor; this happens when a step is run.
2474   // As a result, the shapes of constants are not needed at graph rewrite time.
2475   const int num_args = arg_shapes.size() - params_info.NumGuaranteedConstants();
2476   TF_RET_CHECK(num_args == params_info.NumPerReplicaArgs() +
2477                                params_info.NumDistributedArgs() +
2478                                params_info.NumBroadcastArgs() +
2479                                params_info.NumVariables());
2480 
2481   for (int i = 0; i < num_args; ++i) {
2482     const PartialTensorShape* shape = arg_shapes[i].handle_type == DT_INVALID
2483                                           ? &arg_shapes[i].shape
2484                                           : &arg_shapes[i].handle_shape;
2485     if (!shape->IsFullyDefined()) {
2486       NodeDef def;
2487       Node* src;
2488       int src_output;
2489       std::vector<Node*> control_inputs;
2490 
2491       if (params_info.IsVariableArg(i)) {
2492         int64_t var_num = i - params_info.NumPerReplicaArgs() -
2493                           params_info.NumDistributedArgs() -
2494                           params_info.NumBroadcastArgs();
2495         TF_RET_CHECK(0 <= var_num && var_num < variable_reads.size());
2496         Node* read = variable_reads[var_num];
2497 
2498         DCHECK_EQ(read->type_string(), "ReadVariableOp");
2499 
2500         for (const Edge* edge : read->in_edges()) {
2501           if (edge->IsControlEdge()) {
2502             control_inputs.push_back(edge->src());
2503           }
2504         }
2505 
2506         const Edge* variable_input = nullptr;
2507         TF_RETURN_IF_ERROR(read->input_edge(/*idx=*/0, &variable_input));
2508         src = variable_input->src();
2509         src_output = variable_input->src_output();
2510 
2511         def.set_name(
2512             graph->NewName(strings::StrCat(src->name(), "/variable_shape")));
2513         def.set_op("VariableShape");
2514       } else {
2515         if (params_info.IsPerReplicaArg(i)) {
2516           TF_RET_CHECK(i < replicate_input_edges.size());
2517           // All replicas must have the same input shapes. Uses the shape of the
2518           // inputs from the first replica.
2519           src = replicate_input_edges[i]->src();
2520           src_output = replicate_input_edges[i]->src_output();
2521         } else {
2522           DCHECK(params_info.IsDistributedArg(i) ||
2523                  params_info.IsBroadcastArg(i));
2524           int64_t input_num =
2525               params_info.NumPerReplicaArgs() * params_info.NumReplicas() + i -
2526               params_info.NumPerReplicaArgs();
2527           TF_RET_CHECK(0 <= input_num &&
2528                        input_num < replicate_input_edges.size());
2529           src = replicate_input_edges[input_num]->src();
2530           src_output = replicate_input_edges[input_num]->src_output();
2531         }
2532 
2533         def.set_name(graph->NewName(strings::StrCat(src->name(), "/shape")));
2534         def.set_op("Shape");
2535         AddNodeAttr("T", src->output_type(src_output), &def);
2536       }
2537 
2538       def.set_device(src->assigned_device_name());
2539       AddNodeAttr("out_type", DT_INT64, &def);
2540       MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
2541 
2542       TF_ASSIGN_OR_RETURN(Node * shape_node, graph->AddNode(def));
2543       dynamic_shape_nodes->push_back(shape_node);
2544 
2545       shape_node->set_assigned_device_name(src->assigned_device_name());
2546       graph->AddEdge(src, src_output, shape_node, 0);
2547       for (Node* control_input : control_inputs) {
2548         graph->AddControlEdge(control_input, shape_node);
2549       }
2550     }
2551   }
2552   return OkStatus();
2553 }
2554 
2555 namespace {
2556 
XlaBroadcastTypeSupported(const DataType dtype)2557 bool XlaBroadcastTypeSupported(const DataType dtype) {
2558   return (dtype == DT_FLOAT || dtype == DT_BFLOAT16 || dtype == DT_INT32 ||
2559           dtype == DT_BOOL);
2560 }
2561 
XlaBroadcastKindSupported(const DistributedTPURewritePass::ParameterInfo & params_info,int param_num)2562 bool XlaBroadcastKindSupported(
2563     const DistributedTPURewritePass::ParameterInfo& params_info,
2564     int param_num) {
2565   // NOTE: This is intended to cover non-sharded data parallel variables, for
2566   // training only. . Is it correct to just check if the arg_type is
2567   // DT_RESOURCE?
2568   return params_info.IsVariableArg(param_num) &&
2569          !(params_info.IsPerReplicaArg(param_num) ||
2570            params_info.IsDistributedArg(param_num) ||
2571            params_info.IsBroadcastArg(param_num) ||
2572            params_info.IsConstantArg(param_num));
2573 }
2574 
EnableXlaParamBroadcast(bool enable_xla_param_broadcast,bool mpmd,const DistributedTPURewritePass::ParameterInfo & params_info,int param_num,DataType dtype)2575 bool EnableXlaParamBroadcast(
2576     bool enable_xla_param_broadcast, bool mpmd,
2577     const DistributedTPURewritePass::ParameterInfo& params_info, int param_num,
2578     DataType dtype) {
2579   // Conditions necessary to use XLA collectives for arg broadcast:
2580   // 1. Globally enabled via enable_xla_param_broadcast.
2581   // 2. DataType must be supported.
2582   // 3. Parameter must be a variable, and not distributed or broadcasted.
2583   // 4. For multi-core models (num_cores_per_replica > 1), must use SPMD.
2584   return enable_xla_param_broadcast && XlaBroadcastTypeSupported(dtype) &&
2585          XlaBroadcastKindSupported(params_info, param_num) && !mpmd;
2586 }
2587 
2588 }  // namespace
2589 
2590 // Builds a TPUCompile node that compiles the bodies of the function call
2591 // `nodes`.
BuildCompileNode(const Node * replicate_node,const NameAttrList & function,uint64 library_fingerprint,const ParameterInfo & params_info,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & arg_types,const std::vector<Node * > & guaranteed_constant_nodes,const string & session_handle,const std::vector<xla::OpSharding> & arg_sharding,const std::vector<bool> & arg_fast_mem,const std::vector<std::string> & arg_names,const std::vector<xla::OpSharding> & retval_sharding,int num_cores_per_replica,const string & compile_device,const xla::DeviceAssignment * xla_device_assignment,const std::vector<Node * > & dynamic_shape_nodes,Graph * graph,Node ** compile_node,int64_t autotuner_thresh)2592 Status DistributedTPURewritePass::BuildCompileNode(
2593     const Node* replicate_node, const NameAttrList& function,
2594     uint64 library_fingerprint, const ParameterInfo& params_info,
2595     const std::vector<InferredShape>& arg_shapes,
2596     const DataTypeVector& arg_types,
2597     const std::vector<Node*>& guaranteed_constant_nodes,
2598     const string& session_handle,
2599     const std::vector<xla::OpSharding>& arg_sharding,
2600     const std::vector<bool>& arg_fast_mem,
2601     const std::vector<std::string>& arg_names,
2602     const std::vector<xla::OpSharding>& retval_sharding,
2603     int num_cores_per_replica, const string& compile_device,
2604     const xla::DeviceAssignment* xla_device_assignment,
2605     const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
2606     Node** compile_node, int64_t autotuner_thresh) {
2607   VLOG(1) << "BuildCompileNode";
2608 
2609   tpu::TPUCompileMetadataProto proto;
2610   if (replicate_node) {
2611     std::string str;
2612     TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node->attrs(),
2613                                    "tpu_compile_options_proto", &str));
2614     TF_RET_CHECK(proto.mutable_compile_options()->ParseFromString(str));
2615   }
2616   proto.set_num_replicas(params_info.NumReplicas());
2617   proto.set_num_cores_per_replica(num_cores_per_replica);
2618   proto.set_function_library_fingerprint(library_fingerprint);
2619   proto.set_enable_automatic_model_parallelism(
2620       enable_cross_replica_sharding_mirrored_variables_);
2621   const bool use_spmd =
2622       UseSpmdForXlaPartitioning(replicate_node) && allow_xla_spmd_partition_;
2623   proto.set_use_spmd_for_xla_partitioning(use_spmd);
2624   const bool mpmd = (num_cores_per_replica > 1) && !use_spmd;
2625 
2626   // Get and fill padding map.
2627   if (replicate_node != nullptr) {
2628     xla::DebugOptions::StepMarkerLocation location;
2629     TF_RETURN_IF_ERROR(GetStepMarkerLocation(*replicate_node, &location));
2630     proto.set_step_marker_location(location);
2631   }
2632 
2633   if (xla_device_assignment != nullptr) {
2634     TF_RETURN_IF_ERROR(
2635         xla_device_assignment->Serialize(proto.mutable_device_assignment()));
2636   }
2637 
2638   const int num_args = arg_types.size();
2639   const int num_guaranteed_constants = guaranteed_constant_nodes.size();
2640   const int guaranteed_const_start_index = num_args - num_guaranteed_constants;
2641   TF_RET_CHECK(num_args == arg_shapes.size());
2642   TF_RET_CHECK(num_args == arg_sharding.size())
2643       << num_args << " != " << arg_sharding.size();
2644 
2645   for (int i = 0; i < num_args; ++i) {
2646     tpu::TPUCompileMetadataProto::Arg* arg = proto.add_args();
2647     DataType type = arg_types[i];
2648     const InferredShape& arg_shape = arg_shapes[i];
2649     arg->set_name(arg_names[i]);
2650     if (type == DT_RESOURCE) {
2651       TF_RET_CHECK(arg_shape.handle_type != DT_INVALID) << i;
2652       arg->set_dtype(arg_shape.handle_type);
2653       arg_shape.handle_shape.AsProto(arg->mutable_shape());
2654       arg->set_kind(tpu::TPUCompileMetadataProto::Arg::VARIABLE);
2655       arg->set_fast_mem(arg_fast_mem[i]);
2656     } else {
2657       arg->set_dtype(type);
2658       arg_shape.shape.AsProto(arg->mutable_shape());
2659       if (i >= guaranteed_const_start_index) {
2660         const DataType edge_type =
2661             guaranteed_constant_nodes[i - guaranteed_const_start_index]
2662                 ->output_type(0);
2663         TF_RET_CHECK(type == edge_type)
2664             << "Arg type: " << type << " but edge type: " << edge_type;
2665         arg->set_kind(tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT);
2666       } else {
2667         arg->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER);
2668       }
2669     }
2670 
2671     // Use XLA collective primitives to distribute variables to all replicas.
2672     arg->set_requires_xla_broadcast(
2673         params_info.NumReplicas() > 1 &&
2674         EnableXlaParamBroadcast(enable_xla_param_broadcast_, mpmd, params_info,
2675                                 i, arg_shape.handle_type /*arg.dtype?*/));
2676 
2677     // As long as the argument is not a per-replica one, it should have the same
2678     // value for all replicas. For clarity, we keep the (redundant) checks for
2679     // variable, broadcast and constant types, to prevent bugs in case new types
2680     // with different semantics are introduced in the future.
2681     arg->set_is_same_data_across_replicas(
2682         !params_info.IsPerReplicaArg(i) && !params_info.IsDistributedArg(i) &&
2683         (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
2684          params_info.IsConstantArg(i)));
2685     if (params_info.mirrored_variable_indices().count(i) > 0) {
2686       TF_RET_CHECK(type == DT_RESOURCE)
2687           << "Arg type: " << type << " name: " << arg->name()
2688           << " shape: " << arg->shape().DebugString();
2689       arg->set_is_same_data_across_replicas(true);
2690       // 64-bit type is not shardable by XLA:TPU yet.
2691       bool sharding_enabled = (arg_shape.handle_type != DT_COMPLEX64 &&
2692                                arg_shape.handle_type != DT_INT64 &&
2693                                arg_shape.handle_type != DT_UINT64 &&
2694                                arg_shape.handle_type != DT_DOUBLE);
2695       arg->set_enable_xla_sharding(
2696           sharding_enabled ? tpu::TPUCompileMetadataProto::Arg::TENTATIVE
2697                            : tpu::TPUCompileMetadataProto::Arg::DISALLOWED);
2698     }
2699     *arg->mutable_sharding() = arg_sharding[i];
2700   }
2701 
2702   const int num_retvals = retval_sharding.size();
2703   for (int i = 0; i < num_retvals; ++i) {
2704     *proto.add_retvals()->mutable_sharding() = retval_sharding[i];
2705   }
2706   proto.set_session_handle(session_handle);
2707 
2708   DataTypeVector constant_arg_types;
2709   constant_arg_types.reserve(num_guaranteed_constants);
2710   for (int i = 0; i < num_guaranteed_constants; ++i) {
2711     constant_arg_types.push_back(arg_types[guaranteed_const_start_index + i]);
2712   }
2713   proto.set_xla_fusion_autotuner_thresh(autotuner_thresh);
2714 
2715   string metadata;
2716   proto.SerializeToString(&metadata);
2717 
2718   NodeDef def;
2719   def.set_name(UniqueNodeName("TPUReplicate/_compile", graph));
2720   def.set_op("TPUCompile");
2721   def.set_device(compile_device);
2722   if (replicate_node) {
2723     MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
2724   }
2725 
2726   AddNodeAttr("function", function, &def);
2727   AddNodeAttr("num_computations", num_cores_per_replica, &def);
2728   AddNodeAttr("NumDynamicShapes", static_cast<int>(dynamic_shape_nodes.size()),
2729               &def);
2730   AddNodeAttr("metadata", metadata, &def);
2731   AddNodeAttr("Tguaranteed_constants", constant_arg_types, &def);
2732 
2733   TF_ASSIGN_OR_RETURN(*compile_node, graph->AddNode(def));
2734 
2735   (*compile_node)->set_assigned_device_name(compile_device);
2736 
2737   for (int i = 0; i < dynamic_shape_nodes.size(); ++i) {
2738     graph->AddEdge(dynamic_shape_nodes[i], 0, *compile_node, i);
2739   }
2740 
2741   for (int i = 0; i < num_guaranteed_constants; ++i) {
2742     graph->AddEdge(guaranteed_constant_nodes[i], 0, *compile_node,
2743                    dynamic_shape_nodes.size() + i);
2744   }
2745   VLOG(1) << "BuildCompileNode()";
2746   return OkStatus();
2747 }
2748 
FindGuaranteedConstantInputs(const Node & node,const NameRangeMap & input_range_map,std::vector<Node * > * guaranteed_constants)2749 Status DistributedTPURewritePass::FindGuaranteedConstantInputs(
2750     const Node& node, const NameRangeMap& input_range_map,
2751     std::vector<Node*>* guaranteed_constants) {
2752   std::vector<const Edge*> input_edges;
2753   TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
2754   std::pair<int, int> variables_limits =
2755       input_range_map.at("guaranteed_constants");
2756   for (int i = variables_limits.first; i < variables_limits.second; ++i) {
2757     guaranteed_constants->push_back(input_edges[i]->src());
2758   }
2759   return OkStatus();
2760 }
2761 
FindVariableInputs(const Node & node,const NameRangeMap & input_range_map,std::vector<VariableInput> * variables)2762 Status DistributedTPURewritePass::FindVariableInputs(
2763     const Node& node, const NameRangeMap& input_range_map,
2764     std::vector<VariableInput>* variables) {
2765   std::vector<const Edge*> input_edges;
2766   TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
2767   std::pair<int, int> variables_limits = input_range_map.at("variables");
2768   for (int i = variables_limits.first; i < variables_limits.second; ++i) {
2769     Node* node = input_edges[i]->src();
2770 
2771     // Find the type of the VarHandleOp that feeds this node, looking through
2772     // any wrapping Enter or Switch nodes.
2773     while (node->IsEnter() || node->IsSwitch()) {
2774       TF_RETURN_IF_ERROR(node->input_node(0, &node));
2775     }
2776     // Fix the variable device assignment if it is requested with a full name.
2777     if (!node->has_assigned_device_name() &&
2778         !node->requested_device().empty()) {
2779       DeviceNameUtils::ParsedName var_device;
2780       TF_RET_CHECK(DeviceNameUtils::ParseFullName(node->requested_device(),
2781                                                   &var_device));
2782       if (var_device.has_job && var_device.has_replica && var_device.has_task &&
2783           var_device.has_type && var_device.has_id) {
2784         node->set_assigned_device_name(node->requested_device());
2785         if (node != input_edges[i]->src() &&
2786             !input_edges[i]->src()->has_assigned_device_name()) {
2787           input_edges[i]->src()->set_assigned_device_name(
2788               node->requested_device());
2789         }
2790       }
2791     }
2792     if (node->type_string() == kVarHandleOp) {
2793       DataType dtype;
2794       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "dtype", &dtype));
2795       variables->push_back(VariableInput{input_edges[i]->src(),
2796                                          input_edges[i]->src_output(), dtype});
2797     } else if (node->type_string() == "_Arg") {
2798       std::vector<DataType> dtypes;
2799       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "_handle_dtypes", &dtypes));
2800       if (dtypes.empty()) {
2801         return errors::Internal(
2802             "_Arg node with resource output must have non-empty _handle_dtypes "
2803             "attribute: ",
2804             node->DebugString());
2805       }
2806       variables->push_back(VariableInput{
2807           input_edges[i]->src(), input_edges[i]->src_output(), dtypes[0]});
2808     } else {
2809       return errors::Internal(
2810           "Cannot handle variable input with node type other than VarHandleOp "
2811           "and _Arg: ",
2812           node->DebugString());
2813     }
2814   }
2815   return OkStatus();
2816 }
2817 
2818 // Builds a NoOp node, used for building control dependencies.
BuildNoopNode(const Node & source,StringPiece name,const string & device,Graph * graph,Node ** node)2819 static Status BuildNoopNode(const Node& source, StringPiece name,
2820                             const string& device, Graph* graph, Node** node) {
2821   NodeDefBuilder builder(name, "NoOp", NodeDebugInfo(source));
2822   if (!device.empty()) {
2823     builder.Device(device);
2824   }
2825   NodeDef def;
2826   TF_RETURN_IF_ERROR(builder.Finalize(&def));
2827 
2828   TF_ASSIGN_OR_RETURN(*node, graph->AddNode(def));
2829   if (!device.empty()) {
2830     (*node)->set_assigned_device_name(device);
2831   }
2832   return OkStatus();
2833 }
2834 
ConnectHostComputeNodes(Node * compile_node,Node * key_placeholder_node,Graph * graph)2835 Status DistributedTPURewritePass::ConnectHostComputeNodes(
2836     Node* compile_node, Node* key_placeholder_node, Graph* graph) {
2837   // First find all the downstream nodes of the key placeholder node, since we
2838   // want to delete the connecting edges from key_placeholder_node which would
2839   // invalidate the out_nodes iterator.
2840   std::vector<Node*> host_transfer_nodes;
2841   for (Node* node : key_placeholder_node->out_nodes()) {
2842     host_transfer_nodes.push_back(node);
2843   }
2844   for (Node* node : host_transfer_nodes) {
2845     int input_index = -1;
2846     for (int i = 0; i < node->num_inputs(); i++) {
2847       const Edge* e;
2848       TF_RETURN_IF_ERROR(node->input_edge(i, &e));
2849       if (e->src() == key_placeholder_node) {
2850         if (input_index != -1) {
2851           return errors::Internal(
2852               "Node ", node->name(),
2853               " has multiple input edges from key placeholder node");
2854         }
2855         input_index = e->dst_input();
2856       }
2857     }
2858     if (input_index == -1) {
2859       return errors::Internal("Node ", node->name(),
2860                               " has no input edge from key placeholder node");
2861     }
2862     const Edge* key_edge;
2863     TF_RETURN_IF_ERROR(node->input_edge(input_index, &key_edge));
2864     graph->RemoveEdge(key_edge);
2865     graph->AddEdge(compile_node, 1, node, input_index);
2866   }
2867   graph->RemoveNode(key_placeholder_node);
2868   return OkStatus();
2869 }
2870 
BuildVariableReads(absl::Span<const VariableInput> variables,Node * control_predecessor,Graph * graph,std::vector<Node * > * variable_reads)2871 Status DistributedTPURewritePass::BuildVariableReads(
2872     absl::Span<const VariableInput> variables, Node* control_predecessor,
2873     Graph* graph, std::vector<Node*>* variable_reads) {
2874   variable_reads->resize(variables.size());
2875   for (int i = 0; i < variables.size(); ++i) {
2876     string name =
2877         graph->NewName(strings::StrCat(variables[i].node->name(), "/read"));
2878     NodeDefBuilder builder(name, "ReadVariableOp",
2879                            NodeDebugInfo(*variables[i].node));
2880 
2881     builder.Attr("dtype", variables[i].dtype);
2882     builder.Device(variables[i].node->assigned_device_name());
2883     builder.Input(variables[i].node->name(), 0, DT_RESOURCE);
2884     NodeDef def;
2885     TF_RETURN_IF_ERROR(builder.Finalize(&def));
2886 
2887     TF_ASSIGN_OR_RETURN(Node * read_node, graph->AddNode(def));
2888     (*variable_reads)[i] = read_node;
2889 
2890     read_node->set_requested_device(variables[i].node->requested_device());
2891     read_node->set_assigned_device_name(
2892         variables[i].node->assigned_device_name());
2893     graph->AddEdge(variables[i].node, variables[i].index, read_node, 0);
2894 
2895     graph->AddControlEdge(control_predecessor, read_node);
2896   }
2897   return OkStatus();
2898 }
2899 
ContainsResourceWriteOp(const Graph & graph,const FunctionLibraryDefinition & fld)2900 bool DistributedTPURewritePass::ContainsResourceWriteOp(
2901     const Graph& graph, const FunctionLibraryDefinition& fld) {
2902   for (const Node* n : graph.nodes()) {
2903     const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n->type_string());
2904     if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
2905       VLOG(2) << "Found write resource op inside computation";
2906       return true;
2907     }
2908   }
2909   for (const string& func_name : fld.ListFunctionNames()) {
2910     const FunctionDef* func_def = fld.Find(func_name);
2911     for (const NodeDef& n : func_def->node_def()) {
2912       const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.op());
2913       if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
2914         VLOG(2) << "Found write resource op inside " << func_name;
2915         return true;
2916       }
2917     }
2918   }
2919   return false;
2920 }
2921 
BuildVariableWrites(absl::Span<const VariableInput> variables,Node * control_successor,absl::Span<const VariableWrite> variable_writes,Graph * graph)2922 Status DistributedTPURewritePass::BuildVariableWrites(
2923     absl::Span<const VariableInput> variables, Node* control_successor,
2924     absl::Span<const VariableWrite> variable_writes, Graph* graph) {
2925   CHECK_EQ(variables.size(), variable_writes.size());
2926   for (int i = 0; i < variables.size(); ++i) {
2927     const VariableWrite& write = variable_writes[i];
2928     NodeDebugInfo debug_info(*variables[i].node);
2929 
2930     auto name = [&](string suffix) {
2931       return graph->NewName(
2932           strings::StrCat(variables[i].node->name(), "/", suffix));
2933     };
2934 
2935     Node* write_node;
2936     TF_RETURN_IF_ERROR(
2937         IncompleteNodeDefBuilder(name("assign"), "AssignVariableOp", debug_info)
2938             .AddAttr("dtype", variables[i].dtype)
2939             .Device(variables[i].node->assigned_device_name())
2940             .Build(graph, &write_node));
2941 
2942     // Colocate the control flow with the variable.
2943     CondBuilder cb(variables[i].node->name(),
2944                    variables[i].node->assigned_device_name(), debug_info,
2945                    graph);
2946 
2947     // Inputs to conditional.
2948     Node* switch_val;
2949     TF_RETURN_IF_ERROR(
2950         cb.AddInput("switch_val", variables[i].dtype,
2951                     /*device=*/write.value->assigned_device_name(), debug_info,
2952                     &switch_val));
2953     Node* switch_var;
2954     TF_RETURN_IF_ERROR(
2955         cb.AddInput("switch_var", DT_RESOURCE,
2956                     /*device=*/variables[i].node->assigned_device_name(),
2957                     debug_info, &switch_var));
2958     // Conditionally write the value back.
2959     graph->AddEdge(variables[i].node, variables[i].index, switch_var, 0);
2960     graph->AddEdge(switch_var, CondBuilder::kThenBranch, write_node, 0);
2961     graph->AddEdge(switch_val, CondBuilder::kThenBranch, write_node, 1);
2962     // Add control edge from the write to value that will be merged. There is no
2963     // output from the write so this control edge ensures the write completes.
2964     graph->AddControlEdge(write_node, cb.switch_t());
2965 
2966     graph->AddControlEdge(cb.control_successor(), control_successor);
2967 
2968     graph->AddEdge(write.predicate, write.predicate_output, cb.pred(), 0);
2969     graph->AddEdge(write.value, write.value_output, switch_val, 0);
2970   }
2971   return OkStatus();
2972 }
2973 
2974 namespace {
2975 
2976 // Computes the shape of the sharded tensor and modifies in place.
ComputeShardedArgShapes(TensorShape * shape,const xla::OpSharding & sharding)2977 Status ComputeShardedArgShapes(TensorShape* shape,
2978                                const xla::OpSharding& sharding) {
2979   if (sharding.type() != xla::OpSharding::OTHER) {
2980     return OkStatus();
2981   }
2982   if (!shape->IsFullyDefined()) {
2983     return errors::Internal(
2984         "Arg shape must be fully defined before sharded shape inference.");
2985   }
2986   int sharded_rank = sharding.tile_assignment_dimensions_size();
2987   if (sharding.replicate_on_last_tile_dim()) {
2988     sharded_rank--;
2989   }
2990   for (int dim_idx = 0; dim_idx < sharded_rank; ++dim_idx) {
2991     auto sharded_dim = tensorflow::MathUtil::CeilOfRatio<int64_t>(
2992         shape->dim_size(dim_idx), sharding.tile_assignment_dimensions(dim_idx));
2993     shape->set_dim(dim_idx, sharded_dim);
2994   }
2995   if (sharded_rank != shape->dims()) {
2996     LOG(WARNING) << "Rank of sharded arg should match sharding spec.  Rank: "
2997                  << sharded_rank << ", tiled shape: " << shape->DebugString()
2998                  << ", sharding: " << sharding.DebugString();
2999   }
3000 
3001   return OkStatus();
3002 }
3003 
3004 // Creates nodes for zero-initialized dummy arguments for TPUExecute nodes.
CreateTpuExecuteDummyArg(const TensorShape & var_shape,const DataType & dtype,const string & host_cpu_device,Node * var_read,int replica_id,Graph * graph)3005 xla::StatusOr<Node*> CreateTpuExecuteDummyArg(const TensorShape& var_shape,
3006                                               const DataType& dtype,
3007                                               const string& host_cpu_device,
3008                                               Node* var_read, int replica_id,
3009                                               Graph* graph) {
3010   Status status;
3011 
3012   // Const - shape_as_tensor
3013   const std::string name_prefix = strings::StrCat(
3014       var_read->name(), absl::StrFormat("/dummy_%d", replica_id));
3015   NodeDef shape_tensor_def;
3016   shape_tensor_def.set_op("Const");
3017   shape_tensor_def.set_name(graph->NewName(
3018       strings::StrCat(name_prefix, "/Initializer/zeros/shape_as_tensor")));
3019   shape_tensor_def.set_device(host_cpu_device);
3020   AddNodeAttr("dtype", DT_INT32, &shape_tensor_def);
3021   TensorProto tensorshape_proto;
3022   tensorshape_proto.set_dtype(DT_INT32);
3023   for (int i = 0; i < var_shape.dims(); ++i) {
3024     tensorshape_proto.add_int_val(var_shape.dim_size(i));
3025   }
3026   TensorShape shape_shape({var_shape.dims()});
3027   shape_shape.AsProto(tensorshape_proto.mutable_tensor_shape());
3028   AddNodeAttr("value", tensorshape_proto, &shape_tensor_def);
3029   TF_ASSIGN_OR_RETURN(Node * shape_as_tensor_node,
3030                       graph->AddNode(shape_tensor_def));
3031 
3032   // Const - initializer value
3033   NodeDef init_val_def;
3034   init_val_def.set_op("Const");
3035   init_val_def.set_name(graph->NewName(
3036       strings::StrCat(name_prefix, "/Initializer/zeros/const_val")));
3037   init_val_def.set_device(host_cpu_device);
3038   TensorProto tensor_proto;
3039   tensor_proto.set_dtype(dtype);
3040   if (dtype == DT_FLOAT) {
3041     tensor_proto.add_float_val(0.0f);
3042   } else if (dtype == DT_BFLOAT16) {
3043     tensor_proto.add_half_val(0);
3044   } else if (dtype == DT_INT32) {
3045     tensor_proto.add_int_val(0);
3046   } else if (dtype == DT_BOOL) {
3047     tensor_proto.add_bool_val(false);
3048   } else {
3049     return errors::Internal(
3050         "Unable to create zero-init dummy arg tensor for type ", dtype);
3051   }
3052   TensorShape scalar_shape({});
3053   scalar_shape.AsProto(tensor_proto.mutable_tensor_shape());
3054   AddNodeAttr("value", tensor_proto, &init_val_def);
3055   AddNodeAttr("dtype", dtype, &init_val_def);
3056   TF_ASSIGN_OR_RETURN(Node * init_val_node, graph->AddNode(init_val_def));
3057 
3058   // Fill node
3059   NodeDef fill_def;
3060   fill_def.set_op("Fill");
3061   fill_def.set_device(host_cpu_device);
3062   fill_def.set_name(
3063       graph->NewName(strings::StrCat(name_prefix, "/Initializer/zeros")));
3064   AddNodeAttr("T", dtype, &fill_def);
3065   AddNodeAttr("index_type", DT_INT32, &fill_def);
3066   TF_ASSIGN_OR_RETURN(Node * fill_node, graph->AddNode(fill_def));
3067   graph->AddEdge(shape_as_tensor_node, 0, fill_node, 0);
3068   graph->AddEdge(init_val_node, 0, fill_node, 1);
3069 
3070   return fill_node;
3071 }
3072 
3073 // Creates dummy inputs for partitioned variables that are using XLA broadcast
3074 // for inputs.
CreatePartitionedDummyVarArgs(const xla::OpSharding & sharding,const int num_replicas,const int replica_id,const InferredShape & raw_shape,Node * orig_var_read,const int orig_arg_num,DataType dtype,const string & device,Graph * graph,const std::vector<std::vector<string>> & tpu_device_names,absl::btree_map<ShardedPerHostInputIndex,Node * > * per_host_index,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)3075 Status CreatePartitionedDummyVarArgs(
3076     const xla::OpSharding& sharding, const int num_replicas,
3077     const int replica_id, const InferredShape& raw_shape, Node* orig_var_read,
3078     const int orig_arg_num, DataType dtype, const string& device, Graph* graph,
3079     const std::vector<std::vector<string>>& tpu_device_names,
3080     absl::btree_map<ShardedPerHostInputIndex, Node*>* per_host_index,
3081     std::map<ShardedInputIndex, ShardedInputInfo>*
3082         arg_index_to_sharded_input_map) {
3083   ShardedInputIndex input_index{replica_id, orig_arg_num};
3084   auto iter = arg_index_to_sharded_input_map->find(input_index);
3085   if (iter != arg_index_to_sharded_input_map->end()) {
3086     return OkStatus();
3087   }
3088   const int repeat = sharding.replicate_on_last_tile_dim()
3089                          ? *sharding.tile_assignment_dimensions().rbegin()
3090                          : 1;
3091   const int num_shards = sharding.tile_assignment_devices_size() / repeat;
3092 
3093   TensorShape var_shape;
3094   if (!raw_shape.handle_shape.AsTensorShape(&var_shape) &&
3095       !raw_shape.shape.AsTensorShape(&var_shape)) {
3096     return errors::FailedPrecondition("Failed to read arg shape.");
3097   }
3098   TF_RETURN_IF_ERROR(ComputeShardedArgShapes(&var_shape, sharding));
3099 
3100   for (int replica = 1; replica < num_replicas; ++replica) {
3101     std::vector<NodeOut> sharded_inputs_list(
3102         sharding.tile_assignment_devices_size());
3103     for (int i = 0; i < num_shards; ++i) {
3104       for (int j = 0; j < repeat; ++j) {
3105         const int index = i * repeat + j;
3106         const int core = sharding.tile_assignment_devices(index);
3107         string host_device;
3108         TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
3109             tpu_device_names[replica][core], &host_device));
3110         ShardedPerHostInputIndex idx{host_device, orig_arg_num};
3111         if (!per_host_index->contains(idx)) {
3112           TF_ASSIGN_OR_RETURN(
3113               auto dummy_node,
3114               CreateTpuExecuteDummyArg(var_shape, dtype, host_device,
3115                                        orig_var_read, replica, graph));
3116           (*per_host_index)[idx] = dummy_node;
3117         }
3118         sharded_inputs_list[core] = {(*per_host_index)[idx], /*index=*/0};
3119       }
3120     }
3121     ShardedInputInfo sharded_input_info{nullptr,
3122                                         std::move(sharded_inputs_list)};
3123     (*arg_index_to_sharded_input_map)[{replica, orig_arg_num}] =
3124         sharded_input_info;
3125   }
3126 
3127   return OkStatus();
3128 }
3129 
3130 // Helper that creates an IdentityN node containing all of the variables
3131 // values on CPU device 'device', except for those that will be split across
3132 // cores. (For split variables, this may cause additional cross-host data
3133 // transfers if more than 1 devices share the same variable partition on a
3134 // remote host.)
3135 //
3136 // A previous iteration of this code built one Identity node per TPU core per
3137 // variable, but this can rapidly become hundreds of thousands of nodes. This
3138 // formulation creates a single IdentityN node containing all of the variables
3139 // on each host. This may cause some unnecessary variable copies if only a
3140 // subset of hosts consume a given variable, but has the virtue of being
3141 // simple, and most models use pure replication where all cores want all the
3142 // variables.
3143 //
3144 // If enable_xla_param_broadcast is set to true, then per-host dummy
3145 // tensor args are created on all hosts except for the primary host. In this
3146 // scheme, the dummy args feed the IdentityN node on their local host. All
3147 // are zero-initialized.
3148 //
3149 // Returns the node and its output index to be consumed by TPUExecute for the
3150 // requested variable index.
CreateOrGetPerHostVariableCopy(const string & host_cpu_device,int64_t var_index,const std::vector<Node * > & variable_reads,const DistributedTPURewritePass::ParameterInfo & params_info,const std::vector<xla::OpSharding> & arg_shardings,const Node & replicate_node,const bool enable_xla_param_broadcast,const bool mpmd,const int num_cores_per_replica,int replica_id,const std::vector<InferredShape> & arg_shapes,absl::flat_hash_map<string,std::vector<NodeOut>> * per_host_var_copies,Graph * graph)3151 xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
3152     const string& host_cpu_device, int64_t var_index,
3153     const std::vector<Node*>& variable_reads,
3154     const DistributedTPURewritePass::ParameterInfo& params_info,
3155     const std::vector<xla::OpSharding>& arg_shardings,
3156     const Node& replicate_node, const bool enable_xla_param_broadcast,
3157     const bool mpmd, const int num_cores_per_replica, int replica_id,
3158     const std::vector<InferredShape>& arg_shapes,
3159     absl::flat_hash_map<string, std::vector<NodeOut>>* per_host_var_copies,
3160     Graph* graph) {
3161   auto it = per_host_var_copies->find(host_cpu_device);
3162   if (it != per_host_var_copies->end()) {
3163     return it->second[var_index];
3164   }
3165 
3166   DataTypeVector dtypes;
3167   // Per-variable data source for TPUExecute.
3168   std::vector<NodeOut> index_mapping;
3169   index_mapping.reserve(variable_reads.size());
3170   dtypes.reserve(variable_reads.size());
3171   for (int64_t i = 0; i < variable_reads.size(); ++i) {
3172     Node* read = variable_reads[i];
3173     int64_t orig_arg_num = i + params_info.NumPerReplicaArgs() +
3174                            params_info.NumDistributedArgs() +
3175                            params_info.NumBroadcastArgs();
3176     if (arg_shardings[orig_arg_num].type() != xla::OpSharding::OTHER) {
3177       // We haven't built the IdentityN node yet, so temporarily use nullptr.
3178       index_mapping.push_back(
3179           NodeOut{nullptr, static_cast<int>(dtypes.size())});
3180       dtypes.push_back(read->output_type(0));
3181     } else {
3182       // Do not copy the full tensor of partitioned variables.
3183       index_mapping.push_back(NodeOut{read, 0});
3184     }
3185   }
3186   NodeDef ndef;
3187   ndef.set_name(graph->NewName(
3188       absl::StrCat(replicate_node.name(), "/", kTpuExecuteStagingNodeName)));
3189   ndef.set_op(kTpuExecuteStagingOp);
3190   ndef.set_device(host_cpu_device);
3191   AddNodeAttr("T", dtypes, &ndef);
3192   // TF meta-optimizer should skip this node for constant folding.
3193   AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &ndef);
3194   TF_ASSIGN_OR_RETURN(Node * id_node, graph->AddNode(ndef));
3195   id_node->set_assigned_device_name(host_cpu_device);
3196 
3197   for (int64_t i = 0; i < variable_reads.size(); ++i) {
3198     Node* read = variable_reads[i];
3199     int64_t orig_arg_num = i + params_info.NumPerReplicaArgs() +
3200                            params_info.NumDistributedArgs() +
3201                            params_info.NumBroadcastArgs();
3202     DataType dtype = read->output_type(0);
3203     bool use_xla_broadcast =
3204         EnableXlaParamBroadcast(enable_xla_param_broadcast, mpmd, params_info,
3205                                 orig_arg_num, dtype) &&
3206         replica_id != 0;
3207     if (index_mapping[i].node == nullptr) {
3208       // Fill index_mapping with the actual IdentityN node.
3209       index_mapping[i].node = id_node;
3210       if (!use_xla_broadcast) {
3211         // Add the variable read edge to id_node.
3212         graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
3213       } else {
3214         // XLA param broadcast mode is enabled.  Create zero-valued dummy
3215         // tensors to use as variable args in the TPUExecuteOp, instead of
3216         // original variable reads.
3217         TensorShape var_shape;
3218         auto inferred_shape = arg_shapes[orig_arg_num];
3219         if (!inferred_shape.handle_shape.AsTensorShape(&var_shape) &&
3220             !inferred_shape.shape.AsTensorShape(&var_shape)) {
3221           return errors::FailedPrecondition("Failed to read arg shape.");
3222         }
3223         TF_ASSIGN_OR_RETURN(
3224             Node * dummy_read,
3225             CreateTpuExecuteDummyArg(var_shape, dtype, host_cpu_device,
3226                                      variable_reads[i], replica_id, graph));
3227         graph->AddEdge(dummy_read, 0, id_node, index_mapping[i].index);
3228       }
3229     }
3230   }
3231 
3232   auto result = index_mapping[var_index];
3233   (*per_host_var_copies)[host_cpu_device] = std::move(index_mapping);
3234   return result;
3235 }
3236 
3237 }  // namespace
3238 
BuildExecuteNodes(const ParameterInfo & params_info,int num_tasks,int num_cores_per_replica,const Node & replicate_node,const std::vector<std::string> & arg_names,const DataTypeVector & arg_types,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & retval_types,const std::vector<xla::OpSharding> & arg_shardings,const std::vector<xla::OpSharding> & retval_shardings,const std::vector<std::vector<string>> & tpu_device_names,Node * compile_node,const std::vector<Node * > & variable_reads,Node * control_predecessor,Node * control_successor,Node * multilock_acquire,std::vector<VariableWrite> * variable_writes,Graph * graph)3239 Status DistributedTPURewritePass::BuildExecuteNodes(
3240     const ParameterInfo& params_info, int num_tasks, int num_cores_per_replica,
3241     const Node& replicate_node, const std::vector<std::string>& arg_names,
3242     const DataTypeVector& arg_types,
3243     const std::vector<InferredShape>& arg_shapes,
3244     const DataTypeVector& retval_types,
3245     const std::vector<xla::OpSharding>& arg_shardings,
3246     const std::vector<xla::OpSharding>& retval_shardings,
3247     const std::vector<std::vector<string>>& tpu_device_names,
3248     Node* compile_node, const std::vector<Node*>& variable_reads,
3249     Node* control_predecessor, Node* control_successor, Node* multilock_acquire,
3250     std::vector<VariableWrite>* variable_writes, Graph* graph) {
3251   VLOG(1) << "BuildExecuteNodes " << replicate_node.DebugString();
3252   TF_RET_CHECK(params_info.NumReplicas() == tpu_device_names.size());
3253 
3254   const int num_variables = variable_reads.size();
3255   const int num_retvals_per_replica = retval_types.size();
3256 
3257   variable_writes->resize(num_variables);
3258 
3259   std::vector<const Edge*> replicate_input_edges;
3260   TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
3261 
3262   // Map from replicate input index to the fan_in node;
3263   absl::flat_hash_map<int, std::vector<NodeAndPort>>
3264       replicate_input_fan_in_nodes;
3265   absl::flat_hash_map<int, std::vector<Node*>> replicate_output_fan_out_nodes;
3266   absl::flat_hash_map<int, std::vector<int>>
3267       replicate_output_fan_out_dst_inputs;
3268   std::vector<Node*> to_be_removed_nodes;
3269 
3270   const bool use_spmd =
3271       UseSpmdForXlaPartitioning(&replicate_node) && allow_xla_spmd_partition_;
3272   const bool mpmd = (num_cores_per_replica > 1) && !use_spmd;
3273 
3274   for (const Edge* e : replicate_input_edges) {
3275     if (e->src()->type_string() == kTPUPartitionedInput) {
3276       int num_users = 0;
3277       for (const auto& ue : e->src()->out_edges()) {
3278         if (!ue->IsControlEdge()) ++num_users;
3279       }
3280       if (num_users != 1) {
3281         return tensorflow::errors::InvalidArgument(
3282             e->src()->name(), " must only have one user. Found ", num_users);
3283       }
3284       to_be_removed_nodes.push_back(e->src());
3285       std::vector<NodeAndPort>& nodes =
3286           replicate_input_fan_in_nodes[e->dst_input()];
3287       nodes.resize(num_cores_per_replica, NodeAndPort(nullptr, 0));
3288       VLOG(2) << "allocate " << num_cores_per_replica
3289               << " for replicate_input_fan_in_nodes[" << e->dst_input() << "]";
3290       std::vector<const Edge*> fan_in_edges;
3291       TF_RETURN_IF_ERROR(e->src()->input_edges(&fan_in_edges));
3292       TF_RET_CHECK(fan_in_edges.size() == num_cores_per_replica);
3293 
3294       for (const Edge* fe : fan_in_edges) {
3295         nodes[fe->dst_input()].node = fe->src();
3296         nodes[fe->dst_input()].port = fe->src_output();
3297         VLOG(2) << "replicate_input_fan_in_nodes[" << e->dst_input() << "]["
3298                 << fe->dst_input() << "] = " << fe->src()->name();
3299       }
3300     }
3301   }
3302 
3303   // Replicate output edges are sorted by replica id and then by outputs for
3304   // each replica. For example, if TPU Computation has outputs (output_1,
3305   // output_2, and output_3) and number of replicas is 2, then
3306   // replicate_output_edges order would be:
3307   // output_1_replica_1, output_2_replica_1, output_3_replica_1,
3308   // output_1_replica_2, output_2_replica_2, output_3_replica_2.
3309   std::vector<const Edge*> replicate_output_edges(replicate_node.num_outputs(),
3310                                                   nullptr);
3311   for (const Edge* edge : replicate_node.out_edges()) {
3312     if (edge->IsControlEdge()) continue;
3313 
3314     int num_partitioned_outputs = 0;
3315 
3316     for (const Edge* out_edge : edge->dst()->out_edges()) {
3317       if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
3318         num_partitioned_outputs++;
3319         // Paths between replicate_node and replicate_output_fan_out_nodes:
3320         // ReplicateNode->TpuOutIdenity->kTPUPartitionedOutput->fan-out-nodes
3321         TF_RET_CHECK(edge->dst()->out_edges().size() == 1);
3322         to_be_removed_nodes.push_back(edge->dst());
3323         to_be_removed_nodes.push_back(out_edge->dst());
3324         // Get the right replicated id from the replicate_output_edge.
3325         std::vector<Node*>& nodes =
3326             replicate_output_fan_out_nodes[edge->src_output()];
3327         std::vector<int>& dst_inputs =
3328             replicate_output_fan_out_dst_inputs[edge->src_output()];
3329         nodes.resize(num_cores_per_replica, nullptr);
3330         dst_inputs.resize(num_cores_per_replica, 0);
3331         TF_RET_CHECK(out_edge->dst()->out_edges().size() ==
3332                      num_cores_per_replica);
3333 
3334         for (const Edge* fe : out_edge->dst()->out_edges()) {
3335           nodes[fe->src_output()] = fe->dst();
3336           dst_inputs[fe->src_output()] = fe->dst_input();
3337           VLOG(2) << "replicate_output_fan_out_nodes[" << out_edge->src_output()
3338                   << "][" << fe->src_output()
3339                   << "] = " << fe->dst()->DebugString() << " with dst_input "
3340                   << fe->dst_input();
3341         }
3342       }
3343     }
3344     replicate_output_edges[edge->src_output()] = edge;
3345     if (num_partitioned_outputs > 1) {
3346       return errors::InvalidArgument(
3347           "More than one TPUPartitionedOutput per replicated output.");
3348     }
3349   }
3350 
3351   const int num_execute_args =
3352       arg_shardings.size() - params_info.NumGuaranteedConstants();
3353   // Inverts the arg_shardings and retval_shardings mappings to
3354   // form core -> {argument number} maps.
3355   std::vector<std::vector<int>> core_arg_nums(num_cores_per_replica);
3356   for (int i = 0; i < num_execute_args; ++i) {
3357     const auto& sharding = arg_shardings[i];
3358     if (sharding.type() == xla::OpSharding::MAXIMAL) {
3359       int core = sharding.tile_assignment_devices(0);
3360       TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
3361       core_arg_nums[core].push_back(i);
3362     } else if (sharding.type() == xla::OpSharding::OTHER) {
3363       for (int64_t core : sharding.tile_assignment_devices()) {
3364         core_arg_nums[core].push_back(i);
3365       }
3366     } else if (sharding.type() == xla::OpSharding::REPLICATED) {
3367       for (int core = 0; core < num_cores_per_replica; ++core) {
3368         core_arg_nums[core].push_back(i);
3369       }
3370     } else {
3371       return tensorflow::errors::InvalidArgument(
3372           "Unsupported argument sharding for arg=", arg_names[i],
3373           " shape=", arg_shapes[i].shape.DebugString(), ": ",
3374           sharding.DebugString());
3375     }
3376   }
3377   std::vector<std::vector<int>> core_retval_nums(num_cores_per_replica);
3378   for (int i = 0; i < retval_shardings.size(); ++i) {
3379     const auto& sharding = retval_shardings[i];
3380     if (sharding.type() == xla::OpSharding::MAXIMAL) {
3381       int core = sharding.tile_assignment_devices(0);
3382       TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
3383       core_retval_nums[core].push_back(i);
3384     } else if (sharding.type() == xla::OpSharding::REPLICATED) {
3385       for (int core = 0; core < num_cores_per_replica; ++core) {
3386         core_retval_nums[core].push_back(i);
3387       }
3388     } else if (sharding.type() == xla::OpSharding::OTHER) {
3389       for (int64_t core : sharding.tile_assignment_devices()) {
3390         core_retval_nums[core].push_back(i);
3391       }
3392     } else {
3393       return tensorflow::errors::InvalidArgument(
3394           "Unsupported argument sharding: ", sharding.DebugString());
3395     }
3396   }
3397 
3398   // Maps host device name to a list of per-variable pairs (variable_copy_node,
3399   // output_index_of_copy_node).
3400   absl::flat_hash_map<string, std::vector<NodeOut>> per_host_var_copies;
3401 
3402   Node* execute_successor = control_successor;
3403 
3404   int num_total_cores = params_info.NumReplicas() * num_cores_per_replica;
3405   if (enable_multicore_locking_ && num_total_cores > 1) {
3406     // Add a node to release exclusive access once all the cores have finished
3407     // execution.
3408     NodeDef lock_def;
3409     lock_def.set_name(graph->NewName(
3410         strings::StrCat(compile_node->name(), "/", "tpu_release_multilock")));
3411     lock_def.set_op("ConsumeTpuMultilock");
3412     MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &lock_def);
3413     TF_ASSIGN_OR_RETURN(Node * multilock_release, graph->AddNode(lock_def));
3414     multilock_release->set_assigned_device_name(
3415         compile_node->assigned_device_name());
3416     TF_RET_CHECK(multilock_acquire != nullptr);
3417     graph->AddEdge(multilock_acquire, 0, multilock_release, 0);
3418     graph->AddControlEdge(multilock_release, control_successor);
3419     // Make sure all execute Ops happen before the multilock_release.
3420     execute_successor = multilock_release;
3421   }
3422 
3423   // Mapping from original resource arg number to a second level map. Second
3424   // level map is from core id to output index of updated variable value.
3425   absl::flat_hash_map<int, absl::flat_hash_map<int, int>>
3426       orig_arg_num_to_output_index_mapping;
3427   // Mapping from retval index to a second level map. Second level map is from
3428   // core id to output index of sharded output value.
3429   std::unordered_map<int, std::unordered_map<int, int>>
3430       retval_index_to_output_index_mapping;
3431 
3432   // Represents mapping of argument index of sharded input to each
3433   // TPUExecute node to its corresponding Split node and its output index
3434   // from which sharded input will be fed into TPUExecute node.
3435   std::map<ShardedInputIndex, ShardedInputInfo> input_index_to_sharded_inputs;
3436 
3437   // Additional map of {host, arg_num} to dummy input. Per-task copies of the
3438   // inputs reduces cross-task communication and allows sharing across replicas.
3439   absl::btree_map<ShardedPerHostInputIndex, Node*> sharded_per_host_index;
3440 
3441   // Builds one TPUExecute node per core per replica.
3442   std::vector<std::vector<Node*>> execute_nodes(params_info.NumReplicas());
3443   for (int core = 0; core < num_cores_per_replica; ++core) {
3444     DataTypeVector core_retval_types;
3445     for (int output : core_retval_nums[core]) {
3446       core_retval_types.push_back(retval_types[output]);
3447     }
3448     DataTypeVector core_arg_types;
3449     std::vector<int> core_variable_writes;
3450     for (int input : core_arg_nums[core]) {
3451       // Resource variables can be passed either by reference (as a DT_RESOURCE)
3452       // tensor or by value (as the variable's current value). Per-replica or
3453       // distributed resource arguments are always passed by reference and
3454       // broadcast variables are always passed by value.
3455       if (arg_types[input] == DT_RESOURCE &&
3456           !params_info.IsPerReplicaArg(input) &&
3457           !params_info.IsDistributedArg(input)) {
3458         DataType handle_type = arg_shapes[input].handle_type;
3459         TF_RET_CHECK(handle_type != DT_INVALID) << DataTypeString(handle_type);
3460         core_arg_types.push_back(handle_type);
3461         int base = input - params_info.NumPerReplicaArgs() -
3462                    params_info.NumDistributedArgs() -
3463                    params_info.NumBroadcastArgs();
3464         // Variables passed by value will have a corresponding additional output
3465         // containing an updated value for the variable.
3466         core_variable_writes.push_back(base);
3467         core_retval_types.push_back(handle_type);
3468       } else {
3469         core_arg_types.push_back(arg_types[input]);
3470       }
3471     }
3472 
3473     NodeDef def;
3474     def.set_op("TPUExecute");
3475     MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
3476     AddNodeAttr("Targs", core_arg_types, &def);
3477     AddNodeAttr("Tresults", core_retval_types, &def);
3478 
3479     // If the producer name was set during inference, propagate the information
3480     // to the TPUExecute op so it can be accessed during metric collection.
3481     std::string producer_name;
3482     Status status =
3483         GetNodeAttr(replicate_node.attrs(), "_producer_name", &producer_name);
3484     if (status.ok()) {
3485       AddNodeAttr("_producer_name", producer_name, &def);
3486     }
3487 
3488     for (int64_t replica = 0; replica < params_info.NumReplicas(); ++replica) {
3489       def.set_name(strings::StrCat(replicate_node.name(), "/_execute_", replica,
3490                                    "_", core));
3491 
3492       TF_ASSIGN_OR_RETURN(Node * node, graph->AddNode(def));
3493       execute_nodes[replica].push_back(node);
3494 
3495       node->set_assigned_device_name(tpu_device_names[replica][core]);
3496 
3497       // Add control edges to ensure that execution happens after
3498       // `control_predecessor`, happens before `execute_successor`, and is
3499       // triggered by evaluating any operator that depends on the original
3500       // TPUReplicate operator. See the comment at the top of the header file
3501       // for more details.
3502       graph->AddControlEdge(control_predecessor, node);
3503       graph->AddControlEdge(node, execute_successor);
3504 
3505       // Add data input edges.
3506       for (int64_t i = 0; i < core_arg_nums[core].size(); ++i) {
3507         int64_t orig_arg_num = core_arg_nums[core][i];
3508         VLOG(2) << " replica " << replica << " core " << core << " i " << i
3509                 << " orig_arg_num " << orig_arg_num;
3510         const bool is_per_replica_arg =
3511             params_info.IsPerReplicaArg(orig_arg_num);
3512         if (is_per_replica_arg || params_info.IsDistributedArg(orig_arg_num)) {
3513           // Per-replica input and distributed input
3514           const int64_t input_num =
3515               is_per_replica_arg ? replica * params_info.NumPerReplicaArgs() +
3516                                        core_arg_nums[core][i]
3517                                  : params_info.NumReplicas() *
3518                                            params_info.NumPerReplicaArgs() +
3519                                        core_arg_nums[core][i] -
3520                                        params_info.NumPerReplicaArgs();
3521 
3522           const Edge* edge = replicate_input_edges[input_num];
3523           VLOG(2) << "replicate_input_edges[" << input_num << "]";
3524           DataType dtype = edge->src()->output_type(edge->src_output());
3525           if (dtype == DT_RESOURCE) {
3526             DataType handle_dtype = arg_shapes[orig_arg_num].handle_type;
3527             if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(),
3528                           handle_dtype) == kTpuAllTypes.end()) {
3529               return errors::InvalidArgument(
3530                   "Unsupported resource variable data type for TPU: ",
3531                   DataTypeString(handle_dtype), ", caused by output ",
3532                   edge->src()->name(), ":", edge->src_output());
3533             }
3534           } else {
3535             if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3536                 kTpuAllTypes.end()) {
3537               return errors::InvalidArgument(
3538                   "Unsupported data type for TPU: ", DataTypeString(dtype),
3539                   ", caused by output ", edge->src()->name(), ":",
3540                   edge->src_output());
3541             }
3542           }
3543           if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
3544             // Don't automatically add a split node when input node is
3545             // kTPUPartitionedInput
3546             if (edge->src()->type_string() == kTPUPartitionedInput) {
3547               VLOG(2)
3548                   << "Connect "
3549                   << replicate_input_fan_in_nodes[input_num][core].node->name()
3550                   << " to " << node->name() << " at " << i;
3551               graph->AddEdge(replicate_input_fan_in_nodes[input_num][core].node,
3552                              replicate_input_fan_in_nodes[input_num][core].port,
3553                              node, i);
3554             } else {
3555               if (dtype == DT_RESOURCE) {
3556                 return errors::InvalidArgument(
3557                     "Tiled sharding for per-replica DT_RESOURCE input must",
3558                     "be TPUPartitionedInput. Here got ",
3559                     edge->src()->type_string());
3560               }
3561               const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
3562 
3563               ShardedInputInfo sharded_input_info;
3564               if (use_nd_sharding_ops_ && is_per_replica_arg) {
3565                 TF_ASSIGN_OR_RETURN(
3566                     sharded_input_info,
3567                     CreateOrGetXlaSplitNodeForShardedPerReplicaArg(
3568                         sharding, replica, orig_arg_num, dtype,
3569                         PartialTensorShape(), edge->src(), edge->src_output(),
3570                         graph, &input_index_to_sharded_inputs));
3571               } else if (use_nd_sharding_ops_) {
3572                 TF_ASSIGN_OR_RETURN(
3573                     sharded_input_info,
3574                     CreateOrGetXlaSplitNodeForDistributedArg(
3575                         sharding, params_info.NumReplicas(), replica,
3576                         orig_arg_num, dtype, PartialTensorShape(), edge->src(),
3577                         edge->src_output(), graph,
3578                         &input_index_to_sharded_inputs));
3579               } else {
3580                 TF_ASSIGN_OR_RETURN(
3581                     sharded_input_info,
3582                     CreateOrGetSplitNodesForInputSharding(
3583                         sharding, orig_arg_num, dtype, PartialTensorShape(),
3584                         replica, edge->src_output(), edge->src(),
3585                         control_predecessor, graph,
3586                         &input_index_to_sharded_inputs));
3587               }
3588 
3589               NodeOut split_node_and_index =
3590                   sharded_input_info.sharded_inputs.at(core);
3591               // Connect with Split node output.
3592               graph->AddEdge(split_node_and_index.node,
3593                              split_node_and_index.index, node, i);
3594             }
3595           } else if (edge->src()->type_string() == kTPUPartitionedInput &&
3596                      arg_shardings[orig_arg_num].type() ==
3597                          xla::OpSharding::REPLICATED) {
3598             graph->AddEdge(replicate_input_fan_in_nodes[input_num][core].node,
3599                            replicate_input_fan_in_nodes[input_num][core].port,
3600                            node, i);
3601           } else {
3602             graph->AddEdge(edge->src(), edge->src_output(), node, i);
3603           }
3604         } else if (params_info.IsBroadcastArg(orig_arg_num)) {
3605           // Broadcast input.
3606           int64_t input_num = params_info.FirstBroadcastArgFromHost() +
3607                               core_arg_nums[core][i] -
3608                               params_info.NumPerReplicaArgs() -
3609                               params_info.NumDistributedArgs();
3610           const Edge* edge = replicate_input_edges[input_num];
3611           DataType dtype = edge->src()->output_type(edge->src_output());
3612           if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3613               kTpuAllTypes.end()) {
3614             return errors::InvalidArgument(
3615                 "Unsupported data type for TPU: ", DataTypeString(dtype),
3616                 ", caused by output ", edge->src()->name(), ":",
3617                 edge->src_output());
3618           }
3619           graph->AddEdge(edge->src(), edge->src_output(), node, i);
3620         } else {
3621           // Variable input.
3622           int64_t variable_num =
3623               orig_arg_num - params_info.NumPerReplicaArgs() -
3624               params_info.NumDistributedArgs() - params_info.NumBroadcastArgs();
3625           TF_RET_CHECK(variable_num < num_variables);
3626 
3627           Node* variable_read = variable_reads[variable_num];
3628           DataType dtype = variable_read->output_type(0);
3629           if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3630               kTpuAllTypes.end()) {
3631             return errors::InvalidArgument(
3632                 "Unsupported resource variable data type for TPU: ",
3633                 DataTypeString(dtype), ", caused by ReadVariableOp ",
3634                 variable_read->DebugString());
3635           }
3636           DeviceNameUtils::ParsedName requested_device;
3637           string requested = variable_read->requested_device();
3638           TF_RET_CHECK(
3639               DeviceNameUtils::ParseFullName(requested, &requested_device));
3640           if (requested_device.type != "TPU") {
3641             // Stage the value via the CPU device on the remote host. The graph
3642             // partitioner will introduce an intermediate copy rather than
3643             // copying the same tensor multiple times across the network, and we
3644             // would prefer that intermediate copy to be in host memory to avoid
3645             // running out of memory if the TPUExecute op on the staging device
3646             // starts running before the _Send ops to the other TPU devices on
3647             // the same host complete. We don't do this if the variables are
3648             // already placed on TPU, otherwise it will cause an unnecessary
3649             // round trip copy.
3650             // TODO(b/79580121): give each replica its own on-device variable
3651             // replica and then delete this code.
3652             string device;
3653             TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
3654                 tpu_device_names[replica][core], &device));
3655             TF_ASSIGN_OR_RETURN(
3656                 auto var_data,
3657                 CreateOrGetPerHostVariableCopy(
3658                     device, variable_num, variable_reads, params_info,
3659                     arg_shardings, replicate_node, enable_xla_param_broadcast_,
3660                     mpmd, num_cores_per_replica, replica, arg_shapes,
3661                     &per_host_var_copies, graph));
3662 
3663             if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
3664               ShardedInputInfo sharded_input_info;
3665 
3666               if (EnableXlaParamBroadcast(enable_xla_param_broadcast_, mpmd,
3667                                           params_info, orig_arg_num, dtype)) {
3668                 // Populates the sharded dummy vars for non-zero replicas.
3669                 TF_RETURN_IF_ERROR(CreatePartitionedDummyVarArgs(
3670                     arg_shardings[orig_arg_num], params_info.NumReplicas(),
3671                     replica, arg_shapes[orig_arg_num], var_data.node,
3672                     orig_arg_num, dtype, device, graph, tpu_device_names,
3673                     &sharded_per_host_index, &input_index_to_sharded_inputs));
3674               }
3675 
3676               if (use_nd_sharding_ops_) {
3677                 TF_ASSIGN_OR_RETURN(
3678                     sharded_input_info,
3679                     CreateOrGetXlaSplitNodeForVariableArg(
3680                         arg_shardings[orig_arg_num], params_info.NumReplicas(),
3681                         replica, orig_arg_num,
3682                         arg_shapes[orig_arg_num].handle_type,
3683                         arg_shapes[orig_arg_num].handle_shape, var_data.node,
3684                         var_data.index, graph, &to_be_removed_nodes,
3685                         &input_index_to_sharded_inputs));
3686               } else {
3687                 TF_ASSIGN_OR_RETURN(
3688                     sharded_input_info,
3689                     CreateOrGetSplitNodesForInputSharding(
3690                         arg_shardings[orig_arg_num], orig_arg_num,
3691                         arg_shapes[orig_arg_num].handle_type,
3692                         arg_shapes[orig_arg_num].handle_shape, replica,
3693                         var_data.index, var_data.node, control_predecessor,
3694                         graph, &input_index_to_sharded_inputs));
3695               }
3696 
3697               NodeOut split_node_and_index =
3698                   sharded_input_info.sharded_inputs[core];
3699               // Connect with Split node output.
3700               graph->AddEdge(split_node_and_index.node,
3701                              split_node_and_index.index, node, i);
3702 
3703             } else {
3704               graph->AddEdge(var_data.node, var_data.index, node, i);
3705             }
3706           } else {
3707             graph->AddEdge(variable_reads[variable_num], 0, node, i);
3708           }
3709         }
3710       }
3711 
3712       // Adds a program input edge from the compiler.
3713       graph->AddEdge(compile_node, core + 1, node, node->num_inputs() - 1);
3714 
3715       // Add data output edges.
3716       int num_outputs = core_retval_nums[core].size();
3717       for (int i = 0; i < num_outputs; ++i) {
3718         int output_num =
3719             replica * num_retvals_per_replica + core_retval_nums[core][i];
3720         const auto& sharding = retval_shardings[core_retval_nums[core][i]];
3721         if (sharding.type() == xla::OpSharding::OTHER) {
3722           int retval_index = core_retval_nums[core][i];
3723           retval_index_to_output_index_mapping[retval_index][core] = i;
3724           bool is_last_core =
3725               core ==
3726               *std::max_element(sharding.tile_assignment_devices().begin(),
3727                                 sharding.tile_assignment_devices().end());
3728           bool isPartitionOutNode = false;
3729 
3730           const Edge* e = replicate_output_edges[output_num];
3731           const Edge* e_out;
3732           for (const Edge* out_edge : e->dst()->out_edges()) {
3733             if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
3734               isPartitionOutNode = true;
3735               e_out = out_edge;
3736             }
3737           }
3738           if (isPartitionOutNode) {
3739             graph->AddEdge(
3740                 node, i, replicate_output_fan_out_nodes[output_num][core],
3741                 replicate_output_fan_out_dst_inputs[output_num][core]);
3742             VLOG(2) << "Connect " << node->name() << " at " << i << " to "
3743                     << replicate_output_fan_out_nodes[output_num][core]->name()
3744                     << " at "
3745                     << replicate_output_fan_out_dst_inputs[output_num][core];
3746             if (is_last_core) {
3747               graph->RemoveEdge(e);
3748               graph->RemoveEdge(e_out);
3749             }
3750             continue;
3751           }
3752 
3753           // Do this in the iteration of last core in tile assignment, so all
3754           // TPUExecute nodes have been created.
3755           if (!is_last_core) {
3756             continue;
3757           }
3758 
3759           // Add a Concat node.
3760           std::vector<NodeOut> orig_inputs;
3761           for (int64_t tile_index = 0;
3762                tile_index < sharding.tile_assignment_devices_size();
3763                ++tile_index) {
3764             int64_t last_tile_dim_size =
3765                 *sharding.tile_assignment_dimensions().rbegin();
3766             if (sharding.replicate_on_last_tile_dim() &&
3767                 tile_index % last_tile_dim_size != 0) {
3768               continue;
3769             }
3770             int64_t core_id = sharding.tile_assignment_devices(tile_index);
3771             int core_retval_index =
3772                 retval_index_to_output_index_mapping[retval_index][core_id];
3773             orig_inputs.push_back(
3774                 NodeOut{execute_nodes[replica][core_id],
3775                         static_cast<int>(
3776                             core_retval_nums[core_id][core_retval_index])});
3777           }
3778           DataType dtype = e->src()->output_type(e->src_output());
3779           Node* concat_node = nullptr;
3780           if (use_nd_sharding_ops_) {
3781             TF_ASSIGN_OR_RETURN(
3782                 concat_node, CreateXlaConcatNode(
3783                                  sharding, replica, dtype,
3784                                  /*partial_tensor_shape=*/PartialTensorShape(),
3785                                  orig_inputs, /*device=*/"", graph));
3786           } else {
3787             TF_ASSIGN_OR_RETURN(
3788                 concat_node,
3789                 CreateConcatNodesForRetval(
3790                     sharding, dtype, /*inferred_shape=*/PartialTensorShape(),
3791                     replica, orig_inputs, graph, /*device=*/""));
3792           }
3793 
3794           const Edge* edge = replicate_output_edges[output_num];
3795           Node* dst = edge->dst();
3796           int dst_input = edge->dst_input();
3797           graph->RemoveEdge(edge);
3798           graph->AddEdge(concat_node, 0, dst, dst_input);
3799 
3800           continue;
3801         }
3802 
3803         // If this is a replicated output, outputs on all cores will be the
3804         // same, and we only take the output from core 0.
3805         if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
3806           continue;
3807         }
3808 
3809         // If output has maximal sharding, make sure we only use output from
3810         // TPUExecute node with logical core id equal to core id defined by the
3811         // xla sharding.
3812         if (sharding.type() == xla::OpSharding::MAXIMAL &&
3813             core != sharding.tile_assignment_devices(0)) {
3814           continue;
3815         }
3816 
3817         const Edge* replicate_edge_to_replace =
3818             replicate_output_edges[output_num];
3819         Node* dst = replicate_edge_to_replace->dst();
3820         int dst_input = replicate_edge_to_replace->dst_input();
3821         graph->RemoveEdge(replicate_edge_to_replace);
3822         graph->AddEdge(node, i, dst, dst_input);
3823       }
3824 
3825       // Feed the updated variable values from the first replica to the
3826       // variable write nodes.
3827       if (replica == 0) {
3828         for (int i = 0; i < core_variable_writes.size(); ++i) {
3829           int orig_arg_num =
3830               core_variable_writes[i] + params_info.NumPerReplicaArgs() +
3831               params_info.NumDistributedArgs() + params_info.NumBroadcastArgs();
3832           const auto& sharding = arg_shardings[orig_arg_num];
3833           // If this is a tiling sharded variable, concat variable updates from
3834           // all cores.
3835           if (sharding.type() == xla::OpSharding::OTHER) {
3836             orig_arg_num_to_output_index_mapping[orig_arg_num][core] = i;
3837 
3838             // Do this in the iteration of last core in tile assignment, so all
3839             // TPUExecute nodes have been created.
3840             if (core !=
3841                 *std::max_element(sharding.tile_assignment_devices().begin(),
3842                                   sharding.tile_assignment_devices().end())) {
3843               continue;
3844             }
3845 
3846             // Add a Concat node.
3847             std::vector<NodeOut> orig_inputs;
3848             for (int64_t tile_index = 0;
3849                  tile_index < sharding.tile_assignment_devices_size();
3850                  ++tile_index) {
3851               int64_t last_tile_dim_size =
3852                   *sharding.tile_assignment_dimensions().rbegin();
3853               if (sharding.replicate_on_last_tile_dim() &&
3854                   tile_index % last_tile_dim_size != 0) {
3855                 continue;
3856               }
3857               int64_t core_id = sharding.tile_assignment_devices(tile_index);
3858               int core_retval_num =
3859                   orig_arg_num_to_output_index_mapping[orig_arg_num][core_id];
3860               orig_inputs.push_back(
3861                   NodeOut{execute_nodes[0][core_id],
3862                           static_cast<int>(core_retval_nums[core_id].size() +
3863                                            core_retval_num)});
3864             }
3865 
3866             // Use the variable read's device for the concat. They should both
3867             // be collocated with the variable.
3868             absl::string_view device =
3869                 variable_reads[core_variable_writes[i]]->assigned_device_name();
3870             Node* concat_node = nullptr;
3871             if (use_nd_sharding_ops_) {
3872               TF_ASSIGN_OR_RETURN(
3873                   concat_node,
3874                   CreateXlaConcatNode(sharding, replica,
3875                                       arg_shapes[orig_arg_num].handle_type,
3876                                       arg_shapes[orig_arg_num].handle_shape,
3877                                       orig_inputs, device, graph));
3878             } else {
3879               TF_ASSIGN_OR_RETURN(
3880                   concat_node,
3881                   CreateConcatNodesForRetval(
3882                       sharding, arg_shapes[orig_arg_num].handle_type,
3883                       arg_shapes[orig_arg_num].handle_shape, replica,
3884                       orig_inputs, graph, device));
3885             }
3886             // Populate VariableWrite.
3887             VariableWrite& write = variable_writes->at(core_variable_writes[i]);
3888             write.value = concat_node;
3889             write.value_output = 0;
3890             write.predicate = compile_node;
3891             write.predicate_output = num_cores_per_replica + core + 1;
3892 
3893             continue;
3894           }
3895 
3896           // If this is a replicated variable, outputs on all cores will be the
3897           // same, and we only take the output from core 0 for the variable
3898           // update.
3899           if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
3900             continue;
3901           }
3902           VariableWrite& write = variable_writes->at(core_variable_writes[i]);
3903           write.value = node;
3904           write.value_output = num_outputs + i;
3905           write.predicate = compile_node;
3906           write.predicate_output = num_cores_per_replica + core + 1;
3907         }
3908       }
3909     }
3910   }
3911 
3912   for (Node* node : to_be_removed_nodes) {
3913     graph->RemoveNode(node);
3914   }
3915   return OkStatus();
3916 }  // NOLINT(readability/fn_size)
3917 
CopyOutsideCompilationNodes(int replica_index,const std::vector<Node * > & outside_compilation_nodes,const DeviceNameUtils::ParsedName & tpu_device,const DeviceNameUtils::ParsedName & partial_device,NodeToNodeReplicasMap * node_images,Graph * graph)3918 /* static */ Status DistributedTPURewritePass::CopyOutsideCompilationNodes(
3919     int replica_index, const std::vector<Node*>& outside_compilation_nodes,
3920     const DeviceNameUtils::ParsedName& tpu_device,
3921     const DeviceNameUtils::ParsedName& partial_device,
3922     NodeToNodeReplicasMap* node_images, Graph* graph) {
3923   for (Node* node : outside_compilation_nodes) {
3924     NodeDef image_def = node->def();
3925     MergeDebugInfo(NodeDebugInfo(node->def()), &image_def);
3926     const string suffix = strings::StrCat("/R", replica_index);
3927     // In addition to node name, make the frame name unique to avoid multiple
3928     // LoopCond nodes in one frame.
3929     TF_RETURN_IF_ERROR(
3930         AddPrefixAndSuffixToNode("" /* prefix */, suffix, &image_def));
3931     TF_ASSIGN_OR_RETURN(Node * image, graph->AddNode(image_def));
3932     image->AddAttr(kXlaReplicaIdAttrName, replica_index);
3933     if (HasNodeAttr(image->def(), kXlaHasHostTransferAttrName)) {
3934       TF_RETURN_IF_ERROR(
3935           SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, image));
3936     } else {
3937       const string& original_device_string =
3938           node->assigned_device_name().empty() ? node->requested_device()
3939                                                : node->assigned_device_name();
3940       DeviceNameUtils::ParsedName device;
3941       TF_RET_CHECK(
3942           DeviceNameUtils::ParseFullName(original_device_string, &device));
3943       // If the requested device can be merged with the replica's host device,
3944       // then do so. For example, if the requested device is "/CPU:0" or
3945       // "/GPU:0" then it will be placed on the CPU/GPU of the host where this
3946       // replica is running. But if the requested device is
3947       // "/task:3/replica:2/CPU:0" then it will be placed on that task/replica.
3948       if (DeviceNameUtils::IsSpecification(device, partial_device)) {
3949         TF_RETURN_IF_ERROR(
3950             DeviceNameUtils::MergeDevNames(&device, partial_device));
3951       }
3952       image->set_requested_device(DeviceNameUtils::ParsedNameToString(device));
3953     }
3954     std::vector<Node*>& node_image_vector = (*node_images)[node];
3955     node_image_vector.resize(replica_index + 1);
3956     node_image_vector[replica_index] = image;
3957   }
3958   return OkStatus();
3959 }
3960 
ReplicateOutsideCompilationNodes(const std::vector<std::vector<string>> & tf_device_assignment,const HostComputeCoreMap & host_compute_core,const OutsideCompilationNodeMap & outside_compilation_nodes,NodeToNodeReplicasMap * node_images,Graph * graph)3961 /* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationNodes(
3962     const std::vector<std::vector<string>>& tf_device_assignment,
3963     const HostComputeCoreMap& host_compute_core,
3964     const OutsideCompilationNodeMap& outside_compilation_nodes,
3965     NodeToNodeReplicasMap* node_images, Graph* graph) {
3966   // Iterate over replicas.
3967   for (int i = 0; i < tf_device_assignment.size(); ++i) {
3968     const auto& core_devices = tf_device_assignment[i];
3969     for (const auto& oc_cluster_iter : outside_compilation_nodes) {
3970       const string& oc_cluster_name = oc_cluster_iter.first;
3971       const auto& oc_cluster_nodes = oc_cluster_iter.second;
3972       // We previously validated that host_compute_core contains an entry for
3973       // each cluster.
3974       int core = host_compute_core.at(oc_cluster_name);
3975       TF_RET_CHECK(core >= 0 && core < core_devices.size());
3976       // tpu_device is the device the HostCompute XLA Op for this cluster runs
3977       // on.
3978       DeviceNameUtils::ParsedName tpu_device;
3979       TF_RET_CHECK(
3980           DeviceNameUtils::ParseFullName(core_devices[core], &tpu_device));
3981       // partial_device contains the replica and task but not the type.
3982       DeviceNameUtils::ParsedName partial_device = tpu_device;
3983       partial_device.has_type = false;
3984       partial_device.has_id = false;
3985 
3986       if (tf_device_assignment.size() == 1) {
3987         // With a single replica don't copy any nodes just put the original
3988         // nodes into the image map. We leave the device placement alone, except
3989         // that we have to fill in the correct core for the host send and
3990         // receive nodes.
3991         for (Node* node : oc_cluster_nodes) {
3992           (*node_images)[node] = {node};
3993           node->AddAttr(kXlaReplicaIdAttrName, 0);
3994           if (HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
3995             TF_RETURN_IF_ERROR(
3996                 SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, node));
3997           }
3998         }
3999       } else {
4000         // Iterate over outside_compilation clusters in this computation, adding
4001         // all the nodes with appropriate device assignments.
4002         TF_RETURN_IF_ERROR(
4003             CopyOutsideCompilationNodes(i, oc_cluster_nodes, tpu_device,
4004                                         partial_device, node_images, graph));
4005       }
4006     }
4007   }
4008   return OkStatus();
4009 }
4010 
CopyOutsideCompilationEdges(const std::vector<Node * > & outside_compilation_nodes,const NodeToNodeReplicasMap & node_images,const std::unordered_map<string,Node * > outside_compilation_inputs,Graph * graph)4011 /* static */ Status DistributedTPURewritePass::CopyOutsideCompilationEdges(
4012     const std::vector<Node*>& outside_compilation_nodes,
4013     const NodeToNodeReplicasMap& node_images,
4014     const std::unordered_map<string, Node*> outside_compilation_inputs,
4015     Graph* graph) {
4016   for (Node* node : outside_compilation_nodes) {
4017     const auto& images = node_images.at(node);
4018     // Make a copy of all edges and iterate on "in_edges", because we might
4019     // remove edges when iteratating through them.
4020     std::vector<const Edge*> in_edges(node->in_edges().begin(),
4021                                       node->in_edges().end());
4022     for (const Edge* edge : in_edges) {
4023       Node* src = edge->src();
4024       const auto iter = node_images.find(src);
4025       if (iter == node_images.end()) {
4026         if (images.size() > 1) {
4027           // The source node is a 'normal' node not part of any
4028           // rewrite. Broadcast the value to all replicas. (If images.size() ==
4029           // 1 the cluster is not replicated and we can leave the original edge
4030           // in place.)
4031           for (Node* dst : images) {
4032             graph->AddEdge(src, edge->src_output(), dst, edge->dst_input());
4033           }
4034         }
4035         continue;
4036       }
4037 
4038       // The source node is a replicated outside_compilation node.
4039       const auto& src_images = iter->second;
4040       if (src_images.size() != images.size()) {
4041         return errors::InvalidArgument(
4042             "Graph contains an edge from node ", src->name(),
4043             " in an outside_compilation block replicated ", src_images.size(),
4044             " ways to node ", node->name(),
4045             " in an outside_compilation block replicated ", images.size(),
4046             " ways. Replication factors must match. Leave a comment on "
4047             "tracking bug b/76419636 if you need this to be supported.");
4048       }
4049       bool is_lifted_arg;
4050       string outside_compilation_cluster;
4051       if (GetNodeAttr(src->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg)
4052               .ok() &&
4053           GetNodeAttr(src->def(), kOutsideCompilationAttr,
4054                       &outside_compilation_cluster)
4055               .ok()) {
4056         const auto input_iter =
4057             outside_compilation_inputs.find(outside_compilation_cluster);
4058         TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
4059         TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
4060         int dst_input = edge->dst_input();
4061         if (src_images.size() == 1) {
4062           graph->RemoveEdge(edge);
4063         }
4064         for (int i = 0; i < src_images.size(); ++i) {
4065           graph->AddEdge(input_iter->second, i, images[i], dst_input);
4066         }
4067         continue;
4068       }
4069 
4070       bool is_placeholder_for_arg;
4071       string outside_compilation_input_attr;
4072       if (GetNodeAttr(src->def(), kXlaIsPlaceholderForArg,
4073                       &is_placeholder_for_arg)
4074               .ok() &&
4075           GetNodeAttr(src->def(), kXlaOutsideCompilationInputsAttrName,
4076                       &outside_compilation_input_attr)
4077               .ok()) {
4078         const auto input_iter =
4079             outside_compilation_inputs.find(outside_compilation_input_attr);
4080         TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
4081         TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
4082         int dst_input = edge->dst_input();
4083         if (src_images.size() == 1) {
4084           graph->RemoveEdge(edge);
4085         }
4086         for (int i = 0; i < src_images.size(); ++i) {
4087           graph->AddEdge(input_iter->second, i, images[i], dst_input);
4088         }
4089         continue;
4090       }
4091 
4092       if (images.size() > 1) {
4093         // If images.size() == 1 neither cluster is replicated and we can
4094         // leave the original edges in place.
4095         for (int i = 0; i < src_images.size(); ++i) {
4096           graph->AddEdge(src_images[i], edge->src_output(), images[i],
4097                          edge->dst_input());
4098         }
4099       }
4100     }
4101     for (const Edge* edge : node->out_edges()) {
4102       Node* dst = edge->dst();
4103       const auto iter = node_images.find(dst);
4104       if (iter == node_images.end()) {
4105         // The source node is a 'normal' node not part of any rewrite.
4106         if (edge->IsControlEdge()) {
4107           // Make the dst node have a control dependency on every replica.
4108           if (images.size() > 1) {
4109             for (int i = 0; i < images.size(); ++i) {
4110               graph->AddControlEdge(images[i], dst);
4111             }
4112           }
4113           // else the cluster is not replicated so we can leave the original
4114           // edge in place.
4115         } else {
4116           // The edge
4117           // is only valid if the outside_compilation block is not replicated.
4118           if (images.size() > 1) {
4119             return errors::InvalidArgument(
4120                 "Graph contains an edge from node ", node->name(),
4121                 " in an outside_compilation block replicated ", images.size(),
4122                 " ways to node ", dst->name(),
4123                 " that is not part of an outside_compilation block. Edges from "
4124                 "outside_compilation to regular graph nodes are only supported "
4125                 "for replication factors of 1. Leave a comment on tracking bug "
4126                 "b/76419636 if you need this to be supported.");
4127           }
4128           // else the cluster is not replicated so we can leave the original
4129           // edge in place.
4130         }
4131       }
4132       // The case where src and dst are both in node_images is covered elsewhere
4133       // when iterating over in_edges of dst.
4134     }
4135   }
4136   return OkStatus();
4137 }
4138 
ReplicateOutsideCompilationEdges(const OutsideCompilationNodeMap & outside_compilation_nodes,const NodeToNodeReplicasMap & node_images,const std::unordered_map<string,Node * > outside_compilation_inputs,Graph * graph)4139 /* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationEdges(
4140     const OutsideCompilationNodeMap& outside_compilation_nodes,
4141     const NodeToNodeReplicasMap& node_images,
4142     const std::unordered_map<string, Node*> outside_compilation_inputs,
4143     Graph* graph) {
4144   for (const auto& oc_cluster_iter : outside_compilation_nodes) {
4145     TF_RETURN_IF_ERROR(
4146         CopyOutsideCompilationEdges(oc_cluster_iter.second, node_images,
4147                                     outside_compilation_inputs, graph));
4148   }
4149   return OkStatus();
4150 }
4151 
RemoveOutsideCompilationNodes(const NodeToNodeReplicasMap & node_images,Graph * graph)4152 /* static */ Status DistributedTPURewritePass::RemoveOutsideCompilationNodes(
4153     const NodeToNodeReplicasMap& node_images, Graph* graph) {
4154   for (const auto& iter : node_images) {
4155     if (iter.second.size() > 1) {
4156       // The cluster was replicated so remove the original node.
4157       Node* node = iter.first;
4158       graph->RemoveNode(node);
4159     }
4160   }
4161   return OkStatus();
4162 }
4163 
4164 /* static */ Status
LowerOutsideCompilationFunctionalNodes(Graph * g,FunctionLibraryDefinition & flib_def,const TPUReplicateDeviceNamesMapping & tpu_replicate_device_names_mapping)4165 DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes(
4166     Graph* g, FunctionLibraryDefinition& flib_def,
4167     const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping) {
4168   bool modified = false;
4169   do {
4170     std::vector<Node*> nodes_to_lower;
4171     for (Node* n : g->op_nodes()) {
4172       if (!HasNodeAttr(n->def(), kOutsideCompilationAttr)) {
4173         continue;
4174       }
4175 
4176       if (n->IsWhileNode() || n->IsIfNode() || IsFunctionCall(flib_def, *n)) {
4177         // Only lower functional ops with DT_RESOURCE input, because otherwise
4178         // placer will complain. For normal cases, lowering will cause slowdown
4179         // when related functions are huge (b/139037679).
4180         bool has_resource_input = false;
4181         for (const Edge* e : n->in_edges()) {
4182           if (!e->IsControlEdge() &&
4183               e->src()->output_type(e->src_output()) == DT_RESOURCE) {
4184             has_resource_input = true;
4185             break;
4186           }
4187         }
4188         if (has_resource_input) {
4189           nodes_to_lower.push_back(n);
4190         }
4191       }
4192     }
4193 
4194     modified = !nodes_to_lower.empty();
4195 
4196     auto lower_functional_node = [&flib_def, &g](Node* n) -> Status {
4197       // Clear device assignment. Otherwise all lowered nodes will have
4198       // device assignment, which is not what we want.
4199       n->set_requested_device("");
4200 
4201       int replica_id;
4202       TF_RETURN_IF_ERROR(
4203           GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
4204 
4205       string outside_compilation_attr;
4206       TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kOutsideCompilationAttr,
4207                                      &outside_compilation_attr));
4208 
4209       // There are two different kinds of functional outside compilation nodes:
4210       // 1. Nodes that are in outside compilation blocks already. They are
4211       //    generated by FunctionalizeControlFlowForXlaPass, and only have
4212       //    attribute kOutsideCompilationAttr.
4213       // 2. Mirrored control flow built for outside compilation in functional
4214       //    nodes. They are generated by ExtractOutsideCompilationPass, and have
4215       //    both kOutsideCompilationAttr and kXlaHasHostTransferAttrName.
4216       // When lowering them, they need to be treated differently.
4217       // For 1), their body functions are always V1 functions written by users,
4218       // and their "control outputs" are control inputs of _Retval nodes. They
4219       // should be lowered as V1 functions.
4220       // For 2), we always add necessary "control outputs"
4221       // (_XlaRecvAtHost/_XlaSendAtHost nodes) to "control_ret" field in their
4222       // FunctionDef's. They should be lowered as V2 functions.
4223       bool is_host_side_mirrored_control_flow =
4224           HasNodeAttr(n->def(), kXlaHasHostTransferAttrName);
4225 
4226       int num_node_ids = g->num_node_ids();
4227       bool is_call_node = IsFunctionCall(flib_def, *n);
4228       if (n->IsWhileNode()) {
4229         TF_RETURN_IF_ERROR(RewriteWhileNode(n, g, &flib_def,
4230                                             /*keep_node_fetchable=*/false));
4231       } else if (n->IsIfNode()) {
4232         TF_RETURN_IF_ERROR(RewriteIfNode(n, g, /*keep_node_fetchable=*/false));
4233       } else {
4234         TF_RET_CHECK(is_call_node);
4235         // See comments for "is_host_side_mirrored_control_flow" above.
4236         // If this is a node that's in outside compilation block, lower it as
4237         // V1 function. This is controlled by removing
4238         // kLowerAsMultiDeviceFunctionAttr from the node.
4239         if (!is_host_side_mirrored_control_flow) {
4240           n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
4241         } else {
4242           n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
4243           n->AddAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr,
4244                      true);
4245         }
4246         TF_RETURN_IF_ERROR(
4247             RewriteFunctionCallNode(n, g, flib_def,
4248                                     /*keep_caller_fetchable=*/false));
4249       }
4250 
4251       for (int i = num_node_ids; i < g->num_node_ids(); i++) {
4252         Node* node = g->FindNodeId(i);
4253         if (!node) {
4254           continue;
4255         }
4256 
4257         if (!is_call_node && is_host_side_mirrored_control_flow &&
4258             IsFunctionCall(flib_def, *node)) {
4259           // For If/While nodes, if they are host side mirrored control flow,
4260           // mark their body function calls with kXlaHasHostTransferAttrName
4261           // attribute to make sure we lower them as V2 function.
4262           node->AddAttr(kXlaHasHostTransferAttrName, true);
4263         }
4264 
4265         if (IsFunctionCall(flib_def, *node) || node->IsWhileNode() ||
4266             node->IsIfNode()) {
4267           // Set kOutsideCompilationAttr attribute so we lower these
4268           // nested function call nodes later.
4269           node->AddAttr(kOutsideCompilationAttr, outside_compilation_attr);
4270           // Set kXlaReplicaIdAttrName attribute so we know replica id when we
4271           // lower this function call node.
4272           node->AddAttr(kXlaReplicaIdAttrName, replica_id);
4273         } else if (node->type_string() == "_XlaRecvAtHost" ||
4274                    node->type_string() == "_XlaSendFromHost") {
4275           // For "_XlaRecvAtHost" and "_XlaSendFromHost" nodes, make sure they
4276           // have kXlaReplicaIdAttrName attribute so later we know which host
4277           // device to assign.
4278           node->AddAttr(kXlaReplicaIdAttrName, replica_id);
4279         }
4280       }
4281       return OkStatus();
4282     };
4283 
4284     for (Node* n : nodes_to_lower) {
4285       TF_RETURN_IF_ERROR(lower_functional_node(n));
4286     }
4287   } while (modified);
4288 
4289   // Set device for all _XlaRecvAtHost and _XlaSendFromHost nodes.
4290   for (Node* n : g->op_nodes()) {
4291     if (n->type_string() != "_XlaRecvAtHost" &&
4292         n->type_string() != "_XlaSendFromHost") {
4293       continue;
4294     }
4295 
4296     string replicate;
4297     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kTPUReplicateAttr, &replicate));
4298     auto iter = tpu_replicate_device_names_mapping.find(replicate);
4299     TF_RET_CHECK(iter != tpu_replicate_device_names_mapping.end());
4300     const auto& tpu_device_names = iter->second;
4301 
4302     int replica_id;
4303     TF_RETURN_IF_ERROR(
4304         GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
4305     TF_RET_CHECK(replica_id < tpu_device_names.size());
4306     const string& tpu_device_name = tpu_device_names[replica_id][0];
4307     string host_device_name;
4308     TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
4309         tpu_device_name, &host_device_name));
4310     n->set_assigned_device_name(host_device_name);
4311     // We may run TPU rewrite passes again on the subgraphs of the resulting
4312     // graph. Clear kTPUReplicateAttr and kOutsideCompilationAttr for
4313     // "_XlaRecvAtHost" nodes and "_XlaSendFromHost" nodes, in order to make
4314     // sure that TPU rewrite passes take no effect on host-side subgraphs for
4315     // outside compilation.
4316     n->ClearAttr(kTPUReplicateAttr);
4317     n->ClearAttr(kOutsideCompilationAttr);
4318   }
4319 
4320   // Remove IdentityN nodes generated for outside compilation. IdentityN is
4321   // exempt from resource edge colocation, but here we do need input and output
4322   // for these IdentityN nodes to be colocated.
4323   std::vector<Node*> identityn_nodes;
4324   for (Node* n : g->op_nodes()) {
4325     if (n->type_string() == "IdentityN" &&
4326         HasNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName)) {
4327       identityn_nodes.push_back(n);
4328     }
4329   }
4330   for (Node* n : identityn_nodes) {
4331     std::vector<const Edge*> out_edges(n->out_edges().begin(),
4332                                        n->out_edges().end());
4333     for (const Edge* e : out_edges) {
4334       if (e->IsControlEdge()) {
4335         continue;
4336       }
4337 
4338       int src_output = e->src_output();
4339       const Edge* input_edge;
4340       TF_RETURN_IF_ERROR(n->input_edge(src_output, &input_edge));
4341       Node* dst = e->dst();
4342       int dst_input = e->dst_input();
4343       g->RemoveEdge(e);
4344       g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
4345     }
4346     g->RemoveNode(n);
4347   }
4348 
4349   return OkStatus();
4350 }
4351 
ParseHostComputeCores(const Node & replicate_node,const OutsideCompilationNodeMap & outside_compilation_nodes,HostComputeCoreMap * host_compute_core)4352 /* static */ Status DistributedTPURewritePass::ParseHostComputeCores(
4353     const Node& replicate_node,
4354     const OutsideCompilationNodeMap& outside_compilation_nodes,
4355     HostComputeCoreMap* host_compute_core) {
4356   std::vector<string> hc_core_string;
4357   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "host_compute_core",
4358                                  &hc_core_string));
4359   TF_RETURN_IF_ERROR(
4360       ParseHostComputeCoreList(hc_core_string, host_compute_core));
4361   for (const auto& iter : outside_compilation_nodes) {
4362     const string& oc_cluster_name = iter.first;
4363     if (host_compute_core->find(oc_cluster_name) == host_compute_core->end()) {
4364       // By default put host compute Ops on replicated core 0.
4365       (*host_compute_core)[oc_cluster_name] = 0;
4366     }
4367   }
4368   return OkStatus();
4369 }
4370 
GetDeviceTopology(const DeviceSet & device_set,const Node & replicate_node,int * num_replicas,int * num_cores_per_replica,int * num_tasks,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment,string * tpu_compilation_device)4371 /* static */ Status DistributedTPURewritePass::GetDeviceTopology(
4372     const DeviceSet& device_set, const Node& replicate_node, int* num_replicas,
4373     int* num_cores_per_replica, int* num_tasks,
4374     std::vector<std::vector<string>>* tf_device_assignment,
4375     std::vector<int>* devices_to_lock,
4376     std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment,
4377     string* tpu_compilation_device) {
4378   TF_RETURN_IF_ERROR(
4379       GetNodeAttr(replicate_node.attrs(), "num_replicas", num_replicas));
4380   if (*num_replicas < 1) {
4381     return errors::InvalidArgument("num_replicas must be >= 1, got ",
4382                                    *num_replicas);
4383   }
4384 
4385   // Find the set of TPU devices in the TF job.
4386   // Indexed by [task number][tpu device number].
4387   std::vector<std::vector<Device*>> tpu_devices;
4388   int num_tpus_per_task;
4389   TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(),
4390                                        device_set, tpu_compilation_device,
4391                                        &num_tpus_per_task, &tpu_devices));
4392   *num_tasks = tpu_devices.size();
4393 
4394   string topology;
4395   TF_RETURN_IF_ERROR(
4396       GetNodeAttr(replicate_node.attrs(), "topology", &topology));
4397   TF_RETURN_IF_ERROR(GetNodeAttr(
4398       replicate_node.attrs(), "num_cores_per_replica", num_cores_per_replica));
4399   std::vector<int> device_assignment;
4400   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "device_assignment",
4401                                  &device_assignment));
4402 
4403   // TODO(cwhipkey): since we can control multiple pods of different shapes
4404   // from a single worker, it may be desirable to propagate the remote device
4405   // information around (e.g., in DeviceAttributes). This can lead to the mesh
4406   // topology proto being leaked to cloud TPU users (e.g. through GetStatus
4407   // calls); this may be okay, but to be conservative, just assume that the
4408   // master session has the proper flags set.
4409 
4410   // We do not initialize platform right now, but we can still retrieve the
4411   // TPU topology even with an uninitialized platform.
4412   auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform(
4413       /*initialize_platform=*/false);
4414   TF_RET_CHECK(tpu_platform);
4415   tpu::TpuTopologyExternal tpu_topology(tpu_platform->GetTopologyPtr());
4416   TF_RET_CHECK(num_tpus_per_task ==
4417                tpu_topology.LogicalDevicesPerHost(kTensorCore));
4418   TF_RETURN_IF_ERROR(BuildDeviceAssignment(
4419       tpu_topology, num_tpus_per_task, tpu_devices, *num_replicas,
4420       *num_cores_per_replica, topology, device_assignment, tf_device_assignment,
4421       devices_to_lock, xla_device_assignment));
4422 
4423   return OkStatus();
4424 }
4425 
GetIOTypes(int num_replicas,const Node & replicate_node,FunctionLibraryRuntime * flr,Graph * graph,NameRangeMap * input_name_map,const NameAttrList ** function,std::unique_ptr<Graph> * computation,DataTypeVector * arg_types,DataTypeVector * retval_types,ParameterInfo * params_info)4426 /* static */ Status DistributedTPURewritePass::GetIOTypes(
4427     int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr,
4428     Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function,
4429     std::unique_ptr<Graph>* computation, DataTypeVector* arg_types,
4430     DataTypeVector* retval_types, ParameterInfo* params_info) {
4431   DataTypeVector input_types, broadcast_input_types, guaranteed_constant_types;
4432   TF_RETURN_IF_ERROR(
4433       GetNodeAttr(replicate_node.attrs(), "Tinputs", &input_types));
4434   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "Tbroadcast_inputs",
4435                                  &broadcast_input_types));
4436   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
4437                                  "Tguaranteed_constants",
4438                                  &guaranteed_constant_types));
4439   int num_distributed_vars;
4440   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
4441                                  "num_distributed_variables",
4442                                  &num_distributed_vars));
4443   const int num_per_replica_inputs = input_types.size() - num_distributed_vars;
4444 
4445   if (num_per_replica_inputs % num_replicas != 0) {
4446     return errors::InvalidArgument(
4447         "Number of inputs to TPUReplicate (", num_per_replica_inputs,
4448         ") is not divisible by the number of replicas (", num_replicas, ").");
4449   }
4450 
4451   int num_variables;
4452   TF_RETURN_IF_ERROR(
4453       GetNodeAttr(replicate_node.attrs(), "NumVariables", &num_variables));
4454 
4455   NameRangeMap output_name_map;
4456   TF_RETURN_IF_ERROR(NameRangesForNode(replicate_node, replicate_node.op_def(),
4457                                        input_name_map, &output_name_map));
4458 
4459   TF_RETURN_IF_ERROR(
4460       GetNodeAttr(replicate_node.attrs(), "computation", function));
4461 
4462   *computation = absl::make_unique<Graph>(graph->op_registry());
4463   TF_RETURN_IF_ERROR(GetComputationForTPUReplicateOp(
4464       **function, flr, computation->get(), arg_types, retval_types));
4465 
4466   *params_info = ParameterInfo(
4467       num_replicas, num_per_replica_inputs / num_replicas, num_distributed_vars,
4468       broadcast_input_types.size(), num_variables,
4469       guaranteed_constant_types.size(), retval_types->size());
4470 
4471   if (arg_types->size() != params_info->NumInputsToEachReplica()) {
4472     return errors::InvalidArgument(
4473         "Computation argument to TPUReplicate has wrong number of "
4474         "arguments. Expected ",
4475         params_info->NumInputsToEachReplica(), " inputs, got ",
4476         arg_types->size());
4477   }
4478   if (replicate_node.num_outputs() != params_info->NumOutputsToHost()) {
4479     return errors::InvalidArgument(
4480         "Wrong number of outputs from TPUReplicate. Expected ",
4481         params_info->NumOutputsToHost(), " outputs, got ",
4482         replicate_node.num_outputs());
4483   }
4484   if (enable_cross_replica_sharding_mirrored_variables_) {
4485     std::vector<int> mirrored_variable_indices;
4486     TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
4487                                    TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
4488                                    &mirrored_variable_indices));
4489     for (int index : mirrored_variable_indices) {
4490       TF_RET_CHECK(params_info->IsPerReplicaArg(index) ||
4491                    params_info->IsDistributedArg(index))
4492           << "Mirrored variables not categorized as per-replica arguments, "
4493              "index: "
4494           << index;
4495       params_info->mutable_mirrored_variable_indices()->insert(index);
4496     }
4497   }
4498   return OkStatus();
4499 }
4500 
BuildSequencingNodes(const string & tpu_compilation_device,const Node & replicate_node,Graph * graph,Node ** host_transfer_sequencer,Node ** control_before,Node ** control_after)4501 /* static */ Status DistributedTPURewritePass::BuildSequencingNodes(
4502     const string& tpu_compilation_device, const Node& replicate_node,
4503     Graph* graph, Node** host_transfer_sequencer, Node** control_before,
4504     Node** control_after) {
4505   *host_transfer_sequencer = nullptr;
4506 
4507   TF_RETURN_IF_ERROR(
4508       BuildNoopNode(replicate_node,
4509                     graph->NewName(strings::StrCat(replicate_node.name(), "/",
4510                                                    "control_before")),
4511                     /*device=*/"", graph, control_before));
4512   for (const Edge* e : replicate_node.in_edges()) {
4513     if (!e->IsControlEdge()) {
4514       continue;
4515     }
4516     Node* predecessor = e->src();
4517     if (predecessor->IsSource()) continue;
4518     if (predecessor->type_string() == "NoOp" &&
4519         predecessor->attrs().Find("_xla_host_transfer_sequencer") != nullptr) {
4520       // The node is the sequencer for host transfer operations. Its control
4521       // dependency needs to be placed after the execute node, not before.
4522       if (*host_transfer_sequencer != nullptr) {
4523         return errors::Internal("Replicate node ", replicate_node.name(),
4524                                 " has two transfer sequencer nodes: ",
4525                                 (*host_transfer_sequencer)->name(), " and ",
4526                                 predecessor->name());
4527       }
4528       // Set the correct device to match the other sequencing nodes.
4529       predecessor->set_assigned_device_name(tpu_compilation_device);
4530       *host_transfer_sequencer = predecessor;
4531     } else {
4532       graph->AddControlEdge(predecessor, *control_before);
4533     }
4534   }
4535 
4536   TF_RETURN_IF_ERROR(
4537       BuildNoopNode(replicate_node,
4538                     graph->NewName(strings::StrCat(replicate_node.name(), "/",
4539                                                    "control_after")),
4540                     /*device=*/tpu_compilation_device, graph, control_after));
4541   for (Node* successor : replicate_node.out_nodes()) {
4542     if (successor->attrs().Find("_xla_tail_outside_compilation") != nullptr) {
4543       graph->AddControlEdge(successor, *control_after);
4544     } else {
4545       graph->AddControlEdge(*control_after, successor);
4546     }
4547   }
4548   return OkStatus();
4549 }
4550 
DealWithConstantsAndVariables(const Node & replicate_node,const NameRangeMap & input_name_map,Graph * graph,Node * host_transfer_sequencer,Node * control_before,Node * control_after,absl::Span<const VariableInput> variable_nodes,std::vector<Node * > * guaranteed_constant_nodes,std::vector<Node * > * variable_reads)4551 /* static */ Status DistributedTPURewritePass::DealWithConstantsAndVariables(
4552     const Node& replicate_node, const NameRangeMap& input_name_map,
4553     Graph* graph, Node* host_transfer_sequencer, Node* control_before,
4554     Node* control_after, absl::Span<const VariableInput> variable_nodes,
4555     std::vector<Node*>* guaranteed_constant_nodes,
4556     std::vector<Node*>* variable_reads) {
4557   TF_RETURN_IF_ERROR(FindGuaranteedConstantInputs(
4558       replicate_node, input_name_map, guaranteed_constant_nodes));
4559 
4560   TF_RETURN_IF_ERROR(BuildVariableReads(variable_nodes, control_before, graph,
4561                                         variable_reads));
4562   // Add the control dependency from host transfer nodes.
4563   if (host_transfer_sequencer != nullptr) {
4564     graph->AddControlEdge(host_transfer_sequencer, control_after);
4565   }
4566   return OkStatus();
4567 }
4568 
4569 /* static */ Status
BuildCompilationStatusReturnNodes(Node * replicate_node,Node * compile_node,absl::Span<const int> devices_to_lock,Node ** control_after_compilation,Node ** multilock_acquire,Graph * graph)4570 DistributedTPURewritePass::BuildCompilationStatusReturnNodes(
4571     Node* replicate_node, Node* compile_node,
4572     absl::Span<const int> devices_to_lock, Node** control_after_compilation,
4573     Node** multilock_acquire, Graph* graph) {
4574   const Edge* compilation_edge = nullptr;
4575   for (const auto* e : replicate_node->out_edges()) {
4576     if (e->IsControlEdge() &&
4577         e->dst()->type_string() == "TPUCompilationResult") {
4578       TF_RET_CHECK(compilation_edge == nullptr)
4579           << "Multiple compilation result nodes attached to the same replicate "
4580              "cluster.";
4581       compilation_edge = e;
4582     }
4583   }
4584 
4585   // TODO(jpienaar): This should be checked by default, current tests not using
4586   // this are ones that use the "abort upon successful compile flag" which will
4587   // be removed. Leaving this in until then.
4588   if (compilation_edge != nullptr) {
4589     Node* compilation_status = compilation_edge->dst();
4590     const AttrValue* compile_status_cluster_attr =
4591         compilation_status->attrs().Find(kTPUCompilationResultAttr);
4592     TF_RET_CHECK(compile_status_cluster_attr != nullptr);
4593     const string& compile_status_cluster = compile_status_cluster_attr->s();
4594     TF_RET_CHECK(!compile_status_cluster.empty());
4595     const AttrValue* replicate_cluster_attr =
4596         replicate_node->attrs().Find(kTPUReplicateAttr);
4597     TF_RET_CHECK(replicate_cluster_attr != nullptr);
4598     const string& replicate_cluster = replicate_cluster_attr->s();
4599     TF_RET_CHECK(!replicate_cluster.empty());
4600     TF_RET_CHECK(compile_status_cluster == replicate_cluster);
4601 
4602     TF_RETURN_IF_ERROR(
4603         ReplaceCompilationResultNodeWithIdentity(graph, &compilation_status));
4604     graph->AddEdge(compile_node, 0, compilation_status, 0);
4605   }
4606 
4607   NodeDef def;
4608   def.set_name(UniqueNodeName("tpu_compile_succeeded_assert", graph));
4609   // Create an op to assert that compilation succeeded. The alternative would
4610   // have been to have each execute op check and return an error.
4611   def.set_op("TPUCompileSucceededAssert");
4612   MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
4613   TF_ASSIGN_OR_RETURN(Node * compile_succeeded, graph->AddNode(def));
4614   compile_succeeded->set_assigned_device_name(
4615       compile_node->assigned_device_name());
4616   graph->AddEdge(compile_node, 0, compile_succeeded, 0);
4617 
4618   Node* last_node_before_sequencer = compile_succeeded;
4619 
4620   if (enable_multicore_locking_ && devices_to_lock.size() > 1) {
4621     // Add a lock node to acquire exclusive access to all the cores that will
4622     // execute this program. The lock is required to prevent deadlock or
4623     // incorrect results when running concurrent multi-core programs in the
4624     // same distributed runtime when there is no direct graph dependency
4625     // between the programs (either because they are run from different sessions
4626     // or because they are in the same graph, but have no control or data
4627     // dependencies to sequence them). Consider the case of two multi-core
4628     // computations A and B whose cores overlap and include cores X and Y. With
4629     // no locking and no graph dependencies it is possible that A's program
4630     // gets enqueued before B's on core X, while B's program gets enqueued
4631     // before A's on core Y. This will lead either to deadlock or to
4632     // incorrect results, since the runtime has no mechanism to re-sequence
4633     // the programs on the cores. By adding a multi-lock acquisition for all the
4634     // before any TPUExecute ops are run, and releasing it after they complete,
4635     // we ensure that the programs are enqueued on the cores in a consistent
4636     // order.
4637     //
4638     // There is a risk when computations are in the same graph, and include a
4639     // data dependency, that the lock acquisition could provoke deadlock.
4640     // Suppose that A must happen before B because B's input depends on A's
4641     // output. Then it is obviously necessary that A's lock acquisition must
4642     // happen before B's lock acquisition, and so we must ensure that there is
4643     // a graph dependency causing B's lock acquisition to be sequenced after A's
4644     // lock acquisition. Right now that dependency is satisfied because the
4645     // shape inference code cannot determine the shape of A's outputs, and so
4646     // B's compilation, which precedes B's lock acquisition, is always sequenced
4647     // after A's execution. If the shape inference is improved it will be
4648     // necessary to add an explicit control edge between dependent lock
4649     // acquisition ops.
4650     NodeDef lock_def;
4651     lock_def.set_name(graph->NewName(
4652         strings::StrCat(compile_node->name(), "/", "tpu_acquire_multilock")));
4653     lock_def.set_op("TpuMultilock");
4654     AddNodeAttr("lock_list", devices_to_lock, &lock_def);
4655     MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &lock_def);
4656     TF_ASSIGN_OR_RETURN(*multilock_acquire, graph->AddNode(lock_def));
4657     (*multilock_acquire)
4658         ->set_assigned_device_name(compile_node->assigned_device_name());
4659     graph->AddControlEdge(compile_succeeded, *multilock_acquire);
4660     last_node_before_sequencer = *multilock_acquire;
4661   } else {
4662     *multilock_acquire = nullptr;
4663   }
4664 
4665   // Build a sequencing node for when compilation has completed.
4666   TF_RETURN_IF_ERROR(
4667       BuildNoopNode(*replicate_node,
4668                     graph->NewName(strings::StrCat(compile_node->name(), "/",
4669                                                    "after_compilation")),
4670                     /*device=*/"", graph, control_after_compilation));
4671   graph->AddControlEdge(last_node_before_sequencer, *control_after_compilation);
4672 
4673   return OkStatus();
4674 }
4675 
4676 // Updates the head and tail outside compiled nodes so that nodes have the
4677 // correct device and removes the replication and outside compilation attributes
4678 // so that these nodes do not trigger further graph optimization passes.
UpdateHeadTailOutsideCompilation(const std::vector<std::vector<string>> & tf_device_assignment,const std::vector<Node * > & head_tail_outside_compilation_nodes)4679 /* static */ Status DistributedTPURewritePass::UpdateHeadTailOutsideCompilation(
4680     const std::vector<std::vector<string>>& tf_device_assignment,
4681     const std::vector<Node*>& head_tail_outside_compilation_nodes) {
4682   for (Node* node : head_tail_outside_compilation_nodes) {
4683     int replica_id;
4684     TF_RETURN_IF_ERROR(
4685         GetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id));
4686     // Since we set the device, this will now run on a task other than 0. We
4687     // clear the two following attributes so that we don't trigger encapsulation
4688     // again on the remote host (which will fail due to a missing
4689     // _TPUReplicateMetadata node for the cluster).
4690     for (const Edge* e : node->in_edges()) {
4691       // Resource consuming ops should colocate with its resource input.
4692       if (e->src()->IsArg() &&
4693           e->src()->output_type(e->src_output()) == DT_RESOURCE) {
4694         node->set_requested_device(tf_device_assignment[replica_id][0]);
4695       }
4696     }
4697     if (node->requested_device().empty()) {
4698       string cpu_device;
4699       TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
4700           tf_device_assignment[replica_id][0], &cpu_device));
4701       node->set_requested_device(cpu_device);
4702     }
4703     node->ClearAttr(kTPUReplicateAttr);
4704     node->ClearAttr(kOutsideCompilationAttr);
4705   }
4706   return OkStatus();
4707 }
4708 
4709 // Performs the rewrite on a single TPUReplicate node.
RewriteTPUReplicateNode(const string & session_handle,const DeviceSet & device_set,Node * replicate_node,FunctionLibraryDefinition * flib_def,FunctionLibraryRuntime * flr,Node * host_compute_key_placeholder_node,const OutsideCompilationNodeMap & outside_compilation_nodes,const std::vector<Node * > & head_tail_outside_compilation_nodes,NodeToNodeReplicasMap * outside_compilation_node_images,Graph * graph,const GraphShapeInfo & shape_info,TPUReplicateDeviceNamesMapping * tpu_replicate_device_names_mapping,int64_t autotuner_thresh)4710 /* static */ Status DistributedTPURewritePass::RewriteTPUReplicateNode(
4711     const string& session_handle, const DeviceSet& device_set,
4712     Node* replicate_node, FunctionLibraryDefinition* flib_def,
4713     FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node,
4714     const OutsideCompilationNodeMap& outside_compilation_nodes,
4715     const std::vector<Node*>& head_tail_outside_compilation_nodes,
4716     NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph,
4717     const GraphShapeInfo& shape_info,
4718     TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping,
4719     int64_t autotuner_thresh) {
4720   VLOG(2) << "Rewriting node " << replicate_node->name();
4721 
4722   // num_replicas and num_cores_per_replica are the 'virtual' replicas (copies
4723   // of the computation) and cores (virtual cores within computations) specified
4724   // by the user. They will be mapped to physical TPU cores below.
4725   int num_replicas;
4726   int num_cores_per_replica;
4727   int num_tasks;
4728   std::vector<std::vector<string>> tf_device_assignment;
4729   std::vector<int> devices_to_lock;
4730   std::unique_ptr<xla::DeviceAssignment> xla_device_assignment;
4731   string tpu_compilation_device;
4732   TF_RETURN_IF_ERROR(GetDeviceTopology(
4733       device_set, *replicate_node, &num_replicas, &num_cores_per_replica,
4734       &num_tasks, &tf_device_assignment, &devices_to_lock,
4735       &xla_device_assignment, &tpu_compilation_device));
4736 
4737   TF_RETURN_IF_ERROR(UpdateHeadTailOutsideCompilation(
4738       tf_device_assignment, head_tail_outside_compilation_nodes));
4739 
4740   string replicate;
4741   TF_RETURN_IF_ERROR(
4742       GetNodeAttr(replicate_node->def(), kTPUReplicateAttr, &replicate));
4743   tpu_replicate_device_names_mapping->emplace(replicate, tf_device_assignment);
4744 
4745   NameRangeMap input_name_map;
4746   const NameAttrList* function;
4747   std::unique_ptr<Graph> computation;
4748   DataTypeVector arg_types, retval_types;
4749   ParameterInfo params_info;
4750   TF_RETURN_IF_ERROR(GetIOTypes(num_replicas, *replicate_node, flr, graph,
4751                                 &input_name_map, &function, &computation,
4752                                 &arg_types, &retval_types, &params_info));
4753 
4754   std::vector<InferredShape> arg_shapes, retval_shapes;
4755   TF_RETURN_IF_ERROR(GetArgAndRetvalShapes(
4756       shape_info, *replicate_node, params_info, &arg_shapes, &retval_shapes));
4757 
4758   TF_RETURN_IF_ERROR(ValidateCoreNumbers(*computation, num_cores_per_replica));
4759 
4760   std::vector<xla::OpSharding> arg_sharding;
4761   std::vector<bool> arg_fast_mem;
4762   std::vector<std::string> arg_names;
4763   std::vector<xla::OpSharding> retval_sharding;
4764   TF_RETURN_IF_ERROR(AssignArgsAndRetvalsToCores(
4765       num_cores_per_replica, params_info, arg_types, arg_shapes, retval_types,
4766       retval_shapes, *computation, replicate_node, flr,
4767       allow_xla_spmd_partition_, &arg_sharding, &arg_fast_mem, &retval_sharding,
4768       &arg_names));
4769 
4770   VLOG(1) << DumpGraphToFile("distributed_tpu_graph_to_replicate", *computation,
4771                              flib_def);
4772 
4773   GraphDef graph_def;
4774   graph->ToGraphDef(&graph_def);
4775   FunctionLibraryDefinition reachable_functions =
4776       flib_def->ReachableDefinitions(graph_def);
4777   uint64 library_fingerprint;
4778 
4779   TF_RETURN_IF_ERROR(
4780       FingerprintFunctionLibrary(reachable_functions, &library_fingerprint));
4781   VLOG(1) << "Fingerprint functions: "
4782           << absl::StrJoin(reachable_functions.ListFunctionNames(), ", ");
4783   VLOG(1) << "library_fingerprint: " << library_fingerprint;
4784 
4785   // Builds trigger nodes that put barriers around the expansion of
4786   // TPUReplicate. In particular, we must guarantee:
4787   // a) variable reads happen after all predecessors of the original
4788   //    TPUReplicate.
4789   // b) variable writes happen before all successors of the original
4790   //    TPUReplicate.
4791   // c) all replicas execute, even if output tensors are only requested from
4792   //    a subset of replicas. This is necessary both to ensure that variable
4793   //    updates happen, but also Send/Recv will deadlock if only one half of
4794   //    the communicating pair runs.
4795   Node* host_transfer_sequencer;
4796   Node* control_before;
4797   Node* control_after;
4798   TF_RETURN_IF_ERROR(BuildSequencingNodes(
4799       tpu_compilation_device, *replicate_node, graph, &host_transfer_sequencer,
4800       &control_before, &control_after));
4801 
4802   // Build a vector of variable nodes that are inputs.
4803   std::vector<VariableInput> variable_inputs;
4804   TF_RETURN_IF_ERROR(
4805       FindVariableInputs(*replicate_node, input_name_map, &variable_inputs));
4806 
4807   std::vector<Node*> guaranteed_constant_nodes;
4808   std::vector<Node*> variable_reads;
4809   TF_RETURN_IF_ERROR(DealWithConstantsAndVariables(
4810       *replicate_node, input_name_map, graph, host_transfer_sequencer,
4811       control_before, control_after, variable_inputs,
4812       &guaranteed_constant_nodes, &variable_reads));
4813 
4814   // Builds Shape nodes that compute the dynamic shapes of arguments whose
4815   // shapes are not statically known.
4816   std::vector<Node*> dynamic_shape_nodes;
4817   TF_RETURN_IF_ERROR(BuildDynamicShapeNodes(*replicate_node, arg_shapes,
4818                                             params_info, variable_reads, graph,
4819                                             &dynamic_shape_nodes));
4820 
4821   // Builds a TPUCompile node that compiles `clusters` on `compile_device`.
4822   Node* compile_node;
4823   TF_RETURN_IF_ERROR(BuildCompileNode(
4824       replicate_node, *function, library_fingerprint, params_info, arg_shapes,
4825       arg_types, guaranteed_constant_nodes, session_handle, arg_sharding,
4826       arg_fast_mem, arg_names, retval_sharding, num_cores_per_replica,
4827       /*compile_device=*/tpu_compilation_device, xla_device_assignment.get(),
4828       dynamic_shape_nodes, graph, &compile_node, autotuner_thresh));
4829 
4830   // Compilation must be sequenced after the control node if the TPU computation
4831   // in a control-flow construct, such as a loop.
4832   graph->AddControlEdge(control_before, compile_node);
4833 
4834   Node* control_after_compilation;
4835   Node* multilock_acquire;
4836   TF_RETURN_IF_ERROR(BuildCompilationStatusReturnNodes(
4837       replicate_node, compile_node, devices_to_lock, &control_after_compilation,
4838       &multilock_acquire, graph));
4839 
4840   std::vector<VariableWrite> variable_writes;
4841   TF_RETURN_IF_ERROR(BuildExecuteNodes(
4842       params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_names,
4843       arg_types, arg_shapes, retval_types, arg_sharding, retval_sharding,
4844       tf_device_assignment, compile_node, variable_reads,
4845       control_after_compilation, control_after, multilock_acquire,
4846       &variable_writes, graph));
4847   bool contains_resource_write_op =
4848       ContainsResourceWriteOp(*graph, reachable_functions);
4849 
4850   VLOG(2) << "contains_resource_write_op: " << contains_resource_write_op;
4851   // Skip conditional write if there is no resource writing op inside TPU
4852   // computation.
4853   if (contains_resource_write_op) {
4854     TF_RETURN_IF_ERROR(BuildVariableWrites(variable_inputs, control_after,
4855                                            variable_writes, graph));
4856   }
4857 
4858   if (host_compute_key_placeholder_node != nullptr) {
4859     TF_RETURN_IF_ERROR(ConnectHostComputeNodes(
4860         compile_node, host_compute_key_placeholder_node, graph));
4861   }
4862 
4863   HostComputeCoreMap host_compute_core;
4864   TF_RETURN_IF_ERROR(ParseHostComputeCores(
4865       *replicate_node, outside_compilation_nodes, &host_compute_core));
4866   TF_RETURN_IF_ERROR(ReplicateOutsideCompilationNodes(
4867       tf_device_assignment, host_compute_core, outside_compilation_nodes,
4868       outside_compilation_node_images, graph));
4869 
4870   graph->RemoveNode(replicate_node);
4871   return OkStatus();
4872 }
4873 
4874 // Adds sharded weight update optimization for each host training loop.
4875 //
4876 // For any host training loop found in the graph, TPUVariableReshard ops
4877 // are inserted to match the best layout chosen by the XLA.
4878 /* static */ Status
PerformHostTrainingLoopOptimization(Graph * graph,FunctionLibraryDefinition * flib_def,FunctionLibraryRuntime * flr)4879 DistributedTPURewritePass::PerformHostTrainingLoopOptimization(
4880     Graph* graph, FunctionLibraryDefinition* flib_def,
4881     FunctionLibraryRuntime* flr) {
4882   std::vector<tpu::HostTrainingLoopInfo> host_training_loops_info;
4883   Status s = tpu::DetectHostTrainingLoop(
4884       /*current_function_name=*/nullptr,
4885       /*current_function_attr=*/nullptr, flib_def, graph, flr,
4886       &host_training_loops_info);
4887   if (!s.ok()) {
4888     VLOG(2) << "No valid host training loop found. Skipping sharded weight "
4889             << "update optimization.";
4890     return OkStatus();
4891   }
4892 
4893   for (const auto& host_loop : host_training_loops_info) {
4894     const auto& function_name = host_loop.encapsulating_function_name;
4895     // `function_name` has value when host training loop is inside a
4896     // function call node. When host training loop is found inside a function
4897     // call node, then, in addition to adding TPUVariableReshard ops, function
4898     // library definition needs to be updated as well.
4899     if (function_name.has_value()) {
4900       const auto& function_attr = host_loop.encapsulating_function_attrs;
4901       TF_RET_CHECK(function_attr.has_value())
4902           << "Unable to find function attribute for function: "
4903           << *function_name;
4904 
4905       const FunctionDef* function_def = flib_def->Find(*function_name);
4906       TF_RET_CHECK(function_def)
4907           << "Unable to find function : " << *function_name;
4908 
4909       std::unique_ptr<FunctionBody> fbody;
4910       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
4911           *function_def, AttrSlice(&function_attr.value()), flib_def, &fbody));
4912       Graph* function_graph = fbody->graph;
4913       TF_RETURN_IF_ERROR(tpu::AddReshardOp(function_graph, host_loop));
4914       TF_RETURN_IF_ERROR(UpdateFunctionLibDefinition(*function_graph,
4915                                                      *function_name, flib_def));
4916     } else {
4917       TF_RETURN_IF_ERROR(tpu::AddReshardOp(graph, host_loop));
4918     }
4919   }
4920   return OkStatus();
4921 }
4922 
PlaceUnassignedDeviceNodesOnTPUIfPossible(Graph * graph)4923 Status DistributedTPURewritePass::PlaceUnassignedDeviceNodesOnTPUIfPossible(
4924     Graph* graph) {
4925   PropagateDevices(CanAcceptTPUDevicePropagation, IsTpuDevice, graph);
4926   return OkStatus();
4927 }
4928 
Run(const GraphOptimizationPassOptions & options)4929 Status DistributedTPURewritePass::Run(
4930     const GraphOptimizationPassOptions& options) {
4931   Status status = InternalRun(options);
4932   OkOrSetErrorCounterPayload(
4933       tensorflow::core::platform::ErrorSourceProto::TF_XLA_BRIDGE, status);
4934   return status;
4935 }
4936 
InternalRun(const GraphOptimizationPassOptions & options)4937 Status DistributedTPURewritePass::InternalRun(
4938     const GraphOptimizationPassOptions& options) {
4939   VLOG(1) << "DistributedTPURewritePass::Run";
4940 
4941   Graph* graph = options.graph->get();
4942 
4943   VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_before", *graph,
4944                              options.flib_def);
4945 
4946   const auto* config = &options.session_options->config;
4947   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
4948       new ProcessFunctionLibraryRuntime(
4949           nullptr, options.session_options->env, config,
4950           graph->versions().producer(), options.flib_def,
4951           config ? config->graph_options().optimizer_options()
4952                  : OptimizerOptions()));
4953 
4954   FunctionLibraryRuntime* flr =
4955       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
4956 
4957   // This pass can only run in the session master, which should fill
4958   // in the device_set field to the options.
4959   TF_RET_CHECK(options.device_set != nullptr);
4960 
4961   // Find all the replicate nodes before mutating the graph.
4962   std::vector<Node*> replicate_nodes;
4963   // Map from compiled subgraph cluster name to the outside_compilation nodes in
4964   // that cluster.
4965   std::map<string, OutsideCompilationNodeMap> outside_compilation_nodes;
4966   std::map<string, std::vector<Node*>> head_tail_outside_compilation_nodes;
4967   TF_RETURN_IF_ERROR(FindTaggedNodes(graph, &replicate_nodes,
4968                                      &outside_compilation_nodes,
4969                                      &head_tail_outside_compilation_nodes));
4970 
4971   if (replicate_nodes.empty()) {
4972     // Remove unused TPUPartitionedInput nodes.
4973     for (Node* n : graph->nodes()) {
4974       if (n->type_string() == kTPUPartitionedInput) graph->RemoveNode(n);
4975     }
4976     VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_after", *graph,
4977                                options.flib_def);
4978     VLOG(1) << "Replicate nodes are empty. DistributedTPURewritePass::Run() "
4979                "finished";
4980     return OkStatus();
4981   }
4982 
4983   std::unordered_map<string, Node*> host_compute_key_placeholder_map;
4984   TF_RETURN_IF_ERROR(FindHostComputeKeyPlaceholderNodes(
4985       graph, replicate_nodes, &host_compute_key_placeholder_map));
4986 
4987   // This shape inference pass does not compute the shapes of outputs of
4988   // TPU computations. The concurrent multi-core locking implementation
4989   // *relies* on this behavior because it ensures that, if TPU computation B's
4990   // inputs depend on TPU computation A's outputs, then computation B's
4991   // compilation will be sequenced after A's execution, and this ensures that
4992   // locks are acquired in the correct order. If the shape inference is improved
4993   // to compute shapes of TPU computation outputs, it will be necessary to add
4994   // an explicit control edge between lock acquisitions for dependent
4995   // computations in order to avoid deadlock.
4996   GraphShapeInfo shape_info;
4997   TF_RETURN_IF_ERROR(InferShapes(graph, /*arg_shapes=*/{},
4998                                  flr->GetFunctionLibraryDefinition(),
4999                                  &shape_info));
5000   int64_t autotuner_thresh = options.session_options->config.experimental()
5001                                  .xla_fusion_autotuner_thresh();
5002 
5003   NodeToNodeReplicasMap outside_compilation_node_images;
5004   TPUReplicateDeviceNamesMapping tpu_replicate_device_names_mapping;
5005   for (Node* node : replicate_nodes) {
5006     TF_RETURN_IF_ERROR(RewriteTPUReplicateNode(
5007         options.session_handle, *options.device_set, node, options.flib_def,
5008         flr, host_compute_key_placeholder_map[node->name()],
5009         outside_compilation_nodes[node->name()],
5010         head_tail_outside_compilation_nodes[node->name()],
5011         &outside_compilation_node_images, graph, shape_info,
5012         &tpu_replicate_device_names_mapping, autotuner_thresh));
5013   }
5014 
5015   // Place the padding nodes generated by dynamic padder on the correct devices.
5016   // TODO(rxsang): Place padding ops on TPUs in
5017   // PlaceUnassignedDeviceNodesOnTPUIfPossible function.
5018   TF_RETURN_IF_ERROR(SetPaddingNodesDevices(graph));
5019 
5020   std::unordered_map<string, Node*> outside_compilation_inputs;
5021   for (Node* n : graph->op_nodes()) {
5022     string lifted_arg_inputs_attr;
5023     if (n->type_string() == "IdentityN" &&
5024         GetNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName,
5025                     &lifted_arg_inputs_attr)
5026             .ok()) {
5027       outside_compilation_inputs[lifted_arg_inputs_attr] = n;
5028     }
5029   }
5030   for (const auto& iter : outside_compilation_nodes) {
5031     TF_RETURN_IF_ERROR(ReplicateOutsideCompilationEdges(
5032         iter.second, outside_compilation_node_images,
5033         outside_compilation_inputs, graph));
5034   }
5035   TF_RETURN_IF_ERROR(
5036       RemoveOutsideCompilationNodes(outside_compilation_node_images, graph));
5037   TF_RETURN_IF_ERROR(LowerOutsideCompilationFunctionalNodes(
5038       graph, *options.flib_def, tpu_replicate_device_names_mapping));
5039 
5040   TF_RETURN_IF_ERROR(PlaceUnassignedDeviceNodesOnTPUIfPossible(graph));
5041   VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_after", *graph,
5042                              options.flib_def);
5043   VLOG(1) << "DistributedTPURewritePass::Run() finished";
5044 
5045   if (enable_cross_replica_sharding_mirrored_variables_) {
5046     VLOG(1) << "Starting host training loop optimization.";
5047     VLOG(1) << DumpGraphToFile("host_loop_optimization_before", *graph,
5048                                options.flib_def);
5049     TF_RETURN_IF_ERROR(
5050         PerformHostTrainingLoopOptimization(graph, options.flib_def, flr));
5051     VLOG(1) << DumpGraphToFile("host_loop_optimization_after", *graph,
5052                                options.flib_def);
5053     VLOG(1) << "Host training loop optimization finished.";
5054   }
5055 
5056   return OkStatus();
5057 }
5058 
5059 bool DistributedTPURewritePass::distribute_vars_ = false;
5060 bool DistributedTPURewritePass::allow_xla_spmd_partition_ = true;
5061 bool DistributedTPURewritePass::
5062     replicate_inputs_outputs_by_default_for_xla_spmd_ = false;
5063 bool DistributedTPURewritePass::
5064     enable_cross_replica_sharding_mirrored_variables_ = true;
5065 bool DistributedTPURewritePass::enable_automatic_model_parallelism_ = false;
5066 bool DistributedTPURewritePass::enable_xla_param_broadcast_ = true;
5067 bool DistributedTPURewritePass::enable_multicore_locking_ = false;
5068 bool DistributedTPURewritePass::use_nd_sharding_ops_ = false;
5069 
SetDistributedTpuRewritePassOptions(bool distribute_vars,bool allow_xla_spmd_partition,bool replicate_inputs_outputs_by_default_for_xla_spmd,bool enable_cross_replica_sharding_mirrored_variables,bool enable_automatic_model_parallelism,bool enable_xla_param_broadcast,bool enable_multicore_locking,bool use_nd_sharding_ops)5070 /*static*/ void DistributedTPURewritePass::SetDistributedTpuRewritePassOptions(
5071     bool distribute_vars, bool allow_xla_spmd_partition,
5072     bool replicate_inputs_outputs_by_default_for_xla_spmd,
5073     bool enable_cross_replica_sharding_mirrored_variables,
5074     bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast,
5075     bool enable_multicore_locking, bool use_nd_sharding_ops) {
5076   distribute_vars_ = distribute_vars;
5077   allow_xla_spmd_partition_ = allow_xla_spmd_partition;
5078   replicate_inputs_outputs_by_default_for_xla_spmd_ =
5079       replicate_inputs_outputs_by_default_for_xla_spmd;
5080   enable_cross_replica_sharding_mirrored_variables_ =
5081       enable_cross_replica_sharding_mirrored_variables;
5082   enable_automatic_model_parallelism_ = enable_automatic_model_parallelism;
5083   enable_xla_param_broadcast_ = enable_xla_param_broadcast;
5084   enable_multicore_locking_ = enable_multicore_locking;
5085   use_nd_sharding_ops_ = use_nd_sharding_ops;
5086 }
5087 
5088 }  // namespace tensorflow
5089