1 #include <ATen/core/functional.h>
2 #include <torch/csrc/cuda/device_set.h>
3 #include <torch/csrc/cuda/nccl.h>
4
5 #include <ATen/ATen.h>
6 #include <c10/cuda/CUDAException.h>
7 #include <c10/cuda/CUDAGuard.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/hash.h>
10 #include <c10/util/irange.h>
11
12 #include <nccl.h>
13
14 #include <limits>
15 #include <sstream>
16 #include <type_traits>
17 #include <unordered_map>
18
19 #if !defined(USE_ROCM) && \
20 ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 14)))
21 #define NCCL_HAS_COMM_NONBLOCKING 1
22 #endif
23
to_nccl_comm(torch::cuda::nccl::ncclComm_t * var)24 ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) {
25 return reinterpret_cast<ncclComm_t*>(var);
26 }
27
to_nccl_comm(torch::cuda::nccl::ncclComm_t var)28 ncclComm_t to_nccl_comm(torch::cuda::nccl::ncclComm_t var) {
29 return reinterpret_cast<ncclComm_t>(var);
30 }
31
to_nccl_unique_id(torch::cuda::nccl::ncclUniqueId * var)32 ncclUniqueId* to_nccl_unique_id(torch::cuda::nccl::ncclUniqueId* var) {
33 return reinterpret_cast<ncclUniqueId*>(var);
34 }
35
to_nccl_result(torch::cuda::nccl::ncclResult var)36 ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) {
37 switch (var) {
38 case torch::cuda::nccl::ncclResult::Success:
39 return ncclResult_t::ncclSuccess;
40 case torch::cuda::nccl::ncclResult::UnhandledCudaError:
41 return ncclResult_t::ncclUnhandledCudaError;
42 case torch::cuda::nccl::ncclResult::SystemError:
43 return ncclResult_t::ncclSystemError;
44 case torch::cuda::nccl::ncclResult::InternalError:
45 return ncclResult_t::ncclInternalError;
46 case torch::cuda::nccl::ncclResult::InvalidArgument:
47 return ncclResult_t::ncclInvalidArgument;
48 case torch::cuda::nccl::ncclResult::InvalidUsage:
49 return ncclResult_t::ncclInvalidUsage;
50 case torch::cuda::nccl::ncclResult::RemoteError:
51 return ncclResult_t::ncclRemoteError;
52 #ifdef NCCL_HAS_COMM_NONBLOCKING
53 case torch::cuda::nccl::ncclResult::InProgress:
54 return ncclResult_t::ncclInProgress;
55 #endif
56 case torch::cuda::nccl::ncclResult::NumResults:
57 return ncclResult_t::ncclNumResults;
58 default:
59 throw std::runtime_error("Unconvertible NCCL type");
60 }
61 }
62
from_nccl_result(ncclResult_t var)63 torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) {
64 switch (var) {
65 case ncclSuccess:
66 return torch::cuda::nccl::ncclResult::Success;
67 case ncclUnhandledCudaError:
68 return torch::cuda::nccl::ncclResult::UnhandledCudaError;
69 case ncclSystemError:
70 return torch::cuda::nccl::ncclResult::SystemError;
71 case ncclInternalError:
72 return torch::cuda::nccl::ncclResult::InternalError;
73 case ncclInvalidArgument:
74 return torch::cuda::nccl::ncclResult::InvalidArgument;
75 case ncclInvalidUsage:
76 return torch::cuda::nccl::ncclResult::InvalidUsage;
77 case ncclRemoteError:
78 return torch::cuda::nccl::ncclResult::RemoteError;
79 #ifdef NCCL_HAS_COMM_NONBLOCKING
80 case ncclInProgress:
81 return torch::cuda::nccl::ncclResult::InProgress;
82 #endif
83 case ncclNumResults:
84 return torch::cuda::nccl::ncclResult::NumResults;
85 default:
86 throw std::runtime_error("Unconvertible NCCL type");
87 }
88 }
89
to_nccl_data_type(c10::ScalarType type)90 ncclDataType_t to_nccl_data_type(c10::ScalarType type) {
91 switch (type) {
92 case at::kFloat:
93 return ncclDataType_t::ncclFloat;
94 case at::kHalf:
95 return ncclDataType_t::ncclHalf;
96 case at::kDouble:
97 return ncclDataType_t::ncclDouble;
98 case at::kLong:
99 return ncclDataType_t::ncclInt64;
100 case at::kInt:
101 return ncclDataType_t::ncclInt;
102 case at::kChar:
103 return ncclDataType_t::ncclChar;
104 case at::kByte:
105 return ncclDataType_t::ncclUint8;
106 case at::kBool:
107 return ncclDataType_t::ncclUint8;
108 #if HAS_NCCL_BF16_DATATYPE
109 case at::kBFloat16:
110 return ncclDataType_t::ncclBfloat16;
111 #endif
112 default:
113 TORCH_CHECK(false, "Unconvertible NCCL type ", type);
114 }
115 }
116
to_nccl_data_type(const at::Tensor & t)117 ncclDataType_t to_nccl_data_type(const at::Tensor& t) {
118 if (!t.is_cuda()) {
119 TORCH_CHECK(
120 false,
121 "NCCL only supports CUDA tensors, but got a tensor on ",
122 t.device());
123 }
124 return to_nccl_data_type(t.scalar_type());
125 }
126
to_nccl_red_op(int var)127 ncclRedOp_t to_nccl_red_op(int var) {
128 return (ncclRedOp_t)(var);
129 }
130
131 namespace torch::cuda::nccl {
132
133 using namespace at;
134
135 namespace detail {
136
NCCL_CHECK(ncclResult_t result)137 static inline void NCCL_CHECK(ncclResult_t result) {
138 NCCL_CHECK(from_nccl_result(result));
139 }
140
141 // TODO(eqy): can this duplication be avoided from NCCLUtils.cpp?
nccl_use_nonblocking()142 bool nccl_use_nonblocking() {
143 static bool nccl_use_nonblocking_ =
144 c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
145 if (nccl_use_nonblocking_) {
146 TORCH_WARN("Using experimental non-blocking NCCL communicator.");
147 }
148 return nccl_use_nonblocking_;
149 }
150
_parse_nccl_nonblocking_timeout()151 static int _parse_nccl_nonblocking_timeout() {
152 const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT");
153 int timeout = -1;
154 if (val) {
155 const std::string config(val);
156 timeout = std::stoi(config);
157 if (!nccl_use_nonblocking() && timeout > 0) {
158 TORCH_WARN(
159 "TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false.");
160 timeout = -1;
161 }
162 }
163 return timeout;
164 }
165
nccl_nonblocking_timeout()166 static int nccl_nonblocking_timeout() {
167 static int timeout = _parse_nccl_nonblocking_timeout();
168 return timeout;
169 }
170
NCCL_CHECK_TIMEOUT(ncclResult status,ncclComm_t comm)171 static inline void NCCL_CHECK_TIMEOUT(ncclResult status, ncclComm_t comm) {
172 #ifdef NCCL_HAS_COMM_NONBLOCKING
173 ncclResult_t result = to_nccl_result(status);
174 auto startTimepoint = std::chrono::steady_clock::now();
175 while (result == ncclInProgress) {
176 if (nccl_nonblocking_timeout() > 0) {
177 auto currentTimepoint = std::chrono::steady_clock::now();
178 auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
179 currentTimepoint - startTimepoint)
180 .count();
181 if (timeElapsed > nccl_nonblocking_timeout()) {
182 throw std::runtime_error("NCCL timeout.");
183 }
184 }
185 ncclCommGetAsyncError(to_nccl_comm(comm), &result);
186 }
187 if (result != ncclSuccess) {
188 throw_nccl_error(from_nccl_result(result));
189 }
190 #else
191 TORCH_INTERNAL_ASSERT(
192 false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION.");
193 #endif
194 }
195
NCCL_CHECK_TIMEOUT(ncclResult_t result,ncclComm_t comm)196 static inline void NCCL_CHECK_TIMEOUT(ncclResult_t result, ncclComm_t comm) {
197 NCCL_CHECK_TIMEOUT(from_nccl_result(result), comm);
198 }
199
NCCL_CHECK_TIMEOUT(ncclResult status,std::vector<ncclComm_t> & comms)200 static inline void NCCL_CHECK_TIMEOUT(
201 ncclResult status,
202 std::vector<ncclComm_t>& comms) {
203 #ifdef NCCL_HAS_COMM_NONBLOCKING
204 ncclResult_t result = to_nccl_result(status);
205 auto startTimepoint = std::chrono::steady_clock::now();
206 if (result == ncclInProgress) {
207 for (const auto i : c10::irange(comms.size())) {
208 do {
209 if (nccl_nonblocking_timeout() > 0) {
210 auto currentTimepoint = std::chrono::steady_clock::now();
211 auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
212 currentTimepoint - startTimepoint)
213 .count();
214 if (timeElapsed > nccl_nonblocking_timeout()) {
215 throw std::runtime_error("NCCL timeout.");
216 }
217 }
218 ncclCommGetAsyncError(to_nccl_comm(comms[i]), &result);
219 } while (result == ncclInProgress);
220 if (result != ncclSuccess) {
221 break; /* fall through to failed case */
222 }
223 }
224 }
225 if (result != ncclSuccess) {
226 throw_nccl_error(from_nccl_result(result));
227 }
228 #else
229 TORCH_INTERNAL_ASSERT(
230 false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION.");
231 #endif
232 }
233
NCCL_CHECK_TIMEOUT(ncclResult_t result,std::vector<ncclComm_t> & comms)234 static inline void NCCL_CHECK_TIMEOUT(
235 ncclResult_t result,
236 std::vector<ncclComm_t>& comms) {
237 NCCL_CHECK_TIMEOUT(from_nccl_result(result), comms);
238 }
239
throw_nccl_error(torch::cuda::nccl::ncclResult status)240 void throw_nccl_error(torch::cuda::nccl::ncclResult status) {
241 std::ostringstream err;
242 err << "NCCL Error " << static_cast<int>(status) << ": "
243 << ncclGetErrorString(to_nccl_result(status));
244 throw std::runtime_error(err.str());
245 }
246
247 struct NcclCommList {
248 std::unique_ptr<ncclComm_t[]> comms;
249 int ndevices;
NcclCommListtorch::cuda::nccl::detail::NcclCommList250 NcclCommList(const std::vector<int>& devices)
251 : comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) {
252 NCCL_CHECK(ncclCommInitAll(
253 to_nccl_comm(comms.get()), devices.size(), devices.data()));
254 }
255 NcclCommList(NcclCommList&& foo) = default;
~NcclCommListtorch::cuda::nccl::detail::NcclCommList256 ~NcclCommList() {
257 if (comms) {
258 for (const auto i : c10::irange(ndevices)) {
259 int dummy_var;
260 if (C10_CUDA_ERROR_HANDLED(cudaGetDevice(&dummy_var)) != cudaSuccess) {
261 /* there are cases when this destructor is called after the
262 CUDA driver is already unloaded from the process.
263 In these cases, skip ncclCommDestroy */
264 return;
265 }
266 comm_destroy(comms[i]);
267 }
268 }
269 }
reftorch::cuda::nccl::detail::NcclCommList270 ArrayRef<ncclComm_t> ref() const {
271 return ArrayRef<ncclComm_t>(comms.get(), ndevices);
272 }
273 };
274
275 using device_list = std::vector<int>;
276 // accesses to this object have to be guarded by THC's CudaFreeMutex
277 static std::unordered_map<device_list, NcclCommList, c10::hash<device_list>>
278 _communicators;
279
get_communicators(TensorList inputs)280 ArrayRef<ncclComm_t> get_communicators(TensorList inputs) {
281 static auto get_device = [](const at::Tensor& t) -> int {
282 return t.get_device();
283 };
284 device_list devices = fmap(inputs, get_device);
285 auto it = _communicators.find(devices);
286 if (it == _communicators.end()) {
287 it = _communicators.emplace(devices, devices).first;
288 }
289 return it->second.ref();
290 }
291
check_tensor(const at::Tensor & input,const at::optional<at::Tensor> & output,int input_multiplier,int output_multiplier,int64_t ref_numel,ScalarType ref_dtype)292 static inline void check_tensor(
293 const at::Tensor& input,
294 const at::optional<at::Tensor>& output,
295 int input_multiplier,
296 int output_multiplier,
297 int64_t ref_numel,
298 ScalarType ref_dtype) {
299 auto check_one = [&](const at::Tensor& tensor) {
300 if (!tensor.is_cuda() || tensor.is_sparse()) {
301 throw std::runtime_error(
302 "input and output elements have to be cuda dense Tensors");
303 }
304
305 if (ref_dtype != tensor.scalar_type()) {
306 throw std::runtime_error(
307 "all inputs and outputs must be of the same Tensor dtype");
308 }
309
310 if (!tensor.is_contiguous()) {
311 throw std::runtime_error("all inputs and outputs have to be contiguous");
312 }
313 };
314
315 check_one(input);
316
317 // all inputs must be same size
318 if (input.numel() != ref_numel) {
319 throw std::runtime_error(
320 "all inputs must have the same number of elements");
321 }
322
323 if (output) {
324 check_one(*output);
325
326 // inputs and outputs must be on same device respectively
327 if (input.get_device() != output->get_device()) {
328 throw std::runtime_error("input and output must be on the same device");
329 }
330
331 if (output->numel() * output_multiplier != ref_numel * input_multiplier) {
332 throw std::runtime_error(
333 "output must be of size input_size * size_multiplier");
334 }
335 }
336 }
337
check_inputs(TensorList inputs,TensorList outputs,int input_multiplier,int output_multiplier)338 void check_inputs(
339 TensorList inputs,
340 TensorList outputs,
341 int input_multiplier,
342 int output_multiplier) {
343 // len(inputs) == len(outputs)
344 size_t len = inputs.size();
345
346 if (len <= 0) {
347 throw std::runtime_error("input sequence can't be empty");
348 }
349
350 if (len != outputs.size()) {
351 std::stringstream err;
352 err << "inputs and outputs sequences have to be of the same length, but got input of length "
353 << len << " and output of length " << outputs.size();
354 throw std::runtime_error(err.str());
355 }
356
357 device_set devices;
358 int64_t numel = inputs[0].numel();
359 auto dtype = inputs[0].scalar_type();
360
361 for (const auto i : c10::irange(len)) {
362 auto input = inputs[i];
363 auto output = outputs[i];
364
365 check_tensor(
366 input, output, input_multiplier, output_multiplier, numel, dtype);
367
368 auto input_device = input.get_device();
369 // inputs must be on unique devices
370 if (devices.test(input_device)) {
371 throw std::runtime_error("inputs must be on unique devices");
372 }
373 devices.set(input_device);
374 }
375 }
376
check_inputs(TensorList inputs,const at::Tensor & output,int root,int input_multiplier,int output_multiplier)377 void check_inputs(
378 TensorList inputs,
379 const at::Tensor& output,
380 int root,
381 int input_multiplier,
382 int output_multiplier) {
383 auto len = inputs.size();
384
385 if (len <= 0) {
386 throw std::runtime_error("input sequence can't be empty");
387 }
388
389 device_set devices;
390 int64_t numel = inputs[0].numel();
391 auto dtype = inputs[0].scalar_type();
392
393 for (const auto i : c10::irange(len)) {
394 auto input = inputs[i];
395
396 check_tensor(
397 input,
398 i == static_cast<std::remove_cv_t<decltype(i)>>(root)
399 ? at::optional<at::Tensor>{output}
400 : at::nullopt,
401 input_multiplier,
402 output_multiplier,
403 numel,
404 dtype);
405
406 auto input_device = input.get_device();
407 // inputs must be on unique devices
408 if (devices.test(input_device)) {
409 throw std::runtime_error("inputs must be on unique devices");
410 }
411 devices.set(input_device);
412 }
413 }
414
415 } // namespace detail
416
AutoNcclGroup()417 AutoNcclGroup::AutoNcclGroup() {
418 #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
419 // nccl < 2.0 cannot be called concurrently with cudaFree
420 (c10::cuda::getFreeMutex())->lock();
421 #endif
422 comm_nonblocking_ = false;
423 comm_ = nullptr;
424 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
425 detail::NCCL_CHECK(ncclGroupStart());
426 #endif
427 }
428
AutoNcclGroup(ncclComm_t comm,bool comm_nonblocking)429 AutoNcclGroup::AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking) {
430 #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
431 // nccl < 2.0 cannot be called concurrently with cudaFree
432 (c10::cuda::getFreeMutex())->lock();
433 #endif
434 comm_ = comm;
435 comm_nonblocking_ = comm_nonblocking;
436 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
437 detail::NCCL_CHECK(ncclGroupStart());
438 #endif
439 }
440
~AutoNcclGroup()441 AutoNcclGroup::~AutoNcclGroup() noexcept(false) {
442 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
443 if (comm_nonblocking_ && comm_ != nullptr) {
444 detail::NCCL_CHECK_TIMEOUT(ncclGroupEnd(), comm_);
445 } else {
446 detail::NCCL_CHECK(ncclGroupEnd());
447 }
448 #endif
449 #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
450 (c10::cuda::getFreeMutex())->unlock();
451 #endif
452 }
453
is_available(TensorList tensors)454 bool is_available(TensorList tensors) {
455 #ifdef USE_NCCL
456 device_set devices;
457 for (auto& tensor : tensors) {
458 if (!tensor.is_cuda() || tensor.is_sparse())
459 return false;
460 if (!tensor.is_contiguous())
461 return false;
462 auto device = tensor.get_device();
463 if (devices[device])
464 return false;
465 devices[device] = true;
466 }
467 return true;
468 #else
469 return false;
470 #endif
471 }
472
version()473 std::uint64_t version() {
474 #if defined(NCCL_MAJOR)
475 constexpr std::uint64_t ver = (((uint64_t)NCCL_MAJOR) << 32) |
476 (((uint64_t)NCCL_MINOR) << 16) | ((uint64_t)NCCL_PATCH);
477 return ver;
478 #elif defined(USE_NCCL)
479 // return major version "1"
480 return ((uint64_t)1) << 32;
481 #else
482 return 0;
483 #endif
484 }
485
version_suffix()486 const char* version_suffix() {
487 #if defined(NCCL_SUFFIX)
488 return NCCL_SUFFIX;
489 #else
490 return "";
491 #endif
492 }
493
get_unique_id(ncclUniqueId & id)494 void get_unique_id(ncclUniqueId& id) {
495 #ifdef USE_NCCL
496 using namespace torch::cuda::nccl::detail;
497 NCCL_CHECK(ncclGetUniqueId(to_nccl_unique_id(&id)));
498 #else
499 AT_ERROR("PyTorch built without NCCL support");
500 #endif
501 }
502
comm_init_rank(int nranks,const ncclUniqueId & comm_id,int rank)503 ncclComm_t comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank) {
504 #ifdef USE_NCCL
505 using namespace torch::cuda::nccl::detail;
506 ncclComm_t comm;
507 ncclUniqueId id = comm_id;
508 NCCL_CHECK(ncclCommInitRank(
509 to_nccl_comm(&comm), nranks, *(to_nccl_unique_id(&id)), rank));
510 return comm;
511 #else
512 return nullptr;
513 #endif
514 }
515
comm_destroy(ncclComm_t comm)516 void comm_destroy(ncclComm_t comm) {
517 /*
518 * TODO(T30279827) Temporarily disable calling ncclCommDestroy
519 * Calling ncclCommDestroy while program exiting is undefined
520 * according to Nvidia, and lead to segfault in NCCL 2
521 * (whether it is called before or after the CUDA runtime destructor).
522 * Temporarily disable it in destructor to avoid segfault.
523 * Following up with Nvidia for long term solution.
524 */
525 return;
526
527 #ifdef USE_NCCL
528 using namespace torch::cuda::nccl::detail;
529 NCCL_CHECK(ncclCommDestroy(to_nccl_comm(comm)));
530 #endif
531 }
532
533 namespace {
534 // NCCL changed the numerical type used for count between NCCL1 and NCCL2.
535 // So we use the following struct, which gets the type of the second argument
536 // of T, if T is a function type, with ncclBcast, to get that type statically
537 // and programmatically.
538
539 template <typename T>
540 struct GetSecondArgType;
541
542 template <typename R, typename Arg0, typename Arg1, typename... Args>
543 struct GetSecondArgType<R(Arg0, Arg1, Args...)> {
544 typedef typename std::decay<Arg1>::type type;
545 };
546
547 constexpr auto count_max =
548 std::numeric_limits<GetSecondArgType<decltype(ncclBcast)>::type>::max();
549
550 // Since NCCL 2.12.10, NCCL supports send/recv 0 byte:
551 // https://github.com/NVIDIA/nccl/issues/696. The issue of skipping send/recv
552 // is that it can cause deadlock when a rank send and recv 0 bytes so it's
553 // completely skipping the collective, causing mismatch across ranks
554 #if defined(NCCL_MAJOR) && \
555 ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR > 13)))
556 template <typename T>
_nccl_should_send_recv(C10_UNUSED T _unused_)557 constexpr bool _nccl_should_send_recv(C10_UNUSED T _unused_) {
558 return true;
559 }
560 #else
561 // old NCCL uses 0 byte message for synchronization
562 // Avoid send/recv when message size is zero
563 template <typename T>
_nccl_should_send_recv(T value)564 inline bool _nccl_should_send_recv(T value) {
565 return value != 0;
566 }
567 #endif
568 } // namespace
569
get_max_count()570 size_t get_max_count() {
571 return count_max;
572 }
573
broadcast(TensorList tensors,const stream_list & streams,const comm_list & user_comms)574 void broadcast(
575 TensorList tensors,
576 const stream_list& streams,
577 const comm_list& user_comms) {
578 #ifdef USE_NCCL
579 using namespace torch::cuda::nccl::detail;
580 check_inputs(tensors, tensors, 1, 1);
581 auto data_type = to_nccl_data_type(tensors[0]);
582 int64_t numel = tensors[0].numel();
583
584 const auto comms = user_comms.empty() ? get_communicators(tensors)
585 : ArrayRef<ncclComm_t>(user_comms);
586
587 AutoNcclGroup nccl_group_guard;
588 at::cuda::OptionalCUDAGuard device_guard;
589 for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; i++) {
590 auto device = tensors[i].get_device();
591 device_guard.set_index(device);
592 // Default to the current stream
593 const auto stream = (streams.empty() || !streams[i])
594 ? at::cuda::getCurrentCUDAStream(device).stream()
595 : streams[i]->stream();
596 TORCH_CHECK(
597 static_cast<uint64_t>(numel) <= static_cast<uint64_t>(count_max),
598 "Broadcast tensor has ",
599 numel,
600 " elements, which exceeds the "
601 "maximum NCCL supports (",
602 count_max,
603 ")");
604 ncclComm_t comm = comms[i];
605 NCCL_CHECK(ncclBcast(
606 tensors[i].data_ptr(),
607 numel,
608 data_type,
609 0,
610 to_nccl_comm(comm),
611 stream));
612 }
613 #else
614 AT_ERROR("PyTorch built without NCCL support");
615 #endif
616 }
617
reduce(const std::vector<at::Tensor> & inputs,at::Tensor & output,int32_t root,int32_t op,const stream_list & streams,const comm_list & user_comms)618 void reduce(
619 const std::vector<at::Tensor>& inputs,
620 at::Tensor& output,
621 int32_t root,
622 int32_t op,
623 const stream_list& streams,
624 const comm_list& user_comms) {
625 #ifdef USE_NCCL
626 using namespace torch::cuda::nccl::detail;
627 TORCH_CHECK(
628 root >= 0 && static_cast<size_t>(root) < inputs.size(), "invalid root");
629
630 check_inputs(inputs, output, root, 1, 1);
631 const auto len = inputs.size();
632
633 auto data_type = to_nccl_data_type(inputs[0]);
634
635 const auto count = inputs[0].numel();
636 auto comms_ref = user_comms.empty() ? get_communicators(inputs)
637 : ArrayRef<ncclComm_t>(user_comms);
638
639 AutoNcclGroup nccl_group_guard;
640 at::cuda::OptionalCUDAGuard device_guard;
641 for (const auto i : c10::irange(len)) {
642 auto device = inputs[i].device().index();
643 device_guard.set_index(device);
644 // Default to the current stream
645 const auto stream = (streams.empty() || !streams[i])
646 ? at::cuda::getCurrentCUDAStream(device).stream()
647 : streams[i]->stream();
648
649 ncclComm_t comm = comms_ref[i];
650 NCCL_CHECK(ncclReduce(
651 inputs[i].data_ptr(),
652 static_cast<std::remove_cv_t<decltype(i)>>(root) == i
653 ? output.data_ptr()
654 : nullptr,
655 count,
656 data_type,
657 to_nccl_red_op(op),
658 root,
659 to_nccl_comm(comm),
660 stream));
661 }
662 #else
663 AT_ERROR("PyTorch built without NCCL support");
664 #endif
665 }
666
reduce(std::vector<at::Tensor> & inputs,int32_t root,int32_t op,const stream_list & streams,const comm_list & user_comms)667 void reduce(
668 std::vector<at::Tensor>& inputs,
669 int32_t root,
670 int32_t op,
671 const stream_list& streams,
672 const comm_list& user_comms) {
673 reduce(inputs, /*output=*/inputs[root], root, op, streams, user_comms);
674 }
675
all_reduce(const std::vector<at::Tensor> & inputs,std::vector<at::Tensor> & outputs,int32_t op,const stream_list & streams,const comm_list & user_comms)676 void all_reduce(
677 const std::vector<at::Tensor>& inputs,
678 std::vector<at::Tensor>& outputs,
679 int32_t op,
680 const stream_list& streams,
681 const comm_list& user_comms) {
682 #ifdef USE_NCCL
683 using namespace torch::cuda::nccl::detail;
684 check_inputs(inputs, outputs, 1, 1);
685 const auto len = inputs.size();
686
687 auto data_type = to_nccl_data_type(inputs[0]);
688
689 const auto count = inputs[0].numel();
690 auto comms_ref = user_comms.empty() ? get_communicators(inputs)
691 : ArrayRef<ncclComm_t>(user_comms);
692
693 AutoNcclGroup nccl_group_guard;
694 at::cuda::OptionalCUDAGuard device_guard;
695 for (const auto i : c10::irange(len)) {
696 auto device = inputs[i].device().index();
697 device_guard.set_index(device);
698 // Default to the current stream
699 const auto stream = (streams.empty() || !streams[i])
700 ? at::cuda::getCurrentCUDAStream(device).stream()
701 : streams[i]->stream();
702
703 ncclComm_t comm = comms_ref[i];
704 NCCL_CHECK(ncclAllReduce(
705 inputs[i].data_ptr(),
706 outputs[i].data_ptr(),
707 count,
708 data_type,
709 to_nccl_red_op(op),
710 to_nccl_comm(comm),
711 stream));
712 }
713 #else
714 AT_ERROR("PyTorch built without NCCL support");
715 #endif
716 }
717
reduce_scatter(const std::vector<at::Tensor> & inputs,std::vector<at::Tensor> & outputs,int32_t op,const stream_list & streams,const comm_list & user_comms)718 void reduce_scatter(
719 const std::vector<at::Tensor>& inputs,
720 std::vector<at::Tensor>& outputs,
721 int32_t op,
722 const stream_list& streams,
723 const comm_list& user_comms) {
724 #ifdef USE_NCCL
725 using namespace torch::cuda::nccl::detail;
726 const auto len = inputs.size();
727 check_inputs(inputs, outputs, 1, len);
728
729 auto data_type = to_nccl_data_type(inputs[0]);
730
731 const auto count = inputs[0].numel() / len;
732 auto comms_ref = user_comms.empty() ? get_communicators(inputs)
733 : ArrayRef<ncclComm_t>(user_comms);
734
735 AutoNcclGroup nccl_group_guard;
736 at::cuda::OptionalCUDAGuard device_guard;
737 for (const auto i : c10::irange(len)) {
738 auto device = inputs[i].device().index();
739 device_guard.set_index(device);
740 // Default to the current stream
741 const auto stream = (streams.empty() || !streams[i])
742 ? at::cuda::getCurrentCUDAStream(device).stream()
743 : streams[i]->stream();
744
745 ncclComm_t comm = comms_ref[i];
746 NCCL_CHECK(ncclReduceScatter(
747 inputs[i].data_ptr(),
748 outputs[i].data_ptr(),
749 count,
750 data_type,
751 to_nccl_red_op(op),
752 to_nccl_comm(comm),
753 stream));
754 }
755 #else
756 AT_ERROR("PyTorch built without NCCL support");
757 #endif
758 }
759
all_gather(const std::vector<at::Tensor> & inputs,std::vector<at::Tensor> & outputs,const stream_list & streams,const comm_list & user_comms)760 void all_gather(
761 const std::vector<at::Tensor>& inputs,
762 std::vector<at::Tensor>& outputs,
763 const stream_list& streams,
764 const comm_list& user_comms) {
765 #ifdef USE_NCCL
766 using namespace torch::cuda::nccl::detail;
767 const auto len = inputs.size();
768 check_inputs(inputs, outputs, len, 1);
769
770 auto data_type = to_nccl_data_type(inputs[0]);
771
772 const auto count = inputs[0].numel();
773 auto comms_ref = user_comms.empty() ? get_communicators(inputs)
774 : ArrayRef<ncclComm_t>(user_comms);
775
776 AutoNcclGroup nccl_group_guard;
777 at::cuda::OptionalCUDAGuard device_guard;
778 for (const auto i : c10::irange(len)) {
779 auto device = inputs[i].device().index();
780 device_guard.set_index(device);
781 // Default to the current stream
782 const auto stream = (streams.empty() || !streams[i])
783 ? at::cuda::getCurrentCUDAStream(device).stream()
784 : streams[i]->stream();
785
786 ncclComm_t comm = comms_ref[i];
787 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
788 NCCL_CHECK(ncclAllGather(
789 inputs[i].data_ptr(),
790 outputs[i].data_ptr(),
791 count,
792 data_type,
793 to_nccl_comm(comm),
794 stream));
795 #else
796 NCCL_CHECK(ncclAllGather(
797 inputs[i].data_ptr(),
798 count,
799 data_type,
800 outputs[i].data_ptr(),
801 to_nccl_comm(comm),
802 stream));
803 #endif
804 }
805 #else
806 AT_ERROR("PyTorch built without NCCL support");
807 #endif
808 }
809
all2all_single_equal_split(at::Tensor & input,at::Tensor & output,int size,ncclComm_t _comm,at::cuda::CUDAStream & stream)810 void all2all_single_equal_split(
811 at::Tensor& input,
812 at::Tensor& output,
813 int size,
814 ncclComm_t _comm,
815 at::cuda::CUDAStream& stream) {
816 #ifdef USE_NCCL
817 #if defined(NCCL_MAJOR) && \
818 ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7)))
819 using namespace torch::cuda::nccl::detail;
820
821 int numranks;
822 auto type = to_nccl_data_type(input);
823 size_t count = input.numel() / size;
824 size_t rankdiff = input.nbytes() / size;
825 const auto* sendbuff = reinterpret_cast<const char*>(input.const_data_ptr());
826 auto* recvbuff = reinterpret_cast<char*>(output.data_ptr());
827 auto comm = to_nccl_comm(_comm);
828 #if defined(USE_ROCM)
829 NCCL_CHECK(ncclAllToAll(sendbuff, recvbuff, count, type, comm, stream));
830 #else
831 NCCL_CHECK(ncclCommCount(comm, &numranks));
832 NCCL_CHECK(ncclGroupStart());
833 for (const auto r : c10::irange(numranks)) {
834 if (_nccl_should_send_recv(count)) {
835 NCCL_CHECK(
836 ncclSend(sendbuff + r * rankdiff, count, type, r, comm, stream));
837 NCCL_CHECK(
838 ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream));
839 }
840 }
841 #ifndef NCCL_HAS_COMM_NONBLOCKING
842 NCCL_CHECK(ncclGroupEnd());
843 #else
844 NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
845 #endif
846 #endif
847 #else
848 AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
849 #endif
850 #else
851 AT_ERROR("PyTorch built without NCCL support");
852 #endif
853 }
854
all2all_single_unequal_split(void * sendbuff,const size_t * sendcounts,const size_t * senddispls,void * recvbuff,const size_t * recvcounts,const size_t * recvdispls,size_t size,c10::ScalarType _type,ncclComm_t _comm,at::cuda::CUDAStream & stream)855 void all2all_single_unequal_split(
856 void* sendbuff,
857 const size_t* sendcounts,
858 const size_t* senddispls,
859 void* recvbuff,
860 const size_t* recvcounts,
861 const size_t* recvdispls,
862 size_t size,
863 c10::ScalarType _type,
864 ncclComm_t _comm,
865 at::cuda::CUDAStream& stream) {
866 #ifdef USE_NCCL
867 #if defined(NCCL_MAJOR) && \
868 ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7)))
869 using namespace torch::cuda::nccl::detail;
870
871 auto type = to_nccl_data_type(_type);
872 auto comm = to_nccl_comm(_comm);
873 int numranks;
874 NCCL_CHECK(ncclCommCount(comm, &numranks));
875 NCCL_CHECK(ncclGroupStart());
876 for (const auto r : c10::irange(numranks)) {
877 if (_nccl_should_send_recv(sendcounts[r])) {
878 NCCL_CHECK(ncclSend(
879 ((char*)sendbuff) + senddispls[r] * size,
880 sendcounts[r],
881 type,
882 r,
883 comm,
884 stream));
885 }
886 if (_nccl_should_send_recv(recvcounts[r])) {
887 NCCL_CHECK(ncclRecv(
888 ((char*)recvbuff) + recvdispls[r] * size,
889 recvcounts[r],
890 type,
891 r,
892 comm,
893 stream));
894 }
895 }
896 #ifndef NCCL_HAS_COMM_NONBLOCKING
897 NCCL_CHECK(ncclGroupEnd());
898 #else
899 NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
900 #endif
901 #else
902 AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
903 #endif
904 #else
905 AT_ERROR("PyTorch built without NCCL support");
906 #endif
907 }
908
all2all(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,ncclComm_t _comm,at::cuda::CUDAStream & stream)909 void all2all(
910 std::vector<at::Tensor>& outputTensors,
911 std::vector<at::Tensor>& inputTensors,
912 ncclComm_t _comm,
913 at::cuda::CUDAStream& stream) {
914 #ifdef USE_NCCL
915 #if defined(NCCL_MAJOR) && \
916 ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7)))
917 using namespace torch::cuda::nccl::detail;
918 auto comm = to_nccl_comm(_comm);
919
920 NCCL_CHECK(ncclGroupStart());
921 for (const auto r : c10::irange(outputTensors.size())) {
922 at::Tensor& input = inputTensors[r];
923 at::Tensor& output = outputTensors[r];
924
925 if (_nccl_should_send_recv(input.numel())) {
926 NCCL_CHECK(ncclSend(
927 input.data_ptr(),
928 input.numel(),
929 to_nccl_data_type(input),
930 r,
931 comm,
932 stream.stream()));
933 }
934 if (_nccl_should_send_recv(output.numel())) {
935 NCCL_CHECK(ncclRecv(
936 output.data_ptr(),
937 output.numel(),
938 to_nccl_data_type(output),
939 r,
940 comm,
941 stream.stream()));
942 }
943 }
944 #ifndef NCCL_HAS_COMM_NONBLOCKING
945 NCCL_CHECK(ncclGroupEnd());
946 #else
947 NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
948 #endif
949 #else
950 AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
951 #endif
952 #else
953 AT_ERROR("PyTorch built without NCCL support");
954 #endif
955 }
956
send(const at::Tensor & input,ncclComm_t comm,at::cuda::CUDAStream stream,int dst)957 void send(
958 const at::Tensor& input,
959 ncclComm_t comm,
960 at::cuda::CUDAStream stream,
961 int dst) {
962 #ifdef USE_NCCL
963 #if defined(NCCL_MAJOR) && \
964 ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7)))
965 using namespace torch::cuda::nccl::detail;
966 #ifndef NCCL_HAS_COMM_NONBLOCKING
967 NCCL_CHECK(ncclSend(
968 input.data_ptr(),
969 input.numel(),
970 to_nccl_data_type(input),
971 dst,
972 to_nccl_comm(comm),
973 stream.stream()));
974 #else
975 NCCL_CHECK_TIMEOUT(
976 ncclSend(
977 input.data_ptr(),
978 input.numel(),
979 to_nccl_data_type(input),
980 dst,
981 to_nccl_comm(comm),
982 stream.stream()),
983 comm);
984 #endif
985 #else
986 AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0");
987 #endif
988 #else
989 AT_ERROR("PyTorch built without NCCL support");
990 #endif
991 }
992
recv(at::Tensor & output,ncclComm_t comm,at::cuda::CUDAStream stream,int src)993 void recv(
994 at::Tensor& output,
995 ncclComm_t comm,
996 at::cuda::CUDAStream stream,
997 int src) {
998 #ifdef USE_NCCL
999 #if defined(NCCL_MAJOR) && \
1000 ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7)))
1001 using namespace torch::cuda::nccl::detail;
1002 #ifndef NCCL_HAS_COMM_NONBLOCKING
1003 NCCL_CHECK(ncclRecv(
1004 output.data_ptr(),
1005 output.numel(),
1006 to_nccl_data_type(output),
1007 src,
1008 to_nccl_comm(comm),
1009 stream.stream()));
1010 #else
1011 NCCL_CHECK_TIMEOUT(
1012 ncclRecv(
1013 output.data_ptr(),
1014 output.numel(),
1015 to_nccl_data_type(output),
1016 src,
1017 to_nccl_comm(comm),
1018 stream.stream()),
1019 comm);
1020 #endif
1021 #else
1022 AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0");
1023 #endif
1024 #else
1025 AT_ERROR("PyTorch built without NCCL support");
1026 #endif
1027 }
1028
gather(const at::Tensor & inputs,std::vector<at::Tensor> & outputs,ncclComm_t _comm,at::cuda::CUDAStream & stream,int32_t root)1029 void gather(
1030 const at::Tensor& inputs,
1031 std::vector<at::Tensor>& outputs,
1032 ncclComm_t _comm,
1033 at::cuda::CUDAStream& stream,
1034 int32_t root) {
1035 #ifdef USE_NCCL
1036 #if defined(NCCL_MAJOR) && \
1037 ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7)))
1038 using namespace torch::cuda::nccl::detail;
1039
1040 auto comm = to_nccl_comm(_comm);
1041 int numranks, cur_rank;
1042 NCCL_CHECK(ncclCommCount(comm, &numranks));
1043 NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
1044
1045 size_t count = inputs.numel();
1046 auto type = to_nccl_data_type(inputs);
1047 const auto* sendbuff = reinterpret_cast<const char*>(inputs.const_data_ptr());
1048
1049 NCCL_CHECK(ncclGroupStart());
1050
1051 if (cur_rank == root) {
1052 for (const auto r : c10::irange(numranks)) {
1053 if (r != root) {
1054 auto* recvbuff = reinterpret_cast<char*>(outputs[r].data_ptr());
1055 NCCL_CHECK(ncclRecv(recvbuff, count, type, r, comm, stream));
1056 } else {
1057 // on its own rank, simply copy from the input
1058 outputs[r].copy_(inputs);
1059 }
1060 }
1061 } else {
1062 NCCL_CHECK(ncclSend(sendbuff, count, type, root, comm, stream));
1063 }
1064 #ifndef NCCL_HAS_COMM_NONBLOCKING
1065 NCCL_CHECK(ncclGroupEnd());
1066 #else
1067 NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
1068 #endif
1069
1070 #else
1071 AT_ERROR("gather is only supported for NCCL lib version >= 2.7.0");
1072 #endif
1073 #else
1074 AT_ERROR("PyTorch built without NCCL support");
1075 #endif
1076 }
1077
scatter(const std::vector<at::Tensor> & inputs,at::Tensor & outputs,ncclComm_t _comm,at::cuda::CUDAStream & stream,int32_t root)1078 void scatter(
1079 const std::vector<at::Tensor>& inputs,
1080 at::Tensor& outputs,
1081 ncclComm_t _comm,
1082 at::cuda::CUDAStream& stream,
1083 int32_t root) {
1084 #ifdef USE_NCCL
1085 #if defined(NCCL_MAJOR) && \
1086 ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7)))
1087 using namespace torch::cuda::nccl::detail;
1088
1089 auto comm = to_nccl_comm(_comm);
1090 int numranks, cur_rank;
1091 #ifndef NCCL_HAS_COMM_NONBLOCKING
1092 NCCL_CHECK(ncclCommCount(comm, &numranks));
1093 NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
1094 #else
1095 NCCL_CHECK_TIMEOUT(ncclCommCount(comm, &numranks), _comm);
1096 NCCL_CHECK_TIMEOUT(ncclCommUserRank(comm, &cur_rank), _comm);
1097 #endif
1098 NCCL_CHECK(ncclGroupStart());
1099 if (cur_rank == root) {
1100 for (const auto r : c10::irange(numranks)) {
1101 if (r != root) {
1102 size_t send_count = inputs[r].numel();
1103 auto send_type = to_nccl_data_type(inputs[r]);
1104 const auto* sendbuff =
1105 reinterpret_cast<const char*>(inputs[r].const_data_ptr());
1106 NCCL_CHECK(ncclSend(sendbuff, send_count, send_type, r, comm, stream));
1107 } else {
1108 // on its own rank, simply copy it to the output
1109 outputs.copy_(inputs[r]);
1110 }
1111 }
1112 } else {
1113 size_t recv_count = outputs.numel();
1114 auto recv_type = to_nccl_data_type(outputs);
1115 auto* recvbuff = reinterpret_cast<char*>(outputs.data_ptr());
1116 NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream));
1117 }
1118 #ifndef NCCL_HAS_COMM_NONBLOCKING
1119 NCCL_CHECK(ncclGroupEnd());
1120 #else
1121 NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
1122 #endif
1123 #else
1124 AT_ERROR("scatter is only supported for NCCL lib version >= 2.7.0");
1125 #endif
1126 #else
1127 AT_ERROR("PyTorch built without NCCL support");
1128 #endif
1129 }
1130
1131 } // namespace torch::cuda::nccl
1132