#include #include #include #include #include #include #include #include #include #include #include #include #include #include #if !defined(USE_ROCM) && \ ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 14))) #define NCCL_HAS_COMM_NONBLOCKING 1 #endif ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) { return reinterpret_cast(var); } ncclComm_t to_nccl_comm(torch::cuda::nccl::ncclComm_t var) { return reinterpret_cast(var); } ncclUniqueId* to_nccl_unique_id(torch::cuda::nccl::ncclUniqueId* var) { return reinterpret_cast(var); } ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) { switch (var) { case torch::cuda::nccl::ncclResult::Success: return ncclResult_t::ncclSuccess; case torch::cuda::nccl::ncclResult::UnhandledCudaError: return ncclResult_t::ncclUnhandledCudaError; case torch::cuda::nccl::ncclResult::SystemError: return ncclResult_t::ncclSystemError; case torch::cuda::nccl::ncclResult::InternalError: return ncclResult_t::ncclInternalError; case torch::cuda::nccl::ncclResult::InvalidArgument: return ncclResult_t::ncclInvalidArgument; case torch::cuda::nccl::ncclResult::InvalidUsage: return ncclResult_t::ncclInvalidUsage; case torch::cuda::nccl::ncclResult::RemoteError: return ncclResult_t::ncclRemoteError; #ifdef NCCL_HAS_COMM_NONBLOCKING case torch::cuda::nccl::ncclResult::InProgress: return ncclResult_t::ncclInProgress; #endif case torch::cuda::nccl::ncclResult::NumResults: return ncclResult_t::ncclNumResults; default: throw std::runtime_error("Unconvertible NCCL type"); } } torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) { switch (var) { case ncclSuccess: return torch::cuda::nccl::ncclResult::Success; case ncclUnhandledCudaError: return torch::cuda::nccl::ncclResult::UnhandledCudaError; case ncclSystemError: return torch::cuda::nccl::ncclResult::SystemError; case ncclInternalError: return torch::cuda::nccl::ncclResult::InternalError; case ncclInvalidArgument: return torch::cuda::nccl::ncclResult::InvalidArgument; case ncclInvalidUsage: return torch::cuda::nccl::ncclResult::InvalidUsage; case ncclRemoteError: return torch::cuda::nccl::ncclResult::RemoteError; #ifdef NCCL_HAS_COMM_NONBLOCKING case ncclInProgress: return torch::cuda::nccl::ncclResult::InProgress; #endif case ncclNumResults: return torch::cuda::nccl::ncclResult::NumResults; default: throw std::runtime_error("Unconvertible NCCL type"); } } ncclDataType_t to_nccl_data_type(c10::ScalarType type) { switch (type) { case at::kFloat: return ncclDataType_t::ncclFloat; case at::kHalf: return ncclDataType_t::ncclHalf; case at::kDouble: return ncclDataType_t::ncclDouble; case at::kLong: return ncclDataType_t::ncclInt64; case at::kInt: return ncclDataType_t::ncclInt; case at::kChar: return ncclDataType_t::ncclChar; case at::kByte: return ncclDataType_t::ncclUint8; case at::kBool: return ncclDataType_t::ncclUint8; #if HAS_NCCL_BF16_DATATYPE case at::kBFloat16: return ncclDataType_t::ncclBfloat16; #endif default: TORCH_CHECK(false, "Unconvertible NCCL type ", type); } } ncclDataType_t to_nccl_data_type(const at::Tensor& t) { if (!t.is_cuda()) { TORCH_CHECK( false, "NCCL only supports CUDA tensors, but got a tensor on ", t.device()); } return to_nccl_data_type(t.scalar_type()); } ncclRedOp_t to_nccl_red_op(int var) { return (ncclRedOp_t)(var); } namespace torch::cuda::nccl { using namespace at; namespace detail { static inline void NCCL_CHECK(ncclResult_t result) { NCCL_CHECK(from_nccl_result(result)); } // TODO(eqy): can this duplication be avoided from NCCLUtils.cpp? bool nccl_use_nonblocking() { static bool nccl_use_nonblocking_ = c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true; if (nccl_use_nonblocking_) { TORCH_WARN("Using experimental non-blocking NCCL communicator."); } return nccl_use_nonblocking_; } static int _parse_nccl_nonblocking_timeout() { const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT"); int timeout = -1; if (val) { const std::string config(val); timeout = std::stoi(config); if (!nccl_use_nonblocking() && timeout > 0) { TORCH_WARN( "TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false."); timeout = -1; } } return timeout; } static int nccl_nonblocking_timeout() { static int timeout = _parse_nccl_nonblocking_timeout(); return timeout; } static inline void NCCL_CHECK_TIMEOUT(ncclResult status, ncclComm_t comm) { #ifdef NCCL_HAS_COMM_NONBLOCKING ncclResult_t result = to_nccl_result(status); auto startTimepoint = std::chrono::steady_clock::now(); while (result == ncclInProgress) { if (nccl_nonblocking_timeout() > 0) { auto currentTimepoint = std::chrono::steady_clock::now(); auto timeElapsed = std::chrono::duration_cast( currentTimepoint - startTimepoint) .count(); if (timeElapsed > nccl_nonblocking_timeout()) { throw std::runtime_error("NCCL timeout."); } } ncclCommGetAsyncError(to_nccl_comm(comm), &result); } if (result != ncclSuccess) { throw_nccl_error(from_nccl_result(result)); } #else TORCH_INTERNAL_ASSERT( false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION."); #endif } static inline void NCCL_CHECK_TIMEOUT(ncclResult_t result, ncclComm_t comm) { NCCL_CHECK_TIMEOUT(from_nccl_result(result), comm); } static inline void NCCL_CHECK_TIMEOUT( ncclResult status, std::vector& comms) { #ifdef NCCL_HAS_COMM_NONBLOCKING ncclResult_t result = to_nccl_result(status); auto startTimepoint = std::chrono::steady_clock::now(); if (result == ncclInProgress) { for (const auto i : c10::irange(comms.size())) { do { if (nccl_nonblocking_timeout() > 0) { auto currentTimepoint = std::chrono::steady_clock::now(); auto timeElapsed = std::chrono::duration_cast( currentTimepoint - startTimepoint) .count(); if (timeElapsed > nccl_nonblocking_timeout()) { throw std::runtime_error("NCCL timeout."); } } ncclCommGetAsyncError(to_nccl_comm(comms[i]), &result); } while (result == ncclInProgress); if (result != ncclSuccess) { break; /* fall through to failed case */ } } } if (result != ncclSuccess) { throw_nccl_error(from_nccl_result(result)); } #else TORCH_INTERNAL_ASSERT( false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION."); #endif } static inline void NCCL_CHECK_TIMEOUT( ncclResult_t result, std::vector& comms) { NCCL_CHECK_TIMEOUT(from_nccl_result(result), comms); } void throw_nccl_error(torch::cuda::nccl::ncclResult status) { std::ostringstream err; err << "NCCL Error " << static_cast(status) << ": " << ncclGetErrorString(to_nccl_result(status)); throw std::runtime_error(err.str()); } struct NcclCommList { std::unique_ptr comms; int ndevices; NcclCommList(const std::vector& devices) : comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) { NCCL_CHECK(ncclCommInitAll( to_nccl_comm(comms.get()), devices.size(), devices.data())); } NcclCommList(NcclCommList&& foo) = default; ~NcclCommList() { if (comms) { for (const auto i : c10::irange(ndevices)) { int dummy_var; if (C10_CUDA_ERROR_HANDLED(cudaGetDevice(&dummy_var)) != cudaSuccess) { /* there are cases when this destructor is called after the CUDA driver is already unloaded from the process. In these cases, skip ncclCommDestroy */ return; } comm_destroy(comms[i]); } } } ArrayRef ref() const { return ArrayRef(comms.get(), ndevices); } }; using device_list = std::vector; // accesses to this object have to be guarded by THC's CudaFreeMutex static std::unordered_map> _communicators; ArrayRef get_communicators(TensorList inputs) { static auto get_device = [](const at::Tensor& t) -> int { return t.get_device(); }; device_list devices = fmap(inputs, get_device); auto it = _communicators.find(devices); if (it == _communicators.end()) { it = _communicators.emplace(devices, devices).first; } return it->second.ref(); } static inline void check_tensor( const at::Tensor& input, const at::optional& output, int input_multiplier, int output_multiplier, int64_t ref_numel, ScalarType ref_dtype) { auto check_one = [&](const at::Tensor& tensor) { if (!tensor.is_cuda() || tensor.is_sparse()) { throw std::runtime_error( "input and output elements have to be cuda dense Tensors"); } if (ref_dtype != tensor.scalar_type()) { throw std::runtime_error( "all inputs and outputs must be of the same Tensor dtype"); } if (!tensor.is_contiguous()) { throw std::runtime_error("all inputs and outputs have to be contiguous"); } }; check_one(input); // all inputs must be same size if (input.numel() != ref_numel) { throw std::runtime_error( "all inputs must have the same number of elements"); } if (output) { check_one(*output); // inputs and outputs must be on same device respectively if (input.get_device() != output->get_device()) { throw std::runtime_error("input and output must be on the same device"); } if (output->numel() * output_multiplier != ref_numel * input_multiplier) { throw std::runtime_error( "output must be of size input_size * size_multiplier"); } } } void check_inputs( TensorList inputs, TensorList outputs, int input_multiplier, int output_multiplier) { // len(inputs) == len(outputs) size_t len = inputs.size(); if (len <= 0) { throw std::runtime_error("input sequence can't be empty"); } if (len != outputs.size()) { std::stringstream err; err << "inputs and outputs sequences have to be of the same length, but got input of length " << len << " and output of length " << outputs.size(); throw std::runtime_error(err.str()); } device_set devices; int64_t numel = inputs[0].numel(); auto dtype = inputs[0].scalar_type(); for (const auto i : c10::irange(len)) { auto input = inputs[i]; auto output = outputs[i]; check_tensor( input, output, input_multiplier, output_multiplier, numel, dtype); auto input_device = input.get_device(); // inputs must be on unique devices if (devices.test(input_device)) { throw std::runtime_error("inputs must be on unique devices"); } devices.set(input_device); } } void check_inputs( TensorList inputs, const at::Tensor& output, int root, int input_multiplier, int output_multiplier) { auto len = inputs.size(); if (len <= 0) { throw std::runtime_error("input sequence can't be empty"); } device_set devices; int64_t numel = inputs[0].numel(); auto dtype = inputs[0].scalar_type(); for (const auto i : c10::irange(len)) { auto input = inputs[i]; check_tensor( input, i == static_cast>(root) ? at::optional{output} : at::nullopt, input_multiplier, output_multiplier, numel, dtype); auto input_device = input.get_device(); // inputs must be on unique devices if (devices.test(input_device)) { throw std::runtime_error("inputs must be on unique devices"); } devices.set(input_device); } } } // namespace detail AutoNcclGroup::AutoNcclGroup() { #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2) // nccl < 2.0 cannot be called concurrently with cudaFree (c10::cuda::getFreeMutex())->lock(); #endif comm_nonblocking_ = false; comm_ = nullptr; #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) detail::NCCL_CHECK(ncclGroupStart()); #endif } AutoNcclGroup::AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking) { #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2) // nccl < 2.0 cannot be called concurrently with cudaFree (c10::cuda::getFreeMutex())->lock(); #endif comm_ = comm; comm_nonblocking_ = comm_nonblocking; #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) detail::NCCL_CHECK(ncclGroupStart()); #endif } AutoNcclGroup::~AutoNcclGroup() noexcept(false) { #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) if (comm_nonblocking_ && comm_ != nullptr) { detail::NCCL_CHECK_TIMEOUT(ncclGroupEnd(), comm_); } else { detail::NCCL_CHECK(ncclGroupEnd()); } #endif #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2) (c10::cuda::getFreeMutex())->unlock(); #endif } bool is_available(TensorList tensors) { #ifdef USE_NCCL device_set devices; for (auto& tensor : tensors) { if (!tensor.is_cuda() || tensor.is_sparse()) return false; if (!tensor.is_contiguous()) return false; auto device = tensor.get_device(); if (devices[device]) return false; devices[device] = true; } return true; #else return false; #endif } std::uint64_t version() { #if defined(NCCL_MAJOR) constexpr std::uint64_t ver = (((uint64_t)NCCL_MAJOR) << 32) | (((uint64_t)NCCL_MINOR) << 16) | ((uint64_t)NCCL_PATCH); return ver; #elif defined(USE_NCCL) // return major version "1" return ((uint64_t)1) << 32; #else return 0; #endif } const char* version_suffix() { #if defined(NCCL_SUFFIX) return NCCL_SUFFIX; #else return ""; #endif } void get_unique_id(ncclUniqueId& id) { #ifdef USE_NCCL using namespace torch::cuda::nccl::detail; NCCL_CHECK(ncclGetUniqueId(to_nccl_unique_id(&id))); #else AT_ERROR("PyTorch built without NCCL support"); #endif } ncclComm_t comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank) { #ifdef USE_NCCL using namespace torch::cuda::nccl::detail; ncclComm_t comm; ncclUniqueId id = comm_id; NCCL_CHECK(ncclCommInitRank( to_nccl_comm(&comm), nranks, *(to_nccl_unique_id(&id)), rank)); return comm; #else return nullptr; #endif } void comm_destroy(ncclComm_t comm) { /* * TODO(T30279827) Temporarily disable calling ncclCommDestroy * Calling ncclCommDestroy while program exiting is undefined * according to Nvidia, and lead to segfault in NCCL 2 * (whether it is called before or after the CUDA runtime destructor). * Temporarily disable it in destructor to avoid segfault. * Following up with Nvidia for long term solution. */ return; #ifdef USE_NCCL using namespace torch::cuda::nccl::detail; NCCL_CHECK(ncclCommDestroy(to_nccl_comm(comm))); #endif } namespace { // NCCL changed the numerical type used for count between NCCL1 and NCCL2. // So we use the following struct, which gets the type of the second argument // of T, if T is a function type, with ncclBcast, to get that type statically // and programmatically. template struct GetSecondArgType; template struct GetSecondArgType { typedef typename std::decay::type type; }; constexpr auto count_max = std::numeric_limits::type>::max(); // Since NCCL 2.12.10, NCCL supports send/recv 0 byte: // https://github.com/NVIDIA/nccl/issues/696. The issue of skipping send/recv // is that it can cause deadlock when a rank send and recv 0 bytes so it's // completely skipping the collective, causing mismatch across ranks #if defined(NCCL_MAJOR) && \ ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR > 13))) template constexpr bool _nccl_should_send_recv(C10_UNUSED T _unused_) { return true; } #else // old NCCL uses 0 byte message for synchronization // Avoid send/recv when message size is zero template inline bool _nccl_should_send_recv(T value) { return value != 0; } #endif } // namespace size_t get_max_count() { return count_max; } void broadcast( TensorList tensors, const stream_list& streams, const comm_list& user_comms) { #ifdef USE_NCCL using namespace torch::cuda::nccl::detail; check_inputs(tensors, tensors, 1, 1); auto data_type = to_nccl_data_type(tensors[0]); int64_t numel = tensors[0].numel(); const auto comms = user_comms.empty() ? get_communicators(tensors) : ArrayRef(user_comms); AutoNcclGroup nccl_group_guard; at::cuda::OptionalCUDAGuard device_guard; for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; i++) { auto device = tensors[i].get_device(); device_guard.set_index(device); // Default to the current stream const auto stream = (streams.empty() || !streams[i]) ? at::cuda::getCurrentCUDAStream(device).stream() : streams[i]->stream(); TORCH_CHECK( static_cast(numel) <= static_cast(count_max), "Broadcast tensor has ", numel, " elements, which exceeds the " "maximum NCCL supports (", count_max, ")"); ncclComm_t comm = comms[i]; NCCL_CHECK(ncclBcast( tensors[i].data_ptr(), numel, data_type, 0, to_nccl_comm(comm), stream)); } #else AT_ERROR("PyTorch built without NCCL support"); #endif } void reduce( const std::vector& inputs, at::Tensor& output, int32_t root, int32_t op, const stream_list& streams, const comm_list& user_comms) { #ifdef USE_NCCL using namespace torch::cuda::nccl::detail; TORCH_CHECK( root >= 0 && static_cast(root) < inputs.size(), "invalid root"); check_inputs(inputs, output, root, 1, 1); const auto len = inputs.size(); auto data_type = to_nccl_data_type(inputs[0]); const auto count = inputs[0].numel(); auto comms_ref = user_comms.empty() ? get_communicators(inputs) : ArrayRef(user_comms); AutoNcclGroup nccl_group_guard; at::cuda::OptionalCUDAGuard device_guard; for (const auto i : c10::irange(len)) { auto device = inputs[i].device().index(); device_guard.set_index(device); // Default to the current stream const auto stream = (streams.empty() || !streams[i]) ? at::cuda::getCurrentCUDAStream(device).stream() : streams[i]->stream(); ncclComm_t comm = comms_ref[i]; NCCL_CHECK(ncclReduce( inputs[i].data_ptr(), static_cast>(root) == i ? output.data_ptr() : nullptr, count, data_type, to_nccl_red_op(op), root, to_nccl_comm(comm), stream)); } #else AT_ERROR("PyTorch built without NCCL support"); #endif } void reduce( std::vector& inputs, int32_t root, int32_t op, const stream_list& streams, const comm_list& user_comms) { reduce(inputs, /*output=*/inputs[root], root, op, streams, user_comms); } void all_reduce( const std::vector& inputs, std::vector& outputs, int32_t op, const stream_list& streams, const comm_list& user_comms) { #ifdef USE_NCCL using namespace torch::cuda::nccl::detail; check_inputs(inputs, outputs, 1, 1); const auto len = inputs.size(); auto data_type = to_nccl_data_type(inputs[0]); const auto count = inputs[0].numel(); auto comms_ref = user_comms.empty() ? get_communicators(inputs) : ArrayRef(user_comms); AutoNcclGroup nccl_group_guard; at::cuda::OptionalCUDAGuard device_guard; for (const auto i : c10::irange(len)) { auto device = inputs[i].device().index(); device_guard.set_index(device); // Default to the current stream const auto stream = (streams.empty() || !streams[i]) ? at::cuda::getCurrentCUDAStream(device).stream() : streams[i]->stream(); ncclComm_t comm = comms_ref[i]; NCCL_CHECK(ncclAllReduce( inputs[i].data_ptr(), outputs[i].data_ptr(), count, data_type, to_nccl_red_op(op), to_nccl_comm(comm), stream)); } #else AT_ERROR("PyTorch built without NCCL support"); #endif } void reduce_scatter( const std::vector& inputs, std::vector& outputs, int32_t op, const stream_list& streams, const comm_list& user_comms) { #ifdef USE_NCCL using namespace torch::cuda::nccl::detail; const auto len = inputs.size(); check_inputs(inputs, outputs, 1, len); auto data_type = to_nccl_data_type(inputs[0]); const auto count = inputs[0].numel() / len; auto comms_ref = user_comms.empty() ? get_communicators(inputs) : ArrayRef(user_comms); AutoNcclGroup nccl_group_guard; at::cuda::OptionalCUDAGuard device_guard; for (const auto i : c10::irange(len)) { auto device = inputs[i].device().index(); device_guard.set_index(device); // Default to the current stream const auto stream = (streams.empty() || !streams[i]) ? at::cuda::getCurrentCUDAStream(device).stream() : streams[i]->stream(); ncclComm_t comm = comms_ref[i]; NCCL_CHECK(ncclReduceScatter( inputs[i].data_ptr(), outputs[i].data_ptr(), count, data_type, to_nccl_red_op(op), to_nccl_comm(comm), stream)); } #else AT_ERROR("PyTorch built without NCCL support"); #endif } void all_gather( const std::vector& inputs, std::vector& outputs, const stream_list& streams, const comm_list& user_comms) { #ifdef USE_NCCL using namespace torch::cuda::nccl::detail; const auto len = inputs.size(); check_inputs(inputs, outputs, len, 1); auto data_type = to_nccl_data_type(inputs[0]); const auto count = inputs[0].numel(); auto comms_ref = user_comms.empty() ? get_communicators(inputs) : ArrayRef(user_comms); AutoNcclGroup nccl_group_guard; at::cuda::OptionalCUDAGuard device_guard; for (const auto i : c10::irange(len)) { auto device = inputs[i].device().index(); device_guard.set_index(device); // Default to the current stream const auto stream = (streams.empty() || !streams[i]) ? at::cuda::getCurrentCUDAStream(device).stream() : streams[i]->stream(); ncclComm_t comm = comms_ref[i]; #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) NCCL_CHECK(ncclAllGather( inputs[i].data_ptr(), outputs[i].data_ptr(), count, data_type, to_nccl_comm(comm), stream)); #else NCCL_CHECK(ncclAllGather( inputs[i].data_ptr(), count, data_type, outputs[i].data_ptr(), to_nccl_comm(comm), stream)); #endif } #else AT_ERROR("PyTorch built without NCCL support"); #endif } void all2all_single_equal_split( at::Tensor& input, at::Tensor& output, int size, ncclComm_t _comm, at::cuda::CUDAStream& stream) { #ifdef USE_NCCL #if defined(NCCL_MAJOR) && \ ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7))) using namespace torch::cuda::nccl::detail; int numranks; auto type = to_nccl_data_type(input); size_t count = input.numel() / size; size_t rankdiff = input.nbytes() / size; const auto* sendbuff = reinterpret_cast(input.const_data_ptr()); auto* recvbuff = reinterpret_cast(output.data_ptr()); auto comm = to_nccl_comm(_comm); #if defined(USE_ROCM) NCCL_CHECK(ncclAllToAll(sendbuff, recvbuff, count, type, comm, stream)); #else NCCL_CHECK(ncclCommCount(comm, &numranks)); NCCL_CHECK(ncclGroupStart()); for (const auto r : c10::irange(numranks)) { if (_nccl_should_send_recv(count)) { NCCL_CHECK( ncclSend(sendbuff + r * rankdiff, count, type, r, comm, stream)); NCCL_CHECK( ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream)); } } #ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclGroupEnd()); #else NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); #endif #endif #else AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); #endif #else AT_ERROR("PyTorch built without NCCL support"); #endif } void all2all_single_unequal_split( void* sendbuff, const size_t* sendcounts, const size_t* senddispls, void* recvbuff, const size_t* recvcounts, const size_t* recvdispls, size_t size, c10::ScalarType _type, ncclComm_t _comm, at::cuda::CUDAStream& stream) { #ifdef USE_NCCL #if defined(NCCL_MAJOR) && \ ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7))) using namespace torch::cuda::nccl::detail; auto type = to_nccl_data_type(_type); auto comm = to_nccl_comm(_comm); int numranks; NCCL_CHECK(ncclCommCount(comm, &numranks)); NCCL_CHECK(ncclGroupStart()); for (const auto r : c10::irange(numranks)) { if (_nccl_should_send_recv(sendcounts[r])) { NCCL_CHECK(ncclSend( ((char*)sendbuff) + senddispls[r] * size, sendcounts[r], type, r, comm, stream)); } if (_nccl_should_send_recv(recvcounts[r])) { NCCL_CHECK(ncclRecv( ((char*)recvbuff) + recvdispls[r] * size, recvcounts[r], type, r, comm, stream)); } } #ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclGroupEnd()); #else NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); #endif #else AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); #endif #else AT_ERROR("PyTorch built without NCCL support"); #endif } void all2all( std::vector& outputTensors, std::vector& inputTensors, ncclComm_t _comm, at::cuda::CUDAStream& stream) { #ifdef USE_NCCL #if defined(NCCL_MAJOR) && \ ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7))) using namespace torch::cuda::nccl::detail; auto comm = to_nccl_comm(_comm); NCCL_CHECK(ncclGroupStart()); for (const auto r : c10::irange(outputTensors.size())) { at::Tensor& input = inputTensors[r]; at::Tensor& output = outputTensors[r]; if (_nccl_should_send_recv(input.numel())) { NCCL_CHECK(ncclSend( input.data_ptr(), input.numel(), to_nccl_data_type(input), r, comm, stream.stream())); } if (_nccl_should_send_recv(output.numel())) { NCCL_CHECK(ncclRecv( output.data_ptr(), output.numel(), to_nccl_data_type(output), r, comm, stream.stream())); } } #ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclGroupEnd()); #else NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); #endif #else AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); #endif #else AT_ERROR("PyTorch built without NCCL support"); #endif } void send( const at::Tensor& input, ncclComm_t comm, at::cuda::CUDAStream stream, int dst) { #ifdef USE_NCCL #if defined(NCCL_MAJOR) && \ ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7))) using namespace torch::cuda::nccl::detail; #ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclSend( input.data_ptr(), input.numel(), to_nccl_data_type(input), dst, to_nccl_comm(comm), stream.stream())); #else NCCL_CHECK_TIMEOUT( ncclSend( input.data_ptr(), input.numel(), to_nccl_data_type(input), dst, to_nccl_comm(comm), stream.stream()), comm); #endif #else AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0"); #endif #else AT_ERROR("PyTorch built without NCCL support"); #endif } void recv( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream stream, int src) { #ifdef USE_NCCL #if defined(NCCL_MAJOR) && \ ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7))) using namespace torch::cuda::nccl::detail; #ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclRecv( output.data_ptr(), output.numel(), to_nccl_data_type(output), src, to_nccl_comm(comm), stream.stream())); #else NCCL_CHECK_TIMEOUT( ncclRecv( output.data_ptr(), output.numel(), to_nccl_data_type(output), src, to_nccl_comm(comm), stream.stream()), comm); #endif #else AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0"); #endif #else AT_ERROR("PyTorch built without NCCL support"); #endif } void gather( const at::Tensor& inputs, std::vector& outputs, ncclComm_t _comm, at::cuda::CUDAStream& stream, int32_t root) { #ifdef USE_NCCL #if defined(NCCL_MAJOR) && \ ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7))) using namespace torch::cuda::nccl::detail; auto comm = to_nccl_comm(_comm); int numranks, cur_rank; NCCL_CHECK(ncclCommCount(comm, &numranks)); NCCL_CHECK(ncclCommUserRank(comm, &cur_rank)); size_t count = inputs.numel(); auto type = to_nccl_data_type(inputs); const auto* sendbuff = reinterpret_cast(inputs.const_data_ptr()); NCCL_CHECK(ncclGroupStart()); if (cur_rank == root) { for (const auto r : c10::irange(numranks)) { if (r != root) { auto* recvbuff = reinterpret_cast(outputs[r].data_ptr()); NCCL_CHECK(ncclRecv(recvbuff, count, type, r, comm, stream)); } else { // on its own rank, simply copy from the input outputs[r].copy_(inputs); } } } else { NCCL_CHECK(ncclSend(sendbuff, count, type, root, comm, stream)); } #ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclGroupEnd()); #else NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); #endif #else AT_ERROR("gather is only supported for NCCL lib version >= 2.7.0"); #endif #else AT_ERROR("PyTorch built without NCCL support"); #endif } void scatter( const std::vector& inputs, at::Tensor& outputs, ncclComm_t _comm, at::cuda::CUDAStream& stream, int32_t root) { #ifdef USE_NCCL #if defined(NCCL_MAJOR) && \ ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7))) using namespace torch::cuda::nccl::detail; auto comm = to_nccl_comm(_comm); int numranks, cur_rank; #ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclCommCount(comm, &numranks)); NCCL_CHECK(ncclCommUserRank(comm, &cur_rank)); #else NCCL_CHECK_TIMEOUT(ncclCommCount(comm, &numranks), _comm); NCCL_CHECK_TIMEOUT(ncclCommUserRank(comm, &cur_rank), _comm); #endif NCCL_CHECK(ncclGroupStart()); if (cur_rank == root) { for (const auto r : c10::irange(numranks)) { if (r != root) { size_t send_count = inputs[r].numel(); auto send_type = to_nccl_data_type(inputs[r]); const auto* sendbuff = reinterpret_cast(inputs[r].const_data_ptr()); NCCL_CHECK(ncclSend(sendbuff, send_count, send_type, r, comm, stream)); } else { // on its own rank, simply copy it to the output outputs.copy_(inputs[r]); } } } else { size_t recv_count = outputs.numel(); auto recv_type = to_nccl_data_type(outputs); auto* recvbuff = reinterpret_cast(outputs.data_ptr()); NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream)); } #ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclGroupEnd()); #else NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); #endif #else AT_ERROR("scatter is only supported for NCCL lib version >= 2.7.0"); #endif #else AT_ERROR("PyTorch built without NCCL support"); #endif } } // namespace torch::cuda::nccl