xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_C10D_NCCL
2 
3 #include <exception>
4 #include <fstream>
5 #include <map>
6 #include <mutex>
7 #include <sstream>
8 #include <stdexcept>
9 #include <tuple>
10 #include <unordered_set>
11 #include <utility>
12 
13 #include <ATen/cuda/CUDAContext.h>
14 #include <ATen/cuda/CUDAGraph.h>
15 #include <c10/core/DeviceType.h>
16 #include <c10/cuda/CUDAAllocatorConfig.h>
17 #include <c10/cuda/CUDAGraphsC10Utils.h>
18 #include <c10/cuda/CUDAGuard.h>
19 #include <c10/util/CallOnce.h>
20 #include <c10/util/Exception.h>
21 #include <c10/util/Logging.h>
22 #include <c10/util/WaitCounter.h>
23 #include <c10/util/irange.h>
24 #include <c10/util/thread_name.h>
25 #include <torch/csrc/cuda/nccl.h>
26 #include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
27 #include <torch/csrc/distributed/c10d/NanCheck.hpp>
28 #include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
29 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
30 #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
31 #include <torch/csrc/distributed/c10d/TraceUtils.h>
32 #include <torch/csrc/distributed/c10d/Utils.hpp>
33 #include <torch/csrc/distributed/c10d/logger.hpp>
34 #include <torch/torch.h>
35 #include <optional>
36 
37 namespace c10d {
38 
39 constexpr const char* const kNCCLAbortedCommStoreKey = "NCCLABORTEDCOMM";
40 
41 namespace {
42 
43 #if defined(NCCL_MAJOR) && \
44     ((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10))
45 #define NCCL_HAS_AVG 1
46 #endif
47 
48 // NCCL op mapping
49 const std::map<ReduceOp::RedOpType, ncclRedOp_t> ncclOp = {
50     {ReduceOp::MIN, ncclMin},
51     {ReduceOp::MAX, ncclMax},
52     {ReduceOp::SUM, ncclSum},
53     {ReduceOp::PRODUCT, ncclProd},
54 #ifdef NCCL_HAS_AVG
55     {ReduceOp::AVG, ncclAvg},
56 #endif
57 };
58 
59 // NCCL type typing
60 std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
61     {at::kChar, ncclInt8},
62     {at::kByte, ncclUint8},
63     {at::kFloat, ncclFloat},
64     {at::kDouble, ncclDouble},
65     {at::kInt, ncclInt32},
66     {at::kLong, ncclInt64},
67     {at::kHalf, ncclHalf},
68     {at::kBool, ncclUint8},
69     {at::kFloat8_e5m2, ncclUint8},
70     {at::kFloat8_e4m3fn, ncclUint8},
71     {at::kFloat8_e4m3fnuz, ncclUint8},
72     {at::kFloat8_e5m2fnuz, ncclUint8},
73 #if HAS_NCCL_BF16_DATATYPE
74     {at::kBFloat16, ncclBfloat16},
75 #endif
76 };
77 
78 // Helper function that gets the data type and issues error if not supported
getNcclDataType(at::ScalarType type)79 ncclDataType_t getNcclDataType(at::ScalarType type) {
80   auto it = ncclDataType.find(type);
81   TORCH_CHECK_WITH(
82       TypeError,
83       it != ncclDataType.end(),
84       "Input tensor data type is not supported for NCCL process group: ",
85       type);
86   return it->second;
87 }
88 
complexViewAsRealAllowed(const ReduceOp reduceOp)89 bool complexViewAsRealAllowed(const ReduceOp reduceOp) {
90   switch (reduceOp) {
91     case ReduceOp::SUM:
92       return true;
93     case ReduceOp::AVG:
94       return true;
95     case ReduceOp::PREMUL_SUM:
96       return true;
97     case ReduceOp::UNUSED:
98       return true;
99     default:
100       return false;
101   }
102   return false;
103 }
104 
105 #ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT
106 template <typename T, ncclDataType_t dataType>
unpackPreMulSum(const ReduceOp & reduceOp,const ncclComm_t & comm)107 ncclRedOpRAII unpackPreMulSum(
108     const ReduceOp& reduceOp,
109     const ncclComm_t& comm) {
110   const auto* preMulSupplement =
111       reinterpret_cast<NCCLPreMulSumSupplement*>(reduceOp.supplement_.get());
112   ncclRedOp_t preMulSum;
113   bool has_tensor = preMulSupplement->tensor_factor.defined();
114   auto residence = has_tensor ? ncclScalarDevice : ncclScalarHostImmediate;
115   const T* ptr_factor = has_tensor
116       ? preMulSupplement->tensor_factor.const_data_ptr<T>()
117       : nullptr;
118   T scalar_factor = T(preMulSupplement->double_factor);
119   ncclRedOpCreatePreMulSum(
120       &preMulSum,
121       // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/ops.html#ncclredopcreatepremulsum
122       // tells us that the scalar input is strictly a multiplier.
123       /*scalar=*/has_tensor ? const_cast<T*>(ptr_factor) : &scalar_factor,
124       dataType,
125       residence,
126       comm);
127   return ncclRedOpRAII(preMulSum, comm);
128 }
129 #endif
130 
getNcclReduceOp(const ReduceOp & reduceOp,at::Tensor & input,const ncclDataType_t & dataType,const ncclComm_t & comm)131 ncclRedOpRAII getNcclReduceOp(
132     const ReduceOp& reduceOp,
133     at::Tensor& input,
134     const ncclDataType_t& dataType,
135     const ncclComm_t& comm) {
136   try {
137     if (input.scalar_type() == at::kBool) {
138       if (reduceOp == ReduceOp::SUM) {
139         // For bool tensors, map sum to max, which both represent a bitwise or.
140         // This is to prevent overflow issues with sum, since we use uint8 to
141         // represent a bool (see ncclDataType mapping).
142         return ncclMax;
143       }
144 #ifdef NCCL_HAS_AVG
145       if (reduceOp == ReduceOp::AVG) {
146         C10_THROW_ERROR(
147             TypeError, "Cannot use ReduceOp.AVG with boolean inputs");
148       }
149 #endif
150     }
151     if (reduceOp == ReduceOp::PREMUL_SUM) {
152 #ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT
153       switch (dataType) {
154         case ncclHalf:
155           return unpackPreMulSum<at::Half, ncclHalf>(reduceOp, comm);
156         case ncclFloat:
157           return unpackPreMulSum<float, ncclFloat>(reduceOp, comm);
158         case ncclDouble:
159           return unpackPreMulSum<double, ncclDouble>(reduceOp, comm);
160         default:
161           C10_THROW_ERROR(
162               TypeError, "PreMulSum Data type must be half, float, or double");
163           ncclRedOp_t unused;
164           return unused;
165       }
166 #else
167       C10_THROW_ERROR(ValueError, "PreMulSum requires NCCL>=2.11.1");
168 #endif
169     }
170     return ncclOp.at(reduceOp);
171   } catch (const std::out_of_range&) {
172     switch (reduceOp) {
173       case ReduceOp::AVG:
174         C10_THROW_ERROR(
175             ValueError,
176             c10::str(
177                 "AVG requires NCCL 2.10+. The current version is ",
178                 NCCL_MAJOR,
179                 ".",
180                 NCCL_MINOR));
181         break;
182       case ReduceOp::BAND:
183         C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with NCCL");
184         break;
185       case ReduceOp::BOR:
186         C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with NCCL");
187         break;
188       case ReduceOp::BXOR:
189         C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with NCCL");
190         break;
191       default:
192         C10_THROW_ERROR(ValueError, "Unhandled ReduceOp");
193         break;
194     }
195   }
196 }
197 
198 // Get a key string from device
getKeyFromDevice(at::Device & device)199 inline std::string getKeyFromDevice(at::Device& device) {
200   return std::to_string(device.index());
201 }
202 
getIndexFromDeviceKey(const std::string & deviceKey)203 inline at::DeviceIndex getIndexFromDeviceKey(const std::string& deviceKey) {
204   // initialize the device index to -1, which is an invalid value.
205   int index = -1;
206   try {
207     index = std::stoi(deviceKey);
208   } catch (const std::invalid_argument& e) {
209     LOG(ERROR) << c10::str(
210         "Invalid deviceKey: ", deviceKey, ",", e.what(), ".");
211   } catch (const std::out_of_range& e) {
212     LOG(ERROR) << "Out of range: " << e.what();
213   }
214   return static_cast<at::DeviceIndex>(index);
215 }
216 
getKeySendRecv(int myRank,int peer)217 std::string getKeySendRecv(int myRank, int peer) {
218   int lowRank = myRank < peer ? myRank : peer;
219   int highRank = myRank < peer ? peer : myRank;
220   std::string sendRecvPair =
221       std::to_string(lowRank) + ":" + std::to_string(highRank);
222   return sendRecvPair;
223 }
224 
225 // Get device from tensor
getDevice(at::Tensor & tensor)226 inline at::Device getDevice(at::Tensor& tensor) {
227   return tensor.device();
228 }
229 
230 // [Sync Streams] Helper that lets the input ncclStreams to wait for the current
231 // stream. NCCL communications run on ncclStreams, but input tensors are
232 // allocated on different streams (i.e., current streams). Communications on
233 // ncclStreams cannot start before pending input tensor ops on current streams
234 // finish. Otherwise, ops on two streams might read/write same tensors
235 // concurrently.
236 //
237 // The synchronization above alone is not enough. We also need to make sure
238 // input tensors are not freed before their usages on ncclStreams finish. This
239 // can be achieved by calling c10::cuda::CUDACachingAllocator::recordStream,
240 // which remembers the usage stream (ncclStream), creates an event on the usage
241 // stream when GC attempts to free the input tensor, and delays GC until that
242 // event is done.
syncStream(at::Device & device,at::cuda::CUDAEvent & ncclEvent,at::cuda::CUDAStream & ncclStream)243 void syncStream(
244     at::Device& device,
245     at::cuda::CUDAEvent& ncclEvent,
246     at::cuda::CUDAStream& ncclStream) {
247   ncclEvent.record(at::cuda::getCurrentCUDAStream(device.index()));
248   ncclEvent.block(ncclStream);
249 }
250 
251 // Given a ncclUniqueId, convert it to a string representation that can be put
252 // in the store.
buildNcclUniqueIdStr(const ncclUniqueId & ncclID)253 std::string buildNcclUniqueIdStr(const ncclUniqueId& ncclID) {
254   const uint8_t* bytes = reinterpret_cast<const uint8_t*>(&ncclID);
255   std::ostringstream oss;
256   for (const auto i : c10::irange(NCCL_UNIQUE_ID_BYTES)) {
257     oss << std::hex << static_cast<int>(bytes[i]);
258   }
259   return oss.str();
260 }
261 
getNcclAbortedCommStoreKey(const std::string ncclIdStr)262 std::string getNcclAbortedCommStoreKey(const std::string ncclIdStr) {
263   return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr;
264 }
265 
266 // Returns exception's what() given an exception_ptr instance.
getExceptionMsgFromExceptionPtr(const std::exception_ptr & exceptionPtr)267 std::string getExceptionMsgFromExceptionPtr(
268     const std::exception_ptr& exceptionPtr) {
269   TORCH_CHECK(exceptionPtr != nullptr);
270   try {
271     std::rethrow_exception(exceptionPtr);
272   } catch (const std::exception& e) {
273     return e.what();
274   } catch (...) {
275     return "Unknown exception type";
276   }
277 }
278 
errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status)279 inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) {
280   // parentheses avoid some compiler warnings
281   static const uint64_t min_version =
282       (((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6);
283   static const uint64_t cur_version = torch::cuda::nccl::version();
284   if (cur_version < min_version) {
285     TORCH_CHECK_WITH(
286         NotImplementedError,
287         status == c10::cuda::CaptureStatus::None,
288         "Capturing NCCL collectives is only allowed with NCCL >= 2.9.6");
289   }
290 }
291 
292 } // namespace
293 
294 // Map from each communicator to its device index.
295 // This map is used when register/deregister cache segments from cache
296 // allocator. See design notes below:
297 // - Each segment should be registered only to the communicator on the
298 //   same device.
299 // - We cannot reuse devNCCLCommMap_ in each ProcessGroup because the key may be
300 //   ranks rather than device in point-to-point case.
301 // - This map has also to be maintained as global variable since the register
302 //   hooks are called outside the scope of any PG, thus we need traverse
303 //   communicators in all PGs.
304 static std::unordered_map<std::shared_ptr<NCCLComm>, int> ncclCommDevIdxMap;
305 static std::mutex ncclCommDevIdxMapMutex;
306 static bool allocatorHooksAttached = false;
307 
308 std::atomic<bool> ProcessGroupNCCL::shouldDump_(false);
309 
cacheAllocatorRegisterHook(const c10::cuda::CUDACachingAllocator::TraceEntry & te)310 void cacheAllocatorRegisterHook(
311     const c10::cuda::CUDACachingAllocator::TraceEntry& te) {
312   // Register after SEGMENT_ALLOC
313   if (te.action_ !=
314       c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_ALLOC) {
315     return;
316   }
317 
318   std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
319   for (auto& it : ncclCommDevIdxMap) {
320     auto& ncclComm = it.first;
321     auto& devIdx = it.second;
322     if (te.device_ == devIdx) {
323       ncclComm->registerSegment(reinterpret_cast<void*>(te.addr_), te.size_);
324     }
325   }
326 }
327 
cacheAllocatorDeregisterHook(const c10::cuda::CUDACachingAllocator::TraceEntry & te)328 void cacheAllocatorDeregisterHook(
329     const c10::cuda::CUDACachingAllocator::TraceEntry& te) {
330   // deregister before SEGMENT_FREE
331   if (te.action_ !=
332       c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_FREE) {
333     return;
334   }
335 
336   std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
337   for (auto& it : ncclCommDevIdxMap) {
338     auto& ncclComm = it.first;
339     auto& devIdx = it.second;
340     if (te.device_ == devIdx) {
341       ncclComm->deregisterSegment(reinterpret_cast<void*>(te.addr_));
342     }
343   }
344 }
345 
346 std::unordered_map<std::string, std::unordered_map<std::string, std::string>>
getNCCLCommDumpMap()347 getNCCLCommDumpMap() {
348 #if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP)
349   std::unordered_map<
350       std::string /* ncclUniqueID */,
351       std::unordered_map<std::string, std::string> /* dump from this comm */>
352       ncclDumpMap;
353   // dump_nccl_trace is only called from the default PG (local_id_=0), but we
354   // want to dump from all comms so we need to iterate over ncclCommDevIdxMap,
355   // which is static
356   std::vector<std::shared_ptr<NCCLComm>> allNCCLComms;
357   // within the critical section, we don't want to dump while holding the lock
358   // as dump might hang
359   ncclCommDevIdxMapMutex.lock();
360   for (auto& [ncclComm, _] : ncclCommDevIdxMap) {
361     allNCCLComms.push_back(ncclComm);
362   }
363   ncclCommDevIdxMapMutex.unlock();
364   for (auto& ncclComm : allNCCLComms) {
365     std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId());
366     ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump();
367   }
368   return ncclDumpMap;
369 #else
370   return std::unordered_map<
371       std::string,
372       std::unordered_map<std::string, std::string>>();
373 #endif
374 }
375 
dump_nccl_trace(bool includeCollectives,bool includeStackTraces,bool onlyActive)376 std::string dump_nccl_trace(
377     bool includeCollectives,
378     bool includeStackTraces,
379     bool onlyActive) {
380   auto ncclDumpMap = getNCCLCommDumpMap();
381   return NCCLTraceBuffer::get()->dump(
382       ncclDumpMap, includeCollectives, includeStackTraces, onlyActive);
383 }
384 
dump_nccl_trace_json(bool includeCollectives,bool onlyActive)385 std::string dump_nccl_trace_json(bool includeCollectives, bool onlyActive) {
386   auto ncclDumpMap = getNCCLCommDumpMap();
387   return NCCLTraceBuffer::get()->dump_json(
388       ncclDumpMap, includeCollectives, onlyActive);
389 }
390 
391 std::optional<std::function<void(std::function<void(const std::string&)>)>>&
get_cpp_trace_dumper()392 get_cpp_trace_dumper() {
393   static std::optional<
394       std::function<void(std::function<void(const std::string&)>)>>
395       dumper(std::nullopt);
396   return dumper;
397 }
398 
get_gil_checker()399 gil_checker_t& get_gil_checker() {
400   static gil_checker_t gil_checker = nullptr;
401   return gil_checker;
402 }
403 
launchAsyncGilCheck()404 std::future<bool> launchAsyncGilCheck() {
405   std::promise<bool> resultPromise;
406   std::future<bool> resultFuture = resultPromise.get_future();
407   TORCH_CHECK(get_gil_checker(), "Can't check GIL with null GIL checker");
408   std::thread workerThread([promise = std::move(resultPromise)]() mutable {
409     c10::setThreadName("pt_nccl_gil_chk");
410 
411     try {
412       auto& gil_checker = get_gil_checker();
413       promise.set_value((*gil_checker)());
414     } catch (...) {
415       promise.set_exception(std::current_exception());
416     }
417   });
418 
419   // Detach the thread to allow it to run independently
420   workerThread.detach();
421 
422   return resultFuture;
423 }
424 
425 const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 100;
426 constexpr int64_t kSynchronizeBusyWaitMillis = 10;
427 thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0;
428 
operator <<(std::ostream & output,const ProcessGroupNCCL::WorkNCCL & workNCCL)429 std::ostream& operator<<(
430     std::ostream& output,
431     const ProcessGroupNCCL::WorkNCCL& workNCCL) {
432   std::string workInfo;
433   workInfo = c10::str(
434       "WorkNCCL(",
435       "SeqNum=",
436       workNCCL.seq_,
437       ", OpType=",
438       opTypeToString(workNCCL.opType_),
439       ", NumelIn=",
440       workNCCL.numelIn_,
441       ", NumelOut=",
442       workNCCL.numelOut_,
443       ", Timeout(ms)=",
444       workNCCL.opTimeout_.count(),
445       ")");
446   return output << workInfo;
447 }
448 
WorkNCCL(const std::string & pgUID,const std::string & pgDesc,at::Device & device,int rank,OpType opType,uint64_t seq,const char * profilingTitle,const std::optional<std::vector<at::Tensor>> & inputs,bool desyncDebug,bool enableTiming,bool cudaEventCacheEnabled,DebugLevel distDebugLevel)449 ProcessGroupNCCL::WorkNCCL::WorkNCCL(
450     const std::string& pgUID,
451     const std::string& pgDesc,
452     at::Device& device,
453     int rank,
454     OpType opType,
455     uint64_t seq,
456     const char* profilingTitle,
457     const std::optional<std::vector<at::Tensor>>& inputs,
458     bool desyncDebug,
459     bool enableTiming,
460     bool cudaEventCacheEnabled,
461     DebugLevel distDebugLevel)
462     : Work(rank, opType, profilingTitle, inputs),
463       pgUID_(pgUID),
464       pgDesc_(pgDesc),
465       device_(device),
466       workStartTime_(std::chrono::steady_clock::now()),
467       seq_(seq),
468       timingEnabled_(enableTiming),
469       distDebugLevel_(distDebugLevel) {
470   // Creates the CUDA event wrappers
471   // Note: The actual events are lazily created when first recorded to with
472   // DEFAULT_FLAGS = cudaEventDisableTiming.
473   if (cudaEventCacheEnabled) {
474     ncclStartEvent_ = enableTiming
475         ? ProcessGroupNCCL::CUDAEventCache::get().create(enableTiming)
476         : nullptr;
477     ncclEndEvent_ =
478         ProcessGroupNCCL::CUDAEventCache::get().create(enableTiming);
479   } else {
480     ncclStartEvent_ = enableTiming
481         ? std::make_shared<at::cuda::CUDAEvent>(cudaEventDefault)
482         : nullptr;
483     ncclEndEvent_ = std::make_shared<at::cuda::CUDAEvent>(
484         enableTiming ? cudaEventDefault : cudaEventDisableTiming);
485   }
486 }
487 
WorkNCCL(const WorkNCCL & w)488 ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
489     : Work(w.rank_, w.opType_),
490       std::enable_shared_from_this<WorkNCCL>(w),
491       pgUID_(w.pgUID_),
492       pgDesc_(w.pgDesc_),
493       device_(w.device_),
494       ncclStartEvent_(w.ncclStartEvent_),
495       ncclEndEvent_(w.ncclEndEvent_),
496       ncclComm_(w.ncclComm_),
497       blockingWait_(w.blockingWait_),
498       opTimeout_(w.opTimeout_),
499       ownedEphermeralTimeout_(w.ownedEphermeralTimeout_),
500       workStartTime_(w.workStartTime_),
501       seq_(w.seq_),
502       startTraceUpdated_(w.startTraceUpdated_),
503       numelIn_(w.numelIn_),
504       numelOut_(w.numelOut_),
505       store_(w.store_),
506       timingEnabled_(w.timingEnabled_),
507       trace_id_(w.trace_id_),
508       distDebugLevel_(w.distDebugLevel_) {
509   exception_ = w.exception_;
510 }
511 
512 ProcessGroupNCCL::WorkNCCL::~WorkNCCL() = default;
513 
isCompleted()514 bool ProcessGroupNCCL::WorkNCCL::isCompleted() {
515   if (!ncclComm_->isAborted()) {
516     checkAndSetException();
517   }
518   return exception() || finishedGPUExecutionInternal();
519 }
520 
isStarted()521 bool ProcessGroupNCCL::WorkNCCL::isStarted() {
522   if (!ncclComm_->isAborted()) {
523     checkAndSetException();
524   }
525   return exception() || startedGPUExecutionInternal();
526 }
527 
isSuccess() const528 bool ProcessGroupNCCL::WorkNCCL::isSuccess() const {
529   C10_THROW_ERROR(NotImplementedError, "WorkNCCL::isSuccess() is deprecated");
530 }
531 
checkAndSetException()532 void ProcessGroupNCCL::WorkNCCL::checkAndSetException() {
533   if (exception()) {
534     // We already have an exception.
535     return;
536   }
537 
538   auto exception_ptr = checkForNCCLErrors();
539   std::unique_lock<std::mutex> lock(mutex_);
540   exception_ = exception_ptr;
541   if (exception_) {
542     LOG(ERROR) << logPrefix() << "Collective " << *this
543                << " raised the following async exception: "
544                << getExceptionMsgFromExceptionPtr(exception_);
545   }
546 }
547 
logPrefix() const548 const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const {
549   static std::string prefix = c10::str("[Rank ", rank_, "] ");
550   return prefix;
551 }
552 
setException(std::exception_ptr exception_ptr)553 void ProcessGroupNCCL::WorkNCCL::setException(
554     std::exception_ptr exception_ptr) {
555   std::unique_lock<std::mutex> lock(mutex_);
556   exception_ = exception_ptr;
557 }
558 
559 // Helper that checks if the NCCL kernels are completed on the GPUs
finishedGPUExecution()560 bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() {
561   checkAndSetException();
562   return finishedGPUExecutionInternal();
563 }
564 
startedGPUExecutionInternal() const565 bool ProcessGroupNCCL::WorkNCCL::startedGPUExecutionInternal() const {
566   // if timing is disabled we won't have allocated start events
567   if (!timingEnabled_) {
568     return false;
569   }
570   // Checking the work's corresponding CUDA event's status
571   if (!ncclStartEvent_->query()) {
572     return false;
573   }
574   return true;
575 }
576 
finishedGPUExecutionInternal() const577 bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const {
578   // Checking the work's corresponding CUDA event's status
579   // It calls `cudaEventQuery` eventually. Although this seems to be a
580   // non-blocking call, but we did notice hangs in the past. It can
581   // hang if another thread is holding the CUDA global context lock. For
582   // example, when doing a `cudaDeviceSynchronize` or even
583   // `cudaStreamSynchronize`.
584   if (!ncclEndEvent_->query()) {
585     return false;
586   }
587   return true;
588 }
589 
checkTimeout(std::optional<std::chrono::milliseconds> timeout)590 bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
591     std::optional<std::chrono::milliseconds> timeout) {
592   STATIC_SCOPED_WAIT_COUNTER(
593       pytorch.wait_counter.ProcessGroupNCCL__checkTimeout);
594   auto currentTimepoint = std::chrono::steady_clock::now();
595   auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
596       currentTimepoint - workStartTime_);
597   auto workTimeout = timeout ? *timeout : opTimeout_;
598 
599   if (timeElapsed < workTimeout)
600     return false;
601 
602   // Timed out
603 
604   // There is already an error, we don't override it
605   if (exception())
606     return true;
607 
608   std::string exceptionMsg = c10::str(
609       logPrefix(),
610       "Watchdog caught collective operation timeout: ",
611       *this,
612       " ran for ",
613       timeElapsed.count(),
614       " milliseconds before timing out.");
615 
616   LOG(ERROR) << exceptionMsg;
617   std::exception_ptr exception_ptr =
618       std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exceptionMsg));
619   setException(exception_ptr);
620   return true;
621 }
622 
handleException(ErrorHandlingMode errorHandling)623 void ProcessGroupNCCL::WorkNCCL::handleException(
624     ErrorHandlingMode errorHandling) {
625   if (exception_) {
626     auto exceptionMsg = c10::str(
627         "Some NCCL operations have failed or timed out. Due to the ",
628         "asynchronous nature of CUDA kernels, subsequent GPU operations ",
629         "might run on corrupted/incomplete data.");
630     LOG(ERROR) << logPrefix() << exceptionMsg;
631     C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.WorkNCCL.handleException");
632 
633     if (SHOULD_TEAR_DOWN(errorHandling)) {
634       auto tearDownMsg = c10::str(
635           "To avoid data inconsistency, we are taking the entire process down.");
636       LOG(ERROR) << logPrefix() << tearDownMsg;
637       std::rethrow_exception(exception_);
638     }
639   }
640 }
641 
synchronize()642 void ProcessGroupNCCL::WorkNCCL::synchronize() {
643   // Call Synchronize without a timeout. We use this method to avoid adding a
644   // timeout argument to the public synchronize API.
645   synchronizeInternal(kNoTimeout);
646 }
647 
synchronizeStream()648 void ProcessGroupNCCL::WorkNCCL::synchronizeStream() {
649   auto currentStream = at::cuda::getCurrentCUDAStream(device_.index());
650   // Block the current stream on the NCCL stream
651   ncclEndEvent_->block(currentStream);
652 
653   if (avoidRecordStreams_) {
654     stashed_for_allocator_safety_->clear();
655   }
656 }
657 
658 // Waiting on the work's corresponding CUDA events
synchronizeInternal(std::chrono::milliseconds timeout)659 void ProcessGroupNCCL::WorkNCCL::synchronizeInternal(
660     std::chrono::milliseconds timeout) {
661   synchronizeStream();
662 
663   // In case of blocking, wait for the operation to complete.
664   if (blockingWait_) {
665     while (!isCompleted()) {
666       bool timedOut = checkTimeout(
667           timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout));
668       // Explicitly abort ncclComms here before throwing this timed out
669       // exception to users.
670       // If throwing timed out excepiton without aborting nccl communicators
671       // here, it was observed that CUDA GPU will have 100% utilization and
672       // can not run new events successfully.
673       if (timedOut) {
674         std::string exceptionMsg = c10::str(
675             logPrefix(),
676             "Work ",
677             (*this),
678             " timed out in blocking wait (TORCH_NCCL_BLOCKING_WAIT=1).");
679         LOG(ERROR) << exceptionMsg;
680         break;
681       }
682       // Yield
683       std::this_thread::sleep_for(
684           std::chrono::milliseconds(kSynchronizeBusyWaitMillis));
685     }
686     // exception() includes timeout and error during blocking wait
687     if (exception()) {
688       // Abort NCCL communicators
689       abort();
690       // Throw exception (from main thread here)
691       handleException(TearDown);
692     }
693   }
694 
695   // Device synchronize only after we've completed timeout checks.
696   if (barrierTensor_.defined()) {
697     // If we use the work to do barrier, we should block here
698     // `dist.barrier()` only requires all CPU processes to enter this
699     // function, hence we only need to make sure the dummy all-reduce has
700     // completed. So we would only need to sync the **current stream** back to
701     // host, and do not need to synchronize the entire device (which may have
702     // kernels running on other streams).
703     // Using `cudaStreamSynchronize` instead of `cudaDeviceSynchronize` can:
704     // - lower chance of hang;
705     // - CurrentCUDAStream is usually the context of the next operation in
706     // Python, thus blocking current stream would already block the next
707     // compute kernel;
708     // - achieve better barrier performance.
709     auto currentStream = at::cuda::getCurrentCUDAStream(device_.index());
710     // CUDAStream wrapper will correctly use a DeviceGuard here
711     currentStream.synchronize();
712   }
713 }
714 
715 // Same as calling synchronize().
wait(std::chrono::milliseconds timeout)716 bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) {
717   RECORD_PARAM_COMMS(
718       static_cast<int>(this->seq_), // seq
719       std::make_tuple(pgUID_, pgDesc_), // PG name tuple
720       rank_, // rank
721       "wait", // collective name
722       0, // inNelems
723       0, // outNelems
724       at::kByte, // dType
725       std::vector<int64_t>(), // inSplitSizes
726       std::vector<int64_t>(), // outSplitSizes
727       -1,
728       -1,
729       static_cast<int>(1)); // number of device?
730   synchronizeInternal(timeout);
731   // TODO(kwen2501): this should be moved to c10d tests, to qualify a NCCL
732   // upgrade. Once a NCCL version is qualified, this code should not be needed
733   // at runtime.
734 #ifdef PGNCCL_ENABLE_HASH
735   if (distDebugLevel_ >= DebugLevel::Detail) {
736     auto numel = getTensorsNumel(*outputs_);
737     auto hashValue = hashTensors(*outputs_);
738     PRINT_COLLECTIVE_HASH_SIGNATURE(
739         "output", opTypeToString(opType_), numel, hashValue);
740   }
741 #endif
742   // Always return true, because abort API is not implemented.
743   return true;
744 }
745 
abort()746 void ProcessGroupNCCL::WorkNCCL::abort() {
747   // Abort all communicators of this work
748   ncclComm_->ncclCommAbort();
749 
750   ncclCommDevIdxMapMutex.lock();
751   ncclCommDevIdxMap.erase(ncclComm_);
752   ncclCommDevIdxMapMutex.unlock();
753 }
754 
CUDAEventCache()755 ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() {}
756 
757 // CUDA event is used to record the start/end of one Work.
758 // Instead of let the CUDA event gets destroyed, we now reuse it after the Work
759 // has been erased from workMetaList_.
760 // This is to avoid the potential deadlock caused by CudaEventDestroy.
create(bool timing)761 std::shared_ptr<at::cuda::CUDAEvent> ProcessGroupNCCL::CUDAEventCache::create(
762     bool timing) {
763   auto deleter = [this, timing](at::cuda::CUDAEvent* event) {
764     std::lock_guard<std::mutex> lock(this->cacheMutex_);
765     this->eventsArray_[timing ? 1 : 0].push_back(event);
766   };
767   at::cuda::CUDAEvent* event = nullptr;
768   {
769     std::lock_guard<std::mutex> lock(cacheMutex_);
770     auto events = eventsArray_[timing ? 1 : 0];
771     if (!events.empty()) {
772       event = events.back();
773       events.pop_back();
774     }
775   }
776   if (!event) {
777     event = new at::cuda::CUDAEvent(
778         timing ? cudaEventDefault : cudaEventDisableTiming);
779   }
780   return std::shared_ptr<at::cuda::CUDAEvent>(event, std::move(deleter));
781 }
782 
get()783 ProcessGroupNCCL::CUDAEventCache& ProcessGroupNCCL::CUDAEventCache::get() {
784   static ProcessGroupNCCL::CUDAEventCache cache;
785   return cache;
786 }
787 
788 static std::atomic<size_t> process_group_id = 0;
789 
790 constexpr const char* MULTI_DEVICE_ERROR_MSG =
791     "Expecting one tensor only but got multiple. You are probably using multiple "
792     "devices under one thread. The support for such usage has been deprecated. "
793     "For details, please refer to "
794     "https://pytorch.org/docs/stable/distributed.html#multi-gpu-collective-functions. "
795     "ProcessGroupNCCL continues supporting multi-process and multi-thread modes.";
796 
ProcessGroupNCCL(const c10::intrusive_ptr<Store> & store,int rank,int size,c10::intrusive_ptr<Options> options)797 ProcessGroupNCCL::ProcessGroupNCCL(
798     const c10::intrusive_ptr<Store>& store,
799     int rank,
800     int size,
801     c10::intrusive_ptr<Options> options)
802     : Backend(rank, size),
803       store_(store),
804       options_(options),
805       ncclCommCounter_(0),
806       traceKeyStart_(getTraceStartKey("NCCL", rank)),
807       traceKeyEnd_(getTraceEndKey("NCCL", rank)),
808       terminateProcessGroup_(false),
809       terminateHeartbeatMonitorThread_(false),
810       collectiveDebugInfoMode_(false),
811       local_id_(process_group_id++),
812       intraNodeComm_(initIntraNodeComm()) {
813   TORCH_CHECK_WITH(
814       ValueError,
815       at::cuda::getNumGPUs() != 0,
816       "ProcessGroupNCCL is only supported with GPUs, no GPUs found!");
817 
818   // getNcclVersion needs to get called before launching threads which can
819   // potentially call getenv. getNcclVersion internally calls setenv to set some
820   // environment variables from config file, which can race with getenv from
821   // other threads and cause segfaults.
822   const auto ncclVersion = getNcclVersion();
823   this->setGroupUid(options_->group_name);
824   this->localDeviceCount_ = at::cuda::getNumGPUs();
825   logPrefix_ = createLogPrefix();
826   blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false);
827   asyncErrorHandling_ = static_cast<ErrorHandlingMode>(
828       getCvarInt(TORCH_NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/));
829   desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) ||
830       (dist_debug_level_ >= DebugLevel::Detail);
831   rethrowCUDAErrors_ = getCvarBool(TORCH_NCCL_RETHROW_CUDA_ERRORS, true);
832   // TODO, we should either deprecate TORCH_NCCL_DUMP_ON_TIMEOUT
833   // or change its name to reflect that dump happens on exception including
834   // both timeout and other errors.
835   dumpOnTimeoutOrEx_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) ||
836       (dist_debug_level_ >= DebugLevel::Detail);
837   // logging C++ stack isn't safe. Introduce a variable to control it.
838   logCppStackOnUncleanShutdown_ =
839       getCvarBool(TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN, true);
840   enableNanCheck_ = getCvarBool(TORCH_NCCL_NAN_CHECK, false);
841   heartbeat_ = 1ULL;
842   monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true));
843   cudaEventCacheEnabled_.store(getCvarBool(TORCH_NCCL_CUDA_EVENT_CACHE, false));
844   heartbeatTimeoutInSec_ =
845       getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 8 /*8 Mins*/);
846   waitTimeoutDumpInMilSec_ =
847       getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 60 * 1000 /*60 Sec*/);
848   coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000);
849   ncclTraceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 0);
850   enableCollecticeHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail);
851   // store_ usually is wrapped with PrefixStore and the prefix is different
852   // across different ProcessGroupNCCL(PG) instances. We need to get the
853   // underlying non-PrefixStore for sharing global information shared across
854   // different PGs.
855   PrefixStore* prefixStore = dynamic_cast<PrefixStore*>(store_.get());
856   globalStore_ =
857       prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_;
858 #ifdef ENABLE_NCCL_ERROR_CHECKING
859   enableTiming_.store(
860       getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_);
861 #endif
862   avoidRecordStreams_ = getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false);
863 #ifdef NCCL_HAS_COMM_REGISTER
864   useTensorRegisterAllocatorHook_ =
865       getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false);
866   if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
867           expandable_segments()) {
868     useTensorRegisterAllocatorHook_ = false;
869     LOG(INFO)
870         << logPrefix()
871         << "disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode.";
872   }
873 #endif
874 
875   if (blockingWait_) {
876     if (asyncErrorHandling_ != NoHandling || desyncDebug_) {
877       LOG(INFO)
878           << logPrefix() << "TORCH_NCCL_BLOCKING_WAIT and "
879           << "TORCH_NCCL_ASYNC_ERROR_HANDLING|TORCH_NCCL_DESYNC_DEBUG"
880           << "should not both be enabled. "
881           << "Only TORCH_NCCL_BLOCKING_WAIT is being used in this process.";
882       asyncErrorHandling_ = NoHandling;
883       desyncDebug_ = false;
884     }
885   } else {
886     if (desyncDebug_ && asyncErrorHandling_ == NoHandling) {
887       LOG(INFO)
888           << logPrefix()
889           << "TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING "
890           << "must both be enabled. "
891           << "Enabling TORCH_NCCL_ASYNC_ERROR_HANDLING.";
892       asyncErrorHandling_ = SkipCleanUp;
893     }
894   }
895 
896 #ifdef ENABLE_NCCL_ERROR_CHECKING
897   ncclCommWatchdogThread_ =
898       std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
899 #endif
900 
901   init();
902   const std::string OFF = "OFF";
903   std::string torch_distributed_debug =
904       getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str());
905   LOG(INFO) << logPrefix() << "ProcessGroupNCCL initialization options: "
906             << "size: " << size << ", global rank: " << globalRank()
907             << ", TIMEOUT(ms): " << options_->timeout.count()
908             << ", USE_HIGH_PRIORITY_STREAM: "
909             << options_->is_high_priority_stream
910             << ", SPLIT_FROM: " << options_->split_from
911             << ", SPLIT_COLOR: " << options_->split_color
912             << ", PG Name: " << options_->group_name;
913 
914   LOG(INFO) << logPrefix() << "ProcessGroupNCCL environments: "
915             << "NCCL version: " << ncclVersion
916             << ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_
917             << ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnTimeoutOrEx_
918             << ", TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: "
919             << waitTimeoutDumpInMilSec_
920             << ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_
921             << ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load()
922             << ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_
923             << ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
924 #ifdef NCCL_HAS_COMM_REGISTER
925             << ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: "
926             << useTensorRegisterAllocatorHook_
927 #endif
928             << ", TORCH_NCCL_ENABLE_MONITORING: "
929             << monitorThreadEnabled_.load()
930             << ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_
931             << ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_
932             << ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_
933             << ", TORCH_NCCL_NAN_CHECK: " << enableNanCheck_
934             << ", TORCH_NCCL_CUDA_EVENT_CACHE: " << cudaEventCacheEnabled_
935             << ", TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: "
936             << logCppStackOnUncleanShutdown_;
937 
938   if (options_->global_ranks_in_group.empty()) {
939     this->globalRankStart = 0;
940   } else {
941     this->globalRankStart = options_->global_ranks_in_group[0];
942   }
943 
944   if (options_->global_ranks_in_group.empty()) {
945     this->globalRankStride = 1;
946   } else if (options_->global_ranks_in_group.size() == 1) {
947     this->globalRankStride = 0;
948   } else {
949     bool ranksAreStrided = true;
950     int startRank = options_->global_ranks_in_group[0];
951     int stride =
952         options_->global_ranks_in_group[1] - options_->global_ranks_in_group[0];
953     for (std::vector<uint64_t>::size_type i = 0;
954          i < options_->global_ranks_in_group.size();
955          i++) {
956       if (options_->global_ranks_in_group[i] != startRank + i * stride) {
957         ranksAreStrided = false;
958         break;
959       }
960     }
961 
962     if (ranksAreStrided) {
963       this->globalRankStride = options_->global_ranks_in_group[1] -
964           options_->global_ranks_in_group[0];
965     } else {
966       this->globalRankStride = -1;
967     }
968   }
969 
970   // Attach hooks to cache allocator to trigger the hooks whenever a traced
971   // action is called. In the following hooks, we register a newly allocated
972   // segment when SEGMENT_ALLOC action occurs, and deregister a segment when
973   // SEGMENT_FREE action occurs.
974   // We attach hooks only once at the first PG creation.
975   // Attaching hooks fails if CUDACachingAllocator is not initialized, so
976   // lazyInitCUDA is called (and is a no-op if CUDA is already initialized).
977   if (useTensorRegisterAllocatorHook_ && !allocatorHooksAttached) {
978     at::globalContext().lazyInitCUDA();
979     c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker(
980         &cacheAllocatorRegisterHook);
981     c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker(
982         &cacheAllocatorDeregisterHook);
983     allocatorHooksAttached = true;
984   }
985 }
986 
eagerConnectSingleDevice(at::Device device)987 void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) {
988   const auto key = getKeyFromDevice(device);
989   LOG(INFO) << logPrefix() << "Eagerly connecting nccl backend with device "
990             << device;
991   getNCCLComm(key, device, OpType::ALLREDUCE);
992 }
993 
performNocolorSplit(at::Device device)994 void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
995   // If our backend doesn't support splitting, this is a no-op for
996   // ranks not in the new subgroup (and ranks that would be in it will
997   // just use a new communicator rather than split).
998 #ifdef NCCL_HAS_COMM_SPLIT
999   const auto key = getKeyFromDevice(device);
1000   LOG(INFO) << logPrefix() << "Performing nocolor split on backend device "
1001             << device << ", key " << key << ", i am " << this;
1002   auto comm = getNCCLComm(key, device, OpType::ALLREDUCE);
1003   NCCLComm::split(
1004       comm.get(),
1005       NCCL_SPLIT_NOCOLOR,
1006       rank_,
1007       options_->config,
1008       options_->global_ranks_in_group);
1009 #endif
1010 }
1011 
1012 c10::intrusive_ptr<intra_node_comm::IntraNodeComm> ProcessGroupNCCL::
initIntraNodeComm()1013     initIntraNodeComm() {
1014   using IntraNodeComm = intra_node_comm::IntraNodeComm;
1015   if (!IntraNodeComm::isEnabled()) {
1016     return nullptr;
1017   }
1018   auto prefixStore = c10::make_intrusive<PrefixStore>("IntraNodeComm", store_);
1019   auto comm = c10::make_intrusive<IntraNodeComm>(prefixStore, rank_, size_);
1020   if (comm->rendezvous()) {
1021     return comm;
1022   } else {
1023     return nullptr;
1024   }
1025 }
1026 
setSequenceNumberForGroup()1027 void ProcessGroupNCCL::setSequenceNumberForGroup() {
1028 } // NCCL just starts sequence numbers at 0.
1029 
getSequenceNumberForGroup()1030 uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() {
1031   return seqCollective_;
1032 }
1033 
registerOnCompletionHook(std::function<void (std::shared_ptr<WorkInfo>)> && hook)1034 void ProcessGroupNCCL::registerOnCompletionHook(
1035     std::function<void(std::shared_ptr<WorkInfo>)>&& hook) {
1036   TORCH_CHECK_WITH(
1037       DistBackendError,
1038       onCompletionHook_ == nullptr,
1039       "ProcessGroupNCCL OnCompletion hook already registered");
1040 
1041   TORCH_CHECK_WITH(
1042       ValueError,
1043       enableTiming_.load(),
1044       "ProcessGroupNCCL OnCompletion hook requires recording start and end "
1045       "events which require setting TORCH_NCCL_ENABLE_TIMING environment variable. "
1046       "This is only available for NCCL version >= 2.4.");
1047   onCompletionHook_ = std::move(hook);
1048   onCompletionHookThread_ = std::thread(&ProcessGroupNCCL::runHookLoop, this);
1049 }
1050 
1051 // must release GIL when calling this method
waitForPendingWorks()1052 void ProcessGroupNCCL::waitForPendingWorks() {
1053   // Reasoning about hook completion:
1054   // 1. waitForPendingWorks should be called after user code has finished
1055   // calling
1056   //    all collectives. This means, when we got here, all of the collectives
1057   //    are either in workMetaList_ or has been erased from workMetaList_.
1058   // 2. The watchdog thread grabs both locks to move Work object from the
1059   //    workMetaList_ to the completedWorkList_, and the hook thread only erases
1060   //    a Work object after the hook is returned. Therefore, after user code
1061   //    calls a collective, its Work object is either in workMetaList_ or in
1062   //    completedWorkList_ before it finishes.
1063   // 3. We have three threads and two locks.
1064   //      a. main thread (this function) grabs two locks atomically
1065   //      b. watchdog thread (watchdogHandler function) always grabs
1066   //      workMetaListMutex_
1067   //         first and then grabs completedWorkListMutex_.
1068   //      c. hook thread (runHookLoop function) only grabs
1069   //      completedWorkListMutex_. Therefore, locks are always acquired in the
1070   //      same order and hence no deadlocks.
1071   while (true) {
1072     {
1073       std::lock(workMetaListMutex_, completedWorkListMutex_);
1074       std::lock_guard<std::mutex> lockWork(workMetaListMutex_, std::adopt_lock);
1075       std::lock_guard<std::mutex> lockHook(
1076           completedWorkListMutex_, std::adopt_lock);
1077 
1078       if (workMetaList_.empty() && completedWorkList_.empty()) {
1079         return;
1080       }
1081     }
1082 
1083     std::this_thread::sleep_for(
1084         std::chrono::milliseconds(kWatchdogThreadSleepMillis));
1085   }
1086 }
1087 
enableCollectivesTiming()1088 void ProcessGroupNCCL::enableCollectivesTiming() {
1089   enableTiming_.store(true);
1090 }
1091 
waitForFutureOrTimeout(std::future<bool> & fut,const std::chrono::milliseconds & timeOutMilSec,const std::string & futDescription,bool throwException,bool log)1092 void ProcessGroupNCCL::waitForFutureOrTimeout(
1093     std::future<bool>& fut,
1094     const std::chrono::milliseconds& timeOutMilSec,
1095     const std::string& futDescription,
1096     bool throwException,
1097     bool log) {
1098   std::string errorMsg;
1099 
1100   ::c10d::C10dLoggingData data;
1101   if (log) {
1102     data.integers["pg_id"] = local_id_;
1103     data.integers["rank"] = rank_;
1104     data.integers["global_rank"] = globalRank();
1105     data.strings["flight_recorder_version"] = c10d::version_val_str;
1106   }
1107 
1108   TORCH_CHECK(fut.valid(), "Expected a valid future");
1109   std::future_status status = fut.wait_for(timeOutMilSec);
1110   if (status == std::future_status::ready) {
1111     // Calling .get() will re-raise any exception from the future, and we don't
1112     // care about the retval
1113     try {
1114       bool result = fut.get();
1115       if (result) {
1116         LOG(INFO) << logPrefix()
1117                   << "future is successfully executed for: " << futDescription;
1118         if (log) {
1119           data.strings["status"] = "SUCCESS";
1120         }
1121       }
1122     } catch (const std::exception& e) {
1123       errorMsg = c10::str(
1124           logPrefix(),
1125           "Exception thrown when waiting for future ",
1126           futDescription,
1127           ": ",
1128           e.what());
1129       if (log) {
1130         data.strings["status"] = "EXCEPTION";
1131         data.strings["exception"] = e.what();
1132       }
1133       LOG(ERROR) << errorMsg;
1134     } catch (...) {
1135       errorMsg = c10::str(
1136           logPrefix(),
1137           "Unknown exception thrown when waiting for future ",
1138           futDescription);
1139       if (log) {
1140         data.strings["status"] = "EXCEPTION";
1141         data.strings["exception"] = "Unknown exception";
1142       }
1143       LOG(ERROR) << errorMsg;
1144     }
1145   } else {
1146     errorMsg = c10::str(
1147         logPrefix(),
1148         "Future for ",
1149         futDescription,
1150         " timed out after ",
1151         timeOutMilSec.count(),
1152         " ms");
1153     data.strings["status"] = "TIMEOUT";
1154     LOG(ERROR) << errorMsg;
1155   }
1156   if (log) {
1157     auto logger = c10d::C10dLogger::getLogger();
1158     if (logger) {
1159       logger->log(data);
1160     }
1161   }
1162   if (throwException && !errorMsg.empty()) {
1163     C10_THROW_ERROR(DistBackendError, errorMsg);
1164   }
1165 }
1166 
abortCommsFromMap(std::unordered_map<std::string,std::shared_ptr<NCCLComm>> & ncclCommsMap,std::optional<std::string> abortReason)1167 void ProcessGroupNCCL::abortCommsFromMap(
1168     std::unordered_map<std::string, std::shared_ptr<NCCLComm>>& ncclCommsMap,
1169     std::optional<std::string> abortReason) {
1170   // The process may control multiple devices, loop through the communicators on
1171   // each device
1172   for (auto& it : ncclCommsMap) {
1173     auto& devName = it.first;
1174     auto& ncclComm = it.second;
1175     at::cuda::OptionalCUDAGuard gpuGuard;
1176     at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName);
1177     if (deviceIndex >= 0) {
1178       // For P2P comms, the deviceIndex could be -1 (invalid), as the keys in
1179       // the map could be non deviceIndex, but rank to rank numbers. So we
1180       // indeed need to check if deviceIndex >= 0
1181       // TODO: fix `getIndexFromDeviceKey` or fix `DeviceKey`
1182       gpuGuard.set_index(deviceIndex);
1183     }
1184     LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ "
1185               << ncclComm->ncclComm_ << " on CUDA device: " << devName;
1186     ncclComm->ncclCommAbort(abortReason);
1187     // Note that we don't remove the aborted communicators from the
1188     // cache. The reason is that if we do remove the communicator
1189     // from the cache, it is possible that a new collective operation
1190     // calls `ncclCommInitRank` to create a new communicator whereas
1191     // other ranks might have failed/timed out and didn't enter
1192     // `ncclCommInitRank`. As a result, when there is a failure on
1193     // a communicator the application receives an exception and its
1194     // their responsibility to destroy the process group and recreate
1195     // it to recover from errors.
1196 
1197     LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroyed "
1198               << " communicator on CUDA device: " << devName;
1199   }
1200 }
1201 
1202 // Abort all communicators on this rank
abort(std::optional<std::string> abortReason)1203 bool ProcessGroupNCCL::abort(std::optional<std::string> abortReason) {
1204   // Remove record from global ncclCommDevIdxMapMutex before aboarting,
1205   // so that a new cache segment would not register to already aborded
1206   // communicators. Note that ncclCommDevIdxMap is a global container which may
1207   // contain other PG's communicators, thus we need to only erase communicators
1208   // for the current PG.
1209   ncclCommDevIdxMapMutex.lock();
1210   for (auto& it : devNCCLCommMap_) {
1211     auto& ncclComm = it.second;
1212     ncclCommDevIdxMap.erase(ncclComm);
1213   }
1214   ncclCommDevIdxMapMutex.unlock();
1215 
1216   std::lock_guard<std::mutex> lock(mutex_);
1217   abortCommsFromMap(devNCCLCommMap_, abortReason);
1218   abortCommsFromMap(inInitializationCommMap_, abortReason);
1219   return true;
1220 }
1221 
shutdown(std::optional<std::string> reason)1222 void ProcessGroupNCCL::shutdown(std::optional<std::string> reason) {
1223   // Don't join threads here since the purpose of this method is to abort all
1224   // communicators and signal the threads to exit. Joining on the threads could
1225   // potentially block and hence avoid it in this method.
1226   terminateProcessGroup_.store(true);
1227   workMetaListCV_.notify_one();
1228 
1229   // lauch abort asynchrounously and wait for it to complete or timeout
1230   LOG(INFO) << logPrefix()
1231             << "Launching ProcessGroupNCCL abort asynchrounously.";
1232   std::future<bool> fut = std::async(
1233       std::launch::async, [this, &reason]() { return this->abort(reason); });
1234 
1235   waitForFutureOrTimeout(
1236       fut, options_->timeout, "ProcessGroup abort", true, false);
1237   LOG(INFO) << logPrefix() << "ProcessGroupNCCL aborts successfully.";
1238 
1239   // We need to wait for abort to finish before we can safely shut down
1240   // heartbeat monitoring thread.
1241   terminateHeartbeatMonitorThread_.store(true);
1242   monitorWakeUpCV_.notify_one();
1243 }
1244 
~ProcessGroupNCCL()1245 ProcessGroupNCCL::~ProcessGroupNCCL() {
1246   LOG(INFO) << logPrefix() << "ProcessGroupNCCL destructor entered.";
1247 
1248   if (!terminateProcessGroup_.load()) {
1249     if (rank_ % localDeviceCount_ == 0) {
1250       TORCH_WARN_ONCE(
1251           "WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. ",
1252           "On normal program exit, the application should call destroy_process_group to ",
1253           "ensure that any pending NCCL operations have finished in this process. "
1254           "In rare cases this process can exit before this point and block the progress of "
1255           "another member of the process group. This constraint has always been present, "
1256           " but this warning has only been added since PyTorch 2.4");
1257     }
1258     // If user haven't explicitly destroy/shutdown process group, destructor
1259     // needs to do so
1260     shutdown();
1261   }
1262 
1263   // Wait for all threads to finish before returning
1264 #ifdef ENABLE_NCCL_ERROR_CHECKING
1265   if (ncclCommWatchdogThread_.joinable()) {
1266     ncclCommWatchdogThread_.join();
1267     LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined.";
1268   }
1269   if (ncclHeartbeatMonitorThread_.joinable()) {
1270     ncclHeartbeatMonitorThread_.join();
1271     LOG(INFO) << logPrefix()
1272               << "ProcessGroupNCCL heart beat monitor thread joined.";
1273   }
1274 #endif
1275   if (onCompletionHookThread_.joinable()) {
1276     onCompletionHookThread_.join();
1277     LOG(INFO) << logPrefix()
1278               << "ProcessGroupNCCL onCompletionHookThread thread joined.";
1279   }
1280 }
1281 
dumpDebuggingInfo()1282 bool ProcessGroupNCCL::dumpDebuggingInfo() {
1283   // Serialize all calls to this function to avoid corrupting data, but allow
1284   // multiple calls in one runtime. User is responsible for preserving the
1285   // output file from an earlier call before a later call overwrites it.
1286   static std::mutex writeDebugInfoMutex;
1287   std::lock_guard<std::mutex> lock(writeDebugInfoMutex);
1288   LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info.";
1289   if (ncclTraceBufferSize_ > 0) {
1290     // We dump nccl trace into local disk by default and users can register
1291     // their customized writer by inheriting `DebugInfoWriter` via
1292     // `registerDebugInfoWriter`.
1293     auto ncclTrace = dump_nccl_trace(true, true, false);
1294     DebugInfoWriter& writer = DebugInfoWriter::getWriter(globalRank());
1295     LOG(INFO) << logPrefix() << "ProcessGroupNCCL dumping nccl trace to "
1296               << writer.getWriterTarget();
1297     writer.write(ncclTrace);
1298     return true;
1299   }
1300   return false;
1301 }
1302 
terminateProcess(std::string errMsg)1303 void ProcessGroupNCCL::terminateProcess(std::string errMsg) {
1304   // Logging with `FATAL`, after errMsg printed, it calls `std::abort()`
1305   // to terminate the program execution.
1306   LOG(FATAL) << logPrefix() << errMsg;
1307 }
1308 
computeDeltaMS(std::chrono::time_point<std::chrono::steady_clock> start,std::chrono::time_point<std::chrono::steady_clock> end)1309 int computeDeltaMS(
1310     std::chrono::time_point<std::chrono::steady_clock> start,
1311     std::chrono::time_point<std::chrono::steady_clock> end) {
1312   return std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
1313       .count();
1314 }
1315 
getNCCLWatchdogTimeoutErrorMsg(const std::string & extraMsg)1316 std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutErrorMsg(
1317     const std::string& extraMsg) {
1318   return c10::str(
1319       logPrefix(),
1320       "Received a dump signal due to a collective timeout from ",
1321       extraMsg,
1322       " and we will try our best to dump the debug info. ",
1323       "Last enqueued NCCL work: ",
1324       pgStatus_->lastEnqueuedSeq,
1325       ", last completed NCCL work: ",
1326       pgStatus_->lastCompletedSeq,
1327       ".",
1328       "This is most likely caused by incorrect usages of collectives, e.g., wrong ",
1329       "sizes used across ranks, the order of collectives is not same for all ranks ",
1330       "or the scheduled collective, for some reason, didn't run. Additionally, ",
1331       "this can be caused by GIL deadlock or other reasons such as network errors or ",
1332       "bugs in the communications library (e.g. NCCL), etc. ");
1333 }
1334 
getNCCLWatchdogTimeoutExitMsg(const std::string & exitReason)1335 std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutExitMsg(
1336     const std::string& exitReason) {
1337   return c10::str(
1338       logPrefix(),
1339       "Terminating the process after attempting to dump debug info, due to ",
1340       exitReason,
1341       ".");
1342 }
1343 
heartbeatMonitor()1344 void ProcessGroupNCCL::heartbeatMonitor() {
1345   c10::setThreadName("pt_nccl_heartbt");
1346 
1347   uint64_t heartBeatCounter = 0ULL;
1348   std::string errorMsg;
1349   std::string exitReason;
1350   bool checkDumpSignal = (dumpOnTimeoutOrEx_ && local_id_ == 0);
1351   int monitorPollInterval = checkDumpSignal ? coordCheckIntervalMilSec_
1352                                             : heartbeatTimeoutInSec_ * 1000;
1353   auto lastTimePollStore = std::chrono::steady_clock::now();
1354   auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now();
1355   std::optional<DumpPipe> dumpPipe = std::nullopt;
1356   if (local_id_ == 0) {
1357     // DumpPipe is one per-trainer process, and its convenient to name them
1358     // after 'global' ranks in the system, So we assume processgroup (uid)==0 is
1359     // the global PG and has globally unique rank ids across trainers.
1360     dumpPipe.emplace(rank_);
1361   }
1362   while (true) {
1363     // This won't have any lock since this lock is only used here.
1364     // Please be aware that mutex `monitorMutex_` should not be used
1365     // somewhere else to avoid the deadlock.
1366     std::unique_lock<std::mutex> lock(monitorMutex_);
1367     if (monitorWakeUpCV_.wait_for(
1368             lock, std::chrono::milliseconds(monitorPollInterval), [&] {
1369               return terminateHeartbeatMonitorThread_.load();
1370             })) {
1371       // For the normal complete or user interception, monitorWakeUpCV_
1372       // will get notified, we early return and exit heartbeatMonitor.
1373       return;
1374     }
1375     auto currentTime = std::chrono::steady_clock::now();
1376 
1377     // We put extra functionality in the thread for the default PG (aka,
1378     // local_id_=0) because the signal is same across different PGs. We only
1379     // need to run once per process to avoid duplicate things performed in too
1380     // many separate threads. For example, we check a global flag on the
1381     // TCPStore periodically to see if any PG on any rank observed a timeout and
1382     // signaled peers to dump debugging info, and we avoid hammering the
1383     // TCPStore from all PGs on the same rank.
1384     if (checkDumpSignal) {
1385       // There are two scenarios where monitor thread will dump on timeout:
1386       // 1. The current rank is the first to observe a timeout in watchdog.
1387       // (shouldDump_ was set to true by the watchdog thread).
1388       // 2. Other ranks detected the timeout and signal the current rank to
1389       // dump. In addtion, monitor threads will dump if watchdog threads has no
1390       // heartbeat or dumpPipe is not empty.
1391       if (shouldDump_.load()) {
1392         errorMsg = getNCCLWatchdogTimeoutErrorMsg("this local rank");
1393         exitReason = "collective timeout or exception";
1394         break;
1395       }
1396       // We poll store to see if some ranks have flagged a timeout when
1397       // we haven't polled for `heartbeat_timeout` seconds and there haven't
1398       // any work added or removed for `watchdog_timeout` seconds.
1399       if (computeDeltaMS(lastWorkListUpdateTime_, currentTime) >=
1400               kWatchdogThreadSleepMillis &&
1401           computeDeltaMS(lastTimePollStore, currentTime) >=
1402               coordCheckIntervalMilSec_) {
1403         lastTimePollStore = currentTime;
1404         // Wrap globalStore_->check() in a try-catch block to avoid crashing if
1405         // the store is not available.
1406         bool checkExceptionDump = false;
1407         try {
1408           checkExceptionDump =
1409               globalStore_->check({std::string(EXCEPTION_DUMP)});
1410         } catch (const std::exception& e) {
1411           LOG(WARNING)
1412               << logPrefix()
1413               << "Failed to check the \"should dump\" flag on TCPStore, "
1414               << "(maybe TCPStore server has shut down too early), with error: "
1415               << e.what();
1416           // We give up for now assuming TCPStore has been torn down.
1417           return;
1418         }
1419 
1420         if (checkExceptionDump) {
1421           int timeOutRank = -1;
1422           if (!shouldDump_.load()) {
1423             LOG(ERROR)
1424                 << logPrefix()
1425                 << "Observed flight recorder dump signal from another rank via TCPStore.";
1426           }
1427           shouldDump_.store(true);
1428           try {
1429             auto vec = globalStore_->get(std::string(EXCEPTION_DUMP));
1430             TORCH_CHECK_WITH(
1431                 DistBackendError,
1432                 vec.size() == sizeof(int),
1433                 "Invalid size for the timeout rank ID");
1434             std::memcpy(&timeOutRank, vec.data(), vec.size());
1435           } catch (const std::exception& e) {
1436             LOG(ERROR) << logPrefix()
1437                        << "Failed to get timeout rank ID from TCPStore."
1438                        << e.what();
1439           }
1440           errorMsg =
1441               getNCCLWatchdogTimeoutErrorMsg(c10::str(" rank ", timeOutRank));
1442           exitReason = "collective timeout or exception";
1443           break;
1444         }
1445       }
1446     }
1447 
1448     if (computeDeltaMS(lastTimeHeartBeatCheck, currentTime) >=
1449         heartbeatTimeoutInSec_ * 1000) {
1450       // Check the heart beat of watchdog thread.
1451       lastTimeHeartBeatCheck = currentTime;
1452       auto heartbeat = heartbeat_.load();
1453       if (heartbeat != heartBeatCounter) {
1454         heartBeatCounter = heartbeat;
1455       } else {
1456         shouldDump_.store(true);
1457         // Watchdog heartbeat timeout.
1458         errorMsg = c10::str(
1459             logPrefix(),
1460             "ProcessGroupNCCL's watchdog got stuck for ",
1461             heartbeatTimeoutInSec_,
1462             " seconds without making progress in monitoring enqueued collectives. ",
1463             "This typically indicates a NCCL/CUDA API (e.g., CudaEventDestroy) hang blocking the watchdog, ",
1464             "and could be triggered by another thread holding the GIL inside a ",
1465             "CUDA api (for example, CudaEventDestroy), or other deadlock-prone behaviors.",
1466             "If you suspect the watchdog is not actually stuck and a longer timeout would help, ",
1467             "you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value "
1468             "or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0)."
1469             "If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout "
1470             "or false positive abort; otherwise, please attempt to debug the hang. ");
1471         exitReason = "ProcessGroupNCCL watchdog hang";
1472         break;
1473       }
1474     }
1475     // process a request to dump the trace. only PG uid 0 will respond to dump
1476     // requests, but this is fine since all PG's feed into the same flight
1477     // recorder and dump. After dump, the training should continue.
1478     if (dumpPipe.has_value() && dumpPipe->shouldDump()) {
1479       // best effort dump, not waiting for the dump here
1480       std::future<bool> fut = std::async(
1481           std::launch::async, [this]() { return this->dumpDebuggingInfo(); });
1482     }
1483   }
1484   LOG(ERROR) << errorMsg;
1485 
1486   auto& cpp_dumper = get_cpp_trace_dumper();
1487   if (logCppStackOnUncleanShutdown_ && cpp_dumper.has_value()) {
1488     LOG(INFO) << "Dumping c++ stacktraces:";
1489     cpp_dumper.value()([](const std::string& line) { LOG(INFO) << line; });
1490   }
1491 
1492   if (checkDumpSignal && shouldDump_.load()) {
1493     // Store debug info to storage if no other thread does it. (By default to
1494     // local disk)
1495     std::future<bool> asyncDebugDump = std::async(
1496         std::launch::async, [this]() { return this->dumpDebuggingInfo(); });
1497 
1498     // wait for the dump until timeout - log data
1499     waitForFutureOrTimeout(
1500         asyncDebugDump,
1501         std::chrono::milliseconds(waitTimeoutDumpInMilSec_),
1502         "Flight recorder dump in heartbeatMonitor",
1503         false,
1504         true);
1505   }
1506 
1507   if (get_gil_checker() != nullptr) {
1508     auto fut = launchAsyncGilCheck();
1509     auto kGilCheckTimeout = std::chrono::milliseconds(300);
1510     auto futStatus = fut.wait_for(kGilCheckTimeout);
1511     if (futStatus != std::future_status::ready) {
1512       TORCH_CHECK(
1513           futStatus != std::future_status::deferred,
1514           "Expected the future to have been launched eagerly.");
1515       LOG(ERROR)
1516           << "Could not acquire GIL within 300 ms on exit, possible GIL induced hang";
1517     }
1518   } else {
1519     LOG(INFO)
1520         << "GIL checker was not registered, perhaps this is a no-python build?";
1521   }
1522 
1523   // There are two possible cases for the watchdog thread exit:
1524   // Case one: desync report runs quickly, and it follows the step:
1525   // collective timeout -> desync -> exception handling -> destructors
1526   // -> set terminateHeartbeatMonitorThread_ -> notify monitorWakeUpCV_.
1527   // So the code either early returns above or will skip the sleep below.
1528   // Case two: desync might be slow or get stuck. Or we get stuck in
1529   // destructors, we will sleep for some time before calling std::abort() to
1530   // kill the whole process.
1531   if ((terminateProcessGroup_.load() || collectiveDebugInfoMode_.load() ||
1532        shouldDump_.load()) &&
1533       !terminateHeartbeatMonitorThread_.load()) {
1534     // Leave another two mins for desync report generation or process group
1535     // destroy.
1536     std::this_thread::sleep_for(std::chrono::seconds(heartbeatTimeoutInSec_));
1537     LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_
1538               << " waiting for desync report or process group destroy.";
1539   }
1540 
1541   // At this point, we either already sleep for another `heartbeatTimeoutInSec_`
1542   // or the thread has finished. Because we don't want to block the monitor
1543   // thread, so We mark the thread detach and the dump of debug info becomes
1544   // "best effort". If the process exit normally, marking it detach also makes
1545   // sense because we don't really care about dumping the debug info.
1546 
1547   // We already log completion inside the thread, so it may not be necessary to
1548   // check the return value here.  We mainly use a future so we can exit early
1549   // if done.
1550 
1551   if (!terminateHeartbeatMonitorThread_.load()) {
1552     // Create a error message reported from MonitorThread, so
1553     // we throw exception and make the whole process to be killed.
1554     // TODO(fduwjj): After having a hang debug wiki, we need to update the wiki
1555     // url here.
1556     if (monitorThreadEnabled_.load()) {
1557       terminateProcess(getNCCLWatchdogTimeoutExitMsg(exitReason));
1558     } else {
1559       // Ideally we want to merge this one with the above one, but we are going
1560       // to remove the kill switch for monitor thread soon, so we keep this one
1561       // for now.
1562       LOG(ERROR)
1563           << logPrefix()
1564           << "ProcessGroupNCCL monitor thread is disabled, but would have terminated the process"
1565           << "after attempting to dump debug info, due to " << exitReason
1566           << ".";
1567     }
1568   }
1569 }
1570 
ncclCommWatchdog()1571 void ProcessGroupNCCL::ncclCommWatchdog() {
1572   c10::setThreadName("pt_nccl_watchdg");
1573 
1574   try {
1575     VLOG(2) << logPrefix() << "Process group watchdog thread started!";
1576     ncclHeartbeatMonitorThread_ =
1577         std::thread(&ProcessGroupNCCL::heartbeatMonitor, this);
1578     watchdogHandler();
1579     VLOG(2) << logPrefix()
1580             << "Process group watchdog thread terminated normally";
1581   } catch (std::exception& e) {
1582     if (std::string(e.what()).find("driver shutting down") !=
1583         std::string::npos) {
1584       LOG(INFO)
1585           << logPrefix()
1586           << "main process destroyed cuda before watchdog loop exited, terminating watchdog."
1587           << " (Watchdog caught exception: " << e.what();
1588 
1589     } else {
1590       // Append error message reported from watchdogHandler
1591       const auto exitMsg = c10::str(
1592           logPrefix(),
1593           "Process group watchdog thread terminated with exception: ",
1594           e.what());
1595       LOG(ERROR) << exitMsg;
1596       if (C10_LIKELY(rethrowCUDAErrors_) ||
1597           !(std::string(e.what()).find("CUDA Error"))) {
1598         // TODO(whc) clean up the rethrow - why is it stored in a class var and
1599         // rethrown?
1600         watchDogException_ =
1601             std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg));
1602         std::rethrow_exception(watchDogException_);
1603       }
1604     }
1605   } catch (...) {
1606     const auto exitMsg = c10::str(
1607         logPrefix(),
1608         "Process group watchdog thread terminated with exception: unknown");
1609     LOG(ERROR) << exitMsg;
1610     watchDogException_ =
1611         std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg));
1612     std::rethrow_exception(watchDogException_);
1613   }
1614 }
1615 
logWorkStart(WorkNCCL & work)1616 void ProcessGroupNCCL::logWorkStart(WorkNCCL& work) {
1617   if (work.startTraceUpdated_)
1618     return;
1619 
1620   if (terminateProcessGroup_.load() || storeError_)
1621     return;
1622 
1623   work.startTraceUpdated_ = true;
1624   storeError_ = !c10d::traceUpdate(
1625       store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_));
1626 }
1627 
logWorkEnd(WorkNCCL & work)1628 void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work) {
1629   if (terminateProcessGroup_.load() || storeError_)
1630     return;
1631 
1632   // In case the start of the work hasn't been logged
1633   if (!work.startTraceUpdated_) {
1634     logWorkStart(work);
1635   }
1636 
1637   storeError_ = !c10d::traceUpdate(
1638       store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_));
1639 }
1640 
getNCCLWatchdogDebugInfo()1641 std::string ProcessGroupNCCL::getNCCLWatchdogDebugInfo() {
1642   return retrieveDesyncReport(store_, "NCCL", rank_, size_);
1643 }
1644 
1645 // We want to have both PG ID and global unique ID (guid) for the logging
1646 // prefix. PG ID records how many ProcessGroupNCCL objects were created on a
1647 // specific rank and is a stable index across ranks, which lets users reason
1648 // about, for example, the second PG we initialized on this rank is for FSDP,
1649 // and corresponds with PG ID = 1 on other ranks as well. Unlike PG ID, guid (or
1650 // group name) is a global unique ID across ranks. The guid is either a hash of
1651 // all the ranks in the group or a counter of how many times
1652 // `_process_group_name` is called, essentially it means how many times we
1653 // have PGs users have created. Before using split_group, even if
1654 // we are creating a new sub-PG, all ranks have to call the API at the same
1655 // time, and this makes `group_name` a unique identifier for a group (PG).
createLogPrefix() const1656 std::string ProcessGroupNCCL::createLogPrefix() const {
1657   if (!pg_desc_.empty() && pg_desc_ != "undefined") {
1658     return c10::str(
1659         "[PG ID ",
1660         local_id_,
1661         " PG GUID ",
1662         pg_uid_,
1663         "(",
1664         pg_desc_,
1665         ") Rank ",
1666         rank_,
1667         "] ");
1668   }
1669   return c10::str(
1670       "[PG ID ", local_id_, " PG GUID ", pg_uid_, " Rank ", rank_, "] ");
1671 }
1672 
logPrefix() const1673 const std::string& ProcessGroupNCCL::logPrefix() const {
1674   return logPrefix_;
1675 }
1676 
globalRank() const1677 const int& ProcessGroupNCCL::globalRank() const {
1678   static int globalRank = rank_;
1679   return globalRank;
1680 }
1681 
groupRanks() const1682 const std::vector<uint64_t>& ProcessGroupNCCL::groupRanks() const {
1683   if (options_->global_ranks_in_group.empty() && local_id_ == 0) {
1684     static std::vector<uint64_t> globalRanks(size_);
1685     std::iota(globalRanks.begin(), globalRanks.end(), 0);
1686     return globalRanks;
1687   }
1688   return options_->global_ranks_in_group;
1689 }
1690 
addEphemeralTimeout(const std::chrono::milliseconds & timeout)1691 void ProcessGroupNCCL::addEphemeralTimeout(
1692     const std::chrono::milliseconds& timeout) {
1693   std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
1694   ephemeralTimeoutActive_ += timeout;
1695 }
1696 
verifyWorkTimeoutForTest(const c10::intrusive_ptr<Work> work,const std::chrono::milliseconds & timeout)1697 bool ProcessGroupNCCL::verifyWorkTimeoutForTest(
1698     const c10::intrusive_ptr<Work> work,
1699     const std::chrono::milliseconds& timeout) {
1700   // Since collective returns a c10d::Work, we need to cast it to WorkNCCL.
1701   if (auto workNCCL = c10::dynamic_intrusive_pointer_cast<WorkNCCL>(work)) {
1702     // workNCCL is now a c10::intrusive_ptr<WorkNCCL>
1703     return workNCCL->opTimeout_ == timeout;
1704   }
1705   C10_THROW_ERROR(
1706       DistBackendError, "Non c10d::WorkNCCL object returned from collective");
1707 }
1708 
watchdogHandler()1709 void ProcessGroupNCCL::watchdogHandler() {
1710   bool done = false;
1711   lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
1712   auto lastStatusUpdateTime = std::chrono::steady_clock::now();
1713   std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList;
1714 
1715   while (!done || !terminateProcessGroup_.load()) {
1716     std::unique_lock<std::mutex> lock(workMetaListMutex_);
1717     // We busy-poll the work vector every kWatchdogThreadSleepMillis
1718     // milliseconds as long as the atomic is True.
1719     workMetaListCV_.wait_for(
1720         lock,
1721         std::chrono::milliseconds(kWatchdogThreadSleepMillis),
1722         [&]() -> bool { return terminateProcessGroup_.load(); });
1723     // Bump up heart beat by one.
1724     heartbeat_++;
1725 
1726 // Some versions of GLOG support less-spammy version of LOG_EVERY_MS
1727 // in which case we don't want to spam the logs.
1728 #ifdef LOG_EVERY_MS
1729     // Log the progress of this PG periodically
1730     C10_LOG_EVERY_MS(INFO, kWorkStatusUpdatePeriodMs) << c10::str(
1731         logPrefix(),
1732         "NCCL Work update periodically: ",
1733         "last enqueued NCCL work: ",
1734         pgStatus_->lastEnqueuedSeq,
1735         ", last completed NCCL work: ",
1736         pgStatus_->lastCompletedSeq,
1737         ".");
1738 #endif
1739     auto logger = ::c10d::C10dLogger::getLogger();
1740     if (logger &&
1741         computeDeltaMS(
1742             lastStatusUpdateTime, std::chrono::steady_clock::now()) >=
1743             kWorkStatusUpdatePeriodMs) {
1744       ::c10d::C10dLoggingData data;
1745       // logging integers
1746       data.integers["pg_id"] = local_id_;
1747       data.integers["rank"] = rank_;
1748       data.integers["global_rank"] = globalRank();
1749       data.integers["last_enqueued_work"] = pgStatus_->lastEnqueuedSeq;
1750       data.integers["last_started_work"] = pgStatus_->lastStartedSeq;
1751       data.integers["last_completed_work"] = pgStatus_->lastCompletedSeq;
1752       data.integers["last_enqueued_numel_in"] = pgStatus_->lastEnqueuedNumelIn;
1753       data.integers["last_enqueued_numel_out"] =
1754           pgStatus_->lastEnqueuedNumelOut;
1755       data.integers["last_completed_numel_in"] =
1756           pgStatus_->lastCompletedNumelIn;
1757       data.integers["last_completed_numel_out"] =
1758           pgStatus_->lastCompletedNumelOut;
1759       // logging strings
1760       data.strings["last_enqueued_work_name"] = pgStatus_->lastEnqueuedWorkName;
1761       data.strings["last_started_work_name"] = pgStatus_->lastStartedWorkName;
1762       data.strings["last_completed_work_name"] =
1763           pgStatus_->lastCompletedWorkName;
1764       data.strings["pg_name"] = pg_uid_;
1765       data.strings["pg_desc"] = pg_desc_;
1766       logger->log(data);
1767       lastStatusUpdateTime = std::chrono::steady_clock::now();
1768     }
1769 
1770     for (auto it = workMetaList_.begin(); it != workMetaList_.end();
1771          /* no increment */) {
1772       auto& work = *it;
1773       // When terminateProcessGroup_ is true, communicators have already been
1774       // aborted, So cannot check exception based on them. But watchdog needs to
1775       // finish the check for the works that have already been enqueued to
1776       // workMetaList_
1777       if (!terminateProcessGroup_.load()) {
1778         work.checkAndSetException();
1779       }
1780       bool timedOut = work.checkTimeout();
1781 
1782       // If work hits an exception (either an error or timeout)
1783       if (work.exception()) {
1784         // log as soon as exception is detected
1785         LOG(ERROR) << c10::str(
1786             logPrefix(),
1787             "Exception (either an error or timeout) detected by watchdog at work: ",
1788             work.seq_,
1789             ", last enqueued NCCL work: ",
1790             pgStatus_->lastEnqueuedSeq,
1791             ", last completed NCCL work: ",
1792             pgStatus_->lastCompletedSeq,
1793             ".");
1794         // try to notify other ranks via global TCPStore to dump the flight
1795         // recorder when a collective timeout or exception happens. Flight
1796         // recorder behavior is independent of desync Debug.
1797         if (dumpOnTimeoutOrEx_) {
1798           try {
1799             auto rank = globalRank();
1800             auto vec = std::vector<uint8_t>(
1801                 reinterpret_cast<uint8_t*>(&rank),
1802                 reinterpret_cast<uint8_t*>(&rank) + sizeof(rank));
1803             globalStore_->set(std::string(EXCEPTION_DUMP), vec);
1804             if (!shouldDump_.load()) {
1805               LOG(ERROR)
1806                   << logPrefix()
1807                   << "Broadcasting flight-recorder dump signal to other processes via TCPStore.";
1808             }
1809             // signal the monitor thread on PG0 to start dumping
1810             shouldDump_.store(true);
1811             // This sleep is used to give time for dumping before throwing
1812             // exception
1813             std::this_thread::sleep_for(
1814                 std::chrono::seconds(heartbeatTimeoutInSec_));
1815             LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_
1816                       << " giving time for flight recorder dumps to finish.";
1817           } catch (const std::exception& e) {
1818             LOG(ERROR) << logPrefix()
1819                        << "Failed to set dump signal in tcpstore. "
1820                        << "Error: " << e.what();
1821           }
1822         }
1823 
1824         if (SHOULD_CLEAN_UP(asyncErrorHandling_)) {
1825           // Abort work and corresponding communicators
1826           work.abort();
1827           // PG level abort, which would abort all other communicators on this
1828           // rank
1829           abort();
1830         }
1831 
1832         // Report desync state in case of timeout
1833         if (timedOut) {
1834           LOG(ERROR) << c10::str(
1835               logPrefix(),
1836               "Timeout at NCCL work: ",
1837               work.seq_,
1838               ", last enqueued NCCL work: ",
1839               pgStatus_->lastEnqueuedSeq,
1840               ", last completed NCCL work: ",
1841               pgStatus_->lastCompletedSeq,
1842               ".");
1843           if (desyncDebug_) {
1844             try {
1845               collectiveDebugInfoMode_.store(true);
1846               auto desyncMsg = getNCCLWatchdogDebugInfo();
1847               LOG(ERROR) << logPrefix() << desyncMsg;
1848             } catch (const std::exception& e) {
1849               LOG(ERROR)
1850                   << logPrefix()
1851                   << "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. "
1852                   << " Please file an issue. Error: " << e.what();
1853             } catch (...) {
1854               LOG(ERROR)
1855                   << logPrefix()
1856                   << "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error."
1857                   << " Please file an issue.";
1858             }
1859           }
1860         }
1861         // Throw exception
1862         work.handleException(asyncErrorHandling_);
1863       }
1864 
1865       // Work status logging for desync debug
1866       if (desyncDebug_) {
1867         if (work.isStarted()) {
1868           logWorkStart(work);
1869         }
1870         if (work.isCompleted()) {
1871           logWorkEnd(work);
1872         }
1873       }
1874 
1875       // a work could be started but not completed, so we should not update
1876       // lastStartedSeq and lastStartedOpName if the work state is checked
1877       // multiple times after the start
1878       if (pgStatus_->lastStartedSeq < static_cast<int64_t>(work.seq_) &&
1879           work.isStarted()) {
1880         pgStatus_->lastStartedSeq = work.seq_;
1881         pgStatus_->lastStartedWorkName = opTypeToString(work.opType_);
1882       }
1883 
1884       // Clean up completed work
1885       if (work.isCompleted()) {
1886         {
1887           // Reset the timeout and first work if the work is completed.
1888           std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
1889           if (work.ownedEphermeralTimeout_.count() > 0) {
1890             ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_;
1891             ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_;
1892           }
1893         }
1894         pgStatus_->lastCompletedSeq = work.seq_;
1895         pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_);
1896         pgStatus_->lastCompletedNumelIn = work.numelIn_;
1897         pgStatus_->lastCompletedNumelOut = work.numelOut_;
1898         NCCLTraceBuffer::get()->retire_id(work.trace_id_, true);
1899         if (onCompletionHook_) {
1900           // Move Work object to completedWorkList_ to be consumed by the hook
1901           // thread
1902           {
1903             const std::lock_guard<std::mutex> lock(completedWorkListMutex_);
1904             completedWorkList_.splice(
1905                 completedWorkList_.end(), workMetaList_, it++);
1906           }
1907           completedWorkListCV_.notify_one();
1908         } else {
1909           it = workMetaList_.erase(it);
1910           lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
1911         }
1912         at::cuda::CUDAGraph::dec_pending_event_queries();
1913       } else {
1914         // Increment the iterator if the current WorkNCCL object is not
1915         // completed.
1916         ++it;
1917       }
1918       // Increment heartbeat after each work processed,
1919       // in case processing is slowed down (but not hung) by cuda api contention
1920       heartbeat_++;
1921     }
1922     done = workMetaList_.empty();
1923   }
1924 }
1925 
runHookLoop()1926 void ProcessGroupNCCL::runHookLoop() {
1927   c10::setThreadName("pt_nccl_runhook");
1928 
1929   bool done = false;
1930   while (!done || !terminateProcessGroup_.load()) {
1931     std::unique_lock<std::mutex> lock(completedWorkListMutex_);
1932     // We busy-poll the work vector every kWatchdogThreadSleepMillis
1933     // milliseconds as long as the atomic is True.
1934     completedWorkListCV_.wait_for(
1935         lock,
1936         std::chrono::milliseconds(kWatchdogThreadSleepMillis),
1937         [&]() -> bool {
1938           return !completedWorkList_.empty() || terminateProcessGroup_.load();
1939         });
1940 
1941     try {
1942       for (auto it = completedWorkList_.begin(); it != completedWorkList_.end();
1943            /* no increment */) {
1944         const WorkNCCL& work = *it;
1945         // Hook might grab GIL, unlock first to prevent deadlock
1946         lock.unlock();
1947 
1948         auto timeStarted =
1949             std::chrono::system_clock::now() +
1950             std::chrono::duration_cast<std::chrono::system_clock::duration>(
1951                 work.workStartTime_ - std::chrono::steady_clock::now());
1952         onCompletionHook_(std::make_shared<WorkInfo>(
1953             work.retrieveOpType(), // OpType
1954             work.getSequencenumber(), // seq
1955             timeStarted, // timeStarted
1956             std::chrono::system_clock::now(), // timeFinished
1957             std::chrono::duration<float, std::milli>(
1958                 work.getDuration()) // activeDuration
1959             ));
1960 
1961         lock.lock();
1962         it = completedWorkList_.erase(it);
1963       }
1964     } catch (std::exception& e) {
1965       if (std::string(e.what()).find("driver shutting down") !=
1966           std::string::npos) {
1967         LOG(INFO)
1968             << logPrefix()
1969             << "main process destroyed cuda before runHookLoop exited, terminating runHookLoop."
1970             << " (runHookLoop caught exception: " << e.what();
1971 
1972       } else {
1973         // PythonOnCompletionHook has already extracted Python exception message
1974         // and wrapped it with a cpp one. So we no longer need to acquire GIL
1975         // here.
1976         const auto errorStr = c10::str(
1977             "Caught exception on rank ",
1978             rank_,
1979             " while running onCompletion hook for ProcessGroupNCCL: ",
1980             e.what(),
1981             ". Aborting all communicators.");
1982 
1983         // No need to call abort() on WorkNCCL here as that collective has
1984         // already finished successfully at this point. We just need to abort
1985         // the process Abort all NCCL Communicators on this ProcessGroupNCCL
1986         // instance.
1987         abort(errorStr);
1988       }
1989     }
1990 
1991     // Lock is still acquired at this point
1992     done = completedWorkList_.empty();
1993   }
1994 }
1995 
checkForNCCLErrors()1996 std::exception_ptr ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors() {
1997   return checkForNCCLErrorsInternal(ncclComm_);
1998 }
1999 
checkForNCCLErrors(std::shared_ptr<NCCLComm> & ncclComm)2000 std::exception_ptr ProcessGroupNCCL::checkForNCCLErrors(
2001     std::shared_ptr<NCCLComm>& ncclComm) {
2002   return checkForNCCLErrorsInternal(ncclComm);
2003 }
2004 
checkForNCCLErrorsInternal(std::shared_ptr<NCCLComm> & ncclComm)2005 std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal(
2006     std::shared_ptr<NCCLComm>& ncclComm) {
2007   // Prioritize commFailureReason over checkForNcclError() result if
2008   // commFailureReason is set.
2009   auto commFailureReason = ncclComm->getNcclCommFailureReason();
2010   if (commFailureReason != std::nullopt) {
2011     return std::make_exception_ptr(C10_BUILD_ERROR(
2012         DistBackendError,
2013         c10::str(
2014             "NCCL communicator encountered error set by ProcessGroupNCCL: ",
2015             *commFailureReason)));
2016   }
2017   ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError();
2018   // When nonblocking mode is enabled by TORCH_NCCL_USE_COMM_NONBLOCKING,
2019   // ncclInProgress could be returned when there are pending NCCL calls.
2020   // In this case, no exception should be thrown
2021 #ifdef NCCL_HAS_COMM_NONBLOCKING
2022   // ncclInProgress is defined only if NCCL_HAS_COMM_NONBLOCKING is defined
2023   if (ncclAsyncErr != ncclSuccess && ncclAsyncErr != ncclInProgress) {
2024 #else
2025   if (ncclAsyncErr != ncclSuccess) {
2026 #endif
2027     return std::make_exception_ptr(C10_BUILD_ERROR(
2028         DistBackendError,
2029         "NCCL error: " + ncclGetErrorWithVersion(ncclAsyncErr) + "\n" +
2030             getNcclErrorDetailStr(ncclAsyncErr)));
2031   }
2032 
2033   return nullptr;
2034 }
2035 
2036 void ProcessGroupNCCL::broadcastUniqueNCCLID(
2037     ncclUniqueId* ncclID,
2038     bool isSingleP2POp,
2039     const std::string& p2pKey,
2040     int p2pRank) {
2041   // For collective operations:
2042   // For every NCCL communicator that we create we need to broadcast
2043   // a unique ID from rank 0 to all other ranks. This broadcast is
2044   // done by rank 0 setting a key in the store and all other ranks
2045   // retrieving the contents of that key. A single process group
2046   // may create multiple NCCL communicators, so we use a sequence
2047   // number to differentiate between them.
2048   // For single point-to-point operations:
2049   // The sequence number will only be increased on 2 out of all the
2050   // processes in a Process Group. So all following collective
2051   // operations will see different sequence numbers which will cause
2052   // runtime errors. To avoid that, use the src:target pair instead
2053   // of sequence number for p2p communications.
2054 
2055   std::string storeKey;
2056   if (!isSingleP2POp) {
2057     storeKey = std::to_string(ncclCommCounter_++);
2058   } else {
2059     storeKey = p2pKey;
2060   }
2061   if (rank_ == 0 || (isSingleP2POp && p2pRank == 0)) {
2062     auto vec = std::vector<uint8_t>(
2063         reinterpret_cast<uint8_t*>(ncclID),
2064         reinterpret_cast<uint8_t*>(ncclID) + NCCL_UNIQUE_ID_BYTES);
2065     store_->set(storeKey, vec);
2066   } else {
2067     try {
2068       auto vec = store_->get(storeKey);
2069       TORCH_CHECK_WITH(
2070           DistBackendError,
2071           vec.size() == NCCL_UNIQUE_ID_BYTES,
2072           "Invalid size for ncclUniqueId");
2073       std::memcpy(ncclID, vec.data(), vec.size());
2074     } catch (const std::exception& e) {
2075       std::string exceptionMsg = c10::str(
2076           "[",
2077           rank_,
2078           "] is setting up NCCL communicator and "
2079           "retrieving ncclUniqueId from [0] via c10d key-value store by key '",
2080           storeKey,
2081           "', but store->get('",
2082           storeKey,
2083           "') got error: ");
2084       C10_THROW_ERROR(
2085           DistBackendError,
2086           exceptionMsg + e.what() +
2087               ". This may indicate a possible application crash on rank 0 or a network set up issue.");
2088     } catch (...) {
2089       C10_THROW_ERROR(
2090           DistBackendError,
2091           c10::str(
2092               "Unknown exception while [",
2093               rank_,
2094               "] is setting up NCCL communicator and "
2095               "retrieving ncclUniqueId from [0] via c10d key-value store by key '",
2096               storeKey,
2097               "'",
2098               ". This may indicate a possible application crash on rank 0 or a network set up issue."));
2099     }
2100   }
2101 }
2102 
2103 void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
2104   std::lock_guard<std::mutex> lock(mutex_);
2105   if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) {
2106     TORCH_INTERNAL_ASSERT(
2107         false,
2108         "Expected to find key ",
2109         devNCCLCommMapKey,
2110         " in NCCL communicator map.");
2111   }
2112   std::shared_ptr<NCCLComm>& ncclComm = devNCCLCommMap_[devNCCLCommMapKey];
2113   // ncclCommDestroy(comm->getNcclComm()) results in segfault when PG is being
2114   // destroyed, so using ncclCommAbort here.
2115   ncclComm->ncclCommAbort();
2116   // Remove communicators from the cache.
2117   devNCCLCommMap_.erase(devNCCLCommMapKey);
2118   // Clear used device indices.
2119   usedDeviceIdxs_.clear();
2120 
2121   ncclCommDevIdxMapMutex.lock();
2122   ncclCommDevIdxMap.erase(ncclComm);
2123   ncclCommDevIdxMapMutex.unlock();
2124 }
2125 
2126 std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
2127     const std::string& deviceKey,
2128     at::Device& device,
2129     OpType opType,
2130     int p2pRank,
2131     bool isSendRecvSelf) {
2132   // Sanity check
2133   if (deviceKey.empty()) {
2134     C10_THROW_ERROR(
2135         DistBackendError,
2136         "Not able to create/get the NCCL Communicator since "
2137         "the GPU devices are not known");
2138   }
2139   if (bound_device_id_) {
2140     if (*bound_device_id_ != device) {
2141       LOG(ERROR) << logPrefix() << "Tensor found on device " << device
2142                  << " but backend constrained to " << *bound_device_id_;
2143       C10_THROW_ERROR(
2144           DistBackendError,
2145           "Attempt to perform collective on tensor not on device passed to init_process_group");
2146     }
2147   }
2148 
2149   usedDeviceIdxs_.insert(device.index());
2150 
2151   {
2152     std::lock_guard<std::mutex> lock(mutex_);
2153     if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) {
2154       // Reuse the cached communicator if there is one.
2155       return devNCCLCommMap_[deviceKey];
2156     }
2157   }
2158 
2159   // NCCL communicator not cached, create a new entry
2160   std::shared_ptr<NCCLComm> ncclComm;
2161 
2162   // Create the unique NCCL ID and broadcast it
2163   ncclUniqueId ncclID;
2164 
2165   // reset log prefix to include group_desc
2166   logPrefix_ = createLogPrefix();
2167 
2168 #ifdef NCCL_COMM_DESCRIPTION
2169   // Pass process group name and description to NCCL communicator
2170   std::string commDesc = pg_desc_ + ':' + pg_uid_;
2171   options_->config.commDesc = strdup(commDesc.c_str());
2172 #endif
2173 
2174   // For batch_isend_irecv, ncclGroupStart() would be called upfront
2175   bool batchP2P = ncclActiveGroupCounter_ > 0;
2176   bool singleP2POp = isP2POp(opType, batchP2P);
2177 
2178   // Get the device index
2179   auto deviceIndex = device.index();
2180   at::cuda::OptionalCUDAGuard gpuGuard(device);
2181 
2182   // [Group Start/End Note] This is used to ensure that nccl communicator will
2183   // be created before communication primitives are called. Let's look at this
2184   // example: Using the batch_isend_irecv to send a tensor to a target process.
2185   // On the sender side, the corresponding underlying NCCL calls will look like
2186   //   ncclGroupStart() // This is in batch_isend_irecv
2187   //   ncclCommInitRank() // Inside NCCLComm::create
2188   //   ncclSend()
2189   //   ncclGroupEnd() // This is in batch_isend_irecv
2190   // With this pattern, the nccl communicator will be created in the last
2191   // ncclGroupEnd which means when ncclSend is processed, the passed
2192   // communicator argument is NULL which will lead to runtime error. So we need
2193   // to "close" all active nccl groups to ensure nccl communicator is actually
2194   // created before encountering any communication calls. This is why we need
2195   // the following for loop.
2196   for (const auto i : c10::irange(ncclActiveGroupCounter_)) {
2197     (void)i;
2198     // comms have not been initiated yet, so can only check in blocking-way
2199     C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt);
2200   }
2201 
2202   // GPU world size and GPU rank
2203   int numRanks, rank;
2204 
2205   if (!singleP2POp) {
2206     // Collective, all-to-all, or batch P2P
2207     numRanks = getSize();
2208     rank = getRank();
2209   } else if (isSendRecvSelf) {
2210     // Same process send and recv.
2211     numRanks = 1;
2212     rank = 0;
2213   } else {
2214     // For single point-to-point operation, there are only 2 processes
2215     // involved so the GPU rank is either 0 or 1.
2216     numRanks = 2;
2217     rank = p2pRank;
2218   }
2219 
2220 #ifdef NCCL_HAS_COMM_SPLIT
2221   if (options_->split_from) {
2222     TORCH_CHECK(
2223         options_->split_color != 0,
2224         "Must specify a non-zero color when splitting");
2225     // Find a valid, healthy communicator to split from if possible.
2226     std::lock_guard<std::mutex> lock(options_->split_from->mutex_);
2227     auto& other_comms = options_->split_from->devNCCLCommMap_;
2228     auto dit = other_comms.find(getKeyFromDevice(device));
2229     if (dit != other_comms.end()) {
2230       auto& parentComm = dit->second;
2231       if (parentComm != nullptr && !parentComm->isAborted()) {
2232         ncclComm = NCCLComm::split(
2233             parentComm.get(),
2234             options_->split_color,
2235             rank,
2236             options_->config,
2237             options_->global_ranks_in_group);
2238       }
2239     }
2240   }
2241 #endif
2242 
2243   // To simplify conditional nesting, just create the ncclComms[i]
2244   // entry if it hasn't been yet rather than untangling the
2245   // conditions that might have resulted in a split above.
2246   if (!ncclComm) {
2247     if (getCvarBool(TORCH_NCCL_BCAST_UNIQUEID, true) && !isSendRecvSelf) {
2248       // For point-to-point communication, lower rank of the two will get unique
2249       // id.
2250       if (rank_ == 0 || (singleP2POp && p2pRank == 0)) {
2251         C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), std::nullopt);
2252       }
2253 
2254       // Broadcast so that each process can have a unique NCCL ID
2255       auto timeStarted = std::chrono::steady_clock::now();
2256       broadcastUniqueNCCLID(&ncclID, singleP2POp, deviceKey, p2pRank);
2257       auto timerDeltaMs =
2258           std::chrono::duration_cast<std::chrono::duration<double>>(
2259               std::chrono::steady_clock::now() - timeStarted)
2260               .count() *
2261           1000;
2262       LOG(INFO) << logPrefix()
2263                 << "ProcessGroupNCCL broadcast unique ID through store took "
2264                 << timerDeltaMs << " ms";
2265     }
2266 
2267 #ifdef NCCL_HAS_COMM_NONBLOCKING
2268     ncclComm = NCCLComm::create(numRanks, rank, ncclID, options_->config);
2269 #else
2270     ncclComm = NCCLComm::create(numRanks, rank, ncclID);
2271 #endif
2272   }
2273 
2274   // Creates the NCCL streams
2275   bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false);
2276   auto streamVal = at::cuda::getStreamFromPool(
2277       options_->is_high_priority_stream || force_high);
2278 
2279   {
2280     std::lock_guard<std::mutex> lock(mutex_);
2281     inInitializationCommMap_.emplace(deviceKey, ncclComm);
2282   }
2283 
2284   NCCLTraceBuffer::get()->record_pg_ranks(
2285       std::make_tuple(pg_uid_, pg_desc_), groupRanks());
2286 
2287   RECORD_PARAM_COMMS(
2288       0, // seq
2289       std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
2290       rank, // rank
2291       "init", // collective name
2292       0, // inNelems
2293       0, // outNelems
2294       at::kByte, // dType
2295       std::vector<int64_t>(), // inSplitSizes
2296       std::vector<int64_t>(), // outSplitSizes
2297       globalRankStart, // globalRankStart
2298       globalRankStride, // globalRankStride
2299       size_); // worldSize
2300 
2301   LOG(INFO) << logPrefix() << "ProcessGroupNCCL created ncclComm_ "
2302             << ncclComm->ncclComm_ << " on CUDA device: " << deviceIndex;
2303 
2304   // At this point NCCL should have been initialized, hence we can accurately
2305   // get the env value even if NCCL sets it by reading from nccl.conf file
2306   LOG(INFO) << logPrefix()
2307             << "NCCL_DEBUG: " << getCvarString({"NCCL_DEBUG"}, "N/A");
2308 
2309   // See [Group Start/End Note]
2310   for (const auto i : c10::irange(ncclActiveGroupCounter_)) {
2311     (void)i;
2312     C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt);
2313   }
2314 
2315   ncclStreams_.emplace(deviceKey, std::move(streamVal));
2316 
2317   // Note: these events are created with the (default) cudaEventDisableTiming
2318   // flag This flag provides the best performance when used with
2319   // cudaStreamWaitEvent() and cudaEventQuery(). Since we here don't measure the
2320   // performance using cudaEvent, this should be set.
2321   // TODO(kwen2501): is ncclEvents_ used anywhere else?
2322   ncclEvents_.emplace(deviceKey, at::cuda::CUDAEvent(cudaEventDisableTiming));
2323 
2324   // Move the NCCL resource to cache
2325   auto it = inInitializationCommMap_.find(deviceKey);
2326   // A previous thread could've already removed devicesKey from
2327   // inInitializationCommMap_ and added it to devNCCLCommMap_
2328   if (it != inInitializationCommMap_.end()) {
2329     devNCCLCommMap_.emplace(deviceKey, std::move(it->second));
2330     inInitializationCommMap_.erase(deviceKey);
2331 
2332     // Now ncclComms are fully initialized.
2333     // Register all active CUDA memory segments in cache allocator to
2334     // the new NCCL communicators
2335     if (useTensorRegisterAllocatorHook_) {
2336       auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
2337       // Register the segment to a new NCCL communicator if on the same device
2338       for (const auto& segmentInfo : snapshot.segments) {
2339         TORCH_INTERNAL_ASSERT(
2340             segmentInfo.device == device.index(),
2341             "Mismatch between CUDA memory segment device and current device");
2342         ncclComm->registerSegment(
2343             reinterpret_cast<void*>(segmentInfo.address),
2344             segmentInfo.total_size);
2345       }
2346     }
2347     // Record the mapping between ncclComm and device index so that later
2348     // register hook can register a newly allocated segment to communicators
2349     // on the same device.
2350     // NOTE: we need remove the communicator from this map when it is
2351     // destroyed, otherwise may register onto an invalid communicator.
2352     ncclCommDevIdxMapMutex.lock();
2353     ncclCommDevIdxMap.emplace(ncclComm, device.index());
2354     ncclCommDevIdxMapMutex.unlock();
2355   }
2356 
2357   it = devNCCLCommMap_.find(deviceKey);
2358   TORCH_INTERNAL_ASSERT(
2359       it != devNCCLCommMap_.end(), "Communicators not populated in cache!");
2360 
2361   return it->second;
2362 }
2363 
2364 uint64_t ProcessGroupNCCL::getCommSplitCounter() const {
2365   uint64_t ret = 0;
2366   for (const auto& i : devNCCLCommMap_) {
2367     auto& ncclComm = i.second;
2368     ret += ncclComm->getCommSplitCounter();
2369   }
2370   return ret;
2371 }
2372 
2373 namespace {
2374 
2375 // Check validity of tensor
2376 void check_gpu_single_tensor(
2377     const at::Tensor& tensor,
2378     const bool p2p = false // whether operation is a P2P operation
2379 ) {
2380   if (!tensor.is_cuda() || tensor.is_sparse()) {
2381     C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense");
2382   }
2383   // Skip the following requirements for P2P operations
2384   if (!tensor.is_contiguous(tensor.suggest_memory_format())) {
2385     if (p2p) {
2386       TORCH_WARN_ONCE(
2387           "Detected non-contiguous tensor in P2P operations. It is user "
2388           "responsibility to guarantee that source and destination tensors have "
2389           "the same contiguity format.");
2390     } else {
2391       C10_THROW_ERROR(ValueError, "Tensors must be contiguous");
2392     }
2393   }
2394 }
2395 
2396 // Checks that all `tensors' have the same type and shape and reside on the same
2397 // GPU.
2398 // TODO: test_c10d_nccl.py should consider adding tests for the error conditions
2399 // here, ie, that deliberately pass invalid tensors and check the right
2400 // exception is thrown. The "Expected list of tensors on the same device"
2401 // condition may be a challenge because the test would need to pass tensors on
2402 // different devices in the same process.
2403 int64_t check_gpu_tensors_same_device(const std::vector<at::Tensor>& tensors) {
2404   if (tensors.size() == 0) {
2405     C10_THROW_ERROR(ValueError, "Tensor list must be nonempty");
2406   }
2407 
2408   const auto& first = tensors.front();
2409 
2410   int64_t total_numel = 0;
2411   for (const auto& t : tensors) {
2412     if (!t.is_cuda() || t.is_sparse()) {
2413       C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense");
2414     }
2415     if (t.scalar_type() != first.scalar_type()) {
2416       C10_THROW_ERROR(TypeError, "Tensors must have identical type");
2417     }
2418     if (!t.is_non_overlapping_and_dense()) {
2419       C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense");
2420     }
2421     // If we're in this function, the user called a _coalesced collective
2422     // on a set of tensors with potentially different sizes and strides.
2423     // Therefore, we don't check for matching sizes and strides,
2424     // but we do double-check tensors are on the same device.
2425     TORCH_CHECK_WITH(
2426         ValueError,
2427         t.get_device() == tensors[0].get_device(),
2428         "Expected list of tensors on the same device");
2429     total_numel += t.numel();
2430   }
2431 
2432   return total_numel;
2433 }
2434 
2435 bool check_same_size(const std::vector<at::Tensor>& input_tensors) {
2436   for (const auto& input_tensor : input_tensors) {
2437     if (!input_tensors[0].is_same_size(input_tensor)) {
2438       return false;
2439     }
2440   }
2441   return true;
2442 }
2443 
2444 } // namespace
2445 
2446 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
2447     at::Device& device,
2448     int rank,
2449     OpType opType,
2450     const char* profilingTitle,
2451     const std::vector<at::Tensor>& inputs,
2452     const std::vector<at::Tensor>& outputs, // TODO(kwen2501): necessary?
2453     bool record) {
2454   auto r = c10::make_intrusive<ProcessGroupNCCL::WorkNCCL>(
2455       pg_uid_,
2456       pg_desc_,
2457       device,
2458       rank,
2459       opType,
2460       seqCollective_,
2461       profilingTitle,
2462       profilingTitle != nullptr ? std::optional<std::vector<at::Tensor>>(inputs)
2463                                 : std::nullopt,
2464       desyncDebug_,
2465       enableTiming_.load(),
2466       cudaEventCacheEnabled_.load(),
2467       dist_debug_level_);
2468   if (record) {
2469     bool isP2P = isP2POp(opType);
2470     // Ideally record every work that we enqueue, rather than every work we
2471     // create.
2472     // - at the time of this PR we do not currently enqueue every created work
2473     // - but it is unsafe to steal refs to start/end cuda events from Works that
2474     //   may go out of scope before flight recorder has retired them,
2475     //   so we must ensure that any work that is initialized via initWork will
2476     //   be enqueued
2477     // - initially, moved record() into workEnqueue(), but found that makes it
2478     //   hard to get access to profilingTitle,
2479     //   inputs, and outputs for metadata recording, and we don't want to attach
2480     //   these objects to the Work becuase it has implications for keeping those
2481     //   tensors alive longer and adds overhead when copying Work objects
2482     //   between threads
2483     r->trace_id_ = NCCLTraceBuffer::get()->record(
2484         local_id_,
2485         std::make_tuple(pg_uid_, pg_desc_),
2486         seqCollective_,
2487         seqP2P_,
2488         op_id_,
2489         profilingTitle ? profilingTitle : "",
2490         inputs,
2491         outputs,
2492         r->ncclStartEvent_.get(),
2493         r->ncclEndEvent_.get(),
2494         options_->timeout,
2495         pgStatus_,
2496         isP2P);
2497   }
2498   return r;
2499 }
2500 
2501 // TODO(kwen2501): deprecate
2502 std::vector<at::Tensor> ProcessGroupNCCL::WorkNCCL::result() {
2503   return *outputs_;
2504 }
2505 
2506 c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL::
2507     getFuture() {
2508   return future_;
2509 }
2510 
2511 float ProcessGroupNCCL::WorkNCCL::getDuration() const {
2512   TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled");
2513   TORCH_CHECK(
2514       ncclStartEvent_,
2515       "getDuration only works if ncclStartEvents_ is populated, true if timing enabled");
2516   TORCH_CHECK(
2517       ncclEndEvent_,
2518       "getDuration only works if ncclEndEvents_ is populated, which should always be true");
2519   return ncclStartEvent_->elapsed_time(*ncclEndEvent_);
2520 }
2521 
2522 uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const {
2523   return seq_;
2524 }
2525 
2526 void ProcessGroupNCCL::assignTimeoutToWork(
2527     const c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work,
2528     const c10::intrusive_ptr<ProcessGroupNCCL::Options>& option) {
2529   std::chrono::milliseconds timeout = option->timeout;
2530   std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
2531   if (ephemeralTimeoutActive_.count() > 0) {
2532     timeout += ephemeralTimeoutActive_;
2533   }
2534   work->opTimeout_ = timeout;
2535   work->ownedEphermeralTimeout_ =
2536       ephemeralTimeoutActive_ - ephemeralTimeoutInflight_;
2537   ephemeralTimeoutInflight_ = ephemeralTimeoutActive_;
2538 }
2539 
2540 void ProcessGroupNCCL::workEnqueue(
2541     c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> work) {
2542   if (!terminateProcessGroup_.load()) {
2543     std::lock_guard<std::mutex> lock(workMetaListMutex_);
2544     // Avoid view tensors to be processed in cleanup thread.
2545     // View tensors' destruction invokes autograd_meta, which
2546     // needs to be destructed in user thread. Otherwise will
2547     // get deadlock. Here we enqueue work without outputs_.
2548     workMetaList_.emplace_back(*work);
2549     // update the PG status related to the last enqueued work
2550     pgStatus_->lastEnqueuedSeq = work->seq_;
2551     pgStatus_->lastEnqueuedWorkName = opTypeToString(work->opType_);
2552     pgStatus_->lastEnqueuedNumelIn = work->numelIn_;
2553     pgStatus_->lastEnqueuedNumelOut = work->numelOut_;
2554     lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
2555   }
2556 }
2557 
2558 ProcessGroupNCCL::Options::Options(bool is_high_priority_stream)
2559     : Backend::Options(NCCL_BACKEND_NAME, kProcessGroupNCCLDefaultTimeout),
2560       is_high_priority_stream(is_high_priority_stream) {}
2561 
2562 static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04;
2563 
2564 void ProcessGroupNCCL::startCoalescing() {
2565   // Other collective ops bump seq_ before creating a work. Thus, if coalesced
2566   // ops bump seq_ only after initing a work they will collide with (reuse) the
2567   // seq_ of the last non-coalesced collective.  Previously, seq_ was bumped
2568   // inside endCoalescing, but before initWork. Since we now record individual
2569   // ops from a coalesce group into the flight recorder, we want to have the
2570   // same seq_ for those ops and its 'endCoalescing' op. Hence we bump during
2571   // start, which has one minor downside- we burn a seq_ if someone ever does a
2572   // 'start' and 'end' coalescing region without doing an operation inbetween.
2573 
2574   // Don't bump op_id_ here, because startCoalescing isn't a logical operation.
2575   // Bump it for each logical op inside the coalescing group.
2576   if (coalescing_state_ & CoalP2P) {
2577     seqP2P_++;
2578   } else {
2579     seqCollective_++;
2580   }
2581 
2582   coalescedDevice_.set_index(-1);
2583   coalescedComm_ = nullptr;
2584   coalescing_state_ |= CoalActive;
2585   groupStart();
2586 }
2587 
2588 // `optype` is for specifying a composite optype, such as ALLGATHER and
2589 // REDUCE_SCATTER
2590 c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
2591   if (coalescedComm_ == nullptr) {
2592     // There is no actual work being coalesced, return here
2593     groupEnd();
2594     coalescing_state_ = 0;
2595     return nullptr;
2596   }
2597   TORCH_CHECK(
2598       coalescedDevice_.index() >= 0,
2599       "Somthing went wrong. Did you call end_coalescing before start_coalescing?");
2600 
2601   // `coalescedComm_` should have same set of comms across collectives
2602   auto comm = coalescedComm_;
2603   // `coalescedDevice_` should have same set of devices across collectives
2604   auto device = coalescedDevice_;
2605 
2606   // `getKeyFromDevice` is how we get keys for both collectives and batch P2P
2607   const auto key = getKeyFromDevice(device);
2608   auto ncclStream = ncclStreams_.at(key);
2609 
2610   // Create Work object
2611   c10::cuda::CaptureStatus capture_status =
2612       c10::cuda::currentStreamCaptureStatusMayInitCtx();
2613   bool enqueue =
2614       (coalescing_state_) && capture_status == c10::cuda::CaptureStatus::None;
2615   auto work =
2616       initWork(device, rank_, optype, "nccl:coalesced", {}, {}, enqueue);
2617   work->ncclComm_ = comm;
2618   work->blockingWait_ = blockingWait_;
2619   work->avoidRecordStreams_ = avoidRecordStreams_;
2620   work->store_ = store_;
2621   assignTimeoutToWork(work, options_);
2622 
2623   // Record start before ncclGroupEnd
2624   if (work->timingEnabled_) {
2625     work->ncclStartEvent_->record(ncclStream);
2626   }
2627 
2628   if (nccl_use_nonblocking()) {
2629     groupEndNonblocking(comm);
2630   } else {
2631     groupEnd();
2632   }
2633 
2634   // Record end after ncclGroupEnd
2635   // TODO(eqy): is this still necessary if avoidRecordStreams_ is set?
2636   work->ncclEndEvent_->record(ncclStream);
2637 
2638   if (avoidRecordStreams_) {
2639     // other functions expect an initialized ptr if avoidRecordStreams_ is set
2640     work->stashed_for_allocator_safety_ =
2641         std::make_shared<std::vector<at::Tensor>>();
2642   }
2643 
2644   // Notify graphs before we check the capture status preemptively
2645   at::cuda::CUDAGraph::inc_pending_event_queries();
2646 
2647   if (enqueue) {
2648     workEnqueue(work);
2649   } else {
2650     at::cuda::CUDAGraph::dec_pending_event_queries();
2651   }
2652 
2653   coalescing_state_ = 0;
2654   coalescedComm_ = nullptr;
2655   return work;
2656 }
2657 
2658 c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing() {
2659   // Default OpType to COALESCED if not specified
2660   return endCoalescing(OpType::COALESCED);
2661 }
2662 
2663 template <typename Fn, typename PreProcess, typename PostProcess>
2664 c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
2665     std::vector<at::Tensor>& inputs,
2666     std::vector<at::Tensor>& outputs,
2667     Fn fn,
2668     PreProcess pre,
2669     PostProcess post,
2670     OpType opType,
2671     const char* profilingTitle,
2672     bool avoidRecordStreams,
2673     bool nanCheck) {
2674   // Environment setting by the user may add onto collective call's option
2675   avoidRecordStreams |= avoidRecordStreams_;
2676   nanCheck &= enableNanCheck_;
2677 
2678   c10::cuda::CaptureStatus capture_status =
2679       c10::cuda::currentStreamCaptureStatusMayInitCtx();
2680   errorIfCapturingNonCapturableNCCL(capture_status);
2681 
2682   // Bump collective counter
2683   seqCollective_++;
2684   op_id_++;
2685 
2686   auto device = getDevice(inputs[0]);
2687   const auto key = getKeyFromDevice(device);
2688   auto ncclComm = getNCCLComm(key, device, opType);
2689 
2690   if (coalescing_state_ & CoalActive) {
2691     coalescing_state_ |= CoalColl;
2692     if (coalescedDevice_.index() < 0) {
2693       coalescedDevice_ = device;
2694     } else {
2695       TORCH_CHECK(
2696           coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG);
2697     }
2698     if (coalescedComm_ == nullptr) {
2699       coalescedComm_ = ncclComm;
2700     } else {
2701       TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
2702     }
2703   }
2704 
2705   // Used many times below, so we stash the unordered_map lookup
2706   auto ncclStream = ncclStreams_.at(key);
2707 
2708   // First let NCCL streams wait for input tensors allocation streams
2709   syncStream(device, ncclEvents_[key], ncclStream);
2710 
2711   bool enqueue =
2712       !coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None;
2713   auto work =
2714       initWork(device, rank_, opType, profilingTitle, inputs, outputs, enqueue);
2715 
2716   // Store references to outputs to be used by WorkNCCL::result and operator<<.
2717   work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
2718 
2719   if (avoidRecordStreams) {
2720     work->stashed_for_allocator_safety_ =
2721         std::make_shared<std::vector<at::Tensor>>(inputs);
2722   }
2723 
2724   at::cuda::OptionalCUDAGuard gpuGuard(device);
2725 
2726   if (nanCheck) {
2727     for (const auto& input : inputs) {
2728       checkForNan(input, ncclStream);
2729     }
2730   }
2731 
2732   // Start event should only be recorded before the ncclGroupStart()
2733   if (work->timingEnabled_) {
2734     work->ncclStartEvent_->record(ncclStream);
2735   }
2736 
2737   pre(ncclStream, work);
2738 
2739   ncclComm_t comm = ncclComm->getNcclComm();
2740 
2741   // Both `inputs' and `outputs' are created on a worker stream and used in
2742   // different ncclStreams.  Hence, both must record the ncclStream to
2743   // prevent being freed before the collective finishes.
2744   //
2745   // We only record `inputs' here, and leave recording `outputs' to `fn' for
2746   // operations where `inputs' and `outputs' are not the same.
2747   //
2748   // See [Sync Streams].
2749   if (!avoidRecordStreams) {
2750     for (const auto& input : inputs) {
2751       if (!input.is_sparse()) {
2752         c10::cuda::CUDACachingAllocator::recordStream(
2753             input.storage().data_ptr(), ncclStream);
2754       } else {
2755         // for sparse input case record streams on both index and value
2756         // tensors
2757         c10::cuda::CUDACachingAllocator::recordStream(
2758             input.values().storage().data_ptr(), ncclStream);
2759         c10::cuda::CUDACachingAllocator::recordStream(
2760             input.indices().storage().data_ptr(), ncclStream);
2761       }
2762     }
2763   }
2764 
2765 // Not all collectives have the same signature, e.g, all-reduce take in a Tensor
2766 // as the input and output while all-to-all take in a vector of Tensors as input
2767 // and output. Because we define the signature of the fn to take only single
2768 // tensor as input and output, we need to do a hack to get the first element in
2769 // the vector and pass it to fn.
2770 // TODO: we should clean up this in future (by either entirely removing lambda's
2771 // or removing input and output from lambda's signature).
2772 #ifndef NCCL_HAS_COMM_NONBLOCKING
2773   C10D_NCCL_CHECK(
2774       fn(inputs[0], outputs[0], comm, ncclStream),
2775       ncclComm->getNcclCommFailureReason());
2776 #else
2777   C10D_NCCL_CHECK_TIMEOUT(
2778       fn(inputs[0], outputs[0], comm, ncclStream),
2779       comm,
2780       ncclComm->getNcclCommFailureReason());
2781 #endif
2782 
2783   post(ncclStream, work);
2784 
2785   // End event should only be recorded after the ncclGroupEnd()
2786   if (!coalescing_state_) {
2787     work->ncclEndEvent_->record(ncclStream);
2788   }
2789   work->ncclComm_ = ncclComm;
2790 
2791   {
2792     c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream);
2793     std::vector<at::Device> devices{device};
2794     work->future_ = c10::make_intrusive<at::ivalue::Future>(
2795         c10::ListType::create(c10::TensorType::get()), devices);
2796 
2797     // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA
2798     // future blocks the stream this callback runs on the corresponding
2799     // ncclEndEvents_ ensuring appropriate synchronization.
2800     if (work->recordFunctionEndCallback_) {
2801       work->future_->addCallback(
2802           [work](at::ivalue::Future& /* unused */) {
2803             work->recordFunctionEndCallback_();
2804           },
2805           // uses_future = false allows us to skip synchronization in
2806           // ivalue::Future, but is only valid as long as the lambda doesn't use
2807           // the "Future" argument.
2808           /*uses_future=*/false);
2809     }
2810     work->future_->markCompleted(at::IValue(*work->outputs_));
2811   }
2812 
2813   // Set appropriate work parameters.
2814   work->blockingWait_ = blockingWait_;
2815   work->avoidRecordStreams_ = avoidRecordStreams;
2816   work->store_ = store_;
2817   assignTimeoutToWork(work, options_);
2818   // Record size info for debug. We only record the size on the first device as
2819   // multi-device per process is deprecated
2820   work->numelIn_ = 0;
2821   work->numelOut_ = 0;
2822   for (const auto& input : inputs) {
2823     work->numelIn_ += input.numel();
2824   }
2825   for (const auto& output : outputs) {
2826     work->numelOut_ += output.numel();
2827   }
2828 
2829   // Notify graphs before we check the capture status preemptively
2830   at::cuda::CUDAGraph::inc_pending_event_queries();
2831   if (enqueue) {
2832     workEnqueue(work);
2833   } else {
2834     at::cuda::CUDAGraph::dec_pending_event_queries();
2835   }
2836 
2837   return work;
2838 }
2839 
2840 template <typename Fn>
2841 c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
2842     std::vector<at::Tensor>& inputs,
2843     std::vector<at::Tensor>& outputs,
2844     Fn fn,
2845     OpType opType,
2846     const char* profilingTitle,
2847     bool avoidRecordStreams) {
2848   // Environment setting by the user may add onto collective call's option
2849   avoidRecordStreams |= avoidRecordStreams_;
2850   c10::cuda::CaptureStatus capture_status =
2851       c10::cuda::currentStreamCaptureStatusMayInitCtx();
2852   errorIfCapturingNonCapturableNCCL(capture_status);
2853 
2854   // Bump collective counter
2855   seqCollective_++;
2856 
2857   // For coalescingManager collectives, there is no individual c++ call per
2858   // collective so there is no flight record and we increment seq*_ and op_id_
2859   // together. Compare this to startCoalesing/endCoalescing flow where we
2860   // increment seq_ once per group and increment op_id_ once per indvidual
2861   // operation within the group
2862   op_id_++;
2863 
2864   // Currently, the API permits one scenario where inputs.size() and
2865   // outputs.size() are > 0.
2866   // 1. If the call was a _coalesced call, all inputs must be on the same
2867   // device.
2868   //    The group of nccl calls applies the collective separately to each input,
2869   //    but the group as a whole should be efficient, and might even execute as
2870   //    a single fused kernel.
2871   auto device = getDevice(inputs[0]);
2872   const auto key = getKeyFromDevice(device);
2873   auto ncclComm = getNCCLComm(key, device, opType);
2874 
2875   if (coalescing_state_ & CoalActive) {
2876     coalescing_state_ |= CoalColl;
2877     if (coalescedDevice_.index() < 0) {
2878       coalescedDevice_ = device;
2879     } else {
2880       TORCH_CHECK(
2881           coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG);
2882     }
2883     if (coalescedComm_ == nullptr) {
2884       coalescedComm_ = ncclComm;
2885     } else {
2886       TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
2887     }
2888   }
2889 
2890   // Used many times below, so we stash the unordered_map lookup
2891   auto ncclStream = ncclStreams_.at(key);
2892 
2893   // First let NCCL streams wait for input tensors allocation streams
2894   syncStream(device, ncclEvents_[key], ncclStream);
2895 
2896   auto work = initWork(
2897       device, rank_, opType, profilingTitle, inputs, outputs, /*record=*/true);
2898 
2899   // Store references to outputs to be used by WorkNCCL::result and operator<<.
2900   work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
2901 
2902   if (avoidRecordStreams) {
2903     work->stashed_for_allocator_safety_ =
2904         std::make_shared<std::vector<at::Tensor>>(inputs);
2905   }
2906 
2907   at::cuda::OptionalCUDAGuard gpuGuard(device);
2908 
2909   // Start event should only be recorded before the ncclGroupStart() (which
2910   // happens inside AutoNcclGroup guard below)
2911   if (work->timingEnabled_) {
2912     work->ncclStartEvent_->record(ncclStream);
2913   }
2914 
2915   ncclComm_t comm = ncclComm->getNcclComm();
2916 
2917 // TODO(kwen2501): this should be moved to c10d tests, to qualify a NCCL
2918 // upgrade. Once a NCCL version is qualified, this code should not be needed at
2919 // runtime.
2920 #ifdef PGNCCL_ENABLE_HASH
2921   if (enableCollecticeHashDebug_.load()) {
2922     auto numel = getTensorsNumel(inputs);
2923     auto hashValue = hashTensors(inputs);
2924     PRINT_COLLECTIVE_HASH_SIGNATURE(
2925         "input", opTypeToString(opType), numel, hashValue);
2926   }
2927 #endif
2928 
2929   {
2930     torch::cuda::nccl::AutoNcclGroup nccl_group_guard(
2931         comm, nccl_use_nonblocking());
2932     for (const auto i : c10::irange(inputs.size())) {
2933       // Both `inputs' and `outputs' are created on a worker stream and used in
2934       // different ncclStreams.  Hence, both must record the ncclStream to
2935       // prevent being freed before the collective finishes.
2936       //
2937       // We only record `inputs' here, and leave recording `outputs' to `fn' for
2938       // operations where `inputs' and `outputs' are not the same.
2939       //
2940       // See [Sync Streams].
2941       if (!avoidRecordStreams) {
2942         if (!inputs[i].is_sparse()) {
2943           c10::cuda::CUDACachingAllocator::recordStream(
2944               inputs[i].storage().data_ptr(), ncclStream);
2945         } else {
2946           // for sparse input case record streams on both index and value
2947           // tensors
2948           c10::cuda::CUDACachingAllocator::recordStream(
2949               inputs[i].values().storage().data_ptr(), ncclStream);
2950           c10::cuda::CUDACachingAllocator::recordStream(
2951               inputs[i].indices().storage().data_ptr(), ncclStream);
2952         }
2953       }
2954 #ifndef NCCL_HAS_COMM_NONBLOCKING
2955       C10D_NCCL_CHECK(
2956           fn(inputs[i], outputs[i], comm, ncclStream),
2957           ncclComm->getNcclCommFailureReason());
2958 #else
2959       C10D_NCCL_CHECK_TIMEOUT(
2960           fn(inputs[i], outputs[i], comm, ncclStream),
2961           comm,
2962           ncclComm->getNcclCommFailureReason());
2963 #endif
2964     }
2965   }
2966 
2967   work->ncclEndEvent_->record(ncclStream);
2968   work->ncclComm_ = ncclComm;
2969 
2970   {
2971     c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream);
2972     std::vector<at::Device> devices{device};
2973     work->future_ = c10::make_intrusive<at::ivalue::Future>(
2974         c10::ListType::create(c10::TensorType::get()), devices);
2975 
2976     // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA
2977     // future blocks the stream this callback runs on the corresponding
2978     // ncclEndEvents_ ensuring appropriate synchronization.
2979     if (work->recordFunctionEndCallback_) {
2980       work->future_->addCallback(
2981           [work](at::ivalue::Future& /* unused */) {
2982             work->recordFunctionEndCallback_();
2983           },
2984           // uses_future = false allows us to skip synchronization in
2985           // ivalue::Future, but is only valid as long as the lambda doesn't use
2986           // the "Future" argument.
2987           /*uses_future=*/false);
2988     }
2989     work->future_->markCompleted(at::IValue(*work->outputs_));
2990   }
2991 
2992   // Set appropriate work parameters.
2993   work->blockingWait_ = blockingWait_;
2994   work->avoidRecordStreams_ = avoidRecordStreams;
2995   work->store_ = store_;
2996   assignTimeoutToWork(work, options_);
2997   // Record size info for debug. We only record the size on the first device as
2998   // multi-device per process is deprecated
2999   work->numelIn_ = inputs[0].numel();
3000   work->numelOut_ = outputs[0].numel();
3001 
3002   /* Note [cuda graph capture and workEnqueue]
3003 
3004   Normal behavior of the C10D watchdog is to query cuda events on work objects
3005   periodically, but when cuda graph recording is active these event queries
3006   would crash or mess up the recording.
3007 
3008   To ensure we do not enqueue a work object to the watchdog when cuda graph
3009   capture is active, we use a one-way sync. We increment a flag pre-emptively,
3010   indicating our intent to enqueue a work object. Then we check capture_status
3011   to see if (a) capturing is already in progress (we cannot enqueue in this
3012   case), (b) capturing hasn't started yet, so we can trust that no capture will
3013   start (since a pre-condition of starting a capture is to check the event query
3014   count is 0).
3015 
3016   If we are not able to enqueue the work due to capture-in-progress, we finally
3017   decrement the counter.
3018 
3019   For this reason we cannot easily move the increment inside workEnqueue unless
3020   we also change the semantic of workEnqueue to 'maybeWorkEnqueue'.
3021 
3022   TODO:
3023    - Is our design for flight recorder safe in this context?  are we recording
3024   any FR events during cudagraph capture? if so, they won't be safe to poll for
3025   completion status.
3026   */
3027   at::cuda::CUDAGraph::inc_pending_event_queries();
3028   if (capture_status == c10::cuda::CaptureStatus::None) {
3029     workEnqueue(work);
3030   } else {
3031     at::cuda::CUDAGraph::dec_pending_event_queries();
3032   }
3033   // TODO(whc) if the work isn't enqueued, I don't feel great about returning
3034   // it, since interactions with it by usercode won't behave normally - they
3035   // won't observe work completion, for instance.  Will this lead to silent
3036   // problems during capture?
3037   return work;
3038 }
3039 
3040 template <typename Fn, typename PreProcess, typename PostProcess>
3041 c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
3042     at::Tensor& tensor,
3043     Fn fn,
3044     int peer,
3045     OpType opType,
3046     PreProcess pre,
3047     PostProcess post,
3048     const char* profilingTitle) {
3049   // avoidRecordStreams_ note:
3050   // send, recv, and irecv should be ok with avoidRecordStreams,
3051   // However, for isend, I don't think the API requires the user
3052   // to wait() on the returned handle, so ProcessGroupNCCL can't know
3053   // when it's safe to release the input back to the allocator,
3054   // and the present call has no way to know it's not an isend.
3055   // Therefore, we warn and fall back to the typical recordStream logic:
3056   if (avoidRecordStreams_) {
3057     TORCH_WARN_ONCE(
3058         "TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point "
3059         "collectives.");
3060   }
3061 
3062   auto device = getDevice(tensor);
3063   std::string key;
3064   int p2pRank = 0, p2pTargetRank = 0;
3065   bool isSendRecvSelf = false;
3066   // For batch_isend_irecv, ncclGroupStart() would be called upfront
3067   bool batchP2P = ncclActiveGroupCounter_ > 0;
3068   if (batchP2P) {
3069     // For batch P2P, we need to treat it like a collective when selecting
3070     // communicator, because other ranks can call into this batch other than my
3071     // rank and my peer
3072     key = getKeyFromDevice(device);
3073     p2pRank = rank_;
3074     p2pTargetRank = peer;
3075   } else {
3076     // For single P2P, preserve the old two-rank behavior (to avoid perf diff)
3077     key = getKeySendRecv(rank_, peer);
3078     p2pRank = rank_ <= peer ? 0 : 1;
3079     isSendRecvSelf = rank_ == peer;
3080     p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank;
3081 
3082     if (!coalescing_state_) {
3083       // Bump P2P sequence number. Don't do so if it's a batch P2P, it will be
3084       // bumped in `startCoalescing`.
3085       seqP2P_++;
3086     }
3087   }
3088 
3089   // Bump the logical operation counter regardless of whether this op is
3090   // coalesced or individual
3091   op_id_++;
3092 
3093   auto ncclComm = getNCCLComm(key, device, opType, p2pRank, isSendRecvSelf);
3094 
3095   if (coalescing_state_ & CoalActive) {
3096     coalescing_state_ |= CoalP2P;
3097     if (coalescedDevice_.index() < 0) {
3098       coalescedDevice_ = device;
3099     } else {
3100       TORCH_CHECK(
3101           coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG);
3102     }
3103     if (coalescedComm_ == nullptr) {
3104       coalescedComm_ = ncclComm;
3105     } else {
3106       TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
3107     }
3108   }
3109 
3110   // Used many times below, so we stash the unordered_map lookup
3111   auto ncclStream = ncclStreams_.at(key);
3112   // First let NCCL streams wait for input tensors allocation streams
3113   syncStream(device, ncclEvents_[key], ncclStream);
3114 
3115   // Work itself will create the CUDA events on all GPUs of tensors
3116   c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> work;
3117   if (coalescing_state_) {
3118     // When coalescing, we record events per op that lack timing/state
3119     // information becuase there is no 'work' associated with them, and then
3120     // later in endCoalescing we record a 'coalesced' Work which has
3121     // timing/state updates via watchdog thread, but lacks op metadata such as
3122     // input/output sizes and profilingTitle per-op in the group.
3123     auto trace_id = NCCLTraceBuffer::get()->record(
3124         local_id_,
3125         std::make_tuple(pg_uid_, pg_desc_),
3126         seqCollective_,
3127         seqP2P_,
3128         op_id_,
3129         profilingTitle,
3130         {tensor},
3131         {tensor},
3132         nullptr,
3133         nullptr,
3134         options_->timeout,
3135         pgStatus_,
3136         /*isP2P=*/true);
3137     // TODO(whc) if we want to make the per-p2p-op flightrecorder entries get
3138     // their timings/states updated by proxy when the Work obj representing the
3139     // coalesce group gets its update, we could accumulate these trace_ids
3140     // together and ask FlightRecorder to take the update from one Work and
3141     // apply it to multiple entries
3142     (void)trace_id;
3143   } else {
3144     // Store references to outputs to be used by WorkNCCL::result and
3145     // operator<<. Note that these outputs are only valid for recv(), as send()
3146     // does not modify the inputs but we still create these outputs for use
3147     // cases such as profiling.
3148 
3149     work = initWork(
3150         device, rank_, opType, profilingTitle, {tensor}, {}, /*record=*/false);
3151     // This bypasses something in Work() that crashes if {tensor} is given as
3152     // output, not sure what
3153     work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
3154     work->outputs_->push_back(tensor);
3155     // TODO(whc) because we don't pass output {tensor} to initWork, we tell
3156     // initWork to not record, and then we manually call record passing all the
3157     // information it wants.
3158     work->trace_id_ = NCCLTraceBuffer::get()->record(
3159         local_id_,
3160         std::make_tuple(pg_uid_, pg_desc_),
3161         seqCollective_,
3162         seqP2P_,
3163         op_id_,
3164         profilingTitle,
3165         {tensor},
3166         {tensor},
3167         work->ncclStartEvent_.get(),
3168         work->ncclEndEvent_.get(),
3169         options_->timeout,
3170         pgStatus_,
3171         /*isP2P=*/true);
3172   }
3173 
3174   // is gpuGuard needed for the if block below, or can i swap them
3175   at::cuda::OptionalCUDAGuard gpuGuard(device);
3176 
3177   // Only check for NaN for send ops, for recv ops `tensor` can be a random
3178   // placeholder
3179   if (enableNanCheck_ && opType == OpType::SEND) {
3180     checkForNan(tensor, ncclStream);
3181   }
3182 
3183   if (!coalescing_state_) {
3184     // Start event should only be recorded before the ncclGroupStart()
3185     if (work->timingEnabled_) {
3186       work->ncclStartEvent_->record(ncclStream);
3187     }
3188 
3189     pre(ncclStream, work);
3190   }
3191 
3192   // Both send tensor and recv tensor are created on a worker stream and used
3193   // in different ncclStreams.  Hence, both must record the ncclStream to
3194   // prevent being freed before the collective finishes.
3195   //
3196   // See [Sync Streams].
3197   c10::cuda::CUDACachingAllocator::recordStream(
3198       tensor.storage().data_ptr(), ncclStream);
3199 
3200   // This part seems common to both p2p and coalesced-p2p usage?
3201   ncclComm_t comm_ = ncclComm->getNcclComm();
3202 
3203 #ifndef NCCL_HAS_COMM_NONBLOCKING
3204   C10D_NCCL_CHECK(
3205       fn(tensor, comm_, ncclStream, p2pTargetRank),
3206       ncclComm->getNcclCommFailureReason());
3207 #else
3208   C10D_NCCL_CHECK_TIMEOUT(
3209       fn(tensor, comm_, ncclStream, p2pTargetRank),
3210       ncclComm->getNcclComm(),
3211       ncclComm->getNcclCommFailureReason());
3212 #endif
3213 
3214   if (!coalescing_state_) {
3215     post(ncclStream);
3216 
3217     // End event should only be recorded after the ncclGroupEnd()
3218     work->ncclEndEvent_->record(ncclStream);
3219     work->ncclComm_ = ncclComm;
3220     work->blockingWait_ = blockingWait_;
3221     work->store_ = store_;
3222     assignTimeoutToWork(work, options_);
3223     // Record size info for debug. We only record the size on the first device
3224     // as multi-device per process is deprecated
3225     work->numelIn_ = work->numelOut_ = tensor.numel();
3226 
3227     // Future only needs to be created and marked completed with outputs for
3228     // recv(), but still create future for use cases such as profiling even for
3229     // send().
3230     {
3231       c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream);
3232       std::vector<at::Device> devices{device};
3233       work->future_ = c10::make_intrusive<at::ivalue::Future>(
3234           c10::ListType::create(c10::TensorType::get()), devices);
3235       work->future_->markCompleted(at::IValue(*work->outputs_));
3236     }
3237 
3238     // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA
3239     // future blocks the stream this callback runs on the corresponding
3240     // ncclEndEvents_ ensuring appropriate synchronization.
3241     if (work->recordFunctionEndCallback_) {
3242       work->future_->addCallback(
3243           [work](at::ivalue::Future& /* unused */) {
3244             work->recordFunctionEndCallback_();
3245           },
3246           // uses_future = false allows us to skip synchronization in
3247           // ivalue::Future, but is only valid as long as the lambda doesn't use
3248           // the "Future" argument.
3249           /*uses_future=*/false);
3250     }
3251   }
3252 
3253   // Enqueue P2P op so that it can be cancelled by NCCL watchdog
3254   c10::cuda::CaptureStatus capture_status =
3255       c10::cuda::currentStreamCaptureStatusMayInitCtx();
3256 
3257   // Notify graphs before we check the capture status preemptively
3258   at::cuda::CUDAGraph::inc_pending_event_queries();
3259 
3260   if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) {
3261     workEnqueue(work);
3262     return work;
3263   } else {
3264     at::cuda::CUDAGraph::dec_pending_event_queries();
3265     return nullptr;
3266   }
3267 }
3268 
3269 template <typename Fn, typename PreProcess, typename PostProcess>
3270 c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
3271     at::Tensor& input,
3272     at::Tensor& output,
3273     Fn fn,
3274     PreProcess pre,
3275     PostProcess post,
3276     OpType opType,
3277     const char* profilingTitle,
3278     bool avoidRecordStreams,
3279     bool nanCheck) {
3280   auto inputs = std::vector<at::Tensor>{input};
3281   auto outputs = std::vector<at::Tensor>{output};
3282   return collective(
3283       inputs,
3284       outputs,
3285       fn,
3286       pre,
3287       post,
3288       opType,
3289       profilingTitle,
3290       avoidRecordStreams,
3291       nanCheck);
3292 }
3293 
3294 template <typename Fn>
3295 c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
3296     at::Tensor& input,
3297     at::Tensor& output,
3298     Fn fn,
3299     OpType opType,
3300     const char* profilingTitle,
3301     bool avoidRecordStreams,
3302     bool nanCheck) {
3303   auto inputs = std::vector<at::Tensor>{input};
3304   auto outputs = std::vector<at::Tensor>{output};
3305   return collective(
3306       inputs,
3307       outputs,
3308       fn,
3309       [](at::cuda::CUDAStream&,
3310          c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
3311       [](at::cuda::CUDAStream&,
3312          c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
3313       opType,
3314       profilingTitle,
3315       avoidRecordStreams,
3316       nanCheck);
3317 }
3318 
3319 template <typename Fn>
3320 c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
3321     at::Tensor& tensor,
3322     Fn fn,
3323     int peer,
3324     OpType opType,
3325     const char* profilingTitle) {
3326   return pointToPoint(
3327       tensor,
3328       fn,
3329       peer,
3330       opType,
3331       [](at::cuda::CUDAStream&,
3332          c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
3333       [](at::cuda::CUDAStream&) {},
3334       profilingTitle);
3335 }
3336 
3337 c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
3338     std::vector<at::Tensor>& tensors,
3339     const AllreduceOptions& opts) {
3340   TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3341   auto tensor = tensors.back();
3342   TORCH_CHECK(
3343       !isFloat8Type(tensor.scalar_type()),
3344       "Float8 dtypes are not currenlty supported for NCCL reductions");
3345 #ifdef IS_NCCLX
3346   tensor = tensor.coalesce();
3347   at::Tensor outputTensor =
3348       torch::zeros(tensor.sizes(), tensor.options().layout(torch::kStrided));
3349   auto work = collective(
3350       tensor,
3351       outputTensor,
3352       [&](at::Tensor& input,
3353           at::Tensor& output,
3354           ncclComm_t comm,
3355           at::cuda::CUDAStream& stream) {
3356         auto ncclDataType = getNcclDataType(input.scalar_type());
3357         auto ncclReduceOp =
3358             getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
3359 
3360         size_t num_elements = output.numel();
3361         auto indices = input.indices();
3362         auto sizes = input.sizes();
3363         int colSize = sizes[1];
3364         auto rows = indices[0];
3365         size_t blockCount = rows.sizes()[0];
3366         auto recvIndices = indices[0] * colSize;
3367 
3368         // prevent output and recvIndices from being freed
3369         c10::cuda::CUDACachingAllocator::recordStream(
3370             output.storage().data_ptr(), stream);
3371         c10::cuda::CUDACachingAllocator::recordStream(
3372             recvIndices.storage().data_ptr(), stream);
3373         auto result = ncclAllReduceSparseBlock(
3374             input._values().data_ptr(), // sendbuff
3375             recvIndices.data_ptr<int64_t>(), // recv_indices
3376             blockCount, // block_count
3377             colSize, // block_length
3378             output.data_ptr(), // recvbuff
3379             output.numel(), // recv_count
3380             ncclDataType,
3381             ncclReduceOp,
3382             comm,
3383             stream.stream());
3384         return result;
3385       },
3386       [](at::cuda::CUDAStream& ncclStream,
3387          c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
3388       [&](at::cuda::CUDAStream& ncclStream,
3389           c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
3390         // Convert output tensors to sparse and back into tensors.
3391         at::cuda::CUDAStreamGuard guard(ncclStream);
3392         if (opts.sparseIndices.has_value()) {
3393           tensor = at::sparse_coo_tensor(
3394               opts.sparseIndices.value(), outputTensor, tensor.sizes());
3395         } else {
3396           tensor = outputTensor.to_sparse();
3397         }
3398       },
3399       OpType::_ALLREDUCE_SPARSE,
3400       "nccl:all_reduce_sparse");
3401   return work;
3402 #else
3403   // If the nccl branch is not "exp" then we just error
3404   C10_THROW_ERROR(
3405       Error,
3406       "NCCL does not support all_reduce with sparse tensors. Please use dense tensors instead.");
3407 #endif
3408 }
3409 
3410 c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_impl(
3411     at::Tensor& tensor,
3412     const AllreduceOptions& opts) {
3413   return collective(
3414       tensor,
3415       tensor,
3416       [&](at::Tensor& input,
3417           at::Tensor& output,
3418           ncclComm_t comm,
3419           at::cuda::CUDAStream& stream) {
3420         auto ncclDataType = getNcclDataType(input.scalar_type());
3421         auto ncclReduceOp =
3422             getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
3423         return ncclAllReduce(
3424             input.data_ptr(),
3425             output.data_ptr(),
3426             input.numel(),
3427             ncclDataType,
3428             ncclReduceOp,
3429             comm,
3430             stream.stream());
3431       },
3432       OpType::ALLREDUCE,
3433       "nccl:all_reduce");
3434 }
3435 
3436 c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
3437     std::vector<at::Tensor>& tensors,
3438     const AllreduceOptions& opts) {
3439   TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3440   auto tensor = tensors.back();
3441   if (tensor.is_complex()) {
3442     TORCH_CHECK(
3443         complexViewAsRealAllowed(opts.reduceOp),
3444         "all_reduce does not support",
3445         opts.reduceOp,
3446         "on complex tensors");
3447     tensor = at::view_as_real(tensor);
3448   }
3449   check_gpu_single_tensor(tensor);
3450 
3451   if (intraNodeComm_ != nullptr && opts.reduceOp == ReduceOp::SUM) {
3452     using namespace intra_node_comm;
3453     auto algo = intraNodeComm_->selectAllReduceAlgo(tensor);
3454     if (algo != intra_node_comm::AllReduceAlgo::NONE) {
3455       intraNodeComm_->allReduce(tensor, algo);
3456       return c10::make_intrusive<IntraNodeCommWork>();
3457     }
3458   }
3459   TORCH_CHECK(
3460       !isFloat8Type(tensor.scalar_type()),
3461       "Float8 dtypes are not currenlty supported for NCCL reductions");
3462   // @lint-ignore CLANGTIDY
3463   RECORD_PARAM_COMMS_DATA(
3464       static_cast<int>(
3465           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
3466       std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
3467       tensors, // inputTensors
3468       tensors, // outputTensors
3469       rank_, // rank
3470       "allreduce", // collective name
3471       tensor.numel(), // inNelems
3472       tensor.numel(), // outNelems
3473       tensor.scalar_type(), // dType
3474       std::vector<int64_t>(), // inSplitSizes
3475       std::vector<int64_t>(), // outSplitSizes
3476       globalRankStart, // globalRankStart
3477       globalRankStride, // globalRankStride
3478       this->getSize()); // worldSize
3479 
3480   // avoidRecordStreams_ note: collective() will stash tensors.
3481   return allreduce_impl(tensor, opts);
3482 }
3483 
3484 c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced(
3485     std::vector<at::Tensor>& tensors,
3486     const AllreduceCoalescedOptions& opts) {
3487   auto total_numel = check_gpu_tensors_same_device(tensors);
3488   TORCH_CHECK(
3489       !isFloat8Type(tensors.back().scalar_type()),
3490       "Float8 dtypes are not currenlty supported for NCCL reductions");
3491 
3492   // @lint-ignore CLANGTIDY
3493   RECORD_PARAM_COMMS_DATA(
3494       static_cast<int>(
3495           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
3496       std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
3497       tensors, // inputTensors
3498       tensors, // outputTensors
3499       rank_, // rank
3500       "allreduce_coalesced", // collective name
3501       total_numel, // inNelems
3502       total_numel, // outNelems
3503       tensors[0].scalar_type(), // dType
3504       // I'm not sure what in,outSplitSizes mean here.
3505       std::vector<int64_t>(), // inSplitSizes
3506       std::vector<int64_t>(), // outSplitSizes
3507       globalRankStart, // globalRankStart
3508       globalRankStride, // globalRankStride
3509       this->getSize()); // worldSize
3510 
3511   // avoidRecordStreams_ note: collective() will stash tensors.
3512   return collectiveCoalesced(
3513       tensors,
3514       tensors,
3515       [&](at::Tensor& input,
3516           at::Tensor& output,
3517           ncclComm_t comm,
3518           at::cuda::CUDAStream& stream) {
3519         auto ncclDataType = getNcclDataType(input.scalar_type());
3520         auto ncclReduceOp =
3521             getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
3522         return ncclAllReduce(
3523             input.data_ptr(),
3524             output.data_ptr(),
3525             input.numel(),
3526             ncclDataType,
3527             ncclReduceOp,
3528             comm,
3529             stream.stream());
3530       },
3531       OpType::COALESCED,
3532       "nccl:allreduce_coalesced");
3533 }
3534 
3535 c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
3536     std::vector<at::Tensor>& tensors,
3537     const BroadcastOptions& opts) {
3538   TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3539   auto tensor = tensors.back();
3540   if (tensor.is_complex()) {
3541     tensor = at::view_as_real(tensor);
3542   }
3543   check_gpu_single_tensor(tensor);
3544 
3545   // @lint-ignore CLANGTIDY
3546   RECORD_PARAM_COMMS_DATA(
3547       static_cast<int>(
3548           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
3549       std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
3550       tensors, // inputTensors
3551       tensors, // outputTensors
3552       opts.rootRank, // root rank
3553       "broadcast", // collective name
3554       tensor.numel(), // inNelems
3555       tensor.numel(), // outNelems
3556       tensor.scalar_type(), // dType
3557       std::vector<int64_t>(), // inSplitSizes
3558       std::vector<int64_t>(), // outSplitSizes
3559       globalRankStart, // globalRankStart
3560       globalRankStride, // globalRankStride
3561       this->getSize()); // worldSize
3562 
3563   // avoidRecordStreams_ note: collective() will stash tensors.
3564   bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
3565 
3566   const auto root = opts.rootRank + opts.rootTensor;
3567   bool nanCheck = (root == rank_);
3568 
3569   return collective(
3570       tensor,
3571       tensor,
3572       [&](at::Tensor& input,
3573           at::Tensor& output,
3574           ncclComm_t comm,
3575           at::cuda::CUDAStream& stream) {
3576         return ncclBcast(
3577             input.data_ptr(),
3578             input.numel(),
3579             getNcclDataType(input.scalar_type()),
3580             root,
3581             comm,
3582             stream.stream());
3583       },
3584       OpType::BROADCAST,
3585       "nccl:broadcast",
3586       avoidRecordStreams,
3587       nanCheck);
3588 }
3589 
3590 // _broadcast_oop adds an out-of-place broadcast in PGNCCL
3591 // Custom collectives may be implemented by coalescing broadcast operations
3592 // One use-case is implementing a vector all_gather (all_gather_v)
3593 // where unevenly sized inputs are gathered among participating ranks
3594 // Since all_gather provides an out-of-place API, an all_gather_v
3595 // semantic implemented inside pg_nccl.all_gather also needs to support
3596 // out-of-place, for which an out-of-place broadcast is required to be added
3597 c10::intrusive_ptr<Work> ProcessGroupNCCL::_broadcast_oop(
3598     at::Tensor& outputTensor,
3599     at::Tensor& inputTensor,
3600     const BroadcastOptions& opts) {
3601   if (outputTensor.numel() != inputTensor.numel()) {
3602     C10_THROW_ERROR(
3603         ValueError,
3604         "Tensor input and output of _broadcast_oop must have the same number of elements ");
3605   }
3606   const auto root = opts.rootRank + opts.rootTensor;
3607   bool nanCheck = (root == rank_);
3608   return collective(
3609       inputTensor,
3610       outputTensor,
3611       [&](at::Tensor& input,
3612           at::Tensor& output,
3613           ncclComm_t comm,
3614           at::cuda::CUDAStream& stream) {
3615         return ncclBroadcast(
3616             input.data_ptr(),
3617             output.data_ptr(),
3618             input.numel(),
3619             getNcclDataType(input.scalar_type()),
3620             root,
3621             comm,
3622             stream.stream());
3623       },
3624       OpType::BROADCAST,
3625       "nccl:_broadcast_oop",
3626       /*avoidRecordStreams=*/false,
3627       nanCheck);
3628 }
3629 
3630 c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
3631     std::vector<at::Tensor>& tensors,
3632     const ReduceOptions& opts) {
3633   TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3634   // @lint-ignore CLANGTIDY
3635   auto tensor = tensors.back();
3636   if (tensor.is_complex()) {
3637     TORCH_CHECK(
3638         complexViewAsRealAllowed(opts.reduceOp),
3639         "reduce does not support",
3640         opts.reduceOp,
3641         "on complex tensors");
3642     tensor = at::view_as_real(tensor);
3643   }
3644   check_gpu_single_tensor(tensor);
3645   RECORD_PARAM_COMMS_DATA(
3646       static_cast<int>(
3647           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
3648       std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
3649       tensors, // inputTensors
3650       tensors, // outputTensors
3651       opts.rootRank, // root rank
3652       "reduce", // collective name
3653       tensor.numel(), // inNelems
3654       tensor.numel(), // outNelems
3655       tensor.scalar_type(), // dType
3656       std::vector<int64_t>(), // inSplitSizes
3657       std::vector<int64_t>(), // outSplitSizes
3658       globalRankStart, // globalRankStart
3659       globalRankStride, // globalRankStride
3660       this->getSize()); // worldSize
3661 
3662   // avoidRecordStreams_ note: collective() will stash tensors.
3663   return collective(
3664       tensor,
3665       tensor,
3666       [&](at::Tensor& input,
3667           at::Tensor& output,
3668           ncclComm_t comm,
3669           at::cuda::CUDAStream& stream) {
3670         const auto root = opts.rootRank + opts.rootTensor;
3671         auto ncclDataType = getNcclDataType(input.scalar_type());
3672         auto ncclReduceOp =
3673             getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
3674         return ncclReduce(
3675             input.data_ptr(),
3676             output.data_ptr(),
3677             input.numel(),
3678             ncclDataType,
3679             ncclReduceOp,
3680             root,
3681             comm,
3682             stream.stream());
3683       },
3684       OpType::REDUCE,
3685       "nccl:reduce");
3686 }
3687 
3688 // _reduce_oop exposes an out-of-place reduce from PGNCCL
3689 // Custom collectives may be implemented by coalescing reduce operations
3690 // One use-case is implementing a vector reduce_scatter (reduce_scatter_v)
3691 // where inputs are reduced and scattered unevenly among participating ranks
3692 // Since reduce_scatter provides an out-of-place API, a reduce_scatter_v
3693 // semantic implemented inside pg_nccl.reduce_scatter also needs to support
3694 // out-of-place, for which an out-of-place reduce is required to be added
3695 c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_oop(
3696     at::Tensor& outputTensor,
3697     at::Tensor& inputTensor,
3698     const ReduceOptions& opts) {
3699   if (outputTensor.numel() != inputTensor.numel()) {
3700     C10_THROW_ERROR(
3701         ValueError,
3702         "Tensor input and output of _reduce_oop must have the same number of elements ");
3703   }
3704   return collective(
3705       inputTensor,
3706       outputTensor,
3707       [&](at::Tensor& input,
3708           at::Tensor& output,
3709           ncclComm_t comm,
3710           at::cuda::CUDAStream& stream) {
3711         const auto root = opts.rootRank + opts.rootTensor;
3712         const auto ncclDataType = getNcclDataType(input.scalar_type());
3713         const auto ncclReduceOp =
3714             getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
3715         return ncclReduce(
3716             input.data_ptr(),
3717             output.data_ptr(),
3718             input.numel(),
3719             ncclDataType,
3720             ncclReduceOp,
3721             (int)root,
3722             comm,
3723             stream.stream());
3724       },
3725       OpType::REDUCE,
3726       "nccl:_reduce_oop");
3727 }
3728 
3729 c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
3730     std::vector<std::vector<at::Tensor>>& outputTensors,
3731     std::vector<at::Tensor>& inputTensors,
3732     const AllgatherOptions& opts) {
3733   TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3734   // @lint-ignore CLANGTIDY
3735   auto inputTensor = inputTensors.back();
3736   check_gpu_single_tensor(inputTensor);
3737   // @lint-ignore CLANGTIDY
3738   auto outputTensors_ = outputTensors.back();
3739 
3740   RECORD_PARAM_COMMS_DATA(
3741       static_cast<int>(
3742           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
3743       std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
3744       inputTensors, // inputTensors
3745       outputTensors, // outputTensors
3746       rank_, // rank
3747       "all_gather", // collective name
3748       inputTensor.numel(), // inNelems
3749       inputTensor.numel() * // outNelems
3750           this->getSize(),
3751       inputTensor.scalar_type(), // dType
3752       std::vector<int64_t>(), // inSplitSizes
3753       std::vector<int64_t>(), // outSplitSize
3754       globalRankStart, // globalRankStart
3755       globalRankStride, // globalRankStride
3756       this->getSize()); // worldSize
3757 
3758   bool same_size = check_same_size(outputTensors_);
3759   if (same_size) {
3760     // Flatten a vector of tensors into a single, stacked tensor.
3761     at::Tensor outputFlattened = newLikeFlat(outputTensors_);
3762 
3763     return collective(
3764         inputTensor,
3765         outputFlattened,
3766         [&](at::Tensor& input,
3767             at::Tensor& output,
3768             ncclComm_t comm,
3769             at::cuda::CUDAStream& stream) {
3770           if (!avoidRecordStreams_) {
3771             c10::cuda::CUDACachingAllocator::recordStream(
3772                 output.storage().data_ptr(), stream);
3773           }
3774           return ncclAllGather(
3775               input.data_ptr(),
3776               output.data_ptr(),
3777               input.numel(),
3778               getNcclDataType(input.scalar_type()),
3779               comm,
3780               stream.stream());
3781         },
3782         [](at::cuda::CUDAStream& ncclStream,
3783            c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
3784           // avoidRecordStreams_ note: We actually don't need to stash anything
3785           // here.
3786           //  - inputTensors is stashed onto work->stashed_for_allocator_safety_
3787           //    in collective().
3788           //  - outputFlattened is stashed onto work->outputs_ in collective().
3789           //  - User-facing outputTensors should be held by the user until after
3790           //    waiting on work_, or the call makes no sense.
3791           // So all participating tensors are accounted for, and won't be
3792           // released back to their allocation streams until after work_ is
3793           // waited on.
3794         },
3795         [&](at::cuda::CUDAStream& ncclStream,
3796             c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
3797           // Copy the flattened output tensors to the outputs.
3798           at::cuda::CUDAStreamGuard guard(ncclStream);
3799           for (const auto j : c10::irange(outputTensors_.size())) {
3800             // See [Sync Streams].
3801             if (!avoidRecordStreams_) {
3802               c10::cuda::CUDACachingAllocator::recordStream(
3803                   outputTensors_[j].storage().data_ptr(), ncclStream);
3804             }
3805             outputTensors_[j].copy_(outputFlattened[j], true);
3806           }
3807         },
3808         OpType::ALLGATHER,
3809         "nccl:all_gather");
3810   } else {
3811     const auto num_reduces = outputTensors_.size();
3812     startCoalescing();
3813     for (const int i : c10::irange(num_reduces)) {
3814       auto& output = outputTensors_[i];
3815       auto& input = (i == rank_) ? inputTensor : output;
3816       auto broadcastOpts = BroadcastOptions{
3817           static_cast<int64_t>(i), static_cast<int64_t>(0), opts.timeout};
3818       _broadcast_oop(output, input, broadcastOpts);
3819     }
3820     auto work = endCoalescing(OpType::ALLGATHER);
3821     return work;
3822   }
3823 }
3824 
3825 c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather_coalesced(
3826     std::vector<std::vector<at::Tensor>>& /* unused */,
3827     std::vector<at::Tensor>& /* unused */,
3828     const AllgatherOptions& /* unused */) {
3829   C10_THROW_ERROR(
3830       NotImplementedError,
3831       "ProcessGroupNCCL does not support allgather_coalesced");
3832 }
3833 
3834 c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather_into_tensor_coalesced(
3835     std::vector<at::Tensor>& outputs,
3836     std::vector<at::Tensor>& inputs,
3837     const AllgatherOptions& opts) {
3838   return collectiveCoalesced(
3839       inputs,
3840       outputs,
3841       [&](at::Tensor& input,
3842           at::Tensor& output,
3843           ncclComm_t comm,
3844           at::cuda::CUDAStream& stream) {
3845         return ncclAllGather(
3846             input.data_ptr(),
3847             output.data_ptr(),
3848             input.numel(),
3849             getNcclDataType(input.scalar_type()),
3850             comm,
3851             stream.stream());
3852       },
3853       OpType::COALESCED,
3854       "nccl:all_gather_into_tensor_coalesced");
3855 }
3856 
3857 c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
3858     std::vector<at::Tensor>& outputTensors,
3859     std::vector<std::vector<at::Tensor>>& inputTensors,
3860     const ReduceScatterOptions& opts) {
3861   TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3862   // @lint-ignore CLANGTIDY
3863   auto outputTensor = outputTensors.back();
3864   check_gpu_single_tensor(outputTensor);
3865   // @lint-ignore CLANGTIDY
3866   auto inputTensors_ = inputTensors.back();
3867   TORCH_CHECK(
3868       !isFloat8Type(outputTensor.scalar_type()),
3869       "Float8 dtypes are not currenlty supported for NCCL reductions");
3870 
3871   RECORD_PARAM_COMMS_DATA(
3872       static_cast<int>(
3873           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
3874       std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
3875       inputTensors, // inputTensors
3876       outputTensors, // outputTensors
3877       rank_, // rank
3878       "reduce_scatter", // collective name
3879       outputTensor.numel() * this->getSize(), // inNelems
3880       outputTensor.numel(), // outNelems
3881       outputTensor.scalar_type(), // dType
3882       std::vector<int64_t>(), // inSplitSizes
3883       std::vector<int64_t>(), // outSplitSizes
3884       globalRankStart, // globalRankStart
3885       globalRankStride, // globalRankStride
3886       this->getSize()); // worldSize
3887 
3888   bool same_size = check_same_size(inputTensors_);
3889   if (same_size) {
3890     // Flatten a vector of tensors into a single, stacked tensor.
3891     at::Tensor inputFlattened = newLikeFlat(inputTensors_);
3892 
3893     return collective(
3894         inputFlattened,
3895         outputTensor,
3896         [&](at::Tensor& input,
3897             at::Tensor& output,
3898             ncclComm_t comm,
3899             at::cuda::CUDAStream& stream) {
3900           if (!avoidRecordStreams_) {
3901             c10::cuda::CUDACachingAllocator::recordStream(
3902                 output.storage().data_ptr(), stream);
3903           }
3904           const auto ncclDataType = getNcclDataType(input.scalar_type());
3905           const auto ncclReduceOp =
3906               getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
3907           return ncclReduceScatter(
3908               input.data_ptr(),
3909               output.data_ptr(),
3910               output.numel(),
3911               ncclDataType,
3912               ncclReduceOp,
3913               comm,
3914               stream.stream());
3915         },
3916         [&](at::cuda::CUDAStream& ncclStream,
3917             c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
3918           if (avoidRecordStreams_) {
3919             // We only need to stash inputTensors.
3920             //  - inputFlattened is stashed onto
3921             //  work->stashed_for_allocator_safety_
3922             //    in collective().
3923             //  - User-facing outputTensors is stashed onto work->outputs_ in
3924             //  collective(),
3925             //    and should also be held by the user until after waiting on
3926             //    work_.
3927             auto& v = work->stashed_for_allocator_safety_;
3928             v->insert(v->end(), inputTensors_.begin(), inputTensors_.end());
3929           }
3930 
3931           // Copy the input tensors to the flattened inputs.
3932           at::cuda::CUDAStreamGuard guard(ncclStream);
3933           for (const auto j : c10::irange(inputTensors_.size())) {
3934             // See [Sync Streams].
3935             if (!avoidRecordStreams_) {
3936               c10::cuda::CUDACachingAllocator::recordStream(
3937                   inputTensors_[j].storage().data_ptr(), ncclStream);
3938             }
3939             inputFlattened[j].copy_(inputTensors_[j], true);
3940           }
3941         },
3942         [&](at::cuda::CUDAStream&,
3943             c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
3944         OpType::REDUCE_SCATTER,
3945         "nccl:reduce_scatter");
3946   } else {
3947     const auto num_reduces = inputTensors_.size();
3948     startCoalescing();
3949     for (const int i : c10::irange(num_reduces)) {
3950       auto& input = inputTensors_[i];
3951       auto& output = (i == rank_) ? outputTensor : input;
3952       auto reduceOpts = ReduceOptions{
3953           opts.reduceOp,
3954           static_cast<int64_t>(i),
3955           static_cast<int64_t>(0),
3956           opts.timeout};
3957       _reduce_oop(output, input, reduceOpts);
3958     }
3959     auto work = endCoalescing(OpType::REDUCE_SCATTER);
3960     return work;
3961   }
3962 }
3963 
3964 c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
3965     at::Tensor& outputTensor,
3966     at::Tensor& inputTensor,
3967     const ReduceScatterOptions& opts) {
3968   if (inputTensor.dtype() != outputTensor.dtype()) {
3969     C10_THROW_ERROR(
3970         TypeError, "input tensor must be the same type as the output tensor.");
3971   }
3972 
3973   if (inputTensor.numel() != outputTensor.numel() * size_) {
3974     C10_THROW_ERROR(
3975         ValueError,
3976         "input tensor must be the same size as output size times world size");
3977   }
3978 
3979   // @lint-ignore CLANGTIDY
3980   const auto& tensor = outputTensor;
3981   TORCH_CHECK(
3982       !isFloat8Type(tensor.scalar_type()),
3983       "Float8 dtypes are not currenlty supported for NCCL reductions");
3984   RECORD_PARAM_COMMS_DATA(
3985       static_cast<int>(
3986           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
3987       std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
3988       inputTensor, // inputTensor
3989       outputTensor, // outputTensor
3990       rank_, // rank
3991       "_reduce_scatter_base", // collective name
3992       inputTensor.numel(), // inNelems
3993       tensor.numel(), // outNelems
3994       tensor.scalar_type(), // dtype
3995       std::vector<int64_t>(), // inSplitSizes
3996       std::vector<int64_t>(), // outSplitSizes
3997       globalRankStart, // globalRankStart
3998       globalRankStride, // globalRankStride
3999       this->getSize()); // worldSize
4000 
4001   // avoidRecordStreams_ note: collective() will stash inputs and outputs.
4002   // Note 2: for asyncOp = false, we don't want to record streams because we
4003   // know that the NCCL stream will join back to the "current" stream right
4004   // after this op. So we might just as well keep the stream ownership of the
4005   // input/output tensors unchanged. The benefit would be that the
4006   // allocation/free of the tensors would look deterministic to the "current"
4007   // stream so that the caching allocator can reuse memory pool for this stream
4008   // in a clever way. This setting is added for libraries like FSDP which uses
4009   // `reduce_scatter_tensor`.
4010   bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
4011 
4012   return collective(
4013       inputTensor,
4014       outputTensor,
4015       [&](at::Tensor& input,
4016           at::Tensor& output,
4017           ncclComm_t comm,
4018           at::cuda::CUDAStream& stream) {
4019         if (!avoidRecordStreams) {
4020           c10::cuda::CUDACachingAllocator::recordStream(
4021               output.storage().data_ptr(), stream);
4022         }
4023         auto ncclDataType = getNcclDataType(input.scalar_type());
4024         auto ncclReduceOp =
4025             getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
4026         return ncclReduceScatter(
4027             input.data_ptr(),
4028             output.data_ptr(),
4029             output.numel(),
4030             ncclDataType,
4031             ncclReduceOp,
4032             comm,
4033             stream.stream());
4034       },
4035       OpType::_REDUCE_SCATTER_BASE,
4036       "nccl:_reduce_scatter_base",
4037       avoidRecordStreams);
4038 }
4039 
4040 c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
4041     std::vector<at::Tensor>& outputs,
4042     std::vector<at::Tensor>& inputs,
4043     const ReduceScatterOptions& opts) {
4044   TORCH_CHECK(
4045       !isFloat8Type(inputs.back().scalar_type()),
4046       "Float8 dtypes are not currenlty supported for NCCL reductions");
4047   return collectiveCoalesced(
4048       inputs,
4049       outputs,
4050       [&](at::Tensor& input,
4051           at::Tensor& output,
4052           ncclComm_t comm,
4053           at::cuda::CUDAStream& stream) {
4054         if (!avoidRecordStreams_) {
4055           c10::cuda::CUDACachingAllocator::recordStream(
4056               output.storage().data_ptr(), stream);
4057         }
4058         auto ncclDataType = getNcclDataType(input.scalar_type());
4059         auto ncclReduceOp =
4060             getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
4061         return ncclReduceScatter(
4062             input.data_ptr(),
4063             output.data_ptr(),
4064             output.numel(),
4065             ncclDataType,
4066             ncclReduceOp,
4067             comm,
4068             stream.stream());
4069       },
4070       OpType::COALESCED,
4071       "nccl:reduce_scatter_tensor_coalesced");
4072 }
4073 
4074 c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
4075   RECORD_PARAM_COMMS(
4076       static_cast<int>(
4077           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
4078       std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4079       rank_, // rank
4080       "barrier", // collective name
4081       0, // inNelems
4082       0, // outNelems
4083       at::kByte, // dType
4084       std::vector<int64_t>(), // inSplitSizes
4085       std::vector<int64_t>(), // outSplitSizes
4086       globalRankStart, // globalRankStart
4087       globalRankStride, // globalRankStride
4088       this->getSize()); // worldSize
4089 
4090   // Device to use for barrier
4091   int barDevIdx = -1;
4092 
4093   // Select device to use for barrier
4094   // 1st choice: Use user defined GPU device ids if provided
4095   if (!opts.device_ids.empty()) {
4096     // Use the first device id because PG NCCL is single-device now
4097     barDevIdx = opts.device_ids[0];
4098   } else if (getBoundDeviceId()) {
4099     // 2nd choice: Use the bound GPU device id if available.
4100     // Bounded device id can be passed to `init_process_group`.
4101     barDevIdx = (*getBoundDeviceId()).index();
4102   } else if (!usedDeviceIdxs_.empty()) {
4103     // 3rd choice: infer the device id from the used device ids.
4104     barDevIdx = *usedDeviceIdxs_.begin();
4105   } else {
4106     // This means there is not yet a NCCL collective being called
4107     // Here we have to use the best guesses and will use a single GPU to call
4108     // allreduce to achieve barrier.
4109     // In case the multiple processes fall into the same node, we use rank to
4110     // ensure that each process is on a different GPU
4111     // Note: it is better to use global rank because the group-local rank can be
4112     // offset wrt the device id if intra-node GPUs are sharded into multiple
4113     // dimensions.
4114     barDevIdx = static_cast<int16_t>(globalRank() % localDeviceCount_);
4115     LOG(WARNING)
4116         << logPrefix()
4117         << c10::str(
4118                " using GPU ",
4119                barDevIdx,
4120                " to perform barrier as devices used by this process are currently unknown. ",
4121                "This can potentially cause a hang if this rank to GPU mapping is incorrect.",
4122                "Specify device_ids in barrier() to force use of a particular device,",
4123                "or call init_process_group() with a device_id.");
4124   }
4125 
4126   TORCH_CHECK_WITH(
4127       ValueError,
4128       barDevIdx >= 0,
4129       "Failed to infer a GPU device id to perform barrier. ");
4130   auto barDevice = at::Device(at::DeviceType::CUDA, barDevIdx);
4131 
4132   // Create a dummy tensor on the device
4133   // Note: we use zeros() instead of empty() to prevent barrier from triggering
4134   // alarm when NaN checker is enabled.
4135   at::Tensor barrierTensor =
4136       at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat));
4137 
4138   // All reduce to achieve the barrier
4139   auto work = allreduce_impl(barrierTensor);
4140 
4141   // Work will take over barrierTensors
4142   auto ncclWork = dynamic_cast<ProcessGroupNCCL::WorkNCCL*>(work.get());
4143   TORCH_CHECK(ncclWork);
4144   ncclWork->barrierTensor_ = std::move(barrierTensor);
4145   return work;
4146 }
4147 
4148 c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
4149     at::Tensor& outputTensor,
4150     at::Tensor& inputTensor,
4151     std::vector<int64_t>& outputSplitSizes,
4152     std::vector<int64_t>& inputSplitSizes,
4153     const AllToAllOptions& /* unused */) {
4154   check_gpu_single_tensor(outputTensor, true);
4155   check_gpu_single_tensor(inputTensor, true);
4156   if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) {
4157     RECORD_PARAM_COMMS_DATA(
4158         static_cast<int>(
4159             this->getSequenceNumberForGroup() +
4160             1), // seq + 1 to match collective
4161         std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4162         inputTensor, // inputTensor
4163         outputTensor, // outputTensor
4164         rank_, // rank
4165         "all_to_all", // collective name
4166         inputTensor.numel(), // inNelems
4167         outputTensor.numel(), // outNelems
4168         inputTensor.scalar_type(), // dType
4169         std::vector<int64_t>(), // inSplitSizes
4170         std::vector<int64_t>(), // outSplitSizes
4171         globalRankStart, // globalRankStart
4172         globalRankStride, // globalRankStride
4173         this->getSize()); // worldSize
4174 
4175     // avoidRecordStreams_ note: collective() will stash inputTensors and
4176     // outputTensors.
4177     return collective(
4178         inputTensor,
4179         outputTensor,
4180         [&](at::Tensor& input,
4181             at::Tensor& output,
4182             ncclComm_t comm,
4183             at::cuda::CUDAStream& stream) {
4184           // See [Sync Streams].
4185           if (!avoidRecordStreams_) {
4186             c10::cuda::CUDACachingAllocator::recordStream(
4187                 output.storage().data_ptr(), stream);
4188           }
4189           torch::cuda::nccl::all2all_single_equal_split(
4190               input, output, this->getSize(), comm, stream);
4191           return ncclSuccess;
4192         },
4193         OpType::ALLTOALL_BASE,
4194         "nccl:all_to_all");
4195   } else {
4196     c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
4197     c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
4198 
4199     RECORD_PARAM_COMMS_DATA(
4200         static_cast<int>(
4201             this->getSequenceNumberForGroup() +
4202             1), // seq + 1 to match collective
4203         std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4204         inputTensor, // inputTensor
4205         outputTensor, // outputTensor
4206         rank_, // rank
4207         "all_to_allv", // collective name
4208         inputTensor.numel(), // inNelems
4209         outputTensor.numel(), // outNelems
4210         inputTensor.scalar_type(), // dType
4211         inputSplitSizes, // inSplitSizes
4212         outputSplitSizes, // outSplitSizes
4213         globalRankStart, // globalRankStart
4214         globalRankStride, // globalRankStride
4215         this->getSize()); // worldSize
4216 
4217     // avoidRecordStreams_ note: collective() will stash inputTensors and
4218     // outputTensors.
4219     return collective(
4220         inputTensor,
4221         outputTensor,
4222         [&](at::Tensor& input,
4223             at::Tensor& output,
4224             ncclComm_t comm,
4225             at::cuda::CUDAStream& stream) {
4226           std::vector<size_t> send_lengths(size_);
4227           std::vector<size_t> recv_lengths(size_);
4228           std::vector<size_t> send_offsets(size_);
4229           std::vector<size_t> recv_offsets(size_);
4230           c10d::computeLengthsAndOffsets(
4231               inputSplitSizes, input, &send_lengths, &send_offsets);
4232           c10d::computeLengthsAndOffsets(
4233               outputSplitSizes, output, &recv_lengths, &recv_offsets);
4234           // See [Sync Streams].
4235           if (!avoidRecordStreams_) {
4236             c10::cuda::CUDACachingAllocator::recordStream(
4237                 output.storage().data_ptr(), stream);
4238           }
4239           torch::cuda::nccl::all2all_single_unequal_split(
4240               input.data_ptr(),
4241               send_lengths.data(),
4242               send_offsets.data(),
4243               output.data_ptr(),
4244               recv_lengths.data(),
4245               recv_offsets.data(),
4246               input.element_size(),
4247               input.scalar_type(),
4248               comm,
4249               stream);
4250           return ncclSuccess;
4251         },
4252         OpType::ALLTOALL_BASE,
4253         "nccl:all_to_all");
4254   }
4255 }
4256 
4257 c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
4258     std::vector<at::Tensor>& outputTensors,
4259     std::vector<at::Tensor>& inputTensors,
4260     const AllToAllOptions& /* unused */) {
4261   std::vector<int64_t> inSplitSizes;
4262   std::vector<int64_t> outSplitSizes;
4263   int64_t total_numel = 0;
4264 
4265   auto device = outputTensors[0].device();
4266   for (const auto r : c10::irange(outputTensors.size())) {
4267     check_gpu_single_tensor(outputTensors[r], true);
4268     check_gpu_single_tensor(inputTensors[r], true);
4269     TORCH_CHECK(
4270         device == outputTensors[r].device() &&
4271             device == inputTensors[r].device(),
4272         "Tensors must be on the same device")
4273     inSplitSizes.push_back(inputTensors[r].numel());
4274     outSplitSizes.push_back(outputTensors[r].numel());
4275     total_numel += inputTensors[r].numel();
4276   }
4277 
4278   RECORD_PARAM_COMMS_DATA(
4279       static_cast<int>(
4280           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
4281       std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4282       inputTensors, // inputTensors
4283       outputTensors, // outputTensors
4284       rank_, // rank
4285       "all_to_all", // collective name
4286       total_numel, // inNelems
4287       total_numel, // outNelems
4288       inputTensors.front().scalar_type(), // dType
4289       inSplitSizes, // inSplitSizes
4290       outSplitSizes, // outSplitSizes
4291       globalRankStart, // globalRankStart
4292       globalRankStride, // globalRankStride
4293       this->getSize()); // worldSize
4294 
4295   return collective(
4296       inputTensors,
4297       outputTensors,
4298       [&](at::Tensor& /* unused */,
4299           at::Tensor& /* unused */,
4300           ncclComm_t comm,
4301           at::cuda::CUDAStream& stream) {
4302         torch::cuda::nccl::all2all(outputTensors, inputTensors, comm, stream);
4303         return ncclSuccess;
4304       },
4305       [&](at::cuda::CUDAStream&,
4306           c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
4307         if (avoidRecordStreams_) {
4308           // inputTensor0 and outputTensor0 are stashed redundantly by
4309           // collective(), but that's ok.
4310           auto& v = work->stashed_for_allocator_safety_;
4311           v->insert(v->end(), inputTensors.begin(), inputTensors.end());
4312           v->insert(v->end(), outputTensors.begin(), outputTensors.end());
4313         }
4314       },
4315       [](at::cuda::CUDAStream&,
4316          c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
4317       OpType::ALLTOALL,
4318       "nccl:all_to_all");
4319 }
4320 
4321 c10::intrusive_ptr<Work> ProcessGroupNCCL::send(
4322     std::vector<at::Tensor>& tensors,
4323     int dstRank,
4324     int /* unused */) {
4325   TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
4326   // @lint-ignore CLANGTIDY
4327   auto tensor = tensors.back();
4328   check_gpu_single_tensor(tensor, true);
4329 
4330   RECORD_PARAM_COMMS_DATA(
4331       static_cast<int>(
4332           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
4333       std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4334       tensors, // inputTensors
4335       tensors, // outputTensors
4336       dstRank, // dst rank
4337       "send", // collective name
4338       tensor.numel(), // inNelems
4339       tensor.numel(), // outNelems
4340       tensor.scalar_type(), // dType
4341       std::vector<int64_t>(), // inSplitSizes
4342       std::vector<int64_t>(), // outSplitSizes
4343       globalRankStart, // globalRankStart
4344       globalRankStride, // globalRankStride
4345       this->getSize()); // worldSize
4346 
4347   auto ret = pointToPoint(
4348       tensor,
4349       [&](at::Tensor& input,
4350           ncclComm_t comm,
4351           at::cuda::CUDAStream& stream,
4352           int dst) {
4353         torch::cuda::nccl::send(input, comm, stream, dst);
4354         return ncclSuccess;
4355       },
4356       dstRank,
4357       OpType::SEND,
4358       c10::str("nccl:send ", rank_, "->", dstRank).c_str());
4359   return ret;
4360 }
4361 
4362 c10::intrusive_ptr<Work> ProcessGroupNCCL::recv(
4363     std::vector<at::Tensor>& tensors,
4364     int srcRank,
4365     int /* unused */) {
4366   TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
4367   // @lint-ignore CLANGTIDY
4368   auto tensor = tensors.back();
4369   check_gpu_single_tensor(tensor, true);
4370 
4371   RECORD_PARAM_COMMS_DATA(
4372       static_cast<int>(
4373           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
4374       std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4375       tensors, // inputTensors
4376       tensors, // outputTensors
4377       srcRank, // src rank
4378       "recv", // collective name
4379       tensor.numel(), // inNelems
4380       tensor.numel(), // outNelems
4381       tensor.scalar_type(), // dType
4382       std::vector<int64_t>(), // inSplitSizes
4383       std::vector<int64_t>(), // outSplitSizes
4384       globalRankStart, // globalRankStart
4385       globalRankStride, // globalRankStride
4386       this->getSize()); // worldSize
4387 
4388   auto ret = pointToPoint(
4389       tensor,
4390       [&](at::Tensor& output,
4391           ncclComm_t comm,
4392           at::cuda::CUDAStream& stream,
4393           int src) {
4394         torch::cuda::nccl::recv(output, comm, stream, src);
4395         return ncclSuccess;
4396       },
4397       srcRank,
4398       OpType::RECV,
4399       c10::str("nccl:recv ", rank_, "<-", srcRank).c_str());
4400   return ret;
4401 }
4402 
4403 void ProcessGroupNCCL::groupStart() {
4404   C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt);
4405   ++ncclActiveGroupCounter_;
4406 }
4407 
4408 void ProcessGroupNCCL::groupEnd() {
4409   C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt);
4410   --ncclActiveGroupCounter_;
4411 }
4412 
4413 void ProcessGroupNCCL::groupEndNonblocking(std::shared_ptr<NCCLComm> comm) {
4414 #ifndef NCCL_HAS_COMM_NONBLOCKING
4415   C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt);
4416 #else
4417   if (!nccl_use_nonblocking()) {
4418     C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt);
4419   } else {
4420     C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comm, std::nullopt);
4421   }
4422 #endif
4423   --ncclActiveGroupCounter_;
4424 }
4425 
4426 c10::intrusive_ptr<Work> ProcessGroupNCCL::gather(
4427     std::vector<std::vector<at::Tensor>>& outputTensors,
4428     std::vector<at::Tensor>& inputTensors,
4429     const GatherOptions& opts) {
4430   static auto invalidArgument = [](const std::string& msg) {
4431     C10_THROW_ERROR(ValueError, "ProcessGroupNCCL::gather: " + msg);
4432   };
4433 
4434   assertRootRank(invalidArgument, opts.rootRank, size_);
4435 
4436   TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
4437   // @lint-ignore CLANGTIDY
4438   auto inputTensor = inputTensors.back();
4439 
4440   std::vector<at::Tensor> outputs;
4441 
4442   if (getRank() == opts.rootRank) {
4443     if (outputTensors.size() != 1) {
4444       std::stringstream ss;
4445       ss << "requires a single-element output list containing a list with "
4446          << getSize() << " tensors.";
4447       invalidArgument(ss.str());
4448     } else if (outputTensors[0].size() != static_cast<size_t>(getSize())) {
4449       std::stringstream ss;
4450       ss << "Incorrect output list size " << outputTensors[0].size()
4451          << ". Output list size should be " << getSize()
4452          << ", same as size of the process group.";
4453       invalidArgument(ss.str());
4454     }
4455 
4456     const auto& options = inputTensor.options();
4457     const auto& sizes = inputTensor.sizes();
4458     assertTypeAndSizesMatch(invalidArgument, outputTensors[0], options, sizes);
4459     outputs = outputTensors[0];
4460   } else {
4461     // if not in the root rank, initialize outputs as empty list
4462     if (outputTensors.size() != 0) {
4463       invalidArgument("requires empty output on non-root");
4464     }
4465     outputs = {};
4466     // append a empty tensor to the list, we don't use it but the
4467     // `collective` template function requires it to invoke its function
4468     outputs.emplace_back();
4469   }
4470 
4471   RECORD_PARAM_COMMS_DATA(
4472       static_cast<int>(
4473           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
4474       std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4475       inputTensors, // inputTensors
4476       outputTensors, // outputTensors
4477       opts.rootRank, // root rank
4478       "gather", // collective name
4479       inputTensor.numel(), // inNelems
4480       inputTensor.numel() * this->getSize(), // outNelems
4481       inputTensor.scalar_type(), // dType
4482       std::vector<int64_t>(), // inSplitSizes
4483       std::vector<int64_t>(), // outSplitSize
4484       globalRankStart, // globalRankStart
4485       globalRankStride, // globalRankStride
4486       this->getSize()); // worldSize
4487 
4488   // avoidRecordStreams_ note: collective() will stash inputTensors and
4489   // outputs, which == outputTensors[0] on the root rank where it matters.
4490 
4491   auto inputs = std::vector<at::Tensor>{inputTensor};
4492   return collective(
4493       inputs,
4494       outputs, // just to fit the collective interface
4495       [&](at::Tensor& /* unused */,
4496           at::Tensor& /* unused */,
4497           ncclComm_t comm,
4498           at::cuda::CUDAStream& stream) {
4499         const auto root = opts.rootRank;
4500         if (getRank() == root) {
4501           if (!avoidRecordStreams_) {
4502             for (auto output : outputs) {
4503               c10::cuda::CUDACachingAllocator::recordStream(
4504                   output.storage().data_ptr(), stream);
4505             }
4506           }
4507         }
4508         torch::cuda::nccl::gather(inputTensor, outputs, comm, stream, root);
4509         return ncclSuccess;
4510       },
4511       [](at::cuda::CUDAStream&,
4512          c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
4513       [](at::cuda::CUDAStream&,
4514          c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
4515       OpType::GATHER,
4516       "nccl:gather");
4517 }
4518 
4519 c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
4520     std::vector<at::Tensor>& outputTensors,
4521     std::vector<std::vector<at::Tensor>>& inputTensors,
4522     const ScatterOptions& opts) {
4523   static auto invalidArgument = [](const std::string& msg) {
4524     C10_THROW_ERROR(ValueError, "ProcessGroupNCCL::scatter: " + msg);
4525   };
4526 
4527   assertRootRank(invalidArgument, opts.rootRank, size_);
4528 
4529   TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
4530   auto outputTensor = outputTensors.back();
4531 
4532   std::vector<at::Tensor> inputs;
4533 
4534   if (getRank() == opts.rootRank) {
4535     if (inputTensors.size() != 1) {
4536       std::stringstream ss;
4537       ss << "requires a single-element input list containing a list with "
4538          << getSize() << " tensors.";
4539       invalidArgument(ss.str());
4540     } else if (inputTensors[0].size() != static_cast<size_t>(getSize())) {
4541       std::stringstream ss;
4542       ss << "Incorrect input list size " << inputTensors[0].size()
4543          << ". Input list size should be " << getSize()
4544          << ", same as size of the process group.";
4545       invalidArgument(ss.str());
4546     }
4547 
4548     const auto& options = outputTensor.options();
4549     const auto& sizes = outputTensor.sizes();
4550     assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes);
4551     inputs = inputTensors[0];
4552   } else {
4553     // if not in the root rank, initialize inputTensors as empty place holder
4554     // with an empty list
4555     if (inputTensors.size() != 0) {
4556       invalidArgument("requires empty input on non-root");
4557     }
4558     inputs = {};
4559     // append a empty tensor to the list, we don't use it but the
4560     // `collective` template function requires it to invoke its function
4561     inputs.emplace_back();
4562   }
4563 
4564   RECORD_PARAM_COMMS_DATA(
4565       static_cast<int>(
4566           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
4567       std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4568       inputTensors, // inputTensors
4569       outputTensors, // outputTensors
4570       opts.rootRank, // root rank
4571       "scatter", // collective name
4572       outputTensor.numel() * this->getSize(), // inNelems
4573       outputTensor.numel(), // outNelems
4574       outputTensor.scalar_type(), // dType
4575       std::vector<int64_t>(), // inSplitSizes
4576       std::vector<int64_t>(), // outSplitSize
4577       globalRankStart, // globalRankStart
4578       globalRankStride, // globalRankStride
4579       this->getSize()); // worldSize
4580 
4581   // avoidRecordStreams_ note: collective() will stash outputTensors and
4582   // inputs, which == inputTensors[0] on the root rank where it matters.
4583   bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
4584 
4585   const auto root = opts.rootRank;
4586   bool nanCheck = (rank_ == root);
4587 
4588   auto outputs = std::vector<at::Tensor>{outputTensor};
4589   return collective(
4590       outputs,
4591       inputs, // just to fit the collective interface
4592       [&](at::Tensor& /* unused */,
4593           at::Tensor& /* unused */,
4594           ncclComm_t comm,
4595           at::cuda::CUDAStream& stream) {
4596         if (getRank() == root) {
4597           if (!avoidRecordStreams) {
4598             for (auto input : inputs) {
4599               c10::cuda::CUDACachingAllocator::recordStream(
4600                   input.storage().data_ptr(), stream);
4601             }
4602           }
4603         }
4604         torch::cuda::nccl::scatter(inputs, outputTensor, comm, stream, root);
4605         return ncclSuccess;
4606       },
4607       [](at::cuda::CUDAStream&,
4608          c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
4609       [](at::cuda::CUDAStream&,
4610          c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
4611       OpType::SCATTER,
4612       "nccl:scatter",
4613       avoidRecordStreams,
4614       nanCheck);
4615 }
4616 
4617 c10::intrusive_ptr<Work> ProcessGroupNCCL::recvAnysource(
4618     std::vector<at::Tensor>& /* unused */,
4619     int /* unused */) {
4620   C10_THROW_ERROR(
4621       NotImplementedError, "ProcessGroupNCCL does not support recvAnysource");
4622 }
4623 
4624 c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
4625     at::Tensor& output_tensor,
4626     at::Tensor& input_tensor,
4627     const AllgatherOptions& opts) {
4628   check_gpu_single_tensor(input_tensor);
4629   check_gpu_single_tensor(output_tensor);
4630 
4631   if (input_tensor.dtype() != output_tensor.dtype()) {
4632     C10_THROW_ERROR(
4633         TypeError, "output tensor must have the same type as input tensor");
4634   }
4635 
4636   if (input_tensor.numel() * size_ != output_tensor.numel()) {
4637     C10_THROW_ERROR(
4638         ValueError,
4639         "output tensor size must be equal to world_size times input tensor size");
4640   }
4641 
4642   RECORD_PARAM_COMMS_DATA(
4643       static_cast<int>(
4644           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
4645       std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4646       input_tensor, // inputTensors
4647       output_tensor, // outputTensors
4648       rank_, // rank
4649       "_allgather_base", // collective name
4650       input_tensor.numel(), // inNelems
4651       output_tensor.numel(), // outNelems
4652       output_tensor.scalar_type(), // dType
4653       std::vector<int64_t>(), // inSplitSizes
4654       std::vector<int64_t>(), // outSplitSize
4655       globalRankStart, // globalRankStart
4656       globalRankStride, // globalRankStride
4657       this->getSize()); // worldSize
4658 
4659   // avoidRecordStreams_ note: collective() will stash inputs and outputs.
4660   // Note 2: for asyncOp = false, we don't want to record streams because we
4661   // know that the NCCL stream will join back to the "current" stream right
4662   // after this op. So we might just as well keep the stream ownership of the
4663   // input/output tensors unchanged. The benefit would be that the
4664   // allocation/free of the tensors would look deterministic to the "current"
4665   // stream so that the caching allocator can reuse memory pool for this stream
4666   // in a clever way. This setting is added for libraries like FSDP which uses
4667   // `all_gather_into_tensor`.
4668   bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
4669 
4670   return collective(
4671       input_tensor,
4672       output_tensor,
4673       [&](at::Tensor& input,
4674           at::Tensor& output,
4675           ncclComm_t comm,
4676           at::cuda::CUDAStream& stream) {
4677         if (!avoidRecordStreams) {
4678           c10::cuda::CUDACachingAllocator::recordStream(
4679               output.storage().data_ptr(), stream);
4680         }
4681         return ncclAllGather(
4682             input.data_ptr(),
4683             output.data_ptr(),
4684             input.numel(),
4685             getNcclDataType(input.scalar_type()),
4686             comm,
4687             stream.stream());
4688       },
4689       OpType::_ALLGATHER_BASE,
4690       "nccl:_all_gather_base",
4691       avoidRecordStreams);
4692 }
4693 
4694 } // namespace c10d
4695 
4696 #endif // USE_C10D_NCCL
4697