1 #ifdef USE_C10D_NCCL
2
3 #include <exception>
4 #include <fstream>
5 #include <map>
6 #include <mutex>
7 #include <sstream>
8 #include <stdexcept>
9 #include <tuple>
10 #include <unordered_set>
11 #include <utility>
12
13 #include <ATen/cuda/CUDAContext.h>
14 #include <ATen/cuda/CUDAGraph.h>
15 #include <c10/core/DeviceType.h>
16 #include <c10/cuda/CUDAAllocatorConfig.h>
17 #include <c10/cuda/CUDAGraphsC10Utils.h>
18 #include <c10/cuda/CUDAGuard.h>
19 #include <c10/util/CallOnce.h>
20 #include <c10/util/Exception.h>
21 #include <c10/util/Logging.h>
22 #include <c10/util/WaitCounter.h>
23 #include <c10/util/irange.h>
24 #include <c10/util/thread_name.h>
25 #include <torch/csrc/cuda/nccl.h>
26 #include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
27 #include <torch/csrc/distributed/c10d/NanCheck.hpp>
28 #include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
29 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
30 #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
31 #include <torch/csrc/distributed/c10d/TraceUtils.h>
32 #include <torch/csrc/distributed/c10d/Utils.hpp>
33 #include <torch/csrc/distributed/c10d/logger.hpp>
34 #include <torch/torch.h>
35 #include <optional>
36
37 namespace c10d {
38
39 constexpr const char* const kNCCLAbortedCommStoreKey = "NCCLABORTEDCOMM";
40
41 namespace {
42
43 #if defined(NCCL_MAJOR) && \
44 ((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10))
45 #define NCCL_HAS_AVG 1
46 #endif
47
48 // NCCL op mapping
49 const std::map<ReduceOp::RedOpType, ncclRedOp_t> ncclOp = {
50 {ReduceOp::MIN, ncclMin},
51 {ReduceOp::MAX, ncclMax},
52 {ReduceOp::SUM, ncclSum},
53 {ReduceOp::PRODUCT, ncclProd},
54 #ifdef NCCL_HAS_AVG
55 {ReduceOp::AVG, ncclAvg},
56 #endif
57 };
58
59 // NCCL type typing
60 std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
61 {at::kChar, ncclInt8},
62 {at::kByte, ncclUint8},
63 {at::kFloat, ncclFloat},
64 {at::kDouble, ncclDouble},
65 {at::kInt, ncclInt32},
66 {at::kLong, ncclInt64},
67 {at::kHalf, ncclHalf},
68 {at::kBool, ncclUint8},
69 {at::kFloat8_e5m2, ncclUint8},
70 {at::kFloat8_e4m3fn, ncclUint8},
71 {at::kFloat8_e4m3fnuz, ncclUint8},
72 {at::kFloat8_e5m2fnuz, ncclUint8},
73 #if HAS_NCCL_BF16_DATATYPE
74 {at::kBFloat16, ncclBfloat16},
75 #endif
76 };
77
78 // Helper function that gets the data type and issues error if not supported
getNcclDataType(at::ScalarType type)79 ncclDataType_t getNcclDataType(at::ScalarType type) {
80 auto it = ncclDataType.find(type);
81 TORCH_CHECK_WITH(
82 TypeError,
83 it != ncclDataType.end(),
84 "Input tensor data type is not supported for NCCL process group: ",
85 type);
86 return it->second;
87 }
88
complexViewAsRealAllowed(const ReduceOp reduceOp)89 bool complexViewAsRealAllowed(const ReduceOp reduceOp) {
90 switch (reduceOp) {
91 case ReduceOp::SUM:
92 return true;
93 case ReduceOp::AVG:
94 return true;
95 case ReduceOp::PREMUL_SUM:
96 return true;
97 case ReduceOp::UNUSED:
98 return true;
99 default:
100 return false;
101 }
102 return false;
103 }
104
105 #ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT
106 template <typename T, ncclDataType_t dataType>
unpackPreMulSum(const ReduceOp & reduceOp,const ncclComm_t & comm)107 ncclRedOpRAII unpackPreMulSum(
108 const ReduceOp& reduceOp,
109 const ncclComm_t& comm) {
110 const auto* preMulSupplement =
111 reinterpret_cast<NCCLPreMulSumSupplement*>(reduceOp.supplement_.get());
112 ncclRedOp_t preMulSum;
113 bool has_tensor = preMulSupplement->tensor_factor.defined();
114 auto residence = has_tensor ? ncclScalarDevice : ncclScalarHostImmediate;
115 const T* ptr_factor = has_tensor
116 ? preMulSupplement->tensor_factor.const_data_ptr<T>()
117 : nullptr;
118 T scalar_factor = T(preMulSupplement->double_factor);
119 ncclRedOpCreatePreMulSum(
120 &preMulSum,
121 // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/ops.html#ncclredopcreatepremulsum
122 // tells us that the scalar input is strictly a multiplier.
123 /*scalar=*/has_tensor ? const_cast<T*>(ptr_factor) : &scalar_factor,
124 dataType,
125 residence,
126 comm);
127 return ncclRedOpRAII(preMulSum, comm);
128 }
129 #endif
130
getNcclReduceOp(const ReduceOp & reduceOp,at::Tensor & input,const ncclDataType_t & dataType,const ncclComm_t & comm)131 ncclRedOpRAII getNcclReduceOp(
132 const ReduceOp& reduceOp,
133 at::Tensor& input,
134 const ncclDataType_t& dataType,
135 const ncclComm_t& comm) {
136 try {
137 if (input.scalar_type() == at::kBool) {
138 if (reduceOp == ReduceOp::SUM) {
139 // For bool tensors, map sum to max, which both represent a bitwise or.
140 // This is to prevent overflow issues with sum, since we use uint8 to
141 // represent a bool (see ncclDataType mapping).
142 return ncclMax;
143 }
144 #ifdef NCCL_HAS_AVG
145 if (reduceOp == ReduceOp::AVG) {
146 C10_THROW_ERROR(
147 TypeError, "Cannot use ReduceOp.AVG with boolean inputs");
148 }
149 #endif
150 }
151 if (reduceOp == ReduceOp::PREMUL_SUM) {
152 #ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT
153 switch (dataType) {
154 case ncclHalf:
155 return unpackPreMulSum<at::Half, ncclHalf>(reduceOp, comm);
156 case ncclFloat:
157 return unpackPreMulSum<float, ncclFloat>(reduceOp, comm);
158 case ncclDouble:
159 return unpackPreMulSum<double, ncclDouble>(reduceOp, comm);
160 default:
161 C10_THROW_ERROR(
162 TypeError, "PreMulSum Data type must be half, float, or double");
163 ncclRedOp_t unused;
164 return unused;
165 }
166 #else
167 C10_THROW_ERROR(ValueError, "PreMulSum requires NCCL>=2.11.1");
168 #endif
169 }
170 return ncclOp.at(reduceOp);
171 } catch (const std::out_of_range&) {
172 switch (reduceOp) {
173 case ReduceOp::AVG:
174 C10_THROW_ERROR(
175 ValueError,
176 c10::str(
177 "AVG requires NCCL 2.10+. The current version is ",
178 NCCL_MAJOR,
179 ".",
180 NCCL_MINOR));
181 break;
182 case ReduceOp::BAND:
183 C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with NCCL");
184 break;
185 case ReduceOp::BOR:
186 C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with NCCL");
187 break;
188 case ReduceOp::BXOR:
189 C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with NCCL");
190 break;
191 default:
192 C10_THROW_ERROR(ValueError, "Unhandled ReduceOp");
193 break;
194 }
195 }
196 }
197
198 // Get a key string from device
getKeyFromDevice(at::Device & device)199 inline std::string getKeyFromDevice(at::Device& device) {
200 return std::to_string(device.index());
201 }
202
getIndexFromDeviceKey(const std::string & deviceKey)203 inline at::DeviceIndex getIndexFromDeviceKey(const std::string& deviceKey) {
204 // initialize the device index to -1, which is an invalid value.
205 int index = -1;
206 try {
207 index = std::stoi(deviceKey);
208 } catch (const std::invalid_argument& e) {
209 LOG(ERROR) << c10::str(
210 "Invalid deviceKey: ", deviceKey, ",", e.what(), ".");
211 } catch (const std::out_of_range& e) {
212 LOG(ERROR) << "Out of range: " << e.what();
213 }
214 return static_cast<at::DeviceIndex>(index);
215 }
216
getKeySendRecv(int myRank,int peer)217 std::string getKeySendRecv(int myRank, int peer) {
218 int lowRank = myRank < peer ? myRank : peer;
219 int highRank = myRank < peer ? peer : myRank;
220 std::string sendRecvPair =
221 std::to_string(lowRank) + ":" + std::to_string(highRank);
222 return sendRecvPair;
223 }
224
225 // Get device from tensor
getDevice(at::Tensor & tensor)226 inline at::Device getDevice(at::Tensor& tensor) {
227 return tensor.device();
228 }
229
230 // [Sync Streams] Helper that lets the input ncclStreams to wait for the current
231 // stream. NCCL communications run on ncclStreams, but input tensors are
232 // allocated on different streams (i.e., current streams). Communications on
233 // ncclStreams cannot start before pending input tensor ops on current streams
234 // finish. Otherwise, ops on two streams might read/write same tensors
235 // concurrently.
236 //
237 // The synchronization above alone is not enough. We also need to make sure
238 // input tensors are not freed before their usages on ncclStreams finish. This
239 // can be achieved by calling c10::cuda::CUDACachingAllocator::recordStream,
240 // which remembers the usage stream (ncclStream), creates an event on the usage
241 // stream when GC attempts to free the input tensor, and delays GC until that
242 // event is done.
syncStream(at::Device & device,at::cuda::CUDAEvent & ncclEvent,at::cuda::CUDAStream & ncclStream)243 void syncStream(
244 at::Device& device,
245 at::cuda::CUDAEvent& ncclEvent,
246 at::cuda::CUDAStream& ncclStream) {
247 ncclEvent.record(at::cuda::getCurrentCUDAStream(device.index()));
248 ncclEvent.block(ncclStream);
249 }
250
251 // Given a ncclUniqueId, convert it to a string representation that can be put
252 // in the store.
buildNcclUniqueIdStr(const ncclUniqueId & ncclID)253 std::string buildNcclUniqueIdStr(const ncclUniqueId& ncclID) {
254 const uint8_t* bytes = reinterpret_cast<const uint8_t*>(&ncclID);
255 std::ostringstream oss;
256 for (const auto i : c10::irange(NCCL_UNIQUE_ID_BYTES)) {
257 oss << std::hex << static_cast<int>(bytes[i]);
258 }
259 return oss.str();
260 }
261
getNcclAbortedCommStoreKey(const std::string ncclIdStr)262 std::string getNcclAbortedCommStoreKey(const std::string ncclIdStr) {
263 return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr;
264 }
265
266 // Returns exception's what() given an exception_ptr instance.
getExceptionMsgFromExceptionPtr(const std::exception_ptr & exceptionPtr)267 std::string getExceptionMsgFromExceptionPtr(
268 const std::exception_ptr& exceptionPtr) {
269 TORCH_CHECK(exceptionPtr != nullptr);
270 try {
271 std::rethrow_exception(exceptionPtr);
272 } catch (const std::exception& e) {
273 return e.what();
274 } catch (...) {
275 return "Unknown exception type";
276 }
277 }
278
errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status)279 inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) {
280 // parentheses avoid some compiler warnings
281 static const uint64_t min_version =
282 (((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6);
283 static const uint64_t cur_version = torch::cuda::nccl::version();
284 if (cur_version < min_version) {
285 TORCH_CHECK_WITH(
286 NotImplementedError,
287 status == c10::cuda::CaptureStatus::None,
288 "Capturing NCCL collectives is only allowed with NCCL >= 2.9.6");
289 }
290 }
291
292 } // namespace
293
294 // Map from each communicator to its device index.
295 // This map is used when register/deregister cache segments from cache
296 // allocator. See design notes below:
297 // - Each segment should be registered only to the communicator on the
298 // same device.
299 // - We cannot reuse devNCCLCommMap_ in each ProcessGroup because the key may be
300 // ranks rather than device in point-to-point case.
301 // - This map has also to be maintained as global variable since the register
302 // hooks are called outside the scope of any PG, thus we need traverse
303 // communicators in all PGs.
304 static std::unordered_map<std::shared_ptr<NCCLComm>, int> ncclCommDevIdxMap;
305 static std::mutex ncclCommDevIdxMapMutex;
306 static bool allocatorHooksAttached = false;
307
308 std::atomic<bool> ProcessGroupNCCL::shouldDump_(false);
309
cacheAllocatorRegisterHook(const c10::cuda::CUDACachingAllocator::TraceEntry & te)310 void cacheAllocatorRegisterHook(
311 const c10::cuda::CUDACachingAllocator::TraceEntry& te) {
312 // Register after SEGMENT_ALLOC
313 if (te.action_ !=
314 c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_ALLOC) {
315 return;
316 }
317
318 std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
319 for (auto& it : ncclCommDevIdxMap) {
320 auto& ncclComm = it.first;
321 auto& devIdx = it.second;
322 if (te.device_ == devIdx) {
323 ncclComm->registerSegment(reinterpret_cast<void*>(te.addr_), te.size_);
324 }
325 }
326 }
327
cacheAllocatorDeregisterHook(const c10::cuda::CUDACachingAllocator::TraceEntry & te)328 void cacheAllocatorDeregisterHook(
329 const c10::cuda::CUDACachingAllocator::TraceEntry& te) {
330 // deregister before SEGMENT_FREE
331 if (te.action_ !=
332 c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_FREE) {
333 return;
334 }
335
336 std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
337 for (auto& it : ncclCommDevIdxMap) {
338 auto& ncclComm = it.first;
339 auto& devIdx = it.second;
340 if (te.device_ == devIdx) {
341 ncclComm->deregisterSegment(reinterpret_cast<void*>(te.addr_));
342 }
343 }
344 }
345
346 std::unordered_map<std::string, std::unordered_map<std::string, std::string>>
getNCCLCommDumpMap()347 getNCCLCommDumpMap() {
348 #if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP)
349 std::unordered_map<
350 std::string /* ncclUniqueID */,
351 std::unordered_map<std::string, std::string> /* dump from this comm */>
352 ncclDumpMap;
353 // dump_nccl_trace is only called from the default PG (local_id_=0), but we
354 // want to dump from all comms so we need to iterate over ncclCommDevIdxMap,
355 // which is static
356 std::vector<std::shared_ptr<NCCLComm>> allNCCLComms;
357 // within the critical section, we don't want to dump while holding the lock
358 // as dump might hang
359 ncclCommDevIdxMapMutex.lock();
360 for (auto& [ncclComm, _] : ncclCommDevIdxMap) {
361 allNCCLComms.push_back(ncclComm);
362 }
363 ncclCommDevIdxMapMutex.unlock();
364 for (auto& ncclComm : allNCCLComms) {
365 std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId());
366 ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump();
367 }
368 return ncclDumpMap;
369 #else
370 return std::unordered_map<
371 std::string,
372 std::unordered_map<std::string, std::string>>();
373 #endif
374 }
375
dump_nccl_trace(bool includeCollectives,bool includeStackTraces,bool onlyActive)376 std::string dump_nccl_trace(
377 bool includeCollectives,
378 bool includeStackTraces,
379 bool onlyActive) {
380 auto ncclDumpMap = getNCCLCommDumpMap();
381 return NCCLTraceBuffer::get()->dump(
382 ncclDumpMap, includeCollectives, includeStackTraces, onlyActive);
383 }
384
dump_nccl_trace_json(bool includeCollectives,bool onlyActive)385 std::string dump_nccl_trace_json(bool includeCollectives, bool onlyActive) {
386 auto ncclDumpMap = getNCCLCommDumpMap();
387 return NCCLTraceBuffer::get()->dump_json(
388 ncclDumpMap, includeCollectives, onlyActive);
389 }
390
391 std::optional<std::function<void(std::function<void(const std::string&)>)>>&
get_cpp_trace_dumper()392 get_cpp_trace_dumper() {
393 static std::optional<
394 std::function<void(std::function<void(const std::string&)>)>>
395 dumper(std::nullopt);
396 return dumper;
397 }
398
get_gil_checker()399 gil_checker_t& get_gil_checker() {
400 static gil_checker_t gil_checker = nullptr;
401 return gil_checker;
402 }
403
launchAsyncGilCheck()404 std::future<bool> launchAsyncGilCheck() {
405 std::promise<bool> resultPromise;
406 std::future<bool> resultFuture = resultPromise.get_future();
407 TORCH_CHECK(get_gil_checker(), "Can't check GIL with null GIL checker");
408 std::thread workerThread([promise = std::move(resultPromise)]() mutable {
409 c10::setThreadName("pt_nccl_gil_chk");
410
411 try {
412 auto& gil_checker = get_gil_checker();
413 promise.set_value((*gil_checker)());
414 } catch (...) {
415 promise.set_exception(std::current_exception());
416 }
417 });
418
419 // Detach the thread to allow it to run independently
420 workerThread.detach();
421
422 return resultFuture;
423 }
424
425 const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 100;
426 constexpr int64_t kSynchronizeBusyWaitMillis = 10;
427 thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0;
428
operator <<(std::ostream & output,const ProcessGroupNCCL::WorkNCCL & workNCCL)429 std::ostream& operator<<(
430 std::ostream& output,
431 const ProcessGroupNCCL::WorkNCCL& workNCCL) {
432 std::string workInfo;
433 workInfo = c10::str(
434 "WorkNCCL(",
435 "SeqNum=",
436 workNCCL.seq_,
437 ", OpType=",
438 opTypeToString(workNCCL.opType_),
439 ", NumelIn=",
440 workNCCL.numelIn_,
441 ", NumelOut=",
442 workNCCL.numelOut_,
443 ", Timeout(ms)=",
444 workNCCL.opTimeout_.count(),
445 ")");
446 return output << workInfo;
447 }
448
WorkNCCL(const std::string & pgUID,const std::string & pgDesc,at::Device & device,int rank,OpType opType,uint64_t seq,const char * profilingTitle,const std::optional<std::vector<at::Tensor>> & inputs,bool desyncDebug,bool enableTiming,bool cudaEventCacheEnabled,DebugLevel distDebugLevel)449 ProcessGroupNCCL::WorkNCCL::WorkNCCL(
450 const std::string& pgUID,
451 const std::string& pgDesc,
452 at::Device& device,
453 int rank,
454 OpType opType,
455 uint64_t seq,
456 const char* profilingTitle,
457 const std::optional<std::vector<at::Tensor>>& inputs,
458 bool desyncDebug,
459 bool enableTiming,
460 bool cudaEventCacheEnabled,
461 DebugLevel distDebugLevel)
462 : Work(rank, opType, profilingTitle, inputs),
463 pgUID_(pgUID),
464 pgDesc_(pgDesc),
465 device_(device),
466 workStartTime_(std::chrono::steady_clock::now()),
467 seq_(seq),
468 timingEnabled_(enableTiming),
469 distDebugLevel_(distDebugLevel) {
470 // Creates the CUDA event wrappers
471 // Note: The actual events are lazily created when first recorded to with
472 // DEFAULT_FLAGS = cudaEventDisableTiming.
473 if (cudaEventCacheEnabled) {
474 ncclStartEvent_ = enableTiming
475 ? ProcessGroupNCCL::CUDAEventCache::get().create(enableTiming)
476 : nullptr;
477 ncclEndEvent_ =
478 ProcessGroupNCCL::CUDAEventCache::get().create(enableTiming);
479 } else {
480 ncclStartEvent_ = enableTiming
481 ? std::make_shared<at::cuda::CUDAEvent>(cudaEventDefault)
482 : nullptr;
483 ncclEndEvent_ = std::make_shared<at::cuda::CUDAEvent>(
484 enableTiming ? cudaEventDefault : cudaEventDisableTiming);
485 }
486 }
487
WorkNCCL(const WorkNCCL & w)488 ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
489 : Work(w.rank_, w.opType_),
490 std::enable_shared_from_this<WorkNCCL>(w),
491 pgUID_(w.pgUID_),
492 pgDesc_(w.pgDesc_),
493 device_(w.device_),
494 ncclStartEvent_(w.ncclStartEvent_),
495 ncclEndEvent_(w.ncclEndEvent_),
496 ncclComm_(w.ncclComm_),
497 blockingWait_(w.blockingWait_),
498 opTimeout_(w.opTimeout_),
499 ownedEphermeralTimeout_(w.ownedEphermeralTimeout_),
500 workStartTime_(w.workStartTime_),
501 seq_(w.seq_),
502 startTraceUpdated_(w.startTraceUpdated_),
503 numelIn_(w.numelIn_),
504 numelOut_(w.numelOut_),
505 store_(w.store_),
506 timingEnabled_(w.timingEnabled_),
507 trace_id_(w.trace_id_),
508 distDebugLevel_(w.distDebugLevel_) {
509 exception_ = w.exception_;
510 }
511
512 ProcessGroupNCCL::WorkNCCL::~WorkNCCL() = default;
513
isCompleted()514 bool ProcessGroupNCCL::WorkNCCL::isCompleted() {
515 if (!ncclComm_->isAborted()) {
516 checkAndSetException();
517 }
518 return exception() || finishedGPUExecutionInternal();
519 }
520
isStarted()521 bool ProcessGroupNCCL::WorkNCCL::isStarted() {
522 if (!ncclComm_->isAborted()) {
523 checkAndSetException();
524 }
525 return exception() || startedGPUExecutionInternal();
526 }
527
isSuccess() const528 bool ProcessGroupNCCL::WorkNCCL::isSuccess() const {
529 C10_THROW_ERROR(NotImplementedError, "WorkNCCL::isSuccess() is deprecated");
530 }
531
checkAndSetException()532 void ProcessGroupNCCL::WorkNCCL::checkAndSetException() {
533 if (exception()) {
534 // We already have an exception.
535 return;
536 }
537
538 auto exception_ptr = checkForNCCLErrors();
539 std::unique_lock<std::mutex> lock(mutex_);
540 exception_ = exception_ptr;
541 if (exception_) {
542 LOG(ERROR) << logPrefix() << "Collective " << *this
543 << " raised the following async exception: "
544 << getExceptionMsgFromExceptionPtr(exception_);
545 }
546 }
547
logPrefix() const548 const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const {
549 static std::string prefix = c10::str("[Rank ", rank_, "] ");
550 return prefix;
551 }
552
setException(std::exception_ptr exception_ptr)553 void ProcessGroupNCCL::WorkNCCL::setException(
554 std::exception_ptr exception_ptr) {
555 std::unique_lock<std::mutex> lock(mutex_);
556 exception_ = exception_ptr;
557 }
558
559 // Helper that checks if the NCCL kernels are completed on the GPUs
finishedGPUExecution()560 bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() {
561 checkAndSetException();
562 return finishedGPUExecutionInternal();
563 }
564
startedGPUExecutionInternal() const565 bool ProcessGroupNCCL::WorkNCCL::startedGPUExecutionInternal() const {
566 // if timing is disabled we won't have allocated start events
567 if (!timingEnabled_) {
568 return false;
569 }
570 // Checking the work's corresponding CUDA event's status
571 if (!ncclStartEvent_->query()) {
572 return false;
573 }
574 return true;
575 }
576
finishedGPUExecutionInternal() const577 bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const {
578 // Checking the work's corresponding CUDA event's status
579 // It calls `cudaEventQuery` eventually. Although this seems to be a
580 // non-blocking call, but we did notice hangs in the past. It can
581 // hang if another thread is holding the CUDA global context lock. For
582 // example, when doing a `cudaDeviceSynchronize` or even
583 // `cudaStreamSynchronize`.
584 if (!ncclEndEvent_->query()) {
585 return false;
586 }
587 return true;
588 }
589
checkTimeout(std::optional<std::chrono::milliseconds> timeout)590 bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
591 std::optional<std::chrono::milliseconds> timeout) {
592 STATIC_SCOPED_WAIT_COUNTER(
593 pytorch.wait_counter.ProcessGroupNCCL__checkTimeout);
594 auto currentTimepoint = std::chrono::steady_clock::now();
595 auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
596 currentTimepoint - workStartTime_);
597 auto workTimeout = timeout ? *timeout : opTimeout_;
598
599 if (timeElapsed < workTimeout)
600 return false;
601
602 // Timed out
603
604 // There is already an error, we don't override it
605 if (exception())
606 return true;
607
608 std::string exceptionMsg = c10::str(
609 logPrefix(),
610 "Watchdog caught collective operation timeout: ",
611 *this,
612 " ran for ",
613 timeElapsed.count(),
614 " milliseconds before timing out.");
615
616 LOG(ERROR) << exceptionMsg;
617 std::exception_ptr exception_ptr =
618 std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exceptionMsg));
619 setException(exception_ptr);
620 return true;
621 }
622
handleException(ErrorHandlingMode errorHandling)623 void ProcessGroupNCCL::WorkNCCL::handleException(
624 ErrorHandlingMode errorHandling) {
625 if (exception_) {
626 auto exceptionMsg = c10::str(
627 "Some NCCL operations have failed or timed out. Due to the ",
628 "asynchronous nature of CUDA kernels, subsequent GPU operations ",
629 "might run on corrupted/incomplete data.");
630 LOG(ERROR) << logPrefix() << exceptionMsg;
631 C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.WorkNCCL.handleException");
632
633 if (SHOULD_TEAR_DOWN(errorHandling)) {
634 auto tearDownMsg = c10::str(
635 "To avoid data inconsistency, we are taking the entire process down.");
636 LOG(ERROR) << logPrefix() << tearDownMsg;
637 std::rethrow_exception(exception_);
638 }
639 }
640 }
641
synchronize()642 void ProcessGroupNCCL::WorkNCCL::synchronize() {
643 // Call Synchronize without a timeout. We use this method to avoid adding a
644 // timeout argument to the public synchronize API.
645 synchronizeInternal(kNoTimeout);
646 }
647
synchronizeStream()648 void ProcessGroupNCCL::WorkNCCL::synchronizeStream() {
649 auto currentStream = at::cuda::getCurrentCUDAStream(device_.index());
650 // Block the current stream on the NCCL stream
651 ncclEndEvent_->block(currentStream);
652
653 if (avoidRecordStreams_) {
654 stashed_for_allocator_safety_->clear();
655 }
656 }
657
658 // Waiting on the work's corresponding CUDA events
synchronizeInternal(std::chrono::milliseconds timeout)659 void ProcessGroupNCCL::WorkNCCL::synchronizeInternal(
660 std::chrono::milliseconds timeout) {
661 synchronizeStream();
662
663 // In case of blocking, wait for the operation to complete.
664 if (blockingWait_) {
665 while (!isCompleted()) {
666 bool timedOut = checkTimeout(
667 timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout));
668 // Explicitly abort ncclComms here before throwing this timed out
669 // exception to users.
670 // If throwing timed out excepiton without aborting nccl communicators
671 // here, it was observed that CUDA GPU will have 100% utilization and
672 // can not run new events successfully.
673 if (timedOut) {
674 std::string exceptionMsg = c10::str(
675 logPrefix(),
676 "Work ",
677 (*this),
678 " timed out in blocking wait (TORCH_NCCL_BLOCKING_WAIT=1).");
679 LOG(ERROR) << exceptionMsg;
680 break;
681 }
682 // Yield
683 std::this_thread::sleep_for(
684 std::chrono::milliseconds(kSynchronizeBusyWaitMillis));
685 }
686 // exception() includes timeout and error during blocking wait
687 if (exception()) {
688 // Abort NCCL communicators
689 abort();
690 // Throw exception (from main thread here)
691 handleException(TearDown);
692 }
693 }
694
695 // Device synchronize only after we've completed timeout checks.
696 if (barrierTensor_.defined()) {
697 // If we use the work to do barrier, we should block here
698 // `dist.barrier()` only requires all CPU processes to enter this
699 // function, hence we only need to make sure the dummy all-reduce has
700 // completed. So we would only need to sync the **current stream** back to
701 // host, and do not need to synchronize the entire device (which may have
702 // kernels running on other streams).
703 // Using `cudaStreamSynchronize` instead of `cudaDeviceSynchronize` can:
704 // - lower chance of hang;
705 // - CurrentCUDAStream is usually the context of the next operation in
706 // Python, thus blocking current stream would already block the next
707 // compute kernel;
708 // - achieve better barrier performance.
709 auto currentStream = at::cuda::getCurrentCUDAStream(device_.index());
710 // CUDAStream wrapper will correctly use a DeviceGuard here
711 currentStream.synchronize();
712 }
713 }
714
715 // Same as calling synchronize().
wait(std::chrono::milliseconds timeout)716 bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) {
717 RECORD_PARAM_COMMS(
718 static_cast<int>(this->seq_), // seq
719 std::make_tuple(pgUID_, pgDesc_), // PG name tuple
720 rank_, // rank
721 "wait", // collective name
722 0, // inNelems
723 0, // outNelems
724 at::kByte, // dType
725 std::vector<int64_t>(), // inSplitSizes
726 std::vector<int64_t>(), // outSplitSizes
727 -1,
728 -1,
729 static_cast<int>(1)); // number of device?
730 synchronizeInternal(timeout);
731 // TODO(kwen2501): this should be moved to c10d tests, to qualify a NCCL
732 // upgrade. Once a NCCL version is qualified, this code should not be needed
733 // at runtime.
734 #ifdef PGNCCL_ENABLE_HASH
735 if (distDebugLevel_ >= DebugLevel::Detail) {
736 auto numel = getTensorsNumel(*outputs_);
737 auto hashValue = hashTensors(*outputs_);
738 PRINT_COLLECTIVE_HASH_SIGNATURE(
739 "output", opTypeToString(opType_), numel, hashValue);
740 }
741 #endif
742 // Always return true, because abort API is not implemented.
743 return true;
744 }
745
abort()746 void ProcessGroupNCCL::WorkNCCL::abort() {
747 // Abort all communicators of this work
748 ncclComm_->ncclCommAbort();
749
750 ncclCommDevIdxMapMutex.lock();
751 ncclCommDevIdxMap.erase(ncclComm_);
752 ncclCommDevIdxMapMutex.unlock();
753 }
754
CUDAEventCache()755 ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() {}
756
757 // CUDA event is used to record the start/end of one Work.
758 // Instead of let the CUDA event gets destroyed, we now reuse it after the Work
759 // has been erased from workMetaList_.
760 // This is to avoid the potential deadlock caused by CudaEventDestroy.
create(bool timing)761 std::shared_ptr<at::cuda::CUDAEvent> ProcessGroupNCCL::CUDAEventCache::create(
762 bool timing) {
763 auto deleter = [this, timing](at::cuda::CUDAEvent* event) {
764 std::lock_guard<std::mutex> lock(this->cacheMutex_);
765 this->eventsArray_[timing ? 1 : 0].push_back(event);
766 };
767 at::cuda::CUDAEvent* event = nullptr;
768 {
769 std::lock_guard<std::mutex> lock(cacheMutex_);
770 auto events = eventsArray_[timing ? 1 : 0];
771 if (!events.empty()) {
772 event = events.back();
773 events.pop_back();
774 }
775 }
776 if (!event) {
777 event = new at::cuda::CUDAEvent(
778 timing ? cudaEventDefault : cudaEventDisableTiming);
779 }
780 return std::shared_ptr<at::cuda::CUDAEvent>(event, std::move(deleter));
781 }
782
get()783 ProcessGroupNCCL::CUDAEventCache& ProcessGroupNCCL::CUDAEventCache::get() {
784 static ProcessGroupNCCL::CUDAEventCache cache;
785 return cache;
786 }
787
788 static std::atomic<size_t> process_group_id = 0;
789
790 constexpr const char* MULTI_DEVICE_ERROR_MSG =
791 "Expecting one tensor only but got multiple. You are probably using multiple "
792 "devices under one thread. The support for such usage has been deprecated. "
793 "For details, please refer to "
794 "https://pytorch.org/docs/stable/distributed.html#multi-gpu-collective-functions. "
795 "ProcessGroupNCCL continues supporting multi-process and multi-thread modes.";
796
ProcessGroupNCCL(const c10::intrusive_ptr<Store> & store,int rank,int size,c10::intrusive_ptr<Options> options)797 ProcessGroupNCCL::ProcessGroupNCCL(
798 const c10::intrusive_ptr<Store>& store,
799 int rank,
800 int size,
801 c10::intrusive_ptr<Options> options)
802 : Backend(rank, size),
803 store_(store),
804 options_(options),
805 ncclCommCounter_(0),
806 traceKeyStart_(getTraceStartKey("NCCL", rank)),
807 traceKeyEnd_(getTraceEndKey("NCCL", rank)),
808 terminateProcessGroup_(false),
809 terminateHeartbeatMonitorThread_(false),
810 collectiveDebugInfoMode_(false),
811 local_id_(process_group_id++),
812 intraNodeComm_(initIntraNodeComm()) {
813 TORCH_CHECK_WITH(
814 ValueError,
815 at::cuda::getNumGPUs() != 0,
816 "ProcessGroupNCCL is only supported with GPUs, no GPUs found!");
817
818 // getNcclVersion needs to get called before launching threads which can
819 // potentially call getenv. getNcclVersion internally calls setenv to set some
820 // environment variables from config file, which can race with getenv from
821 // other threads and cause segfaults.
822 const auto ncclVersion = getNcclVersion();
823 this->setGroupUid(options_->group_name);
824 this->localDeviceCount_ = at::cuda::getNumGPUs();
825 logPrefix_ = createLogPrefix();
826 blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false);
827 asyncErrorHandling_ = static_cast<ErrorHandlingMode>(
828 getCvarInt(TORCH_NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/));
829 desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) ||
830 (dist_debug_level_ >= DebugLevel::Detail);
831 rethrowCUDAErrors_ = getCvarBool(TORCH_NCCL_RETHROW_CUDA_ERRORS, true);
832 // TODO, we should either deprecate TORCH_NCCL_DUMP_ON_TIMEOUT
833 // or change its name to reflect that dump happens on exception including
834 // both timeout and other errors.
835 dumpOnTimeoutOrEx_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) ||
836 (dist_debug_level_ >= DebugLevel::Detail);
837 // logging C++ stack isn't safe. Introduce a variable to control it.
838 logCppStackOnUncleanShutdown_ =
839 getCvarBool(TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN, true);
840 enableNanCheck_ = getCvarBool(TORCH_NCCL_NAN_CHECK, false);
841 heartbeat_ = 1ULL;
842 monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true));
843 cudaEventCacheEnabled_.store(getCvarBool(TORCH_NCCL_CUDA_EVENT_CACHE, false));
844 heartbeatTimeoutInSec_ =
845 getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 8 /*8 Mins*/);
846 waitTimeoutDumpInMilSec_ =
847 getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 60 * 1000 /*60 Sec*/);
848 coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000);
849 ncclTraceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 0);
850 enableCollecticeHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail);
851 // store_ usually is wrapped with PrefixStore and the prefix is different
852 // across different ProcessGroupNCCL(PG) instances. We need to get the
853 // underlying non-PrefixStore for sharing global information shared across
854 // different PGs.
855 PrefixStore* prefixStore = dynamic_cast<PrefixStore*>(store_.get());
856 globalStore_ =
857 prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_;
858 #ifdef ENABLE_NCCL_ERROR_CHECKING
859 enableTiming_.store(
860 getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_);
861 #endif
862 avoidRecordStreams_ = getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false);
863 #ifdef NCCL_HAS_COMM_REGISTER
864 useTensorRegisterAllocatorHook_ =
865 getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false);
866 if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
867 expandable_segments()) {
868 useTensorRegisterAllocatorHook_ = false;
869 LOG(INFO)
870 << logPrefix()
871 << "disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode.";
872 }
873 #endif
874
875 if (blockingWait_) {
876 if (asyncErrorHandling_ != NoHandling || desyncDebug_) {
877 LOG(INFO)
878 << logPrefix() << "TORCH_NCCL_BLOCKING_WAIT and "
879 << "TORCH_NCCL_ASYNC_ERROR_HANDLING|TORCH_NCCL_DESYNC_DEBUG"
880 << "should not both be enabled. "
881 << "Only TORCH_NCCL_BLOCKING_WAIT is being used in this process.";
882 asyncErrorHandling_ = NoHandling;
883 desyncDebug_ = false;
884 }
885 } else {
886 if (desyncDebug_ && asyncErrorHandling_ == NoHandling) {
887 LOG(INFO)
888 << logPrefix()
889 << "TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING "
890 << "must both be enabled. "
891 << "Enabling TORCH_NCCL_ASYNC_ERROR_HANDLING.";
892 asyncErrorHandling_ = SkipCleanUp;
893 }
894 }
895
896 #ifdef ENABLE_NCCL_ERROR_CHECKING
897 ncclCommWatchdogThread_ =
898 std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
899 #endif
900
901 init();
902 const std::string OFF = "OFF";
903 std::string torch_distributed_debug =
904 getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str());
905 LOG(INFO) << logPrefix() << "ProcessGroupNCCL initialization options: "
906 << "size: " << size << ", global rank: " << globalRank()
907 << ", TIMEOUT(ms): " << options_->timeout.count()
908 << ", USE_HIGH_PRIORITY_STREAM: "
909 << options_->is_high_priority_stream
910 << ", SPLIT_FROM: " << options_->split_from
911 << ", SPLIT_COLOR: " << options_->split_color
912 << ", PG Name: " << options_->group_name;
913
914 LOG(INFO) << logPrefix() << "ProcessGroupNCCL environments: "
915 << "NCCL version: " << ncclVersion
916 << ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_
917 << ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnTimeoutOrEx_
918 << ", TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: "
919 << waitTimeoutDumpInMilSec_
920 << ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_
921 << ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load()
922 << ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_
923 << ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
924 #ifdef NCCL_HAS_COMM_REGISTER
925 << ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: "
926 << useTensorRegisterAllocatorHook_
927 #endif
928 << ", TORCH_NCCL_ENABLE_MONITORING: "
929 << monitorThreadEnabled_.load()
930 << ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_
931 << ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_
932 << ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_
933 << ", TORCH_NCCL_NAN_CHECK: " << enableNanCheck_
934 << ", TORCH_NCCL_CUDA_EVENT_CACHE: " << cudaEventCacheEnabled_
935 << ", TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: "
936 << logCppStackOnUncleanShutdown_;
937
938 if (options_->global_ranks_in_group.empty()) {
939 this->globalRankStart = 0;
940 } else {
941 this->globalRankStart = options_->global_ranks_in_group[0];
942 }
943
944 if (options_->global_ranks_in_group.empty()) {
945 this->globalRankStride = 1;
946 } else if (options_->global_ranks_in_group.size() == 1) {
947 this->globalRankStride = 0;
948 } else {
949 bool ranksAreStrided = true;
950 int startRank = options_->global_ranks_in_group[0];
951 int stride =
952 options_->global_ranks_in_group[1] - options_->global_ranks_in_group[0];
953 for (std::vector<uint64_t>::size_type i = 0;
954 i < options_->global_ranks_in_group.size();
955 i++) {
956 if (options_->global_ranks_in_group[i] != startRank + i * stride) {
957 ranksAreStrided = false;
958 break;
959 }
960 }
961
962 if (ranksAreStrided) {
963 this->globalRankStride = options_->global_ranks_in_group[1] -
964 options_->global_ranks_in_group[0];
965 } else {
966 this->globalRankStride = -1;
967 }
968 }
969
970 // Attach hooks to cache allocator to trigger the hooks whenever a traced
971 // action is called. In the following hooks, we register a newly allocated
972 // segment when SEGMENT_ALLOC action occurs, and deregister a segment when
973 // SEGMENT_FREE action occurs.
974 // We attach hooks only once at the first PG creation.
975 // Attaching hooks fails if CUDACachingAllocator is not initialized, so
976 // lazyInitCUDA is called (and is a no-op if CUDA is already initialized).
977 if (useTensorRegisterAllocatorHook_ && !allocatorHooksAttached) {
978 at::globalContext().lazyInitCUDA();
979 c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker(
980 &cacheAllocatorRegisterHook);
981 c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker(
982 &cacheAllocatorDeregisterHook);
983 allocatorHooksAttached = true;
984 }
985 }
986
eagerConnectSingleDevice(at::Device device)987 void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) {
988 const auto key = getKeyFromDevice(device);
989 LOG(INFO) << logPrefix() << "Eagerly connecting nccl backend with device "
990 << device;
991 getNCCLComm(key, device, OpType::ALLREDUCE);
992 }
993
performNocolorSplit(at::Device device)994 void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
995 // If our backend doesn't support splitting, this is a no-op for
996 // ranks not in the new subgroup (and ranks that would be in it will
997 // just use a new communicator rather than split).
998 #ifdef NCCL_HAS_COMM_SPLIT
999 const auto key = getKeyFromDevice(device);
1000 LOG(INFO) << logPrefix() << "Performing nocolor split on backend device "
1001 << device << ", key " << key << ", i am " << this;
1002 auto comm = getNCCLComm(key, device, OpType::ALLREDUCE);
1003 NCCLComm::split(
1004 comm.get(),
1005 NCCL_SPLIT_NOCOLOR,
1006 rank_,
1007 options_->config,
1008 options_->global_ranks_in_group);
1009 #endif
1010 }
1011
1012 c10::intrusive_ptr<intra_node_comm::IntraNodeComm> ProcessGroupNCCL::
initIntraNodeComm()1013 initIntraNodeComm() {
1014 using IntraNodeComm = intra_node_comm::IntraNodeComm;
1015 if (!IntraNodeComm::isEnabled()) {
1016 return nullptr;
1017 }
1018 auto prefixStore = c10::make_intrusive<PrefixStore>("IntraNodeComm", store_);
1019 auto comm = c10::make_intrusive<IntraNodeComm>(prefixStore, rank_, size_);
1020 if (comm->rendezvous()) {
1021 return comm;
1022 } else {
1023 return nullptr;
1024 }
1025 }
1026
setSequenceNumberForGroup()1027 void ProcessGroupNCCL::setSequenceNumberForGroup() {
1028 } // NCCL just starts sequence numbers at 0.
1029
getSequenceNumberForGroup()1030 uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() {
1031 return seqCollective_;
1032 }
1033
registerOnCompletionHook(std::function<void (std::shared_ptr<WorkInfo>)> && hook)1034 void ProcessGroupNCCL::registerOnCompletionHook(
1035 std::function<void(std::shared_ptr<WorkInfo>)>&& hook) {
1036 TORCH_CHECK_WITH(
1037 DistBackendError,
1038 onCompletionHook_ == nullptr,
1039 "ProcessGroupNCCL OnCompletion hook already registered");
1040
1041 TORCH_CHECK_WITH(
1042 ValueError,
1043 enableTiming_.load(),
1044 "ProcessGroupNCCL OnCompletion hook requires recording start and end "
1045 "events which require setting TORCH_NCCL_ENABLE_TIMING environment variable. "
1046 "This is only available for NCCL version >= 2.4.");
1047 onCompletionHook_ = std::move(hook);
1048 onCompletionHookThread_ = std::thread(&ProcessGroupNCCL::runHookLoop, this);
1049 }
1050
1051 // must release GIL when calling this method
waitForPendingWorks()1052 void ProcessGroupNCCL::waitForPendingWorks() {
1053 // Reasoning about hook completion:
1054 // 1. waitForPendingWorks should be called after user code has finished
1055 // calling
1056 // all collectives. This means, when we got here, all of the collectives
1057 // are either in workMetaList_ or has been erased from workMetaList_.
1058 // 2. The watchdog thread grabs both locks to move Work object from the
1059 // workMetaList_ to the completedWorkList_, and the hook thread only erases
1060 // a Work object after the hook is returned. Therefore, after user code
1061 // calls a collective, its Work object is either in workMetaList_ or in
1062 // completedWorkList_ before it finishes.
1063 // 3. We have three threads and two locks.
1064 // a. main thread (this function) grabs two locks atomically
1065 // b. watchdog thread (watchdogHandler function) always grabs
1066 // workMetaListMutex_
1067 // first and then grabs completedWorkListMutex_.
1068 // c. hook thread (runHookLoop function) only grabs
1069 // completedWorkListMutex_. Therefore, locks are always acquired in the
1070 // same order and hence no deadlocks.
1071 while (true) {
1072 {
1073 std::lock(workMetaListMutex_, completedWorkListMutex_);
1074 std::lock_guard<std::mutex> lockWork(workMetaListMutex_, std::adopt_lock);
1075 std::lock_guard<std::mutex> lockHook(
1076 completedWorkListMutex_, std::adopt_lock);
1077
1078 if (workMetaList_.empty() && completedWorkList_.empty()) {
1079 return;
1080 }
1081 }
1082
1083 std::this_thread::sleep_for(
1084 std::chrono::milliseconds(kWatchdogThreadSleepMillis));
1085 }
1086 }
1087
enableCollectivesTiming()1088 void ProcessGroupNCCL::enableCollectivesTiming() {
1089 enableTiming_.store(true);
1090 }
1091
waitForFutureOrTimeout(std::future<bool> & fut,const std::chrono::milliseconds & timeOutMilSec,const std::string & futDescription,bool throwException,bool log)1092 void ProcessGroupNCCL::waitForFutureOrTimeout(
1093 std::future<bool>& fut,
1094 const std::chrono::milliseconds& timeOutMilSec,
1095 const std::string& futDescription,
1096 bool throwException,
1097 bool log) {
1098 std::string errorMsg;
1099
1100 ::c10d::C10dLoggingData data;
1101 if (log) {
1102 data.integers["pg_id"] = local_id_;
1103 data.integers["rank"] = rank_;
1104 data.integers["global_rank"] = globalRank();
1105 data.strings["flight_recorder_version"] = c10d::version_val_str;
1106 }
1107
1108 TORCH_CHECK(fut.valid(), "Expected a valid future");
1109 std::future_status status = fut.wait_for(timeOutMilSec);
1110 if (status == std::future_status::ready) {
1111 // Calling .get() will re-raise any exception from the future, and we don't
1112 // care about the retval
1113 try {
1114 bool result = fut.get();
1115 if (result) {
1116 LOG(INFO) << logPrefix()
1117 << "future is successfully executed for: " << futDescription;
1118 if (log) {
1119 data.strings["status"] = "SUCCESS";
1120 }
1121 }
1122 } catch (const std::exception& e) {
1123 errorMsg = c10::str(
1124 logPrefix(),
1125 "Exception thrown when waiting for future ",
1126 futDescription,
1127 ": ",
1128 e.what());
1129 if (log) {
1130 data.strings["status"] = "EXCEPTION";
1131 data.strings["exception"] = e.what();
1132 }
1133 LOG(ERROR) << errorMsg;
1134 } catch (...) {
1135 errorMsg = c10::str(
1136 logPrefix(),
1137 "Unknown exception thrown when waiting for future ",
1138 futDescription);
1139 if (log) {
1140 data.strings["status"] = "EXCEPTION";
1141 data.strings["exception"] = "Unknown exception";
1142 }
1143 LOG(ERROR) << errorMsg;
1144 }
1145 } else {
1146 errorMsg = c10::str(
1147 logPrefix(),
1148 "Future for ",
1149 futDescription,
1150 " timed out after ",
1151 timeOutMilSec.count(),
1152 " ms");
1153 data.strings["status"] = "TIMEOUT";
1154 LOG(ERROR) << errorMsg;
1155 }
1156 if (log) {
1157 auto logger = c10d::C10dLogger::getLogger();
1158 if (logger) {
1159 logger->log(data);
1160 }
1161 }
1162 if (throwException && !errorMsg.empty()) {
1163 C10_THROW_ERROR(DistBackendError, errorMsg);
1164 }
1165 }
1166
abortCommsFromMap(std::unordered_map<std::string,std::shared_ptr<NCCLComm>> & ncclCommsMap,std::optional<std::string> abortReason)1167 void ProcessGroupNCCL::abortCommsFromMap(
1168 std::unordered_map<std::string, std::shared_ptr<NCCLComm>>& ncclCommsMap,
1169 std::optional<std::string> abortReason) {
1170 // The process may control multiple devices, loop through the communicators on
1171 // each device
1172 for (auto& it : ncclCommsMap) {
1173 auto& devName = it.first;
1174 auto& ncclComm = it.second;
1175 at::cuda::OptionalCUDAGuard gpuGuard;
1176 at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName);
1177 if (deviceIndex >= 0) {
1178 // For P2P comms, the deviceIndex could be -1 (invalid), as the keys in
1179 // the map could be non deviceIndex, but rank to rank numbers. So we
1180 // indeed need to check if deviceIndex >= 0
1181 // TODO: fix `getIndexFromDeviceKey` or fix `DeviceKey`
1182 gpuGuard.set_index(deviceIndex);
1183 }
1184 LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ "
1185 << ncclComm->ncclComm_ << " on CUDA device: " << devName;
1186 ncclComm->ncclCommAbort(abortReason);
1187 // Note that we don't remove the aborted communicators from the
1188 // cache. The reason is that if we do remove the communicator
1189 // from the cache, it is possible that a new collective operation
1190 // calls `ncclCommInitRank` to create a new communicator whereas
1191 // other ranks might have failed/timed out and didn't enter
1192 // `ncclCommInitRank`. As a result, when there is a failure on
1193 // a communicator the application receives an exception and its
1194 // their responsibility to destroy the process group and recreate
1195 // it to recover from errors.
1196
1197 LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroyed "
1198 << " communicator on CUDA device: " << devName;
1199 }
1200 }
1201
1202 // Abort all communicators on this rank
abort(std::optional<std::string> abortReason)1203 bool ProcessGroupNCCL::abort(std::optional<std::string> abortReason) {
1204 // Remove record from global ncclCommDevIdxMapMutex before aboarting,
1205 // so that a new cache segment would not register to already aborded
1206 // communicators. Note that ncclCommDevIdxMap is a global container which may
1207 // contain other PG's communicators, thus we need to only erase communicators
1208 // for the current PG.
1209 ncclCommDevIdxMapMutex.lock();
1210 for (auto& it : devNCCLCommMap_) {
1211 auto& ncclComm = it.second;
1212 ncclCommDevIdxMap.erase(ncclComm);
1213 }
1214 ncclCommDevIdxMapMutex.unlock();
1215
1216 std::lock_guard<std::mutex> lock(mutex_);
1217 abortCommsFromMap(devNCCLCommMap_, abortReason);
1218 abortCommsFromMap(inInitializationCommMap_, abortReason);
1219 return true;
1220 }
1221
shutdown(std::optional<std::string> reason)1222 void ProcessGroupNCCL::shutdown(std::optional<std::string> reason) {
1223 // Don't join threads here since the purpose of this method is to abort all
1224 // communicators and signal the threads to exit. Joining on the threads could
1225 // potentially block and hence avoid it in this method.
1226 terminateProcessGroup_.store(true);
1227 workMetaListCV_.notify_one();
1228
1229 // lauch abort asynchrounously and wait for it to complete or timeout
1230 LOG(INFO) << logPrefix()
1231 << "Launching ProcessGroupNCCL abort asynchrounously.";
1232 std::future<bool> fut = std::async(
1233 std::launch::async, [this, &reason]() { return this->abort(reason); });
1234
1235 waitForFutureOrTimeout(
1236 fut, options_->timeout, "ProcessGroup abort", true, false);
1237 LOG(INFO) << logPrefix() << "ProcessGroupNCCL aborts successfully.";
1238
1239 // We need to wait for abort to finish before we can safely shut down
1240 // heartbeat monitoring thread.
1241 terminateHeartbeatMonitorThread_.store(true);
1242 monitorWakeUpCV_.notify_one();
1243 }
1244
~ProcessGroupNCCL()1245 ProcessGroupNCCL::~ProcessGroupNCCL() {
1246 LOG(INFO) << logPrefix() << "ProcessGroupNCCL destructor entered.";
1247
1248 if (!terminateProcessGroup_.load()) {
1249 if (rank_ % localDeviceCount_ == 0) {
1250 TORCH_WARN_ONCE(
1251 "WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. ",
1252 "On normal program exit, the application should call destroy_process_group to ",
1253 "ensure that any pending NCCL operations have finished in this process. "
1254 "In rare cases this process can exit before this point and block the progress of "
1255 "another member of the process group. This constraint has always been present, "
1256 " but this warning has only been added since PyTorch 2.4");
1257 }
1258 // If user haven't explicitly destroy/shutdown process group, destructor
1259 // needs to do so
1260 shutdown();
1261 }
1262
1263 // Wait for all threads to finish before returning
1264 #ifdef ENABLE_NCCL_ERROR_CHECKING
1265 if (ncclCommWatchdogThread_.joinable()) {
1266 ncclCommWatchdogThread_.join();
1267 LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined.";
1268 }
1269 if (ncclHeartbeatMonitorThread_.joinable()) {
1270 ncclHeartbeatMonitorThread_.join();
1271 LOG(INFO) << logPrefix()
1272 << "ProcessGroupNCCL heart beat monitor thread joined.";
1273 }
1274 #endif
1275 if (onCompletionHookThread_.joinable()) {
1276 onCompletionHookThread_.join();
1277 LOG(INFO) << logPrefix()
1278 << "ProcessGroupNCCL onCompletionHookThread thread joined.";
1279 }
1280 }
1281
dumpDebuggingInfo()1282 bool ProcessGroupNCCL::dumpDebuggingInfo() {
1283 // Serialize all calls to this function to avoid corrupting data, but allow
1284 // multiple calls in one runtime. User is responsible for preserving the
1285 // output file from an earlier call before a later call overwrites it.
1286 static std::mutex writeDebugInfoMutex;
1287 std::lock_guard<std::mutex> lock(writeDebugInfoMutex);
1288 LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info.";
1289 if (ncclTraceBufferSize_ > 0) {
1290 // We dump nccl trace into local disk by default and users can register
1291 // their customized writer by inheriting `DebugInfoWriter` via
1292 // `registerDebugInfoWriter`.
1293 auto ncclTrace = dump_nccl_trace(true, true, false);
1294 DebugInfoWriter& writer = DebugInfoWriter::getWriter(globalRank());
1295 LOG(INFO) << logPrefix() << "ProcessGroupNCCL dumping nccl trace to "
1296 << writer.getWriterTarget();
1297 writer.write(ncclTrace);
1298 return true;
1299 }
1300 return false;
1301 }
1302
terminateProcess(std::string errMsg)1303 void ProcessGroupNCCL::terminateProcess(std::string errMsg) {
1304 // Logging with `FATAL`, after errMsg printed, it calls `std::abort()`
1305 // to terminate the program execution.
1306 LOG(FATAL) << logPrefix() << errMsg;
1307 }
1308
computeDeltaMS(std::chrono::time_point<std::chrono::steady_clock> start,std::chrono::time_point<std::chrono::steady_clock> end)1309 int computeDeltaMS(
1310 std::chrono::time_point<std::chrono::steady_clock> start,
1311 std::chrono::time_point<std::chrono::steady_clock> end) {
1312 return std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
1313 .count();
1314 }
1315
getNCCLWatchdogTimeoutErrorMsg(const std::string & extraMsg)1316 std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutErrorMsg(
1317 const std::string& extraMsg) {
1318 return c10::str(
1319 logPrefix(),
1320 "Received a dump signal due to a collective timeout from ",
1321 extraMsg,
1322 " and we will try our best to dump the debug info. ",
1323 "Last enqueued NCCL work: ",
1324 pgStatus_->lastEnqueuedSeq,
1325 ", last completed NCCL work: ",
1326 pgStatus_->lastCompletedSeq,
1327 ".",
1328 "This is most likely caused by incorrect usages of collectives, e.g., wrong ",
1329 "sizes used across ranks, the order of collectives is not same for all ranks ",
1330 "or the scheduled collective, for some reason, didn't run. Additionally, ",
1331 "this can be caused by GIL deadlock or other reasons such as network errors or ",
1332 "bugs in the communications library (e.g. NCCL), etc. ");
1333 }
1334
getNCCLWatchdogTimeoutExitMsg(const std::string & exitReason)1335 std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutExitMsg(
1336 const std::string& exitReason) {
1337 return c10::str(
1338 logPrefix(),
1339 "Terminating the process after attempting to dump debug info, due to ",
1340 exitReason,
1341 ".");
1342 }
1343
heartbeatMonitor()1344 void ProcessGroupNCCL::heartbeatMonitor() {
1345 c10::setThreadName("pt_nccl_heartbt");
1346
1347 uint64_t heartBeatCounter = 0ULL;
1348 std::string errorMsg;
1349 std::string exitReason;
1350 bool checkDumpSignal = (dumpOnTimeoutOrEx_ && local_id_ == 0);
1351 int monitorPollInterval = checkDumpSignal ? coordCheckIntervalMilSec_
1352 : heartbeatTimeoutInSec_ * 1000;
1353 auto lastTimePollStore = std::chrono::steady_clock::now();
1354 auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now();
1355 std::optional<DumpPipe> dumpPipe = std::nullopt;
1356 if (local_id_ == 0) {
1357 // DumpPipe is one per-trainer process, and its convenient to name them
1358 // after 'global' ranks in the system, So we assume processgroup (uid)==0 is
1359 // the global PG and has globally unique rank ids across trainers.
1360 dumpPipe.emplace(rank_);
1361 }
1362 while (true) {
1363 // This won't have any lock since this lock is only used here.
1364 // Please be aware that mutex `monitorMutex_` should not be used
1365 // somewhere else to avoid the deadlock.
1366 std::unique_lock<std::mutex> lock(monitorMutex_);
1367 if (monitorWakeUpCV_.wait_for(
1368 lock, std::chrono::milliseconds(monitorPollInterval), [&] {
1369 return terminateHeartbeatMonitorThread_.load();
1370 })) {
1371 // For the normal complete or user interception, monitorWakeUpCV_
1372 // will get notified, we early return and exit heartbeatMonitor.
1373 return;
1374 }
1375 auto currentTime = std::chrono::steady_clock::now();
1376
1377 // We put extra functionality in the thread for the default PG (aka,
1378 // local_id_=0) because the signal is same across different PGs. We only
1379 // need to run once per process to avoid duplicate things performed in too
1380 // many separate threads. For example, we check a global flag on the
1381 // TCPStore periodically to see if any PG on any rank observed a timeout and
1382 // signaled peers to dump debugging info, and we avoid hammering the
1383 // TCPStore from all PGs on the same rank.
1384 if (checkDumpSignal) {
1385 // There are two scenarios where monitor thread will dump on timeout:
1386 // 1. The current rank is the first to observe a timeout in watchdog.
1387 // (shouldDump_ was set to true by the watchdog thread).
1388 // 2. Other ranks detected the timeout and signal the current rank to
1389 // dump. In addtion, monitor threads will dump if watchdog threads has no
1390 // heartbeat or dumpPipe is not empty.
1391 if (shouldDump_.load()) {
1392 errorMsg = getNCCLWatchdogTimeoutErrorMsg("this local rank");
1393 exitReason = "collective timeout or exception";
1394 break;
1395 }
1396 // We poll store to see if some ranks have flagged a timeout when
1397 // we haven't polled for `heartbeat_timeout` seconds and there haven't
1398 // any work added or removed for `watchdog_timeout` seconds.
1399 if (computeDeltaMS(lastWorkListUpdateTime_, currentTime) >=
1400 kWatchdogThreadSleepMillis &&
1401 computeDeltaMS(lastTimePollStore, currentTime) >=
1402 coordCheckIntervalMilSec_) {
1403 lastTimePollStore = currentTime;
1404 // Wrap globalStore_->check() in a try-catch block to avoid crashing if
1405 // the store is not available.
1406 bool checkExceptionDump = false;
1407 try {
1408 checkExceptionDump =
1409 globalStore_->check({std::string(EXCEPTION_DUMP)});
1410 } catch (const std::exception& e) {
1411 LOG(WARNING)
1412 << logPrefix()
1413 << "Failed to check the \"should dump\" flag on TCPStore, "
1414 << "(maybe TCPStore server has shut down too early), with error: "
1415 << e.what();
1416 // We give up for now assuming TCPStore has been torn down.
1417 return;
1418 }
1419
1420 if (checkExceptionDump) {
1421 int timeOutRank = -1;
1422 if (!shouldDump_.load()) {
1423 LOG(ERROR)
1424 << logPrefix()
1425 << "Observed flight recorder dump signal from another rank via TCPStore.";
1426 }
1427 shouldDump_.store(true);
1428 try {
1429 auto vec = globalStore_->get(std::string(EXCEPTION_DUMP));
1430 TORCH_CHECK_WITH(
1431 DistBackendError,
1432 vec.size() == sizeof(int),
1433 "Invalid size for the timeout rank ID");
1434 std::memcpy(&timeOutRank, vec.data(), vec.size());
1435 } catch (const std::exception& e) {
1436 LOG(ERROR) << logPrefix()
1437 << "Failed to get timeout rank ID from TCPStore."
1438 << e.what();
1439 }
1440 errorMsg =
1441 getNCCLWatchdogTimeoutErrorMsg(c10::str(" rank ", timeOutRank));
1442 exitReason = "collective timeout or exception";
1443 break;
1444 }
1445 }
1446 }
1447
1448 if (computeDeltaMS(lastTimeHeartBeatCheck, currentTime) >=
1449 heartbeatTimeoutInSec_ * 1000) {
1450 // Check the heart beat of watchdog thread.
1451 lastTimeHeartBeatCheck = currentTime;
1452 auto heartbeat = heartbeat_.load();
1453 if (heartbeat != heartBeatCounter) {
1454 heartBeatCounter = heartbeat;
1455 } else {
1456 shouldDump_.store(true);
1457 // Watchdog heartbeat timeout.
1458 errorMsg = c10::str(
1459 logPrefix(),
1460 "ProcessGroupNCCL's watchdog got stuck for ",
1461 heartbeatTimeoutInSec_,
1462 " seconds without making progress in monitoring enqueued collectives. ",
1463 "This typically indicates a NCCL/CUDA API (e.g., CudaEventDestroy) hang blocking the watchdog, ",
1464 "and could be triggered by another thread holding the GIL inside a ",
1465 "CUDA api (for example, CudaEventDestroy), or other deadlock-prone behaviors.",
1466 "If you suspect the watchdog is not actually stuck and a longer timeout would help, ",
1467 "you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value "
1468 "or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0)."
1469 "If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout "
1470 "or false positive abort; otherwise, please attempt to debug the hang. ");
1471 exitReason = "ProcessGroupNCCL watchdog hang";
1472 break;
1473 }
1474 }
1475 // process a request to dump the trace. only PG uid 0 will respond to dump
1476 // requests, but this is fine since all PG's feed into the same flight
1477 // recorder and dump. After dump, the training should continue.
1478 if (dumpPipe.has_value() && dumpPipe->shouldDump()) {
1479 // best effort dump, not waiting for the dump here
1480 std::future<bool> fut = std::async(
1481 std::launch::async, [this]() { return this->dumpDebuggingInfo(); });
1482 }
1483 }
1484 LOG(ERROR) << errorMsg;
1485
1486 auto& cpp_dumper = get_cpp_trace_dumper();
1487 if (logCppStackOnUncleanShutdown_ && cpp_dumper.has_value()) {
1488 LOG(INFO) << "Dumping c++ stacktraces:";
1489 cpp_dumper.value()([](const std::string& line) { LOG(INFO) << line; });
1490 }
1491
1492 if (checkDumpSignal && shouldDump_.load()) {
1493 // Store debug info to storage if no other thread does it. (By default to
1494 // local disk)
1495 std::future<bool> asyncDebugDump = std::async(
1496 std::launch::async, [this]() { return this->dumpDebuggingInfo(); });
1497
1498 // wait for the dump until timeout - log data
1499 waitForFutureOrTimeout(
1500 asyncDebugDump,
1501 std::chrono::milliseconds(waitTimeoutDumpInMilSec_),
1502 "Flight recorder dump in heartbeatMonitor",
1503 false,
1504 true);
1505 }
1506
1507 if (get_gil_checker() != nullptr) {
1508 auto fut = launchAsyncGilCheck();
1509 auto kGilCheckTimeout = std::chrono::milliseconds(300);
1510 auto futStatus = fut.wait_for(kGilCheckTimeout);
1511 if (futStatus != std::future_status::ready) {
1512 TORCH_CHECK(
1513 futStatus != std::future_status::deferred,
1514 "Expected the future to have been launched eagerly.");
1515 LOG(ERROR)
1516 << "Could not acquire GIL within 300 ms on exit, possible GIL induced hang";
1517 }
1518 } else {
1519 LOG(INFO)
1520 << "GIL checker was not registered, perhaps this is a no-python build?";
1521 }
1522
1523 // There are two possible cases for the watchdog thread exit:
1524 // Case one: desync report runs quickly, and it follows the step:
1525 // collective timeout -> desync -> exception handling -> destructors
1526 // -> set terminateHeartbeatMonitorThread_ -> notify monitorWakeUpCV_.
1527 // So the code either early returns above or will skip the sleep below.
1528 // Case two: desync might be slow or get stuck. Or we get stuck in
1529 // destructors, we will sleep for some time before calling std::abort() to
1530 // kill the whole process.
1531 if ((terminateProcessGroup_.load() || collectiveDebugInfoMode_.load() ||
1532 shouldDump_.load()) &&
1533 !terminateHeartbeatMonitorThread_.load()) {
1534 // Leave another two mins for desync report generation or process group
1535 // destroy.
1536 std::this_thread::sleep_for(std::chrono::seconds(heartbeatTimeoutInSec_));
1537 LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_
1538 << " waiting for desync report or process group destroy.";
1539 }
1540
1541 // At this point, we either already sleep for another `heartbeatTimeoutInSec_`
1542 // or the thread has finished. Because we don't want to block the monitor
1543 // thread, so We mark the thread detach and the dump of debug info becomes
1544 // "best effort". If the process exit normally, marking it detach also makes
1545 // sense because we don't really care about dumping the debug info.
1546
1547 // We already log completion inside the thread, so it may not be necessary to
1548 // check the return value here. We mainly use a future so we can exit early
1549 // if done.
1550
1551 if (!terminateHeartbeatMonitorThread_.load()) {
1552 // Create a error message reported from MonitorThread, so
1553 // we throw exception and make the whole process to be killed.
1554 // TODO(fduwjj): After having a hang debug wiki, we need to update the wiki
1555 // url here.
1556 if (monitorThreadEnabled_.load()) {
1557 terminateProcess(getNCCLWatchdogTimeoutExitMsg(exitReason));
1558 } else {
1559 // Ideally we want to merge this one with the above one, but we are going
1560 // to remove the kill switch for monitor thread soon, so we keep this one
1561 // for now.
1562 LOG(ERROR)
1563 << logPrefix()
1564 << "ProcessGroupNCCL monitor thread is disabled, but would have terminated the process"
1565 << "after attempting to dump debug info, due to " << exitReason
1566 << ".";
1567 }
1568 }
1569 }
1570
ncclCommWatchdog()1571 void ProcessGroupNCCL::ncclCommWatchdog() {
1572 c10::setThreadName("pt_nccl_watchdg");
1573
1574 try {
1575 VLOG(2) << logPrefix() << "Process group watchdog thread started!";
1576 ncclHeartbeatMonitorThread_ =
1577 std::thread(&ProcessGroupNCCL::heartbeatMonitor, this);
1578 watchdogHandler();
1579 VLOG(2) << logPrefix()
1580 << "Process group watchdog thread terminated normally";
1581 } catch (std::exception& e) {
1582 if (std::string(e.what()).find("driver shutting down") !=
1583 std::string::npos) {
1584 LOG(INFO)
1585 << logPrefix()
1586 << "main process destroyed cuda before watchdog loop exited, terminating watchdog."
1587 << " (Watchdog caught exception: " << e.what();
1588
1589 } else {
1590 // Append error message reported from watchdogHandler
1591 const auto exitMsg = c10::str(
1592 logPrefix(),
1593 "Process group watchdog thread terminated with exception: ",
1594 e.what());
1595 LOG(ERROR) << exitMsg;
1596 if (C10_LIKELY(rethrowCUDAErrors_) ||
1597 !(std::string(e.what()).find("CUDA Error"))) {
1598 // TODO(whc) clean up the rethrow - why is it stored in a class var and
1599 // rethrown?
1600 watchDogException_ =
1601 std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg));
1602 std::rethrow_exception(watchDogException_);
1603 }
1604 }
1605 } catch (...) {
1606 const auto exitMsg = c10::str(
1607 logPrefix(),
1608 "Process group watchdog thread terminated with exception: unknown");
1609 LOG(ERROR) << exitMsg;
1610 watchDogException_ =
1611 std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg));
1612 std::rethrow_exception(watchDogException_);
1613 }
1614 }
1615
logWorkStart(WorkNCCL & work)1616 void ProcessGroupNCCL::logWorkStart(WorkNCCL& work) {
1617 if (work.startTraceUpdated_)
1618 return;
1619
1620 if (terminateProcessGroup_.load() || storeError_)
1621 return;
1622
1623 work.startTraceUpdated_ = true;
1624 storeError_ = !c10d::traceUpdate(
1625 store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_));
1626 }
1627
logWorkEnd(WorkNCCL & work)1628 void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work) {
1629 if (terminateProcessGroup_.load() || storeError_)
1630 return;
1631
1632 // In case the start of the work hasn't been logged
1633 if (!work.startTraceUpdated_) {
1634 logWorkStart(work);
1635 }
1636
1637 storeError_ = !c10d::traceUpdate(
1638 store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_));
1639 }
1640
getNCCLWatchdogDebugInfo()1641 std::string ProcessGroupNCCL::getNCCLWatchdogDebugInfo() {
1642 return retrieveDesyncReport(store_, "NCCL", rank_, size_);
1643 }
1644
1645 // We want to have both PG ID and global unique ID (guid) for the logging
1646 // prefix. PG ID records how many ProcessGroupNCCL objects were created on a
1647 // specific rank and is a stable index across ranks, which lets users reason
1648 // about, for example, the second PG we initialized on this rank is for FSDP,
1649 // and corresponds with PG ID = 1 on other ranks as well. Unlike PG ID, guid (or
1650 // group name) is a global unique ID across ranks. The guid is either a hash of
1651 // all the ranks in the group or a counter of how many times
1652 // `_process_group_name` is called, essentially it means how many times we
1653 // have PGs users have created. Before using split_group, even if
1654 // we are creating a new sub-PG, all ranks have to call the API at the same
1655 // time, and this makes `group_name` a unique identifier for a group (PG).
createLogPrefix() const1656 std::string ProcessGroupNCCL::createLogPrefix() const {
1657 if (!pg_desc_.empty() && pg_desc_ != "undefined") {
1658 return c10::str(
1659 "[PG ID ",
1660 local_id_,
1661 " PG GUID ",
1662 pg_uid_,
1663 "(",
1664 pg_desc_,
1665 ") Rank ",
1666 rank_,
1667 "] ");
1668 }
1669 return c10::str(
1670 "[PG ID ", local_id_, " PG GUID ", pg_uid_, " Rank ", rank_, "] ");
1671 }
1672
logPrefix() const1673 const std::string& ProcessGroupNCCL::logPrefix() const {
1674 return logPrefix_;
1675 }
1676
globalRank() const1677 const int& ProcessGroupNCCL::globalRank() const {
1678 static int globalRank = rank_;
1679 return globalRank;
1680 }
1681
groupRanks() const1682 const std::vector<uint64_t>& ProcessGroupNCCL::groupRanks() const {
1683 if (options_->global_ranks_in_group.empty() && local_id_ == 0) {
1684 static std::vector<uint64_t> globalRanks(size_);
1685 std::iota(globalRanks.begin(), globalRanks.end(), 0);
1686 return globalRanks;
1687 }
1688 return options_->global_ranks_in_group;
1689 }
1690
addEphemeralTimeout(const std::chrono::milliseconds & timeout)1691 void ProcessGroupNCCL::addEphemeralTimeout(
1692 const std::chrono::milliseconds& timeout) {
1693 std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
1694 ephemeralTimeoutActive_ += timeout;
1695 }
1696
verifyWorkTimeoutForTest(const c10::intrusive_ptr<Work> work,const std::chrono::milliseconds & timeout)1697 bool ProcessGroupNCCL::verifyWorkTimeoutForTest(
1698 const c10::intrusive_ptr<Work> work,
1699 const std::chrono::milliseconds& timeout) {
1700 // Since collective returns a c10d::Work, we need to cast it to WorkNCCL.
1701 if (auto workNCCL = c10::dynamic_intrusive_pointer_cast<WorkNCCL>(work)) {
1702 // workNCCL is now a c10::intrusive_ptr<WorkNCCL>
1703 return workNCCL->opTimeout_ == timeout;
1704 }
1705 C10_THROW_ERROR(
1706 DistBackendError, "Non c10d::WorkNCCL object returned from collective");
1707 }
1708
watchdogHandler()1709 void ProcessGroupNCCL::watchdogHandler() {
1710 bool done = false;
1711 lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
1712 auto lastStatusUpdateTime = std::chrono::steady_clock::now();
1713 std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList;
1714
1715 while (!done || !terminateProcessGroup_.load()) {
1716 std::unique_lock<std::mutex> lock(workMetaListMutex_);
1717 // We busy-poll the work vector every kWatchdogThreadSleepMillis
1718 // milliseconds as long as the atomic is True.
1719 workMetaListCV_.wait_for(
1720 lock,
1721 std::chrono::milliseconds(kWatchdogThreadSleepMillis),
1722 [&]() -> bool { return terminateProcessGroup_.load(); });
1723 // Bump up heart beat by one.
1724 heartbeat_++;
1725
1726 // Some versions of GLOG support less-spammy version of LOG_EVERY_MS
1727 // in which case we don't want to spam the logs.
1728 #ifdef LOG_EVERY_MS
1729 // Log the progress of this PG periodically
1730 C10_LOG_EVERY_MS(INFO, kWorkStatusUpdatePeriodMs) << c10::str(
1731 logPrefix(),
1732 "NCCL Work update periodically: ",
1733 "last enqueued NCCL work: ",
1734 pgStatus_->lastEnqueuedSeq,
1735 ", last completed NCCL work: ",
1736 pgStatus_->lastCompletedSeq,
1737 ".");
1738 #endif
1739 auto logger = ::c10d::C10dLogger::getLogger();
1740 if (logger &&
1741 computeDeltaMS(
1742 lastStatusUpdateTime, std::chrono::steady_clock::now()) >=
1743 kWorkStatusUpdatePeriodMs) {
1744 ::c10d::C10dLoggingData data;
1745 // logging integers
1746 data.integers["pg_id"] = local_id_;
1747 data.integers["rank"] = rank_;
1748 data.integers["global_rank"] = globalRank();
1749 data.integers["last_enqueued_work"] = pgStatus_->lastEnqueuedSeq;
1750 data.integers["last_started_work"] = pgStatus_->lastStartedSeq;
1751 data.integers["last_completed_work"] = pgStatus_->lastCompletedSeq;
1752 data.integers["last_enqueued_numel_in"] = pgStatus_->lastEnqueuedNumelIn;
1753 data.integers["last_enqueued_numel_out"] =
1754 pgStatus_->lastEnqueuedNumelOut;
1755 data.integers["last_completed_numel_in"] =
1756 pgStatus_->lastCompletedNumelIn;
1757 data.integers["last_completed_numel_out"] =
1758 pgStatus_->lastCompletedNumelOut;
1759 // logging strings
1760 data.strings["last_enqueued_work_name"] = pgStatus_->lastEnqueuedWorkName;
1761 data.strings["last_started_work_name"] = pgStatus_->lastStartedWorkName;
1762 data.strings["last_completed_work_name"] =
1763 pgStatus_->lastCompletedWorkName;
1764 data.strings["pg_name"] = pg_uid_;
1765 data.strings["pg_desc"] = pg_desc_;
1766 logger->log(data);
1767 lastStatusUpdateTime = std::chrono::steady_clock::now();
1768 }
1769
1770 for (auto it = workMetaList_.begin(); it != workMetaList_.end();
1771 /* no increment */) {
1772 auto& work = *it;
1773 // When terminateProcessGroup_ is true, communicators have already been
1774 // aborted, So cannot check exception based on them. But watchdog needs to
1775 // finish the check for the works that have already been enqueued to
1776 // workMetaList_
1777 if (!terminateProcessGroup_.load()) {
1778 work.checkAndSetException();
1779 }
1780 bool timedOut = work.checkTimeout();
1781
1782 // If work hits an exception (either an error or timeout)
1783 if (work.exception()) {
1784 // log as soon as exception is detected
1785 LOG(ERROR) << c10::str(
1786 logPrefix(),
1787 "Exception (either an error or timeout) detected by watchdog at work: ",
1788 work.seq_,
1789 ", last enqueued NCCL work: ",
1790 pgStatus_->lastEnqueuedSeq,
1791 ", last completed NCCL work: ",
1792 pgStatus_->lastCompletedSeq,
1793 ".");
1794 // try to notify other ranks via global TCPStore to dump the flight
1795 // recorder when a collective timeout or exception happens. Flight
1796 // recorder behavior is independent of desync Debug.
1797 if (dumpOnTimeoutOrEx_) {
1798 try {
1799 auto rank = globalRank();
1800 auto vec = std::vector<uint8_t>(
1801 reinterpret_cast<uint8_t*>(&rank),
1802 reinterpret_cast<uint8_t*>(&rank) + sizeof(rank));
1803 globalStore_->set(std::string(EXCEPTION_DUMP), vec);
1804 if (!shouldDump_.load()) {
1805 LOG(ERROR)
1806 << logPrefix()
1807 << "Broadcasting flight-recorder dump signal to other processes via TCPStore.";
1808 }
1809 // signal the monitor thread on PG0 to start dumping
1810 shouldDump_.store(true);
1811 // This sleep is used to give time for dumping before throwing
1812 // exception
1813 std::this_thread::sleep_for(
1814 std::chrono::seconds(heartbeatTimeoutInSec_));
1815 LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_
1816 << " giving time for flight recorder dumps to finish.";
1817 } catch (const std::exception& e) {
1818 LOG(ERROR) << logPrefix()
1819 << "Failed to set dump signal in tcpstore. "
1820 << "Error: " << e.what();
1821 }
1822 }
1823
1824 if (SHOULD_CLEAN_UP(asyncErrorHandling_)) {
1825 // Abort work and corresponding communicators
1826 work.abort();
1827 // PG level abort, which would abort all other communicators on this
1828 // rank
1829 abort();
1830 }
1831
1832 // Report desync state in case of timeout
1833 if (timedOut) {
1834 LOG(ERROR) << c10::str(
1835 logPrefix(),
1836 "Timeout at NCCL work: ",
1837 work.seq_,
1838 ", last enqueued NCCL work: ",
1839 pgStatus_->lastEnqueuedSeq,
1840 ", last completed NCCL work: ",
1841 pgStatus_->lastCompletedSeq,
1842 ".");
1843 if (desyncDebug_) {
1844 try {
1845 collectiveDebugInfoMode_.store(true);
1846 auto desyncMsg = getNCCLWatchdogDebugInfo();
1847 LOG(ERROR) << logPrefix() << desyncMsg;
1848 } catch (const std::exception& e) {
1849 LOG(ERROR)
1850 << logPrefix()
1851 << "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. "
1852 << " Please file an issue. Error: " << e.what();
1853 } catch (...) {
1854 LOG(ERROR)
1855 << logPrefix()
1856 << "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error."
1857 << " Please file an issue.";
1858 }
1859 }
1860 }
1861 // Throw exception
1862 work.handleException(asyncErrorHandling_);
1863 }
1864
1865 // Work status logging for desync debug
1866 if (desyncDebug_) {
1867 if (work.isStarted()) {
1868 logWorkStart(work);
1869 }
1870 if (work.isCompleted()) {
1871 logWorkEnd(work);
1872 }
1873 }
1874
1875 // a work could be started but not completed, so we should not update
1876 // lastStartedSeq and lastStartedOpName if the work state is checked
1877 // multiple times after the start
1878 if (pgStatus_->lastStartedSeq < static_cast<int64_t>(work.seq_) &&
1879 work.isStarted()) {
1880 pgStatus_->lastStartedSeq = work.seq_;
1881 pgStatus_->lastStartedWorkName = opTypeToString(work.opType_);
1882 }
1883
1884 // Clean up completed work
1885 if (work.isCompleted()) {
1886 {
1887 // Reset the timeout and first work if the work is completed.
1888 std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
1889 if (work.ownedEphermeralTimeout_.count() > 0) {
1890 ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_;
1891 ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_;
1892 }
1893 }
1894 pgStatus_->lastCompletedSeq = work.seq_;
1895 pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_);
1896 pgStatus_->lastCompletedNumelIn = work.numelIn_;
1897 pgStatus_->lastCompletedNumelOut = work.numelOut_;
1898 NCCLTraceBuffer::get()->retire_id(work.trace_id_, true);
1899 if (onCompletionHook_) {
1900 // Move Work object to completedWorkList_ to be consumed by the hook
1901 // thread
1902 {
1903 const std::lock_guard<std::mutex> lock(completedWorkListMutex_);
1904 completedWorkList_.splice(
1905 completedWorkList_.end(), workMetaList_, it++);
1906 }
1907 completedWorkListCV_.notify_one();
1908 } else {
1909 it = workMetaList_.erase(it);
1910 lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
1911 }
1912 at::cuda::CUDAGraph::dec_pending_event_queries();
1913 } else {
1914 // Increment the iterator if the current WorkNCCL object is not
1915 // completed.
1916 ++it;
1917 }
1918 // Increment heartbeat after each work processed,
1919 // in case processing is slowed down (but not hung) by cuda api contention
1920 heartbeat_++;
1921 }
1922 done = workMetaList_.empty();
1923 }
1924 }
1925
runHookLoop()1926 void ProcessGroupNCCL::runHookLoop() {
1927 c10::setThreadName("pt_nccl_runhook");
1928
1929 bool done = false;
1930 while (!done || !terminateProcessGroup_.load()) {
1931 std::unique_lock<std::mutex> lock(completedWorkListMutex_);
1932 // We busy-poll the work vector every kWatchdogThreadSleepMillis
1933 // milliseconds as long as the atomic is True.
1934 completedWorkListCV_.wait_for(
1935 lock,
1936 std::chrono::milliseconds(kWatchdogThreadSleepMillis),
1937 [&]() -> bool {
1938 return !completedWorkList_.empty() || terminateProcessGroup_.load();
1939 });
1940
1941 try {
1942 for (auto it = completedWorkList_.begin(); it != completedWorkList_.end();
1943 /* no increment */) {
1944 const WorkNCCL& work = *it;
1945 // Hook might grab GIL, unlock first to prevent deadlock
1946 lock.unlock();
1947
1948 auto timeStarted =
1949 std::chrono::system_clock::now() +
1950 std::chrono::duration_cast<std::chrono::system_clock::duration>(
1951 work.workStartTime_ - std::chrono::steady_clock::now());
1952 onCompletionHook_(std::make_shared<WorkInfo>(
1953 work.retrieveOpType(), // OpType
1954 work.getSequencenumber(), // seq
1955 timeStarted, // timeStarted
1956 std::chrono::system_clock::now(), // timeFinished
1957 std::chrono::duration<float, std::milli>(
1958 work.getDuration()) // activeDuration
1959 ));
1960
1961 lock.lock();
1962 it = completedWorkList_.erase(it);
1963 }
1964 } catch (std::exception& e) {
1965 if (std::string(e.what()).find("driver shutting down") !=
1966 std::string::npos) {
1967 LOG(INFO)
1968 << logPrefix()
1969 << "main process destroyed cuda before runHookLoop exited, terminating runHookLoop."
1970 << " (runHookLoop caught exception: " << e.what();
1971
1972 } else {
1973 // PythonOnCompletionHook has already extracted Python exception message
1974 // and wrapped it with a cpp one. So we no longer need to acquire GIL
1975 // here.
1976 const auto errorStr = c10::str(
1977 "Caught exception on rank ",
1978 rank_,
1979 " while running onCompletion hook for ProcessGroupNCCL: ",
1980 e.what(),
1981 ". Aborting all communicators.");
1982
1983 // No need to call abort() on WorkNCCL here as that collective has
1984 // already finished successfully at this point. We just need to abort
1985 // the process Abort all NCCL Communicators on this ProcessGroupNCCL
1986 // instance.
1987 abort(errorStr);
1988 }
1989 }
1990
1991 // Lock is still acquired at this point
1992 done = completedWorkList_.empty();
1993 }
1994 }
1995
checkForNCCLErrors()1996 std::exception_ptr ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors() {
1997 return checkForNCCLErrorsInternal(ncclComm_);
1998 }
1999
checkForNCCLErrors(std::shared_ptr<NCCLComm> & ncclComm)2000 std::exception_ptr ProcessGroupNCCL::checkForNCCLErrors(
2001 std::shared_ptr<NCCLComm>& ncclComm) {
2002 return checkForNCCLErrorsInternal(ncclComm);
2003 }
2004
checkForNCCLErrorsInternal(std::shared_ptr<NCCLComm> & ncclComm)2005 std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal(
2006 std::shared_ptr<NCCLComm>& ncclComm) {
2007 // Prioritize commFailureReason over checkForNcclError() result if
2008 // commFailureReason is set.
2009 auto commFailureReason = ncclComm->getNcclCommFailureReason();
2010 if (commFailureReason != std::nullopt) {
2011 return std::make_exception_ptr(C10_BUILD_ERROR(
2012 DistBackendError,
2013 c10::str(
2014 "NCCL communicator encountered error set by ProcessGroupNCCL: ",
2015 *commFailureReason)));
2016 }
2017 ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError();
2018 // When nonblocking mode is enabled by TORCH_NCCL_USE_COMM_NONBLOCKING,
2019 // ncclInProgress could be returned when there are pending NCCL calls.
2020 // In this case, no exception should be thrown
2021 #ifdef NCCL_HAS_COMM_NONBLOCKING
2022 // ncclInProgress is defined only if NCCL_HAS_COMM_NONBLOCKING is defined
2023 if (ncclAsyncErr != ncclSuccess && ncclAsyncErr != ncclInProgress) {
2024 #else
2025 if (ncclAsyncErr != ncclSuccess) {
2026 #endif
2027 return std::make_exception_ptr(C10_BUILD_ERROR(
2028 DistBackendError,
2029 "NCCL error: " + ncclGetErrorWithVersion(ncclAsyncErr) + "\n" +
2030 getNcclErrorDetailStr(ncclAsyncErr)));
2031 }
2032
2033 return nullptr;
2034 }
2035
2036 void ProcessGroupNCCL::broadcastUniqueNCCLID(
2037 ncclUniqueId* ncclID,
2038 bool isSingleP2POp,
2039 const std::string& p2pKey,
2040 int p2pRank) {
2041 // For collective operations:
2042 // For every NCCL communicator that we create we need to broadcast
2043 // a unique ID from rank 0 to all other ranks. This broadcast is
2044 // done by rank 0 setting a key in the store and all other ranks
2045 // retrieving the contents of that key. A single process group
2046 // may create multiple NCCL communicators, so we use a sequence
2047 // number to differentiate between them.
2048 // For single point-to-point operations:
2049 // The sequence number will only be increased on 2 out of all the
2050 // processes in a Process Group. So all following collective
2051 // operations will see different sequence numbers which will cause
2052 // runtime errors. To avoid that, use the src:target pair instead
2053 // of sequence number for p2p communications.
2054
2055 std::string storeKey;
2056 if (!isSingleP2POp) {
2057 storeKey = std::to_string(ncclCommCounter_++);
2058 } else {
2059 storeKey = p2pKey;
2060 }
2061 if (rank_ == 0 || (isSingleP2POp && p2pRank == 0)) {
2062 auto vec = std::vector<uint8_t>(
2063 reinterpret_cast<uint8_t*>(ncclID),
2064 reinterpret_cast<uint8_t*>(ncclID) + NCCL_UNIQUE_ID_BYTES);
2065 store_->set(storeKey, vec);
2066 } else {
2067 try {
2068 auto vec = store_->get(storeKey);
2069 TORCH_CHECK_WITH(
2070 DistBackendError,
2071 vec.size() == NCCL_UNIQUE_ID_BYTES,
2072 "Invalid size for ncclUniqueId");
2073 std::memcpy(ncclID, vec.data(), vec.size());
2074 } catch (const std::exception& e) {
2075 std::string exceptionMsg = c10::str(
2076 "[",
2077 rank_,
2078 "] is setting up NCCL communicator and "
2079 "retrieving ncclUniqueId from [0] via c10d key-value store by key '",
2080 storeKey,
2081 "', but store->get('",
2082 storeKey,
2083 "') got error: ");
2084 C10_THROW_ERROR(
2085 DistBackendError,
2086 exceptionMsg + e.what() +
2087 ". This may indicate a possible application crash on rank 0 or a network set up issue.");
2088 } catch (...) {
2089 C10_THROW_ERROR(
2090 DistBackendError,
2091 c10::str(
2092 "Unknown exception while [",
2093 rank_,
2094 "] is setting up NCCL communicator and "
2095 "retrieving ncclUniqueId from [0] via c10d key-value store by key '",
2096 storeKey,
2097 "'",
2098 ". This may indicate a possible application crash on rank 0 or a network set up issue."));
2099 }
2100 }
2101 }
2102
2103 void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
2104 std::lock_guard<std::mutex> lock(mutex_);
2105 if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) {
2106 TORCH_INTERNAL_ASSERT(
2107 false,
2108 "Expected to find key ",
2109 devNCCLCommMapKey,
2110 " in NCCL communicator map.");
2111 }
2112 std::shared_ptr<NCCLComm>& ncclComm = devNCCLCommMap_[devNCCLCommMapKey];
2113 // ncclCommDestroy(comm->getNcclComm()) results in segfault when PG is being
2114 // destroyed, so using ncclCommAbort here.
2115 ncclComm->ncclCommAbort();
2116 // Remove communicators from the cache.
2117 devNCCLCommMap_.erase(devNCCLCommMapKey);
2118 // Clear used device indices.
2119 usedDeviceIdxs_.clear();
2120
2121 ncclCommDevIdxMapMutex.lock();
2122 ncclCommDevIdxMap.erase(ncclComm);
2123 ncclCommDevIdxMapMutex.unlock();
2124 }
2125
2126 std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
2127 const std::string& deviceKey,
2128 at::Device& device,
2129 OpType opType,
2130 int p2pRank,
2131 bool isSendRecvSelf) {
2132 // Sanity check
2133 if (deviceKey.empty()) {
2134 C10_THROW_ERROR(
2135 DistBackendError,
2136 "Not able to create/get the NCCL Communicator since "
2137 "the GPU devices are not known");
2138 }
2139 if (bound_device_id_) {
2140 if (*bound_device_id_ != device) {
2141 LOG(ERROR) << logPrefix() << "Tensor found on device " << device
2142 << " but backend constrained to " << *bound_device_id_;
2143 C10_THROW_ERROR(
2144 DistBackendError,
2145 "Attempt to perform collective on tensor not on device passed to init_process_group");
2146 }
2147 }
2148
2149 usedDeviceIdxs_.insert(device.index());
2150
2151 {
2152 std::lock_guard<std::mutex> lock(mutex_);
2153 if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) {
2154 // Reuse the cached communicator if there is one.
2155 return devNCCLCommMap_[deviceKey];
2156 }
2157 }
2158
2159 // NCCL communicator not cached, create a new entry
2160 std::shared_ptr<NCCLComm> ncclComm;
2161
2162 // Create the unique NCCL ID and broadcast it
2163 ncclUniqueId ncclID;
2164
2165 // reset log prefix to include group_desc
2166 logPrefix_ = createLogPrefix();
2167
2168 #ifdef NCCL_COMM_DESCRIPTION
2169 // Pass process group name and description to NCCL communicator
2170 std::string commDesc = pg_desc_ + ':' + pg_uid_;
2171 options_->config.commDesc = strdup(commDesc.c_str());
2172 #endif
2173
2174 // For batch_isend_irecv, ncclGroupStart() would be called upfront
2175 bool batchP2P = ncclActiveGroupCounter_ > 0;
2176 bool singleP2POp = isP2POp(opType, batchP2P);
2177
2178 // Get the device index
2179 auto deviceIndex = device.index();
2180 at::cuda::OptionalCUDAGuard gpuGuard(device);
2181
2182 // [Group Start/End Note] This is used to ensure that nccl communicator will
2183 // be created before communication primitives are called. Let's look at this
2184 // example: Using the batch_isend_irecv to send a tensor to a target process.
2185 // On the sender side, the corresponding underlying NCCL calls will look like
2186 // ncclGroupStart() // This is in batch_isend_irecv
2187 // ncclCommInitRank() // Inside NCCLComm::create
2188 // ncclSend()
2189 // ncclGroupEnd() // This is in batch_isend_irecv
2190 // With this pattern, the nccl communicator will be created in the last
2191 // ncclGroupEnd which means when ncclSend is processed, the passed
2192 // communicator argument is NULL which will lead to runtime error. So we need
2193 // to "close" all active nccl groups to ensure nccl communicator is actually
2194 // created before encountering any communication calls. This is why we need
2195 // the following for loop.
2196 for (const auto i : c10::irange(ncclActiveGroupCounter_)) {
2197 (void)i;
2198 // comms have not been initiated yet, so can only check in blocking-way
2199 C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt);
2200 }
2201
2202 // GPU world size and GPU rank
2203 int numRanks, rank;
2204
2205 if (!singleP2POp) {
2206 // Collective, all-to-all, or batch P2P
2207 numRanks = getSize();
2208 rank = getRank();
2209 } else if (isSendRecvSelf) {
2210 // Same process send and recv.
2211 numRanks = 1;
2212 rank = 0;
2213 } else {
2214 // For single point-to-point operation, there are only 2 processes
2215 // involved so the GPU rank is either 0 or 1.
2216 numRanks = 2;
2217 rank = p2pRank;
2218 }
2219
2220 #ifdef NCCL_HAS_COMM_SPLIT
2221 if (options_->split_from) {
2222 TORCH_CHECK(
2223 options_->split_color != 0,
2224 "Must specify a non-zero color when splitting");
2225 // Find a valid, healthy communicator to split from if possible.
2226 std::lock_guard<std::mutex> lock(options_->split_from->mutex_);
2227 auto& other_comms = options_->split_from->devNCCLCommMap_;
2228 auto dit = other_comms.find(getKeyFromDevice(device));
2229 if (dit != other_comms.end()) {
2230 auto& parentComm = dit->second;
2231 if (parentComm != nullptr && !parentComm->isAborted()) {
2232 ncclComm = NCCLComm::split(
2233 parentComm.get(),
2234 options_->split_color,
2235 rank,
2236 options_->config,
2237 options_->global_ranks_in_group);
2238 }
2239 }
2240 }
2241 #endif
2242
2243 // To simplify conditional nesting, just create the ncclComms[i]
2244 // entry if it hasn't been yet rather than untangling the
2245 // conditions that might have resulted in a split above.
2246 if (!ncclComm) {
2247 if (getCvarBool(TORCH_NCCL_BCAST_UNIQUEID, true) && !isSendRecvSelf) {
2248 // For point-to-point communication, lower rank of the two will get unique
2249 // id.
2250 if (rank_ == 0 || (singleP2POp && p2pRank == 0)) {
2251 C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), std::nullopt);
2252 }
2253
2254 // Broadcast so that each process can have a unique NCCL ID
2255 auto timeStarted = std::chrono::steady_clock::now();
2256 broadcastUniqueNCCLID(&ncclID, singleP2POp, deviceKey, p2pRank);
2257 auto timerDeltaMs =
2258 std::chrono::duration_cast<std::chrono::duration<double>>(
2259 std::chrono::steady_clock::now() - timeStarted)
2260 .count() *
2261 1000;
2262 LOG(INFO) << logPrefix()
2263 << "ProcessGroupNCCL broadcast unique ID through store took "
2264 << timerDeltaMs << " ms";
2265 }
2266
2267 #ifdef NCCL_HAS_COMM_NONBLOCKING
2268 ncclComm = NCCLComm::create(numRanks, rank, ncclID, options_->config);
2269 #else
2270 ncclComm = NCCLComm::create(numRanks, rank, ncclID);
2271 #endif
2272 }
2273
2274 // Creates the NCCL streams
2275 bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false);
2276 auto streamVal = at::cuda::getStreamFromPool(
2277 options_->is_high_priority_stream || force_high);
2278
2279 {
2280 std::lock_guard<std::mutex> lock(mutex_);
2281 inInitializationCommMap_.emplace(deviceKey, ncclComm);
2282 }
2283
2284 NCCLTraceBuffer::get()->record_pg_ranks(
2285 std::make_tuple(pg_uid_, pg_desc_), groupRanks());
2286
2287 RECORD_PARAM_COMMS(
2288 0, // seq
2289 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
2290 rank, // rank
2291 "init", // collective name
2292 0, // inNelems
2293 0, // outNelems
2294 at::kByte, // dType
2295 std::vector<int64_t>(), // inSplitSizes
2296 std::vector<int64_t>(), // outSplitSizes
2297 globalRankStart, // globalRankStart
2298 globalRankStride, // globalRankStride
2299 size_); // worldSize
2300
2301 LOG(INFO) << logPrefix() << "ProcessGroupNCCL created ncclComm_ "
2302 << ncclComm->ncclComm_ << " on CUDA device: " << deviceIndex;
2303
2304 // At this point NCCL should have been initialized, hence we can accurately
2305 // get the env value even if NCCL sets it by reading from nccl.conf file
2306 LOG(INFO) << logPrefix()
2307 << "NCCL_DEBUG: " << getCvarString({"NCCL_DEBUG"}, "N/A");
2308
2309 // See [Group Start/End Note]
2310 for (const auto i : c10::irange(ncclActiveGroupCounter_)) {
2311 (void)i;
2312 C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt);
2313 }
2314
2315 ncclStreams_.emplace(deviceKey, std::move(streamVal));
2316
2317 // Note: these events are created with the (default) cudaEventDisableTiming
2318 // flag This flag provides the best performance when used with
2319 // cudaStreamWaitEvent() and cudaEventQuery(). Since we here don't measure the
2320 // performance using cudaEvent, this should be set.
2321 // TODO(kwen2501): is ncclEvents_ used anywhere else?
2322 ncclEvents_.emplace(deviceKey, at::cuda::CUDAEvent(cudaEventDisableTiming));
2323
2324 // Move the NCCL resource to cache
2325 auto it = inInitializationCommMap_.find(deviceKey);
2326 // A previous thread could've already removed devicesKey from
2327 // inInitializationCommMap_ and added it to devNCCLCommMap_
2328 if (it != inInitializationCommMap_.end()) {
2329 devNCCLCommMap_.emplace(deviceKey, std::move(it->second));
2330 inInitializationCommMap_.erase(deviceKey);
2331
2332 // Now ncclComms are fully initialized.
2333 // Register all active CUDA memory segments in cache allocator to
2334 // the new NCCL communicators
2335 if (useTensorRegisterAllocatorHook_) {
2336 auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
2337 // Register the segment to a new NCCL communicator if on the same device
2338 for (const auto& segmentInfo : snapshot.segments) {
2339 TORCH_INTERNAL_ASSERT(
2340 segmentInfo.device == device.index(),
2341 "Mismatch between CUDA memory segment device and current device");
2342 ncclComm->registerSegment(
2343 reinterpret_cast<void*>(segmentInfo.address),
2344 segmentInfo.total_size);
2345 }
2346 }
2347 // Record the mapping between ncclComm and device index so that later
2348 // register hook can register a newly allocated segment to communicators
2349 // on the same device.
2350 // NOTE: we need remove the communicator from this map when it is
2351 // destroyed, otherwise may register onto an invalid communicator.
2352 ncclCommDevIdxMapMutex.lock();
2353 ncclCommDevIdxMap.emplace(ncclComm, device.index());
2354 ncclCommDevIdxMapMutex.unlock();
2355 }
2356
2357 it = devNCCLCommMap_.find(deviceKey);
2358 TORCH_INTERNAL_ASSERT(
2359 it != devNCCLCommMap_.end(), "Communicators not populated in cache!");
2360
2361 return it->second;
2362 }
2363
2364 uint64_t ProcessGroupNCCL::getCommSplitCounter() const {
2365 uint64_t ret = 0;
2366 for (const auto& i : devNCCLCommMap_) {
2367 auto& ncclComm = i.second;
2368 ret += ncclComm->getCommSplitCounter();
2369 }
2370 return ret;
2371 }
2372
2373 namespace {
2374
2375 // Check validity of tensor
2376 void check_gpu_single_tensor(
2377 const at::Tensor& tensor,
2378 const bool p2p = false // whether operation is a P2P operation
2379 ) {
2380 if (!tensor.is_cuda() || tensor.is_sparse()) {
2381 C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense");
2382 }
2383 // Skip the following requirements for P2P operations
2384 if (!tensor.is_contiguous(tensor.suggest_memory_format())) {
2385 if (p2p) {
2386 TORCH_WARN_ONCE(
2387 "Detected non-contiguous tensor in P2P operations. It is user "
2388 "responsibility to guarantee that source and destination tensors have "
2389 "the same contiguity format.");
2390 } else {
2391 C10_THROW_ERROR(ValueError, "Tensors must be contiguous");
2392 }
2393 }
2394 }
2395
2396 // Checks that all `tensors' have the same type and shape and reside on the same
2397 // GPU.
2398 // TODO: test_c10d_nccl.py should consider adding tests for the error conditions
2399 // here, ie, that deliberately pass invalid tensors and check the right
2400 // exception is thrown. The "Expected list of tensors on the same device"
2401 // condition may be a challenge because the test would need to pass tensors on
2402 // different devices in the same process.
2403 int64_t check_gpu_tensors_same_device(const std::vector<at::Tensor>& tensors) {
2404 if (tensors.size() == 0) {
2405 C10_THROW_ERROR(ValueError, "Tensor list must be nonempty");
2406 }
2407
2408 const auto& first = tensors.front();
2409
2410 int64_t total_numel = 0;
2411 for (const auto& t : tensors) {
2412 if (!t.is_cuda() || t.is_sparse()) {
2413 C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense");
2414 }
2415 if (t.scalar_type() != first.scalar_type()) {
2416 C10_THROW_ERROR(TypeError, "Tensors must have identical type");
2417 }
2418 if (!t.is_non_overlapping_and_dense()) {
2419 C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense");
2420 }
2421 // If we're in this function, the user called a _coalesced collective
2422 // on a set of tensors with potentially different sizes and strides.
2423 // Therefore, we don't check for matching sizes and strides,
2424 // but we do double-check tensors are on the same device.
2425 TORCH_CHECK_WITH(
2426 ValueError,
2427 t.get_device() == tensors[0].get_device(),
2428 "Expected list of tensors on the same device");
2429 total_numel += t.numel();
2430 }
2431
2432 return total_numel;
2433 }
2434
2435 bool check_same_size(const std::vector<at::Tensor>& input_tensors) {
2436 for (const auto& input_tensor : input_tensors) {
2437 if (!input_tensors[0].is_same_size(input_tensor)) {
2438 return false;
2439 }
2440 }
2441 return true;
2442 }
2443
2444 } // namespace
2445
2446 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
2447 at::Device& device,
2448 int rank,
2449 OpType opType,
2450 const char* profilingTitle,
2451 const std::vector<at::Tensor>& inputs,
2452 const std::vector<at::Tensor>& outputs, // TODO(kwen2501): necessary?
2453 bool record) {
2454 auto r = c10::make_intrusive<ProcessGroupNCCL::WorkNCCL>(
2455 pg_uid_,
2456 pg_desc_,
2457 device,
2458 rank,
2459 opType,
2460 seqCollective_,
2461 profilingTitle,
2462 profilingTitle != nullptr ? std::optional<std::vector<at::Tensor>>(inputs)
2463 : std::nullopt,
2464 desyncDebug_,
2465 enableTiming_.load(),
2466 cudaEventCacheEnabled_.load(),
2467 dist_debug_level_);
2468 if (record) {
2469 bool isP2P = isP2POp(opType);
2470 // Ideally record every work that we enqueue, rather than every work we
2471 // create.
2472 // - at the time of this PR we do not currently enqueue every created work
2473 // - but it is unsafe to steal refs to start/end cuda events from Works that
2474 // may go out of scope before flight recorder has retired them,
2475 // so we must ensure that any work that is initialized via initWork will
2476 // be enqueued
2477 // - initially, moved record() into workEnqueue(), but found that makes it
2478 // hard to get access to profilingTitle,
2479 // inputs, and outputs for metadata recording, and we don't want to attach
2480 // these objects to the Work becuase it has implications for keeping those
2481 // tensors alive longer and adds overhead when copying Work objects
2482 // between threads
2483 r->trace_id_ = NCCLTraceBuffer::get()->record(
2484 local_id_,
2485 std::make_tuple(pg_uid_, pg_desc_),
2486 seqCollective_,
2487 seqP2P_,
2488 op_id_,
2489 profilingTitle ? profilingTitle : "",
2490 inputs,
2491 outputs,
2492 r->ncclStartEvent_.get(),
2493 r->ncclEndEvent_.get(),
2494 options_->timeout,
2495 pgStatus_,
2496 isP2P);
2497 }
2498 return r;
2499 }
2500
2501 // TODO(kwen2501): deprecate
2502 std::vector<at::Tensor> ProcessGroupNCCL::WorkNCCL::result() {
2503 return *outputs_;
2504 }
2505
2506 c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL::
2507 getFuture() {
2508 return future_;
2509 }
2510
2511 float ProcessGroupNCCL::WorkNCCL::getDuration() const {
2512 TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled");
2513 TORCH_CHECK(
2514 ncclStartEvent_,
2515 "getDuration only works if ncclStartEvents_ is populated, true if timing enabled");
2516 TORCH_CHECK(
2517 ncclEndEvent_,
2518 "getDuration only works if ncclEndEvents_ is populated, which should always be true");
2519 return ncclStartEvent_->elapsed_time(*ncclEndEvent_);
2520 }
2521
2522 uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const {
2523 return seq_;
2524 }
2525
2526 void ProcessGroupNCCL::assignTimeoutToWork(
2527 const c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work,
2528 const c10::intrusive_ptr<ProcessGroupNCCL::Options>& option) {
2529 std::chrono::milliseconds timeout = option->timeout;
2530 std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
2531 if (ephemeralTimeoutActive_.count() > 0) {
2532 timeout += ephemeralTimeoutActive_;
2533 }
2534 work->opTimeout_ = timeout;
2535 work->ownedEphermeralTimeout_ =
2536 ephemeralTimeoutActive_ - ephemeralTimeoutInflight_;
2537 ephemeralTimeoutInflight_ = ephemeralTimeoutActive_;
2538 }
2539
2540 void ProcessGroupNCCL::workEnqueue(
2541 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> work) {
2542 if (!terminateProcessGroup_.load()) {
2543 std::lock_guard<std::mutex> lock(workMetaListMutex_);
2544 // Avoid view tensors to be processed in cleanup thread.
2545 // View tensors' destruction invokes autograd_meta, which
2546 // needs to be destructed in user thread. Otherwise will
2547 // get deadlock. Here we enqueue work without outputs_.
2548 workMetaList_.emplace_back(*work);
2549 // update the PG status related to the last enqueued work
2550 pgStatus_->lastEnqueuedSeq = work->seq_;
2551 pgStatus_->lastEnqueuedWorkName = opTypeToString(work->opType_);
2552 pgStatus_->lastEnqueuedNumelIn = work->numelIn_;
2553 pgStatus_->lastEnqueuedNumelOut = work->numelOut_;
2554 lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
2555 }
2556 }
2557
2558 ProcessGroupNCCL::Options::Options(bool is_high_priority_stream)
2559 : Backend::Options(NCCL_BACKEND_NAME, kProcessGroupNCCLDefaultTimeout),
2560 is_high_priority_stream(is_high_priority_stream) {}
2561
2562 static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04;
2563
2564 void ProcessGroupNCCL::startCoalescing() {
2565 // Other collective ops bump seq_ before creating a work. Thus, if coalesced
2566 // ops bump seq_ only after initing a work they will collide with (reuse) the
2567 // seq_ of the last non-coalesced collective. Previously, seq_ was bumped
2568 // inside endCoalescing, but before initWork. Since we now record individual
2569 // ops from a coalesce group into the flight recorder, we want to have the
2570 // same seq_ for those ops and its 'endCoalescing' op. Hence we bump during
2571 // start, which has one minor downside- we burn a seq_ if someone ever does a
2572 // 'start' and 'end' coalescing region without doing an operation inbetween.
2573
2574 // Don't bump op_id_ here, because startCoalescing isn't a logical operation.
2575 // Bump it for each logical op inside the coalescing group.
2576 if (coalescing_state_ & CoalP2P) {
2577 seqP2P_++;
2578 } else {
2579 seqCollective_++;
2580 }
2581
2582 coalescedDevice_.set_index(-1);
2583 coalescedComm_ = nullptr;
2584 coalescing_state_ |= CoalActive;
2585 groupStart();
2586 }
2587
2588 // `optype` is for specifying a composite optype, such as ALLGATHER and
2589 // REDUCE_SCATTER
2590 c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
2591 if (coalescedComm_ == nullptr) {
2592 // There is no actual work being coalesced, return here
2593 groupEnd();
2594 coalescing_state_ = 0;
2595 return nullptr;
2596 }
2597 TORCH_CHECK(
2598 coalescedDevice_.index() >= 0,
2599 "Somthing went wrong. Did you call end_coalescing before start_coalescing?");
2600
2601 // `coalescedComm_` should have same set of comms across collectives
2602 auto comm = coalescedComm_;
2603 // `coalescedDevice_` should have same set of devices across collectives
2604 auto device = coalescedDevice_;
2605
2606 // `getKeyFromDevice` is how we get keys for both collectives and batch P2P
2607 const auto key = getKeyFromDevice(device);
2608 auto ncclStream = ncclStreams_.at(key);
2609
2610 // Create Work object
2611 c10::cuda::CaptureStatus capture_status =
2612 c10::cuda::currentStreamCaptureStatusMayInitCtx();
2613 bool enqueue =
2614 (coalescing_state_) && capture_status == c10::cuda::CaptureStatus::None;
2615 auto work =
2616 initWork(device, rank_, optype, "nccl:coalesced", {}, {}, enqueue);
2617 work->ncclComm_ = comm;
2618 work->blockingWait_ = blockingWait_;
2619 work->avoidRecordStreams_ = avoidRecordStreams_;
2620 work->store_ = store_;
2621 assignTimeoutToWork(work, options_);
2622
2623 // Record start before ncclGroupEnd
2624 if (work->timingEnabled_) {
2625 work->ncclStartEvent_->record(ncclStream);
2626 }
2627
2628 if (nccl_use_nonblocking()) {
2629 groupEndNonblocking(comm);
2630 } else {
2631 groupEnd();
2632 }
2633
2634 // Record end after ncclGroupEnd
2635 // TODO(eqy): is this still necessary if avoidRecordStreams_ is set?
2636 work->ncclEndEvent_->record(ncclStream);
2637
2638 if (avoidRecordStreams_) {
2639 // other functions expect an initialized ptr if avoidRecordStreams_ is set
2640 work->stashed_for_allocator_safety_ =
2641 std::make_shared<std::vector<at::Tensor>>();
2642 }
2643
2644 // Notify graphs before we check the capture status preemptively
2645 at::cuda::CUDAGraph::inc_pending_event_queries();
2646
2647 if (enqueue) {
2648 workEnqueue(work);
2649 } else {
2650 at::cuda::CUDAGraph::dec_pending_event_queries();
2651 }
2652
2653 coalescing_state_ = 0;
2654 coalescedComm_ = nullptr;
2655 return work;
2656 }
2657
2658 c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing() {
2659 // Default OpType to COALESCED if not specified
2660 return endCoalescing(OpType::COALESCED);
2661 }
2662
2663 template <typename Fn, typename PreProcess, typename PostProcess>
2664 c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
2665 std::vector<at::Tensor>& inputs,
2666 std::vector<at::Tensor>& outputs,
2667 Fn fn,
2668 PreProcess pre,
2669 PostProcess post,
2670 OpType opType,
2671 const char* profilingTitle,
2672 bool avoidRecordStreams,
2673 bool nanCheck) {
2674 // Environment setting by the user may add onto collective call's option
2675 avoidRecordStreams |= avoidRecordStreams_;
2676 nanCheck &= enableNanCheck_;
2677
2678 c10::cuda::CaptureStatus capture_status =
2679 c10::cuda::currentStreamCaptureStatusMayInitCtx();
2680 errorIfCapturingNonCapturableNCCL(capture_status);
2681
2682 // Bump collective counter
2683 seqCollective_++;
2684 op_id_++;
2685
2686 auto device = getDevice(inputs[0]);
2687 const auto key = getKeyFromDevice(device);
2688 auto ncclComm = getNCCLComm(key, device, opType);
2689
2690 if (coalescing_state_ & CoalActive) {
2691 coalescing_state_ |= CoalColl;
2692 if (coalescedDevice_.index() < 0) {
2693 coalescedDevice_ = device;
2694 } else {
2695 TORCH_CHECK(
2696 coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG);
2697 }
2698 if (coalescedComm_ == nullptr) {
2699 coalescedComm_ = ncclComm;
2700 } else {
2701 TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
2702 }
2703 }
2704
2705 // Used many times below, so we stash the unordered_map lookup
2706 auto ncclStream = ncclStreams_.at(key);
2707
2708 // First let NCCL streams wait for input tensors allocation streams
2709 syncStream(device, ncclEvents_[key], ncclStream);
2710
2711 bool enqueue =
2712 !coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None;
2713 auto work =
2714 initWork(device, rank_, opType, profilingTitle, inputs, outputs, enqueue);
2715
2716 // Store references to outputs to be used by WorkNCCL::result and operator<<.
2717 work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
2718
2719 if (avoidRecordStreams) {
2720 work->stashed_for_allocator_safety_ =
2721 std::make_shared<std::vector<at::Tensor>>(inputs);
2722 }
2723
2724 at::cuda::OptionalCUDAGuard gpuGuard(device);
2725
2726 if (nanCheck) {
2727 for (const auto& input : inputs) {
2728 checkForNan(input, ncclStream);
2729 }
2730 }
2731
2732 // Start event should only be recorded before the ncclGroupStart()
2733 if (work->timingEnabled_) {
2734 work->ncclStartEvent_->record(ncclStream);
2735 }
2736
2737 pre(ncclStream, work);
2738
2739 ncclComm_t comm = ncclComm->getNcclComm();
2740
2741 // Both `inputs' and `outputs' are created on a worker stream and used in
2742 // different ncclStreams. Hence, both must record the ncclStream to
2743 // prevent being freed before the collective finishes.
2744 //
2745 // We only record `inputs' here, and leave recording `outputs' to `fn' for
2746 // operations where `inputs' and `outputs' are not the same.
2747 //
2748 // See [Sync Streams].
2749 if (!avoidRecordStreams) {
2750 for (const auto& input : inputs) {
2751 if (!input.is_sparse()) {
2752 c10::cuda::CUDACachingAllocator::recordStream(
2753 input.storage().data_ptr(), ncclStream);
2754 } else {
2755 // for sparse input case record streams on both index and value
2756 // tensors
2757 c10::cuda::CUDACachingAllocator::recordStream(
2758 input.values().storage().data_ptr(), ncclStream);
2759 c10::cuda::CUDACachingAllocator::recordStream(
2760 input.indices().storage().data_ptr(), ncclStream);
2761 }
2762 }
2763 }
2764
2765 // Not all collectives have the same signature, e.g, all-reduce take in a Tensor
2766 // as the input and output while all-to-all take in a vector of Tensors as input
2767 // and output. Because we define the signature of the fn to take only single
2768 // tensor as input and output, we need to do a hack to get the first element in
2769 // the vector and pass it to fn.
2770 // TODO: we should clean up this in future (by either entirely removing lambda's
2771 // or removing input and output from lambda's signature).
2772 #ifndef NCCL_HAS_COMM_NONBLOCKING
2773 C10D_NCCL_CHECK(
2774 fn(inputs[0], outputs[0], comm, ncclStream),
2775 ncclComm->getNcclCommFailureReason());
2776 #else
2777 C10D_NCCL_CHECK_TIMEOUT(
2778 fn(inputs[0], outputs[0], comm, ncclStream),
2779 comm,
2780 ncclComm->getNcclCommFailureReason());
2781 #endif
2782
2783 post(ncclStream, work);
2784
2785 // End event should only be recorded after the ncclGroupEnd()
2786 if (!coalescing_state_) {
2787 work->ncclEndEvent_->record(ncclStream);
2788 }
2789 work->ncclComm_ = ncclComm;
2790
2791 {
2792 c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream);
2793 std::vector<at::Device> devices{device};
2794 work->future_ = c10::make_intrusive<at::ivalue::Future>(
2795 c10::ListType::create(c10::TensorType::get()), devices);
2796
2797 // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA
2798 // future blocks the stream this callback runs on the corresponding
2799 // ncclEndEvents_ ensuring appropriate synchronization.
2800 if (work->recordFunctionEndCallback_) {
2801 work->future_->addCallback(
2802 [work](at::ivalue::Future& /* unused */) {
2803 work->recordFunctionEndCallback_();
2804 },
2805 // uses_future = false allows us to skip synchronization in
2806 // ivalue::Future, but is only valid as long as the lambda doesn't use
2807 // the "Future" argument.
2808 /*uses_future=*/false);
2809 }
2810 work->future_->markCompleted(at::IValue(*work->outputs_));
2811 }
2812
2813 // Set appropriate work parameters.
2814 work->blockingWait_ = blockingWait_;
2815 work->avoidRecordStreams_ = avoidRecordStreams;
2816 work->store_ = store_;
2817 assignTimeoutToWork(work, options_);
2818 // Record size info for debug. We only record the size on the first device as
2819 // multi-device per process is deprecated
2820 work->numelIn_ = 0;
2821 work->numelOut_ = 0;
2822 for (const auto& input : inputs) {
2823 work->numelIn_ += input.numel();
2824 }
2825 for (const auto& output : outputs) {
2826 work->numelOut_ += output.numel();
2827 }
2828
2829 // Notify graphs before we check the capture status preemptively
2830 at::cuda::CUDAGraph::inc_pending_event_queries();
2831 if (enqueue) {
2832 workEnqueue(work);
2833 } else {
2834 at::cuda::CUDAGraph::dec_pending_event_queries();
2835 }
2836
2837 return work;
2838 }
2839
2840 template <typename Fn>
2841 c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
2842 std::vector<at::Tensor>& inputs,
2843 std::vector<at::Tensor>& outputs,
2844 Fn fn,
2845 OpType opType,
2846 const char* profilingTitle,
2847 bool avoidRecordStreams) {
2848 // Environment setting by the user may add onto collective call's option
2849 avoidRecordStreams |= avoidRecordStreams_;
2850 c10::cuda::CaptureStatus capture_status =
2851 c10::cuda::currentStreamCaptureStatusMayInitCtx();
2852 errorIfCapturingNonCapturableNCCL(capture_status);
2853
2854 // Bump collective counter
2855 seqCollective_++;
2856
2857 // For coalescingManager collectives, there is no individual c++ call per
2858 // collective so there is no flight record and we increment seq*_ and op_id_
2859 // together. Compare this to startCoalesing/endCoalescing flow where we
2860 // increment seq_ once per group and increment op_id_ once per indvidual
2861 // operation within the group
2862 op_id_++;
2863
2864 // Currently, the API permits one scenario where inputs.size() and
2865 // outputs.size() are > 0.
2866 // 1. If the call was a _coalesced call, all inputs must be on the same
2867 // device.
2868 // The group of nccl calls applies the collective separately to each input,
2869 // but the group as a whole should be efficient, and might even execute as
2870 // a single fused kernel.
2871 auto device = getDevice(inputs[0]);
2872 const auto key = getKeyFromDevice(device);
2873 auto ncclComm = getNCCLComm(key, device, opType);
2874
2875 if (coalescing_state_ & CoalActive) {
2876 coalescing_state_ |= CoalColl;
2877 if (coalescedDevice_.index() < 0) {
2878 coalescedDevice_ = device;
2879 } else {
2880 TORCH_CHECK(
2881 coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG);
2882 }
2883 if (coalescedComm_ == nullptr) {
2884 coalescedComm_ = ncclComm;
2885 } else {
2886 TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
2887 }
2888 }
2889
2890 // Used many times below, so we stash the unordered_map lookup
2891 auto ncclStream = ncclStreams_.at(key);
2892
2893 // First let NCCL streams wait for input tensors allocation streams
2894 syncStream(device, ncclEvents_[key], ncclStream);
2895
2896 auto work = initWork(
2897 device, rank_, opType, profilingTitle, inputs, outputs, /*record=*/true);
2898
2899 // Store references to outputs to be used by WorkNCCL::result and operator<<.
2900 work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
2901
2902 if (avoidRecordStreams) {
2903 work->stashed_for_allocator_safety_ =
2904 std::make_shared<std::vector<at::Tensor>>(inputs);
2905 }
2906
2907 at::cuda::OptionalCUDAGuard gpuGuard(device);
2908
2909 // Start event should only be recorded before the ncclGroupStart() (which
2910 // happens inside AutoNcclGroup guard below)
2911 if (work->timingEnabled_) {
2912 work->ncclStartEvent_->record(ncclStream);
2913 }
2914
2915 ncclComm_t comm = ncclComm->getNcclComm();
2916
2917 // TODO(kwen2501): this should be moved to c10d tests, to qualify a NCCL
2918 // upgrade. Once a NCCL version is qualified, this code should not be needed at
2919 // runtime.
2920 #ifdef PGNCCL_ENABLE_HASH
2921 if (enableCollecticeHashDebug_.load()) {
2922 auto numel = getTensorsNumel(inputs);
2923 auto hashValue = hashTensors(inputs);
2924 PRINT_COLLECTIVE_HASH_SIGNATURE(
2925 "input", opTypeToString(opType), numel, hashValue);
2926 }
2927 #endif
2928
2929 {
2930 torch::cuda::nccl::AutoNcclGroup nccl_group_guard(
2931 comm, nccl_use_nonblocking());
2932 for (const auto i : c10::irange(inputs.size())) {
2933 // Both `inputs' and `outputs' are created on a worker stream and used in
2934 // different ncclStreams. Hence, both must record the ncclStream to
2935 // prevent being freed before the collective finishes.
2936 //
2937 // We only record `inputs' here, and leave recording `outputs' to `fn' for
2938 // operations where `inputs' and `outputs' are not the same.
2939 //
2940 // See [Sync Streams].
2941 if (!avoidRecordStreams) {
2942 if (!inputs[i].is_sparse()) {
2943 c10::cuda::CUDACachingAllocator::recordStream(
2944 inputs[i].storage().data_ptr(), ncclStream);
2945 } else {
2946 // for sparse input case record streams on both index and value
2947 // tensors
2948 c10::cuda::CUDACachingAllocator::recordStream(
2949 inputs[i].values().storage().data_ptr(), ncclStream);
2950 c10::cuda::CUDACachingAllocator::recordStream(
2951 inputs[i].indices().storage().data_ptr(), ncclStream);
2952 }
2953 }
2954 #ifndef NCCL_HAS_COMM_NONBLOCKING
2955 C10D_NCCL_CHECK(
2956 fn(inputs[i], outputs[i], comm, ncclStream),
2957 ncclComm->getNcclCommFailureReason());
2958 #else
2959 C10D_NCCL_CHECK_TIMEOUT(
2960 fn(inputs[i], outputs[i], comm, ncclStream),
2961 comm,
2962 ncclComm->getNcclCommFailureReason());
2963 #endif
2964 }
2965 }
2966
2967 work->ncclEndEvent_->record(ncclStream);
2968 work->ncclComm_ = ncclComm;
2969
2970 {
2971 c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream);
2972 std::vector<at::Device> devices{device};
2973 work->future_ = c10::make_intrusive<at::ivalue::Future>(
2974 c10::ListType::create(c10::TensorType::get()), devices);
2975
2976 // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA
2977 // future blocks the stream this callback runs on the corresponding
2978 // ncclEndEvents_ ensuring appropriate synchronization.
2979 if (work->recordFunctionEndCallback_) {
2980 work->future_->addCallback(
2981 [work](at::ivalue::Future& /* unused */) {
2982 work->recordFunctionEndCallback_();
2983 },
2984 // uses_future = false allows us to skip synchronization in
2985 // ivalue::Future, but is only valid as long as the lambda doesn't use
2986 // the "Future" argument.
2987 /*uses_future=*/false);
2988 }
2989 work->future_->markCompleted(at::IValue(*work->outputs_));
2990 }
2991
2992 // Set appropriate work parameters.
2993 work->blockingWait_ = blockingWait_;
2994 work->avoidRecordStreams_ = avoidRecordStreams;
2995 work->store_ = store_;
2996 assignTimeoutToWork(work, options_);
2997 // Record size info for debug. We only record the size on the first device as
2998 // multi-device per process is deprecated
2999 work->numelIn_ = inputs[0].numel();
3000 work->numelOut_ = outputs[0].numel();
3001
3002 /* Note [cuda graph capture and workEnqueue]
3003
3004 Normal behavior of the C10D watchdog is to query cuda events on work objects
3005 periodically, but when cuda graph recording is active these event queries
3006 would crash or mess up the recording.
3007
3008 To ensure we do not enqueue a work object to the watchdog when cuda graph
3009 capture is active, we use a one-way sync. We increment a flag pre-emptively,
3010 indicating our intent to enqueue a work object. Then we check capture_status
3011 to see if (a) capturing is already in progress (we cannot enqueue in this
3012 case), (b) capturing hasn't started yet, so we can trust that no capture will
3013 start (since a pre-condition of starting a capture is to check the event query
3014 count is 0).
3015
3016 If we are not able to enqueue the work due to capture-in-progress, we finally
3017 decrement the counter.
3018
3019 For this reason we cannot easily move the increment inside workEnqueue unless
3020 we also change the semantic of workEnqueue to 'maybeWorkEnqueue'.
3021
3022 TODO:
3023 - Is our design for flight recorder safe in this context? are we recording
3024 any FR events during cudagraph capture? if so, they won't be safe to poll for
3025 completion status.
3026 */
3027 at::cuda::CUDAGraph::inc_pending_event_queries();
3028 if (capture_status == c10::cuda::CaptureStatus::None) {
3029 workEnqueue(work);
3030 } else {
3031 at::cuda::CUDAGraph::dec_pending_event_queries();
3032 }
3033 // TODO(whc) if the work isn't enqueued, I don't feel great about returning
3034 // it, since interactions with it by usercode won't behave normally - they
3035 // won't observe work completion, for instance. Will this lead to silent
3036 // problems during capture?
3037 return work;
3038 }
3039
3040 template <typename Fn, typename PreProcess, typename PostProcess>
3041 c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
3042 at::Tensor& tensor,
3043 Fn fn,
3044 int peer,
3045 OpType opType,
3046 PreProcess pre,
3047 PostProcess post,
3048 const char* profilingTitle) {
3049 // avoidRecordStreams_ note:
3050 // send, recv, and irecv should be ok with avoidRecordStreams,
3051 // However, for isend, I don't think the API requires the user
3052 // to wait() on the returned handle, so ProcessGroupNCCL can't know
3053 // when it's safe to release the input back to the allocator,
3054 // and the present call has no way to know it's not an isend.
3055 // Therefore, we warn and fall back to the typical recordStream logic:
3056 if (avoidRecordStreams_) {
3057 TORCH_WARN_ONCE(
3058 "TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point "
3059 "collectives.");
3060 }
3061
3062 auto device = getDevice(tensor);
3063 std::string key;
3064 int p2pRank = 0, p2pTargetRank = 0;
3065 bool isSendRecvSelf = false;
3066 // For batch_isend_irecv, ncclGroupStart() would be called upfront
3067 bool batchP2P = ncclActiveGroupCounter_ > 0;
3068 if (batchP2P) {
3069 // For batch P2P, we need to treat it like a collective when selecting
3070 // communicator, because other ranks can call into this batch other than my
3071 // rank and my peer
3072 key = getKeyFromDevice(device);
3073 p2pRank = rank_;
3074 p2pTargetRank = peer;
3075 } else {
3076 // For single P2P, preserve the old two-rank behavior (to avoid perf diff)
3077 key = getKeySendRecv(rank_, peer);
3078 p2pRank = rank_ <= peer ? 0 : 1;
3079 isSendRecvSelf = rank_ == peer;
3080 p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank;
3081
3082 if (!coalescing_state_) {
3083 // Bump P2P sequence number. Don't do so if it's a batch P2P, it will be
3084 // bumped in `startCoalescing`.
3085 seqP2P_++;
3086 }
3087 }
3088
3089 // Bump the logical operation counter regardless of whether this op is
3090 // coalesced or individual
3091 op_id_++;
3092
3093 auto ncclComm = getNCCLComm(key, device, opType, p2pRank, isSendRecvSelf);
3094
3095 if (coalescing_state_ & CoalActive) {
3096 coalescing_state_ |= CoalP2P;
3097 if (coalescedDevice_.index() < 0) {
3098 coalescedDevice_ = device;
3099 } else {
3100 TORCH_CHECK(
3101 coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG);
3102 }
3103 if (coalescedComm_ == nullptr) {
3104 coalescedComm_ = ncclComm;
3105 } else {
3106 TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
3107 }
3108 }
3109
3110 // Used many times below, so we stash the unordered_map lookup
3111 auto ncclStream = ncclStreams_.at(key);
3112 // First let NCCL streams wait for input tensors allocation streams
3113 syncStream(device, ncclEvents_[key], ncclStream);
3114
3115 // Work itself will create the CUDA events on all GPUs of tensors
3116 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> work;
3117 if (coalescing_state_) {
3118 // When coalescing, we record events per op that lack timing/state
3119 // information becuase there is no 'work' associated with them, and then
3120 // later in endCoalescing we record a 'coalesced' Work which has
3121 // timing/state updates via watchdog thread, but lacks op metadata such as
3122 // input/output sizes and profilingTitle per-op in the group.
3123 auto trace_id = NCCLTraceBuffer::get()->record(
3124 local_id_,
3125 std::make_tuple(pg_uid_, pg_desc_),
3126 seqCollective_,
3127 seqP2P_,
3128 op_id_,
3129 profilingTitle,
3130 {tensor},
3131 {tensor},
3132 nullptr,
3133 nullptr,
3134 options_->timeout,
3135 pgStatus_,
3136 /*isP2P=*/true);
3137 // TODO(whc) if we want to make the per-p2p-op flightrecorder entries get
3138 // their timings/states updated by proxy when the Work obj representing the
3139 // coalesce group gets its update, we could accumulate these trace_ids
3140 // together and ask FlightRecorder to take the update from one Work and
3141 // apply it to multiple entries
3142 (void)trace_id;
3143 } else {
3144 // Store references to outputs to be used by WorkNCCL::result and
3145 // operator<<. Note that these outputs are only valid for recv(), as send()
3146 // does not modify the inputs but we still create these outputs for use
3147 // cases such as profiling.
3148
3149 work = initWork(
3150 device, rank_, opType, profilingTitle, {tensor}, {}, /*record=*/false);
3151 // This bypasses something in Work() that crashes if {tensor} is given as
3152 // output, not sure what
3153 work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
3154 work->outputs_->push_back(tensor);
3155 // TODO(whc) because we don't pass output {tensor} to initWork, we tell
3156 // initWork to not record, and then we manually call record passing all the
3157 // information it wants.
3158 work->trace_id_ = NCCLTraceBuffer::get()->record(
3159 local_id_,
3160 std::make_tuple(pg_uid_, pg_desc_),
3161 seqCollective_,
3162 seqP2P_,
3163 op_id_,
3164 profilingTitle,
3165 {tensor},
3166 {tensor},
3167 work->ncclStartEvent_.get(),
3168 work->ncclEndEvent_.get(),
3169 options_->timeout,
3170 pgStatus_,
3171 /*isP2P=*/true);
3172 }
3173
3174 // is gpuGuard needed for the if block below, or can i swap them
3175 at::cuda::OptionalCUDAGuard gpuGuard(device);
3176
3177 // Only check for NaN for send ops, for recv ops `tensor` can be a random
3178 // placeholder
3179 if (enableNanCheck_ && opType == OpType::SEND) {
3180 checkForNan(tensor, ncclStream);
3181 }
3182
3183 if (!coalescing_state_) {
3184 // Start event should only be recorded before the ncclGroupStart()
3185 if (work->timingEnabled_) {
3186 work->ncclStartEvent_->record(ncclStream);
3187 }
3188
3189 pre(ncclStream, work);
3190 }
3191
3192 // Both send tensor and recv tensor are created on a worker stream and used
3193 // in different ncclStreams. Hence, both must record the ncclStream to
3194 // prevent being freed before the collective finishes.
3195 //
3196 // See [Sync Streams].
3197 c10::cuda::CUDACachingAllocator::recordStream(
3198 tensor.storage().data_ptr(), ncclStream);
3199
3200 // This part seems common to both p2p and coalesced-p2p usage?
3201 ncclComm_t comm_ = ncclComm->getNcclComm();
3202
3203 #ifndef NCCL_HAS_COMM_NONBLOCKING
3204 C10D_NCCL_CHECK(
3205 fn(tensor, comm_, ncclStream, p2pTargetRank),
3206 ncclComm->getNcclCommFailureReason());
3207 #else
3208 C10D_NCCL_CHECK_TIMEOUT(
3209 fn(tensor, comm_, ncclStream, p2pTargetRank),
3210 ncclComm->getNcclComm(),
3211 ncclComm->getNcclCommFailureReason());
3212 #endif
3213
3214 if (!coalescing_state_) {
3215 post(ncclStream);
3216
3217 // End event should only be recorded after the ncclGroupEnd()
3218 work->ncclEndEvent_->record(ncclStream);
3219 work->ncclComm_ = ncclComm;
3220 work->blockingWait_ = blockingWait_;
3221 work->store_ = store_;
3222 assignTimeoutToWork(work, options_);
3223 // Record size info for debug. We only record the size on the first device
3224 // as multi-device per process is deprecated
3225 work->numelIn_ = work->numelOut_ = tensor.numel();
3226
3227 // Future only needs to be created and marked completed with outputs for
3228 // recv(), but still create future for use cases such as profiling even for
3229 // send().
3230 {
3231 c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream);
3232 std::vector<at::Device> devices{device};
3233 work->future_ = c10::make_intrusive<at::ivalue::Future>(
3234 c10::ListType::create(c10::TensorType::get()), devices);
3235 work->future_->markCompleted(at::IValue(*work->outputs_));
3236 }
3237
3238 // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA
3239 // future blocks the stream this callback runs on the corresponding
3240 // ncclEndEvents_ ensuring appropriate synchronization.
3241 if (work->recordFunctionEndCallback_) {
3242 work->future_->addCallback(
3243 [work](at::ivalue::Future& /* unused */) {
3244 work->recordFunctionEndCallback_();
3245 },
3246 // uses_future = false allows us to skip synchronization in
3247 // ivalue::Future, but is only valid as long as the lambda doesn't use
3248 // the "Future" argument.
3249 /*uses_future=*/false);
3250 }
3251 }
3252
3253 // Enqueue P2P op so that it can be cancelled by NCCL watchdog
3254 c10::cuda::CaptureStatus capture_status =
3255 c10::cuda::currentStreamCaptureStatusMayInitCtx();
3256
3257 // Notify graphs before we check the capture status preemptively
3258 at::cuda::CUDAGraph::inc_pending_event_queries();
3259
3260 if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) {
3261 workEnqueue(work);
3262 return work;
3263 } else {
3264 at::cuda::CUDAGraph::dec_pending_event_queries();
3265 return nullptr;
3266 }
3267 }
3268
3269 template <typename Fn, typename PreProcess, typename PostProcess>
3270 c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
3271 at::Tensor& input,
3272 at::Tensor& output,
3273 Fn fn,
3274 PreProcess pre,
3275 PostProcess post,
3276 OpType opType,
3277 const char* profilingTitle,
3278 bool avoidRecordStreams,
3279 bool nanCheck) {
3280 auto inputs = std::vector<at::Tensor>{input};
3281 auto outputs = std::vector<at::Tensor>{output};
3282 return collective(
3283 inputs,
3284 outputs,
3285 fn,
3286 pre,
3287 post,
3288 opType,
3289 profilingTitle,
3290 avoidRecordStreams,
3291 nanCheck);
3292 }
3293
3294 template <typename Fn>
3295 c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
3296 at::Tensor& input,
3297 at::Tensor& output,
3298 Fn fn,
3299 OpType opType,
3300 const char* profilingTitle,
3301 bool avoidRecordStreams,
3302 bool nanCheck) {
3303 auto inputs = std::vector<at::Tensor>{input};
3304 auto outputs = std::vector<at::Tensor>{output};
3305 return collective(
3306 inputs,
3307 outputs,
3308 fn,
3309 [](at::cuda::CUDAStream&,
3310 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
3311 [](at::cuda::CUDAStream&,
3312 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
3313 opType,
3314 profilingTitle,
3315 avoidRecordStreams,
3316 nanCheck);
3317 }
3318
3319 template <typename Fn>
3320 c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
3321 at::Tensor& tensor,
3322 Fn fn,
3323 int peer,
3324 OpType opType,
3325 const char* profilingTitle) {
3326 return pointToPoint(
3327 tensor,
3328 fn,
3329 peer,
3330 opType,
3331 [](at::cuda::CUDAStream&,
3332 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
3333 [](at::cuda::CUDAStream&) {},
3334 profilingTitle);
3335 }
3336
3337 c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
3338 std::vector<at::Tensor>& tensors,
3339 const AllreduceOptions& opts) {
3340 TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3341 auto tensor = tensors.back();
3342 TORCH_CHECK(
3343 !isFloat8Type(tensor.scalar_type()),
3344 "Float8 dtypes are not currenlty supported for NCCL reductions");
3345 #ifdef IS_NCCLX
3346 tensor = tensor.coalesce();
3347 at::Tensor outputTensor =
3348 torch::zeros(tensor.sizes(), tensor.options().layout(torch::kStrided));
3349 auto work = collective(
3350 tensor,
3351 outputTensor,
3352 [&](at::Tensor& input,
3353 at::Tensor& output,
3354 ncclComm_t comm,
3355 at::cuda::CUDAStream& stream) {
3356 auto ncclDataType = getNcclDataType(input.scalar_type());
3357 auto ncclReduceOp =
3358 getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
3359
3360 size_t num_elements = output.numel();
3361 auto indices = input.indices();
3362 auto sizes = input.sizes();
3363 int colSize = sizes[1];
3364 auto rows = indices[0];
3365 size_t blockCount = rows.sizes()[0];
3366 auto recvIndices = indices[0] * colSize;
3367
3368 // prevent output and recvIndices from being freed
3369 c10::cuda::CUDACachingAllocator::recordStream(
3370 output.storage().data_ptr(), stream);
3371 c10::cuda::CUDACachingAllocator::recordStream(
3372 recvIndices.storage().data_ptr(), stream);
3373 auto result = ncclAllReduceSparseBlock(
3374 input._values().data_ptr(), // sendbuff
3375 recvIndices.data_ptr<int64_t>(), // recv_indices
3376 blockCount, // block_count
3377 colSize, // block_length
3378 output.data_ptr(), // recvbuff
3379 output.numel(), // recv_count
3380 ncclDataType,
3381 ncclReduceOp,
3382 comm,
3383 stream.stream());
3384 return result;
3385 },
3386 [](at::cuda::CUDAStream& ncclStream,
3387 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
3388 [&](at::cuda::CUDAStream& ncclStream,
3389 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
3390 // Convert output tensors to sparse and back into tensors.
3391 at::cuda::CUDAStreamGuard guard(ncclStream);
3392 if (opts.sparseIndices.has_value()) {
3393 tensor = at::sparse_coo_tensor(
3394 opts.sparseIndices.value(), outputTensor, tensor.sizes());
3395 } else {
3396 tensor = outputTensor.to_sparse();
3397 }
3398 },
3399 OpType::_ALLREDUCE_SPARSE,
3400 "nccl:all_reduce_sparse");
3401 return work;
3402 #else
3403 // If the nccl branch is not "exp" then we just error
3404 C10_THROW_ERROR(
3405 Error,
3406 "NCCL does not support all_reduce with sparse tensors. Please use dense tensors instead.");
3407 #endif
3408 }
3409
3410 c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_impl(
3411 at::Tensor& tensor,
3412 const AllreduceOptions& opts) {
3413 return collective(
3414 tensor,
3415 tensor,
3416 [&](at::Tensor& input,
3417 at::Tensor& output,
3418 ncclComm_t comm,
3419 at::cuda::CUDAStream& stream) {
3420 auto ncclDataType = getNcclDataType(input.scalar_type());
3421 auto ncclReduceOp =
3422 getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
3423 return ncclAllReduce(
3424 input.data_ptr(),
3425 output.data_ptr(),
3426 input.numel(),
3427 ncclDataType,
3428 ncclReduceOp,
3429 comm,
3430 stream.stream());
3431 },
3432 OpType::ALLREDUCE,
3433 "nccl:all_reduce");
3434 }
3435
3436 c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
3437 std::vector<at::Tensor>& tensors,
3438 const AllreduceOptions& opts) {
3439 TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3440 auto tensor = tensors.back();
3441 if (tensor.is_complex()) {
3442 TORCH_CHECK(
3443 complexViewAsRealAllowed(opts.reduceOp),
3444 "all_reduce does not support",
3445 opts.reduceOp,
3446 "on complex tensors");
3447 tensor = at::view_as_real(tensor);
3448 }
3449 check_gpu_single_tensor(tensor);
3450
3451 if (intraNodeComm_ != nullptr && opts.reduceOp == ReduceOp::SUM) {
3452 using namespace intra_node_comm;
3453 auto algo = intraNodeComm_->selectAllReduceAlgo(tensor);
3454 if (algo != intra_node_comm::AllReduceAlgo::NONE) {
3455 intraNodeComm_->allReduce(tensor, algo);
3456 return c10::make_intrusive<IntraNodeCommWork>();
3457 }
3458 }
3459 TORCH_CHECK(
3460 !isFloat8Type(tensor.scalar_type()),
3461 "Float8 dtypes are not currenlty supported for NCCL reductions");
3462 // @lint-ignore CLANGTIDY
3463 RECORD_PARAM_COMMS_DATA(
3464 static_cast<int>(
3465 this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
3466 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
3467 tensors, // inputTensors
3468 tensors, // outputTensors
3469 rank_, // rank
3470 "allreduce", // collective name
3471 tensor.numel(), // inNelems
3472 tensor.numel(), // outNelems
3473 tensor.scalar_type(), // dType
3474 std::vector<int64_t>(), // inSplitSizes
3475 std::vector<int64_t>(), // outSplitSizes
3476 globalRankStart, // globalRankStart
3477 globalRankStride, // globalRankStride
3478 this->getSize()); // worldSize
3479
3480 // avoidRecordStreams_ note: collective() will stash tensors.
3481 return allreduce_impl(tensor, opts);
3482 }
3483
3484 c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced(
3485 std::vector<at::Tensor>& tensors,
3486 const AllreduceCoalescedOptions& opts) {
3487 auto total_numel = check_gpu_tensors_same_device(tensors);
3488 TORCH_CHECK(
3489 !isFloat8Type(tensors.back().scalar_type()),
3490 "Float8 dtypes are not currenlty supported for NCCL reductions");
3491
3492 // @lint-ignore CLANGTIDY
3493 RECORD_PARAM_COMMS_DATA(
3494 static_cast<int>(
3495 this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
3496 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
3497 tensors, // inputTensors
3498 tensors, // outputTensors
3499 rank_, // rank
3500 "allreduce_coalesced", // collective name
3501 total_numel, // inNelems
3502 total_numel, // outNelems
3503 tensors[0].scalar_type(), // dType
3504 // I'm not sure what in,outSplitSizes mean here.
3505 std::vector<int64_t>(), // inSplitSizes
3506 std::vector<int64_t>(), // outSplitSizes
3507 globalRankStart, // globalRankStart
3508 globalRankStride, // globalRankStride
3509 this->getSize()); // worldSize
3510
3511 // avoidRecordStreams_ note: collective() will stash tensors.
3512 return collectiveCoalesced(
3513 tensors,
3514 tensors,
3515 [&](at::Tensor& input,
3516 at::Tensor& output,
3517 ncclComm_t comm,
3518 at::cuda::CUDAStream& stream) {
3519 auto ncclDataType = getNcclDataType(input.scalar_type());
3520 auto ncclReduceOp =
3521 getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
3522 return ncclAllReduce(
3523 input.data_ptr(),
3524 output.data_ptr(),
3525 input.numel(),
3526 ncclDataType,
3527 ncclReduceOp,
3528 comm,
3529 stream.stream());
3530 },
3531 OpType::COALESCED,
3532 "nccl:allreduce_coalesced");
3533 }
3534
3535 c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
3536 std::vector<at::Tensor>& tensors,
3537 const BroadcastOptions& opts) {
3538 TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3539 auto tensor = tensors.back();
3540 if (tensor.is_complex()) {
3541 tensor = at::view_as_real(tensor);
3542 }
3543 check_gpu_single_tensor(tensor);
3544
3545 // @lint-ignore CLANGTIDY
3546 RECORD_PARAM_COMMS_DATA(
3547 static_cast<int>(
3548 this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
3549 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
3550 tensors, // inputTensors
3551 tensors, // outputTensors
3552 opts.rootRank, // root rank
3553 "broadcast", // collective name
3554 tensor.numel(), // inNelems
3555 tensor.numel(), // outNelems
3556 tensor.scalar_type(), // dType
3557 std::vector<int64_t>(), // inSplitSizes
3558 std::vector<int64_t>(), // outSplitSizes
3559 globalRankStart, // globalRankStart
3560 globalRankStride, // globalRankStride
3561 this->getSize()); // worldSize
3562
3563 // avoidRecordStreams_ note: collective() will stash tensors.
3564 bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
3565
3566 const auto root = opts.rootRank + opts.rootTensor;
3567 bool nanCheck = (root == rank_);
3568
3569 return collective(
3570 tensor,
3571 tensor,
3572 [&](at::Tensor& input,
3573 at::Tensor& output,
3574 ncclComm_t comm,
3575 at::cuda::CUDAStream& stream) {
3576 return ncclBcast(
3577 input.data_ptr(),
3578 input.numel(),
3579 getNcclDataType(input.scalar_type()),
3580 root,
3581 comm,
3582 stream.stream());
3583 },
3584 OpType::BROADCAST,
3585 "nccl:broadcast",
3586 avoidRecordStreams,
3587 nanCheck);
3588 }
3589
3590 // _broadcast_oop adds an out-of-place broadcast in PGNCCL
3591 // Custom collectives may be implemented by coalescing broadcast operations
3592 // One use-case is implementing a vector all_gather (all_gather_v)
3593 // where unevenly sized inputs are gathered among participating ranks
3594 // Since all_gather provides an out-of-place API, an all_gather_v
3595 // semantic implemented inside pg_nccl.all_gather also needs to support
3596 // out-of-place, for which an out-of-place broadcast is required to be added
3597 c10::intrusive_ptr<Work> ProcessGroupNCCL::_broadcast_oop(
3598 at::Tensor& outputTensor,
3599 at::Tensor& inputTensor,
3600 const BroadcastOptions& opts) {
3601 if (outputTensor.numel() != inputTensor.numel()) {
3602 C10_THROW_ERROR(
3603 ValueError,
3604 "Tensor input and output of _broadcast_oop must have the same number of elements ");
3605 }
3606 const auto root = opts.rootRank + opts.rootTensor;
3607 bool nanCheck = (root == rank_);
3608 return collective(
3609 inputTensor,
3610 outputTensor,
3611 [&](at::Tensor& input,
3612 at::Tensor& output,
3613 ncclComm_t comm,
3614 at::cuda::CUDAStream& stream) {
3615 return ncclBroadcast(
3616 input.data_ptr(),
3617 output.data_ptr(),
3618 input.numel(),
3619 getNcclDataType(input.scalar_type()),
3620 root,
3621 comm,
3622 stream.stream());
3623 },
3624 OpType::BROADCAST,
3625 "nccl:_broadcast_oop",
3626 /*avoidRecordStreams=*/false,
3627 nanCheck);
3628 }
3629
3630 c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
3631 std::vector<at::Tensor>& tensors,
3632 const ReduceOptions& opts) {
3633 TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3634 // @lint-ignore CLANGTIDY
3635 auto tensor = tensors.back();
3636 if (tensor.is_complex()) {
3637 TORCH_CHECK(
3638 complexViewAsRealAllowed(opts.reduceOp),
3639 "reduce does not support",
3640 opts.reduceOp,
3641 "on complex tensors");
3642 tensor = at::view_as_real(tensor);
3643 }
3644 check_gpu_single_tensor(tensor);
3645 RECORD_PARAM_COMMS_DATA(
3646 static_cast<int>(
3647 this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
3648 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
3649 tensors, // inputTensors
3650 tensors, // outputTensors
3651 opts.rootRank, // root rank
3652 "reduce", // collective name
3653 tensor.numel(), // inNelems
3654 tensor.numel(), // outNelems
3655 tensor.scalar_type(), // dType
3656 std::vector<int64_t>(), // inSplitSizes
3657 std::vector<int64_t>(), // outSplitSizes
3658 globalRankStart, // globalRankStart
3659 globalRankStride, // globalRankStride
3660 this->getSize()); // worldSize
3661
3662 // avoidRecordStreams_ note: collective() will stash tensors.
3663 return collective(
3664 tensor,
3665 tensor,
3666 [&](at::Tensor& input,
3667 at::Tensor& output,
3668 ncclComm_t comm,
3669 at::cuda::CUDAStream& stream) {
3670 const auto root = opts.rootRank + opts.rootTensor;
3671 auto ncclDataType = getNcclDataType(input.scalar_type());
3672 auto ncclReduceOp =
3673 getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
3674 return ncclReduce(
3675 input.data_ptr(),
3676 output.data_ptr(),
3677 input.numel(),
3678 ncclDataType,
3679 ncclReduceOp,
3680 root,
3681 comm,
3682 stream.stream());
3683 },
3684 OpType::REDUCE,
3685 "nccl:reduce");
3686 }
3687
3688 // _reduce_oop exposes an out-of-place reduce from PGNCCL
3689 // Custom collectives may be implemented by coalescing reduce operations
3690 // One use-case is implementing a vector reduce_scatter (reduce_scatter_v)
3691 // where inputs are reduced and scattered unevenly among participating ranks
3692 // Since reduce_scatter provides an out-of-place API, a reduce_scatter_v
3693 // semantic implemented inside pg_nccl.reduce_scatter also needs to support
3694 // out-of-place, for which an out-of-place reduce is required to be added
3695 c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_oop(
3696 at::Tensor& outputTensor,
3697 at::Tensor& inputTensor,
3698 const ReduceOptions& opts) {
3699 if (outputTensor.numel() != inputTensor.numel()) {
3700 C10_THROW_ERROR(
3701 ValueError,
3702 "Tensor input and output of _reduce_oop must have the same number of elements ");
3703 }
3704 return collective(
3705 inputTensor,
3706 outputTensor,
3707 [&](at::Tensor& input,
3708 at::Tensor& output,
3709 ncclComm_t comm,
3710 at::cuda::CUDAStream& stream) {
3711 const auto root = opts.rootRank + opts.rootTensor;
3712 const auto ncclDataType = getNcclDataType(input.scalar_type());
3713 const auto ncclReduceOp =
3714 getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
3715 return ncclReduce(
3716 input.data_ptr(),
3717 output.data_ptr(),
3718 input.numel(),
3719 ncclDataType,
3720 ncclReduceOp,
3721 (int)root,
3722 comm,
3723 stream.stream());
3724 },
3725 OpType::REDUCE,
3726 "nccl:_reduce_oop");
3727 }
3728
3729 c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
3730 std::vector<std::vector<at::Tensor>>& outputTensors,
3731 std::vector<at::Tensor>& inputTensors,
3732 const AllgatherOptions& opts) {
3733 TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3734 // @lint-ignore CLANGTIDY
3735 auto inputTensor = inputTensors.back();
3736 check_gpu_single_tensor(inputTensor);
3737 // @lint-ignore CLANGTIDY
3738 auto outputTensors_ = outputTensors.back();
3739
3740 RECORD_PARAM_COMMS_DATA(
3741 static_cast<int>(
3742 this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
3743 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
3744 inputTensors, // inputTensors
3745 outputTensors, // outputTensors
3746 rank_, // rank
3747 "all_gather", // collective name
3748 inputTensor.numel(), // inNelems
3749 inputTensor.numel() * // outNelems
3750 this->getSize(),
3751 inputTensor.scalar_type(), // dType
3752 std::vector<int64_t>(), // inSplitSizes
3753 std::vector<int64_t>(), // outSplitSize
3754 globalRankStart, // globalRankStart
3755 globalRankStride, // globalRankStride
3756 this->getSize()); // worldSize
3757
3758 bool same_size = check_same_size(outputTensors_);
3759 if (same_size) {
3760 // Flatten a vector of tensors into a single, stacked tensor.
3761 at::Tensor outputFlattened = newLikeFlat(outputTensors_);
3762
3763 return collective(
3764 inputTensor,
3765 outputFlattened,
3766 [&](at::Tensor& input,
3767 at::Tensor& output,
3768 ncclComm_t comm,
3769 at::cuda::CUDAStream& stream) {
3770 if (!avoidRecordStreams_) {
3771 c10::cuda::CUDACachingAllocator::recordStream(
3772 output.storage().data_ptr(), stream);
3773 }
3774 return ncclAllGather(
3775 input.data_ptr(),
3776 output.data_ptr(),
3777 input.numel(),
3778 getNcclDataType(input.scalar_type()),
3779 comm,
3780 stream.stream());
3781 },
3782 [](at::cuda::CUDAStream& ncclStream,
3783 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
3784 // avoidRecordStreams_ note: We actually don't need to stash anything
3785 // here.
3786 // - inputTensors is stashed onto work->stashed_for_allocator_safety_
3787 // in collective().
3788 // - outputFlattened is stashed onto work->outputs_ in collective().
3789 // - User-facing outputTensors should be held by the user until after
3790 // waiting on work_, or the call makes no sense.
3791 // So all participating tensors are accounted for, and won't be
3792 // released back to their allocation streams until after work_ is
3793 // waited on.
3794 },
3795 [&](at::cuda::CUDAStream& ncclStream,
3796 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
3797 // Copy the flattened output tensors to the outputs.
3798 at::cuda::CUDAStreamGuard guard(ncclStream);
3799 for (const auto j : c10::irange(outputTensors_.size())) {
3800 // See [Sync Streams].
3801 if (!avoidRecordStreams_) {
3802 c10::cuda::CUDACachingAllocator::recordStream(
3803 outputTensors_[j].storage().data_ptr(), ncclStream);
3804 }
3805 outputTensors_[j].copy_(outputFlattened[j], true);
3806 }
3807 },
3808 OpType::ALLGATHER,
3809 "nccl:all_gather");
3810 } else {
3811 const auto num_reduces = outputTensors_.size();
3812 startCoalescing();
3813 for (const int i : c10::irange(num_reduces)) {
3814 auto& output = outputTensors_[i];
3815 auto& input = (i == rank_) ? inputTensor : output;
3816 auto broadcastOpts = BroadcastOptions{
3817 static_cast<int64_t>(i), static_cast<int64_t>(0), opts.timeout};
3818 _broadcast_oop(output, input, broadcastOpts);
3819 }
3820 auto work = endCoalescing(OpType::ALLGATHER);
3821 return work;
3822 }
3823 }
3824
3825 c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather_coalesced(
3826 std::vector<std::vector<at::Tensor>>& /* unused */,
3827 std::vector<at::Tensor>& /* unused */,
3828 const AllgatherOptions& /* unused */) {
3829 C10_THROW_ERROR(
3830 NotImplementedError,
3831 "ProcessGroupNCCL does not support allgather_coalesced");
3832 }
3833
3834 c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather_into_tensor_coalesced(
3835 std::vector<at::Tensor>& outputs,
3836 std::vector<at::Tensor>& inputs,
3837 const AllgatherOptions& opts) {
3838 return collectiveCoalesced(
3839 inputs,
3840 outputs,
3841 [&](at::Tensor& input,
3842 at::Tensor& output,
3843 ncclComm_t comm,
3844 at::cuda::CUDAStream& stream) {
3845 return ncclAllGather(
3846 input.data_ptr(),
3847 output.data_ptr(),
3848 input.numel(),
3849 getNcclDataType(input.scalar_type()),
3850 comm,
3851 stream.stream());
3852 },
3853 OpType::COALESCED,
3854 "nccl:all_gather_into_tensor_coalesced");
3855 }
3856
3857 c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
3858 std::vector<at::Tensor>& outputTensors,
3859 std::vector<std::vector<at::Tensor>>& inputTensors,
3860 const ReduceScatterOptions& opts) {
3861 TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
3862 // @lint-ignore CLANGTIDY
3863 auto outputTensor = outputTensors.back();
3864 check_gpu_single_tensor(outputTensor);
3865 // @lint-ignore CLANGTIDY
3866 auto inputTensors_ = inputTensors.back();
3867 TORCH_CHECK(
3868 !isFloat8Type(outputTensor.scalar_type()),
3869 "Float8 dtypes are not currenlty supported for NCCL reductions");
3870
3871 RECORD_PARAM_COMMS_DATA(
3872 static_cast<int>(
3873 this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
3874 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
3875 inputTensors, // inputTensors
3876 outputTensors, // outputTensors
3877 rank_, // rank
3878 "reduce_scatter", // collective name
3879 outputTensor.numel() * this->getSize(), // inNelems
3880 outputTensor.numel(), // outNelems
3881 outputTensor.scalar_type(), // dType
3882 std::vector<int64_t>(), // inSplitSizes
3883 std::vector<int64_t>(), // outSplitSizes
3884 globalRankStart, // globalRankStart
3885 globalRankStride, // globalRankStride
3886 this->getSize()); // worldSize
3887
3888 bool same_size = check_same_size(inputTensors_);
3889 if (same_size) {
3890 // Flatten a vector of tensors into a single, stacked tensor.
3891 at::Tensor inputFlattened = newLikeFlat(inputTensors_);
3892
3893 return collective(
3894 inputFlattened,
3895 outputTensor,
3896 [&](at::Tensor& input,
3897 at::Tensor& output,
3898 ncclComm_t comm,
3899 at::cuda::CUDAStream& stream) {
3900 if (!avoidRecordStreams_) {
3901 c10::cuda::CUDACachingAllocator::recordStream(
3902 output.storage().data_ptr(), stream);
3903 }
3904 const auto ncclDataType = getNcclDataType(input.scalar_type());
3905 const auto ncclReduceOp =
3906 getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
3907 return ncclReduceScatter(
3908 input.data_ptr(),
3909 output.data_ptr(),
3910 output.numel(),
3911 ncclDataType,
3912 ncclReduceOp,
3913 comm,
3914 stream.stream());
3915 },
3916 [&](at::cuda::CUDAStream& ncclStream,
3917 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
3918 if (avoidRecordStreams_) {
3919 // We only need to stash inputTensors.
3920 // - inputFlattened is stashed onto
3921 // work->stashed_for_allocator_safety_
3922 // in collective().
3923 // - User-facing outputTensors is stashed onto work->outputs_ in
3924 // collective(),
3925 // and should also be held by the user until after waiting on
3926 // work_.
3927 auto& v = work->stashed_for_allocator_safety_;
3928 v->insert(v->end(), inputTensors_.begin(), inputTensors_.end());
3929 }
3930
3931 // Copy the input tensors to the flattened inputs.
3932 at::cuda::CUDAStreamGuard guard(ncclStream);
3933 for (const auto j : c10::irange(inputTensors_.size())) {
3934 // See [Sync Streams].
3935 if (!avoidRecordStreams_) {
3936 c10::cuda::CUDACachingAllocator::recordStream(
3937 inputTensors_[j].storage().data_ptr(), ncclStream);
3938 }
3939 inputFlattened[j].copy_(inputTensors_[j], true);
3940 }
3941 },
3942 [&](at::cuda::CUDAStream&,
3943 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
3944 OpType::REDUCE_SCATTER,
3945 "nccl:reduce_scatter");
3946 } else {
3947 const auto num_reduces = inputTensors_.size();
3948 startCoalescing();
3949 for (const int i : c10::irange(num_reduces)) {
3950 auto& input = inputTensors_[i];
3951 auto& output = (i == rank_) ? outputTensor : input;
3952 auto reduceOpts = ReduceOptions{
3953 opts.reduceOp,
3954 static_cast<int64_t>(i),
3955 static_cast<int64_t>(0),
3956 opts.timeout};
3957 _reduce_oop(output, input, reduceOpts);
3958 }
3959 auto work = endCoalescing(OpType::REDUCE_SCATTER);
3960 return work;
3961 }
3962 }
3963
3964 c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
3965 at::Tensor& outputTensor,
3966 at::Tensor& inputTensor,
3967 const ReduceScatterOptions& opts) {
3968 if (inputTensor.dtype() != outputTensor.dtype()) {
3969 C10_THROW_ERROR(
3970 TypeError, "input tensor must be the same type as the output tensor.");
3971 }
3972
3973 if (inputTensor.numel() != outputTensor.numel() * size_) {
3974 C10_THROW_ERROR(
3975 ValueError,
3976 "input tensor must be the same size as output size times world size");
3977 }
3978
3979 // @lint-ignore CLANGTIDY
3980 const auto& tensor = outputTensor;
3981 TORCH_CHECK(
3982 !isFloat8Type(tensor.scalar_type()),
3983 "Float8 dtypes are not currenlty supported for NCCL reductions");
3984 RECORD_PARAM_COMMS_DATA(
3985 static_cast<int>(
3986 this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
3987 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
3988 inputTensor, // inputTensor
3989 outputTensor, // outputTensor
3990 rank_, // rank
3991 "_reduce_scatter_base", // collective name
3992 inputTensor.numel(), // inNelems
3993 tensor.numel(), // outNelems
3994 tensor.scalar_type(), // dtype
3995 std::vector<int64_t>(), // inSplitSizes
3996 std::vector<int64_t>(), // outSplitSizes
3997 globalRankStart, // globalRankStart
3998 globalRankStride, // globalRankStride
3999 this->getSize()); // worldSize
4000
4001 // avoidRecordStreams_ note: collective() will stash inputs and outputs.
4002 // Note 2: for asyncOp = false, we don't want to record streams because we
4003 // know that the NCCL stream will join back to the "current" stream right
4004 // after this op. So we might just as well keep the stream ownership of the
4005 // input/output tensors unchanged. The benefit would be that the
4006 // allocation/free of the tensors would look deterministic to the "current"
4007 // stream so that the caching allocator can reuse memory pool for this stream
4008 // in a clever way. This setting is added for libraries like FSDP which uses
4009 // `reduce_scatter_tensor`.
4010 bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
4011
4012 return collective(
4013 inputTensor,
4014 outputTensor,
4015 [&](at::Tensor& input,
4016 at::Tensor& output,
4017 ncclComm_t comm,
4018 at::cuda::CUDAStream& stream) {
4019 if (!avoidRecordStreams) {
4020 c10::cuda::CUDACachingAllocator::recordStream(
4021 output.storage().data_ptr(), stream);
4022 }
4023 auto ncclDataType = getNcclDataType(input.scalar_type());
4024 auto ncclReduceOp =
4025 getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
4026 return ncclReduceScatter(
4027 input.data_ptr(),
4028 output.data_ptr(),
4029 output.numel(),
4030 ncclDataType,
4031 ncclReduceOp,
4032 comm,
4033 stream.stream());
4034 },
4035 OpType::_REDUCE_SCATTER_BASE,
4036 "nccl:_reduce_scatter_base",
4037 avoidRecordStreams);
4038 }
4039
4040 c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
4041 std::vector<at::Tensor>& outputs,
4042 std::vector<at::Tensor>& inputs,
4043 const ReduceScatterOptions& opts) {
4044 TORCH_CHECK(
4045 !isFloat8Type(inputs.back().scalar_type()),
4046 "Float8 dtypes are not currenlty supported for NCCL reductions");
4047 return collectiveCoalesced(
4048 inputs,
4049 outputs,
4050 [&](at::Tensor& input,
4051 at::Tensor& output,
4052 ncclComm_t comm,
4053 at::cuda::CUDAStream& stream) {
4054 if (!avoidRecordStreams_) {
4055 c10::cuda::CUDACachingAllocator::recordStream(
4056 output.storage().data_ptr(), stream);
4057 }
4058 auto ncclDataType = getNcclDataType(input.scalar_type());
4059 auto ncclReduceOp =
4060 getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
4061 return ncclReduceScatter(
4062 input.data_ptr(),
4063 output.data_ptr(),
4064 output.numel(),
4065 ncclDataType,
4066 ncclReduceOp,
4067 comm,
4068 stream.stream());
4069 },
4070 OpType::COALESCED,
4071 "nccl:reduce_scatter_tensor_coalesced");
4072 }
4073
4074 c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
4075 RECORD_PARAM_COMMS(
4076 static_cast<int>(
4077 this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
4078 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4079 rank_, // rank
4080 "barrier", // collective name
4081 0, // inNelems
4082 0, // outNelems
4083 at::kByte, // dType
4084 std::vector<int64_t>(), // inSplitSizes
4085 std::vector<int64_t>(), // outSplitSizes
4086 globalRankStart, // globalRankStart
4087 globalRankStride, // globalRankStride
4088 this->getSize()); // worldSize
4089
4090 // Device to use for barrier
4091 int barDevIdx = -1;
4092
4093 // Select device to use for barrier
4094 // 1st choice: Use user defined GPU device ids if provided
4095 if (!opts.device_ids.empty()) {
4096 // Use the first device id because PG NCCL is single-device now
4097 barDevIdx = opts.device_ids[0];
4098 } else if (getBoundDeviceId()) {
4099 // 2nd choice: Use the bound GPU device id if available.
4100 // Bounded device id can be passed to `init_process_group`.
4101 barDevIdx = (*getBoundDeviceId()).index();
4102 } else if (!usedDeviceIdxs_.empty()) {
4103 // 3rd choice: infer the device id from the used device ids.
4104 barDevIdx = *usedDeviceIdxs_.begin();
4105 } else {
4106 // This means there is not yet a NCCL collective being called
4107 // Here we have to use the best guesses and will use a single GPU to call
4108 // allreduce to achieve barrier.
4109 // In case the multiple processes fall into the same node, we use rank to
4110 // ensure that each process is on a different GPU
4111 // Note: it is better to use global rank because the group-local rank can be
4112 // offset wrt the device id if intra-node GPUs are sharded into multiple
4113 // dimensions.
4114 barDevIdx = static_cast<int16_t>(globalRank() % localDeviceCount_);
4115 LOG(WARNING)
4116 << logPrefix()
4117 << c10::str(
4118 " using GPU ",
4119 barDevIdx,
4120 " to perform barrier as devices used by this process are currently unknown. ",
4121 "This can potentially cause a hang if this rank to GPU mapping is incorrect.",
4122 "Specify device_ids in barrier() to force use of a particular device,",
4123 "or call init_process_group() with a device_id.");
4124 }
4125
4126 TORCH_CHECK_WITH(
4127 ValueError,
4128 barDevIdx >= 0,
4129 "Failed to infer a GPU device id to perform barrier. ");
4130 auto barDevice = at::Device(at::DeviceType::CUDA, barDevIdx);
4131
4132 // Create a dummy tensor on the device
4133 // Note: we use zeros() instead of empty() to prevent barrier from triggering
4134 // alarm when NaN checker is enabled.
4135 at::Tensor barrierTensor =
4136 at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat));
4137
4138 // All reduce to achieve the barrier
4139 auto work = allreduce_impl(barrierTensor);
4140
4141 // Work will take over barrierTensors
4142 auto ncclWork = dynamic_cast<ProcessGroupNCCL::WorkNCCL*>(work.get());
4143 TORCH_CHECK(ncclWork);
4144 ncclWork->barrierTensor_ = std::move(barrierTensor);
4145 return work;
4146 }
4147
4148 c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
4149 at::Tensor& outputTensor,
4150 at::Tensor& inputTensor,
4151 std::vector<int64_t>& outputSplitSizes,
4152 std::vector<int64_t>& inputSplitSizes,
4153 const AllToAllOptions& /* unused */) {
4154 check_gpu_single_tensor(outputTensor, true);
4155 check_gpu_single_tensor(inputTensor, true);
4156 if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) {
4157 RECORD_PARAM_COMMS_DATA(
4158 static_cast<int>(
4159 this->getSequenceNumberForGroup() +
4160 1), // seq + 1 to match collective
4161 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4162 inputTensor, // inputTensor
4163 outputTensor, // outputTensor
4164 rank_, // rank
4165 "all_to_all", // collective name
4166 inputTensor.numel(), // inNelems
4167 outputTensor.numel(), // outNelems
4168 inputTensor.scalar_type(), // dType
4169 std::vector<int64_t>(), // inSplitSizes
4170 std::vector<int64_t>(), // outSplitSizes
4171 globalRankStart, // globalRankStart
4172 globalRankStride, // globalRankStride
4173 this->getSize()); // worldSize
4174
4175 // avoidRecordStreams_ note: collective() will stash inputTensors and
4176 // outputTensors.
4177 return collective(
4178 inputTensor,
4179 outputTensor,
4180 [&](at::Tensor& input,
4181 at::Tensor& output,
4182 ncclComm_t comm,
4183 at::cuda::CUDAStream& stream) {
4184 // See [Sync Streams].
4185 if (!avoidRecordStreams_) {
4186 c10::cuda::CUDACachingAllocator::recordStream(
4187 output.storage().data_ptr(), stream);
4188 }
4189 torch::cuda::nccl::all2all_single_equal_split(
4190 input, output, this->getSize(), comm, stream);
4191 return ncclSuccess;
4192 },
4193 OpType::ALLTOALL_BASE,
4194 "nccl:all_to_all");
4195 } else {
4196 c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
4197 c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
4198
4199 RECORD_PARAM_COMMS_DATA(
4200 static_cast<int>(
4201 this->getSequenceNumberForGroup() +
4202 1), // seq + 1 to match collective
4203 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4204 inputTensor, // inputTensor
4205 outputTensor, // outputTensor
4206 rank_, // rank
4207 "all_to_allv", // collective name
4208 inputTensor.numel(), // inNelems
4209 outputTensor.numel(), // outNelems
4210 inputTensor.scalar_type(), // dType
4211 inputSplitSizes, // inSplitSizes
4212 outputSplitSizes, // outSplitSizes
4213 globalRankStart, // globalRankStart
4214 globalRankStride, // globalRankStride
4215 this->getSize()); // worldSize
4216
4217 // avoidRecordStreams_ note: collective() will stash inputTensors and
4218 // outputTensors.
4219 return collective(
4220 inputTensor,
4221 outputTensor,
4222 [&](at::Tensor& input,
4223 at::Tensor& output,
4224 ncclComm_t comm,
4225 at::cuda::CUDAStream& stream) {
4226 std::vector<size_t> send_lengths(size_);
4227 std::vector<size_t> recv_lengths(size_);
4228 std::vector<size_t> send_offsets(size_);
4229 std::vector<size_t> recv_offsets(size_);
4230 c10d::computeLengthsAndOffsets(
4231 inputSplitSizes, input, &send_lengths, &send_offsets);
4232 c10d::computeLengthsAndOffsets(
4233 outputSplitSizes, output, &recv_lengths, &recv_offsets);
4234 // See [Sync Streams].
4235 if (!avoidRecordStreams_) {
4236 c10::cuda::CUDACachingAllocator::recordStream(
4237 output.storage().data_ptr(), stream);
4238 }
4239 torch::cuda::nccl::all2all_single_unequal_split(
4240 input.data_ptr(),
4241 send_lengths.data(),
4242 send_offsets.data(),
4243 output.data_ptr(),
4244 recv_lengths.data(),
4245 recv_offsets.data(),
4246 input.element_size(),
4247 input.scalar_type(),
4248 comm,
4249 stream);
4250 return ncclSuccess;
4251 },
4252 OpType::ALLTOALL_BASE,
4253 "nccl:all_to_all");
4254 }
4255 }
4256
4257 c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
4258 std::vector<at::Tensor>& outputTensors,
4259 std::vector<at::Tensor>& inputTensors,
4260 const AllToAllOptions& /* unused */) {
4261 std::vector<int64_t> inSplitSizes;
4262 std::vector<int64_t> outSplitSizes;
4263 int64_t total_numel = 0;
4264
4265 auto device = outputTensors[0].device();
4266 for (const auto r : c10::irange(outputTensors.size())) {
4267 check_gpu_single_tensor(outputTensors[r], true);
4268 check_gpu_single_tensor(inputTensors[r], true);
4269 TORCH_CHECK(
4270 device == outputTensors[r].device() &&
4271 device == inputTensors[r].device(),
4272 "Tensors must be on the same device")
4273 inSplitSizes.push_back(inputTensors[r].numel());
4274 outSplitSizes.push_back(outputTensors[r].numel());
4275 total_numel += inputTensors[r].numel();
4276 }
4277
4278 RECORD_PARAM_COMMS_DATA(
4279 static_cast<int>(
4280 this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
4281 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4282 inputTensors, // inputTensors
4283 outputTensors, // outputTensors
4284 rank_, // rank
4285 "all_to_all", // collective name
4286 total_numel, // inNelems
4287 total_numel, // outNelems
4288 inputTensors.front().scalar_type(), // dType
4289 inSplitSizes, // inSplitSizes
4290 outSplitSizes, // outSplitSizes
4291 globalRankStart, // globalRankStart
4292 globalRankStride, // globalRankStride
4293 this->getSize()); // worldSize
4294
4295 return collective(
4296 inputTensors,
4297 outputTensors,
4298 [&](at::Tensor& /* unused */,
4299 at::Tensor& /* unused */,
4300 ncclComm_t comm,
4301 at::cuda::CUDAStream& stream) {
4302 torch::cuda::nccl::all2all(outputTensors, inputTensors, comm, stream);
4303 return ncclSuccess;
4304 },
4305 [&](at::cuda::CUDAStream&,
4306 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
4307 if (avoidRecordStreams_) {
4308 // inputTensor0 and outputTensor0 are stashed redundantly by
4309 // collective(), but that's ok.
4310 auto& v = work->stashed_for_allocator_safety_;
4311 v->insert(v->end(), inputTensors.begin(), inputTensors.end());
4312 v->insert(v->end(), outputTensors.begin(), outputTensors.end());
4313 }
4314 },
4315 [](at::cuda::CUDAStream&,
4316 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
4317 OpType::ALLTOALL,
4318 "nccl:all_to_all");
4319 }
4320
4321 c10::intrusive_ptr<Work> ProcessGroupNCCL::send(
4322 std::vector<at::Tensor>& tensors,
4323 int dstRank,
4324 int /* unused */) {
4325 TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
4326 // @lint-ignore CLANGTIDY
4327 auto tensor = tensors.back();
4328 check_gpu_single_tensor(tensor, true);
4329
4330 RECORD_PARAM_COMMS_DATA(
4331 static_cast<int>(
4332 this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
4333 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4334 tensors, // inputTensors
4335 tensors, // outputTensors
4336 dstRank, // dst rank
4337 "send", // collective name
4338 tensor.numel(), // inNelems
4339 tensor.numel(), // outNelems
4340 tensor.scalar_type(), // dType
4341 std::vector<int64_t>(), // inSplitSizes
4342 std::vector<int64_t>(), // outSplitSizes
4343 globalRankStart, // globalRankStart
4344 globalRankStride, // globalRankStride
4345 this->getSize()); // worldSize
4346
4347 auto ret = pointToPoint(
4348 tensor,
4349 [&](at::Tensor& input,
4350 ncclComm_t comm,
4351 at::cuda::CUDAStream& stream,
4352 int dst) {
4353 torch::cuda::nccl::send(input, comm, stream, dst);
4354 return ncclSuccess;
4355 },
4356 dstRank,
4357 OpType::SEND,
4358 c10::str("nccl:send ", rank_, "->", dstRank).c_str());
4359 return ret;
4360 }
4361
4362 c10::intrusive_ptr<Work> ProcessGroupNCCL::recv(
4363 std::vector<at::Tensor>& tensors,
4364 int srcRank,
4365 int /* unused */) {
4366 TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
4367 // @lint-ignore CLANGTIDY
4368 auto tensor = tensors.back();
4369 check_gpu_single_tensor(tensor, true);
4370
4371 RECORD_PARAM_COMMS_DATA(
4372 static_cast<int>(
4373 this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
4374 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4375 tensors, // inputTensors
4376 tensors, // outputTensors
4377 srcRank, // src rank
4378 "recv", // collective name
4379 tensor.numel(), // inNelems
4380 tensor.numel(), // outNelems
4381 tensor.scalar_type(), // dType
4382 std::vector<int64_t>(), // inSplitSizes
4383 std::vector<int64_t>(), // outSplitSizes
4384 globalRankStart, // globalRankStart
4385 globalRankStride, // globalRankStride
4386 this->getSize()); // worldSize
4387
4388 auto ret = pointToPoint(
4389 tensor,
4390 [&](at::Tensor& output,
4391 ncclComm_t comm,
4392 at::cuda::CUDAStream& stream,
4393 int src) {
4394 torch::cuda::nccl::recv(output, comm, stream, src);
4395 return ncclSuccess;
4396 },
4397 srcRank,
4398 OpType::RECV,
4399 c10::str("nccl:recv ", rank_, "<-", srcRank).c_str());
4400 return ret;
4401 }
4402
4403 void ProcessGroupNCCL::groupStart() {
4404 C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt);
4405 ++ncclActiveGroupCounter_;
4406 }
4407
4408 void ProcessGroupNCCL::groupEnd() {
4409 C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt);
4410 --ncclActiveGroupCounter_;
4411 }
4412
4413 void ProcessGroupNCCL::groupEndNonblocking(std::shared_ptr<NCCLComm> comm) {
4414 #ifndef NCCL_HAS_COMM_NONBLOCKING
4415 C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt);
4416 #else
4417 if (!nccl_use_nonblocking()) {
4418 C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt);
4419 } else {
4420 C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comm, std::nullopt);
4421 }
4422 #endif
4423 --ncclActiveGroupCounter_;
4424 }
4425
4426 c10::intrusive_ptr<Work> ProcessGroupNCCL::gather(
4427 std::vector<std::vector<at::Tensor>>& outputTensors,
4428 std::vector<at::Tensor>& inputTensors,
4429 const GatherOptions& opts) {
4430 static auto invalidArgument = [](const std::string& msg) {
4431 C10_THROW_ERROR(ValueError, "ProcessGroupNCCL::gather: " + msg);
4432 };
4433
4434 assertRootRank(invalidArgument, opts.rootRank, size_);
4435
4436 TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
4437 // @lint-ignore CLANGTIDY
4438 auto inputTensor = inputTensors.back();
4439
4440 std::vector<at::Tensor> outputs;
4441
4442 if (getRank() == opts.rootRank) {
4443 if (outputTensors.size() != 1) {
4444 std::stringstream ss;
4445 ss << "requires a single-element output list containing a list with "
4446 << getSize() << " tensors.";
4447 invalidArgument(ss.str());
4448 } else if (outputTensors[0].size() != static_cast<size_t>(getSize())) {
4449 std::stringstream ss;
4450 ss << "Incorrect output list size " << outputTensors[0].size()
4451 << ". Output list size should be " << getSize()
4452 << ", same as size of the process group.";
4453 invalidArgument(ss.str());
4454 }
4455
4456 const auto& options = inputTensor.options();
4457 const auto& sizes = inputTensor.sizes();
4458 assertTypeAndSizesMatch(invalidArgument, outputTensors[0], options, sizes);
4459 outputs = outputTensors[0];
4460 } else {
4461 // if not in the root rank, initialize outputs as empty list
4462 if (outputTensors.size() != 0) {
4463 invalidArgument("requires empty output on non-root");
4464 }
4465 outputs = {};
4466 // append a empty tensor to the list, we don't use it but the
4467 // `collective` template function requires it to invoke its function
4468 outputs.emplace_back();
4469 }
4470
4471 RECORD_PARAM_COMMS_DATA(
4472 static_cast<int>(
4473 this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
4474 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4475 inputTensors, // inputTensors
4476 outputTensors, // outputTensors
4477 opts.rootRank, // root rank
4478 "gather", // collective name
4479 inputTensor.numel(), // inNelems
4480 inputTensor.numel() * this->getSize(), // outNelems
4481 inputTensor.scalar_type(), // dType
4482 std::vector<int64_t>(), // inSplitSizes
4483 std::vector<int64_t>(), // outSplitSize
4484 globalRankStart, // globalRankStart
4485 globalRankStride, // globalRankStride
4486 this->getSize()); // worldSize
4487
4488 // avoidRecordStreams_ note: collective() will stash inputTensors and
4489 // outputs, which == outputTensors[0] on the root rank where it matters.
4490
4491 auto inputs = std::vector<at::Tensor>{inputTensor};
4492 return collective(
4493 inputs,
4494 outputs, // just to fit the collective interface
4495 [&](at::Tensor& /* unused */,
4496 at::Tensor& /* unused */,
4497 ncclComm_t comm,
4498 at::cuda::CUDAStream& stream) {
4499 const auto root = opts.rootRank;
4500 if (getRank() == root) {
4501 if (!avoidRecordStreams_) {
4502 for (auto output : outputs) {
4503 c10::cuda::CUDACachingAllocator::recordStream(
4504 output.storage().data_ptr(), stream);
4505 }
4506 }
4507 }
4508 torch::cuda::nccl::gather(inputTensor, outputs, comm, stream, root);
4509 return ncclSuccess;
4510 },
4511 [](at::cuda::CUDAStream&,
4512 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
4513 [](at::cuda::CUDAStream&,
4514 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
4515 OpType::GATHER,
4516 "nccl:gather");
4517 }
4518
4519 c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
4520 std::vector<at::Tensor>& outputTensors,
4521 std::vector<std::vector<at::Tensor>>& inputTensors,
4522 const ScatterOptions& opts) {
4523 static auto invalidArgument = [](const std::string& msg) {
4524 C10_THROW_ERROR(ValueError, "ProcessGroupNCCL::scatter: " + msg);
4525 };
4526
4527 assertRootRank(invalidArgument, opts.rootRank, size_);
4528
4529 TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
4530 auto outputTensor = outputTensors.back();
4531
4532 std::vector<at::Tensor> inputs;
4533
4534 if (getRank() == opts.rootRank) {
4535 if (inputTensors.size() != 1) {
4536 std::stringstream ss;
4537 ss << "requires a single-element input list containing a list with "
4538 << getSize() << " tensors.";
4539 invalidArgument(ss.str());
4540 } else if (inputTensors[0].size() != static_cast<size_t>(getSize())) {
4541 std::stringstream ss;
4542 ss << "Incorrect input list size " << inputTensors[0].size()
4543 << ". Input list size should be " << getSize()
4544 << ", same as size of the process group.";
4545 invalidArgument(ss.str());
4546 }
4547
4548 const auto& options = outputTensor.options();
4549 const auto& sizes = outputTensor.sizes();
4550 assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes);
4551 inputs = inputTensors[0];
4552 } else {
4553 // if not in the root rank, initialize inputTensors as empty place holder
4554 // with an empty list
4555 if (inputTensors.size() != 0) {
4556 invalidArgument("requires empty input on non-root");
4557 }
4558 inputs = {};
4559 // append a empty tensor to the list, we don't use it but the
4560 // `collective` template function requires it to invoke its function
4561 inputs.emplace_back();
4562 }
4563
4564 RECORD_PARAM_COMMS_DATA(
4565 static_cast<int>(
4566 this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
4567 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4568 inputTensors, // inputTensors
4569 outputTensors, // outputTensors
4570 opts.rootRank, // root rank
4571 "scatter", // collective name
4572 outputTensor.numel() * this->getSize(), // inNelems
4573 outputTensor.numel(), // outNelems
4574 outputTensor.scalar_type(), // dType
4575 std::vector<int64_t>(), // inSplitSizes
4576 std::vector<int64_t>(), // outSplitSize
4577 globalRankStart, // globalRankStart
4578 globalRankStride, // globalRankStride
4579 this->getSize()); // worldSize
4580
4581 // avoidRecordStreams_ note: collective() will stash outputTensors and
4582 // inputs, which == inputTensors[0] on the root rank where it matters.
4583 bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
4584
4585 const auto root = opts.rootRank;
4586 bool nanCheck = (rank_ == root);
4587
4588 auto outputs = std::vector<at::Tensor>{outputTensor};
4589 return collective(
4590 outputs,
4591 inputs, // just to fit the collective interface
4592 [&](at::Tensor& /* unused */,
4593 at::Tensor& /* unused */,
4594 ncclComm_t comm,
4595 at::cuda::CUDAStream& stream) {
4596 if (getRank() == root) {
4597 if (!avoidRecordStreams) {
4598 for (auto input : inputs) {
4599 c10::cuda::CUDACachingAllocator::recordStream(
4600 input.storage().data_ptr(), stream);
4601 }
4602 }
4603 }
4604 torch::cuda::nccl::scatter(inputs, outputTensor, comm, stream, root);
4605 return ncclSuccess;
4606 },
4607 [](at::cuda::CUDAStream&,
4608 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
4609 [](at::cuda::CUDAStream&,
4610 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
4611 OpType::SCATTER,
4612 "nccl:scatter",
4613 avoidRecordStreams,
4614 nanCheck);
4615 }
4616
4617 c10::intrusive_ptr<Work> ProcessGroupNCCL::recvAnysource(
4618 std::vector<at::Tensor>& /* unused */,
4619 int /* unused */) {
4620 C10_THROW_ERROR(
4621 NotImplementedError, "ProcessGroupNCCL does not support recvAnysource");
4622 }
4623
4624 c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
4625 at::Tensor& output_tensor,
4626 at::Tensor& input_tensor,
4627 const AllgatherOptions& opts) {
4628 check_gpu_single_tensor(input_tensor);
4629 check_gpu_single_tensor(output_tensor);
4630
4631 if (input_tensor.dtype() != output_tensor.dtype()) {
4632 C10_THROW_ERROR(
4633 TypeError, "output tensor must have the same type as input tensor");
4634 }
4635
4636 if (input_tensor.numel() * size_ != output_tensor.numel()) {
4637 C10_THROW_ERROR(
4638 ValueError,
4639 "output tensor size must be equal to world_size times input tensor size");
4640 }
4641
4642 RECORD_PARAM_COMMS_DATA(
4643 static_cast<int>(
4644 this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
4645 std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
4646 input_tensor, // inputTensors
4647 output_tensor, // outputTensors
4648 rank_, // rank
4649 "_allgather_base", // collective name
4650 input_tensor.numel(), // inNelems
4651 output_tensor.numel(), // outNelems
4652 output_tensor.scalar_type(), // dType
4653 std::vector<int64_t>(), // inSplitSizes
4654 std::vector<int64_t>(), // outSplitSize
4655 globalRankStart, // globalRankStart
4656 globalRankStride, // globalRankStride
4657 this->getSize()); // worldSize
4658
4659 // avoidRecordStreams_ note: collective() will stash inputs and outputs.
4660 // Note 2: for asyncOp = false, we don't want to record streams because we
4661 // know that the NCCL stream will join back to the "current" stream right
4662 // after this op. So we might just as well keep the stream ownership of the
4663 // input/output tensors unchanged. The benefit would be that the
4664 // allocation/free of the tensors would look deterministic to the "current"
4665 // stream so that the caching allocator can reuse memory pool for this stream
4666 // in a clever way. This setting is added for libraries like FSDP which uses
4667 // `all_gather_into_tensor`.
4668 bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
4669
4670 return collective(
4671 input_tensor,
4672 output_tensor,
4673 [&](at::Tensor& input,
4674 at::Tensor& output,
4675 ncclComm_t comm,
4676 at::cuda::CUDAStream& stream) {
4677 if (!avoidRecordStreams) {
4678 c10::cuda::CUDACachingAllocator::recordStream(
4679 output.storage().data_ptr(), stream);
4680 }
4681 return ncclAllGather(
4682 input.data_ptr(),
4683 output.data_ptr(),
4684 input.numel(),
4685 getNcclDataType(input.scalar_type()),
4686 comm,
4687 stream.stream());
4688 },
4689 OpType::_ALLGATHER_BASE,
4690 "nccl:_all_gather_base",
4691 avoidRecordStreams);
4692 }
4693
4694 } // namespace c10d
4695
4696 #endif // USE_C10D_NCCL
4697