1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
16 
17 #include <stddef.h>
18 
19 #include <algorithm>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/str_join.h"
26 #include "tensorflow/core/common_runtime/device_mgr.h"
27 #include "tensorflow/core/framework/cancellation.h"
28 #include "tensorflow/core/framework/collective.h"
29 #include "tensorflow/core/framework/device_attributes.pb.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/gtl/flatmap.h"
34 #include "tensorflow/core/lib/strings/numbers.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/protobuf/config.pb.h"
41 #include "tensorflow/core/util/device_name_utils.h"
42 
43 namespace tensorflow {
44 
CollectiveParamResolverLocal(const ConfigProto & config,const DeviceMgr * dev_mgr,DeviceResolverInterface * dev_resolver,NcclCommunicatorInterface * nccl_communicator,const string & task_name)45 CollectiveParamResolverLocal::CollectiveParamResolverLocal(
46     const ConfigProto& config, const DeviceMgr* dev_mgr,
47     DeviceResolverInterface* dev_resolver,
48     NcclCommunicatorInterface* nccl_communicator, const string& task_name)
49     : nccl_(config.experimental().collective_nccl()),
50       dev_mgr_(dev_mgr),
51       dev_resolver_(dev_resolver),
52       nccl_communicator_(nccl_communicator),
53       task_name_(task_name),
54       gpu_ring_order_(
55           config.gpu_options().experimental().collective_ring_order()) {}
56 
CompleteGroupAsync(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,const StatusCallback & done)57 void CollectiveParamResolverLocal::CompleteGroupAsync(
58     const DeviceAttributes& device, CollGroupParams* group_params,
59     CancellationManager* cancel_mgr, const StatusCallback& done) {
60   CompleteGroupLocal(device, group_params, cancel_mgr, done);
61 }
62 
63 namespace {
GetCollectiveName(const CollectiveParams * cp,bool nccl)64 const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
65   switch (cp->instance.type) {
66     case BROADCAST_COLLECTIVE:
67       return nccl ? "NcclBroadcast" : "HierarchicalTreeBroadcast";
68 
69     case REDUCTION_COLLECTIVE:
70       return nccl ? "NcclReduce" : "RingReduce";
71 
72     case GATHER_COLLECTIVE:
73       return nccl ? "NcclGather" : "RingGather";
74 
75     case PERMUTE_COLLECTIVE:
76       return "Permute";
77 
78     case ALL_TO_ALL_COLLECTIVE:
79       return "AllToAll";
80 
81     default:
82       return "undef";
83   }
84 }
85 
TaskNameFromDeviceName(const string & device_name)86 string TaskNameFromDeviceName(const string& device_name) {
87   DeviceNameUtils::ParsedName parsed_device;
88   CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_device));
89   string task_name;
90   CHECK(DeviceNameUtils::GetTaskName(parsed_device, &task_name));
91   return task_name;
92 }
93 
94 struct RankFormatter {
operator ()tensorflow::__anona90514070111::RankFormatter95   void operator()(std::string* out, CollGroupMember m) const {
96     out->append(std::to_string(m.rank));
97   }
98 };
99 
CheckUserSpecifiedRanks(const std::vector<CollGroupMember> members)100 Status CheckUserSpecifiedRanks(const std::vector<CollGroupMember> members) {
101   absl::flat_hash_set<int> user_ranks = {};
102   bool at_least_one_member_with_no_rank = false;
103   bool at_least_one_member_with_user_rank = false;
104   for (const auto& m : members) {
105     if (m.rank == -1) {
106       at_least_one_member_with_no_rank = true;
107     } else {
108       at_least_one_member_with_user_rank = true;
109       user_ranks.insert(m.rank);
110     }
111   }
112 
113   auto received_ranks = absl::StrJoin(members, ",", RankFormatter());
114   if (at_least_one_member_with_no_rank && at_least_one_member_with_user_rank) {
115     return errors::InvalidArgument(
116         "Only part of the group members have user given rank specified.",
117         "Received ranks: ", received_ranks);
118   }
119 
120   if (at_least_one_member_with_user_rank &&
121       user_ranks.size() < members.size()) {
122     return errors::InvalidArgument(
123         "Duplicate ranks specified for group members. Received ranks: ",
124         received_ranks);
125   }
126   return OkStatus();
127 }
128 }  // namespace
129 
CompleteGroupLocal(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,StatusCallback done)130 void CollectiveParamResolverLocal::CompleteGroupLocal(
131     const DeviceAttributes& device, CollGroupParams* group_params,
132     CancellationManager* cancel_mgr, StatusCallback done) {
133   VLOG(1) << "CompleteGroup device=" << device.name() << ": "
134           << group_params->ToString();
135   std::vector<StatusCallback> to_be_called;
136 
137   GroupRec* gr = nullptr;
138   Status status;
139   {
140     mutex_lock l(group_mu_);
141     auto it = group_table_.find(group_params->group_key);
142     if (it == group_table_.end()) {
143       gr = new GroupRec;
144       mutex_lock grl(gr->mu);
145       gr->group.group_key = group_params->group_key;
146       gr->group.group_size = group_params->group_size;
147       gr->group.device_type = group_params->device_type;
148       if (nccl_communicator_ != nullptr) {
149         gr->group.runtime_details.communicator_key =
150             nccl_communicator_->GenerateCommunicatorKey();
151       }
152       // Store GroupRec in group_table_ which is shared between all devices on
153       // this worker.
154       group_table_[gr->group.group_key].reset(gr);
155       VLOG(2) << "New group_key=" << gr->group.group_key
156               << " group_size=" << gr->group.group_size
157               << " runtime_details=" << gr->group.runtime_details.ToString();
158     } else {
159       gr = it->second.get();
160     }
161   }
162   {
163     mutex_lock l(status_mu_);
164     status = status_;
165   }
166   if (!status.ok()) {
167     done(status);
168     return;
169   }
170 
171   if (cancel_mgr != nullptr) {
172     CancellationToken token = cancel_mgr->get_cancellation_token();
173     bool is_cancelled = !cancel_mgr->RegisterCallback(
174         token, std::bind(&CollectiveParamResolverLocal::CancelGroup, this,
175                          group_params->group_key));
176     if (is_cancelled) {
177       done(errors::Cancelled("CompleteGroup is cancelled before it starts"));
178       return;
179     }
180     done = [cancel_mgr, token,
181             original_done = std::move(done)](const Status& status) {
182       cancel_mgr->TryDeregisterCallback(token);
183       original_done(status);
184     };
185   }
186 
187   {
188     mutex_lock gr_lock(gr->mu);
189     // If there is ever an error associated with a group key, we store the error
190     // status and invoke all waiting and future callbacks with this error
191     // status.
192     VLOG(2) << "gr device_type=" << gr->group.device_type
193             << " cp device_type=" << group_params->device_type
194             << " current device=" << device.name();
195     if (gr->status.ok()) {
196       // Check for consistency with existing GroupRec.
197       if (group_params->device_type != gr->group.device_type) {
198         gr->status = errors::Internal(
199             "Device ", device.name(),
200             " is joining a group with incompatible device type",
201             gr->group.device_type.type_string(),
202             " (group_key=", gr->group.group_key, ")");
203       } else if (group_params->group_size != gr->group.group_size) {
204         gr->status = errors::Internal(
205             "Device ", device.name(), " is joining a group with size",
206             group_params->group_size, ", but that group has size ",
207             gr->group.group_size, " (group_key=", gr->group.group_key, ")");
208       }
209     }
210     bool new_device = false;
211     if (gr->status.ok()) {
212       // Insert device if not already present.
213       auto it = gr->incarnations_by_device_name.find(device.name());
214       if (it == gr->incarnations_by_device_name.end()) {
215         if (gr->group.members.size() == gr->group.group_size) {
216           // The group is already full.
217           gr->status =
218               errors::Internal("Device ", device.name(),
219                                " is joining a group that is already full",
220                                " (group_key=", gr->group.group_key, ")");
221         } else {
222           // This is a new device that has not yet joined the group.
223           gr->incarnations_by_device_name[device.name()] = device.incarnation();
224           CollGroupMember member;
225           member.device = device;
226           if (group_params->user_specified_rank == -1 ||
227               (group_params->user_specified_rank >= 0 &&
228                group_params->user_specified_rank < gr->group.group_size)) {
229             member.rank = group_params->user_specified_rank;
230           } else {
231             gr->status = errors::InvalidArgument(
232                 "User Provided rank is invalid. It should be between [0, "
233                 "group_size)");
234           }
235           gr->group.members.push_back(std::move(member));
236           new_device = true;
237           if (VLOG_IS_ON(1)) {
238             string dev_buf;
239             for (const auto& m : gr->group.members) {
240               strings::StrAppend(&dev_buf, ",", m.device.name());
241             }
242             VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key
243                     << " group_size=" << gr->group.group_size << " (current"
244                     << " devices)=(" << dev_buf << ") (number of"
245                     << " devices pending)="
246                     << (gr->group.group_size - gr->group.members.size());
247           }
248         }
249       } else {
250         // If the device already exists, check if the incarnation matches.
251         if (it->second != device.incarnation()) {
252           gr->status = errors::FailedPrecondition(
253               "Device ", device.name(),
254               " current incarnation doesn't match with one in the group. This "
255               "usually means this worker has restarted but the collective "
256               "leader hasn't, or this worker connects to a wrong cluster.");
257         }
258       }
259     }
260 
261     if (gr->status.ok()) {
262       // If the group is not yet complete, queue to wait for it.
263       VLOG(2) << "group_size " << gr->group.group_size << " set size "
264               << gr->group.members.size() << " gr " << gr;
265 
266       if (gr->group.members.size() < gr->group.group_size) {
267         gr->pending_done.push_back(std::move(done));
268         gr->pending_params.push_back(group_params);
269         return;
270       }
271       CHECK_EQ(gr->group.members.size(), gr->group.group_size);
272       // We get a full group. Fill in remaining fields in gr->group.
273       auto st = CheckUserSpecifiedRanks(gr->group.members);
274       if (!st.ok()) {
275         gr->status = st;
276       }
277       if (new_device) {
278         FinishGroup(gr);
279       }
280       // Copy to all pending CollGroupParams;
281       *group_params = gr->group;
282       for (auto* params : gr->pending_params) {
283         *params = gr->group;
284       }
285     }
286     // At this point, we either have a full group, or an error status.  Ensure
287     // that all callbacks are invoked with the appropriate status.
288     to_be_called.swap(gr->pending_done);
289     gr->pending_params.clear();
290     status = gr->status;
291   }
292   done(status);
293   for (int i = 0; i < to_be_called.size(); ++i) {
294     to_be_called[i](status);
295   }
296 }
297 
298 namespace {
299 struct DevRec {
300   string task;
301   string device;
302   int original_rank;
303   int local_rank;
304   int global_rank;
305   const DeviceLocality* locality;
306 };
307 typedef std::unordered_map<string, DevRec> TaskDeviceMap;
308 typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap;
309 
310 // Create a populated GlobalDeviceMap from CollInstanceParams and localities.
BuildDevRecs(const CollGroupParams & gp)311 GlobalDeviceMap BuildDevRecs(const CollGroupParams& gp) {
312   GlobalDeviceMap gdm;
313   CHECK_EQ(gp.members.size(), gp.members.size());
314   for (int i = 0; i < gp.members.size(); ++i) {
315     TaskDeviceMap& tdm = gdm[gp.members[i].task];
316     DevRec* dr = &tdm[gp.members[i].device.name()];
317     dr->task = gp.members[i].task;
318     dr->device = gp.members[i].device.name();
319     dr->original_rank = i;
320     dr->local_rank = 0;   // Will be populated later by OrderTaskDeviceMap.
321     dr->global_rank = 0;  // Will be populated later by EstablishGlobalRank.
322     dr->locality = &gp.members[i].device.locality();
323   }
324   return gdm;
325 }
326 
ParseRingOrder(const string & gpu_ring_order_str,TaskDeviceMap * tdm)327 bool ParseRingOrder(const string& gpu_ring_order_str, TaskDeviceMap* tdm) {
328   std::vector<string> split_gpu_ring_order_str =
329       str_util::Split(gpu_ring_order_str, ',');
330   if (split_gpu_ring_order_str.size() != tdm->size()) return false;
331 
332   // gpu id -> local rank
333   gtl::FlatMap<int32, int32> gpu_ranks;
334   for (int32_t rank = 0;
335        rank < static_cast<int32>(split_gpu_ring_order_str.size()); ++rank) {
336     int32_t tmp;
337     if (strings::safe_strto32(split_gpu_ring_order_str[rank], &tmp)) {
338       gpu_ranks[tmp] = rank;
339     } else {
340       return false;
341     }
342   }
343 
344   for (auto& tdm_it : *tdm) {
345     DeviceNameUtils::ParsedName parsed_name;
346     DevRec* dr = &tdm_it.second;
347     if (!DeviceNameUtils::ParseFullName(dr->device, &parsed_name)) {
348       return false;
349     }
350     auto rank_it = gpu_ranks.find(parsed_name.id);
351     if (rank_it == gpu_ranks.end()) return false;
352     dr->local_rank = rank_it->second;
353   }
354   VLOG(2) << "Assigned local ranks based on ring order " << gpu_ring_order_str;
355   return true;
356 }
357 
OrderTaskDeviceMap(const string & gpu_ring_order,TaskDeviceMap * tdm)358 void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) {
359   CHECK_GT(tdm->size(), 0);  // Should never be called with 0 devices
360 
361   // If a valid ring order has been passed in via ConfigProto, use that.
362   if (ParseRingOrder(gpu_ring_order, tdm)) return;
363 
364   // Either no ring order was passed in, or the format was unexpected.
365   // We now assign a ring order based on link strengths.  Note that this
366   // algorithm is not optimal and may not always find the best ring order.
367   int least_rank = -1;
368   string next_device;
369   std::set<string> selected;
370   // Starting device is one with the least initial rank.
371   for (const auto& it : *tdm) {
372     if (least_rank < 0 || it.second.original_rank < least_rank) {
373       least_rank = it.second.original_rank;
374       next_device = it.second.device;
375     }
376   }
377   CHECK_GE(least_rank, 0);
378   DeviceNameUtils::ParsedName parsed_name;
379   CHECK(DeviceNameUtils::ParseFullName(next_device, &parsed_name));
380   // NOTE: InterconnectLink has only a device_id, nothing more, so for
381   // the time being if there's more than one device at a task we
382   // assume they're all GPUs.
383 
384   int next_rank = 0;
385   while (true) {
386     selected.insert(next_device);
387     auto next_dev_it = tdm->find(next_device);
388     CHECK(next_dev_it != tdm->end());
389     DevRec* dr = &next_dev_it->second;
390     dr->local_rank = next_rank;
391     ++next_rank;
392     if (selected.size() == tdm->size()) {
393       break;
394     }
395     // For the present time we assume Locality links only cover GPUs.
396     // For multiple CPUs, just take them in order.
397     const InterconnectLink* best_link = nullptr;
398     if (parsed_name.type == "GPU") {
399       for (const InterconnectLink& il : dr->locality->links().link()) {
400         parsed_name.id = il.device_id();
401         string endpoint_device =
402             DeviceNameUtils::ParsedNameToString(parsed_name);
403         // Skip the device if we've already seen it.
404         if (selected.find(endpoint_device) != selected.end()) {
405           continue;
406         }
407         // Skip the device if it is not participating in this collective
408         // instance.
409         if (tdm->find(endpoint_device) == tdm->end()) {
410           continue;
411         }
412         if (best_link == nullptr || il.strength() > best_link->strength()) {
413           best_link = &il;
414         }
415       }
416     }
417     if (best_link != nullptr) {
418       // Follow the best edge
419       parsed_name.id = best_link->device_id();
420       next_device = DeviceNameUtils::ParsedNameToString(parsed_name);
421     } else {
422       // No good edges, alas. Pick the lowest initial rank among remaining
423       // devices.
424       least_rank = -1;
425       for (const auto& it : *tdm) {
426         if (selected.find(it.second.device) != selected.end()) {
427           continue;
428         }
429         if (least_rank < 0 || it.second.original_rank < least_rank) {
430           least_rank = it.second.original_rank;
431           next_device = it.second.device;
432         }
433       }
434       CHECK_GE(least_rank, 0);
435     }
436   }
437 }
438 
439 // The first time a CollGroupParams is established for a group we compute a good
440 // rank order for all the devices in the group, that is appropriate for a ring
441 // algorithm.
EstablishGlobalRank(const CollGroupParams & gp,const string & gpu_ring_order)442 GlobalDeviceMap EstablishGlobalRank(const CollGroupParams& gp,
443                                     const string& gpu_ring_order) {
444   VLOG(1) << "EstablishGlobalRank";
445   GlobalDeviceMap gdm = BuildDevRecs(gp);
446   for (auto& iter : gdm) {
447     TaskDeviceMap& tdm = iter.second;
448     OrderTaskDeviceMap(gpu_ring_order, &tdm);
449   }
450   // Connect the global rank order by the lexicographical order of the tasks.
451   std::set<string> tasks;
452   for (const CollGroupMember& member : gp.members) {
453     tasks.insert(member.task);
454   }
455   int next_rank = 0;
456   for (const string& task : tasks) {
457     TaskDeviceMap* tdm = &gdm[task];
458     for (auto& it : *tdm) {
459       it.second.global_rank = it.second.local_rank + next_rank;
460     }
461     next_rank += tdm->size();
462   }
463   return gdm;
464 }
465 
466 // Count the devices associated with each task and set
467 // gp->same_num_devices_per_task.  Requires gp->task_names
468 // be sorted.
SetDevPerTask(CollGroupParams * gp)469 void SetDevPerTask(CollGroupParams* gp) {
470   gp->num_devices_per_task.clear();
471   for (const CollGroupMember& member : gp->members) {
472     gp->num_devices_per_task[member.task]++;
473   }
474   gp->same_num_devices_per_task = false;
475   int dev_per_task = -1;
476   for (const auto& task_dev : gp->num_devices_per_task) {
477     if (dev_per_task == -1) {
478       dev_per_task = task_dev.second;
479     } else if (dev_per_task != task_dev.second) {
480       return;
481     }
482   }
483   gp->same_num_devices_per_task = true;
484 }
485 
486 }  // namespace
487 
FinishGroup(GroupRec * gr)488 void CollectiveParamResolverLocal::FinishGroup(GroupRec* gr) {
489   // Populate group member task and is_local.
490   for (CollGroupMember& member : gr->group.members) {
491     member.task = TaskNameFromDeviceName(member.device.name());
492     member.is_local = member.task == task_name_;
493   }
494   // Establish the order of the members by considering localities of all
495   // devices.
496   CompleteDefaultRanking(&gr->group);
497   SetDevPerTask(&gr->group);
498   gr->group.num_tasks =
499       static_cast<int32>(gr->group.num_devices_per_task.size());
500 }
501 
CancelGroup(int32 group_key)502 void CollectiveParamResolverLocal::CancelGroup(int32 group_key) {
503   std::vector<StatusCallback> pending_done;
504   GroupRec* gr = nullptr;
505   {
506     mutex_lock l(group_mu_);
507     auto it = group_table_.find(group_key);
508     if (it == group_table_.end()) {
509       return;
510     }
511     gr = it->second.get();
512   }
513   {
514     mutex_lock l(gr->mu);
515     if (gr->group.members.size() == gr->group.group_size) {
516       // The group is already complete. There's no need to cancel.
517       return;
518     }
519     gr->status = errors::Cancelled("group is cancelled");
520     pending_done.swap(gr->pending_done);
521     gr->pending_params.clear();
522   }
523   for (const StatusCallback& done : pending_done) {
524     done(errors::Cancelled("group is cancelled"));
525   }
526 }
527 
SetDefaultRank(const string & device,CollectiveParams * cp)528 void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
529                                                   CollectiveParams* cp) {
530   CHECK_EQ(cp->group.group_size, cp->group.members.size()) << cp->ToString();
531   for (int i = 0; i < cp->group.group_size; ++i) {
532     if (cp->group.members[i].device.name() == device) {
533       cp->default_rank = i;
534     }
535     // Set member rank to default rank if not user specified.
536     if (cp->group.members[i].rank == -1) {
537       cp->group.members[i].rank = i;
538     }
539   }
540 }
541 
InitInstanceSharedParams(const CollectiveParams * cp,InstanceRec * ir)542 void CollectiveParamResolverLocal::InitInstanceSharedParams(
543     const CollectiveParams* cp, InstanceRec* ir) {
544   ir->shared->instance = cp->instance;
545   ir->shared->default_rank = -1;
546 }
547 
548 // NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
549 // to all devices that they are physically connected to and visible to the
550 // TensorFlow runtime.  This set of devices may be a superset of the devices
551 // participating in this instance of collectives.
CompleteDefaultRanking(CollGroupParams * gp)552 void CollectiveParamResolverLocal::CompleteDefaultRanking(CollGroupParams* gp) {
553   // Sort gp->member to avoid indeterminism.
554   std::sort(gp->members.begin(), gp->members.end(),
555             [](const CollGroupMember& lhs, const CollGroupMember& rhs) {
556               DeviceNameUtils::ParsedName lhs_device_name, rhs_device_name;
557               if (DeviceNameUtils::ParseFullName(lhs.device.name(),
558                                                  &lhs_device_name) &&
559                   DeviceNameUtils::ParseFullName(rhs.device.name(),
560                                                  &rhs_device_name)) {
561                 if (lhs_device_name.job == rhs_device_name.job) {
562                   if (lhs_device_name.task == rhs_device_name.task) {
563                     return lhs_device_name.id < rhs_device_name.id;
564                   } else {
565                     return lhs_device_name.task < rhs_device_name.task;
566                   }
567                 } else {
568                   return lhs_device_name.job < rhs_device_name.job;
569                 }
570               }
571               return lhs.device.name() < rhs.device.name();
572             });
573   // Establish an instance-specific default rank order for devices
574   // based on localities.  This rank order should be a good ring
575   // order, if possible.
576   GlobalDeviceMap gdm = EstablishGlobalRank(*gp, gpu_ring_order_);
577   // Reflect the new global ranking on shared
578   std::vector<CollGroupMember> new_members(gp->group_size);
579   for (const auto& git : gdm) {
580     const TaskDeviceMap& tdm = git.second;
581     for (const auto& tit : tdm) {
582       const DevRec& dr = tit.second;
583       new_members[dr.global_rank] = std::move(gp->members[dr.original_rank]);
584     }
585   }
586 
587   if (VLOG_IS_ON(2)) {
588     string buf;
589     for (const auto& m : new_members)
590       strings::StrAppend(&buf, "\n", m.device.name());
591     VLOG(2) << "Optimized device order for group " << gp->group_key << ": "
592             << buf;
593   }
594   gp->members = std::move(new_members);
595 }
596 
597 CollectiveParamResolverLocal::InstanceRec*
GetOrCreateInstanceRec(CollectiveParams * cp,bool * created)598 CollectiveParamResolverLocal::GetOrCreateInstanceRec(CollectiveParams* cp,
599                                                      bool* created) {
600   *created = false;
601   InstanceRec* irec = nullptr;
602   {
603     mutex_lock l(instance_mu_);
604     auto group_it = instance_table_.find(cp->group.group_key);
605     if (group_it != instance_table_.end()) {
606       auto instance_it = group_it->second.find(cp->instance.instance_key);
607       if (instance_it != group_it->second.end()) {
608         irec = instance_it->second.get();
609       }
610     }
611     if (irec == nullptr) {
612       // Create new InstanceRec.
613       irec = new InstanceRec;
614       *created = true;
615       {
616         mutex_lock il(irec->mu);
617         irec->known.resize(cp->group.group_size, false);
618       }
619       InitInstanceSharedParams(cp, irec);
620       instance_table_[cp->group.group_key][cp->instance.instance_key].reset(
621           irec);
622     }
623   }
624   Status status;
625   {
626     mutex_lock l(status_mu_);
627     status = status_;
628   }
629   if (!status.ok()) {
630     mutex_lock l(irec->mu);
631     irec->status = status;
632   }
633   return irec;
634 }
635 
LookupGroup(int32_t group_key,CollGroupParams * group)636 Status CollectiveParamResolverLocal::LookupGroup(int32_t group_key,
637                                                  CollGroupParams* group) {
638   mutex_lock l(group_mu_);
639   auto group_rec = group_table_.find(group_key);
640   if (group_rec == group_table_.end()) {
641     return errors::InvalidArgument("Group ", group_key,
642                                    " is not "
643                                    "initialized. Please call group "
644                                    "initialization op first before invoking "
645                                    "collective op.");
646   }
647   mutex_lock lock(group_rec->second->mu);
648   if (!group_rec->second->status.ok()) {
649     return errors::FailedPrecondition(
650         "Failed to run collective due to "
651         "unsuccessful group initialization. "
652         "Group initialization failed with error ",
653         group_rec->second->status.ToString());
654   }
655   *group = group_rec->second->group;
656   return OkStatus();
657 }
658 
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)659 void CollectiveParamResolverLocal::CompleteParamsAsync(
660     const DeviceAttributes& device, CollectiveParams* cp,
661     CancellationManager* cancel_mgr, const StatusCallback& done) {
662   VLOG(1) << "CompleteParams local " << device.name() << " for " << cp << ": "
663           << cp->ToString();
664   if (cp->run_group_initialization) {
665     CompleteGroupLocal(device, &cp->group, cancel_mgr,
666                        [this, device, cp, done](const Status& s) {
667                          if (s.ok()) {
668                            CompleteInstanceLocal(device.name(), cp, done);
669                          } else {
670                            done(s);
671                          }
672                        });
673   } else {
674     // For Collective V3 ops, group is already initialized. Fetch attributes
675     // for the already initialized group to pass to Insitance initialization.
676     const auto s = LookupGroup(cp->group.group_key, &cp->group);
677     if (s.ok()) {
678       CompleteInstanceLocal(device.name(), cp, done);
679     } else {
680       done(s);
681     }
682   }
683 }
684 
CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)685 void CollectiveParamResolverLocal::CompleteInstanceAsync(
686     const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
687     CancellationManager* cancel_mgr, const StatusCallback& done) {
688   done(
689       errors::Internal("CompleteInstance is not implemented by "
690                        "CollectiveParamResolverLocal which is "
691                        "intended only for non-distributed deployment."));
692 }
693 
694 // TODO(b/111897089): we need a better way to pick the collective
695 // implementation.  The ideal way would depend upon the topology and link
696 // strength before picking a particular implementation.
AssignCollectiveType(CollectiveParams * cp)697 void CollectiveParamResolverLocal::AssignCollectiveType(CollectiveParams* cp) {
698   // We use the NCCL implementation if this is an environment which supports
699   // NCCL, i.e. `LookupParamResolverInstance` for `NcclReduce` returns OK, and
700   // also if indicated either in `ConfigProto` or `communication_hint`.
701   //
702   // After enough testing, we may simplify this logic to use NCCL whenever
703   // available.
704   CollectiveImplementationInterface* col_impl;
705   bool use_nccl =
706       (nccl_ || cp->instance.impl_details.communication_hint == "nccl") &&
707       cp->group.device_type == DEVICE_GPU &&
708       CollectiveRegistry::LookupParamResolverInstance("NcclReduce", &col_impl)
709           .ok();
710   cp->instance.impl_details.collective_name = GetCollectiveName(cp, use_nccl);
711   VLOG(1) << "AssignCollectiveType "
712           << cp->instance.impl_details.collective_name;
713 }
714 
CompleteInstanceLocal(const string & device,CollectiveParams * cp,const StatusCallback & done)715 void CollectiveParamResolverLocal::CompleteInstanceLocal(
716     const string& device, CollectiveParams* cp, const StatusCallback& done) {
717   VLOG(1) << "CompleteInstanceLocal " << device
718           << " instance_key: " << cp->instance.instance_key << " group_key "
719           << cp->group.group_key;
720 
721   bool created_irec;
722   InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec);
723   if (!created_irec) {
724     // Check that the preexisting IRec is consistent with the params passed into
725     // this invocation.
726     if (ir->shared->instance.type != cp->instance.type ||
727         ir->shared->instance.data_type != cp->instance.data_type) {
728       done(errors::Internal("Collective instance ", cp->instance.instance_key,
729                             " expected type ", ir->shared->instance.type,
730                             " and data_type ", ir->shared->instance.data_type,
731                             " but got type ", cp->instance.type,
732                             " and data_type ", cp->instance.data_type));
733       return;
734     }
735   }
736   CompleteInstanceFromInitializedIRec(device, cp, ir, done);
737 }
738 
CompleteInstanceFromInitializedIRec(const string & device,CollectiveParams * cp,InstanceRec * ir,const StatusCallback & done)739 void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
740     const string& device, CollectiveParams* cp, InstanceRec* ir,
741     const StatusCallback& done) {
742   auto expected_shape = cp->instance.shape;
743   Status status;
744   // Populate the fields common across instance.
745   {
746     mutex_lock l(ir->mu);
747     status = ir->status;
748     if (status.ok()) {
749       // custom operator= does a deep copy.
750       cp->instance = ir->shared->instance;
751     }
752   }
753   if (!status.ok()) {
754     done(status);
755     return;
756   }
757   if (expected_shape != cp->instance.shape) {
758     done(errors::InvalidArgument(
759         "Shape mismatch in the collective instance ", cp->instance.instance_key,
760         ". Op at device ", device, " expected shape ",
761         expected_shape.DebugString(), " but another member in the group ",
762         "expected shape ", cp->instance.shape.DebugString(), ". This is likely",
763         " due to different input shapes at different members of the collective",
764         " op."));
765     return;
766   }
767   // Populate the fields common across task.
768   AssignCollectiveType(cp);
769   SetDefaultRank(device, cp);
770 
771   CollectiveImplementationInterface* col_impl;
772   status = CollectiveRegistry::LookupParamResolverInstance(
773       cp->instance.impl_details.collective_name, &col_impl);
774   if (!status.ok()) {
775     done(status);
776     return;
777   }
778 
779   //  We may need to wait for the group, if this is a broadcast, for source
780   //  discovery.
781   if (cp->instance.type == BROADCAST_COLLECTIVE) {
782     WaitForGroup(ir, cp, [col_impl, ir, device, cp, done](InstanceRec* irec) {
783       Status s;
784       if (ir != irec) {
785         s = errors::Internal("Expected ir ", ir, " and irec ", irec,
786                              " to be equal");
787       } else {
788         mutex_lock l(irec->mu);
789         s = irec->status;
790         cp->source_rank = irec->source_rank;
791       }
792       if (s.ok()) {
793         s = col_impl->InitializeCollectiveParams(cp);
794       }
795       done(s);
796     });
797   } else {
798     done(col_impl->InitializeCollectiveParams(cp));
799   }
800 }
801 
WaitForGroup(InstanceRec * ir,CollectiveParams * cp,const IRConsumer & f)802 void CollectiveParamResolverLocal::WaitForGroup(InstanceRec* ir,
803                                                 CollectiveParams* cp,
804                                                 const IRConsumer& f) {
805   std::vector<IRConsumer> ready_waiters;
806   do {
807     mutex_lock l(ir->mu);
808     if (!ir->status.ok()) {
809       break;
810     }
811     CHECK_EQ(cp->group.group_size, ir->known.size());
812     CHECK_GE(cp->default_rank, 0);
813     if (!ir->known[cp->default_rank]) {
814       ir->known[cp->default_rank] = true;
815       ++ir->known_count;
816       if (cp->is_source) {
817         // Initialize source rank.
818         if (ir->source_rank >= 0) {
819           ir->status = errors::Internal("Instance ", cp->instance.instance_key,
820                                         " already has source ", ir->source_rank,
821                                         ", received second claim from ",
822                                         cp->default_rank);
823         } else {
824           ir->source_rank = cp->default_rank;
825         }
826       }
827     }
828     if (ir->known_count < cp->group.group_size) {
829       ir->known_waiters.push_back(f);
830       return;
831     }
832     CHECK_EQ(ir->known_count, cp->group.group_size);
833     if (ir->source_rank < 0) {
834       // NOTE(ayushd): changing the error message below would also require
835       // updating CompleteParamsBroadcastForgotSend test in
836       // CollectiveParamResolverLocalTest.
837       ir->status =
838           errors::Internal("Instance ", cp->instance.instance_key,
839                            " found no source for broadcast.  This "
840                            "could mean that there were group_size=",
841                            ir->known_count, " BcastRecvs but no BcastSend.");
842     }
843     if (!ir->known_waiters.empty()) {
844       ready_waiters = std::move(ir->known_waiters);
845     }
846   } while (false);
847   f(ir);
848   for (auto& f : ready_waiters) {
849     f(ir);
850   }
851 }
852 
StartAbort(const Status & s)853 void CollectiveParamResolverLocal::StartAbort(const Status& s) {
854   {
855     mutex_lock l(status_mu_);
856     if (!status_.ok()) {
857       VLOG(2) << "CollectiveParamResolverLocal already aborted. Ignoring "
858                  "subsequent abortion with status: "
859               << s;
860       return;
861     }
862     status_ = s;
863   }
864   StartAbortLocal(s);
865 }
866 
StartAbortLocal(const Status & s)867 void CollectiveParamResolverLocal::StartAbortLocal(const Status& s) {
868   std::vector<StatusCallback> pending_done;
869   {
870     mutex_lock l(group_mu_);
871     for (const auto& item : group_table_) {
872       GroupRec* gr = item.second.get();
873       {
874         mutex_lock gl(gr->mu);
875         gr->status = s;
876         for (auto& done : gr->pending_done) {
877           pending_done.push_back(std::move(done));
878         }
879         gr->pending_done.clear();
880         gr->pending_params.clear();
881       }
882     }
883   }
884   for (const StatusCallback& done : pending_done) {
885     done(s);
886   }
887   std::vector<InstanceRec*> instances;
888   {
889     mutex_lock l(instance_mu_);
890     for (const auto& group_entry : instance_table_) {
891       for (const auto& item : group_entry.second) {
892         instances.push_back(item.second.get());
893       }
894     }
895   }
896   for (InstanceRec* ir : instances) {
897     std::vector<IRConsumer> known_waiters;
898     {
899       mutex_lock il(ir->mu);
900       ir->status = s;
901       known_waiters.swap(ir->known_waiters);
902     }
903     for (const IRConsumer& done : known_waiters) {
904       done(ir);
905     }
906   }
907 }
908 
909 }  // namespace tensorflow
910