xref: /aosp_15_r20/external/pytorch/torch/csrc/cuda/nccl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/functional.h>
2 #include <torch/csrc/cuda/device_set.h>
3 #include <torch/csrc/cuda/nccl.h>
4 
5 #include <ATen/ATen.h>
6 #include <c10/cuda/CUDAException.h>
7 #include <c10/cuda/CUDAGuard.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/hash.h>
10 #include <c10/util/irange.h>
11 
12 #include <nccl.h>
13 
14 #include <limits>
15 #include <sstream>
16 #include <type_traits>
17 #include <unordered_map>
18 
19 #if !defined(USE_ROCM) && \
20     ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 14)))
21 #define NCCL_HAS_COMM_NONBLOCKING 1
22 #endif
23 
to_nccl_comm(torch::cuda::nccl::ncclComm_t * var)24 ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) {
25   return reinterpret_cast<ncclComm_t*>(var);
26 }
27 
to_nccl_comm(torch::cuda::nccl::ncclComm_t var)28 ncclComm_t to_nccl_comm(torch::cuda::nccl::ncclComm_t var) {
29   return reinterpret_cast<ncclComm_t>(var);
30 }
31 
to_nccl_unique_id(torch::cuda::nccl::ncclUniqueId * var)32 ncclUniqueId* to_nccl_unique_id(torch::cuda::nccl::ncclUniqueId* var) {
33   return reinterpret_cast<ncclUniqueId*>(var);
34 }
35 
to_nccl_result(torch::cuda::nccl::ncclResult var)36 ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) {
37   switch (var) {
38     case torch::cuda::nccl::ncclResult::Success:
39       return ncclResult_t::ncclSuccess;
40     case torch::cuda::nccl::ncclResult::UnhandledCudaError:
41       return ncclResult_t::ncclUnhandledCudaError;
42     case torch::cuda::nccl::ncclResult::SystemError:
43       return ncclResult_t::ncclSystemError;
44     case torch::cuda::nccl::ncclResult::InternalError:
45       return ncclResult_t::ncclInternalError;
46     case torch::cuda::nccl::ncclResult::InvalidArgument:
47       return ncclResult_t::ncclInvalidArgument;
48     case torch::cuda::nccl::ncclResult::InvalidUsage:
49       return ncclResult_t::ncclInvalidUsage;
50     case torch::cuda::nccl::ncclResult::RemoteError:
51       return ncclResult_t::ncclRemoteError;
52 #ifdef NCCL_HAS_COMM_NONBLOCKING
53     case torch::cuda::nccl::ncclResult::InProgress:
54       return ncclResult_t::ncclInProgress;
55 #endif
56     case torch::cuda::nccl::ncclResult::NumResults:
57       return ncclResult_t::ncclNumResults;
58     default:
59       throw std::runtime_error("Unconvertible NCCL type");
60   }
61 }
62 
from_nccl_result(ncclResult_t var)63 torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) {
64   switch (var) {
65     case ncclSuccess:
66       return torch::cuda::nccl::ncclResult::Success;
67     case ncclUnhandledCudaError:
68       return torch::cuda::nccl::ncclResult::UnhandledCudaError;
69     case ncclSystemError:
70       return torch::cuda::nccl::ncclResult::SystemError;
71     case ncclInternalError:
72       return torch::cuda::nccl::ncclResult::InternalError;
73     case ncclInvalidArgument:
74       return torch::cuda::nccl::ncclResult::InvalidArgument;
75     case ncclInvalidUsage:
76       return torch::cuda::nccl::ncclResult::InvalidUsage;
77     case ncclRemoteError:
78       return torch::cuda::nccl::ncclResult::RemoteError;
79 #ifdef NCCL_HAS_COMM_NONBLOCKING
80     case ncclInProgress:
81       return torch::cuda::nccl::ncclResult::InProgress;
82 #endif
83     case ncclNumResults:
84       return torch::cuda::nccl::ncclResult::NumResults;
85     default:
86       throw std::runtime_error("Unconvertible NCCL type");
87   }
88 }
89 
to_nccl_data_type(c10::ScalarType type)90 ncclDataType_t to_nccl_data_type(c10::ScalarType type) {
91   switch (type) {
92     case at::kFloat:
93       return ncclDataType_t::ncclFloat;
94     case at::kHalf:
95       return ncclDataType_t::ncclHalf;
96     case at::kDouble:
97       return ncclDataType_t::ncclDouble;
98     case at::kLong:
99       return ncclDataType_t::ncclInt64;
100     case at::kInt:
101       return ncclDataType_t::ncclInt;
102     case at::kChar:
103       return ncclDataType_t::ncclChar;
104     case at::kByte:
105       return ncclDataType_t::ncclUint8;
106     case at::kBool:
107       return ncclDataType_t::ncclUint8;
108 #if HAS_NCCL_BF16_DATATYPE
109     case at::kBFloat16:
110       return ncclDataType_t::ncclBfloat16;
111 #endif
112     default:
113       TORCH_CHECK(false, "Unconvertible NCCL type ", type);
114   }
115 }
116 
to_nccl_data_type(const at::Tensor & t)117 ncclDataType_t to_nccl_data_type(const at::Tensor& t) {
118   if (!t.is_cuda()) {
119     TORCH_CHECK(
120         false,
121         "NCCL only supports CUDA tensors, but got a tensor on ",
122         t.device());
123   }
124   return to_nccl_data_type(t.scalar_type());
125 }
126 
to_nccl_red_op(int var)127 ncclRedOp_t to_nccl_red_op(int var) {
128   return (ncclRedOp_t)(var);
129 }
130 
131 namespace torch::cuda::nccl {
132 
133 using namespace at;
134 
135 namespace detail {
136 
NCCL_CHECK(ncclResult_t result)137 static inline void NCCL_CHECK(ncclResult_t result) {
138   NCCL_CHECK(from_nccl_result(result));
139 }
140 
141 // TODO(eqy): can this duplication be avoided from NCCLUtils.cpp?
nccl_use_nonblocking()142 bool nccl_use_nonblocking() {
143   static bool nccl_use_nonblocking_ =
144       c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
145   if (nccl_use_nonblocking_) {
146     TORCH_WARN("Using experimental non-blocking NCCL communicator.");
147   }
148   return nccl_use_nonblocking_;
149 }
150 
_parse_nccl_nonblocking_timeout()151 static int _parse_nccl_nonblocking_timeout() {
152   const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT");
153   int timeout = -1;
154   if (val) {
155     const std::string config(val);
156     timeout = std::stoi(config);
157     if (!nccl_use_nonblocking() && timeout > 0) {
158       TORCH_WARN(
159           "TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false.");
160       timeout = -1;
161     }
162   }
163   return timeout;
164 }
165 
nccl_nonblocking_timeout()166 static int nccl_nonblocking_timeout() {
167   static int timeout = _parse_nccl_nonblocking_timeout();
168   return timeout;
169 }
170 
NCCL_CHECK_TIMEOUT(ncclResult status,ncclComm_t comm)171 static inline void NCCL_CHECK_TIMEOUT(ncclResult status, ncclComm_t comm) {
172 #ifdef NCCL_HAS_COMM_NONBLOCKING
173   ncclResult_t result = to_nccl_result(status);
174   auto startTimepoint = std::chrono::steady_clock::now();
175   while (result == ncclInProgress) {
176     if (nccl_nonblocking_timeout() > 0) {
177       auto currentTimepoint = std::chrono::steady_clock::now();
178       auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
179                              currentTimepoint - startTimepoint)
180                              .count();
181       if (timeElapsed > nccl_nonblocking_timeout()) {
182         throw std::runtime_error("NCCL timeout.");
183       }
184     }
185     ncclCommGetAsyncError(to_nccl_comm(comm), &result);
186   }
187   if (result != ncclSuccess) {
188     throw_nccl_error(from_nccl_result(result));
189   }
190 #else
191   TORCH_INTERNAL_ASSERT(
192       false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION.");
193 #endif
194 }
195 
NCCL_CHECK_TIMEOUT(ncclResult_t result,ncclComm_t comm)196 static inline void NCCL_CHECK_TIMEOUT(ncclResult_t result, ncclComm_t comm) {
197   NCCL_CHECK_TIMEOUT(from_nccl_result(result), comm);
198 }
199 
NCCL_CHECK_TIMEOUT(ncclResult status,std::vector<ncclComm_t> & comms)200 static inline void NCCL_CHECK_TIMEOUT(
201     ncclResult status,
202     std::vector<ncclComm_t>& comms) {
203 #ifdef NCCL_HAS_COMM_NONBLOCKING
204   ncclResult_t result = to_nccl_result(status);
205   auto startTimepoint = std::chrono::steady_clock::now();
206   if (result == ncclInProgress) {
207     for (const auto i : c10::irange(comms.size())) {
208       do {
209         if (nccl_nonblocking_timeout() > 0) {
210           auto currentTimepoint = std::chrono::steady_clock::now();
211           auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
212                                  currentTimepoint - startTimepoint)
213                                  .count();
214           if (timeElapsed > nccl_nonblocking_timeout()) {
215             throw std::runtime_error("NCCL timeout.");
216           }
217         }
218         ncclCommGetAsyncError(to_nccl_comm(comms[i]), &result);
219       } while (result == ncclInProgress);
220       if (result != ncclSuccess) {
221         break; /* fall through to failed case */
222       }
223     }
224   }
225   if (result != ncclSuccess) {
226     throw_nccl_error(from_nccl_result(result));
227   }
228 #else
229   TORCH_INTERNAL_ASSERT(
230       false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION.");
231 #endif
232 }
233 
NCCL_CHECK_TIMEOUT(ncclResult_t result,std::vector<ncclComm_t> & comms)234 static inline void NCCL_CHECK_TIMEOUT(
235     ncclResult_t result,
236     std::vector<ncclComm_t>& comms) {
237   NCCL_CHECK_TIMEOUT(from_nccl_result(result), comms);
238 }
239 
throw_nccl_error(torch::cuda::nccl::ncclResult status)240 void throw_nccl_error(torch::cuda::nccl::ncclResult status) {
241   std::ostringstream err;
242   err << "NCCL Error " << static_cast<int>(status) << ": "
243       << ncclGetErrorString(to_nccl_result(status));
244   throw std::runtime_error(err.str());
245 }
246 
247 struct NcclCommList {
248   std::unique_ptr<ncclComm_t[]> comms;
249   int ndevices;
NcclCommListtorch::cuda::nccl::detail::NcclCommList250   NcclCommList(const std::vector<int>& devices)
251       : comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) {
252     NCCL_CHECK(ncclCommInitAll(
253         to_nccl_comm(comms.get()), devices.size(), devices.data()));
254   }
255   NcclCommList(NcclCommList&& foo) = default;
~NcclCommListtorch::cuda::nccl::detail::NcclCommList256   ~NcclCommList() {
257     if (comms) {
258       for (const auto i : c10::irange(ndevices)) {
259         int dummy_var;
260         if (C10_CUDA_ERROR_HANDLED(cudaGetDevice(&dummy_var)) != cudaSuccess) {
261           /* there are cases when this destructor is called after the
262            CUDA driver is already unloaded from the process.
263            In these cases, skip ncclCommDestroy */
264           return;
265         }
266         comm_destroy(comms[i]);
267       }
268     }
269   }
reftorch::cuda::nccl::detail::NcclCommList270   ArrayRef<ncclComm_t> ref() const {
271     return ArrayRef<ncclComm_t>(comms.get(), ndevices);
272   }
273 };
274 
275 using device_list = std::vector<int>;
276 // accesses to this object have to be guarded by THC's CudaFreeMutex
277 static std::unordered_map<device_list, NcclCommList, c10::hash<device_list>>
278     _communicators;
279 
get_communicators(TensorList inputs)280 ArrayRef<ncclComm_t> get_communicators(TensorList inputs) {
281   static auto get_device = [](const at::Tensor& t) -> int {
282     return t.get_device();
283   };
284   device_list devices = fmap(inputs, get_device);
285   auto it = _communicators.find(devices);
286   if (it == _communicators.end()) {
287     it = _communicators.emplace(devices, devices).first;
288   }
289   return it->second.ref();
290 }
291 
check_tensor(const at::Tensor & input,const at::optional<at::Tensor> & output,int input_multiplier,int output_multiplier,int64_t ref_numel,ScalarType ref_dtype)292 static inline void check_tensor(
293     const at::Tensor& input,
294     const at::optional<at::Tensor>& output,
295     int input_multiplier,
296     int output_multiplier,
297     int64_t ref_numel,
298     ScalarType ref_dtype) {
299   auto check_one = [&](const at::Tensor& tensor) {
300     if (!tensor.is_cuda() || tensor.is_sparse()) {
301       throw std::runtime_error(
302           "input and output elements have to be cuda dense Tensors");
303     }
304 
305     if (ref_dtype != tensor.scalar_type()) {
306       throw std::runtime_error(
307           "all inputs and outputs must be of the same Tensor dtype");
308     }
309 
310     if (!tensor.is_contiguous()) {
311       throw std::runtime_error("all inputs and outputs have to be contiguous");
312     }
313   };
314 
315   check_one(input);
316 
317   // all inputs must be same size
318   if (input.numel() != ref_numel) {
319     throw std::runtime_error(
320         "all inputs must have the same number of elements");
321   }
322 
323   if (output) {
324     check_one(*output);
325 
326     // inputs and outputs must be on same device respectively
327     if (input.get_device() != output->get_device()) {
328       throw std::runtime_error("input and output must be on the same device");
329     }
330 
331     if (output->numel() * output_multiplier != ref_numel * input_multiplier) {
332       throw std::runtime_error(
333           "output must be of size input_size * size_multiplier");
334     }
335   }
336 }
337 
check_inputs(TensorList inputs,TensorList outputs,int input_multiplier,int output_multiplier)338 void check_inputs(
339     TensorList inputs,
340     TensorList outputs,
341     int input_multiplier,
342     int output_multiplier) {
343   // len(inputs) == len(outputs)
344   size_t len = inputs.size();
345 
346   if (len <= 0) {
347     throw std::runtime_error("input sequence can't be empty");
348   }
349 
350   if (len != outputs.size()) {
351     std::stringstream err;
352     err << "inputs and outputs sequences have to be of the same length, but got input of length "
353         << len << " and output of length " << outputs.size();
354     throw std::runtime_error(err.str());
355   }
356 
357   device_set devices;
358   int64_t numel = inputs[0].numel();
359   auto dtype = inputs[0].scalar_type();
360 
361   for (const auto i : c10::irange(len)) {
362     auto input = inputs[i];
363     auto output = outputs[i];
364 
365     check_tensor(
366         input, output, input_multiplier, output_multiplier, numel, dtype);
367 
368     auto input_device = input.get_device();
369     // inputs must be on unique devices
370     if (devices.test(input_device)) {
371       throw std::runtime_error("inputs must be on unique devices");
372     }
373     devices.set(input_device);
374   }
375 }
376 
check_inputs(TensorList inputs,const at::Tensor & output,int root,int input_multiplier,int output_multiplier)377 void check_inputs(
378     TensorList inputs,
379     const at::Tensor& output,
380     int root,
381     int input_multiplier,
382     int output_multiplier) {
383   auto len = inputs.size();
384 
385   if (len <= 0) {
386     throw std::runtime_error("input sequence can't be empty");
387   }
388 
389   device_set devices;
390   int64_t numel = inputs[0].numel();
391   auto dtype = inputs[0].scalar_type();
392 
393   for (const auto i : c10::irange(len)) {
394     auto input = inputs[i];
395 
396     check_tensor(
397         input,
398         i == static_cast<std::remove_cv_t<decltype(i)>>(root)
399             ? at::optional<at::Tensor>{output}
400             : at::nullopt,
401         input_multiplier,
402         output_multiplier,
403         numel,
404         dtype);
405 
406     auto input_device = input.get_device();
407     // inputs must be on unique devices
408     if (devices.test(input_device)) {
409       throw std::runtime_error("inputs must be on unique devices");
410     }
411     devices.set(input_device);
412   }
413 }
414 
415 } // namespace detail
416 
AutoNcclGroup()417 AutoNcclGroup::AutoNcclGroup() {
418 #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
419   // nccl < 2.0 cannot be called concurrently with cudaFree
420   (c10::cuda::getFreeMutex())->lock();
421 #endif
422   comm_nonblocking_ = false;
423   comm_ = nullptr;
424 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
425   detail::NCCL_CHECK(ncclGroupStart());
426 #endif
427 }
428 
AutoNcclGroup(ncclComm_t comm,bool comm_nonblocking)429 AutoNcclGroup::AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking) {
430 #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
431   // nccl < 2.0 cannot be called concurrently with cudaFree
432   (c10::cuda::getFreeMutex())->lock();
433 #endif
434   comm_ = comm;
435   comm_nonblocking_ = comm_nonblocking;
436 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
437   detail::NCCL_CHECK(ncclGroupStart());
438 #endif
439 }
440 
~AutoNcclGroup()441 AutoNcclGroup::~AutoNcclGroup() noexcept(false) {
442 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
443   if (comm_nonblocking_ && comm_ != nullptr) {
444     detail::NCCL_CHECK_TIMEOUT(ncclGroupEnd(), comm_);
445   } else {
446     detail::NCCL_CHECK(ncclGroupEnd());
447   }
448 #endif
449 #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
450   (c10::cuda::getFreeMutex())->unlock();
451 #endif
452 }
453 
is_available(TensorList tensors)454 bool is_available(TensorList tensors) {
455 #ifdef USE_NCCL
456   device_set devices;
457   for (auto& tensor : tensors) {
458     if (!tensor.is_cuda() || tensor.is_sparse())
459       return false;
460     if (!tensor.is_contiguous())
461       return false;
462     auto device = tensor.get_device();
463     if (devices[device])
464       return false;
465     devices[device] = true;
466   }
467   return true;
468 #else
469   return false;
470 #endif
471 }
472 
version()473 std::uint64_t version() {
474 #if defined(NCCL_MAJOR)
475   constexpr std::uint64_t ver = (((uint64_t)NCCL_MAJOR) << 32) |
476       (((uint64_t)NCCL_MINOR) << 16) | ((uint64_t)NCCL_PATCH);
477   return ver;
478 #elif defined(USE_NCCL)
479   // return major version "1"
480   return ((uint64_t)1) << 32;
481 #else
482   return 0;
483 #endif
484 }
485 
version_suffix()486 const char* version_suffix() {
487 #if defined(NCCL_SUFFIX)
488   return NCCL_SUFFIX;
489 #else
490   return "";
491 #endif
492 }
493 
get_unique_id(ncclUniqueId & id)494 void get_unique_id(ncclUniqueId& id) {
495 #ifdef USE_NCCL
496   using namespace torch::cuda::nccl::detail;
497   NCCL_CHECK(ncclGetUniqueId(to_nccl_unique_id(&id)));
498 #else
499   AT_ERROR("PyTorch built without NCCL support");
500 #endif
501 }
502 
comm_init_rank(int nranks,const ncclUniqueId & comm_id,int rank)503 ncclComm_t comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank) {
504 #ifdef USE_NCCL
505   using namespace torch::cuda::nccl::detail;
506   ncclComm_t comm;
507   ncclUniqueId id = comm_id;
508   NCCL_CHECK(ncclCommInitRank(
509       to_nccl_comm(&comm), nranks, *(to_nccl_unique_id(&id)), rank));
510   return comm;
511 #else
512   return nullptr;
513 #endif
514 }
515 
comm_destroy(ncclComm_t comm)516 void comm_destroy(ncclComm_t comm) {
517   /*
518    * TODO(T30279827) Temporarily disable calling ncclCommDestroy
519    * Calling ncclCommDestroy while program exiting is undefined
520    * according to Nvidia, and lead to segfault in NCCL 2
521    * (whether it is called before or after the CUDA runtime destructor).
522    * Temporarily disable it in destructor to avoid segfault.
523    * Following up with Nvidia for long term solution.
524    */
525   return;
526 
527 #ifdef USE_NCCL
528   using namespace torch::cuda::nccl::detail;
529   NCCL_CHECK(ncclCommDestroy(to_nccl_comm(comm)));
530 #endif
531 }
532 
533 namespace {
534 // NCCL changed the numerical type used for count between NCCL1 and NCCL2.
535 // So we use the following struct, which gets the type of the second argument
536 // of T, if T is a function type, with ncclBcast, to get that type statically
537 // and programmatically.
538 
539 template <typename T>
540 struct GetSecondArgType;
541 
542 template <typename R, typename Arg0, typename Arg1, typename... Args>
543 struct GetSecondArgType<R(Arg0, Arg1, Args...)> {
544   typedef typename std::decay<Arg1>::type type;
545 };
546 
547 constexpr auto count_max =
548     std::numeric_limits<GetSecondArgType<decltype(ncclBcast)>::type>::max();
549 
550 // Since NCCL 2.12.10, NCCL supports send/recv 0 byte:
551 // https://github.com/NVIDIA/nccl/issues/696. The issue of skipping send/recv
552 // is that it can cause deadlock when a rank send and recv 0 bytes so it's
553 // completely skipping the collective, causing mismatch across ranks
554 #if defined(NCCL_MAJOR) && \
555     ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR > 13)))
556 template <typename T>
_nccl_should_send_recv(C10_UNUSED T _unused_)557 constexpr bool _nccl_should_send_recv(C10_UNUSED T _unused_) {
558   return true;
559 }
560 #else
561 // old NCCL uses 0 byte message for synchronization
562 // Avoid send/recv when message size is zero
563 template <typename T>
_nccl_should_send_recv(T value)564 inline bool _nccl_should_send_recv(T value) {
565   return value != 0;
566 }
567 #endif
568 } // namespace
569 
get_max_count()570 size_t get_max_count() {
571   return count_max;
572 }
573 
broadcast(TensorList tensors,const stream_list & streams,const comm_list & user_comms)574 void broadcast(
575     TensorList tensors,
576     const stream_list& streams,
577     const comm_list& user_comms) {
578 #ifdef USE_NCCL
579   using namespace torch::cuda::nccl::detail;
580   check_inputs(tensors, tensors, 1, 1);
581   auto data_type = to_nccl_data_type(tensors[0]);
582   int64_t numel = tensors[0].numel();
583 
584   const auto comms = user_comms.empty() ? get_communicators(tensors)
585                                         : ArrayRef<ncclComm_t>(user_comms);
586 
587   AutoNcclGroup nccl_group_guard;
588   at::cuda::OptionalCUDAGuard device_guard;
589   for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; i++) {
590     auto device = tensors[i].get_device();
591     device_guard.set_index(device);
592     // Default to the current stream
593     const auto stream = (streams.empty() || !streams[i])
594         ? at::cuda::getCurrentCUDAStream(device).stream()
595         : streams[i]->stream();
596     TORCH_CHECK(
597         static_cast<uint64_t>(numel) <= static_cast<uint64_t>(count_max),
598         "Broadcast tensor has ",
599         numel,
600         " elements, which exceeds the "
601         "maximum NCCL supports (",
602         count_max,
603         ")");
604     ncclComm_t comm = comms[i];
605     NCCL_CHECK(ncclBcast(
606         tensors[i].data_ptr(),
607         numel,
608         data_type,
609         0,
610         to_nccl_comm(comm),
611         stream));
612   }
613 #else
614   AT_ERROR("PyTorch built without NCCL support");
615 #endif
616 }
617 
reduce(const std::vector<at::Tensor> & inputs,at::Tensor & output,int32_t root,int32_t op,const stream_list & streams,const comm_list & user_comms)618 void reduce(
619     const std::vector<at::Tensor>& inputs,
620     at::Tensor& output,
621     int32_t root,
622     int32_t op,
623     const stream_list& streams,
624     const comm_list& user_comms) {
625 #ifdef USE_NCCL
626   using namespace torch::cuda::nccl::detail;
627   TORCH_CHECK(
628       root >= 0 && static_cast<size_t>(root) < inputs.size(), "invalid root");
629 
630   check_inputs(inputs, output, root, 1, 1);
631   const auto len = inputs.size();
632 
633   auto data_type = to_nccl_data_type(inputs[0]);
634 
635   const auto count = inputs[0].numel();
636   auto comms_ref = user_comms.empty() ? get_communicators(inputs)
637                                       : ArrayRef<ncclComm_t>(user_comms);
638 
639   AutoNcclGroup nccl_group_guard;
640   at::cuda::OptionalCUDAGuard device_guard;
641   for (const auto i : c10::irange(len)) {
642     auto device = inputs[i].device().index();
643     device_guard.set_index(device);
644     // Default to the current stream
645     const auto stream = (streams.empty() || !streams[i])
646         ? at::cuda::getCurrentCUDAStream(device).stream()
647         : streams[i]->stream();
648 
649     ncclComm_t comm = comms_ref[i];
650     NCCL_CHECK(ncclReduce(
651         inputs[i].data_ptr(),
652         static_cast<std::remove_cv_t<decltype(i)>>(root) == i
653             ? output.data_ptr()
654             : nullptr,
655         count,
656         data_type,
657         to_nccl_red_op(op),
658         root,
659         to_nccl_comm(comm),
660         stream));
661   }
662 #else
663   AT_ERROR("PyTorch built without NCCL support");
664 #endif
665 }
666 
reduce(std::vector<at::Tensor> & inputs,int32_t root,int32_t op,const stream_list & streams,const comm_list & user_comms)667 void reduce(
668     std::vector<at::Tensor>& inputs,
669     int32_t root,
670     int32_t op,
671     const stream_list& streams,
672     const comm_list& user_comms) {
673   reduce(inputs, /*output=*/inputs[root], root, op, streams, user_comms);
674 }
675 
all_reduce(const std::vector<at::Tensor> & inputs,std::vector<at::Tensor> & outputs,int32_t op,const stream_list & streams,const comm_list & user_comms)676 void all_reduce(
677     const std::vector<at::Tensor>& inputs,
678     std::vector<at::Tensor>& outputs,
679     int32_t op,
680     const stream_list& streams,
681     const comm_list& user_comms) {
682 #ifdef USE_NCCL
683   using namespace torch::cuda::nccl::detail;
684   check_inputs(inputs, outputs, 1, 1);
685   const auto len = inputs.size();
686 
687   auto data_type = to_nccl_data_type(inputs[0]);
688 
689   const auto count = inputs[0].numel();
690   auto comms_ref = user_comms.empty() ? get_communicators(inputs)
691                                       : ArrayRef<ncclComm_t>(user_comms);
692 
693   AutoNcclGroup nccl_group_guard;
694   at::cuda::OptionalCUDAGuard device_guard;
695   for (const auto i : c10::irange(len)) {
696     auto device = inputs[i].device().index();
697     device_guard.set_index(device);
698     // Default to the current stream
699     const auto stream = (streams.empty() || !streams[i])
700         ? at::cuda::getCurrentCUDAStream(device).stream()
701         : streams[i]->stream();
702 
703     ncclComm_t comm = comms_ref[i];
704     NCCL_CHECK(ncclAllReduce(
705         inputs[i].data_ptr(),
706         outputs[i].data_ptr(),
707         count,
708         data_type,
709         to_nccl_red_op(op),
710         to_nccl_comm(comm),
711         stream));
712   }
713 #else
714   AT_ERROR("PyTorch built without NCCL support");
715 #endif
716 }
717 
reduce_scatter(const std::vector<at::Tensor> & inputs,std::vector<at::Tensor> & outputs,int32_t op,const stream_list & streams,const comm_list & user_comms)718 void reduce_scatter(
719     const std::vector<at::Tensor>& inputs,
720     std::vector<at::Tensor>& outputs,
721     int32_t op,
722     const stream_list& streams,
723     const comm_list& user_comms) {
724 #ifdef USE_NCCL
725   using namespace torch::cuda::nccl::detail;
726   const auto len = inputs.size();
727   check_inputs(inputs, outputs, 1, len);
728 
729   auto data_type = to_nccl_data_type(inputs[0]);
730 
731   const auto count = inputs[0].numel() / len;
732   auto comms_ref = user_comms.empty() ? get_communicators(inputs)
733                                       : ArrayRef<ncclComm_t>(user_comms);
734 
735   AutoNcclGroup nccl_group_guard;
736   at::cuda::OptionalCUDAGuard device_guard;
737   for (const auto i : c10::irange(len)) {
738     auto device = inputs[i].device().index();
739     device_guard.set_index(device);
740     // Default to the current stream
741     const auto stream = (streams.empty() || !streams[i])
742         ? at::cuda::getCurrentCUDAStream(device).stream()
743         : streams[i]->stream();
744 
745     ncclComm_t comm = comms_ref[i];
746     NCCL_CHECK(ncclReduceScatter(
747         inputs[i].data_ptr(),
748         outputs[i].data_ptr(),
749         count,
750         data_type,
751         to_nccl_red_op(op),
752         to_nccl_comm(comm),
753         stream));
754   }
755 #else
756   AT_ERROR("PyTorch built without NCCL support");
757 #endif
758 }
759 
all_gather(const std::vector<at::Tensor> & inputs,std::vector<at::Tensor> & outputs,const stream_list & streams,const comm_list & user_comms)760 void all_gather(
761     const std::vector<at::Tensor>& inputs,
762     std::vector<at::Tensor>& outputs,
763     const stream_list& streams,
764     const comm_list& user_comms) {
765 #ifdef USE_NCCL
766   using namespace torch::cuda::nccl::detail;
767   const auto len = inputs.size();
768   check_inputs(inputs, outputs, len, 1);
769 
770   auto data_type = to_nccl_data_type(inputs[0]);
771 
772   const auto count = inputs[0].numel();
773   auto comms_ref = user_comms.empty() ? get_communicators(inputs)
774                                       : ArrayRef<ncclComm_t>(user_comms);
775 
776   AutoNcclGroup nccl_group_guard;
777   at::cuda::OptionalCUDAGuard device_guard;
778   for (const auto i : c10::irange(len)) {
779     auto device = inputs[i].device().index();
780     device_guard.set_index(device);
781     // Default to the current stream
782     const auto stream = (streams.empty() || !streams[i])
783         ? at::cuda::getCurrentCUDAStream(device).stream()
784         : streams[i]->stream();
785 
786     ncclComm_t comm = comms_ref[i];
787 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
788     NCCL_CHECK(ncclAllGather(
789         inputs[i].data_ptr(),
790         outputs[i].data_ptr(),
791         count,
792         data_type,
793         to_nccl_comm(comm),
794         stream));
795 #else
796     NCCL_CHECK(ncclAllGather(
797         inputs[i].data_ptr(),
798         count,
799         data_type,
800         outputs[i].data_ptr(),
801         to_nccl_comm(comm),
802         stream));
803 #endif
804   }
805 #else
806   AT_ERROR("PyTorch built without NCCL support");
807 #endif
808 }
809 
all2all_single_equal_split(at::Tensor & input,at::Tensor & output,int size,ncclComm_t _comm,at::cuda::CUDAStream & stream)810 void all2all_single_equal_split(
811     at::Tensor& input,
812     at::Tensor& output,
813     int size,
814     ncclComm_t _comm,
815     at::cuda::CUDAStream& stream) {
816 #ifdef USE_NCCL
817 #if defined(NCCL_MAJOR) && \
818     ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7)))
819   using namespace torch::cuda::nccl::detail;
820 
821   int numranks;
822   auto type = to_nccl_data_type(input);
823   size_t count = input.numel() / size;
824   size_t rankdiff = input.nbytes() / size;
825   const auto* sendbuff = reinterpret_cast<const char*>(input.const_data_ptr());
826   auto* recvbuff = reinterpret_cast<char*>(output.data_ptr());
827   auto comm = to_nccl_comm(_comm);
828 #if defined(USE_ROCM)
829   NCCL_CHECK(ncclAllToAll(sendbuff, recvbuff, count, type, comm, stream));
830 #else
831   NCCL_CHECK(ncclCommCount(comm, &numranks));
832   NCCL_CHECK(ncclGroupStart());
833   for (const auto r : c10::irange(numranks)) {
834     if (_nccl_should_send_recv(count)) {
835       NCCL_CHECK(
836           ncclSend(sendbuff + r * rankdiff, count, type, r, comm, stream));
837       NCCL_CHECK(
838           ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream));
839     }
840   }
841 #ifndef NCCL_HAS_COMM_NONBLOCKING
842   NCCL_CHECK(ncclGroupEnd());
843 #else
844   NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
845 #endif
846 #endif
847 #else
848   AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
849 #endif
850 #else
851   AT_ERROR("PyTorch built without NCCL support");
852 #endif
853 }
854 
all2all_single_unequal_split(void * sendbuff,const size_t * sendcounts,const size_t * senddispls,void * recvbuff,const size_t * recvcounts,const size_t * recvdispls,size_t size,c10::ScalarType _type,ncclComm_t _comm,at::cuda::CUDAStream & stream)855 void all2all_single_unequal_split(
856     void* sendbuff,
857     const size_t* sendcounts,
858     const size_t* senddispls,
859     void* recvbuff,
860     const size_t* recvcounts,
861     const size_t* recvdispls,
862     size_t size,
863     c10::ScalarType _type,
864     ncclComm_t _comm,
865     at::cuda::CUDAStream& stream) {
866 #ifdef USE_NCCL
867 #if defined(NCCL_MAJOR) && \
868     ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7)))
869   using namespace torch::cuda::nccl::detail;
870 
871   auto type = to_nccl_data_type(_type);
872   auto comm = to_nccl_comm(_comm);
873   int numranks;
874   NCCL_CHECK(ncclCommCount(comm, &numranks));
875   NCCL_CHECK(ncclGroupStart());
876   for (const auto r : c10::irange(numranks)) {
877     if (_nccl_should_send_recv(sendcounts[r])) {
878       NCCL_CHECK(ncclSend(
879           ((char*)sendbuff) + senddispls[r] * size,
880           sendcounts[r],
881           type,
882           r,
883           comm,
884           stream));
885     }
886     if (_nccl_should_send_recv(recvcounts[r])) {
887       NCCL_CHECK(ncclRecv(
888           ((char*)recvbuff) + recvdispls[r] * size,
889           recvcounts[r],
890           type,
891           r,
892           comm,
893           stream));
894     }
895   }
896 #ifndef NCCL_HAS_COMM_NONBLOCKING
897   NCCL_CHECK(ncclGroupEnd());
898 #else
899   NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
900 #endif
901 #else
902   AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
903 #endif
904 #else
905   AT_ERROR("PyTorch built without NCCL support");
906 #endif
907 }
908 
all2all(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,ncclComm_t _comm,at::cuda::CUDAStream & stream)909 void all2all(
910     std::vector<at::Tensor>& outputTensors,
911     std::vector<at::Tensor>& inputTensors,
912     ncclComm_t _comm,
913     at::cuda::CUDAStream& stream) {
914 #ifdef USE_NCCL
915 #if defined(NCCL_MAJOR) && \
916     ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7)))
917   using namespace torch::cuda::nccl::detail;
918   auto comm = to_nccl_comm(_comm);
919 
920   NCCL_CHECK(ncclGroupStart());
921   for (const auto r : c10::irange(outputTensors.size())) {
922     at::Tensor& input = inputTensors[r];
923     at::Tensor& output = outputTensors[r];
924 
925     if (_nccl_should_send_recv(input.numel())) {
926       NCCL_CHECK(ncclSend(
927           input.data_ptr(),
928           input.numel(),
929           to_nccl_data_type(input),
930           r,
931           comm,
932           stream.stream()));
933     }
934     if (_nccl_should_send_recv(output.numel())) {
935       NCCL_CHECK(ncclRecv(
936           output.data_ptr(),
937           output.numel(),
938           to_nccl_data_type(output),
939           r,
940           comm,
941           stream.stream()));
942     }
943   }
944 #ifndef NCCL_HAS_COMM_NONBLOCKING
945   NCCL_CHECK(ncclGroupEnd());
946 #else
947   NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
948 #endif
949 #else
950   AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
951 #endif
952 #else
953   AT_ERROR("PyTorch built without NCCL support");
954 #endif
955 }
956 
send(const at::Tensor & input,ncclComm_t comm,at::cuda::CUDAStream stream,int dst)957 void send(
958     const at::Tensor& input,
959     ncclComm_t comm,
960     at::cuda::CUDAStream stream,
961     int dst) {
962 #ifdef USE_NCCL
963 #if defined(NCCL_MAJOR) && \
964     ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7)))
965   using namespace torch::cuda::nccl::detail;
966 #ifndef NCCL_HAS_COMM_NONBLOCKING
967   NCCL_CHECK(ncclSend(
968       input.data_ptr(),
969       input.numel(),
970       to_nccl_data_type(input),
971       dst,
972       to_nccl_comm(comm),
973       stream.stream()));
974 #else
975   NCCL_CHECK_TIMEOUT(
976       ncclSend(
977           input.data_ptr(),
978           input.numel(),
979           to_nccl_data_type(input),
980           dst,
981           to_nccl_comm(comm),
982           stream.stream()),
983       comm);
984 #endif
985 #else
986   AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0");
987 #endif
988 #else
989   AT_ERROR("PyTorch built without NCCL support");
990 #endif
991 }
992 
recv(at::Tensor & output,ncclComm_t comm,at::cuda::CUDAStream stream,int src)993 void recv(
994     at::Tensor& output,
995     ncclComm_t comm,
996     at::cuda::CUDAStream stream,
997     int src) {
998 #ifdef USE_NCCL
999 #if defined(NCCL_MAJOR) && \
1000     ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7)))
1001   using namespace torch::cuda::nccl::detail;
1002 #ifndef NCCL_HAS_COMM_NONBLOCKING
1003   NCCL_CHECK(ncclRecv(
1004       output.data_ptr(),
1005       output.numel(),
1006       to_nccl_data_type(output),
1007       src,
1008       to_nccl_comm(comm),
1009       stream.stream()));
1010 #else
1011   NCCL_CHECK_TIMEOUT(
1012       ncclRecv(
1013           output.data_ptr(),
1014           output.numel(),
1015           to_nccl_data_type(output),
1016           src,
1017           to_nccl_comm(comm),
1018           stream.stream()),
1019       comm);
1020 #endif
1021 #else
1022   AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0");
1023 #endif
1024 #else
1025   AT_ERROR("PyTorch built without NCCL support");
1026 #endif
1027 }
1028 
gather(const at::Tensor & inputs,std::vector<at::Tensor> & outputs,ncclComm_t _comm,at::cuda::CUDAStream & stream,int32_t root)1029 void gather(
1030     const at::Tensor& inputs,
1031     std::vector<at::Tensor>& outputs,
1032     ncclComm_t _comm,
1033     at::cuda::CUDAStream& stream,
1034     int32_t root) {
1035 #ifdef USE_NCCL
1036 #if defined(NCCL_MAJOR) && \
1037     ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7)))
1038   using namespace torch::cuda::nccl::detail;
1039 
1040   auto comm = to_nccl_comm(_comm);
1041   int numranks, cur_rank;
1042   NCCL_CHECK(ncclCommCount(comm, &numranks));
1043   NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
1044 
1045   size_t count = inputs.numel();
1046   auto type = to_nccl_data_type(inputs);
1047   const auto* sendbuff = reinterpret_cast<const char*>(inputs.const_data_ptr());
1048 
1049   NCCL_CHECK(ncclGroupStart());
1050 
1051   if (cur_rank == root) {
1052     for (const auto r : c10::irange(numranks)) {
1053       if (r != root) {
1054         auto* recvbuff = reinterpret_cast<char*>(outputs[r].data_ptr());
1055         NCCL_CHECK(ncclRecv(recvbuff, count, type, r, comm, stream));
1056       } else {
1057         // on its own rank, simply copy from the input
1058         outputs[r].copy_(inputs);
1059       }
1060     }
1061   } else {
1062     NCCL_CHECK(ncclSend(sendbuff, count, type, root, comm, stream));
1063   }
1064 #ifndef NCCL_HAS_COMM_NONBLOCKING
1065   NCCL_CHECK(ncclGroupEnd());
1066 #else
1067   NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
1068 #endif
1069 
1070 #else
1071   AT_ERROR("gather is only supported for NCCL lib version >= 2.7.0");
1072 #endif
1073 #else
1074   AT_ERROR("PyTorch built without NCCL support");
1075 #endif
1076 }
1077 
scatter(const std::vector<at::Tensor> & inputs,at::Tensor & outputs,ncclComm_t _comm,at::cuda::CUDAStream & stream,int32_t root)1078 void scatter(
1079     const std::vector<at::Tensor>& inputs,
1080     at::Tensor& outputs,
1081     ncclComm_t _comm,
1082     at::cuda::CUDAStream& stream,
1083     int32_t root) {
1084 #ifdef USE_NCCL
1085 #if defined(NCCL_MAJOR) && \
1086     ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7)))
1087   using namespace torch::cuda::nccl::detail;
1088 
1089   auto comm = to_nccl_comm(_comm);
1090   int numranks, cur_rank;
1091 #ifndef NCCL_HAS_COMM_NONBLOCKING
1092   NCCL_CHECK(ncclCommCount(comm, &numranks));
1093   NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
1094 #else
1095   NCCL_CHECK_TIMEOUT(ncclCommCount(comm, &numranks), _comm);
1096   NCCL_CHECK_TIMEOUT(ncclCommUserRank(comm, &cur_rank), _comm);
1097 #endif
1098   NCCL_CHECK(ncclGroupStart());
1099   if (cur_rank == root) {
1100     for (const auto r : c10::irange(numranks)) {
1101       if (r != root) {
1102         size_t send_count = inputs[r].numel();
1103         auto send_type = to_nccl_data_type(inputs[r]);
1104         const auto* sendbuff =
1105             reinterpret_cast<const char*>(inputs[r].const_data_ptr());
1106         NCCL_CHECK(ncclSend(sendbuff, send_count, send_type, r, comm, stream));
1107       } else {
1108         // on its own rank, simply copy it to the output
1109         outputs.copy_(inputs[r]);
1110       }
1111     }
1112   } else {
1113     size_t recv_count = outputs.numel();
1114     auto recv_type = to_nccl_data_type(outputs);
1115     auto* recvbuff = reinterpret_cast<char*>(outputs.data_ptr());
1116     NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream));
1117   }
1118 #ifndef NCCL_HAS_COMM_NONBLOCKING
1119   NCCL_CHECK(ncclGroupEnd());
1120 #else
1121   NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
1122 #endif
1123 #else
1124   AT_ERROR("scatter is only supported for NCCL lib version >= 2.7.0");
1125 #endif
1126 #else
1127   AT_ERROR("PyTorch built without NCCL support");
1128 #endif
1129 }
1130 
1131 } // namespace torch::cuda::nccl
1132