1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Compilation for distributed TPU (TPU_REPLICATED_CORE devices).
17
18 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h"
19
20 #include <algorithm>
21 #include <queue>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/btree_map.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/strings/escaping.h"
29 #include "tensorflow/compiler/jit/encapsulate_util.h"
30 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
31 #include "tensorflow/compiler/tf2xla/sharding_util.h"
32 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
33 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
34 #include "tensorflow/compiler/xla/array3d.h"
35 #include "tensorflow/compiler/xla/array4d.h"
36 #include "tensorflow/compiler/xla/client/sharding_builder.h"
37 #include "tensorflow/compiler/xla/service/computation_placer.h"
38 #include "tensorflow/compiler/xla/xla.pb.h"
39 #include "tensorflow/core/common_runtime/device_propagation.h"
40 #include "tensorflow/core/common_runtime/function.h"
41 #include "tensorflow/core/common_runtime/graph_constructor.h"
42 #include "tensorflow/core/common_runtime/lower_function_call_op.h"
43 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
44 #include "tensorflow/core/common_runtime/lower_if_op.h"
45 #include "tensorflow/core/common_runtime/lower_while_op.h"
46 #include "tensorflow/core/common_runtime/optimization_registry.h"
47 #include "tensorflow/core/framework/function.h"
48 #include "tensorflow/core/framework/graph_to_functiondef.h"
49 #include "tensorflow/core/framework/node_def_builder.h"
50 #include "tensorflow/core/framework/node_def_util.h"
51 #include "tensorflow/core/framework/partial_tensor_shape.h"
52 #include "tensorflow/core/framework/tensor.pb.h"
53 #include "tensorflow/core/framework/types.pb.h"
54 #include "tensorflow/core/framework/versions.pb.h"
55 #include "tensorflow/core/graph/algorithm.h"
56 #include "tensorflow/core/graph/graph.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/lib/gtl/cleanup.h"
60 #include "tensorflow/core/lib/strings/proto_serialization.h"
61 #include "tensorflow/core/lib/strings/str_util.h"
62 #include "tensorflow/core/platform/error_payloads.h"
63 #include "tensorflow/core/platform/fingerprint.h"
64 #include "tensorflow/core/protobuf/core_platform_payloads.pb.h"
65 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
66 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
67 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
68 #include "tensorflow/core/public/session_options.h"
69 #include "tensorflow/core/tpu/graph_rewrite/cond_builder.h"
70 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
71 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
72 #include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
73 #include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
74 #include "tensorflow/core/tpu/tpu_compile_interface.h"
75 #include "tensorflow/core/tpu/tpu_defs.h"
76 #include "tensorflow/core/tpu/tpu_fingerprint_utils.h"
77 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
78 #include "tensorflow/core/util/device_name_utils.h"
79 #include "tensorflow/core/util/dump_graph.h"
80 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
81
82 namespace tensorflow {
83
84 namespace {
85
86 // Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
87 // topology.
88 constexpr int kTPUTopologyRank = 4;
89
90 // An upper bound on how many cores may be present in the topology.
91 static constexpr int kTPUMaxTopologySize = 4096;
92
93 // Attribute containing the serialized xla::OpSharding to be passed to the
94 // corresponding XLA HLO operation, which represents how a shape is distributed
95 // across logical cores, e.g., replication, single-device, or partitioning.
96 const char kShardingAttribute[] = "_XlaSharding";
97
98 const char kTPUPartitionedInput[] = "TPUPartitionedInput";
99 const char kTPUPartitionedOutput[] = "TPUPartitionedOutput";
100
101 const char kVarHandleOp[] = "VarHandleOp";
102
103 static const char* const kTPUCompilationResultAttr = "_tpu_compilation_status";
104 static const char* const kPostDeviceRewriteAttr = "_post_device_rewrite";
105
106 using NodeAndId = std::pair<const Node*, int>;
107
108 struct NodeAndPort {
NodeAndPorttensorflow::__anon71bd90b30111::NodeAndPort109 explicit NodeAndPort(Node* node, int port) : node(node), port(port) {}
110
111 Node* node;
112 // Port of the node, e.g. this can be the `src_output` index of an Edge.
113 int port;
114 };
115
116 class IntrusiveHeapLink {
117 public:
118 using size_type = size_t;
119 static constexpr size_type kNotMember = -1;
120
121 IntrusiveHeapLink() = default;
122
123 // Only IntrusiveHeap and LinkAccess objects should make these objects.
IntrusiveHeapLink(size_type pos)124 explicit IntrusiveHeapLink(size_type pos) : pos_{pos} {}
125
126 // Only IntrusiveHeap and LinkAccess should get the value.
get() const127 size_type get() const { return pos_; }
128
129 private:
130 size_type pos_{kNotMember};
131 };
132
133 template <typename T, IntrusiveHeapLink T::*M>
134 struct IntrusiveHeapDataMemberLinkAccess {
Gettensorflow::__anon71bd90b30111::IntrusiveHeapDataMemberLinkAccess135 IntrusiveHeapLink Get(const T* elem) const { return elem->*M; }
Settensorflow::__anon71bd90b30111::IntrusiveHeapDataMemberLinkAccess136 void Set(T* elem, IntrusiveHeapLink link) const { elem->*M = link; }
137 };
138
139 template <typename T>
140 struct DefaultIntrusiveHeapLinkAccess {
Gettensorflow::__anon71bd90b30111::DefaultIntrusiveHeapLinkAccess141 IntrusiveHeapLink Get(const T* elem) const { return elem->heap; }
Settensorflow::__anon71bd90b30111::DefaultIntrusiveHeapLinkAccess142 void Set(T* elem, IntrusiveHeapLink link) const { elem->heap = link; }
143 };
144
145 template <typename T, typename PtrCompare,
146 typename LinkAccess = DefaultIntrusiveHeapLinkAccess<T>,
147 typename Alloc = std::allocator<T*>>
148 class IntrusiveHeap {
149 public:
150 typedef typename IntrusiveHeapLink::size_type size_type;
151 typedef T value_type;
152 typedef T* pointer;
153 typedef const T* const_pointer;
154 typedef PtrCompare pointer_compare_type;
155 typedef LinkAccess link_access_type;
156 typedef Alloc allocator_type;
157
IntrusiveHeap(const pointer_compare_type & comp=pointer_compare_type (),const link_access_type & link_access=link_access_type (),const allocator_type & alloc=allocator_type ())158 explicit IntrusiveHeap(
159 const pointer_compare_type& comp = pointer_compare_type(),
160 const link_access_type& link_access = link_access_type(),
161 const allocator_type& alloc = allocator_type())
162 : rep_(comp, link_access, alloc) {}
163
size() const164 size_type size() const { return heap().size(); }
165
empty() const166 bool empty() const { return heap().empty(); }
167
168 // Return the top element, but don't remove it.
top() const169 pointer top() const {
170 DCHECK(!empty());
171 return heap()[0];
172 }
173
174 // Remove the top() pointer from the heap and return it.
Pop()175 pointer Pop() {
176 pointer t = top();
177 Remove(t);
178 return t;
179 }
180
181 // Insert 't' into the heap.
Push(pointer t)182 void Push(pointer t) {
183 SetPositionOf(t, heap().size());
184 heap().push_back(t);
185 FixHeapUp(t);
186 }
187
188 // Adjust the heap to accommodate changes in '*t'.
Adjust(pointer t)189 void Adjust(pointer t) {
190 DCHECK(Contains(t));
191 size_type h = GetPositionOf(t);
192 if (h != 0 && compare()(t, heap()[(h - 1) >> 1])) {
193 FixHeapUp(t);
194 } else {
195 FixHeapDown(t);
196 }
197 }
198
199 // Remove the specified pointer from the heap.
Remove(pointer t)200 void Remove(pointer t) {
201 DCHECK(Contains(t));
202 size_type h = GetPositionOf(t);
203 SetPositionOf(t, IntrusiveHeapLink::kNotMember);
204 if (h == heap().size() - 1) {
205 // Fast path for removing from back of heap.
206 heap().pop_back();
207 return;
208 }
209 // Move the element from the back of the heap to overwrite 't'.
210 pointer& elem = heap()[h];
211 elem = heap().back();
212 SetPositionOf(elem, h); // Element has moved, so update its link.
213 heap().pop_back();
214 Adjust(elem); // Restore the heap invariant.
215 }
216
Clear()217 void Clear() { heap().clear(); }
218
Contains(const_pointer t) const219 bool Contains(const_pointer t) const {
220 size_type h = GetPositionOf(t);
221 return (h != IntrusiveHeapLink::kNotMember) && (h < size()) &&
222 heap()[h] == t;
223 }
224
reserve(size_type n)225 void reserve(size_type n) { heap().reserve(n); }
226
capacity() const227 size_type capacity() const { return heap().capacity(); }
228
get_allocator() const229 allocator_type get_allocator() const { return rep_.heap_.get_allocator(); }
230
231 private:
232 typedef std::vector<pointer, allocator_type> heap_type;
233
234 // Empty base class optimization for pointer_compare and link_access.
235 // The heap_ data member retains a copy of the allocator, so it is not
236 // stored explicitly.
237 struct Rep : pointer_compare_type, link_access_type {
Reptensorflow::__anon71bd90b30111::IntrusiveHeap::Rep238 explicit Rep(const pointer_compare_type& cmp,
239 const link_access_type& link_access,
240 const allocator_type& alloc)
241 : pointer_compare_type(cmp),
242 link_access_type(link_access),
243 heap_(alloc) {}
244 heap_type heap_; // NOLINT
245 };
246
compare() const247 const pointer_compare_type& compare() const { return rep_; }
248
link_access() const249 const link_access_type& link_access() const { return rep_; }
250
heap() const251 const heap_type& heap() const { return rep_.heap_; }
heap()252 heap_type& heap() { return rep_.heap_; }
253
GetPositionOf(const_pointer t) const254 size_type GetPositionOf(const_pointer t) const {
255 return link_access().Get(t).get();
256 }
257
SetPositionOf(pointer t,size_type pos) const258 void SetPositionOf(pointer t, size_type pos) const {
259 return link_access().Set(t, IntrusiveHeapLink(pos));
260 }
261
FixHeapUp(pointer t)262 void FixHeapUp(pointer t) {
263 size_type h = GetPositionOf(t);
264 while (h != 0) {
265 size_type parent = (h - 1) >> 1;
266 if (compare()(heap()[parent], t)) {
267 break;
268 }
269 heap()[h] = heap()[parent];
270 SetPositionOf(heap()[h], h);
271 h = parent;
272 }
273 heap()[h] = t;
274 SetPositionOf(t, h);
275 }
276
FixHeapDown(pointer t)277 void FixHeapDown(pointer t) {
278 size_type h = GetPositionOf(t);
279 for (;;) {
280 size_type kid = (h << 1) + 1;
281 if (kid >= heap().size()) {
282 break;
283 }
284 if (kid + 1 < heap().size() && compare()(heap()[kid + 1], heap()[kid])) {
285 ++kid;
286 }
287 if (compare()(t, heap()[kid])) {
288 break;
289 }
290 heap()[h] = heap()[kid];
291 SetPositionOf(heap()[h], h);
292 h = kid;
293 }
294
295 heap()[h] = t;
296 SetPositionOf(t, h);
297 }
298
299 Rep rep_;
300 };
301
CoreDeviceLabel(int core)302 string CoreDeviceLabel(int core) {
303 return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core);
304 }
305
306 // Creates a unique node name with a particular prefix.
UniqueNodeName(const StringPiece prefix,Graph * graph)307 string UniqueNodeName(const StringPiece prefix, Graph* graph) {
308 return graph->NewName(strings::StrCat(prefix, "/_", internal::GetNodeId()));
309 }
310
SetNodeDeviceForTPUCommunication(DeviceNameUtils::ParsedName device,const string & target_device_type,Node * node)311 Status SetNodeDeviceForTPUCommunication(DeviceNameUtils::ParsedName device,
312 const string& target_device_type,
313 Node* node) {
314 TF_RET_CHECK(device.has_type && device.type == DEVICE_TPU_NODE);
315 TF_RET_CHECK(device.has_id);
316 TF_RET_CHECK(HasNodeAttr(node->def(), kXlaHasHostTransferAttrName));
317
318 // Store the device instance as an attr on the Node.
319 TF_RETURN_IF_ERROR(SetDeviceOrdinalAttributeForNode(node, device.id));
320
321 // Place the execute Op on the TPU_SYSTEM device so it can access the cache of
322 // compiled protos in the resource manager.
323 device.type = target_device_type;
324 device.id = 0;
325
326 node->set_assigned_device_name(DeviceNameUtils::ParsedNameToString(device));
327 return OkStatus();
328 }
329
330 // Iterate over the nodes in the original graph and find all the TPUReplicate
331 // nodes, and all the nodes that are part of outside_compilation clusters.
FindTaggedNodes(Graph * graph,std::vector<Node * > * replicate_nodes,std::map<string,DistributedTPURewritePass::OutsideCompilationNodeMap> * outside_compilation_nodes,std::map<string,std::vector<Node * >> * head_tail_outside_compilation_nodes)332 Status FindTaggedNodes(
333 Graph* graph, std::vector<Node*>* replicate_nodes,
334 std::map<string, DistributedTPURewritePass::OutsideCompilationNodeMap>*
335 outside_compilation_nodes,
336 std::map<string, std::vector<Node*>>* head_tail_outside_compilation_nodes) {
337 for (Node* node : graph->op_nodes()) {
338 if (node->type_string() == "_TPUReplicate") {
339 replicate_nodes->push_back(node);
340 const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
341 if (cluster_attr == nullptr) {
342 return errors::Internal("TPUReplicate node ", node->name(), " has no ",
343 kTPUReplicateAttr, " attr.");
344 } else {
345 const string& cluster = cluster_attr->s();
346 if (cluster.empty()) {
347 return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
348 node->name(), " has no string value.");
349 }
350 if (outside_compilation_nodes->find(cluster) !=
351 outside_compilation_nodes->end()) {
352 return errors::Internal(
353 "TPUReplicate node ", node->name(), " has ", kTPUReplicateAttr,
354 " attr value '", cluster,
355 "' which is a duplicate of another TPUReplicate node in the "
356 "graph.");
357 }
358 (*outside_compilation_nodes)[cluster] =
359 DistributedTPURewritePass::OutsideCompilationNodeMap();
360 (*head_tail_outside_compilation_nodes)[cluster] = std::vector<Node*>();
361 }
362 }
363 }
364 for (Node* node : graph->op_nodes()) {
365 if (node->type_string() != "_TPUReplicate") {
366 const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
367 const AttrValue* outside_compilation_attr =
368 node->attrs().Find(kOutsideCompilationAttr);
369 if (cluster_attr == nullptr) {
370 if (outside_compilation_attr != nullptr) {
371 return errors::Internal("Node ", node->name(), " has ",
372 kOutsideCompilationAttr, " attr but no ",
373 kTPUReplicateAttr, " attr.");
374 }
375 } else {
376 const string& cluster = cluster_attr->s();
377 if (cluster.empty()) {
378 return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
379 node->name(), " has no string value.");
380 }
381 const auto iter = outside_compilation_nodes->find(cluster);
382 if (iter == outside_compilation_nodes->end()) {
383 return errors::Internal(
384 "Attr ", kTPUReplicateAttr, " on node ", node->name(),
385 " does not correspond to a TPUReplicate node.");
386 }
387 if (outside_compilation_attr == nullptr) {
388 return errors::Internal("Node ", node->name(), " has ",
389 kTPUReplicateAttr, " attr but no ",
390 kOutsideCompilationAttr, " attr.");
391 }
392 const string& oc_cluster = outside_compilation_attr->s();
393 if (oc_cluster.empty()) {
394 return errors::Internal("Attr ", kOutsideCompilationAttr, " on node ",
395 node->name(), " has no string value.");
396 }
397
398 // Outside compilation cluster at head and tail of TPU computation has
399 // already been moved to host and is already replicated. As so, do not
400 // replicate outside compilation nodes with replica id attribute.
401 int replica_id;
402 if (TryGetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id)) {
403 const AttrValue* head_attr =
404 node->attrs().Find("_xla_only_arg_or_oc_input");
405 const AttrValue* tail_attr =
406 node->attrs().Find("_xla_only_ret_or_oc_output");
407 if (((head_attr != nullptr) && (head_attr->b())) ||
408 ((tail_attr != nullptr) && (tail_attr->b()))) {
409 // This is safe as this has the same keys as
410 // outside_compilation_nodes which we already know has this key.
411 (*head_tail_outside_compilation_nodes)[cluster].push_back(node);
412 }
413 continue;
414 }
415 iter->second[oc_cluster].push_back(node);
416 }
417 }
418 }
419 return OkStatus();
420 }
421
422 // Helper class to spread TPU computation arguments and return values
423 // across cores.
424 // If all shapes are fully defined, balance by their size.
425 // If some of them are not fully defined, the undefined shapes size will
426 // be estimated with the average size of the fully defined ones.
427 // If none are defined, fall back to round-robin.
428 class TensorDevicePlacer {
429 public:
430 // Creates a TensorDevicePlacer object to distribute arguments or
431 // return values to a set of num_devices devices, where the types and
432 // the inferred shapes of the inputs (arguments or return values) are
433 // passed in types and shapes.
TensorDevicePlacer(int64_t num_devices,const DataTypeVector & types,const std::vector<InferredShape> & shapes)434 TensorDevicePlacer(int64_t num_devices, const DataTypeVector& types,
435 const std::vector<InferredShape>& shapes)
436 : index_nodes_(num_devices), sizes_(types.size()) {
437 int64_t total_size = 0;
438 int64_t num_defined = 0;
439 for (int64_t i = 0; i < types.size(); ++i) {
440 sizes_[i] = GetInferredShapeSize(shapes[i], types[i]);
441 if (sizes_[i] >= 0) {
442 total_size += sizes_[i];
443 ++num_defined;
444 }
445 }
446 // If a shape is undefined, select a size for it which is the average
447 // of the defined shapes. If no shapes are defined, assign 1 so that we
448 // get round-robin behavior.
449 int64_t undefined_shape_size =
450 (num_defined > 0) ? total_size / num_defined : 1;
451 for (int64_t i = 0; i < sizes_.size(); ++i) {
452 if (sizes_[i] < 0) {
453 sizes_[i] = undefined_shape_size;
454 }
455 }
456
457 for (int64_t i = 0; i < num_devices; ++i) {
458 heap_.Push(&index_nodes_[i]);
459 }
460 }
461
462 // Reports that the argument/return-value at index has been assigned
463 // by the user to a given device.
ReportDeviceAssigned(int64_t device,int64_t index)464 void ReportDeviceAssigned(int64_t device, int64_t index) {
465 if (device >= index_nodes_.size()) {
466 LOG(FATAL) << "Sharding assignment is out of bounds. " // Crash OK
467 "Check that the number of nodes is properly set.";
468 }
469 DeviceNode* node = &index_nodes_.at(device);
470 node->size += sizes_.at(index);
471 heap_.Adjust(node);
472 }
473
474 // Retrieves the device at which the argument/return-value at index
475 // should be assigned to.
RetrieveAssignment(int64_t index)476 int64_t RetrieveAssignment(int64_t index) {
477 DeviceNode* node = heap_.top();
478 int64_t device = node - index_nodes_.data();
479 node->size += sizes_.at(index);
480 heap_.Adjust(node);
481 return device;
482 }
483
484 private:
485 struct DeviceNode {
486 struct Compare {
487 // Compare functor to implement a min heap using the ::gtl::IntrusiveHeap
488 // infrastructure.
operator ()tensorflow::__anon71bd90b30111::TensorDevicePlacer::DeviceNode::Compare489 bool operator()(const DeviceNode* lhs, const DeviceNode* rhs) const {
490 return lhs->size < rhs->size;
491 }
492 };
493
494 IntrusiveHeapLink heap;
495 int64_t size = 0;
496 };
497
GetInferredShapeSize(const InferredShape & ishape,DataType dtype)498 static int64_t GetInferredShapeSize(const InferredShape& ishape,
499 DataType dtype) {
500 return ishape.shape.IsFullyDefined()
501 ? ishape.shape.num_elements() * DataTypeSize(dtype)
502 : -1;
503 }
504
505 std::vector<DeviceNode> index_nodes_;
506 IntrusiveHeap<DeviceNode, typename DeviceNode::Compare> heap_;
507 std::vector<int64_t> sizes_;
508 };
509
ValidateCoreNumber(int64_t core,int64_t num_cores_per_replica)510 Status ValidateCoreNumber(int64_t core, int64_t num_cores_per_replica) {
511 if (core < 0 || core >= num_cores_per_replica) {
512 return tensorflow::errors::InvalidArgument("Invalid core ID: ", core,
513 ". The valid core IDs are [0..",
514 num_cores_per_replica, ")");
515 }
516 return OkStatus();
517 }
518
FindHostComputeKeyPlaceholderNodes(const Graph * graph,const std::vector<Node * > & replicate_nodes,std::unordered_map<string,Node * > * host_compute_key_placeholder_map)519 Status FindHostComputeKeyPlaceholderNodes(
520 const Graph* graph, const std::vector<Node*>& replicate_nodes,
521 std::unordered_map<string, Node*>* host_compute_key_placeholder_map) {
522 host_compute_key_placeholder_map->clear();
523 for (const auto node : replicate_nodes) {
524 (*host_compute_key_placeholder_map)[node->name()] = nullptr;
525 }
526
527 for (Node* node : graph->op_nodes()) {
528 if (node->type_string() == "Placeholder" &&
529 str_util::EndsWith(node->name(), "_key_placeholder")) {
530 const AttrValue* call_node_attr =
531 node->attrs().Find("_host_compute_call_node");
532 if (call_node_attr != nullptr) {
533 auto iter = host_compute_key_placeholder_map->find(call_node_attr->s());
534 if (iter == host_compute_key_placeholder_map->end()) {
535 return errors::InvalidArgument(
536 "Node ", node->name(), " has _host_compute_call_node attribute '",
537 call_node_attr->s(), "' that doesn't correspond to a call node");
538 }
539 if (iter->second != nullptr) {
540 return errors::InvalidArgument(
541 "Key placeholder node ", iter->second->name(), " for call node ",
542 call_node_attr->s(), " previously found as ",
543 iter->second->name());
544 }
545 iter->second = node;
546 }
547 }
548 }
549
550 return OkStatus();
551 }
552
ReplaceCompilationResultNodeWithIdentity(Graph * graph,Node ** node)553 Status ReplaceCompilationResultNodeWithIdentity(Graph* graph, Node** node) {
554 Node* old_node = *node;
555 // We want to replace the node with an identity node with the same name.
556 const string& node_name = old_node->name();
557
558 // Create identity node.
559 TF_ASSIGN_OR_RETURN(
560 Node * id_node,
561 BuildIdentityNode(graph, node_name, DT_STRING,
562 /*input=*/nullptr, /*requested_device=*/""));
563
564 // No incoming edges are copied as a new one will be added from compile node
565 // to id_node.
566
567 // Copy outgoing edges to the id node.
568 std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
569 old_node->out_edges().end());
570 for (const Edge* edge : out_edges) {
571 Node* dst = edge->dst();
572 int src_output = edge->src_output();
573 int dst_input = edge->dst_input();
574
575 if (src_output == Graph::kControlSlot) {
576 graph->AddControlEdge(id_node, dst);
577 } else {
578 graph->AddEdge(id_node, src_output, dst, dst_input);
579 }
580 graph->RemoveEdge(edge);
581 }
582 graph->RemoveNode(old_node);
583
584 *node = id_node;
585 return OkStatus();
586 }
587
GetStepMarkerLocation(const Node & replicate_node,xla::DebugOptions::StepMarkerLocation * location)588 Status GetStepMarkerLocation(const Node& replicate_node,
589 xla::DebugOptions::StepMarkerLocation* location) {
590 string step_marker_location_attr;
591 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "step_marker_location",
592 &step_marker_location_attr));
593 if (step_marker_location_attr.empty()) {
594 *location = xla::DebugOptions::STEP_MARK_AT_ENTRY;
595 } else {
596 if (!xla::DebugOptions::StepMarkerLocation_Parse(step_marker_location_attr,
597 location)) {
598 return errors::InvalidArgument("Malformed step_marker_location: ",
599 step_marker_location_attr);
600 }
601 }
602 return OkStatus();
603 }
604
605 // Extracts a map of dimension and number of splits for tiled input from xla
606 // sharding attribute.
GetDimensionIndicesAndNumSplitsFromSharding(const xla::OpSharding & sharding,std::map<int,int> * split_dimension_map)607 Status GetDimensionIndicesAndNumSplitsFromSharding(
608 const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) {
609 int64_t tensor_tile_rank = sharding.tile_assignment_dimensions_size();
610 if (sharding.replicate_on_last_tile_dim()) {
611 tensor_tile_rank--;
612 }
613 for (int dim_index = 0; dim_index < tensor_tile_rank; dim_index++) {
614 if (sharding.tile_assignment_dimensions(dim_index) > 1) {
615 split_dimension_map->emplace(
616 dim_index, sharding.tile_assignment_dimensions(dim_index));
617 }
618 }
619
620 if (split_dimension_map->empty()) {
621 return errors::InvalidArgument("Arg has unnecessary tiled sharding: ",
622 sharding.DebugString());
623 }
624 return OkStatus();
625 }
626
627 // Updates contents of the function with `function_name` in function library
628 // definition `flib_def` to `new_graph`. This is required when graph
629 // transformation happens inside a function call body.
UpdateFunctionLibDefinition(const Graph & new_graph,const std::string & function_name,FunctionLibraryDefinition * flib_def)630 Status UpdateFunctionLibDefinition(const Graph& new_graph,
631 const std::string& function_name,
632 FunctionLibraryDefinition* flib_def) {
633 FunctionDef graph_fdef;
634 TF_RETURN_IF_ERROR(GraphToFunctionDef(new_graph, function_name, &graph_fdef));
635 TF_RETURN_IF_ERROR(flib_def->ReplaceFunction(function_name, graph_fdef));
636 return OkStatus();
637 }
638
639 struct NodeOut {
640 Node* node;
641 int index;
642 };
643
644 struct ShardedInputIndex {
645 int replica_id;
646 int argument_index;
647
operator <tensorflow::__anon71bd90b30111::ShardedInputIndex648 bool operator<(const ShardedInputIndex& rhs) const {
649 return std::tie(replica_id, argument_index) <
650 std::tie(rhs.replica_id, rhs.argument_index);
651 }
652 };
653
654 struct ShardedPerHostInputIndex {
655 string host_device;
656 int argument_index;
operator <tensorflow::__anon71bd90b30111::ShardedPerHostInputIndex657 bool operator<(const ShardedPerHostInputIndex& rhs) const {
658 return std::tie(host_device, argument_index) <
659 std::tie(rhs.host_device, rhs.argument_index);
660 }
operator ==tensorflow::__anon71bd90b30111::ShardedPerHostInputIndex661 bool operator==(const ShardedPerHostInputIndex& rhs) const {
662 return (argument_index == rhs.argument_index) &&
663 (host_device == rhs.host_device);
664 }
665 };
666
667 struct ShardedInputInfo {
668 // Split node that would be connected to tiled input Node.
669 Node* split_node;
670 // List of splits nodes and output index of the split node from which sharded
671 // input will be connected to the TPUExecute node. The inputs are ordered by
672 // logical core ids.
673 std::vector<NodeOut> sharded_inputs;
674 };
675
676 // Adds pad node after split node to graph for uneven sharding tiled inputs.
677 // |graph| owns the returned Node* instance.
CreatePadNode(const int padding,const int num_dims,const int split_dim,DataType dtype,Node * control_predecessor,Node * split_node,const int split_index,Graph * graph)678 xla::StatusOr<Node*> CreatePadNode(const int padding, const int num_dims,
679 const int split_dim, DataType dtype,
680 Node* control_predecessor, Node* split_node,
681 const int split_index, Graph* graph) {
682 // Add paddings node.
683 Status s;
684 NodeDef paddings_def;
685 paddings_def.set_name(
686 graph->NewName(absl::StrCat(split_node->name(), "/paddings")));
687 paddings_def.set_op("Const");
688 AddNodeAttr("dtype", DT_INT32, &paddings_def);
689 paddings_def.set_device(split_node->assigned_device_name());
690 TensorProto sizes_tensor_proto;
691 sizes_tensor_proto.set_dtype(DT_INT32);
692 for (int i = 0; i < num_dims; ++i) {
693 sizes_tensor_proto.add_int_val(0);
694 if (i == split_dim) {
695 sizes_tensor_proto.add_int_val(padding);
696 } else {
697 sizes_tensor_proto.add_int_val(0);
698 }
699 }
700 TensorShape sizes_shape({num_dims, 2});
701 sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
702 AddNodeAttr("value", sizes_tensor_proto, &paddings_def);
703 TF_ASSIGN_OR_RETURN(Node * paddings_node, graph->AddNode(paddings_def));
704
705 // Add Pad node.
706 NodeDef pad_def;
707 pad_def.set_name(graph->NewName(
708 absl::StrCat(split_node->name(), "/pad_shard_", split_index)));
709 pad_def.set_op("Pad");
710 pad_def.set_device(split_node->assigned_device_name());
711 AddNodeAttr("T", dtype, &pad_def);
712 AddNodeAttr("Tpaddings", DT_INT32, &pad_def);
713 pad_def.add_input(absl::StrCat(split_node->name(), ":", split_index));
714 pad_def.add_input(absl::StrCat(paddings_node->name(), ":0"));
715 TF_ASSIGN_OR_RETURN(Node * pad_node, graph->AddNode(pad_def));
716 pad_node->set_assigned_device_name(split_node->assigned_device_name());
717 // Add edges for pad node.
718 graph->AddEdge(split_node, split_index, pad_node, 0);
719 graph->AddEdge(paddings_node, 0, pad_node, 1);
720 graph->AddControlEdge(control_predecessor, pad_node);
721 return pad_node;
722 }
723
724 // Adds split node and split dimension node to graph for sharding tiled inputs.
725 // |graph| owns the returned Node* instance.
CreateSplitNode(const int num_splits,const int dim,const int num_dims,const int64_t padding,const int orig_src_output,DataType dtype,absl::string_view name_prefix,Node * control_predecessor,Node * orig_src,Graph * graph)726 xla::StatusOr<Node*> CreateSplitNode(const int num_splits, const int dim,
727 const int num_dims, const int64_t padding,
728 const int orig_src_output, DataType dtype,
729 absl::string_view name_prefix,
730 Node* control_predecessor, Node* orig_src,
731 Graph* graph) {
732 const std::string input_assigned_device = orig_src->assigned_device_name();
733 Node* to_split_node = orig_src;
734 int to_split_index = orig_src_output;
735 if (padding > 0) {
736 TF_ASSIGN_OR_RETURN(
737 Node * pad_node,
738 CreatePadNode(padding, num_dims, dim, dtype, control_predecessor,
739 orig_src, orig_src_output, graph));
740 to_split_node = pad_node;
741 to_split_index = 0;
742 }
743
744 // Add a split dimension node.
745 NodeDef split_dim_def;
746 split_dim_def.set_name(
747 graph->NewName(absl::StrCat(name_prefix, "/split_dim")));
748 split_dim_def.set_op("Const");
749 split_dim_def.set_device(input_assigned_device);
750 AddNodeAttr("dtype", DT_INT32, &split_dim_def);
751 TensorProto tensor_proto;
752 tensor_proto.set_dtype(DT_INT32);
753 tensor_proto.add_int_val(dim);
754 TensorShape shape({});
755 shape.AsProto(tensor_proto.mutable_tensor_shape());
756 AddNodeAttr("value", tensor_proto, &split_dim_def);
757 TF_ASSIGN_OR_RETURN(Node * split_dim_node, graph->AddNode(split_dim_def));
758 // Add a split node.
759 NodeDef split_def;
760 split_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/split")));
761 split_def.set_op("Split");
762 split_def.set_device(input_assigned_device);
763 AddNodeAttr("num_split", num_splits, &split_def);
764 AddNodeAttr("T", dtype, &split_def);
765 split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
766 split_def.add_input(absl::StrCat(to_split_node->name(), ":", to_split_index));
767 TF_ASSIGN_OR_RETURN(Node * split_node, graph->AddNode(split_def));
768
769 split_node->set_assigned_device_name(input_assigned_device);
770
771 // If colocate the newly created split op to source node of input to TPU
772 // computation.
773 split_node->AddAttr(kColocationAttrName,
774 std::vector<string>{absl::StrCat(kColocationGroupPrefix,
775 orig_src->name())});
776
777 graph->AddEdge(split_dim_node, 0, split_node, 0);
778 graph->AddEdge(to_split_node, to_split_index, split_node, 1);
779
780 // Add a control dependency from `control_predecessor` to newly created
781 // constant node. This ensures that newly added split/split dim
782 // nodes are placed inside correct while loop frames when TPUExecute
783 // node is inside a host training loop.
784 graph->AddControlEdge(control_predecessor, split_dim_node);
785 return split_node;
786 }
787
GetPadding(const int split_dim,const int num_splits,const PartialTensorShape & partial_tensor_shape)788 int64_t GetPadding(const int split_dim, const int num_splits,
789 const PartialTensorShape& partial_tensor_shape) {
790 // If dim dimension is not defined, no uneven sharding support.
791 if (partial_tensor_shape.dim_size(split_dim) <= 0) {
792 return 0;
793 }
794 int64_t per_split_size = tensorflow::MathUtil::CeilOfRatio<int64_t>(
795 partial_tensor_shape.dim_size(split_dim), num_splits);
796 int64_t total_padding =
797 per_split_size * num_splits - partial_tensor_shape.dim_size(split_dim);
798 return total_padding;
799 }
800
801 // Creates a set of splits nodes that shards tiled input node in graph.
CreateOrGetSplitNodesForInputSharding(const xla::OpSharding & sharding,int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,int replica_id,int orig_src_output,Node * orig_src,Node * control_predecessor,Graph * graph,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)802 xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
803 const xla::OpSharding& sharding, int orig_arg_num, DataType dtype,
804 const PartialTensorShape& partial_tensor_shape, int replica_id,
805 int orig_src_output, Node* orig_src, Node* control_predecessor,
806 Graph* graph,
807 std::map<ShardedInputIndex, ShardedInputInfo>*
808 arg_index_to_sharded_input_map) {
809 ShardedInputIndex input_index{replica_id, orig_arg_num};
810 auto iter = arg_index_to_sharded_input_map->find(input_index);
811 if (iter != arg_index_to_sharded_input_map->end()) {
812 return iter->second;
813 }
814 // Maps input dimension and number of splits with which the
815 // dimension sharded.
816 std::map<int, int> split_dimension_map;
817 TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
818 sharding, &split_dimension_map));
819 TF_RET_CHECK(!split_dimension_map.empty())
820 << "Unnecessary sharding attribute found.";
821
822 // For v1 while loop, nodes inside the loop body must either
823 // 1) Have data edges from while loop input node.
824 // or
825 // 2) Have direct control dependency from while loop input control
826 // node.
827 //
828 // As so, if we are adding Split node inside, while loop body,
829 // we must manually add a control dependency to a node inside
830 // a while loop (i.e. `control_predecessor`) to constant nodes
831 // without data in-edges to make sure that added split nodes
832 // have correct frame name. Else, placer will complain when
833 // `BuildControlFlow()` is invoked.
834
835 auto sharding_it = split_dimension_map.begin();
836 std::queue<Node*> split_nodes_for_dimension;
837 absl::flat_hash_map<Node*, int> node_to_split_dim;
838 int split_dimension = sharding_it->first;
839 int num_split = sharding_it->second;
840
841 // Creates a tree of split nodes for sharding tiled inputs. Splits nodes
842 // are created such that input data is sharded in row major order.
843 // Split nodes at ith depth from the original input node represent nodes
844 // that split the input data at ith dimension.
845 TF_ASSIGN_OR_RETURN(
846 Node * root_split_node,
847 CreateSplitNode(
848 num_split, split_dimension, partial_tensor_shape.dims(),
849 GetPadding(split_dimension, num_split, partial_tensor_shape),
850 orig_src_output, dtype,
851 absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
852 split_dimension),
853 control_predecessor, orig_src, graph));
854 sharding_it++;
855
856 split_nodes_for_dimension.emplace(root_split_node);
857 node_to_split_dim[root_split_node] = split_dimension;
858
859 while (sharding_it != split_dimension_map.end()) {
860 split_dimension = sharding_it->first;
861 num_split = sharding_it->second;
862 int num_split_nodes_in_dimension = split_nodes_for_dimension.size();
863 for (int i = 0; i < num_split_nodes_in_dimension; ++i) {
864 Node* input_split_node = split_nodes_for_dimension.front();
865 split_nodes_for_dimension.pop();
866 for (int src_output_index = 0;
867 src_output_index < input_split_node->num_outputs();
868 ++src_output_index) {
869 TF_ASSIGN_OR_RETURN(
870 Node * split_node,
871 CreateSplitNode(
872 num_split, split_dimension, partial_tensor_shape.dims(),
873 GetPadding(split_dimension, num_split, partial_tensor_shape),
874 src_output_index, dtype,
875 absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
876 split_dimension),
877 control_predecessor, input_split_node, graph));
878 split_nodes_for_dimension.emplace(split_node);
879 node_to_split_dim[split_node] = split_dimension;
880 }
881 }
882 sharding_it++;
883 }
884
885 // `split_nodes_for_dimension` now includes final split nodes
886 // from which sharded data will be fed into TPUExcute nodes -- sorted by
887 // row major order.
888 std::vector<NodeOut> sharded_inputs_list(
889 sharding.tile_assignment_devices_size());
890 int64_t next_core_tile_index = 0;
891 while (!split_nodes_for_dimension.empty()) {
892 Node* split_node = split_nodes_for_dimension.front();
893 split_nodes_for_dimension.pop();
894 int num_splits;
895 TF_RETURN_IF_ERROR(
896 GetNodeAttr(split_node->def(), "num_split", &num_splits));
897 for (int out_index = 0; out_index < num_splits; ++out_index) {
898 int64_t repeat_count =
899 sharding.replicate_on_last_tile_dim()
900 ? *sharding.tile_assignment_dimensions().rbegin()
901 : 1;
902 for (int64_t i = 0; i < repeat_count; ++i) {
903 int64_t next_core =
904 sharding.tile_assignment_devices(next_core_tile_index++);
905 sharded_inputs_list[next_core] = NodeOut{split_node, out_index};
906 }
907 }
908 }
909
910 ShardedInputInfo sharded_input_info{root_split_node,
911 std::move(sharded_inputs_list)};
912 (*arg_index_to_sharded_input_map)[input_index] = sharded_input_info;
913 return sharded_input_info;
914 }
915
916 // Creates a xla split node to shard an input, and adds that new node to a
917 // Graph.
CreateXlaSplitOp(absl::string_view node_name,const bool is_resource,const NodeOut & input,const PartialTensorShape & partial_tensor_shape,const std::vector<Node * > & control_inputs,const std::vector<Node * > & control_outputs,const DataType dtype,const int num_shards,const xla::OpSharding & sharding,Graph * graph)918 StatusOr<Node*> CreateXlaSplitOp(absl::string_view node_name,
919 const bool is_resource, const NodeOut& input,
920 const PartialTensorShape& partial_tensor_shape,
921 const std::vector<Node*>& control_inputs,
922 const std::vector<Node*>& control_outputs,
923 const DataType dtype, const int num_shards,
924 const xla::OpSharding& sharding,
925 Graph* graph) {
926 const std::string& input_assigned_device = input.node->assigned_device_name();
927 NodeDef xla_split_def;
928 xla_split_def.set_name(graph->NewName(node_name));
929 xla_split_def.set_op(is_resource ? "ReadVariableXlaSplitND" : "XlaSplitND");
930 xla_split_def.set_device(input_assigned_device);
931 AddNodeAttr("T", dtype, &xla_split_def);
932 AddNodeAttr("N", num_shards, &xla_split_def);
933 const std::vector<int64_t> num_splits(
934 sharding.tile_assignment_dimensions().begin(),
935 sharding.replicate_on_last_tile_dim()
936 ? std::prev(sharding.tile_assignment_dimensions().end())
937 : sharding.tile_assignment_dimensions().end());
938 AddNodeAttr("num_splits", num_splits, &xla_split_def);
939 const int rank = sharding.replicate_on_last_tile_dim()
940 ? sharding.tile_assignment_dimensions_size() - 1
941 : sharding.tile_assignment_dimensions_size();
942 std::vector<int32> paddings;
943 paddings.reserve(rank);
944 for (int dim = 0; dim < rank; ++dim) {
945 paddings.push_back(GetPadding(dim, sharding.tile_assignment_dimensions(dim),
946 partial_tensor_shape));
947 }
948 AddNodeAttr("paddings", paddings, &xla_split_def);
949
950 if (!is_resource) {
951 AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &xla_split_def);
952 AddNodeAttr(kColocationAttrName,
953 std::vector<string>{
954 absl::StrCat(kColocationGroupPrefix, input.node->name())},
955 &xla_split_def);
956 }
957
958 TF_ASSIGN_OR_RETURN(Node * xla_split, graph->AddNode(xla_split_def));
959 if (is_resource) {
960 xla_split->set_requested_device(input.node->requested_device());
961 }
962 xla_split->set_assigned_device_name(input_assigned_device);
963 graph->AddEdge(input.node, input.index, xla_split, 0);
964 for (Node* control_input : control_inputs) {
965 graph->AddControlEdge(control_input, xla_split);
966 }
967 for (Node* control_output : control_outputs) {
968 graph->AddControlEdge(xla_split, control_output);
969 }
970 return xla_split;
971 }
972
973 // Creates a sharded tensor list for all input shards of an input with sharding.
ShardInputWithXlaSplitOp(absl::string_view node_name,const bool is_resource,const NodeOut & input,const PartialTensorShape & partial_tensor_shape,const std::vector<Node * > & control_inputs,const std::vector<Node * > & control_outputs,const DataType dtype,const xla::OpSharding & sharding,Graph * graph)974 xla::StatusOr<std::vector<NodeOut>> ShardInputWithXlaSplitOp(
975 absl::string_view node_name, const bool is_resource, const NodeOut& input,
976 const PartialTensorShape& partial_tensor_shape,
977 const std::vector<Node*>& control_inputs,
978 const std::vector<Node*>& control_outputs, const DataType dtype,
979 const xla::OpSharding& sharding, Graph* graph) {
980 const int repeat = sharding.replicate_on_last_tile_dim()
981 ? *sharding.tile_assignment_dimensions().rbegin()
982 : 1;
983 const int num_shards = sharding.tile_assignment_devices_size() / repeat;
984
985 TF_ASSIGN_OR_RETURN(
986 Node * xla_split,
987 CreateXlaSplitOp(node_name, is_resource, input, partial_tensor_shape,
988 control_inputs, control_outputs, dtype, num_shards,
989 sharding, graph));
990
991 std::vector<NodeOut> sharded_inputs_list(
992 sharding.tile_assignment_devices_size());
993
994 for (int i = 0; i < num_shards; ++i) {
995 for (int j = 0; j < repeat; ++j) {
996 const int index = i * repeat + j;
997 const int core = sharding.tile_assignment_devices(index);
998 sharded_inputs_list[core] = {xla_split, i};
999 }
1000 }
1001
1002 return sharded_inputs_list;
1003 }
1004
1005 // Creates an XlaSplitND op to shard a per-replica arg.
CreateOrGetXlaSplitNodeForShardedPerReplicaArg(const xla::OpSharding & sharding,const int replica_id,const int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,Node * orig_src,const int orig_src_output,Graph * graph,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)1006 xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForShardedPerReplicaArg(
1007 const xla::OpSharding& sharding, const int replica_id,
1008 const int orig_arg_num, DataType dtype,
1009 const PartialTensorShape& partial_tensor_shape, Node* orig_src,
1010 const int orig_src_output, Graph* graph,
1011 std::map<ShardedInputIndex, ShardedInputInfo>*
1012 arg_index_to_sharded_input_map) {
1013 ShardedInputIndex input_index{replica_id, orig_arg_num};
1014 auto iter = arg_index_to_sharded_input_map->find(input_index);
1015 if (iter != arg_index_to_sharded_input_map->end()) {
1016 return iter->second;
1017 }
1018
1019 TF_ASSIGN_OR_RETURN(
1020 std::vector<NodeOut> sharded_inputs_list,
1021 ShardInputWithXlaSplitOp(
1022 absl::StrCat(orig_src->name(), "/replica_", replica_id, "_split"),
1023 /*is_resource=*/false, /*input=*/{orig_src, orig_src_output},
1024 partial_tensor_shape, /*control_inputs=*/{}, /*control_outputs=*/{},
1025 dtype, sharding, graph));
1026
1027 ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
1028 (*arg_index_to_sharded_input_map)[input_index] = sharded_input_info;
1029 return sharded_input_info;
1030 }
1031
1032 // Creates an XlaSplitND op to shard a distributed arg.
CreateOrGetXlaSplitNodeForDistributedArg(const xla::OpSharding & sharding,const int num_replicas,const int replica_id,const int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,Node * orig_src,const int orig_src_output,Graph * graph,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)1033 xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForDistributedArg(
1034 const xla::OpSharding& sharding, const int num_replicas,
1035 const int replica_id, const int orig_arg_num, DataType dtype,
1036 const PartialTensorShape& partial_tensor_shape, Node* orig_src,
1037 const int orig_src_output, Graph* graph,
1038 std::map<ShardedInputIndex, ShardedInputInfo>*
1039 arg_index_to_sharded_input_map) {
1040 ShardedInputIndex input_index{replica_id, orig_arg_num};
1041 auto iter = arg_index_to_sharded_input_map->find(input_index);
1042 if (iter != arg_index_to_sharded_input_map->end()) {
1043 return iter->second;
1044 }
1045
1046 TF_ASSIGN_OR_RETURN(
1047 std::vector<NodeOut> sharded_inputs_list,
1048 ShardInputWithXlaSplitOp(
1049 absl::StrCat(orig_src->name(), "/distributed_split"),
1050 /*is_resource=*/false, /*input=*/{orig_src, orig_src_output},
1051 partial_tensor_shape, /*control_inputs=*/{}, /*control_outputs=*/{},
1052 dtype, sharding, graph));
1053
1054 ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
1055 for (int replica = 0; replica < num_replicas; ++replica) {
1056 (*arg_index_to_sharded_input_map)[{replica, orig_arg_num}] =
1057 sharded_input_info;
1058 }
1059 return sharded_input_info;
1060 }
1061
1062 // Creates an ReadVariableXlaSplitND op to shard a variable arg.
CreateOrGetXlaSplitNodeForVariableArg(const xla::OpSharding & sharding,const int num_replicas,const int replica_id,const int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,Node * orig_src,const int orig_src_output,Graph * graph,std::vector<Node * > * to_be_removed_nodes,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)1063 xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForVariableArg(
1064 const xla::OpSharding& sharding, const int num_replicas,
1065 const int replica_id, const int orig_arg_num, DataType dtype,
1066 const PartialTensorShape& partial_tensor_shape, Node* orig_src,
1067 const int orig_src_output, Graph* graph,
1068 std::vector<Node*>* to_be_removed_nodes,
1069 std::map<ShardedInputIndex, ShardedInputInfo>*
1070 arg_index_to_sharded_input_map) {
1071 ShardedInputIndex input_index{replica_id, orig_arg_num};
1072 auto iter = arg_index_to_sharded_input_map->find(input_index);
1073 if (iter != arg_index_to_sharded_input_map->end()) {
1074 return iter->second;
1075 }
1076
1077 DCHECK_EQ(orig_src->type_string(), "ReadVariableOp");
1078 std::vector<Node*> control_outputs;
1079 std::vector<const Edge*> edges_to_remove;
1080 for (const Edge* edge : orig_src->out_edges()) {
1081 if (edge->IsControlEdge()) {
1082 control_outputs.push_back(edge->dst());
1083 }
1084 edges_to_remove.push_back(edge);
1085 }
1086
1087 to_be_removed_nodes->push_back(orig_src);
1088
1089 const Edge* resource = nullptr;
1090 TF_RETURN_IF_ERROR(orig_src->input_edge(0, &resource));
1091
1092 std::vector<Node*> control_inputs;
1093 for (const Edge* edge : orig_src->in_edges()) {
1094 if (edge->IsControlEdge()) {
1095 control_inputs.push_back(edge->src());
1096 }
1097 }
1098
1099 TF_ASSIGN_OR_RETURN(
1100 std::vector<NodeOut> sharded_inputs_list,
1101 ShardInputWithXlaSplitOp(
1102 absl::StrCat(resource->src()->name(), "/read_variable_split"),
1103 /*is_resource=*/true,
1104 /*input=*/{resource->src(), resource->src_output()},
1105 partial_tensor_shape, control_inputs, control_outputs, dtype,
1106 sharding, graph));
1107
1108 for (const Edge* edge : edges_to_remove) {
1109 graph->RemoveControlEdge(edge);
1110 }
1111
1112 DCHECK(orig_src->out_edges().empty());
1113
1114 ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
1115 for (int replica = 0; replica < num_replicas; ++replica) {
1116 ShardedInputIndex idx{replica, orig_arg_num};
1117 // Refrain from overwriting, if dummy inputs were already placed instead.
1118 arg_index_to_sharded_input_map->insert({idx, sharded_input_info});
1119 }
1120 return sharded_input_info;
1121 }
1122
1123 // Creates a concat node to be used for aggregating sharded retvals across
1124 // logical cores.
CreateConcatNode(int dim,int num_splits,DataType dtype,absl::string_view name_prefix,const std::vector<NodeOut> & inputs,Graph * graph,absl::string_view device)1125 xla::StatusOr<Node*> CreateConcatNode(int dim, int num_splits, DataType dtype,
1126 absl::string_view name_prefix,
1127 const std::vector<NodeOut>& inputs,
1128 Graph* graph, absl::string_view device) {
1129 // Add a Concat dim node.
1130 NodeDef concat_dim_def;
1131 concat_dim_def.set_name(
1132 graph->NewName(absl::StrCat(name_prefix, "/concat_dim")));
1133 concat_dim_def.set_op("Const");
1134 AddNodeAttr("dtype", DT_INT32, &concat_dim_def);
1135 concat_dim_def.set_device(std::string(device));
1136 TensorProto tensor_proto;
1137 tensor_proto.set_dtype(DT_INT32);
1138 tensor_proto.add_int_val(dim);
1139 TensorShape shape({});
1140 shape.AsProto(tensor_proto.mutable_tensor_shape());
1141 AddNodeAttr("value", tensor_proto, &concat_dim_def);
1142 TF_ASSIGN_OR_RETURN(Node * concat_dim_node, graph->AddNode(concat_dim_def));
1143
1144 // Add a Concat node.
1145 NodeDef concat_def;
1146 concat_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/concat")));
1147 concat_def.set_op("Concat");
1148 AddNodeAttr("N", num_splits, &concat_def);
1149 AddNodeAttr("T", dtype, &concat_def);
1150 concat_def.add_input(absl::StrCat(concat_dim_node->name(), ":0"));
1151 concat_def.set_device(std::string(device));
1152 for (const auto& i : inputs) {
1153 concat_def.add_input(absl::StrCat(i.node->name(), ":", i.index));
1154 }
1155 TF_ASSIGN_OR_RETURN(Node * concat_node, graph->AddNode(concat_def));
1156
1157 graph->AddEdge(concat_dim_node, 0, concat_node, 0);
1158
1159 // 0th input to concat node is a concat dim node. So we start from 1st input
1160 // and add all input edges.
1161 int dst_input = 1;
1162 for (const auto& i : inputs) {
1163 graph->AddEdge(i.node, i.index, concat_node, dst_input);
1164 ++dst_input;
1165 }
1166 return concat_node;
1167 }
1168
1169 // Adds slice node after concat node to graph for uneven sharding tiled inputs.
CreateSliceNode(DataType dtype,const PartialTensorShape & shape,Node * concat_node,const int concat_out_index,Graph * graph,absl::string_view device)1170 xla::StatusOr<Node*> CreateSliceNode(DataType dtype,
1171 const PartialTensorShape& shape,
1172 Node* concat_node,
1173 const int concat_out_index, Graph* graph,
1174 absl::string_view device) {
1175 Status s;
1176 // Add begin node for concat.
1177 NodeDef begin_def;
1178 begin_def.set_name(
1179 graph->NewName(absl::StrCat(concat_node->name(), "/slice_begin")));
1180 begin_def.set_op("Const");
1181 AddNodeAttr("dtype", DT_INT32, &begin_def);
1182 begin_def.set_device(std::string(device));
1183 TensorProto begin_tensor_proto;
1184 begin_tensor_proto.set_dtype(DT_INT32);
1185 for (int i = 0; i < shape.dims(); ++i) {
1186 begin_tensor_proto.add_int_val(0);
1187 }
1188 TensorShape begin_shape({shape.dims()});
1189 begin_shape.AsProto(begin_tensor_proto.mutable_tensor_shape());
1190 AddNodeAttr("value", begin_tensor_proto, &begin_def);
1191 TF_ASSIGN_OR_RETURN(Node * begin_node, graph->AddNode(begin_def));
1192
1193 // Add size node.
1194 NodeDef size_def;
1195 size_def.set_name(
1196 graph->NewName(absl::StrCat(concat_node->name(), "/slice_size")));
1197 size_def.set_op("Const");
1198 AddNodeAttr("dtype", DT_INT32, &size_def);
1199 size_def.set_device(std::string(device));
1200 TensorProto sizes_tensor_proto;
1201 sizes_tensor_proto.set_dtype(DT_INT32);
1202 for (int i = 0; i < shape.dims(); ++i) {
1203 sizes_tensor_proto.add_int_val(shape.dim_size(i));
1204 }
1205 TensorShape sizes_shape({shape.dims()});
1206 sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
1207 AddNodeAttr("value", sizes_tensor_proto, &size_def);
1208 TF_ASSIGN_OR_RETURN(Node * size_node, graph->AddNode(size_def));
1209
1210 // Add Slice node.
1211 NodeDef slice_def;
1212 slice_def.set_name(
1213 graph->NewName(absl::StrCat(concat_node->name(), "/slice")));
1214 slice_def.set_op("Slice");
1215 slice_def.set_device(std::string(device));
1216 AddNodeAttr("T", dtype, &slice_def);
1217 AddNodeAttr("Index", DT_INT32, &slice_def);
1218 slice_def.add_input(absl::StrCat(concat_node->name(), ":", concat_out_index));
1219 slice_def.add_input(absl::StrCat(begin_node->name(), ":0"));
1220 slice_def.add_input(absl::StrCat(size_node->name(), ":0"));
1221 TF_ASSIGN_OR_RETURN(Node * slice_node, graph->AddNode(slice_def));
1222 // Add edges for slice node.
1223 graph->AddEdge(concat_node, concat_out_index, slice_node, 0);
1224 graph->AddEdge(begin_node, 0, slice_node, 1);
1225 graph->AddEdge(size_node, 0, slice_node, 2);
1226 return slice_node;
1227 }
1228
1229 // Creates a set of Concat nodes that aggregates sharded outputs from TPUExecute
1230 // nodes into a single output. Sharded outputs are concatenated along row major
1231 // order. That is, tiled output along 0th dimension will be concatenated last.
CreateConcatNodesForRetval(const xla::OpSharding & sharding,DataType dtype,const PartialTensorShape & inferred_shape,int replica_id,const std::vector<NodeOut> & orig_inputs,Graph * graph,absl::string_view device)1232 xla::StatusOr<Node*> CreateConcatNodesForRetval(
1233 const xla::OpSharding& sharding, DataType dtype,
1234 const PartialTensorShape& inferred_shape, int replica_id,
1235 const std::vector<NodeOut>& orig_inputs, Graph* graph,
1236 absl::string_view device) {
1237 std::map<int, int> split_dimension_map;
1238 TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
1239 sharding, &split_dimension_map));
1240 std::vector<NodeOut> inputs_to_sharded_retval = orig_inputs;
1241 bool has_paddings = false;
1242
1243 for (auto it = split_dimension_map.rbegin(); it != split_dimension_map.rend();
1244 it++) {
1245 auto dim = it->first;
1246 auto num_splits = it->second;
1247
1248 int num_concat_nodes = inputs_to_sharded_retval.size() / num_splits;
1249 int input_index_to_concat_node = 0;
1250
1251 std::vector<NodeOut> new_concat_nodes;
1252 for (int i = 0; i < num_concat_nodes; ++i) {
1253 auto concat_input_it =
1254 inputs_to_sharded_retval.begin() + input_index_to_concat_node;
1255 std::vector<NodeOut> inputs(concat_input_it,
1256 concat_input_it + num_splits);
1257 input_index_to_concat_node += num_splits;
1258
1259 TF_ASSIGN_OR_RETURN(
1260 Node * concat_node,
1261 CreateConcatNode(
1262 dim, num_splits, dtype,
1263 absl::StrCat("sharded_output/replica_", replica_id, "_dim_", dim),
1264 inputs, graph, device));
1265 int64_t paddings = GetPadding(dim, num_splits, inferred_shape);
1266 has_paddings |= paddings > 0;
1267 new_concat_nodes.emplace_back(NodeOut{concat_node, 0});
1268 }
1269 inputs_to_sharded_retval = new_concat_nodes;
1270 }
1271
1272 TF_RET_CHECK(inputs_to_sharded_retval.size() == 1);
1273 if (has_paddings) {
1274 TF_ASSIGN_OR_RETURN(Node * slice_node,
1275 CreateSliceNode(dtype, inferred_shape,
1276 inputs_to_sharded_retval.at(0).node,
1277 /*concat_out_index*/ 0, graph, device));
1278 return slice_node;
1279 }
1280 return inputs_to_sharded_retval.at(0).node;
1281 }
1282
CreateXlaConcatNode(const xla::OpSharding & sharding,const int replica_id,DataType dtype,const PartialTensorShape & partial_tensor_shape,const std::vector<NodeOut> & orig_inputs,absl::string_view device,Graph * graph)1283 xla::StatusOr<Node*> CreateXlaConcatNode(
1284 const xla::OpSharding& sharding, const int replica_id, DataType dtype,
1285 const PartialTensorShape& partial_tensor_shape,
1286 const std::vector<NodeOut>& orig_inputs, absl::string_view device,
1287 Graph* graph) {
1288 NodeDef xla_concat_def;
1289 xla_concat_def.set_name(graph->NewName(
1290 absl::StrCat("sharded_output/replica_", replica_id, "_concat")));
1291 xla_concat_def.set_op("XlaConcatND");
1292 xla_concat_def.set_device(std::string(device));
1293 AddNodeAttr("T", dtype, &xla_concat_def);
1294 AddNodeAttr("N", static_cast<int64_t>(orig_inputs.size()), &xla_concat_def);
1295 const std::vector<int64_t> num_concats(
1296 sharding.tile_assignment_dimensions().begin(),
1297 sharding.replicate_on_last_tile_dim()
1298 ? std::prev(sharding.tile_assignment_dimensions().end())
1299 : sharding.tile_assignment_dimensions().end());
1300 AddNodeAttr("num_concats", num_concats, &xla_concat_def);
1301 const int rank = sharding.replicate_on_last_tile_dim()
1302 ? sharding.tile_assignment_dimensions_size() - 1
1303 : sharding.tile_assignment_dimensions_size();
1304 std::vector<int32> paddings;
1305 paddings.reserve(rank);
1306 for (int dim = 0; dim < rank; ++dim) {
1307 paddings.push_back(GetPadding(dim, sharding.tile_assignment_dimensions(dim),
1308 partial_tensor_shape));
1309 }
1310 AddNodeAttr("paddings", paddings, &xla_concat_def);
1311
1312 TF_ASSIGN_OR_RETURN(Node * xla_concat, graph->AddNode(xla_concat_def));
1313 for (int i = 0, e = orig_inputs.size(); i < e; ++i) {
1314 const NodeOut& input = orig_inputs[i];
1315 graph->AddEdge(input.node, input.index, xla_concat, i);
1316 }
1317 return xla_concat;
1318 }
1319
1320 // Set the padding ops the same devices as the original inputs. If the original
1321 // inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand
1322 // mode will be triggered, so we don't need to copy the data back to the host
1323 // to do the padding.
SetPaddingNodesDevices(Graph * graph)1324 Status SetPaddingNodesDevices(Graph* graph) {
1325 for (Node* n : graph->op_nodes()) {
1326 bool tpu_padding_attr;
1327 if (n->type_string() == "Pad" &&
1328 GetNodeAttr(n->attrs(), kPostDeviceRewriteAttr, &tpu_padding_attr)
1329 .ok()) {
1330 Node* unpadded_input;
1331 TF_RETURN_IF_ERROR(n->input_node(0, &unpadded_input));
1332
1333 const string& requested_device = unpadded_input->requested_device();
1334 const string& assigned_device = unpadded_input->assigned_device_name();
1335 if (!requested_device.empty() || !assigned_device.empty()) {
1336 // The output nodes of the original unpadded inputs include the padded
1337 // inputs and real shapes of inputs, we assign those to the same device
1338 // as the original inputs.
1339 for (Node* out : unpadded_input->out_nodes()) {
1340 if (GetNodeAttr(out->attrs(), kPostDeviceRewriteAttr,
1341 &tpu_padding_attr)
1342 .ok()) {
1343 out->set_requested_device(requested_device);
1344 out->set_assigned_device_name(assigned_device);
1345 }
1346 }
1347 // There might be a tf.shape node added before TPUCompileOp, we need to
1348 // set its device as well.
1349 for (Node* out : n->out_nodes()) {
1350 if (n->type_string() == "Shape") {
1351 out->set_requested_device(requested_device);
1352 out->set_assigned_device_name(assigned_device);
1353 }
1354 }
1355 }
1356 }
1357 }
1358 return OkStatus();
1359 }
1360
AssignedOrRequestedDevice(const Node * node)1361 const string& AssignedOrRequestedDevice(const Node* node) {
1362 if (!node->assigned_device_name().empty()) {
1363 return node->assigned_device_name();
1364 }
1365 return node->requested_device();
1366 }
1367
IsTpuDevice(StringPiece device_string)1368 bool IsTpuDevice(StringPiece device_string) {
1369 DeviceNameUtils::ParsedName device;
1370 return DeviceNameUtils::ParseFullName(device_string, &device) &&
1371 device.type == DEVICE_TPU_NODE;
1372 }
1373
CanAcceptTPUDevicePropagation(const Node & node)1374 bool CanAcceptTPUDevicePropagation(const Node& node) {
1375 // A set of device ops can be placed on TPU. There is no strict rule of
1376 // thumb to decide which ops should be in the list, but empirically they are
1377 // mostly dummy ops like Identity-like ops or control flow related ops.
1378 // However one can add also add other ops like Pad to allow data stay on TPU.
1379 static const auto place_on_tpu_ops = new absl::flat_hash_set<std::string>(
1380 {"Identity", "IdentityN", "Enter", "Exit", "Switch", "Merge",
1381 "NextIteration", "Shape", "_Retval"});
1382 return place_on_tpu_ops->contains(node.type_string());
1383 }
1384
CreateOpMetadataFromNode(const Node & node)1385 xla::OpMetadata CreateOpMetadataFromNode(const Node& node) {
1386 xla::OpMetadata metadata;
1387 metadata.set_op_type(node.type_string());
1388 metadata.set_op_name(node.name());
1389 return metadata;
1390 }
1391
1392 // Helper struct holding node (nullable) and associated sharding.
1393 struct NodeAndSharding {
NodeAndShardingtensorflow::__anon71bd90b30111::NodeAndSharding1394 explicit NodeAndSharding(const Node* node, const xla::OpSharding& sharding)
1395 : node(node), sharding(sharding) {}
1396
1397 const Node* node;
1398 xla::OpSharding sharding;
1399 };
1400
1401 // Validate sharding configuration derived from XlaSharding attribute.
1402 // Infer the core id from the OpSharding, if necessary.
ParseAndValidateSharding(const NodeAndSharding & node_and_sharding,const int num_cores_per_replica,int64_t * inferred_core_id,absl::optional<NodeAndSharding> * result)1403 Status ParseAndValidateSharding(const NodeAndSharding& node_and_sharding,
1404 const int num_cores_per_replica,
1405 int64_t* inferred_core_id,
1406 absl::optional<NodeAndSharding>* result) {
1407 if (node_and_sharding.sharding.type() == xla::OpSharding::MAXIMAL) {
1408 int64_t core_annotation =
1409 node_and_sharding.sharding.tile_assignment_devices(0);
1410 TF_RETURN_IF_ERROR(
1411 ValidateCoreNumber(core_annotation, num_cores_per_replica));
1412 if (*inferred_core_id == -1 || *inferred_core_id > core_annotation) {
1413 *inferred_core_id = core_annotation;
1414 result->emplace(node_and_sharding);
1415 }
1416 } else {
1417 if (node_and_sharding.sharding.type() == xla::OpSharding::OTHER) {
1418 for (int64_t core :
1419 node_and_sharding.sharding.tile_assignment_devices()) {
1420 TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
1421 }
1422 }
1423
1424 if (!result->has_value()) {
1425 *result = node_and_sharding;
1426 } else {
1427 std::string result_value_serialized;
1428 xla::OpSharding result_value = result->value().sharding;
1429 result_value.clear_metadata();
1430 SerializeToStringDeterministic(result_value, &result_value_serialized);
1431
1432 std::string sharding_serialized;
1433 xla::OpSharding sharding = node_and_sharding.sharding;
1434 sharding.clear_metadata();
1435 SerializeToStringDeterministic(sharding, &sharding_serialized);
1436
1437 // TODO(lyandy): Choose the more granular sharding instead of always
1438 // assigning to core 0 (maximal).
1439 if (result_value_serialized != sharding_serialized) {
1440 // We see different shardings, assign to core 0.
1441 auto core_zero_sharding = xla::sharding_builder::AssignDevice(0);
1442 DCHECK_NE(node_and_sharding.node, nullptr);
1443 *core_zero_sharding.add_metadata() =
1444 CreateOpMetadataFromNode(*node_and_sharding.node);
1445 result->emplace(
1446 NodeAndSharding(node_and_sharding.node, core_zero_sharding));
1447 }
1448 }
1449 }
1450 return OkStatus();
1451 }
1452
1453 // As XlaSharding node may be followed by Cast op or an Identity op,
1454 // recursively walk the graph and aggregate nodes connectd to
1455 // |input_node| or Cast/Identity op following the |input_node|.
FindNodesMaybeContainingShardingInfo(const Node & input_node,std::vector<const Node * > * nodes)1456 void FindNodesMaybeContainingShardingInfo(const Node& input_node,
1457 std::vector<const Node*>* nodes) {
1458 if (input_node.IsIdentity() || input_node.type_string() == "Cast") {
1459 for (const Node* connected_node : input_node.out_nodes())
1460 FindNodesMaybeContainingShardingInfo(*connected_node, nodes);
1461 }
1462 nodes->emplace_back(&input_node);
1463 }
1464
1465 // Parse sharding configuration from |node| or it's adjacent nodes.
1466 // XlaSharding configuration may be derived from
1467 // a) Connected Identity op node.
1468 // b) Connected Cast op node.
1469 xla::StatusOr<absl::optional<NodeAndSharding>>
ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,const Node & node)1470 ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,
1471 const Node& node) {
1472 // If |node| has `device` attribute or is a XlaSharding op,
1473 // return the parsed OpSharding.
1474 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
1475 ParseShardingFromDevice(node, num_cores_per_replica,
1476 /*add_metadata=*/true));
1477 if (sharding.has_value()) {
1478 return absl::optional<NodeAndSharding>(NodeAndSharding(&node, *sharding));
1479 }
1480
1481 // XlaShardingOp may be followed by an identity or followed by identity
1482 // and a Cast op.
1483 std::vector<const Node*> potential_nodes_with_input_sharding;
1484 FindNodesMaybeContainingShardingInfo(node,
1485 &potential_nodes_with_input_sharding);
1486 for (const Node* maybe_node_with_sharding_info :
1487 potential_nodes_with_input_sharding) {
1488 if (maybe_node_with_sharding_info->type_string() != "XlaSharding") continue;
1489
1490 TF_ASSIGN_OR_RETURN(
1491 absl::optional<xla::OpSharding> sharding_config,
1492 ParseShardingFromDevice(*maybe_node_with_sharding_info,
1493 num_cores_per_replica, /*add_metadata=*/true));
1494 if (sharding_config.has_value()) {
1495 return absl::optional<NodeAndSharding>(
1496 NodeAndSharding(maybe_node_with_sharding_info, *sharding_config));
1497 }
1498 }
1499 return absl::optional<NodeAndSharding>();
1500 }
1501
1502 // Walk the graph from an argument node to find OpSharding configuration
1503 // from its neighbor nodes. Sharding configuration may be inferred from
1504 // 1) Parsing XlaSharding attribute from neighboring node.
1505 // 2) If argument node is a resource, then by parsing adjacent nodes
1506 // of the connected ReadVariable op.
ParseAndValidateShardingFromNeighbors(const int num_cores_per_replica,const std::string & arg_node_name,const Node & neighbor_node,int64_t * inferred_core_id,bool * is_fast_mem,absl::optional<NodeAndSharding> * result)1507 Status ParseAndValidateShardingFromNeighbors(
1508 const int num_cores_per_replica, const std::string& arg_node_name,
1509 const Node& neighbor_node, int64_t* inferred_core_id, bool* is_fast_mem,
1510 absl::optional<NodeAndSharding>* result) {
1511 if (neighbor_node.attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
1512 *is_fast_mem = true;
1513 VLOG(2) << "place " << neighbor_node.name() << " on fast memory because "
1514 << arg_node_name << " has " << TPU_FAST_MEM_ATTR << " attribute";
1515 }
1516
1517 // XlaSharding information may be encoded on node directly connected to the
1518 // argument node.
1519 TF_ASSIGN_OR_RETURN(
1520 absl::optional<NodeAndSharding> node_and_sharding,
1521 ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node));
1522 if (node_and_sharding.has_value()) {
1523 TF_RETURN_IF_ERROR(ParseAndValidateSharding(
1524 *node_and_sharding, num_cores_per_replica, inferred_core_id, result));
1525 return OkStatus();
1526 }
1527
1528 // When we use variable in TPU computation, we always have a
1529 // XlaSharding op followed by a ReadVariableOp. As so, correctly parse
1530 // the users of ReadVariableOp for potential sharding configuration.
1531 if (neighbor_node.type_string() == "ReadVariableOp") {
1532 for (const Edge* e : neighbor_node.out_edges()) {
1533 if (e->IsControlEdge()) continue;
1534
1535 if (e->dst()->attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
1536 *is_fast_mem = true;
1537 VLOG(2) << "place " << arg_node_name << " on fast memory because "
1538 << e->dst()->name() << TPU_FAST_MEM_ATTR << " attribute";
1539 }
1540
1541 TF_ASSIGN_OR_RETURN(
1542 absl::optional<NodeAndSharding> node_and_sharding,
1543 ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst()));
1544 if (node_and_sharding.has_value()) {
1545 TF_RETURN_IF_ERROR(ParseAndValidateSharding(*node_and_sharding,
1546 num_cores_per_replica,
1547 inferred_core_id, result));
1548 return OkStatus();
1549 }
1550 }
1551 }
1552 return OkStatus();
1553 }
1554
1555 } // namespace
1556
1557 // Inputs:
1558 // replication_spec_string: the device to which the TPUReplicate node was
1559 // assigned.
1560 // device_set: the set of TF devices.
1561 // Outputs:
1562 // tpu_compilation_device: the name of the TPU compilation device.
1563 // num_tpus_per_task: the number of TPUs in each task. Verifies that all tasks
1564 // have the same number of TPU devices.
1565 // tpu_devices: the TPU devices, indexed by [task][device].
GetTPUDeviceNames(const string & replication_spec_string,const DeviceSet & device_set,string * tpu_compilation_device,int * num_tpus_per_task,std::vector<std::vector<Device * >> * tpu_devices)1566 static Status GetTPUDeviceNames(
1567 const string& replication_spec_string, const DeviceSet& device_set,
1568 string* tpu_compilation_device, int* num_tpus_per_task,
1569 std::vector<std::vector<Device*>>* tpu_devices) {
1570 // TODO(b/110910013) GetSystemDevice parses the spec and returns the name of
1571 // the tpu_system device, which we replace by the cpu device. We do this
1572 // replacement because we want to place the TPUCompileOp (and the compile
1573 // assert op) explicitly on cpu devices on the same job as the tpu_system
1574 // device.
1575 DeviceNameUtils::ParsedName replication_spec;
1576 Device* replication_device;
1577 TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetSystemDevice(
1578 replication_spec_string, device_set, &replication_spec,
1579 &replication_device));
1580 *tpu_compilation_device =
1581 str_util::StringReplace(replication_device->name(), DEVICE_TPU_SYSTEM,
1582 DEVICE_CPU, /*replace_all=*/true);
1583
1584 // Finds the set of TPU devices attached to the tasks in the job.
1585 TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetTPUDevices(
1586 replication_spec, device_set, num_tpus_per_task, tpu_devices));
1587
1588 return OkStatus();
1589 }
1590
1591 // Parses the topology attribute of TPUReplicate, and populates *topology with
1592 // a physical mesh coordinate to (task, device) mapping.
ParseTopologyAttr(const string & topology_attr,const tpu::TpuTopologyExternal & tpu_topology,int num_tasks,int num_tpus_per_task,xla::Array4D<std::pair<int,int>> * topology)1593 static Status ParseTopologyAttr(const string& topology_attr,
1594 const tpu::TpuTopologyExternal& tpu_topology,
1595 int num_tasks, int num_tpus_per_task,
1596 xla::Array4D<std::pair<int, int>>* topology) {
1597 static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
1598 tpu::TopologyProto proto;
1599 proto.ParseFromString(topology_attr);
1600 if (proto.mesh_shape_size() != kTPUTopologyRank) {
1601 return errors::InvalidArgument("TPU topology must be rank ",
1602 kTPUTopologyRank);
1603 }
1604 if (proto.num_tasks() != num_tasks) {
1605 return errors::InvalidArgument("Mismatched number of TPU tasks (",
1606 proto.num_tasks(), " != ", num_tasks, ")");
1607 }
1608 if (proto.num_tpu_devices_per_task() != num_tpus_per_task) {
1609 return errors::InvalidArgument("Mismatched number of TPUs per task (",
1610 proto.num_tpu_devices_per_task(),
1611 " != ", num_tpus_per_task, ").");
1612 }
1613 if (proto.device_coordinates_size() !=
1614 num_tasks * num_tpus_per_task * kTPUTopologyRank) {
1615 return errors::InvalidArgument(
1616 "device coordinates should be ", num_tasks, "x", num_tpus_per_task, "x",
1617 kTPUTopologyRank, "; got ", proto.device_coordinates_size());
1618 }
1619
1620 int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
1621 *topology = xla::Array4D<std::pair<int, int>>(
1622 tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
1623 tpu_topology.chip_bounds().z, devices_per_chip, {-1, -1});
1624 int pos = 0;
1625 for (int task = 0; task < num_tasks; ++task) {
1626 for (int device = 0; device < num_tpus_per_task; ++device) {
1627 int32_t x = proto.device_coordinates(pos++);
1628 int32_t y = proto.device_coordinates(pos++);
1629 int32_t z = proto.device_coordinates(pos++);
1630 int32_t core = proto.device_coordinates(pos++);
1631
1632 if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
1633 core >= devices_per_chip) {
1634 return errors::InvalidArgument(
1635 "Mesh coordinates (", x, ",", y, ",", z, ",", core,
1636 ") are not valid for the current TPU topology");
1637 }
1638 if ((*topology)(x, y, z, core).first != -1) {
1639 return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
1640 ",", z, ",", core, ") in TPU topology");
1641 }
1642 (*topology)(x, y, z, core) = {task, device};
1643 }
1644 }
1645 return OkStatus();
1646 }
1647
1648 // Parses the value of the device_assignment attribute to TPUReplicate.
1649 // Populates *device_assignment; *device_assignment must be a 2D array with
1650 // shape (num_replicas, num_cores_per_replica).
ParseDeviceAssignmentAttr(absl::Span<const int> device_assignment_attr,const tpu::TpuTopologyExternal & tpu_topology,int num_replicas,int num_cores_per_replica,xla::Array2D<tpu::TpuCoreLocationExternal> * device_assignment)1651 static Status ParseDeviceAssignmentAttr(
1652 absl::Span<const int> device_assignment_attr,
1653 const tpu::TpuTopologyExternal& tpu_topology, int num_replicas,
1654 int num_cores_per_replica,
1655 xla::Array2D<tpu::TpuCoreLocationExternal>* device_assignment) {
1656 static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
1657
1658 const int64_t device_assignment_attr_size =
1659 num_replicas * num_cores_per_replica * kTPUTopologyRank;
1660 if (device_assignment_attr.size() != device_assignment_attr_size) {
1661 return errors::InvalidArgument(
1662 "Length of device_assignment attribute must be equal to num_replicas (",
1663 num_replicas, ") * num_cores_per_replica (", num_cores_per_replica,
1664 ") * ", kTPUTopologyRank, " got ", device_assignment_attr.size());
1665 }
1666 for (int core : device_assignment_attr) {
1667 if (core < 0 || core >= kTPUMaxTopologySize) {
1668 return errors::InvalidArgument(
1669 "Invalid core number in device assignment: ", core);
1670 }
1671 }
1672
1673 *device_assignment = xla::Array2D<tpu::TpuCoreLocationExternal>(
1674 num_replicas, num_cores_per_replica);
1675 int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
1676 xla::Array4D<int> replica_assignment(
1677 tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
1678 tpu_topology.chip_bounds().z, devices_per_chip, -1);
1679 int pos = 0;
1680 for (int replica = 0; replica < num_replicas; ++replica) {
1681 for (int logical_core = 0; logical_core < num_cores_per_replica;
1682 ++logical_core) {
1683 int32_t x = device_assignment_attr[pos++];
1684 int32_t y = device_assignment_attr[pos++];
1685 int32_t z = device_assignment_attr[pos++];
1686 int32_t core = device_assignment_attr[pos++];
1687
1688 if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
1689 core >= devices_per_chip) {
1690 return errors::InvalidArgument(
1691 "Mesh coordinates (", x, ",", y, ",", core,
1692 ") are not valid for the current TPU topology");
1693 }
1694 tpu::TpuCoreLocationExternal core_location =
1695 tpu_topology.Core(kTensorCore, x, y, z, core);
1696
1697 if (replica_assignment(x, y, z, core) != -1) {
1698 return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
1699 ",", z, ",", core,
1700 ") in TPU device assignment");
1701 }
1702 replica_assignment(x, y, z, core) = replica;
1703 (*device_assignment)(replica, logical_core) = core_location;
1704 }
1705 }
1706 return OkStatus();
1707 }
1708
1709 // Builds TensorFlow device assignments for the special case of a single core
1710 // computation that is replicated to every core in the mesh.
1711 // LINT.IfChange
BuildFullMeshDeviceAssignment(int num_replicas,const std::vector<std::vector<Device * >> & tpu_devices,int num_tasks,int num_tpus_per_task,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock)1712 static Status BuildFullMeshDeviceAssignment(
1713 int num_replicas, const std::vector<std::vector<Device*>>& tpu_devices,
1714 int num_tasks, int num_tpus_per_task,
1715 std::vector<std::vector<string>>* tf_device_assignment,
1716 std::vector<int>* devices_to_lock) {
1717 // Assign TensorFlow devices to replicas arbitrarily.
1718 for (int i = 0; i < num_replicas; ++i) {
1719 int task = i / num_tpus_per_task;
1720 int device = i % num_tpus_per_task;
1721 TF_RET_CHECK(task >= 0 && task < num_tasks);
1722 TF_RET_CHECK(device >= 0 && device < num_tpus_per_task);
1723
1724 // We don't actually know which TF device corresponds to which physical
1725 // device, but it doesn't matter—they're all identical.
1726 (*tf_device_assignment)[i] = {tpu_devices[task][device]->name()};
1727 devices_to_lock->push_back(i);
1728 }
1729 return OkStatus();
1730 }
1731 // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
1732
1733 // Builds TensorFlow device assignments for a replicated computation and convert
1734 // device_assignment into xla_device_assignment.
BuildGeneralDeviceAssignment(int num_replicas,int num_cores_per_replica,const std::vector<std::vector<Device * >> & tpu_devices,const xla::Array2D<tpu::TpuCoreLocationExternal> & device_assignment,const xla::Array4D<std::pair<int,int>> & topology,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment)1735 static Status BuildGeneralDeviceAssignment(
1736 int num_replicas, int num_cores_per_replica,
1737 const std::vector<std::vector<Device*>>& tpu_devices,
1738 const xla::Array2D<tpu::TpuCoreLocationExternal>& device_assignment,
1739 const xla::Array4D<std::pair<int, int>>& topology,
1740 std::vector<std::vector<string>>* tf_device_assignment,
1741 std::vector<int>* devices_to_lock,
1742 std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
1743 // Assign TensorFlow devices to each computation's replicas according to
1744 // device_assignment and 'topology'.
1745 *xla_device_assignment = absl::make_unique<xla::DeviceAssignment>(
1746 num_replicas, num_cores_per_replica);
1747 for (int replica = 0; replica < num_replicas; ++replica) {
1748 for (int computation = 0; computation < num_cores_per_replica;
1749 ++computation) {
1750 const tpu::TpuCoreLocationExternal& core_location =
1751 device_assignment(replica, computation);
1752
1753 int task;
1754 int device;
1755 std::tie(task, device) =
1756 topology(core_location.chip_coordinates().x,
1757 core_location.chip_coordinates().y,
1758 core_location.chip_coordinates().z, core_location.index());
1759
1760 CHECK_LT(computation, num_cores_per_replica);
1761 (**xla_device_assignment)(replica, computation) = core_location.Id();
1762
1763 // The communication pattern between replicas will be determined later by
1764 // BuildAllReduceRing.
1765 TF_RET_CHECK(task >= 0 && task < tpu_devices.size());
1766 TF_RET_CHECK(device >= 0 && device < tpu_devices[task].size());
1767 (*tf_device_assignment)[replica].push_back(
1768 tpu_devices[task][device]->name());
1769 devices_to_lock->push_back((task * tpu_devices[task].size()) + device);
1770 }
1771 }
1772 return OkStatus();
1773 }
1774
BuildDeviceAssignment(const tpu::TpuTopologyExternal & tpu_topology,int num_tpus_per_task,const std::vector<std::vector<Device * >> & tpu_devices,int num_replicas,int num_cores_per_replica,const string & topology_attr,absl::Span<const int> device_assignment_attr,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment)1775 /*static*/ Status DistributedTPURewritePass::BuildDeviceAssignment(
1776 const tpu::TpuTopologyExternal& tpu_topology, int num_tpus_per_task,
1777 const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas,
1778 int num_cores_per_replica, const string& topology_attr,
1779 absl::Span<const int> device_assignment_attr,
1780 std::vector<std::vector<string>>* tf_device_assignment,
1781 std::vector<int>* devices_to_lock,
1782 std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
1783 const int num_tasks = tpu_devices.size();
1784 const int num_tpu_devices = num_tasks * num_tpus_per_task;
1785 VLOG(2) << "num_tasks=" << num_tasks
1786 << " num_tpus_per_task=" << num_tpus_per_task;
1787
1788 // Checks num_replicas is sane first to avoid integer overflow.
1789 if (num_replicas > num_tpu_devices) {
1790 return errors::InvalidArgument("Requested num_replicas=", num_replicas,
1791 " but there are only ", num_tpu_devices,
1792 " cores in the TPU topology.");
1793 }
1794 if (num_replicas * num_cores_per_replica > num_tpu_devices) {
1795 return errors::InvalidArgument(
1796 "Requested num_replicas=", num_replicas, " with ",
1797 num_cores_per_replica, " cores per replica, but there are only ",
1798 num_tpu_devices, " cores in the TPU topology");
1799 }
1800
1801 tf_device_assignment->clear();
1802 tf_device_assignment->resize(num_replicas);
1803
1804 devices_to_lock->clear();
1805 devices_to_lock->reserve(num_replicas * num_cores_per_replica);
1806
1807 // Special case: we allow the user to omit the topology and device assignment
1808 // information in two cases:
1809 // * there is only one replica and one core per replica. In this case, we
1810 // don't need to know topology information because we don't communicate with
1811 // other cores.
1812 // * the number of replicas is equal to the number of cores in the slice. In
1813 // this case, all cores are running the same program so we don't need to
1814 // know which is which.
1815 if (topology_attr.empty()) {
1816 // LINT.IfChange
1817 if (num_replicas != 1 && num_replicas != num_tpu_devices) {
1818 return errors::InvalidArgument(
1819 "TPUReplicate asked to create ", num_replicas,
1820 " replicas, but the number of cores in the TPU topology is ",
1821 num_tpu_devices,
1822 " and no TPU device assignment was supplied. "
1823 "A TPU device assignment is required if the number of replicas is "
1824 "not 1 or the number of cores in the topology (",
1825 num_tpu_devices, ")");
1826 }
1827
1828 if (num_cores_per_replica != 1) {
1829 return errors::InvalidArgument(
1830 "A TPU topology must be provided if num_cores_per_replica != 1");
1831 }
1832
1833 if (!device_assignment_attr.empty()) {
1834 return errors::InvalidArgument(
1835 "A TPU topology must be provided if device_assignment_attr is "
1836 "non-empty");
1837 }
1838
1839 // If there is only one replica, assign the Tensorflow computation to task 0
1840 // device 0, and leave the XLA device assignment empty. We don't know which
1841 // core this is in the TPU topology, but it doesn't matter—we don't need to
1842 // communicate with any other cores.
1843 if (num_replicas == 1) {
1844 (*tf_device_assignment)[0] = {tpu_devices[0][0]->name()};
1845 devices_to_lock->push_back(0);
1846 return OkStatus();
1847 }
1848
1849 // Otherwise, num_replicas is equal to the number of cores, and we build a
1850 // device assignment that covers the entire mesh. We do not need to know
1851 // the topology to do so because all cores are identical.
1852 return BuildFullMeshDeviceAssignment(num_replicas, tpu_devices, num_tasks,
1853 num_tpus_per_task,
1854 tf_device_assignment, devices_to_lock);
1855 // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
1856 }
1857
1858 // Array that maps mesh coordinates to {TF task, TF TPU device #} pairs.
1859 xla::Array4D<std::pair<int, int>> topology;
1860 TF_RETURN_IF_ERROR(ParseTopologyAttr(topology_attr, tpu_topology, num_tasks,
1861 num_tpus_per_task, &topology));
1862
1863 // Array that maps logical (replica, core) pairs to physical mesh coordinates.
1864 xla::Array2D<tpu::TpuCoreLocationExternal> device_assignment;
1865 TF_RETURN_IF_ERROR(ParseDeviceAssignmentAttr(
1866 device_assignment_attr, tpu_topology, num_replicas, num_cores_per_replica,
1867 &device_assignment));
1868
1869 return BuildGeneralDeviceAssignment(
1870 num_replicas, num_cores_per_replica, tpu_devices, device_assignment,
1871 topology, tf_device_assignment, devices_to_lock, xla_device_assignment);
1872 }
1873
GetComputationForTPUReplicateOp(const NameAttrList & function,FunctionLibraryRuntime * flr,Graph * computation,DataTypeVector * arg_types,DataTypeVector * retval_types)1874 Status DistributedTPURewritePass::GetComputationForTPUReplicateOp(
1875 const NameAttrList& function, FunctionLibraryRuntime* flr,
1876 Graph* computation, DataTypeVector* arg_types,
1877 DataTypeVector* retval_types) {
1878 FunctionLibraryRuntime::Handle handle;
1879
1880 TF_RETURN_IF_ERROR(
1881 flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
1882
1883 const FunctionBody* fbody = flr->GetFunctionBody(handle);
1884
1885 CopyGraph(*fbody->graph, computation);
1886 *arg_types = fbody->arg_types;
1887 *retval_types = fbody->ret_types;
1888 return OkStatus();
1889 }
1890
1891 // Grab the InferredShape corresponding to an edge input.
GetEdgeShape(const GraphShapeInfo & shape_info,const Edge & edge,const InferredShape ** info)1892 static Status GetEdgeShape(const GraphShapeInfo& shape_info, const Edge& edge,
1893 const InferredShape** info) {
1894 auto it = shape_info.find(edge.src()->name());
1895 if (it == shape_info.end()) {
1896 return errors::InvalidArgument(
1897 "Input to replicated TPU computation is missing InferredShape: ",
1898 edge.src()->name());
1899 }
1900 TF_RET_CHECK(it->second.size() > edge.src_output());
1901 *info = &it->second[edge.src_output()];
1902 return OkStatus();
1903 }
1904
GetArgAndRetvalShapes(const GraphShapeInfo & shape_info,const Node & node,const ParameterInfo & params_info,std::vector<InferredShape> * arg_shapes,std::vector<InferredShape> * retval_shapes)1905 Status DistributedTPURewritePass::GetArgAndRetvalShapes(
1906 const GraphShapeInfo& shape_info, const Node& node,
1907 const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes,
1908 std::vector<InferredShape>* retval_shapes) {
1909 std::vector<const Edge*> input_edges;
1910 TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
1911
1912 // If any replica's arg shape is unknown, we will mark the computation's arg
1913 // shape as being unknown. If the shapes differ the TpuExecute Op will raise a
1914 // runtime error.
1915 std::vector<bool> any_replica_shape_unknown(
1916 params_info.NumInputsToEachReplica());
1917 arg_shapes->clear();
1918 arg_shapes->resize(params_info.NumInputsToEachReplica());
1919 TF_RET_CHECK(input_edges.size() == params_info.NumInputsFromHost());
1920 // Determines the shapes of the per-replica arguments and checks that all
1921 // replicas have identical shapes.
1922 int64_t edge_pos = 0;
1923 auto check_shape = [&](int input_index) -> Status {
1924 const InferredShape* info;
1925 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1926 ++edge_pos;
1927
1928 if ((info->handle_type == DT_INVALID && !info->shape.IsFullyDefined()) ||
1929 (info->handle_type != DT_INVALID &&
1930 !info->handle_shape.IsFullyDefined())) {
1931 any_replica_shape_unknown[input_index] = true;
1932 }
1933 xla::StatusOr<InferredShape> status =
1934 MergeInferredShapes((*arg_shapes)[input_index], *info);
1935 if (!status.ok()) {
1936 return errors::InvalidArgument(
1937 "Mismatched shapes for input ", input_index, ": ",
1938 (*arg_shapes)[input_index].shape.DebugString(), " vs. ",
1939 info->shape.DebugString());
1940 }
1941 (*arg_shapes)[input_index] = status.ValueOrDie();
1942 return OkStatus();
1943 };
1944
1945 for (int64_t i = 0; i < params_info.NumReplicas(); ++i) {
1946 for (int64_t j = 0; j < params_info.NumPerReplicaArgs(); ++j) {
1947 TF_RETURN_IF_ERROR(check_shape(j));
1948 }
1949 }
1950
1951 for (int64_t i = 0; i < params_info.NumDistributedArgs(); ++i) {
1952 TF_RETURN_IF_ERROR(check_shape(params_info.NumPerReplicaArgs() + i));
1953 }
1954
1955 for (int64_t i = 0;
1956 i < params_info.NumPerReplicaArgs() + params_info.NumDistributedArgs();
1957 ++i) {
1958 if (any_replica_shape_unknown[i]) {
1959 (*arg_shapes)[i].shape = PartialTensorShape();
1960 (*arg_shapes)[i].handle_shape = PartialTensorShape();
1961 }
1962 }
1963
1964 // Determines the shape of the broadcast arguments.
1965 for (int64_t i = 0; i < params_info.NumBroadcastArgs(); ++i) {
1966 TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
1967 const InferredShape* info;
1968 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1969 (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
1970 params_info.NumDistributedArgs()]
1971 .shape = info->shape;
1972 ++edge_pos;
1973 }
1974
1975 // Determines the handle shape and handle type of the resource variable
1976 // arguments.
1977 for (int64_t i = 0; i < params_info.NumVariables(); ++i) {
1978 TF_RET_CHECK(node.input_type(edge_pos) == DT_RESOURCE);
1979 const InferredShape* info;
1980 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1981 InferredShape& arg_shape =
1982 (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
1983 params_info.NumDistributedArgs() +
1984 params_info.NumBroadcastArgs()];
1985 arg_shape.shape = TensorShape(); // Variables are always scalars.
1986 arg_shape.handle_shape = info->handle_shape;
1987 arg_shape.handle_type = info->handle_type;
1988 TF_RET_CHECK(arg_shape.handle_type != DT_INVALID)
1989 << " input edge: " << input_edges[edge_pos]->DebugString();
1990 ++edge_pos;
1991 }
1992
1993 // Determines the shape of the guaranteed constants.
1994 // TODO(vinuraja): Can be removed because they are not required for any
1995 // calculations. Leaving them here for symmetry with other structures like
1996 // arg_types, arg_sharding, etc.
1997 for (int64_t i = 0; i < params_info.NumGuaranteedConstants(); ++i) {
1998 TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
1999 const InferredShape* info;
2000 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
2001 (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
2002 params_info.NumDistributedArgs() +
2003 params_info.NumBroadcastArgs() + params_info.NumVariables()]
2004 .shape = info->shape;
2005 ++edge_pos;
2006 }
2007
2008 // Extract the return value shapes.
2009 auto it = shape_info.find(node.name());
2010 retval_shapes->clear();
2011 if (it != shape_info.end()) {
2012 TF_RET_CHECK(it->second.size() >= node.num_outputs());
2013 retval_shapes->resize(node.num_outputs());
2014 for (int i = 0; i < node.num_outputs(); ++i) {
2015 (*retval_shapes)[i].shape = it->second[i].shape;
2016 }
2017 } else if (node.num_outputs() > 0) {
2018 return errors::InvalidArgument(
2019 "Replicated TPU computation is missing InferredShape: ",
2020 FormatNodeForError(node));
2021 }
2022 return OkStatus();
2023 }
2024
2025 // Verifies that all nodes have legal sharding.
ValidateCoreNumbers(const Graph & graph,int num_cores_per_replica)2026 static Status ValidateCoreNumbers(const Graph& graph,
2027 int num_cores_per_replica) {
2028 for (Node* n : graph.nodes()) {
2029 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
2030 ParseShardingFromDevice(*n, num_cores_per_replica,
2031 /*add_metadata=*/true));
2032 }
2033 return OkStatus();
2034 }
2035
InferXlaShardingFromNeighbors(const Node & n,int num_cores_per_replica,FunctionLibraryRuntime * flr,CachedFunctionHandles * cached_function_handles,absl::optional<NodeAndSharding> * output_node_and_sharding,bool * is_fast_mem)2036 static Status InferXlaShardingFromNeighbors(
2037 const Node& n, int num_cores_per_replica, FunctionLibraryRuntime* flr,
2038 CachedFunctionHandles* cached_function_handles,
2039 absl::optional<NodeAndSharding>* output_node_and_sharding,
2040 bool* is_fast_mem) {
2041 int64_t core = -1;
2042 absl::optional<NodeAndSharding> result;
2043 // We assume the variable has been allocated on fast memory if any consuming
2044 // op has TPU_FAST_MEM_ATTR attribute. This is a protocol between runtime and
2045 // compiler.
2046 *is_fast_mem = false;
2047 for (const Edge* edge : n.out_edges()) {
2048 if (edge->IsControlEdge()) continue;
2049
2050 TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
2051 num_cores_per_replica, n.name(), *edge->dst(), &core, is_fast_mem,
2052 &result));
2053
2054 if (!flr) continue;
2055
2056 // The nodes deciding this arg's device assignment might be in
2057 // FunctionDef. Instantiate FunctionDefs associated with this node
2058 // and check nodes using this arg.
2059 std::function<Status(const Edge* call_edge)> parse_sharding_from_function =
2060 [&](const Edge* call_edge) {
2061 auto associated_functions = GetAssociatedFunctions(
2062 *call_edge->dst(), flr->GetFunctionLibraryDefinition());
2063 for (auto& associated_function : associated_functions) {
2064 FunctionLibraryRuntime::Handle handle;
2065 TF_RETURN_IF_ERROR(cached_function_handles->GetOrInstantiate(
2066 associated_function.func_name(),
2067 AttrSlice(&associated_function.attrs()), &handle));
2068 const FunctionBody* body = flr->GetFunctionBody(handle);
2069 Graph* g = body->graph;
2070
2071 for (Node* body_node : g->nodes()) {
2072 if (!body_node->IsArg()) continue;
2073
2074 int index;
2075 TF_RETURN_IF_ERROR(
2076 GetNodeAttr(body_node->attrs(), "index", &index));
2077 if (index != call_edge->dst_input()) continue;
2078
2079 for (const Edge* out_edge : body_node->out_edges()) {
2080 if (out_edge->IsControlEdge()) continue;
2081
2082 TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
2083 num_cores_per_replica, n.name(), *out_edge->dst(), &core,
2084 is_fast_mem, &result));
2085
2086 TF_RETURN_IF_ERROR(parse_sharding_from_function(out_edge));
2087 }
2088 }
2089 }
2090 return OkStatus();
2091 };
2092 TF_RETURN_IF_ERROR(parse_sharding_from_function(edge));
2093 }
2094 *output_node_and_sharding = result;
2095 return OkStatus();
2096 }
2097
UseSpmdForXlaPartitioning(const Node * replicate_node)2098 bool UseSpmdForXlaPartitioning(const Node* replicate_node) {
2099 bool spmd_attr;
2100 if (!replicate_node ||
2101 !TryGetNodeAttr(replicate_node->attrs(), "use_spmd_for_xla_partitioning",
2102 &spmd_attr)) {
2103 spmd_attr = false;
2104 }
2105 return spmd_attr;
2106 }
2107
FormatNodeAndShardingMsg(const absl::optional<NodeAndSharding> & node_and_sharding)2108 std::string FormatNodeAndShardingMsg(
2109 const absl::optional<NodeAndSharding>& node_and_sharding) {
2110 DCHECK(node_and_sharding.has_value());
2111
2112 xla::OpSharding sharding_no_metadata = node_and_sharding->sharding;
2113 sharding_no_metadata.clear_metadata();
2114 std::string escaped_sharding_str =
2115 absl::CEscape(sharding_no_metadata.SerializeAsString());
2116 if (node_and_sharding->node == nullptr) {
2117 return absl::StrCat(" via default sharding '", escaped_sharding_str, "'");
2118 }
2119
2120 return absl::StrCat(" via node ", node_and_sharding->node->DebugString(),
2121 " sharding '", escaped_sharding_str, "'");
2122 }
2123
AssignArgsAndRetvalsToCores(int num_cores_per_replica,const ParameterInfo & params_info,const DataTypeVector & arg_types,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & retval_types,const std::vector<InferredShape> & retval_shapes,const Graph & graph,const Node * replicate_node,FunctionLibraryRuntime * flr,bool allow_parameter_replication_for_spmd,std::vector<xla::OpSharding> * arg_sharding,std::vector<bool> * arg_fast_mem,std::vector<xla::OpSharding> * retval_sharding,std::vector<std::string> * arg_names)2124 Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
2125 int num_cores_per_replica, const ParameterInfo& params_info,
2126 const DataTypeVector& arg_types,
2127 const std::vector<InferredShape>& arg_shapes,
2128 const DataTypeVector& retval_types,
2129 const std::vector<InferredShape>& retval_shapes, const Graph& graph,
2130 const Node* replicate_node, FunctionLibraryRuntime* flr,
2131 bool allow_parameter_replication_for_spmd,
2132 std::vector<xla::OpSharding>* arg_sharding, std::vector<bool>* arg_fast_mem,
2133 std::vector<xla::OpSharding>* retval_sharding,
2134 std::vector<std::string>* arg_names) {
2135 // Builds vectors of the argument and return nodes.
2136 std::vector<Node*> args(arg_types.size());
2137 std::vector<Node*> retvals(retval_types.size());
2138 absl::flat_hash_map<int, Node*> partitioned_output_nodes;
2139 for (Node* node : graph.op_nodes()) {
2140 if (node->IsArg()) {
2141 int index;
2142 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
2143 TF_RET_CHECK(index >= 0 && index < args.size());
2144 args[index] = node;
2145 } else if (node->IsRetval()) {
2146 int index;
2147 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
2148 TF_RET_CHECK(index >= 0 && index < retvals.size());
2149 retvals[index] = node;
2150 }
2151 }
2152 for (const Edge* edge : replicate_node->out_edges()) {
2153 int num_partitioned_outputs = 0;
2154 for (const Edge* out_edge : edge->dst()->out_edges()) {
2155 if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
2156 partitioned_output_nodes[edge->src_output()] = out_edge->dst();
2157 num_partitioned_outputs++;
2158 }
2159 }
2160 if (num_partitioned_outputs > 1) {
2161 return errors::InvalidArgument(
2162 "More than one TPUPartitionedOutput per replciated output.");
2163 }
2164 }
2165
2166 // Verifies there are no missing arguments/return values.
2167 for (int i = 0; i < args.size(); ++i) {
2168 if (args[i] == nullptr) {
2169 return errors::Internal("Missing function argument: ", i);
2170 }
2171 }
2172 for (int i = 0; i < retvals.size(); ++i) {
2173 if (retvals[i] == nullptr) {
2174 return errors::Internal("Missing function return value: ", i);
2175 }
2176 }
2177
2178 // Assigns a core to each _Arg. Chooses the lowest-numbered core that
2179 // consumes the argument. We choose the lowest-numbered core so the
2180 // assignment is deterministic.
2181 TensorDevicePlacer args_device_selector(num_cores_per_replica, arg_types,
2182 arg_shapes);
2183 arg_sharding->resize(args.size());
2184 arg_names->resize(args.size());
2185 arg_fast_mem->resize(args.size());
2186 CachedFunctionHandles cached_function_handles(flr);
2187 const bool use_spmd = (UseSpmdForXlaPartitioning(replicate_node) ||
2188 replicate_inputs_outputs_by_default_for_xla_spmd_) &&
2189 allow_parameter_replication_for_spmd &&
2190 num_cores_per_replica > 1;
2191
2192 // Offset _TPUReplicate non per replica argument indices by
2193 // (num_replicas - 1) * num_per_replica_args as _TPUReplicate nodes are
2194 // constructed with all per replica args across all replicas while the
2195 // encapsulated function only has 1 replica's per replica args. Per replica
2196 // args are ordered by replica first, so the index here does not require an
2197 // offset and the first replica's input nodes is sufficient for determining
2198 // argument sharding.
2199 const int index_offset =
2200 (params_info.NumReplicas() - 1) * params_info.NumPerReplicaArgs();
2201 for (int i = 0; i < args.size(); ++i) {
2202 const Node* n = args[i];
2203 absl::optional<int64_t> assigned_core;
2204 absl::optional<NodeAndSharding> node_and_sharding;
2205 bool is_fast_mem;
2206 TF_RETURN_IF_ERROR(InferXlaShardingFromNeighbors(
2207 *n, num_cores_per_replica, flr, &cached_function_handles,
2208 &node_and_sharding, &is_fast_mem));
2209
2210 const bool is_per_replica_arg = params_info.IsPerReplicaArg(i);
2211 if (is_per_replica_arg || params_info.IsDistributedArg(i)) {
2212 Node* input_node;
2213 TF_RETURN_IF_ERROR(replicate_node->input_node(
2214 i + (is_per_replica_arg ? 0 : index_offset), &input_node));
2215 if (input_node->type_string() == kTPUPartitionedInput) {
2216 TF_ASSIGN_OR_RETURN(
2217 absl::optional<xla::OpSharding> parsed_sharding,
2218 GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
2219 if (!parsed_sharding.has_value())
2220 return errors::InvalidArgument("Missing _XlaSharding attr from: ",
2221 input_node->DebugString());
2222 node_and_sharding = NodeAndSharding(input_node, *parsed_sharding);
2223 VLOG(1) << "Arg " << i << " parsed sharding information from "
2224 << input_node->DebugString() << " : "
2225 << parsed_sharding->DebugString();
2226 }
2227 }
2228
2229 if (params_info.IsVariableArg(i)) {
2230 Node* input_node;
2231 TF_RETURN_IF_ERROR(
2232 replicate_node->input_node(i + index_offset, &input_node));
2233 if (input_node->type_string() == kVarHandleOp) {
2234 TF_ASSIGN_OR_RETURN(
2235 absl::optional<xla::OpSharding> parsed_sharding,
2236 GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
2237 if (parsed_sharding.has_value()) {
2238 node_and_sharding = NodeAndSharding(input_node, *parsed_sharding);
2239 VLOG(1) << "Arg " << i << " parsed sharding information from "
2240 << input_node->DebugString() << " : "
2241 << parsed_sharding->DebugString();
2242 }
2243 }
2244 }
2245
2246 if (node_and_sharding.has_value() && enable_automatic_model_parallelism_) {
2247 return tensorflow::errors::InvalidArgument(
2248 "Specifying manual sharding is not allowed when automatic "
2249 "model parallelism is enabled.",
2250 node_and_sharding->sharding.DebugString());
2251 }
2252
2253 if (!node_and_sharding.has_value()) {
2254 if (use_spmd &&
2255 (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
2256 ((params_info.IsPerReplicaArg(i) ||
2257 params_info.IsDistributedArg(i)) &&
2258 arg_types[i] != DT_RESOURCE) ||
2259 params_info.IsConstantArg(i))) {
2260 // Use replication for host variables or non-variable per-replica
2261 // inputs.
2262 node_and_sharding = NodeAndSharding(/*node=*/nullptr,
2263 xla::sharding_builder::Replicate());
2264 } else {
2265 // TODO(dlibenzi): Distributing variables to cores other than 0 makes
2266 // learning/brain/research/babelfish/trainer:trainer_tpu_test fail.
2267 // For now distribute only per replica arguments, unless
2268 // tf_jf_distribute_vars is set, to allow debugging the issue.
2269 if (((params_info.IsPerReplicaArg(i) ||
2270 params_info.IsDistributedArg(i)) &&
2271 arg_types[i] != DT_RESOURCE) ||
2272 (distribute_vars_ && params_info.IsVariableArg(i))) {
2273 assigned_core = args_device_selector.RetrieveAssignment(i);
2274 } else {
2275 assigned_core = 0;
2276 }
2277 node_and_sharding = NodeAndSharding(
2278 /*node=*/nullptr,
2279 xla::sharding_builder::AssignDevice(*assigned_core));
2280 }
2281 *node_and_sharding->sharding.add_metadata() =
2282 CreateOpMetadataFromNode(*replicate_node);
2283 } else if (node_and_sharding->sharding.type() == xla::OpSharding::MAXIMAL) {
2284 if (use_spmd) {
2285 node_and_sharding->sharding = xla::sharding_builder::Replicate();
2286 } else {
2287 assigned_core = node_and_sharding->sharding.tile_assignment_devices(0);
2288 }
2289 } else if (node_and_sharding->sharding.type() !=
2290 xla::OpSharding::REPLICATED &&
2291 node_and_sharding->sharding.type() != xla::OpSharding::OTHER) {
2292 return tensorflow::errors::InvalidArgument(
2293 "Unsupported argument sharding (for arg ", n->DebugString(),
2294 "): ", node_and_sharding->sharding.DebugString());
2295 }
2296 if (assigned_core.has_value()) {
2297 args_device_selector.ReportDeviceAssigned(*assigned_core, i);
2298 VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2299 << ") to core " << *assigned_core
2300 << FormatNodeAndShardingMsg(node_and_sharding);
2301 args[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
2302 } else if (node_and_sharding->sharding.type() == xla::OpSharding::OTHER) {
2303 for (int64_t core :
2304 node_and_sharding->sharding.tile_assignment_devices()) {
2305 TF_RET_CHECK(core >= 0 && core < num_cores_per_replica)
2306 << "core " << core << " should be between [0, "
2307 << num_cores_per_replica << "). sharding is "
2308 << node_and_sharding->sharding.DebugString();
2309 args_device_selector.ReportDeviceAssigned(core, i);
2310 }
2311 VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2312 << ") with tiled sharding to cores "
2313 << absl::StrJoin(
2314 node_and_sharding->sharding.tile_assignment_devices(), ",")
2315 << " " << FormatNodeAndShardingMsg(node_and_sharding);
2316 } else {
2317 DCHECK_EQ(node_and_sharding->sharding.type(),
2318 xla::OpSharding::REPLICATED);
2319 for (int64_t core = 0; core < num_cores_per_replica; ++core) {
2320 args_device_selector.ReportDeviceAssigned(core, i);
2321 }
2322 VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2323 << ") to all cores"
2324 << FormatNodeAndShardingMsg(node_and_sharding);
2325 }
2326 (*arg_sharding)[i] = node_and_sharding->sharding;
2327 (*arg_fast_mem)[i] = is_fast_mem;
2328 (*arg_names)[i] = n->name();
2329 if (is_fast_mem) {
2330 VLOG(3) << "Add " << TPU_FAST_MEM_ATTR << " attribute to "
2331 << args[i]->name();
2332 }
2333 args[i]->AddAttr(kShardingAttribute,
2334 node_and_sharding->sharding.SerializeAsString());
2335 }
2336 TF_RETURN_IF_ERROR(cached_function_handles.ReleaseAllHandles());
2337
2338 // Assigns each _Retval node to the core that produces its value.
2339 TensorDevicePlacer retvals_device_selector(num_cores_per_replica,
2340 retval_types, retval_shapes);
2341 retval_sharding->resize(retvals.size());
2342 for (int i = 0; i < retvals.size(); ++i) {
2343 const Edge* edge;
2344 TF_RETURN_IF_ERROR(retvals[i]->input_edge(0, &edge));
2345
2346 TF_ASSIGN_OR_RETURN(
2347 absl::optional<xla::OpSharding> edge_sharding,
2348 ParseShardingFromEdgeSource(*edge, num_cores_per_replica,
2349 /*add_metadata=*/true));
2350
2351 absl::optional<NodeAndSharding> node_and_sharding;
2352 if (edge_sharding.has_value()) {
2353 node_and_sharding.emplace(NodeAndSharding(edge->src(), *edge_sharding));
2354 }
2355
2356 if (partitioned_output_nodes.contains(i)) {
2357 Node* output_node = partitioned_output_nodes[i];
2358 TF_ASSIGN_OR_RETURN(
2359 absl::optional<xla::OpSharding> parsed_sharding,
2360 GetShardingFromNodeDef(output_node->def(), /*add_metadata=*/true));
2361 if (parsed_sharding.has_value()) {
2362 node_and_sharding = NodeAndSharding(output_node, *parsed_sharding);
2363 VLOG(1) << "Retval " << i << " parsed sharding information from "
2364 << output_node->DebugString() << " : "
2365 << parsed_sharding->DebugString();
2366 }
2367 }
2368 absl::optional<int64_t> assigned_core;
2369 if (node_and_sharding.has_value()) {
2370 if (enable_automatic_model_parallelism_) {
2371 return tensorflow::errors::InvalidArgument(
2372 "Specifying manual sharding is not allowed when automatic "
2373 "model parallelism is enabled.",
2374 node_and_sharding->sharding.DebugString());
2375 }
2376
2377 if (node_and_sharding->sharding.type() == xla::OpSharding::MAXIMAL) {
2378 if (use_spmd) {
2379 node_and_sharding->sharding = xla::sharding_builder::Replicate();
2380 } else {
2381 assigned_core =
2382 node_and_sharding->sharding.tile_assignment_devices(0);
2383 TF_RETURN_IF_ERROR(
2384 ValidateCoreNumber(*assigned_core, num_cores_per_replica));
2385 }
2386 } else if (node_and_sharding->sharding.type() !=
2387 xla::OpSharding::REPLICATED &&
2388 node_and_sharding->sharding.type() != xla::OpSharding::OTHER) {
2389 return tensorflow::errors::InvalidArgument(
2390 "Unsupported argument sharding for retval ",
2391 retvals[i]->DebugString(), " edge=", edge->DebugString(), ": ",
2392 node_and_sharding->sharding.DebugString());
2393 }
2394 } else {
2395 if (use_spmd) {
2396 node_and_sharding = NodeAndSharding(/*node=*/nullptr,
2397 xla::sharding_builder::Replicate());
2398 } else {
2399 if (distribute_vars_) {
2400 assigned_core = retvals_device_selector.RetrieveAssignment(i);
2401 } else {
2402 assigned_core = 0;
2403 }
2404 node_and_sharding = NodeAndSharding(
2405 /*node=*/nullptr,
2406 xla::sharding_builder::AssignDevice(*assigned_core));
2407 }
2408 *node_and_sharding->sharding.add_metadata() =
2409 CreateOpMetadataFromNode(*replicate_node);
2410 }
2411 if (assigned_core.has_value() && !use_spmd) {
2412 retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
2413 retvals_device_selector.ReportDeviceAssigned(*assigned_core, i);
2414 VLOG(3) << "Assigning return value " << i << " ("
2415 << retvals[i]->DebugString() << ") to core " << *assigned_core
2416 << FormatNodeAndShardingMsg(node_and_sharding);
2417 } else if (node_and_sharding->sharding.type() == xla::OpSharding::OTHER) {
2418 for (int64_t core :
2419 node_and_sharding->sharding.tile_assignment_devices()) {
2420 TF_RET_CHECK(core >= 0 && core < num_cores_per_replica)
2421 << "core " << core << " should be between [0, "
2422 << num_cores_per_replica << "). sharding is "
2423 << node_and_sharding->sharding.DebugString();
2424 retvals_device_selector.ReportDeviceAssigned(core, i);
2425 }
2426 VLOG(3) << "Assigning return value " << i << " ("
2427 << retvals[i]->DebugString() << ") with tiled sharding to cores "
2428 << absl::StrJoin(
2429 node_and_sharding->sharding.tile_assignment_devices(), ",")
2430 << " " << FormatNodeAndShardingMsg(node_and_sharding);
2431 } else {
2432 if (use_spmd) {
2433 node_and_sharding->sharding = xla::sharding_builder::Replicate();
2434 }
2435 for (int64_t core = 0; core < num_cores_per_replica; ++core) {
2436 retvals_device_selector.ReportDeviceAssigned(core, i);
2437 }
2438 VLOG(3) << "Assigning return value " << i << " ("
2439 << retvals[i]->DebugString() << ") to all cores"
2440 << FormatNodeAndShardingMsg(node_and_sharding);
2441 }
2442 retvals[i]->AddAttr(kShardingAttribute,
2443 node_and_sharding->sharding.SerializeAsString());
2444 (*retval_sharding)[i] = node_and_sharding->sharding;
2445 }
2446 if (use_spmd &&
2447 (absl::c_any_of(*arg_sharding,
2448 [](const xla::OpSharding& s) {
2449 return s.type() == xla::OpSharding::MAXIMAL;
2450 }) ||
2451 absl::c_any_of(*retval_sharding, [](const xla::OpSharding& s) {
2452 return s.type() == xla::OpSharding::MAXIMAL;
2453 }))) {
2454 return tensorflow::errors::InvalidArgument(
2455 "XLA SPMD only supports cases where all inputs/outputs "
2456 "exist on every partition (sharded or replicated).");
2457 }
2458 return OkStatus();
2459 }
2460
2461 // Builds Shape nodes that compute the shapes of arguments whose shapes are not
2462 // statically known.
BuildDynamicShapeNodes(const Node & replicate_node,const std::vector<InferredShape> & arg_shapes,const ParameterInfo & params_info,const std::vector<Node * > & variable_reads,Graph * graph,std::vector<Node * > * dynamic_shape_nodes)2463 /* static */ Status DistributedTPURewritePass::BuildDynamicShapeNodes(
2464 const Node& replicate_node, const std::vector<InferredShape>& arg_shapes,
2465 const ParameterInfo& params_info, const std::vector<Node*>& variable_reads,
2466 Graph* graph, std::vector<Node*>* dynamic_shape_nodes) {
2467 dynamic_shape_nodes->clear();
2468
2469 std::vector<const Edge*> replicate_input_edges;
2470 TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
2471
2472 // The compiler determines the shape of each constant by inspecting the value
2473 // of its corresponding host-memory tensor; this happens when a step is run.
2474 // As a result, the shapes of constants are not needed at graph rewrite time.
2475 const int num_args = arg_shapes.size() - params_info.NumGuaranteedConstants();
2476 TF_RET_CHECK(num_args == params_info.NumPerReplicaArgs() +
2477 params_info.NumDistributedArgs() +
2478 params_info.NumBroadcastArgs() +
2479 params_info.NumVariables());
2480
2481 for (int i = 0; i < num_args; ++i) {
2482 const PartialTensorShape* shape = arg_shapes[i].handle_type == DT_INVALID
2483 ? &arg_shapes[i].shape
2484 : &arg_shapes[i].handle_shape;
2485 if (!shape->IsFullyDefined()) {
2486 NodeDef def;
2487 Node* src;
2488 int src_output;
2489 std::vector<Node*> control_inputs;
2490
2491 if (params_info.IsVariableArg(i)) {
2492 int64_t var_num = i - params_info.NumPerReplicaArgs() -
2493 params_info.NumDistributedArgs() -
2494 params_info.NumBroadcastArgs();
2495 TF_RET_CHECK(0 <= var_num && var_num < variable_reads.size());
2496 Node* read = variable_reads[var_num];
2497
2498 DCHECK_EQ(read->type_string(), "ReadVariableOp");
2499
2500 for (const Edge* edge : read->in_edges()) {
2501 if (edge->IsControlEdge()) {
2502 control_inputs.push_back(edge->src());
2503 }
2504 }
2505
2506 const Edge* variable_input = nullptr;
2507 TF_RETURN_IF_ERROR(read->input_edge(/*idx=*/0, &variable_input));
2508 src = variable_input->src();
2509 src_output = variable_input->src_output();
2510
2511 def.set_name(
2512 graph->NewName(strings::StrCat(src->name(), "/variable_shape")));
2513 def.set_op("VariableShape");
2514 } else {
2515 if (params_info.IsPerReplicaArg(i)) {
2516 TF_RET_CHECK(i < replicate_input_edges.size());
2517 // All replicas must have the same input shapes. Uses the shape of the
2518 // inputs from the first replica.
2519 src = replicate_input_edges[i]->src();
2520 src_output = replicate_input_edges[i]->src_output();
2521 } else {
2522 DCHECK(params_info.IsDistributedArg(i) ||
2523 params_info.IsBroadcastArg(i));
2524 int64_t input_num =
2525 params_info.NumPerReplicaArgs() * params_info.NumReplicas() + i -
2526 params_info.NumPerReplicaArgs();
2527 TF_RET_CHECK(0 <= input_num &&
2528 input_num < replicate_input_edges.size());
2529 src = replicate_input_edges[input_num]->src();
2530 src_output = replicate_input_edges[input_num]->src_output();
2531 }
2532
2533 def.set_name(graph->NewName(strings::StrCat(src->name(), "/shape")));
2534 def.set_op("Shape");
2535 AddNodeAttr("T", src->output_type(src_output), &def);
2536 }
2537
2538 def.set_device(src->assigned_device_name());
2539 AddNodeAttr("out_type", DT_INT64, &def);
2540 MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
2541
2542 TF_ASSIGN_OR_RETURN(Node * shape_node, graph->AddNode(def));
2543 dynamic_shape_nodes->push_back(shape_node);
2544
2545 shape_node->set_assigned_device_name(src->assigned_device_name());
2546 graph->AddEdge(src, src_output, shape_node, 0);
2547 for (Node* control_input : control_inputs) {
2548 graph->AddControlEdge(control_input, shape_node);
2549 }
2550 }
2551 }
2552 return OkStatus();
2553 }
2554
2555 namespace {
2556
XlaBroadcastTypeSupported(const DataType dtype)2557 bool XlaBroadcastTypeSupported(const DataType dtype) {
2558 return (dtype == DT_FLOAT || dtype == DT_BFLOAT16 || dtype == DT_INT32 ||
2559 dtype == DT_BOOL);
2560 }
2561
XlaBroadcastKindSupported(const DistributedTPURewritePass::ParameterInfo & params_info,int param_num)2562 bool XlaBroadcastKindSupported(
2563 const DistributedTPURewritePass::ParameterInfo& params_info,
2564 int param_num) {
2565 // NOTE: This is intended to cover non-sharded data parallel variables, for
2566 // training only. . Is it correct to just check if the arg_type is
2567 // DT_RESOURCE?
2568 return params_info.IsVariableArg(param_num) &&
2569 !(params_info.IsPerReplicaArg(param_num) ||
2570 params_info.IsDistributedArg(param_num) ||
2571 params_info.IsBroadcastArg(param_num) ||
2572 params_info.IsConstantArg(param_num));
2573 }
2574
EnableXlaParamBroadcast(bool enable_xla_param_broadcast,bool mpmd,const DistributedTPURewritePass::ParameterInfo & params_info,int param_num,DataType dtype)2575 bool EnableXlaParamBroadcast(
2576 bool enable_xla_param_broadcast, bool mpmd,
2577 const DistributedTPURewritePass::ParameterInfo& params_info, int param_num,
2578 DataType dtype) {
2579 // Conditions necessary to use XLA collectives for arg broadcast:
2580 // 1. Globally enabled via enable_xla_param_broadcast.
2581 // 2. DataType must be supported.
2582 // 3. Parameter must be a variable, and not distributed or broadcasted.
2583 // 4. For multi-core models (num_cores_per_replica > 1), must use SPMD.
2584 return enable_xla_param_broadcast && XlaBroadcastTypeSupported(dtype) &&
2585 XlaBroadcastKindSupported(params_info, param_num) && !mpmd;
2586 }
2587
2588 } // namespace
2589
2590 // Builds a TPUCompile node that compiles the bodies of the function call
2591 // `nodes`.
BuildCompileNode(const Node * replicate_node,const NameAttrList & function,uint64 library_fingerprint,const ParameterInfo & params_info,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & arg_types,const std::vector<Node * > & guaranteed_constant_nodes,const string & session_handle,const std::vector<xla::OpSharding> & arg_sharding,const std::vector<bool> & arg_fast_mem,const std::vector<std::string> & arg_names,const std::vector<xla::OpSharding> & retval_sharding,int num_cores_per_replica,const string & compile_device,const xla::DeviceAssignment * xla_device_assignment,const std::vector<Node * > & dynamic_shape_nodes,Graph * graph,Node ** compile_node,int64_t autotuner_thresh)2592 Status DistributedTPURewritePass::BuildCompileNode(
2593 const Node* replicate_node, const NameAttrList& function,
2594 uint64 library_fingerprint, const ParameterInfo& params_info,
2595 const std::vector<InferredShape>& arg_shapes,
2596 const DataTypeVector& arg_types,
2597 const std::vector<Node*>& guaranteed_constant_nodes,
2598 const string& session_handle,
2599 const std::vector<xla::OpSharding>& arg_sharding,
2600 const std::vector<bool>& arg_fast_mem,
2601 const std::vector<std::string>& arg_names,
2602 const std::vector<xla::OpSharding>& retval_sharding,
2603 int num_cores_per_replica, const string& compile_device,
2604 const xla::DeviceAssignment* xla_device_assignment,
2605 const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
2606 Node** compile_node, int64_t autotuner_thresh) {
2607 VLOG(1) << "BuildCompileNode";
2608
2609 tpu::TPUCompileMetadataProto proto;
2610 if (replicate_node) {
2611 std::string str;
2612 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node->attrs(),
2613 "tpu_compile_options_proto", &str));
2614 TF_RET_CHECK(proto.mutable_compile_options()->ParseFromString(str));
2615 }
2616 proto.set_num_replicas(params_info.NumReplicas());
2617 proto.set_num_cores_per_replica(num_cores_per_replica);
2618 proto.set_function_library_fingerprint(library_fingerprint);
2619 proto.set_enable_automatic_model_parallelism(
2620 enable_cross_replica_sharding_mirrored_variables_);
2621 const bool use_spmd =
2622 UseSpmdForXlaPartitioning(replicate_node) && allow_xla_spmd_partition_;
2623 proto.set_use_spmd_for_xla_partitioning(use_spmd);
2624 const bool mpmd = (num_cores_per_replica > 1) && !use_spmd;
2625
2626 // Get and fill padding map.
2627 if (replicate_node != nullptr) {
2628 xla::DebugOptions::StepMarkerLocation location;
2629 TF_RETURN_IF_ERROR(GetStepMarkerLocation(*replicate_node, &location));
2630 proto.set_step_marker_location(location);
2631 }
2632
2633 if (xla_device_assignment != nullptr) {
2634 TF_RETURN_IF_ERROR(
2635 xla_device_assignment->Serialize(proto.mutable_device_assignment()));
2636 }
2637
2638 const int num_args = arg_types.size();
2639 const int num_guaranteed_constants = guaranteed_constant_nodes.size();
2640 const int guaranteed_const_start_index = num_args - num_guaranteed_constants;
2641 TF_RET_CHECK(num_args == arg_shapes.size());
2642 TF_RET_CHECK(num_args == arg_sharding.size())
2643 << num_args << " != " << arg_sharding.size();
2644
2645 for (int i = 0; i < num_args; ++i) {
2646 tpu::TPUCompileMetadataProto::Arg* arg = proto.add_args();
2647 DataType type = arg_types[i];
2648 const InferredShape& arg_shape = arg_shapes[i];
2649 arg->set_name(arg_names[i]);
2650 if (type == DT_RESOURCE) {
2651 TF_RET_CHECK(arg_shape.handle_type != DT_INVALID) << i;
2652 arg->set_dtype(arg_shape.handle_type);
2653 arg_shape.handle_shape.AsProto(arg->mutable_shape());
2654 arg->set_kind(tpu::TPUCompileMetadataProto::Arg::VARIABLE);
2655 arg->set_fast_mem(arg_fast_mem[i]);
2656 } else {
2657 arg->set_dtype(type);
2658 arg_shape.shape.AsProto(arg->mutable_shape());
2659 if (i >= guaranteed_const_start_index) {
2660 const DataType edge_type =
2661 guaranteed_constant_nodes[i - guaranteed_const_start_index]
2662 ->output_type(0);
2663 TF_RET_CHECK(type == edge_type)
2664 << "Arg type: " << type << " but edge type: " << edge_type;
2665 arg->set_kind(tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT);
2666 } else {
2667 arg->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER);
2668 }
2669 }
2670
2671 // Use XLA collective primitives to distribute variables to all replicas.
2672 arg->set_requires_xla_broadcast(
2673 params_info.NumReplicas() > 1 &&
2674 EnableXlaParamBroadcast(enable_xla_param_broadcast_, mpmd, params_info,
2675 i, arg_shape.handle_type /*arg.dtype?*/));
2676
2677 // As long as the argument is not a per-replica one, it should have the same
2678 // value for all replicas. For clarity, we keep the (redundant) checks for
2679 // variable, broadcast and constant types, to prevent bugs in case new types
2680 // with different semantics are introduced in the future.
2681 arg->set_is_same_data_across_replicas(
2682 !params_info.IsPerReplicaArg(i) && !params_info.IsDistributedArg(i) &&
2683 (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
2684 params_info.IsConstantArg(i)));
2685 if (params_info.mirrored_variable_indices().count(i) > 0) {
2686 TF_RET_CHECK(type == DT_RESOURCE)
2687 << "Arg type: " << type << " name: " << arg->name()
2688 << " shape: " << arg->shape().DebugString();
2689 arg->set_is_same_data_across_replicas(true);
2690 // 64-bit type is not shardable by XLA:TPU yet.
2691 bool sharding_enabled = (arg_shape.handle_type != DT_COMPLEX64 &&
2692 arg_shape.handle_type != DT_INT64 &&
2693 arg_shape.handle_type != DT_UINT64 &&
2694 arg_shape.handle_type != DT_DOUBLE);
2695 arg->set_enable_xla_sharding(
2696 sharding_enabled ? tpu::TPUCompileMetadataProto::Arg::TENTATIVE
2697 : tpu::TPUCompileMetadataProto::Arg::DISALLOWED);
2698 }
2699 *arg->mutable_sharding() = arg_sharding[i];
2700 }
2701
2702 const int num_retvals = retval_sharding.size();
2703 for (int i = 0; i < num_retvals; ++i) {
2704 *proto.add_retvals()->mutable_sharding() = retval_sharding[i];
2705 }
2706 proto.set_session_handle(session_handle);
2707
2708 DataTypeVector constant_arg_types;
2709 constant_arg_types.reserve(num_guaranteed_constants);
2710 for (int i = 0; i < num_guaranteed_constants; ++i) {
2711 constant_arg_types.push_back(arg_types[guaranteed_const_start_index + i]);
2712 }
2713 proto.set_xla_fusion_autotuner_thresh(autotuner_thresh);
2714
2715 string metadata;
2716 proto.SerializeToString(&metadata);
2717
2718 NodeDef def;
2719 def.set_name(UniqueNodeName("TPUReplicate/_compile", graph));
2720 def.set_op("TPUCompile");
2721 def.set_device(compile_device);
2722 if (replicate_node) {
2723 MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
2724 }
2725
2726 AddNodeAttr("function", function, &def);
2727 AddNodeAttr("num_computations", num_cores_per_replica, &def);
2728 AddNodeAttr("NumDynamicShapes", static_cast<int>(dynamic_shape_nodes.size()),
2729 &def);
2730 AddNodeAttr("metadata", metadata, &def);
2731 AddNodeAttr("Tguaranteed_constants", constant_arg_types, &def);
2732
2733 TF_ASSIGN_OR_RETURN(*compile_node, graph->AddNode(def));
2734
2735 (*compile_node)->set_assigned_device_name(compile_device);
2736
2737 for (int i = 0; i < dynamic_shape_nodes.size(); ++i) {
2738 graph->AddEdge(dynamic_shape_nodes[i], 0, *compile_node, i);
2739 }
2740
2741 for (int i = 0; i < num_guaranteed_constants; ++i) {
2742 graph->AddEdge(guaranteed_constant_nodes[i], 0, *compile_node,
2743 dynamic_shape_nodes.size() + i);
2744 }
2745 VLOG(1) << "BuildCompileNode()";
2746 return OkStatus();
2747 }
2748
FindGuaranteedConstantInputs(const Node & node,const NameRangeMap & input_range_map,std::vector<Node * > * guaranteed_constants)2749 Status DistributedTPURewritePass::FindGuaranteedConstantInputs(
2750 const Node& node, const NameRangeMap& input_range_map,
2751 std::vector<Node*>* guaranteed_constants) {
2752 std::vector<const Edge*> input_edges;
2753 TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
2754 std::pair<int, int> variables_limits =
2755 input_range_map.at("guaranteed_constants");
2756 for (int i = variables_limits.first; i < variables_limits.second; ++i) {
2757 guaranteed_constants->push_back(input_edges[i]->src());
2758 }
2759 return OkStatus();
2760 }
2761
FindVariableInputs(const Node & node,const NameRangeMap & input_range_map,std::vector<VariableInput> * variables)2762 Status DistributedTPURewritePass::FindVariableInputs(
2763 const Node& node, const NameRangeMap& input_range_map,
2764 std::vector<VariableInput>* variables) {
2765 std::vector<const Edge*> input_edges;
2766 TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
2767 std::pair<int, int> variables_limits = input_range_map.at("variables");
2768 for (int i = variables_limits.first; i < variables_limits.second; ++i) {
2769 Node* node = input_edges[i]->src();
2770
2771 // Find the type of the VarHandleOp that feeds this node, looking through
2772 // any wrapping Enter or Switch nodes.
2773 while (node->IsEnter() || node->IsSwitch()) {
2774 TF_RETURN_IF_ERROR(node->input_node(0, &node));
2775 }
2776 // Fix the variable device assignment if it is requested with a full name.
2777 if (!node->has_assigned_device_name() &&
2778 !node->requested_device().empty()) {
2779 DeviceNameUtils::ParsedName var_device;
2780 TF_RET_CHECK(DeviceNameUtils::ParseFullName(node->requested_device(),
2781 &var_device));
2782 if (var_device.has_job && var_device.has_replica && var_device.has_task &&
2783 var_device.has_type && var_device.has_id) {
2784 node->set_assigned_device_name(node->requested_device());
2785 if (node != input_edges[i]->src() &&
2786 !input_edges[i]->src()->has_assigned_device_name()) {
2787 input_edges[i]->src()->set_assigned_device_name(
2788 node->requested_device());
2789 }
2790 }
2791 }
2792 if (node->type_string() == kVarHandleOp) {
2793 DataType dtype;
2794 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "dtype", &dtype));
2795 variables->push_back(VariableInput{input_edges[i]->src(),
2796 input_edges[i]->src_output(), dtype});
2797 } else if (node->type_string() == "_Arg") {
2798 std::vector<DataType> dtypes;
2799 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "_handle_dtypes", &dtypes));
2800 if (dtypes.empty()) {
2801 return errors::Internal(
2802 "_Arg node with resource output must have non-empty _handle_dtypes "
2803 "attribute: ",
2804 node->DebugString());
2805 }
2806 variables->push_back(VariableInput{
2807 input_edges[i]->src(), input_edges[i]->src_output(), dtypes[0]});
2808 } else {
2809 return errors::Internal(
2810 "Cannot handle variable input with node type other than VarHandleOp "
2811 "and _Arg: ",
2812 node->DebugString());
2813 }
2814 }
2815 return OkStatus();
2816 }
2817
2818 // Builds a NoOp node, used for building control dependencies.
BuildNoopNode(const Node & source,StringPiece name,const string & device,Graph * graph,Node ** node)2819 static Status BuildNoopNode(const Node& source, StringPiece name,
2820 const string& device, Graph* graph, Node** node) {
2821 NodeDefBuilder builder(name, "NoOp", NodeDebugInfo(source));
2822 if (!device.empty()) {
2823 builder.Device(device);
2824 }
2825 NodeDef def;
2826 TF_RETURN_IF_ERROR(builder.Finalize(&def));
2827
2828 TF_ASSIGN_OR_RETURN(*node, graph->AddNode(def));
2829 if (!device.empty()) {
2830 (*node)->set_assigned_device_name(device);
2831 }
2832 return OkStatus();
2833 }
2834
ConnectHostComputeNodes(Node * compile_node,Node * key_placeholder_node,Graph * graph)2835 Status DistributedTPURewritePass::ConnectHostComputeNodes(
2836 Node* compile_node, Node* key_placeholder_node, Graph* graph) {
2837 // First find all the downstream nodes of the key placeholder node, since we
2838 // want to delete the connecting edges from key_placeholder_node which would
2839 // invalidate the out_nodes iterator.
2840 std::vector<Node*> host_transfer_nodes;
2841 for (Node* node : key_placeholder_node->out_nodes()) {
2842 host_transfer_nodes.push_back(node);
2843 }
2844 for (Node* node : host_transfer_nodes) {
2845 int input_index = -1;
2846 for (int i = 0; i < node->num_inputs(); i++) {
2847 const Edge* e;
2848 TF_RETURN_IF_ERROR(node->input_edge(i, &e));
2849 if (e->src() == key_placeholder_node) {
2850 if (input_index != -1) {
2851 return errors::Internal(
2852 "Node ", node->name(),
2853 " has multiple input edges from key placeholder node");
2854 }
2855 input_index = e->dst_input();
2856 }
2857 }
2858 if (input_index == -1) {
2859 return errors::Internal("Node ", node->name(),
2860 " has no input edge from key placeholder node");
2861 }
2862 const Edge* key_edge;
2863 TF_RETURN_IF_ERROR(node->input_edge(input_index, &key_edge));
2864 graph->RemoveEdge(key_edge);
2865 graph->AddEdge(compile_node, 1, node, input_index);
2866 }
2867 graph->RemoveNode(key_placeholder_node);
2868 return OkStatus();
2869 }
2870
BuildVariableReads(absl::Span<const VariableInput> variables,Node * control_predecessor,Graph * graph,std::vector<Node * > * variable_reads)2871 Status DistributedTPURewritePass::BuildVariableReads(
2872 absl::Span<const VariableInput> variables, Node* control_predecessor,
2873 Graph* graph, std::vector<Node*>* variable_reads) {
2874 variable_reads->resize(variables.size());
2875 for (int i = 0; i < variables.size(); ++i) {
2876 string name =
2877 graph->NewName(strings::StrCat(variables[i].node->name(), "/read"));
2878 NodeDefBuilder builder(name, "ReadVariableOp",
2879 NodeDebugInfo(*variables[i].node));
2880
2881 builder.Attr("dtype", variables[i].dtype);
2882 builder.Device(variables[i].node->assigned_device_name());
2883 builder.Input(variables[i].node->name(), 0, DT_RESOURCE);
2884 NodeDef def;
2885 TF_RETURN_IF_ERROR(builder.Finalize(&def));
2886
2887 TF_ASSIGN_OR_RETURN(Node * read_node, graph->AddNode(def));
2888 (*variable_reads)[i] = read_node;
2889
2890 read_node->set_requested_device(variables[i].node->requested_device());
2891 read_node->set_assigned_device_name(
2892 variables[i].node->assigned_device_name());
2893 graph->AddEdge(variables[i].node, variables[i].index, read_node, 0);
2894
2895 graph->AddControlEdge(control_predecessor, read_node);
2896 }
2897 return OkStatus();
2898 }
2899
ContainsResourceWriteOp(const Graph & graph,const FunctionLibraryDefinition & fld)2900 bool DistributedTPURewritePass::ContainsResourceWriteOp(
2901 const Graph& graph, const FunctionLibraryDefinition& fld) {
2902 for (const Node* n : graph.nodes()) {
2903 const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n->type_string());
2904 if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
2905 VLOG(2) << "Found write resource op inside computation";
2906 return true;
2907 }
2908 }
2909 for (const string& func_name : fld.ListFunctionNames()) {
2910 const FunctionDef* func_def = fld.Find(func_name);
2911 for (const NodeDef& n : func_def->node_def()) {
2912 const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.op());
2913 if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
2914 VLOG(2) << "Found write resource op inside " << func_name;
2915 return true;
2916 }
2917 }
2918 }
2919 return false;
2920 }
2921
BuildVariableWrites(absl::Span<const VariableInput> variables,Node * control_successor,absl::Span<const VariableWrite> variable_writes,Graph * graph)2922 Status DistributedTPURewritePass::BuildVariableWrites(
2923 absl::Span<const VariableInput> variables, Node* control_successor,
2924 absl::Span<const VariableWrite> variable_writes, Graph* graph) {
2925 CHECK_EQ(variables.size(), variable_writes.size());
2926 for (int i = 0; i < variables.size(); ++i) {
2927 const VariableWrite& write = variable_writes[i];
2928 NodeDebugInfo debug_info(*variables[i].node);
2929
2930 auto name = [&](string suffix) {
2931 return graph->NewName(
2932 strings::StrCat(variables[i].node->name(), "/", suffix));
2933 };
2934
2935 Node* write_node;
2936 TF_RETURN_IF_ERROR(
2937 IncompleteNodeDefBuilder(name("assign"), "AssignVariableOp", debug_info)
2938 .AddAttr("dtype", variables[i].dtype)
2939 .Device(variables[i].node->assigned_device_name())
2940 .Build(graph, &write_node));
2941
2942 // Colocate the control flow with the variable.
2943 CondBuilder cb(variables[i].node->name(),
2944 variables[i].node->assigned_device_name(), debug_info,
2945 graph);
2946
2947 // Inputs to conditional.
2948 Node* switch_val;
2949 TF_RETURN_IF_ERROR(
2950 cb.AddInput("switch_val", variables[i].dtype,
2951 /*device=*/write.value->assigned_device_name(), debug_info,
2952 &switch_val));
2953 Node* switch_var;
2954 TF_RETURN_IF_ERROR(
2955 cb.AddInput("switch_var", DT_RESOURCE,
2956 /*device=*/variables[i].node->assigned_device_name(),
2957 debug_info, &switch_var));
2958 // Conditionally write the value back.
2959 graph->AddEdge(variables[i].node, variables[i].index, switch_var, 0);
2960 graph->AddEdge(switch_var, CondBuilder::kThenBranch, write_node, 0);
2961 graph->AddEdge(switch_val, CondBuilder::kThenBranch, write_node, 1);
2962 // Add control edge from the write to value that will be merged. There is no
2963 // output from the write so this control edge ensures the write completes.
2964 graph->AddControlEdge(write_node, cb.switch_t());
2965
2966 graph->AddControlEdge(cb.control_successor(), control_successor);
2967
2968 graph->AddEdge(write.predicate, write.predicate_output, cb.pred(), 0);
2969 graph->AddEdge(write.value, write.value_output, switch_val, 0);
2970 }
2971 return OkStatus();
2972 }
2973
2974 namespace {
2975
2976 // Computes the shape of the sharded tensor and modifies in place.
ComputeShardedArgShapes(TensorShape * shape,const xla::OpSharding & sharding)2977 Status ComputeShardedArgShapes(TensorShape* shape,
2978 const xla::OpSharding& sharding) {
2979 if (sharding.type() != xla::OpSharding::OTHER) {
2980 return OkStatus();
2981 }
2982 if (!shape->IsFullyDefined()) {
2983 return errors::Internal(
2984 "Arg shape must be fully defined before sharded shape inference.");
2985 }
2986 int sharded_rank = sharding.tile_assignment_dimensions_size();
2987 if (sharding.replicate_on_last_tile_dim()) {
2988 sharded_rank--;
2989 }
2990 for (int dim_idx = 0; dim_idx < sharded_rank; ++dim_idx) {
2991 auto sharded_dim = tensorflow::MathUtil::CeilOfRatio<int64_t>(
2992 shape->dim_size(dim_idx), sharding.tile_assignment_dimensions(dim_idx));
2993 shape->set_dim(dim_idx, sharded_dim);
2994 }
2995 if (sharded_rank != shape->dims()) {
2996 LOG(WARNING) << "Rank of sharded arg should match sharding spec. Rank: "
2997 << sharded_rank << ", tiled shape: " << shape->DebugString()
2998 << ", sharding: " << sharding.DebugString();
2999 }
3000
3001 return OkStatus();
3002 }
3003
3004 // Creates nodes for zero-initialized dummy arguments for TPUExecute nodes.
CreateTpuExecuteDummyArg(const TensorShape & var_shape,const DataType & dtype,const string & host_cpu_device,Node * var_read,int replica_id,Graph * graph)3005 xla::StatusOr<Node*> CreateTpuExecuteDummyArg(const TensorShape& var_shape,
3006 const DataType& dtype,
3007 const string& host_cpu_device,
3008 Node* var_read, int replica_id,
3009 Graph* graph) {
3010 Status status;
3011
3012 // Const - shape_as_tensor
3013 const std::string name_prefix = strings::StrCat(
3014 var_read->name(), absl::StrFormat("/dummy_%d", replica_id));
3015 NodeDef shape_tensor_def;
3016 shape_tensor_def.set_op("Const");
3017 shape_tensor_def.set_name(graph->NewName(
3018 strings::StrCat(name_prefix, "/Initializer/zeros/shape_as_tensor")));
3019 shape_tensor_def.set_device(host_cpu_device);
3020 AddNodeAttr("dtype", DT_INT32, &shape_tensor_def);
3021 TensorProto tensorshape_proto;
3022 tensorshape_proto.set_dtype(DT_INT32);
3023 for (int i = 0; i < var_shape.dims(); ++i) {
3024 tensorshape_proto.add_int_val(var_shape.dim_size(i));
3025 }
3026 TensorShape shape_shape({var_shape.dims()});
3027 shape_shape.AsProto(tensorshape_proto.mutable_tensor_shape());
3028 AddNodeAttr("value", tensorshape_proto, &shape_tensor_def);
3029 TF_ASSIGN_OR_RETURN(Node * shape_as_tensor_node,
3030 graph->AddNode(shape_tensor_def));
3031
3032 // Const - initializer value
3033 NodeDef init_val_def;
3034 init_val_def.set_op("Const");
3035 init_val_def.set_name(graph->NewName(
3036 strings::StrCat(name_prefix, "/Initializer/zeros/const_val")));
3037 init_val_def.set_device(host_cpu_device);
3038 TensorProto tensor_proto;
3039 tensor_proto.set_dtype(dtype);
3040 if (dtype == DT_FLOAT) {
3041 tensor_proto.add_float_val(0.0f);
3042 } else if (dtype == DT_BFLOAT16) {
3043 tensor_proto.add_half_val(0);
3044 } else if (dtype == DT_INT32) {
3045 tensor_proto.add_int_val(0);
3046 } else if (dtype == DT_BOOL) {
3047 tensor_proto.add_bool_val(false);
3048 } else {
3049 return errors::Internal(
3050 "Unable to create zero-init dummy arg tensor for type ", dtype);
3051 }
3052 TensorShape scalar_shape({});
3053 scalar_shape.AsProto(tensor_proto.mutable_tensor_shape());
3054 AddNodeAttr("value", tensor_proto, &init_val_def);
3055 AddNodeAttr("dtype", dtype, &init_val_def);
3056 TF_ASSIGN_OR_RETURN(Node * init_val_node, graph->AddNode(init_val_def));
3057
3058 // Fill node
3059 NodeDef fill_def;
3060 fill_def.set_op("Fill");
3061 fill_def.set_device(host_cpu_device);
3062 fill_def.set_name(
3063 graph->NewName(strings::StrCat(name_prefix, "/Initializer/zeros")));
3064 AddNodeAttr("T", dtype, &fill_def);
3065 AddNodeAttr("index_type", DT_INT32, &fill_def);
3066 TF_ASSIGN_OR_RETURN(Node * fill_node, graph->AddNode(fill_def));
3067 graph->AddEdge(shape_as_tensor_node, 0, fill_node, 0);
3068 graph->AddEdge(init_val_node, 0, fill_node, 1);
3069
3070 return fill_node;
3071 }
3072
3073 // Creates dummy inputs for partitioned variables that are using XLA broadcast
3074 // for inputs.
CreatePartitionedDummyVarArgs(const xla::OpSharding & sharding,const int num_replicas,const int replica_id,const InferredShape & raw_shape,Node * orig_var_read,const int orig_arg_num,DataType dtype,const string & device,Graph * graph,const std::vector<std::vector<string>> & tpu_device_names,absl::btree_map<ShardedPerHostInputIndex,Node * > * per_host_index,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)3075 Status CreatePartitionedDummyVarArgs(
3076 const xla::OpSharding& sharding, const int num_replicas,
3077 const int replica_id, const InferredShape& raw_shape, Node* orig_var_read,
3078 const int orig_arg_num, DataType dtype, const string& device, Graph* graph,
3079 const std::vector<std::vector<string>>& tpu_device_names,
3080 absl::btree_map<ShardedPerHostInputIndex, Node*>* per_host_index,
3081 std::map<ShardedInputIndex, ShardedInputInfo>*
3082 arg_index_to_sharded_input_map) {
3083 ShardedInputIndex input_index{replica_id, orig_arg_num};
3084 auto iter = arg_index_to_sharded_input_map->find(input_index);
3085 if (iter != arg_index_to_sharded_input_map->end()) {
3086 return OkStatus();
3087 }
3088 const int repeat = sharding.replicate_on_last_tile_dim()
3089 ? *sharding.tile_assignment_dimensions().rbegin()
3090 : 1;
3091 const int num_shards = sharding.tile_assignment_devices_size() / repeat;
3092
3093 TensorShape var_shape;
3094 if (!raw_shape.handle_shape.AsTensorShape(&var_shape) &&
3095 !raw_shape.shape.AsTensorShape(&var_shape)) {
3096 return errors::FailedPrecondition("Failed to read arg shape.");
3097 }
3098 TF_RETURN_IF_ERROR(ComputeShardedArgShapes(&var_shape, sharding));
3099
3100 for (int replica = 1; replica < num_replicas; ++replica) {
3101 std::vector<NodeOut> sharded_inputs_list(
3102 sharding.tile_assignment_devices_size());
3103 for (int i = 0; i < num_shards; ++i) {
3104 for (int j = 0; j < repeat; ++j) {
3105 const int index = i * repeat + j;
3106 const int core = sharding.tile_assignment_devices(index);
3107 string host_device;
3108 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
3109 tpu_device_names[replica][core], &host_device));
3110 ShardedPerHostInputIndex idx{host_device, orig_arg_num};
3111 if (!per_host_index->contains(idx)) {
3112 TF_ASSIGN_OR_RETURN(
3113 auto dummy_node,
3114 CreateTpuExecuteDummyArg(var_shape, dtype, host_device,
3115 orig_var_read, replica, graph));
3116 (*per_host_index)[idx] = dummy_node;
3117 }
3118 sharded_inputs_list[core] = {(*per_host_index)[idx], /*index=*/0};
3119 }
3120 }
3121 ShardedInputInfo sharded_input_info{nullptr,
3122 std::move(sharded_inputs_list)};
3123 (*arg_index_to_sharded_input_map)[{replica, orig_arg_num}] =
3124 sharded_input_info;
3125 }
3126
3127 return OkStatus();
3128 }
3129
3130 // Helper that creates an IdentityN node containing all of the variables
3131 // values on CPU device 'device', except for those that will be split across
3132 // cores. (For split variables, this may cause additional cross-host data
3133 // transfers if more than 1 devices share the same variable partition on a
3134 // remote host.)
3135 //
3136 // A previous iteration of this code built one Identity node per TPU core per
3137 // variable, but this can rapidly become hundreds of thousands of nodes. This
3138 // formulation creates a single IdentityN node containing all of the variables
3139 // on each host. This may cause some unnecessary variable copies if only a
3140 // subset of hosts consume a given variable, but has the virtue of being
3141 // simple, and most models use pure replication where all cores want all the
3142 // variables.
3143 //
3144 // If enable_xla_param_broadcast is set to true, then per-host dummy
3145 // tensor args are created on all hosts except for the primary host. In this
3146 // scheme, the dummy args feed the IdentityN node on their local host. All
3147 // are zero-initialized.
3148 //
3149 // Returns the node and its output index to be consumed by TPUExecute for the
3150 // requested variable index.
CreateOrGetPerHostVariableCopy(const string & host_cpu_device,int64_t var_index,const std::vector<Node * > & variable_reads,const DistributedTPURewritePass::ParameterInfo & params_info,const std::vector<xla::OpSharding> & arg_shardings,const Node & replicate_node,const bool enable_xla_param_broadcast,const bool mpmd,const int num_cores_per_replica,int replica_id,const std::vector<InferredShape> & arg_shapes,absl::flat_hash_map<string,std::vector<NodeOut>> * per_host_var_copies,Graph * graph)3151 xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
3152 const string& host_cpu_device, int64_t var_index,
3153 const std::vector<Node*>& variable_reads,
3154 const DistributedTPURewritePass::ParameterInfo& params_info,
3155 const std::vector<xla::OpSharding>& arg_shardings,
3156 const Node& replicate_node, const bool enable_xla_param_broadcast,
3157 const bool mpmd, const int num_cores_per_replica, int replica_id,
3158 const std::vector<InferredShape>& arg_shapes,
3159 absl::flat_hash_map<string, std::vector<NodeOut>>* per_host_var_copies,
3160 Graph* graph) {
3161 auto it = per_host_var_copies->find(host_cpu_device);
3162 if (it != per_host_var_copies->end()) {
3163 return it->second[var_index];
3164 }
3165
3166 DataTypeVector dtypes;
3167 // Per-variable data source for TPUExecute.
3168 std::vector<NodeOut> index_mapping;
3169 index_mapping.reserve(variable_reads.size());
3170 dtypes.reserve(variable_reads.size());
3171 for (int64_t i = 0; i < variable_reads.size(); ++i) {
3172 Node* read = variable_reads[i];
3173 int64_t orig_arg_num = i + params_info.NumPerReplicaArgs() +
3174 params_info.NumDistributedArgs() +
3175 params_info.NumBroadcastArgs();
3176 if (arg_shardings[orig_arg_num].type() != xla::OpSharding::OTHER) {
3177 // We haven't built the IdentityN node yet, so temporarily use nullptr.
3178 index_mapping.push_back(
3179 NodeOut{nullptr, static_cast<int>(dtypes.size())});
3180 dtypes.push_back(read->output_type(0));
3181 } else {
3182 // Do not copy the full tensor of partitioned variables.
3183 index_mapping.push_back(NodeOut{read, 0});
3184 }
3185 }
3186 NodeDef ndef;
3187 ndef.set_name(graph->NewName(
3188 absl::StrCat(replicate_node.name(), "/", kTpuExecuteStagingNodeName)));
3189 ndef.set_op(kTpuExecuteStagingOp);
3190 ndef.set_device(host_cpu_device);
3191 AddNodeAttr("T", dtypes, &ndef);
3192 // TF meta-optimizer should skip this node for constant folding.
3193 AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &ndef);
3194 TF_ASSIGN_OR_RETURN(Node * id_node, graph->AddNode(ndef));
3195 id_node->set_assigned_device_name(host_cpu_device);
3196
3197 for (int64_t i = 0; i < variable_reads.size(); ++i) {
3198 Node* read = variable_reads[i];
3199 int64_t orig_arg_num = i + params_info.NumPerReplicaArgs() +
3200 params_info.NumDistributedArgs() +
3201 params_info.NumBroadcastArgs();
3202 DataType dtype = read->output_type(0);
3203 bool use_xla_broadcast =
3204 EnableXlaParamBroadcast(enable_xla_param_broadcast, mpmd, params_info,
3205 orig_arg_num, dtype) &&
3206 replica_id != 0;
3207 if (index_mapping[i].node == nullptr) {
3208 // Fill index_mapping with the actual IdentityN node.
3209 index_mapping[i].node = id_node;
3210 if (!use_xla_broadcast) {
3211 // Add the variable read edge to id_node.
3212 graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
3213 } else {
3214 // XLA param broadcast mode is enabled. Create zero-valued dummy
3215 // tensors to use as variable args in the TPUExecuteOp, instead of
3216 // original variable reads.
3217 TensorShape var_shape;
3218 auto inferred_shape = arg_shapes[orig_arg_num];
3219 if (!inferred_shape.handle_shape.AsTensorShape(&var_shape) &&
3220 !inferred_shape.shape.AsTensorShape(&var_shape)) {
3221 return errors::FailedPrecondition("Failed to read arg shape.");
3222 }
3223 TF_ASSIGN_OR_RETURN(
3224 Node * dummy_read,
3225 CreateTpuExecuteDummyArg(var_shape, dtype, host_cpu_device,
3226 variable_reads[i], replica_id, graph));
3227 graph->AddEdge(dummy_read, 0, id_node, index_mapping[i].index);
3228 }
3229 }
3230 }
3231
3232 auto result = index_mapping[var_index];
3233 (*per_host_var_copies)[host_cpu_device] = std::move(index_mapping);
3234 return result;
3235 }
3236
3237 } // namespace
3238
BuildExecuteNodes(const ParameterInfo & params_info,int num_tasks,int num_cores_per_replica,const Node & replicate_node,const std::vector<std::string> & arg_names,const DataTypeVector & arg_types,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & retval_types,const std::vector<xla::OpSharding> & arg_shardings,const std::vector<xla::OpSharding> & retval_shardings,const std::vector<std::vector<string>> & tpu_device_names,Node * compile_node,const std::vector<Node * > & variable_reads,Node * control_predecessor,Node * control_successor,Node * multilock_acquire,std::vector<VariableWrite> * variable_writes,Graph * graph)3239 Status DistributedTPURewritePass::BuildExecuteNodes(
3240 const ParameterInfo& params_info, int num_tasks, int num_cores_per_replica,
3241 const Node& replicate_node, const std::vector<std::string>& arg_names,
3242 const DataTypeVector& arg_types,
3243 const std::vector<InferredShape>& arg_shapes,
3244 const DataTypeVector& retval_types,
3245 const std::vector<xla::OpSharding>& arg_shardings,
3246 const std::vector<xla::OpSharding>& retval_shardings,
3247 const std::vector<std::vector<string>>& tpu_device_names,
3248 Node* compile_node, const std::vector<Node*>& variable_reads,
3249 Node* control_predecessor, Node* control_successor, Node* multilock_acquire,
3250 std::vector<VariableWrite>* variable_writes, Graph* graph) {
3251 VLOG(1) << "BuildExecuteNodes " << replicate_node.DebugString();
3252 TF_RET_CHECK(params_info.NumReplicas() == tpu_device_names.size());
3253
3254 const int num_variables = variable_reads.size();
3255 const int num_retvals_per_replica = retval_types.size();
3256
3257 variable_writes->resize(num_variables);
3258
3259 std::vector<const Edge*> replicate_input_edges;
3260 TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
3261
3262 // Map from replicate input index to the fan_in node;
3263 absl::flat_hash_map<int, std::vector<NodeAndPort>>
3264 replicate_input_fan_in_nodes;
3265 absl::flat_hash_map<int, std::vector<Node*>> replicate_output_fan_out_nodes;
3266 absl::flat_hash_map<int, std::vector<int>>
3267 replicate_output_fan_out_dst_inputs;
3268 std::vector<Node*> to_be_removed_nodes;
3269
3270 const bool use_spmd =
3271 UseSpmdForXlaPartitioning(&replicate_node) && allow_xla_spmd_partition_;
3272 const bool mpmd = (num_cores_per_replica > 1) && !use_spmd;
3273
3274 for (const Edge* e : replicate_input_edges) {
3275 if (e->src()->type_string() == kTPUPartitionedInput) {
3276 int num_users = 0;
3277 for (const auto& ue : e->src()->out_edges()) {
3278 if (!ue->IsControlEdge()) ++num_users;
3279 }
3280 if (num_users != 1) {
3281 return tensorflow::errors::InvalidArgument(
3282 e->src()->name(), " must only have one user. Found ", num_users);
3283 }
3284 to_be_removed_nodes.push_back(e->src());
3285 std::vector<NodeAndPort>& nodes =
3286 replicate_input_fan_in_nodes[e->dst_input()];
3287 nodes.resize(num_cores_per_replica, NodeAndPort(nullptr, 0));
3288 VLOG(2) << "allocate " << num_cores_per_replica
3289 << " for replicate_input_fan_in_nodes[" << e->dst_input() << "]";
3290 std::vector<const Edge*> fan_in_edges;
3291 TF_RETURN_IF_ERROR(e->src()->input_edges(&fan_in_edges));
3292 TF_RET_CHECK(fan_in_edges.size() == num_cores_per_replica);
3293
3294 for (const Edge* fe : fan_in_edges) {
3295 nodes[fe->dst_input()].node = fe->src();
3296 nodes[fe->dst_input()].port = fe->src_output();
3297 VLOG(2) << "replicate_input_fan_in_nodes[" << e->dst_input() << "]["
3298 << fe->dst_input() << "] = " << fe->src()->name();
3299 }
3300 }
3301 }
3302
3303 // Replicate output edges are sorted by replica id and then by outputs for
3304 // each replica. For example, if TPU Computation has outputs (output_1,
3305 // output_2, and output_3) and number of replicas is 2, then
3306 // replicate_output_edges order would be:
3307 // output_1_replica_1, output_2_replica_1, output_3_replica_1,
3308 // output_1_replica_2, output_2_replica_2, output_3_replica_2.
3309 std::vector<const Edge*> replicate_output_edges(replicate_node.num_outputs(),
3310 nullptr);
3311 for (const Edge* edge : replicate_node.out_edges()) {
3312 if (edge->IsControlEdge()) continue;
3313
3314 int num_partitioned_outputs = 0;
3315
3316 for (const Edge* out_edge : edge->dst()->out_edges()) {
3317 if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
3318 num_partitioned_outputs++;
3319 // Paths between replicate_node and replicate_output_fan_out_nodes:
3320 // ReplicateNode->TpuOutIdenity->kTPUPartitionedOutput->fan-out-nodes
3321 TF_RET_CHECK(edge->dst()->out_edges().size() == 1);
3322 to_be_removed_nodes.push_back(edge->dst());
3323 to_be_removed_nodes.push_back(out_edge->dst());
3324 // Get the right replicated id from the replicate_output_edge.
3325 std::vector<Node*>& nodes =
3326 replicate_output_fan_out_nodes[edge->src_output()];
3327 std::vector<int>& dst_inputs =
3328 replicate_output_fan_out_dst_inputs[edge->src_output()];
3329 nodes.resize(num_cores_per_replica, nullptr);
3330 dst_inputs.resize(num_cores_per_replica, 0);
3331 TF_RET_CHECK(out_edge->dst()->out_edges().size() ==
3332 num_cores_per_replica);
3333
3334 for (const Edge* fe : out_edge->dst()->out_edges()) {
3335 nodes[fe->src_output()] = fe->dst();
3336 dst_inputs[fe->src_output()] = fe->dst_input();
3337 VLOG(2) << "replicate_output_fan_out_nodes[" << out_edge->src_output()
3338 << "][" << fe->src_output()
3339 << "] = " << fe->dst()->DebugString() << " with dst_input "
3340 << fe->dst_input();
3341 }
3342 }
3343 }
3344 replicate_output_edges[edge->src_output()] = edge;
3345 if (num_partitioned_outputs > 1) {
3346 return errors::InvalidArgument(
3347 "More than one TPUPartitionedOutput per replicated output.");
3348 }
3349 }
3350
3351 const int num_execute_args =
3352 arg_shardings.size() - params_info.NumGuaranteedConstants();
3353 // Inverts the arg_shardings and retval_shardings mappings to
3354 // form core -> {argument number} maps.
3355 std::vector<std::vector<int>> core_arg_nums(num_cores_per_replica);
3356 for (int i = 0; i < num_execute_args; ++i) {
3357 const auto& sharding = arg_shardings[i];
3358 if (sharding.type() == xla::OpSharding::MAXIMAL) {
3359 int core = sharding.tile_assignment_devices(0);
3360 TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
3361 core_arg_nums[core].push_back(i);
3362 } else if (sharding.type() == xla::OpSharding::OTHER) {
3363 for (int64_t core : sharding.tile_assignment_devices()) {
3364 core_arg_nums[core].push_back(i);
3365 }
3366 } else if (sharding.type() == xla::OpSharding::REPLICATED) {
3367 for (int core = 0; core < num_cores_per_replica; ++core) {
3368 core_arg_nums[core].push_back(i);
3369 }
3370 } else {
3371 return tensorflow::errors::InvalidArgument(
3372 "Unsupported argument sharding for arg=", arg_names[i],
3373 " shape=", arg_shapes[i].shape.DebugString(), ": ",
3374 sharding.DebugString());
3375 }
3376 }
3377 std::vector<std::vector<int>> core_retval_nums(num_cores_per_replica);
3378 for (int i = 0; i < retval_shardings.size(); ++i) {
3379 const auto& sharding = retval_shardings[i];
3380 if (sharding.type() == xla::OpSharding::MAXIMAL) {
3381 int core = sharding.tile_assignment_devices(0);
3382 TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
3383 core_retval_nums[core].push_back(i);
3384 } else if (sharding.type() == xla::OpSharding::REPLICATED) {
3385 for (int core = 0; core < num_cores_per_replica; ++core) {
3386 core_retval_nums[core].push_back(i);
3387 }
3388 } else if (sharding.type() == xla::OpSharding::OTHER) {
3389 for (int64_t core : sharding.tile_assignment_devices()) {
3390 core_retval_nums[core].push_back(i);
3391 }
3392 } else {
3393 return tensorflow::errors::InvalidArgument(
3394 "Unsupported argument sharding: ", sharding.DebugString());
3395 }
3396 }
3397
3398 // Maps host device name to a list of per-variable pairs (variable_copy_node,
3399 // output_index_of_copy_node).
3400 absl::flat_hash_map<string, std::vector<NodeOut>> per_host_var_copies;
3401
3402 Node* execute_successor = control_successor;
3403
3404 int num_total_cores = params_info.NumReplicas() * num_cores_per_replica;
3405 if (enable_multicore_locking_ && num_total_cores > 1) {
3406 // Add a node to release exclusive access once all the cores have finished
3407 // execution.
3408 NodeDef lock_def;
3409 lock_def.set_name(graph->NewName(
3410 strings::StrCat(compile_node->name(), "/", "tpu_release_multilock")));
3411 lock_def.set_op("ConsumeTpuMultilock");
3412 MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &lock_def);
3413 TF_ASSIGN_OR_RETURN(Node * multilock_release, graph->AddNode(lock_def));
3414 multilock_release->set_assigned_device_name(
3415 compile_node->assigned_device_name());
3416 TF_RET_CHECK(multilock_acquire != nullptr);
3417 graph->AddEdge(multilock_acquire, 0, multilock_release, 0);
3418 graph->AddControlEdge(multilock_release, control_successor);
3419 // Make sure all execute Ops happen before the multilock_release.
3420 execute_successor = multilock_release;
3421 }
3422
3423 // Mapping from original resource arg number to a second level map. Second
3424 // level map is from core id to output index of updated variable value.
3425 absl::flat_hash_map<int, absl::flat_hash_map<int, int>>
3426 orig_arg_num_to_output_index_mapping;
3427 // Mapping from retval index to a second level map. Second level map is from
3428 // core id to output index of sharded output value.
3429 std::unordered_map<int, std::unordered_map<int, int>>
3430 retval_index_to_output_index_mapping;
3431
3432 // Represents mapping of argument index of sharded input to each
3433 // TPUExecute node to its corresponding Split node and its output index
3434 // from which sharded input will be fed into TPUExecute node.
3435 std::map<ShardedInputIndex, ShardedInputInfo> input_index_to_sharded_inputs;
3436
3437 // Additional map of {host, arg_num} to dummy input. Per-task copies of the
3438 // inputs reduces cross-task communication and allows sharing across replicas.
3439 absl::btree_map<ShardedPerHostInputIndex, Node*> sharded_per_host_index;
3440
3441 // Builds one TPUExecute node per core per replica.
3442 std::vector<std::vector<Node*>> execute_nodes(params_info.NumReplicas());
3443 for (int core = 0; core < num_cores_per_replica; ++core) {
3444 DataTypeVector core_retval_types;
3445 for (int output : core_retval_nums[core]) {
3446 core_retval_types.push_back(retval_types[output]);
3447 }
3448 DataTypeVector core_arg_types;
3449 std::vector<int> core_variable_writes;
3450 for (int input : core_arg_nums[core]) {
3451 // Resource variables can be passed either by reference (as a DT_RESOURCE)
3452 // tensor or by value (as the variable's current value). Per-replica or
3453 // distributed resource arguments are always passed by reference and
3454 // broadcast variables are always passed by value.
3455 if (arg_types[input] == DT_RESOURCE &&
3456 !params_info.IsPerReplicaArg(input) &&
3457 !params_info.IsDistributedArg(input)) {
3458 DataType handle_type = arg_shapes[input].handle_type;
3459 TF_RET_CHECK(handle_type != DT_INVALID) << DataTypeString(handle_type);
3460 core_arg_types.push_back(handle_type);
3461 int base = input - params_info.NumPerReplicaArgs() -
3462 params_info.NumDistributedArgs() -
3463 params_info.NumBroadcastArgs();
3464 // Variables passed by value will have a corresponding additional output
3465 // containing an updated value for the variable.
3466 core_variable_writes.push_back(base);
3467 core_retval_types.push_back(handle_type);
3468 } else {
3469 core_arg_types.push_back(arg_types[input]);
3470 }
3471 }
3472
3473 NodeDef def;
3474 def.set_op("TPUExecute");
3475 MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
3476 AddNodeAttr("Targs", core_arg_types, &def);
3477 AddNodeAttr("Tresults", core_retval_types, &def);
3478
3479 // If the producer name was set during inference, propagate the information
3480 // to the TPUExecute op so it can be accessed during metric collection.
3481 std::string producer_name;
3482 Status status =
3483 GetNodeAttr(replicate_node.attrs(), "_producer_name", &producer_name);
3484 if (status.ok()) {
3485 AddNodeAttr("_producer_name", producer_name, &def);
3486 }
3487
3488 for (int64_t replica = 0; replica < params_info.NumReplicas(); ++replica) {
3489 def.set_name(strings::StrCat(replicate_node.name(), "/_execute_", replica,
3490 "_", core));
3491
3492 TF_ASSIGN_OR_RETURN(Node * node, graph->AddNode(def));
3493 execute_nodes[replica].push_back(node);
3494
3495 node->set_assigned_device_name(tpu_device_names[replica][core]);
3496
3497 // Add control edges to ensure that execution happens after
3498 // `control_predecessor`, happens before `execute_successor`, and is
3499 // triggered by evaluating any operator that depends on the original
3500 // TPUReplicate operator. See the comment at the top of the header file
3501 // for more details.
3502 graph->AddControlEdge(control_predecessor, node);
3503 graph->AddControlEdge(node, execute_successor);
3504
3505 // Add data input edges.
3506 for (int64_t i = 0; i < core_arg_nums[core].size(); ++i) {
3507 int64_t orig_arg_num = core_arg_nums[core][i];
3508 VLOG(2) << " replica " << replica << " core " << core << " i " << i
3509 << " orig_arg_num " << orig_arg_num;
3510 const bool is_per_replica_arg =
3511 params_info.IsPerReplicaArg(orig_arg_num);
3512 if (is_per_replica_arg || params_info.IsDistributedArg(orig_arg_num)) {
3513 // Per-replica input and distributed input
3514 const int64_t input_num =
3515 is_per_replica_arg ? replica * params_info.NumPerReplicaArgs() +
3516 core_arg_nums[core][i]
3517 : params_info.NumReplicas() *
3518 params_info.NumPerReplicaArgs() +
3519 core_arg_nums[core][i] -
3520 params_info.NumPerReplicaArgs();
3521
3522 const Edge* edge = replicate_input_edges[input_num];
3523 VLOG(2) << "replicate_input_edges[" << input_num << "]";
3524 DataType dtype = edge->src()->output_type(edge->src_output());
3525 if (dtype == DT_RESOURCE) {
3526 DataType handle_dtype = arg_shapes[orig_arg_num].handle_type;
3527 if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(),
3528 handle_dtype) == kTpuAllTypes.end()) {
3529 return errors::InvalidArgument(
3530 "Unsupported resource variable data type for TPU: ",
3531 DataTypeString(handle_dtype), ", caused by output ",
3532 edge->src()->name(), ":", edge->src_output());
3533 }
3534 } else {
3535 if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3536 kTpuAllTypes.end()) {
3537 return errors::InvalidArgument(
3538 "Unsupported data type for TPU: ", DataTypeString(dtype),
3539 ", caused by output ", edge->src()->name(), ":",
3540 edge->src_output());
3541 }
3542 }
3543 if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
3544 // Don't automatically add a split node when input node is
3545 // kTPUPartitionedInput
3546 if (edge->src()->type_string() == kTPUPartitionedInput) {
3547 VLOG(2)
3548 << "Connect "
3549 << replicate_input_fan_in_nodes[input_num][core].node->name()
3550 << " to " << node->name() << " at " << i;
3551 graph->AddEdge(replicate_input_fan_in_nodes[input_num][core].node,
3552 replicate_input_fan_in_nodes[input_num][core].port,
3553 node, i);
3554 } else {
3555 if (dtype == DT_RESOURCE) {
3556 return errors::InvalidArgument(
3557 "Tiled sharding for per-replica DT_RESOURCE input must",
3558 "be TPUPartitionedInput. Here got ",
3559 edge->src()->type_string());
3560 }
3561 const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
3562
3563 ShardedInputInfo sharded_input_info;
3564 if (use_nd_sharding_ops_ && is_per_replica_arg) {
3565 TF_ASSIGN_OR_RETURN(
3566 sharded_input_info,
3567 CreateOrGetXlaSplitNodeForShardedPerReplicaArg(
3568 sharding, replica, orig_arg_num, dtype,
3569 PartialTensorShape(), edge->src(), edge->src_output(),
3570 graph, &input_index_to_sharded_inputs));
3571 } else if (use_nd_sharding_ops_) {
3572 TF_ASSIGN_OR_RETURN(
3573 sharded_input_info,
3574 CreateOrGetXlaSplitNodeForDistributedArg(
3575 sharding, params_info.NumReplicas(), replica,
3576 orig_arg_num, dtype, PartialTensorShape(), edge->src(),
3577 edge->src_output(), graph,
3578 &input_index_to_sharded_inputs));
3579 } else {
3580 TF_ASSIGN_OR_RETURN(
3581 sharded_input_info,
3582 CreateOrGetSplitNodesForInputSharding(
3583 sharding, orig_arg_num, dtype, PartialTensorShape(),
3584 replica, edge->src_output(), edge->src(),
3585 control_predecessor, graph,
3586 &input_index_to_sharded_inputs));
3587 }
3588
3589 NodeOut split_node_and_index =
3590 sharded_input_info.sharded_inputs.at(core);
3591 // Connect with Split node output.
3592 graph->AddEdge(split_node_and_index.node,
3593 split_node_and_index.index, node, i);
3594 }
3595 } else if (edge->src()->type_string() == kTPUPartitionedInput &&
3596 arg_shardings[orig_arg_num].type() ==
3597 xla::OpSharding::REPLICATED) {
3598 graph->AddEdge(replicate_input_fan_in_nodes[input_num][core].node,
3599 replicate_input_fan_in_nodes[input_num][core].port,
3600 node, i);
3601 } else {
3602 graph->AddEdge(edge->src(), edge->src_output(), node, i);
3603 }
3604 } else if (params_info.IsBroadcastArg(orig_arg_num)) {
3605 // Broadcast input.
3606 int64_t input_num = params_info.FirstBroadcastArgFromHost() +
3607 core_arg_nums[core][i] -
3608 params_info.NumPerReplicaArgs() -
3609 params_info.NumDistributedArgs();
3610 const Edge* edge = replicate_input_edges[input_num];
3611 DataType dtype = edge->src()->output_type(edge->src_output());
3612 if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3613 kTpuAllTypes.end()) {
3614 return errors::InvalidArgument(
3615 "Unsupported data type for TPU: ", DataTypeString(dtype),
3616 ", caused by output ", edge->src()->name(), ":",
3617 edge->src_output());
3618 }
3619 graph->AddEdge(edge->src(), edge->src_output(), node, i);
3620 } else {
3621 // Variable input.
3622 int64_t variable_num =
3623 orig_arg_num - params_info.NumPerReplicaArgs() -
3624 params_info.NumDistributedArgs() - params_info.NumBroadcastArgs();
3625 TF_RET_CHECK(variable_num < num_variables);
3626
3627 Node* variable_read = variable_reads[variable_num];
3628 DataType dtype = variable_read->output_type(0);
3629 if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3630 kTpuAllTypes.end()) {
3631 return errors::InvalidArgument(
3632 "Unsupported resource variable data type for TPU: ",
3633 DataTypeString(dtype), ", caused by ReadVariableOp ",
3634 variable_read->DebugString());
3635 }
3636 DeviceNameUtils::ParsedName requested_device;
3637 string requested = variable_read->requested_device();
3638 TF_RET_CHECK(
3639 DeviceNameUtils::ParseFullName(requested, &requested_device));
3640 if (requested_device.type != "TPU") {
3641 // Stage the value via the CPU device on the remote host. The graph
3642 // partitioner will introduce an intermediate copy rather than
3643 // copying the same tensor multiple times across the network, and we
3644 // would prefer that intermediate copy to be in host memory to avoid
3645 // running out of memory if the TPUExecute op on the staging device
3646 // starts running before the _Send ops to the other TPU devices on
3647 // the same host complete. We don't do this if the variables are
3648 // already placed on TPU, otherwise it will cause an unnecessary
3649 // round trip copy.
3650 // TODO(b/79580121): give each replica its own on-device variable
3651 // replica and then delete this code.
3652 string device;
3653 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
3654 tpu_device_names[replica][core], &device));
3655 TF_ASSIGN_OR_RETURN(
3656 auto var_data,
3657 CreateOrGetPerHostVariableCopy(
3658 device, variable_num, variable_reads, params_info,
3659 arg_shardings, replicate_node, enable_xla_param_broadcast_,
3660 mpmd, num_cores_per_replica, replica, arg_shapes,
3661 &per_host_var_copies, graph));
3662
3663 if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
3664 ShardedInputInfo sharded_input_info;
3665
3666 if (EnableXlaParamBroadcast(enable_xla_param_broadcast_, mpmd,
3667 params_info, orig_arg_num, dtype)) {
3668 // Populates the sharded dummy vars for non-zero replicas.
3669 TF_RETURN_IF_ERROR(CreatePartitionedDummyVarArgs(
3670 arg_shardings[orig_arg_num], params_info.NumReplicas(),
3671 replica, arg_shapes[orig_arg_num], var_data.node,
3672 orig_arg_num, dtype, device, graph, tpu_device_names,
3673 &sharded_per_host_index, &input_index_to_sharded_inputs));
3674 }
3675
3676 if (use_nd_sharding_ops_) {
3677 TF_ASSIGN_OR_RETURN(
3678 sharded_input_info,
3679 CreateOrGetXlaSplitNodeForVariableArg(
3680 arg_shardings[orig_arg_num], params_info.NumReplicas(),
3681 replica, orig_arg_num,
3682 arg_shapes[orig_arg_num].handle_type,
3683 arg_shapes[orig_arg_num].handle_shape, var_data.node,
3684 var_data.index, graph, &to_be_removed_nodes,
3685 &input_index_to_sharded_inputs));
3686 } else {
3687 TF_ASSIGN_OR_RETURN(
3688 sharded_input_info,
3689 CreateOrGetSplitNodesForInputSharding(
3690 arg_shardings[orig_arg_num], orig_arg_num,
3691 arg_shapes[orig_arg_num].handle_type,
3692 arg_shapes[orig_arg_num].handle_shape, replica,
3693 var_data.index, var_data.node, control_predecessor,
3694 graph, &input_index_to_sharded_inputs));
3695 }
3696
3697 NodeOut split_node_and_index =
3698 sharded_input_info.sharded_inputs[core];
3699 // Connect with Split node output.
3700 graph->AddEdge(split_node_and_index.node,
3701 split_node_and_index.index, node, i);
3702
3703 } else {
3704 graph->AddEdge(var_data.node, var_data.index, node, i);
3705 }
3706 } else {
3707 graph->AddEdge(variable_reads[variable_num], 0, node, i);
3708 }
3709 }
3710 }
3711
3712 // Adds a program input edge from the compiler.
3713 graph->AddEdge(compile_node, core + 1, node, node->num_inputs() - 1);
3714
3715 // Add data output edges.
3716 int num_outputs = core_retval_nums[core].size();
3717 for (int i = 0; i < num_outputs; ++i) {
3718 int output_num =
3719 replica * num_retvals_per_replica + core_retval_nums[core][i];
3720 const auto& sharding = retval_shardings[core_retval_nums[core][i]];
3721 if (sharding.type() == xla::OpSharding::OTHER) {
3722 int retval_index = core_retval_nums[core][i];
3723 retval_index_to_output_index_mapping[retval_index][core] = i;
3724 bool is_last_core =
3725 core ==
3726 *std::max_element(sharding.tile_assignment_devices().begin(),
3727 sharding.tile_assignment_devices().end());
3728 bool isPartitionOutNode = false;
3729
3730 const Edge* e = replicate_output_edges[output_num];
3731 const Edge* e_out;
3732 for (const Edge* out_edge : e->dst()->out_edges()) {
3733 if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
3734 isPartitionOutNode = true;
3735 e_out = out_edge;
3736 }
3737 }
3738 if (isPartitionOutNode) {
3739 graph->AddEdge(
3740 node, i, replicate_output_fan_out_nodes[output_num][core],
3741 replicate_output_fan_out_dst_inputs[output_num][core]);
3742 VLOG(2) << "Connect " << node->name() << " at " << i << " to "
3743 << replicate_output_fan_out_nodes[output_num][core]->name()
3744 << " at "
3745 << replicate_output_fan_out_dst_inputs[output_num][core];
3746 if (is_last_core) {
3747 graph->RemoveEdge(e);
3748 graph->RemoveEdge(e_out);
3749 }
3750 continue;
3751 }
3752
3753 // Do this in the iteration of last core in tile assignment, so all
3754 // TPUExecute nodes have been created.
3755 if (!is_last_core) {
3756 continue;
3757 }
3758
3759 // Add a Concat node.
3760 std::vector<NodeOut> orig_inputs;
3761 for (int64_t tile_index = 0;
3762 tile_index < sharding.tile_assignment_devices_size();
3763 ++tile_index) {
3764 int64_t last_tile_dim_size =
3765 *sharding.tile_assignment_dimensions().rbegin();
3766 if (sharding.replicate_on_last_tile_dim() &&
3767 tile_index % last_tile_dim_size != 0) {
3768 continue;
3769 }
3770 int64_t core_id = sharding.tile_assignment_devices(tile_index);
3771 int core_retval_index =
3772 retval_index_to_output_index_mapping[retval_index][core_id];
3773 orig_inputs.push_back(
3774 NodeOut{execute_nodes[replica][core_id],
3775 static_cast<int>(
3776 core_retval_nums[core_id][core_retval_index])});
3777 }
3778 DataType dtype = e->src()->output_type(e->src_output());
3779 Node* concat_node = nullptr;
3780 if (use_nd_sharding_ops_) {
3781 TF_ASSIGN_OR_RETURN(
3782 concat_node, CreateXlaConcatNode(
3783 sharding, replica, dtype,
3784 /*partial_tensor_shape=*/PartialTensorShape(),
3785 orig_inputs, /*device=*/"", graph));
3786 } else {
3787 TF_ASSIGN_OR_RETURN(
3788 concat_node,
3789 CreateConcatNodesForRetval(
3790 sharding, dtype, /*inferred_shape=*/PartialTensorShape(),
3791 replica, orig_inputs, graph, /*device=*/""));
3792 }
3793
3794 const Edge* edge = replicate_output_edges[output_num];
3795 Node* dst = edge->dst();
3796 int dst_input = edge->dst_input();
3797 graph->RemoveEdge(edge);
3798 graph->AddEdge(concat_node, 0, dst, dst_input);
3799
3800 continue;
3801 }
3802
3803 // If this is a replicated output, outputs on all cores will be the
3804 // same, and we only take the output from core 0.
3805 if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
3806 continue;
3807 }
3808
3809 // If output has maximal sharding, make sure we only use output from
3810 // TPUExecute node with logical core id equal to core id defined by the
3811 // xla sharding.
3812 if (sharding.type() == xla::OpSharding::MAXIMAL &&
3813 core != sharding.tile_assignment_devices(0)) {
3814 continue;
3815 }
3816
3817 const Edge* replicate_edge_to_replace =
3818 replicate_output_edges[output_num];
3819 Node* dst = replicate_edge_to_replace->dst();
3820 int dst_input = replicate_edge_to_replace->dst_input();
3821 graph->RemoveEdge(replicate_edge_to_replace);
3822 graph->AddEdge(node, i, dst, dst_input);
3823 }
3824
3825 // Feed the updated variable values from the first replica to the
3826 // variable write nodes.
3827 if (replica == 0) {
3828 for (int i = 0; i < core_variable_writes.size(); ++i) {
3829 int orig_arg_num =
3830 core_variable_writes[i] + params_info.NumPerReplicaArgs() +
3831 params_info.NumDistributedArgs() + params_info.NumBroadcastArgs();
3832 const auto& sharding = arg_shardings[orig_arg_num];
3833 // If this is a tiling sharded variable, concat variable updates from
3834 // all cores.
3835 if (sharding.type() == xla::OpSharding::OTHER) {
3836 orig_arg_num_to_output_index_mapping[orig_arg_num][core] = i;
3837
3838 // Do this in the iteration of last core in tile assignment, so all
3839 // TPUExecute nodes have been created.
3840 if (core !=
3841 *std::max_element(sharding.tile_assignment_devices().begin(),
3842 sharding.tile_assignment_devices().end())) {
3843 continue;
3844 }
3845
3846 // Add a Concat node.
3847 std::vector<NodeOut> orig_inputs;
3848 for (int64_t tile_index = 0;
3849 tile_index < sharding.tile_assignment_devices_size();
3850 ++tile_index) {
3851 int64_t last_tile_dim_size =
3852 *sharding.tile_assignment_dimensions().rbegin();
3853 if (sharding.replicate_on_last_tile_dim() &&
3854 tile_index % last_tile_dim_size != 0) {
3855 continue;
3856 }
3857 int64_t core_id = sharding.tile_assignment_devices(tile_index);
3858 int core_retval_num =
3859 orig_arg_num_to_output_index_mapping[orig_arg_num][core_id];
3860 orig_inputs.push_back(
3861 NodeOut{execute_nodes[0][core_id],
3862 static_cast<int>(core_retval_nums[core_id].size() +
3863 core_retval_num)});
3864 }
3865
3866 // Use the variable read's device for the concat. They should both
3867 // be collocated with the variable.
3868 absl::string_view device =
3869 variable_reads[core_variable_writes[i]]->assigned_device_name();
3870 Node* concat_node = nullptr;
3871 if (use_nd_sharding_ops_) {
3872 TF_ASSIGN_OR_RETURN(
3873 concat_node,
3874 CreateXlaConcatNode(sharding, replica,
3875 arg_shapes[orig_arg_num].handle_type,
3876 arg_shapes[orig_arg_num].handle_shape,
3877 orig_inputs, device, graph));
3878 } else {
3879 TF_ASSIGN_OR_RETURN(
3880 concat_node,
3881 CreateConcatNodesForRetval(
3882 sharding, arg_shapes[orig_arg_num].handle_type,
3883 arg_shapes[orig_arg_num].handle_shape, replica,
3884 orig_inputs, graph, device));
3885 }
3886 // Populate VariableWrite.
3887 VariableWrite& write = variable_writes->at(core_variable_writes[i]);
3888 write.value = concat_node;
3889 write.value_output = 0;
3890 write.predicate = compile_node;
3891 write.predicate_output = num_cores_per_replica + core + 1;
3892
3893 continue;
3894 }
3895
3896 // If this is a replicated variable, outputs on all cores will be the
3897 // same, and we only take the output from core 0 for the variable
3898 // update.
3899 if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
3900 continue;
3901 }
3902 VariableWrite& write = variable_writes->at(core_variable_writes[i]);
3903 write.value = node;
3904 write.value_output = num_outputs + i;
3905 write.predicate = compile_node;
3906 write.predicate_output = num_cores_per_replica + core + 1;
3907 }
3908 }
3909 }
3910 }
3911
3912 for (Node* node : to_be_removed_nodes) {
3913 graph->RemoveNode(node);
3914 }
3915 return OkStatus();
3916 } // NOLINT(readability/fn_size)
3917
CopyOutsideCompilationNodes(int replica_index,const std::vector<Node * > & outside_compilation_nodes,const DeviceNameUtils::ParsedName & tpu_device,const DeviceNameUtils::ParsedName & partial_device,NodeToNodeReplicasMap * node_images,Graph * graph)3918 /* static */ Status DistributedTPURewritePass::CopyOutsideCompilationNodes(
3919 int replica_index, const std::vector<Node*>& outside_compilation_nodes,
3920 const DeviceNameUtils::ParsedName& tpu_device,
3921 const DeviceNameUtils::ParsedName& partial_device,
3922 NodeToNodeReplicasMap* node_images, Graph* graph) {
3923 for (Node* node : outside_compilation_nodes) {
3924 NodeDef image_def = node->def();
3925 MergeDebugInfo(NodeDebugInfo(node->def()), &image_def);
3926 const string suffix = strings::StrCat("/R", replica_index);
3927 // In addition to node name, make the frame name unique to avoid multiple
3928 // LoopCond nodes in one frame.
3929 TF_RETURN_IF_ERROR(
3930 AddPrefixAndSuffixToNode("" /* prefix */, suffix, &image_def));
3931 TF_ASSIGN_OR_RETURN(Node * image, graph->AddNode(image_def));
3932 image->AddAttr(kXlaReplicaIdAttrName, replica_index);
3933 if (HasNodeAttr(image->def(), kXlaHasHostTransferAttrName)) {
3934 TF_RETURN_IF_ERROR(
3935 SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, image));
3936 } else {
3937 const string& original_device_string =
3938 node->assigned_device_name().empty() ? node->requested_device()
3939 : node->assigned_device_name();
3940 DeviceNameUtils::ParsedName device;
3941 TF_RET_CHECK(
3942 DeviceNameUtils::ParseFullName(original_device_string, &device));
3943 // If the requested device can be merged with the replica's host device,
3944 // then do so. For example, if the requested device is "/CPU:0" or
3945 // "/GPU:0" then it will be placed on the CPU/GPU of the host where this
3946 // replica is running. But if the requested device is
3947 // "/task:3/replica:2/CPU:0" then it will be placed on that task/replica.
3948 if (DeviceNameUtils::IsSpecification(device, partial_device)) {
3949 TF_RETURN_IF_ERROR(
3950 DeviceNameUtils::MergeDevNames(&device, partial_device));
3951 }
3952 image->set_requested_device(DeviceNameUtils::ParsedNameToString(device));
3953 }
3954 std::vector<Node*>& node_image_vector = (*node_images)[node];
3955 node_image_vector.resize(replica_index + 1);
3956 node_image_vector[replica_index] = image;
3957 }
3958 return OkStatus();
3959 }
3960
ReplicateOutsideCompilationNodes(const std::vector<std::vector<string>> & tf_device_assignment,const HostComputeCoreMap & host_compute_core,const OutsideCompilationNodeMap & outside_compilation_nodes,NodeToNodeReplicasMap * node_images,Graph * graph)3961 /* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationNodes(
3962 const std::vector<std::vector<string>>& tf_device_assignment,
3963 const HostComputeCoreMap& host_compute_core,
3964 const OutsideCompilationNodeMap& outside_compilation_nodes,
3965 NodeToNodeReplicasMap* node_images, Graph* graph) {
3966 // Iterate over replicas.
3967 for (int i = 0; i < tf_device_assignment.size(); ++i) {
3968 const auto& core_devices = tf_device_assignment[i];
3969 for (const auto& oc_cluster_iter : outside_compilation_nodes) {
3970 const string& oc_cluster_name = oc_cluster_iter.first;
3971 const auto& oc_cluster_nodes = oc_cluster_iter.second;
3972 // We previously validated that host_compute_core contains an entry for
3973 // each cluster.
3974 int core = host_compute_core.at(oc_cluster_name);
3975 TF_RET_CHECK(core >= 0 && core < core_devices.size());
3976 // tpu_device is the device the HostCompute XLA Op for this cluster runs
3977 // on.
3978 DeviceNameUtils::ParsedName tpu_device;
3979 TF_RET_CHECK(
3980 DeviceNameUtils::ParseFullName(core_devices[core], &tpu_device));
3981 // partial_device contains the replica and task but not the type.
3982 DeviceNameUtils::ParsedName partial_device = tpu_device;
3983 partial_device.has_type = false;
3984 partial_device.has_id = false;
3985
3986 if (tf_device_assignment.size() == 1) {
3987 // With a single replica don't copy any nodes just put the original
3988 // nodes into the image map. We leave the device placement alone, except
3989 // that we have to fill in the correct core for the host send and
3990 // receive nodes.
3991 for (Node* node : oc_cluster_nodes) {
3992 (*node_images)[node] = {node};
3993 node->AddAttr(kXlaReplicaIdAttrName, 0);
3994 if (HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
3995 TF_RETURN_IF_ERROR(
3996 SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, node));
3997 }
3998 }
3999 } else {
4000 // Iterate over outside_compilation clusters in this computation, adding
4001 // all the nodes with appropriate device assignments.
4002 TF_RETURN_IF_ERROR(
4003 CopyOutsideCompilationNodes(i, oc_cluster_nodes, tpu_device,
4004 partial_device, node_images, graph));
4005 }
4006 }
4007 }
4008 return OkStatus();
4009 }
4010
CopyOutsideCompilationEdges(const std::vector<Node * > & outside_compilation_nodes,const NodeToNodeReplicasMap & node_images,const std::unordered_map<string,Node * > outside_compilation_inputs,Graph * graph)4011 /* static */ Status DistributedTPURewritePass::CopyOutsideCompilationEdges(
4012 const std::vector<Node*>& outside_compilation_nodes,
4013 const NodeToNodeReplicasMap& node_images,
4014 const std::unordered_map<string, Node*> outside_compilation_inputs,
4015 Graph* graph) {
4016 for (Node* node : outside_compilation_nodes) {
4017 const auto& images = node_images.at(node);
4018 // Make a copy of all edges and iterate on "in_edges", because we might
4019 // remove edges when iteratating through them.
4020 std::vector<const Edge*> in_edges(node->in_edges().begin(),
4021 node->in_edges().end());
4022 for (const Edge* edge : in_edges) {
4023 Node* src = edge->src();
4024 const auto iter = node_images.find(src);
4025 if (iter == node_images.end()) {
4026 if (images.size() > 1) {
4027 // The source node is a 'normal' node not part of any
4028 // rewrite. Broadcast the value to all replicas. (If images.size() ==
4029 // 1 the cluster is not replicated and we can leave the original edge
4030 // in place.)
4031 for (Node* dst : images) {
4032 graph->AddEdge(src, edge->src_output(), dst, edge->dst_input());
4033 }
4034 }
4035 continue;
4036 }
4037
4038 // The source node is a replicated outside_compilation node.
4039 const auto& src_images = iter->second;
4040 if (src_images.size() != images.size()) {
4041 return errors::InvalidArgument(
4042 "Graph contains an edge from node ", src->name(),
4043 " in an outside_compilation block replicated ", src_images.size(),
4044 " ways to node ", node->name(),
4045 " in an outside_compilation block replicated ", images.size(),
4046 " ways. Replication factors must match. Leave a comment on "
4047 "tracking bug b/76419636 if you need this to be supported.");
4048 }
4049 bool is_lifted_arg;
4050 string outside_compilation_cluster;
4051 if (GetNodeAttr(src->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg)
4052 .ok() &&
4053 GetNodeAttr(src->def(), kOutsideCompilationAttr,
4054 &outside_compilation_cluster)
4055 .ok()) {
4056 const auto input_iter =
4057 outside_compilation_inputs.find(outside_compilation_cluster);
4058 TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
4059 TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
4060 int dst_input = edge->dst_input();
4061 if (src_images.size() == 1) {
4062 graph->RemoveEdge(edge);
4063 }
4064 for (int i = 0; i < src_images.size(); ++i) {
4065 graph->AddEdge(input_iter->second, i, images[i], dst_input);
4066 }
4067 continue;
4068 }
4069
4070 bool is_placeholder_for_arg;
4071 string outside_compilation_input_attr;
4072 if (GetNodeAttr(src->def(), kXlaIsPlaceholderForArg,
4073 &is_placeholder_for_arg)
4074 .ok() &&
4075 GetNodeAttr(src->def(), kXlaOutsideCompilationInputsAttrName,
4076 &outside_compilation_input_attr)
4077 .ok()) {
4078 const auto input_iter =
4079 outside_compilation_inputs.find(outside_compilation_input_attr);
4080 TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
4081 TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
4082 int dst_input = edge->dst_input();
4083 if (src_images.size() == 1) {
4084 graph->RemoveEdge(edge);
4085 }
4086 for (int i = 0; i < src_images.size(); ++i) {
4087 graph->AddEdge(input_iter->second, i, images[i], dst_input);
4088 }
4089 continue;
4090 }
4091
4092 if (images.size() > 1) {
4093 // If images.size() == 1 neither cluster is replicated and we can
4094 // leave the original edges in place.
4095 for (int i = 0; i < src_images.size(); ++i) {
4096 graph->AddEdge(src_images[i], edge->src_output(), images[i],
4097 edge->dst_input());
4098 }
4099 }
4100 }
4101 for (const Edge* edge : node->out_edges()) {
4102 Node* dst = edge->dst();
4103 const auto iter = node_images.find(dst);
4104 if (iter == node_images.end()) {
4105 // The source node is a 'normal' node not part of any rewrite.
4106 if (edge->IsControlEdge()) {
4107 // Make the dst node have a control dependency on every replica.
4108 if (images.size() > 1) {
4109 for (int i = 0; i < images.size(); ++i) {
4110 graph->AddControlEdge(images[i], dst);
4111 }
4112 }
4113 // else the cluster is not replicated so we can leave the original
4114 // edge in place.
4115 } else {
4116 // The edge
4117 // is only valid if the outside_compilation block is not replicated.
4118 if (images.size() > 1) {
4119 return errors::InvalidArgument(
4120 "Graph contains an edge from node ", node->name(),
4121 " in an outside_compilation block replicated ", images.size(),
4122 " ways to node ", dst->name(),
4123 " that is not part of an outside_compilation block. Edges from "
4124 "outside_compilation to regular graph nodes are only supported "
4125 "for replication factors of 1. Leave a comment on tracking bug "
4126 "b/76419636 if you need this to be supported.");
4127 }
4128 // else the cluster is not replicated so we can leave the original
4129 // edge in place.
4130 }
4131 }
4132 // The case where src and dst are both in node_images is covered elsewhere
4133 // when iterating over in_edges of dst.
4134 }
4135 }
4136 return OkStatus();
4137 }
4138
ReplicateOutsideCompilationEdges(const OutsideCompilationNodeMap & outside_compilation_nodes,const NodeToNodeReplicasMap & node_images,const std::unordered_map<string,Node * > outside_compilation_inputs,Graph * graph)4139 /* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationEdges(
4140 const OutsideCompilationNodeMap& outside_compilation_nodes,
4141 const NodeToNodeReplicasMap& node_images,
4142 const std::unordered_map<string, Node*> outside_compilation_inputs,
4143 Graph* graph) {
4144 for (const auto& oc_cluster_iter : outside_compilation_nodes) {
4145 TF_RETURN_IF_ERROR(
4146 CopyOutsideCompilationEdges(oc_cluster_iter.second, node_images,
4147 outside_compilation_inputs, graph));
4148 }
4149 return OkStatus();
4150 }
4151
RemoveOutsideCompilationNodes(const NodeToNodeReplicasMap & node_images,Graph * graph)4152 /* static */ Status DistributedTPURewritePass::RemoveOutsideCompilationNodes(
4153 const NodeToNodeReplicasMap& node_images, Graph* graph) {
4154 for (const auto& iter : node_images) {
4155 if (iter.second.size() > 1) {
4156 // The cluster was replicated so remove the original node.
4157 Node* node = iter.first;
4158 graph->RemoveNode(node);
4159 }
4160 }
4161 return OkStatus();
4162 }
4163
4164 /* static */ Status
LowerOutsideCompilationFunctionalNodes(Graph * g,FunctionLibraryDefinition & flib_def,const TPUReplicateDeviceNamesMapping & tpu_replicate_device_names_mapping)4165 DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes(
4166 Graph* g, FunctionLibraryDefinition& flib_def,
4167 const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping) {
4168 bool modified = false;
4169 do {
4170 std::vector<Node*> nodes_to_lower;
4171 for (Node* n : g->op_nodes()) {
4172 if (!HasNodeAttr(n->def(), kOutsideCompilationAttr)) {
4173 continue;
4174 }
4175
4176 if (n->IsWhileNode() || n->IsIfNode() || IsFunctionCall(flib_def, *n)) {
4177 // Only lower functional ops with DT_RESOURCE input, because otherwise
4178 // placer will complain. For normal cases, lowering will cause slowdown
4179 // when related functions are huge (b/139037679).
4180 bool has_resource_input = false;
4181 for (const Edge* e : n->in_edges()) {
4182 if (!e->IsControlEdge() &&
4183 e->src()->output_type(e->src_output()) == DT_RESOURCE) {
4184 has_resource_input = true;
4185 break;
4186 }
4187 }
4188 if (has_resource_input) {
4189 nodes_to_lower.push_back(n);
4190 }
4191 }
4192 }
4193
4194 modified = !nodes_to_lower.empty();
4195
4196 auto lower_functional_node = [&flib_def, &g](Node* n) -> Status {
4197 // Clear device assignment. Otherwise all lowered nodes will have
4198 // device assignment, which is not what we want.
4199 n->set_requested_device("");
4200
4201 int replica_id;
4202 TF_RETURN_IF_ERROR(
4203 GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
4204
4205 string outside_compilation_attr;
4206 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kOutsideCompilationAttr,
4207 &outside_compilation_attr));
4208
4209 // There are two different kinds of functional outside compilation nodes:
4210 // 1. Nodes that are in outside compilation blocks already. They are
4211 // generated by FunctionalizeControlFlowForXlaPass, and only have
4212 // attribute kOutsideCompilationAttr.
4213 // 2. Mirrored control flow built for outside compilation in functional
4214 // nodes. They are generated by ExtractOutsideCompilationPass, and have
4215 // both kOutsideCompilationAttr and kXlaHasHostTransferAttrName.
4216 // When lowering them, they need to be treated differently.
4217 // For 1), their body functions are always V1 functions written by users,
4218 // and their "control outputs" are control inputs of _Retval nodes. They
4219 // should be lowered as V1 functions.
4220 // For 2), we always add necessary "control outputs"
4221 // (_XlaRecvAtHost/_XlaSendAtHost nodes) to "control_ret" field in their
4222 // FunctionDef's. They should be lowered as V2 functions.
4223 bool is_host_side_mirrored_control_flow =
4224 HasNodeAttr(n->def(), kXlaHasHostTransferAttrName);
4225
4226 int num_node_ids = g->num_node_ids();
4227 bool is_call_node = IsFunctionCall(flib_def, *n);
4228 if (n->IsWhileNode()) {
4229 TF_RETURN_IF_ERROR(RewriteWhileNode(n, g, &flib_def,
4230 /*keep_node_fetchable=*/false));
4231 } else if (n->IsIfNode()) {
4232 TF_RETURN_IF_ERROR(RewriteIfNode(n, g, /*keep_node_fetchable=*/false));
4233 } else {
4234 TF_RET_CHECK(is_call_node);
4235 // See comments for "is_host_side_mirrored_control_flow" above.
4236 // If this is a node that's in outside compilation block, lower it as
4237 // V1 function. This is controlled by removing
4238 // kLowerAsMultiDeviceFunctionAttr from the node.
4239 if (!is_host_side_mirrored_control_flow) {
4240 n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
4241 } else {
4242 n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
4243 n->AddAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr,
4244 true);
4245 }
4246 TF_RETURN_IF_ERROR(
4247 RewriteFunctionCallNode(n, g, flib_def,
4248 /*keep_caller_fetchable=*/false));
4249 }
4250
4251 for (int i = num_node_ids; i < g->num_node_ids(); i++) {
4252 Node* node = g->FindNodeId(i);
4253 if (!node) {
4254 continue;
4255 }
4256
4257 if (!is_call_node && is_host_side_mirrored_control_flow &&
4258 IsFunctionCall(flib_def, *node)) {
4259 // For If/While nodes, if they are host side mirrored control flow,
4260 // mark their body function calls with kXlaHasHostTransferAttrName
4261 // attribute to make sure we lower them as V2 function.
4262 node->AddAttr(kXlaHasHostTransferAttrName, true);
4263 }
4264
4265 if (IsFunctionCall(flib_def, *node) || node->IsWhileNode() ||
4266 node->IsIfNode()) {
4267 // Set kOutsideCompilationAttr attribute so we lower these
4268 // nested function call nodes later.
4269 node->AddAttr(kOutsideCompilationAttr, outside_compilation_attr);
4270 // Set kXlaReplicaIdAttrName attribute so we know replica id when we
4271 // lower this function call node.
4272 node->AddAttr(kXlaReplicaIdAttrName, replica_id);
4273 } else if (node->type_string() == "_XlaRecvAtHost" ||
4274 node->type_string() == "_XlaSendFromHost") {
4275 // For "_XlaRecvAtHost" and "_XlaSendFromHost" nodes, make sure they
4276 // have kXlaReplicaIdAttrName attribute so later we know which host
4277 // device to assign.
4278 node->AddAttr(kXlaReplicaIdAttrName, replica_id);
4279 }
4280 }
4281 return OkStatus();
4282 };
4283
4284 for (Node* n : nodes_to_lower) {
4285 TF_RETURN_IF_ERROR(lower_functional_node(n));
4286 }
4287 } while (modified);
4288
4289 // Set device for all _XlaRecvAtHost and _XlaSendFromHost nodes.
4290 for (Node* n : g->op_nodes()) {
4291 if (n->type_string() != "_XlaRecvAtHost" &&
4292 n->type_string() != "_XlaSendFromHost") {
4293 continue;
4294 }
4295
4296 string replicate;
4297 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kTPUReplicateAttr, &replicate));
4298 auto iter = tpu_replicate_device_names_mapping.find(replicate);
4299 TF_RET_CHECK(iter != tpu_replicate_device_names_mapping.end());
4300 const auto& tpu_device_names = iter->second;
4301
4302 int replica_id;
4303 TF_RETURN_IF_ERROR(
4304 GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
4305 TF_RET_CHECK(replica_id < tpu_device_names.size());
4306 const string& tpu_device_name = tpu_device_names[replica_id][0];
4307 string host_device_name;
4308 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
4309 tpu_device_name, &host_device_name));
4310 n->set_assigned_device_name(host_device_name);
4311 // We may run TPU rewrite passes again on the subgraphs of the resulting
4312 // graph. Clear kTPUReplicateAttr and kOutsideCompilationAttr for
4313 // "_XlaRecvAtHost" nodes and "_XlaSendFromHost" nodes, in order to make
4314 // sure that TPU rewrite passes take no effect on host-side subgraphs for
4315 // outside compilation.
4316 n->ClearAttr(kTPUReplicateAttr);
4317 n->ClearAttr(kOutsideCompilationAttr);
4318 }
4319
4320 // Remove IdentityN nodes generated for outside compilation. IdentityN is
4321 // exempt from resource edge colocation, but here we do need input and output
4322 // for these IdentityN nodes to be colocated.
4323 std::vector<Node*> identityn_nodes;
4324 for (Node* n : g->op_nodes()) {
4325 if (n->type_string() == "IdentityN" &&
4326 HasNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName)) {
4327 identityn_nodes.push_back(n);
4328 }
4329 }
4330 for (Node* n : identityn_nodes) {
4331 std::vector<const Edge*> out_edges(n->out_edges().begin(),
4332 n->out_edges().end());
4333 for (const Edge* e : out_edges) {
4334 if (e->IsControlEdge()) {
4335 continue;
4336 }
4337
4338 int src_output = e->src_output();
4339 const Edge* input_edge;
4340 TF_RETURN_IF_ERROR(n->input_edge(src_output, &input_edge));
4341 Node* dst = e->dst();
4342 int dst_input = e->dst_input();
4343 g->RemoveEdge(e);
4344 g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
4345 }
4346 g->RemoveNode(n);
4347 }
4348
4349 return OkStatus();
4350 }
4351
ParseHostComputeCores(const Node & replicate_node,const OutsideCompilationNodeMap & outside_compilation_nodes,HostComputeCoreMap * host_compute_core)4352 /* static */ Status DistributedTPURewritePass::ParseHostComputeCores(
4353 const Node& replicate_node,
4354 const OutsideCompilationNodeMap& outside_compilation_nodes,
4355 HostComputeCoreMap* host_compute_core) {
4356 std::vector<string> hc_core_string;
4357 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "host_compute_core",
4358 &hc_core_string));
4359 TF_RETURN_IF_ERROR(
4360 ParseHostComputeCoreList(hc_core_string, host_compute_core));
4361 for (const auto& iter : outside_compilation_nodes) {
4362 const string& oc_cluster_name = iter.first;
4363 if (host_compute_core->find(oc_cluster_name) == host_compute_core->end()) {
4364 // By default put host compute Ops on replicated core 0.
4365 (*host_compute_core)[oc_cluster_name] = 0;
4366 }
4367 }
4368 return OkStatus();
4369 }
4370
GetDeviceTopology(const DeviceSet & device_set,const Node & replicate_node,int * num_replicas,int * num_cores_per_replica,int * num_tasks,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment,string * tpu_compilation_device)4371 /* static */ Status DistributedTPURewritePass::GetDeviceTopology(
4372 const DeviceSet& device_set, const Node& replicate_node, int* num_replicas,
4373 int* num_cores_per_replica, int* num_tasks,
4374 std::vector<std::vector<string>>* tf_device_assignment,
4375 std::vector<int>* devices_to_lock,
4376 std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment,
4377 string* tpu_compilation_device) {
4378 TF_RETURN_IF_ERROR(
4379 GetNodeAttr(replicate_node.attrs(), "num_replicas", num_replicas));
4380 if (*num_replicas < 1) {
4381 return errors::InvalidArgument("num_replicas must be >= 1, got ",
4382 *num_replicas);
4383 }
4384
4385 // Find the set of TPU devices in the TF job.
4386 // Indexed by [task number][tpu device number].
4387 std::vector<std::vector<Device*>> tpu_devices;
4388 int num_tpus_per_task;
4389 TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(),
4390 device_set, tpu_compilation_device,
4391 &num_tpus_per_task, &tpu_devices));
4392 *num_tasks = tpu_devices.size();
4393
4394 string topology;
4395 TF_RETURN_IF_ERROR(
4396 GetNodeAttr(replicate_node.attrs(), "topology", &topology));
4397 TF_RETURN_IF_ERROR(GetNodeAttr(
4398 replicate_node.attrs(), "num_cores_per_replica", num_cores_per_replica));
4399 std::vector<int> device_assignment;
4400 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "device_assignment",
4401 &device_assignment));
4402
4403 // TODO(cwhipkey): since we can control multiple pods of different shapes
4404 // from a single worker, it may be desirable to propagate the remote device
4405 // information around (e.g., in DeviceAttributes). This can lead to the mesh
4406 // topology proto being leaked to cloud TPU users (e.g. through GetStatus
4407 // calls); this may be okay, but to be conservative, just assume that the
4408 // master session has the proper flags set.
4409
4410 // We do not initialize platform right now, but we can still retrieve the
4411 // TPU topology even with an uninitialized platform.
4412 auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform(
4413 /*initialize_platform=*/false);
4414 TF_RET_CHECK(tpu_platform);
4415 tpu::TpuTopologyExternal tpu_topology(tpu_platform->GetTopologyPtr());
4416 TF_RET_CHECK(num_tpus_per_task ==
4417 tpu_topology.LogicalDevicesPerHost(kTensorCore));
4418 TF_RETURN_IF_ERROR(BuildDeviceAssignment(
4419 tpu_topology, num_tpus_per_task, tpu_devices, *num_replicas,
4420 *num_cores_per_replica, topology, device_assignment, tf_device_assignment,
4421 devices_to_lock, xla_device_assignment));
4422
4423 return OkStatus();
4424 }
4425
GetIOTypes(int num_replicas,const Node & replicate_node,FunctionLibraryRuntime * flr,Graph * graph,NameRangeMap * input_name_map,const NameAttrList ** function,std::unique_ptr<Graph> * computation,DataTypeVector * arg_types,DataTypeVector * retval_types,ParameterInfo * params_info)4426 /* static */ Status DistributedTPURewritePass::GetIOTypes(
4427 int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr,
4428 Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function,
4429 std::unique_ptr<Graph>* computation, DataTypeVector* arg_types,
4430 DataTypeVector* retval_types, ParameterInfo* params_info) {
4431 DataTypeVector input_types, broadcast_input_types, guaranteed_constant_types;
4432 TF_RETURN_IF_ERROR(
4433 GetNodeAttr(replicate_node.attrs(), "Tinputs", &input_types));
4434 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "Tbroadcast_inputs",
4435 &broadcast_input_types));
4436 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
4437 "Tguaranteed_constants",
4438 &guaranteed_constant_types));
4439 int num_distributed_vars;
4440 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
4441 "num_distributed_variables",
4442 &num_distributed_vars));
4443 const int num_per_replica_inputs = input_types.size() - num_distributed_vars;
4444
4445 if (num_per_replica_inputs % num_replicas != 0) {
4446 return errors::InvalidArgument(
4447 "Number of inputs to TPUReplicate (", num_per_replica_inputs,
4448 ") is not divisible by the number of replicas (", num_replicas, ").");
4449 }
4450
4451 int num_variables;
4452 TF_RETURN_IF_ERROR(
4453 GetNodeAttr(replicate_node.attrs(), "NumVariables", &num_variables));
4454
4455 NameRangeMap output_name_map;
4456 TF_RETURN_IF_ERROR(NameRangesForNode(replicate_node, replicate_node.op_def(),
4457 input_name_map, &output_name_map));
4458
4459 TF_RETURN_IF_ERROR(
4460 GetNodeAttr(replicate_node.attrs(), "computation", function));
4461
4462 *computation = absl::make_unique<Graph>(graph->op_registry());
4463 TF_RETURN_IF_ERROR(GetComputationForTPUReplicateOp(
4464 **function, flr, computation->get(), arg_types, retval_types));
4465
4466 *params_info = ParameterInfo(
4467 num_replicas, num_per_replica_inputs / num_replicas, num_distributed_vars,
4468 broadcast_input_types.size(), num_variables,
4469 guaranteed_constant_types.size(), retval_types->size());
4470
4471 if (arg_types->size() != params_info->NumInputsToEachReplica()) {
4472 return errors::InvalidArgument(
4473 "Computation argument to TPUReplicate has wrong number of "
4474 "arguments. Expected ",
4475 params_info->NumInputsToEachReplica(), " inputs, got ",
4476 arg_types->size());
4477 }
4478 if (replicate_node.num_outputs() != params_info->NumOutputsToHost()) {
4479 return errors::InvalidArgument(
4480 "Wrong number of outputs from TPUReplicate. Expected ",
4481 params_info->NumOutputsToHost(), " outputs, got ",
4482 replicate_node.num_outputs());
4483 }
4484 if (enable_cross_replica_sharding_mirrored_variables_) {
4485 std::vector<int> mirrored_variable_indices;
4486 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
4487 TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
4488 &mirrored_variable_indices));
4489 for (int index : mirrored_variable_indices) {
4490 TF_RET_CHECK(params_info->IsPerReplicaArg(index) ||
4491 params_info->IsDistributedArg(index))
4492 << "Mirrored variables not categorized as per-replica arguments, "
4493 "index: "
4494 << index;
4495 params_info->mutable_mirrored_variable_indices()->insert(index);
4496 }
4497 }
4498 return OkStatus();
4499 }
4500
BuildSequencingNodes(const string & tpu_compilation_device,const Node & replicate_node,Graph * graph,Node ** host_transfer_sequencer,Node ** control_before,Node ** control_after)4501 /* static */ Status DistributedTPURewritePass::BuildSequencingNodes(
4502 const string& tpu_compilation_device, const Node& replicate_node,
4503 Graph* graph, Node** host_transfer_sequencer, Node** control_before,
4504 Node** control_after) {
4505 *host_transfer_sequencer = nullptr;
4506
4507 TF_RETURN_IF_ERROR(
4508 BuildNoopNode(replicate_node,
4509 graph->NewName(strings::StrCat(replicate_node.name(), "/",
4510 "control_before")),
4511 /*device=*/"", graph, control_before));
4512 for (const Edge* e : replicate_node.in_edges()) {
4513 if (!e->IsControlEdge()) {
4514 continue;
4515 }
4516 Node* predecessor = e->src();
4517 if (predecessor->IsSource()) continue;
4518 if (predecessor->type_string() == "NoOp" &&
4519 predecessor->attrs().Find("_xla_host_transfer_sequencer") != nullptr) {
4520 // The node is the sequencer for host transfer operations. Its control
4521 // dependency needs to be placed after the execute node, not before.
4522 if (*host_transfer_sequencer != nullptr) {
4523 return errors::Internal("Replicate node ", replicate_node.name(),
4524 " has two transfer sequencer nodes: ",
4525 (*host_transfer_sequencer)->name(), " and ",
4526 predecessor->name());
4527 }
4528 // Set the correct device to match the other sequencing nodes.
4529 predecessor->set_assigned_device_name(tpu_compilation_device);
4530 *host_transfer_sequencer = predecessor;
4531 } else {
4532 graph->AddControlEdge(predecessor, *control_before);
4533 }
4534 }
4535
4536 TF_RETURN_IF_ERROR(
4537 BuildNoopNode(replicate_node,
4538 graph->NewName(strings::StrCat(replicate_node.name(), "/",
4539 "control_after")),
4540 /*device=*/tpu_compilation_device, graph, control_after));
4541 for (Node* successor : replicate_node.out_nodes()) {
4542 if (successor->attrs().Find("_xla_tail_outside_compilation") != nullptr) {
4543 graph->AddControlEdge(successor, *control_after);
4544 } else {
4545 graph->AddControlEdge(*control_after, successor);
4546 }
4547 }
4548 return OkStatus();
4549 }
4550
DealWithConstantsAndVariables(const Node & replicate_node,const NameRangeMap & input_name_map,Graph * graph,Node * host_transfer_sequencer,Node * control_before,Node * control_after,absl::Span<const VariableInput> variable_nodes,std::vector<Node * > * guaranteed_constant_nodes,std::vector<Node * > * variable_reads)4551 /* static */ Status DistributedTPURewritePass::DealWithConstantsAndVariables(
4552 const Node& replicate_node, const NameRangeMap& input_name_map,
4553 Graph* graph, Node* host_transfer_sequencer, Node* control_before,
4554 Node* control_after, absl::Span<const VariableInput> variable_nodes,
4555 std::vector<Node*>* guaranteed_constant_nodes,
4556 std::vector<Node*>* variable_reads) {
4557 TF_RETURN_IF_ERROR(FindGuaranteedConstantInputs(
4558 replicate_node, input_name_map, guaranteed_constant_nodes));
4559
4560 TF_RETURN_IF_ERROR(BuildVariableReads(variable_nodes, control_before, graph,
4561 variable_reads));
4562 // Add the control dependency from host transfer nodes.
4563 if (host_transfer_sequencer != nullptr) {
4564 graph->AddControlEdge(host_transfer_sequencer, control_after);
4565 }
4566 return OkStatus();
4567 }
4568
4569 /* static */ Status
BuildCompilationStatusReturnNodes(Node * replicate_node,Node * compile_node,absl::Span<const int> devices_to_lock,Node ** control_after_compilation,Node ** multilock_acquire,Graph * graph)4570 DistributedTPURewritePass::BuildCompilationStatusReturnNodes(
4571 Node* replicate_node, Node* compile_node,
4572 absl::Span<const int> devices_to_lock, Node** control_after_compilation,
4573 Node** multilock_acquire, Graph* graph) {
4574 const Edge* compilation_edge = nullptr;
4575 for (const auto* e : replicate_node->out_edges()) {
4576 if (e->IsControlEdge() &&
4577 e->dst()->type_string() == "TPUCompilationResult") {
4578 TF_RET_CHECK(compilation_edge == nullptr)
4579 << "Multiple compilation result nodes attached to the same replicate "
4580 "cluster.";
4581 compilation_edge = e;
4582 }
4583 }
4584
4585 // TODO(jpienaar): This should be checked by default, current tests not using
4586 // this are ones that use the "abort upon successful compile flag" which will
4587 // be removed. Leaving this in until then.
4588 if (compilation_edge != nullptr) {
4589 Node* compilation_status = compilation_edge->dst();
4590 const AttrValue* compile_status_cluster_attr =
4591 compilation_status->attrs().Find(kTPUCompilationResultAttr);
4592 TF_RET_CHECK(compile_status_cluster_attr != nullptr);
4593 const string& compile_status_cluster = compile_status_cluster_attr->s();
4594 TF_RET_CHECK(!compile_status_cluster.empty());
4595 const AttrValue* replicate_cluster_attr =
4596 replicate_node->attrs().Find(kTPUReplicateAttr);
4597 TF_RET_CHECK(replicate_cluster_attr != nullptr);
4598 const string& replicate_cluster = replicate_cluster_attr->s();
4599 TF_RET_CHECK(!replicate_cluster.empty());
4600 TF_RET_CHECK(compile_status_cluster == replicate_cluster);
4601
4602 TF_RETURN_IF_ERROR(
4603 ReplaceCompilationResultNodeWithIdentity(graph, &compilation_status));
4604 graph->AddEdge(compile_node, 0, compilation_status, 0);
4605 }
4606
4607 NodeDef def;
4608 def.set_name(UniqueNodeName("tpu_compile_succeeded_assert", graph));
4609 // Create an op to assert that compilation succeeded. The alternative would
4610 // have been to have each execute op check and return an error.
4611 def.set_op("TPUCompileSucceededAssert");
4612 MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
4613 TF_ASSIGN_OR_RETURN(Node * compile_succeeded, graph->AddNode(def));
4614 compile_succeeded->set_assigned_device_name(
4615 compile_node->assigned_device_name());
4616 graph->AddEdge(compile_node, 0, compile_succeeded, 0);
4617
4618 Node* last_node_before_sequencer = compile_succeeded;
4619
4620 if (enable_multicore_locking_ && devices_to_lock.size() > 1) {
4621 // Add a lock node to acquire exclusive access to all the cores that will
4622 // execute this program. The lock is required to prevent deadlock or
4623 // incorrect results when running concurrent multi-core programs in the
4624 // same distributed runtime when there is no direct graph dependency
4625 // between the programs (either because they are run from different sessions
4626 // or because they are in the same graph, but have no control or data
4627 // dependencies to sequence them). Consider the case of two multi-core
4628 // computations A and B whose cores overlap and include cores X and Y. With
4629 // no locking and no graph dependencies it is possible that A's program
4630 // gets enqueued before B's on core X, while B's program gets enqueued
4631 // before A's on core Y. This will lead either to deadlock or to
4632 // incorrect results, since the runtime has no mechanism to re-sequence
4633 // the programs on the cores. By adding a multi-lock acquisition for all the
4634 // before any TPUExecute ops are run, and releasing it after they complete,
4635 // we ensure that the programs are enqueued on the cores in a consistent
4636 // order.
4637 //
4638 // There is a risk when computations are in the same graph, and include a
4639 // data dependency, that the lock acquisition could provoke deadlock.
4640 // Suppose that A must happen before B because B's input depends on A's
4641 // output. Then it is obviously necessary that A's lock acquisition must
4642 // happen before B's lock acquisition, and so we must ensure that there is
4643 // a graph dependency causing B's lock acquisition to be sequenced after A's
4644 // lock acquisition. Right now that dependency is satisfied because the
4645 // shape inference code cannot determine the shape of A's outputs, and so
4646 // B's compilation, which precedes B's lock acquisition, is always sequenced
4647 // after A's execution. If the shape inference is improved it will be
4648 // necessary to add an explicit control edge between dependent lock
4649 // acquisition ops.
4650 NodeDef lock_def;
4651 lock_def.set_name(graph->NewName(
4652 strings::StrCat(compile_node->name(), "/", "tpu_acquire_multilock")));
4653 lock_def.set_op("TpuMultilock");
4654 AddNodeAttr("lock_list", devices_to_lock, &lock_def);
4655 MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &lock_def);
4656 TF_ASSIGN_OR_RETURN(*multilock_acquire, graph->AddNode(lock_def));
4657 (*multilock_acquire)
4658 ->set_assigned_device_name(compile_node->assigned_device_name());
4659 graph->AddControlEdge(compile_succeeded, *multilock_acquire);
4660 last_node_before_sequencer = *multilock_acquire;
4661 } else {
4662 *multilock_acquire = nullptr;
4663 }
4664
4665 // Build a sequencing node for when compilation has completed.
4666 TF_RETURN_IF_ERROR(
4667 BuildNoopNode(*replicate_node,
4668 graph->NewName(strings::StrCat(compile_node->name(), "/",
4669 "after_compilation")),
4670 /*device=*/"", graph, control_after_compilation));
4671 graph->AddControlEdge(last_node_before_sequencer, *control_after_compilation);
4672
4673 return OkStatus();
4674 }
4675
4676 // Updates the head and tail outside compiled nodes so that nodes have the
4677 // correct device and removes the replication and outside compilation attributes
4678 // so that these nodes do not trigger further graph optimization passes.
UpdateHeadTailOutsideCompilation(const std::vector<std::vector<string>> & tf_device_assignment,const std::vector<Node * > & head_tail_outside_compilation_nodes)4679 /* static */ Status DistributedTPURewritePass::UpdateHeadTailOutsideCompilation(
4680 const std::vector<std::vector<string>>& tf_device_assignment,
4681 const std::vector<Node*>& head_tail_outside_compilation_nodes) {
4682 for (Node* node : head_tail_outside_compilation_nodes) {
4683 int replica_id;
4684 TF_RETURN_IF_ERROR(
4685 GetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id));
4686 // Since we set the device, this will now run on a task other than 0. We
4687 // clear the two following attributes so that we don't trigger encapsulation
4688 // again on the remote host (which will fail due to a missing
4689 // _TPUReplicateMetadata node for the cluster).
4690 for (const Edge* e : node->in_edges()) {
4691 // Resource consuming ops should colocate with its resource input.
4692 if (e->src()->IsArg() &&
4693 e->src()->output_type(e->src_output()) == DT_RESOURCE) {
4694 node->set_requested_device(tf_device_assignment[replica_id][0]);
4695 }
4696 }
4697 if (node->requested_device().empty()) {
4698 string cpu_device;
4699 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
4700 tf_device_assignment[replica_id][0], &cpu_device));
4701 node->set_requested_device(cpu_device);
4702 }
4703 node->ClearAttr(kTPUReplicateAttr);
4704 node->ClearAttr(kOutsideCompilationAttr);
4705 }
4706 return OkStatus();
4707 }
4708
4709 // Performs the rewrite on a single TPUReplicate node.
RewriteTPUReplicateNode(const string & session_handle,const DeviceSet & device_set,Node * replicate_node,FunctionLibraryDefinition * flib_def,FunctionLibraryRuntime * flr,Node * host_compute_key_placeholder_node,const OutsideCompilationNodeMap & outside_compilation_nodes,const std::vector<Node * > & head_tail_outside_compilation_nodes,NodeToNodeReplicasMap * outside_compilation_node_images,Graph * graph,const GraphShapeInfo & shape_info,TPUReplicateDeviceNamesMapping * tpu_replicate_device_names_mapping,int64_t autotuner_thresh)4710 /* static */ Status DistributedTPURewritePass::RewriteTPUReplicateNode(
4711 const string& session_handle, const DeviceSet& device_set,
4712 Node* replicate_node, FunctionLibraryDefinition* flib_def,
4713 FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node,
4714 const OutsideCompilationNodeMap& outside_compilation_nodes,
4715 const std::vector<Node*>& head_tail_outside_compilation_nodes,
4716 NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph,
4717 const GraphShapeInfo& shape_info,
4718 TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping,
4719 int64_t autotuner_thresh) {
4720 VLOG(2) << "Rewriting node " << replicate_node->name();
4721
4722 // num_replicas and num_cores_per_replica are the 'virtual' replicas (copies
4723 // of the computation) and cores (virtual cores within computations) specified
4724 // by the user. They will be mapped to physical TPU cores below.
4725 int num_replicas;
4726 int num_cores_per_replica;
4727 int num_tasks;
4728 std::vector<std::vector<string>> tf_device_assignment;
4729 std::vector<int> devices_to_lock;
4730 std::unique_ptr<xla::DeviceAssignment> xla_device_assignment;
4731 string tpu_compilation_device;
4732 TF_RETURN_IF_ERROR(GetDeviceTopology(
4733 device_set, *replicate_node, &num_replicas, &num_cores_per_replica,
4734 &num_tasks, &tf_device_assignment, &devices_to_lock,
4735 &xla_device_assignment, &tpu_compilation_device));
4736
4737 TF_RETURN_IF_ERROR(UpdateHeadTailOutsideCompilation(
4738 tf_device_assignment, head_tail_outside_compilation_nodes));
4739
4740 string replicate;
4741 TF_RETURN_IF_ERROR(
4742 GetNodeAttr(replicate_node->def(), kTPUReplicateAttr, &replicate));
4743 tpu_replicate_device_names_mapping->emplace(replicate, tf_device_assignment);
4744
4745 NameRangeMap input_name_map;
4746 const NameAttrList* function;
4747 std::unique_ptr<Graph> computation;
4748 DataTypeVector arg_types, retval_types;
4749 ParameterInfo params_info;
4750 TF_RETURN_IF_ERROR(GetIOTypes(num_replicas, *replicate_node, flr, graph,
4751 &input_name_map, &function, &computation,
4752 &arg_types, &retval_types, ¶ms_info));
4753
4754 std::vector<InferredShape> arg_shapes, retval_shapes;
4755 TF_RETURN_IF_ERROR(GetArgAndRetvalShapes(
4756 shape_info, *replicate_node, params_info, &arg_shapes, &retval_shapes));
4757
4758 TF_RETURN_IF_ERROR(ValidateCoreNumbers(*computation, num_cores_per_replica));
4759
4760 std::vector<xla::OpSharding> arg_sharding;
4761 std::vector<bool> arg_fast_mem;
4762 std::vector<std::string> arg_names;
4763 std::vector<xla::OpSharding> retval_sharding;
4764 TF_RETURN_IF_ERROR(AssignArgsAndRetvalsToCores(
4765 num_cores_per_replica, params_info, arg_types, arg_shapes, retval_types,
4766 retval_shapes, *computation, replicate_node, flr,
4767 allow_xla_spmd_partition_, &arg_sharding, &arg_fast_mem, &retval_sharding,
4768 &arg_names));
4769
4770 VLOG(1) << DumpGraphToFile("distributed_tpu_graph_to_replicate", *computation,
4771 flib_def);
4772
4773 GraphDef graph_def;
4774 graph->ToGraphDef(&graph_def);
4775 FunctionLibraryDefinition reachable_functions =
4776 flib_def->ReachableDefinitions(graph_def);
4777 uint64 library_fingerprint;
4778
4779 TF_RETURN_IF_ERROR(
4780 FingerprintFunctionLibrary(reachable_functions, &library_fingerprint));
4781 VLOG(1) << "Fingerprint functions: "
4782 << absl::StrJoin(reachable_functions.ListFunctionNames(), ", ");
4783 VLOG(1) << "library_fingerprint: " << library_fingerprint;
4784
4785 // Builds trigger nodes that put barriers around the expansion of
4786 // TPUReplicate. In particular, we must guarantee:
4787 // a) variable reads happen after all predecessors of the original
4788 // TPUReplicate.
4789 // b) variable writes happen before all successors of the original
4790 // TPUReplicate.
4791 // c) all replicas execute, even if output tensors are only requested from
4792 // a subset of replicas. This is necessary both to ensure that variable
4793 // updates happen, but also Send/Recv will deadlock if only one half of
4794 // the communicating pair runs.
4795 Node* host_transfer_sequencer;
4796 Node* control_before;
4797 Node* control_after;
4798 TF_RETURN_IF_ERROR(BuildSequencingNodes(
4799 tpu_compilation_device, *replicate_node, graph, &host_transfer_sequencer,
4800 &control_before, &control_after));
4801
4802 // Build a vector of variable nodes that are inputs.
4803 std::vector<VariableInput> variable_inputs;
4804 TF_RETURN_IF_ERROR(
4805 FindVariableInputs(*replicate_node, input_name_map, &variable_inputs));
4806
4807 std::vector<Node*> guaranteed_constant_nodes;
4808 std::vector<Node*> variable_reads;
4809 TF_RETURN_IF_ERROR(DealWithConstantsAndVariables(
4810 *replicate_node, input_name_map, graph, host_transfer_sequencer,
4811 control_before, control_after, variable_inputs,
4812 &guaranteed_constant_nodes, &variable_reads));
4813
4814 // Builds Shape nodes that compute the dynamic shapes of arguments whose
4815 // shapes are not statically known.
4816 std::vector<Node*> dynamic_shape_nodes;
4817 TF_RETURN_IF_ERROR(BuildDynamicShapeNodes(*replicate_node, arg_shapes,
4818 params_info, variable_reads, graph,
4819 &dynamic_shape_nodes));
4820
4821 // Builds a TPUCompile node that compiles `clusters` on `compile_device`.
4822 Node* compile_node;
4823 TF_RETURN_IF_ERROR(BuildCompileNode(
4824 replicate_node, *function, library_fingerprint, params_info, arg_shapes,
4825 arg_types, guaranteed_constant_nodes, session_handle, arg_sharding,
4826 arg_fast_mem, arg_names, retval_sharding, num_cores_per_replica,
4827 /*compile_device=*/tpu_compilation_device, xla_device_assignment.get(),
4828 dynamic_shape_nodes, graph, &compile_node, autotuner_thresh));
4829
4830 // Compilation must be sequenced after the control node if the TPU computation
4831 // in a control-flow construct, such as a loop.
4832 graph->AddControlEdge(control_before, compile_node);
4833
4834 Node* control_after_compilation;
4835 Node* multilock_acquire;
4836 TF_RETURN_IF_ERROR(BuildCompilationStatusReturnNodes(
4837 replicate_node, compile_node, devices_to_lock, &control_after_compilation,
4838 &multilock_acquire, graph));
4839
4840 std::vector<VariableWrite> variable_writes;
4841 TF_RETURN_IF_ERROR(BuildExecuteNodes(
4842 params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_names,
4843 arg_types, arg_shapes, retval_types, arg_sharding, retval_sharding,
4844 tf_device_assignment, compile_node, variable_reads,
4845 control_after_compilation, control_after, multilock_acquire,
4846 &variable_writes, graph));
4847 bool contains_resource_write_op =
4848 ContainsResourceWriteOp(*graph, reachable_functions);
4849
4850 VLOG(2) << "contains_resource_write_op: " << contains_resource_write_op;
4851 // Skip conditional write if there is no resource writing op inside TPU
4852 // computation.
4853 if (contains_resource_write_op) {
4854 TF_RETURN_IF_ERROR(BuildVariableWrites(variable_inputs, control_after,
4855 variable_writes, graph));
4856 }
4857
4858 if (host_compute_key_placeholder_node != nullptr) {
4859 TF_RETURN_IF_ERROR(ConnectHostComputeNodes(
4860 compile_node, host_compute_key_placeholder_node, graph));
4861 }
4862
4863 HostComputeCoreMap host_compute_core;
4864 TF_RETURN_IF_ERROR(ParseHostComputeCores(
4865 *replicate_node, outside_compilation_nodes, &host_compute_core));
4866 TF_RETURN_IF_ERROR(ReplicateOutsideCompilationNodes(
4867 tf_device_assignment, host_compute_core, outside_compilation_nodes,
4868 outside_compilation_node_images, graph));
4869
4870 graph->RemoveNode(replicate_node);
4871 return OkStatus();
4872 }
4873
4874 // Adds sharded weight update optimization for each host training loop.
4875 //
4876 // For any host training loop found in the graph, TPUVariableReshard ops
4877 // are inserted to match the best layout chosen by the XLA.
4878 /* static */ Status
PerformHostTrainingLoopOptimization(Graph * graph,FunctionLibraryDefinition * flib_def,FunctionLibraryRuntime * flr)4879 DistributedTPURewritePass::PerformHostTrainingLoopOptimization(
4880 Graph* graph, FunctionLibraryDefinition* flib_def,
4881 FunctionLibraryRuntime* flr) {
4882 std::vector<tpu::HostTrainingLoopInfo> host_training_loops_info;
4883 Status s = tpu::DetectHostTrainingLoop(
4884 /*current_function_name=*/nullptr,
4885 /*current_function_attr=*/nullptr, flib_def, graph, flr,
4886 &host_training_loops_info);
4887 if (!s.ok()) {
4888 VLOG(2) << "No valid host training loop found. Skipping sharded weight "
4889 << "update optimization.";
4890 return OkStatus();
4891 }
4892
4893 for (const auto& host_loop : host_training_loops_info) {
4894 const auto& function_name = host_loop.encapsulating_function_name;
4895 // `function_name` has value when host training loop is inside a
4896 // function call node. When host training loop is found inside a function
4897 // call node, then, in addition to adding TPUVariableReshard ops, function
4898 // library definition needs to be updated as well.
4899 if (function_name.has_value()) {
4900 const auto& function_attr = host_loop.encapsulating_function_attrs;
4901 TF_RET_CHECK(function_attr.has_value())
4902 << "Unable to find function attribute for function: "
4903 << *function_name;
4904
4905 const FunctionDef* function_def = flib_def->Find(*function_name);
4906 TF_RET_CHECK(function_def)
4907 << "Unable to find function : " << *function_name;
4908
4909 std::unique_ptr<FunctionBody> fbody;
4910 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
4911 *function_def, AttrSlice(&function_attr.value()), flib_def, &fbody));
4912 Graph* function_graph = fbody->graph;
4913 TF_RETURN_IF_ERROR(tpu::AddReshardOp(function_graph, host_loop));
4914 TF_RETURN_IF_ERROR(UpdateFunctionLibDefinition(*function_graph,
4915 *function_name, flib_def));
4916 } else {
4917 TF_RETURN_IF_ERROR(tpu::AddReshardOp(graph, host_loop));
4918 }
4919 }
4920 return OkStatus();
4921 }
4922
PlaceUnassignedDeviceNodesOnTPUIfPossible(Graph * graph)4923 Status DistributedTPURewritePass::PlaceUnassignedDeviceNodesOnTPUIfPossible(
4924 Graph* graph) {
4925 PropagateDevices(CanAcceptTPUDevicePropagation, IsTpuDevice, graph);
4926 return OkStatus();
4927 }
4928
Run(const GraphOptimizationPassOptions & options)4929 Status DistributedTPURewritePass::Run(
4930 const GraphOptimizationPassOptions& options) {
4931 Status status = InternalRun(options);
4932 OkOrSetErrorCounterPayload(
4933 tensorflow::core::platform::ErrorSourceProto::TF_XLA_BRIDGE, status);
4934 return status;
4935 }
4936
InternalRun(const GraphOptimizationPassOptions & options)4937 Status DistributedTPURewritePass::InternalRun(
4938 const GraphOptimizationPassOptions& options) {
4939 VLOG(1) << "DistributedTPURewritePass::Run";
4940
4941 Graph* graph = options.graph->get();
4942
4943 VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_before", *graph,
4944 options.flib_def);
4945
4946 const auto* config = &options.session_options->config;
4947 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
4948 new ProcessFunctionLibraryRuntime(
4949 nullptr, options.session_options->env, config,
4950 graph->versions().producer(), options.flib_def,
4951 config ? config->graph_options().optimizer_options()
4952 : OptimizerOptions()));
4953
4954 FunctionLibraryRuntime* flr =
4955 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
4956
4957 // This pass can only run in the session master, which should fill
4958 // in the device_set field to the options.
4959 TF_RET_CHECK(options.device_set != nullptr);
4960
4961 // Find all the replicate nodes before mutating the graph.
4962 std::vector<Node*> replicate_nodes;
4963 // Map from compiled subgraph cluster name to the outside_compilation nodes in
4964 // that cluster.
4965 std::map<string, OutsideCompilationNodeMap> outside_compilation_nodes;
4966 std::map<string, std::vector<Node*>> head_tail_outside_compilation_nodes;
4967 TF_RETURN_IF_ERROR(FindTaggedNodes(graph, &replicate_nodes,
4968 &outside_compilation_nodes,
4969 &head_tail_outside_compilation_nodes));
4970
4971 if (replicate_nodes.empty()) {
4972 // Remove unused TPUPartitionedInput nodes.
4973 for (Node* n : graph->nodes()) {
4974 if (n->type_string() == kTPUPartitionedInput) graph->RemoveNode(n);
4975 }
4976 VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_after", *graph,
4977 options.flib_def);
4978 VLOG(1) << "Replicate nodes are empty. DistributedTPURewritePass::Run() "
4979 "finished";
4980 return OkStatus();
4981 }
4982
4983 std::unordered_map<string, Node*> host_compute_key_placeholder_map;
4984 TF_RETURN_IF_ERROR(FindHostComputeKeyPlaceholderNodes(
4985 graph, replicate_nodes, &host_compute_key_placeholder_map));
4986
4987 // This shape inference pass does not compute the shapes of outputs of
4988 // TPU computations. The concurrent multi-core locking implementation
4989 // *relies* on this behavior because it ensures that, if TPU computation B's
4990 // inputs depend on TPU computation A's outputs, then computation B's
4991 // compilation will be sequenced after A's execution, and this ensures that
4992 // locks are acquired in the correct order. If the shape inference is improved
4993 // to compute shapes of TPU computation outputs, it will be necessary to add
4994 // an explicit control edge between lock acquisitions for dependent
4995 // computations in order to avoid deadlock.
4996 GraphShapeInfo shape_info;
4997 TF_RETURN_IF_ERROR(InferShapes(graph, /*arg_shapes=*/{},
4998 flr->GetFunctionLibraryDefinition(),
4999 &shape_info));
5000 int64_t autotuner_thresh = options.session_options->config.experimental()
5001 .xla_fusion_autotuner_thresh();
5002
5003 NodeToNodeReplicasMap outside_compilation_node_images;
5004 TPUReplicateDeviceNamesMapping tpu_replicate_device_names_mapping;
5005 for (Node* node : replicate_nodes) {
5006 TF_RETURN_IF_ERROR(RewriteTPUReplicateNode(
5007 options.session_handle, *options.device_set, node, options.flib_def,
5008 flr, host_compute_key_placeholder_map[node->name()],
5009 outside_compilation_nodes[node->name()],
5010 head_tail_outside_compilation_nodes[node->name()],
5011 &outside_compilation_node_images, graph, shape_info,
5012 &tpu_replicate_device_names_mapping, autotuner_thresh));
5013 }
5014
5015 // Place the padding nodes generated by dynamic padder on the correct devices.
5016 // TODO(rxsang): Place padding ops on TPUs in
5017 // PlaceUnassignedDeviceNodesOnTPUIfPossible function.
5018 TF_RETURN_IF_ERROR(SetPaddingNodesDevices(graph));
5019
5020 std::unordered_map<string, Node*> outside_compilation_inputs;
5021 for (Node* n : graph->op_nodes()) {
5022 string lifted_arg_inputs_attr;
5023 if (n->type_string() == "IdentityN" &&
5024 GetNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName,
5025 &lifted_arg_inputs_attr)
5026 .ok()) {
5027 outside_compilation_inputs[lifted_arg_inputs_attr] = n;
5028 }
5029 }
5030 for (const auto& iter : outside_compilation_nodes) {
5031 TF_RETURN_IF_ERROR(ReplicateOutsideCompilationEdges(
5032 iter.second, outside_compilation_node_images,
5033 outside_compilation_inputs, graph));
5034 }
5035 TF_RETURN_IF_ERROR(
5036 RemoveOutsideCompilationNodes(outside_compilation_node_images, graph));
5037 TF_RETURN_IF_ERROR(LowerOutsideCompilationFunctionalNodes(
5038 graph, *options.flib_def, tpu_replicate_device_names_mapping));
5039
5040 TF_RETURN_IF_ERROR(PlaceUnassignedDeviceNodesOnTPUIfPossible(graph));
5041 VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_after", *graph,
5042 options.flib_def);
5043 VLOG(1) << "DistributedTPURewritePass::Run() finished";
5044
5045 if (enable_cross_replica_sharding_mirrored_variables_) {
5046 VLOG(1) << "Starting host training loop optimization.";
5047 VLOG(1) << DumpGraphToFile("host_loop_optimization_before", *graph,
5048 options.flib_def);
5049 TF_RETURN_IF_ERROR(
5050 PerformHostTrainingLoopOptimization(graph, options.flib_def, flr));
5051 VLOG(1) << DumpGraphToFile("host_loop_optimization_after", *graph,
5052 options.flib_def);
5053 VLOG(1) << "Host training loop optimization finished.";
5054 }
5055
5056 return OkStatus();
5057 }
5058
5059 bool DistributedTPURewritePass::distribute_vars_ = false;
5060 bool DistributedTPURewritePass::allow_xla_spmd_partition_ = true;
5061 bool DistributedTPURewritePass::
5062 replicate_inputs_outputs_by_default_for_xla_spmd_ = false;
5063 bool DistributedTPURewritePass::
5064 enable_cross_replica_sharding_mirrored_variables_ = true;
5065 bool DistributedTPURewritePass::enable_automatic_model_parallelism_ = false;
5066 bool DistributedTPURewritePass::enable_xla_param_broadcast_ = true;
5067 bool DistributedTPURewritePass::enable_multicore_locking_ = false;
5068 bool DistributedTPURewritePass::use_nd_sharding_ops_ = false;
5069
SetDistributedTpuRewritePassOptions(bool distribute_vars,bool allow_xla_spmd_partition,bool replicate_inputs_outputs_by_default_for_xla_spmd,bool enable_cross_replica_sharding_mirrored_variables,bool enable_automatic_model_parallelism,bool enable_xla_param_broadcast,bool enable_multicore_locking,bool use_nd_sharding_ops)5070 /*static*/ void DistributedTPURewritePass::SetDistributedTpuRewritePassOptions(
5071 bool distribute_vars, bool allow_xla_spmd_partition,
5072 bool replicate_inputs_outputs_by_default_for_xla_spmd,
5073 bool enable_cross_replica_sharding_mirrored_variables,
5074 bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast,
5075 bool enable_multicore_locking, bool use_nd_sharding_ops) {
5076 distribute_vars_ = distribute_vars;
5077 allow_xla_spmd_partition_ = allow_xla_spmd_partition;
5078 replicate_inputs_outputs_by_default_for_xla_spmd_ =
5079 replicate_inputs_outputs_by_default_for_xla_spmd;
5080 enable_cross_replica_sharding_mirrored_variables_ =
5081 enable_cross_replica_sharding_mirrored_variables;
5082 enable_automatic_model_parallelism_ = enable_automatic_model_parallelism;
5083 enable_xla_param_broadcast_ = enable_xla_param_broadcast;
5084 enable_multicore_locking_ = enable_multicore_locking;
5085 use_nd_sharding_ops_ = use_nd_sharding_ops;
5086 }
5087
5088 } // namespace tensorflow
5089