1 #include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
2
3 #include <ATen/Dispatch.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <c10/cuda/CUDAGuard.h>
6
7 namespace c10d {
8 namespace intra_node_comm {
9
10 static constexpr size_t kBytesPerThread = 16;
11 static constexpr size_t kMaxAllReduceBlocks = 24;
12 static constexpr size_t kThreadsPerBlock = 1024;
13 static constexpr size_t kWarpSize = 32;
14
15 static constexpr size_t kHcmThreshBytes = 256 * 1024;
16 static constexpr size_t kOneShotThreshBytes = 256 * 1024;
17 static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024;
18
19 #if defined(USE_ROCM)
20 using __nv_bfloat162 = uint32_t;
21 #endif
22
23 struct __align__(16) bf16x8 {
24 __nv_bfloat162 vals[4];
25 };
26
27 #define DEVICE_INLINE __device__ inline __attribute__((always_inline))
28
29 DEVICE_INLINE __nv_bfloat162
bf16hadd2(const __nv_bfloat162 x,const __nv_bfloat162 y)30 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
31 #if defined(USE_ROCM)
32 CUDA_KERNEL_ASSERT(false);
33 return 0;
34 #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
35 CUDA_KERNEL_ASSERT(false);
36 __nv_bfloat162 res;
37 return res;
38 #else
39 return __hadd2(x, y);
40 #endif
41 }
42
add_bf16x8(bf16x8 a,bf16x8 b)43 DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) {
44 bf16x8 c;
45 c.vals[0] = bf16hadd2(a.vals[0], b.vals[0]);
46 c.vals[1] = bf16hadd2(a.vals[1], b.vals[1]);
47 c.vals[2] = bf16hadd2(a.vals[2], b.vals[2]);
48 c.vals[3] = bf16hadd2(a.vals[3], b.vals[3]);
49 return c;
50 }
51
52 /**
53 * NOTE [cross device memory synchronization]
54 *
55 * The multi-stage algorithms (e.g. two-shot, hcm allreduce) require the writes
56 * of a thread to be visible by threads with the same block/thread ID on other
57 * devices. To satisfy CUDA's memory consistency model, every thread has to
58 * release its writes at the system scope, and the consuming thread has to
59 * acquire the writes at the system scope. This incurs high overhead and
60 * attempts in optmizing this process can be prone to race condition.
61 *
62 * Instead, we go around caching by having each thread:
63 *
64 * - Directly write to global memory via st.cs (cache-streaming).
65 * - Synchronize with threads within the block.
66 * - Perform cross device synchronization at block level (via system scope
67 * atomic ops).
68 * - Synchronize with threads within the block.
69 * - Directly read from global memory via ld.nc (non-coherent/non-cached).
70 */
71 template <typename T>
streamLoad128(bf16x8 & val,const T * addr)72 DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) {
73 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
74 CUDA_KERNEL_ASSERT(false);
75 #else
76 unsigned long long int low, high;
77 asm("ld.global.nc.v2.u64 {%0, %1}, [%2];"
78 : "=l"(low), "=l"(high)
79 : "l"(addr));
80 reinterpret_cast<unsigned long long int*>(&val)[0] = low;
81 reinterpret_cast<unsigned long long int*>(&val)[1] = high;
82 #endif
83 }
84
streamStore128(at::BFloat16 * addr,const bf16x8 & val)85 __device__ inline void streamStore128(at::BFloat16* addr, const bf16x8& val) {
86 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
87 CUDA_KERNEL_ASSERT(false);
88 #else
89 unsigned long long int low, high;
90 low = reinterpret_cast<const unsigned long long int*>(&val)[0];
91 high = reinterpret_cast<const unsigned long long int*>(&val)[1];
92 asm("st.global.cs.v2.u64 [%0], {%1, %2};" : : "l"(addr), "l"(low), "l"(high));
93 #endif
94 }
95
96 template <typename T>
load128(bf16x8 & val,const T * addr)97 DEVICE_INLINE void load128(bf16x8& val, const T* addr) {
98 *reinterpret_cast<uint4*>(&val) = reinterpret_cast<const uint4*>(addr)[0];
99 }
100
101 template <typename T>
store128(T * addr,const bf16x8 & val)102 DEVICE_INLINE void store128(T* addr, const bf16x8& val) {
103 *reinterpret_cast<uint4*>(addr) = reinterpret_cast<const uint4*>(&val)[0];
104 }
105
releaseSignal(uint32_t * addr)106 DEVICE_INLINE void releaseSignal(uint32_t* addr) {
107 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
108 CUDA_KERNEL_ASSERT(false);
109 #else
110 atomicAdd_system(addr, 1);
111 #endif
112 }
113
acquireSignal(uint32_t * addr)114 DEVICE_INLINE void acquireSignal(uint32_t* addr) {
115 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
116 CUDA_KERNEL_ASSERT(false);
117 #else
118 volatile uint32_t* signal = addr;
119 uint32_t val;
120 do {
121 val = *signal;
122 } while (val == 0 || atomicCAS_system(addr, val, val - 1) != val);
123 #endif
124 }
125
126 ////////////////////////////////////////////////////////////////////////////////
127 // Fully Connected Algos
128 ////////////////////////////////////////////////////////////////////////////////
129
130 struct P2pState {
131 uint32_t signals0[kMaxAllReduceBlocks][kMaxDevices];
132 uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices];
133 };
134
135 static_assert(sizeof(P2pState) <= kP2pStateSize);
136
137 template <uint32_t kWorldSize, bool kAligned>
oneShotAllReduceKernel(at::BFloat16 * input,size_t N,size_t N_aligned,P2pState ** p2pStates,at::BFloat16 ** buffers,size_t rank,bool fuseInputCopy)138 static __global__ void oneShotAllReduceKernel(
139 at::BFloat16* input,
140 size_t N,
141 size_t N_aligned,
142 P2pState** p2pStates,
143 at::BFloat16** buffers,
144 size_t rank,
145 bool fuseInputCopy) {
146 const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
147 const size_t offset =
148 (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
149 const size_t stride = blockDim.x * gridDim.x * numelPerThread;
150
151 if (fuseInputCopy) {
152 for (size_t i = offset; i < N_aligned; i += stride) {
153 bf16x8 val;
154 streamLoad128(val, &input[i]);
155 streamStore128(&buffers[rank][i], val);
156 }
157 }
158
159 // Wait for all other ranks to enter the kernel
160 if (threadIdx.x < kWorldSize) {
161 auto targetRank = threadIdx.x;
162 releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
163 acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
164 }
165 __syncthreads();
166
167 // The source pointers. Distributed round-robin for the different warps
168 const at::BFloat16* srcs[kWorldSize];
169 #pragma unroll kWorldSize
170 for (int ii = 0; ii < kWorldSize; ++ii) {
171 int srcRank = (rank + ii) % kWorldSize;
172 srcs[ii] = buffers[srcRank];
173 }
174
175 for (size_t i = offset; i < N_aligned; i += stride) {
176 bf16x8 vals[kWorldSize];
177 #pragma unroll kWorldSize
178 for (size_t ii = 0; ii < kWorldSize; ++ii) {
179 // Make sure the values in `vals` are order by rank so that the reduction
180 // results are consistent across ranks.
181 int srcRank = (ii + kWorldSize - rank) % kWorldSize;
182 streamLoad128(vals[srcRank], &srcs[ii][i]);
183 }
184
185 bf16x8 sums;
186 memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
187
188 #pragma unroll kWorldSize
189 for (size_t ii = 0; ii < kWorldSize; ++ii) {
190 sums = add_bf16x8(sums, vals[ii]);
191 }
192 if constexpr (kAligned) {
193 streamStore128(&input[i], sums);
194 } else {
195 for (size_t ii = 0; ii < numelPerThread; ++ii) {
196 if (i + ii < N) {
197 input[i + ii] = reinterpret_cast<at::BFloat16*>(&sums)[ii];
198 }
199 }
200 }
201 }
202 }
203
204 template <uint32_t kWorldSize>
twoShotAllReduceKernel(at::BFloat16 * input,size_t N_aligned,P2pState ** p2pStates,at::BFloat16 ** buffers,size_t rank)205 static __launch_bounds__(1024) __global__ void twoShotAllReduceKernel(
206 at::BFloat16* input,
207 size_t N_aligned,
208 P2pState** p2pStates,
209 at::BFloat16** buffers,
210 size_t rank) {
211 const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
212 const size_t offset =
213 (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
214 const size_t stride = blockDim.x * gridDim.x * numelPerThread;
215 const size_t N_per_rank = N_aligned / kWorldSize;
216 const size_t N_start = N_per_rank * rank;
217
218 // Wait for all other ranks to enter the kernel
219 if (threadIdx.x < kWorldSize) {
220 auto targetRank = threadIdx.x;
221 releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
222 acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
223 }
224 __syncthreads();
225
226 // The source pointers. Distributed round-robin for the different warps
227 at::BFloat16* srcs[kWorldSize];
228 size_t srcRanks[kWorldSize];
229 #pragma unroll kWorldSize
230 for (int ii = 0; ii < kWorldSize; ++ii) {
231 int srcRank = (rank + ii) % kWorldSize;
232 srcs[ii] = buffers[srcRank];
233 srcRanks[ii] = srcRank;
234 }
235
236 for (size_t i = offset; i < N_per_rank; i += stride) {
237 bf16x8 vals[kWorldSize];
238 #pragma unroll kWorldSize
239 for (size_t ii = 0; ii < kWorldSize; ++ii) {
240 // Make sure the values in `vals` are order by rank so that the reduction
241 // results are consistent across ranks.
242 int srcRank = (ii + kWorldSize - rank) % kWorldSize;
243 streamLoad128(vals[srcRank], &srcs[ii][N_start + i]);
244 }
245
246 bf16x8 sums;
247 memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
248
249 #pragma unroll kWorldSize
250 for (size_t ii = 0; ii < kWorldSize; ++ii) {
251 sums = add_bf16x8(sums, vals[ii]);
252 }
253 streamStore128(&srcs[0][N_start + i], sums);
254 // Store local sums into input now so we can avoid
255 // a global memory access later for it.
256 streamStore128(&input[N_start + i], sums);
257 }
258 __syncthreads();
259
260 if (threadIdx.x < kWorldSize) {
261 auto targetRank = threadIdx.x;
262 releaseSignal(&p2pStates[targetRank]->signals1[blockIdx.x][rank]);
263 acquireSignal(&p2pStates[rank]->signals1[blockIdx.x][targetRank]);
264 }
265 __syncthreads();
266
267 for (size_t i = offset; i < N_per_rank; i += stride) {
268 #pragma unroll kWorldSize - 1
269 for (size_t ii = 1; ii < kWorldSize; ++ii) {
270 size_t k = N_start + i + (srcRanks[ii] - rank) * N_per_rank;
271 bf16x8 val;
272 streamLoad128(val, &srcs[ii][k]);
273 streamStore128(&input[k], val);
274 }
275 }
276 }
277
278 ////////////////////////////////////////////////////////////////////////////////
279 // Hybrid Cube Mesh Algos
280 ////////////////////////////////////////////////////////////////////////////////
281
282 /**
283 * NOTE [hybrid cube mesh]
284 *
285 * In a hybrid cube mesh topology, every device has exactly 4 neighbors
286 * (directly connected via NVLink). For every device X, it has exactly 1
287 * neighbor Y that is a neighbor of the 3 non-neighbor of X. We call Y the
288 * relay neighbor of X. This property is symmetrical: X is also guaranteed to
289 * be the relay neighbor of Y.
290 *
291 * With this property, we can perform a variant of one-shot allreduce algo that
292 * only moves data across NVLinks:
293 *
294 * - Each device one-shot allreduce among itself and 3 non-relay neighbors.
295 * - Each device exchange data with its relay neighbor.
296 *
297 * HybridCubeMesh is a data structure for describing the topology:
298 *
299 * - hcm[X][0:3] are the 3 neighbors of X.
300 * - hcm[X][3] is the relay neighbor of X.
301 * - For load balancing purpose, we also ensure that if hcm[X][k] = Y,
302 * hcm[Y][k] = X.
303 */
getHybridCubeMesh(NvlMesh nvlMesh)304 std::optional<HybridCubeMesh> getHybridCubeMesh(NvlMesh nvlMesh) {
305 std::array<std::unordered_set<size_t>, kMaxDevices> neighbors = {};
306 std::array<size_t, kMaxDevices> neighborMasks = {};
307 for (size_t i = 0; i < kMaxDevices; ++i) {
308 for (size_t j = 0; j < kMaxDevices; ++j) {
309 if (nvlMesh[i][j] > 0) {
310 neighbors[i].insert(j);
311 neighborMasks[i] |= (1ul << j);
312 }
313 }
314 }
315 HybridCubeMesh hcm = {};
316 for (auto& row : hcm) {
317 row.fill(-1);
318 }
319 // A topology is an HCM if:
320 // - Every device has exactly 4 neighbors.
321 // - For every device, it has exactly 1 relay neighbor that is
322 // a neighbor of the 3 non-neighbor of the device.
323 for (size_t i = 0; i < kMaxDevices; ++i) {
324 if (neighbors[i].size() != 4) {
325 return std::nullopt;
326 }
327 // Condition 1: check the number of neighbors
328 std::vector<size_t> relayNeighbors;
329 for (size_t j = 0; j < kMaxDevices; ++j) {
330 if ((neighborMasks[i] & neighborMasks[j]) == 0) {
331 relayNeighbors.push_back(j);
332 }
333 }
334 // Condition 2: check the number of relay neighbors
335 if (relayNeighbors.size() != 1) {
336 return std::nullopt;
337 }
338 neighbors[i].erase(relayNeighbors[0]);
339 hcm[i][3] = relayNeighbors[0];
340 }
341
342 for (size_t i = 0; i < kMaxDevices; ++i) {
343 for (size_t k = 0; k < 3; ++k) {
344 // We can only fill hcm[i][k] with j if hcm[j][k] is not filled
345 for (size_t j : neighbors[i]) {
346 if (hcm[j][k] == -1) {
347 hcm[i][k] = j;
348 hcm[j][k] = i;
349 break;
350 }
351 }
352 TORCH_CHECK(hcm[i][k] != -1);
353 neighbors[i].erase(hcm[i][k]);
354 }
355 }
356 return hcm;
357 }
358
359 template <bool kAligned>
hybridCubeMeshAllReduceKernel(at::BFloat16 * input,size_t N,size_t N_aligned,P2pState ** p2pStates,at::BFloat16 ** buffers,int hcmInfo[4],size_t bufferSize,size_t rank)360 static __global__ void hybridCubeMeshAllReduceKernel(
361 at::BFloat16* input,
362 size_t N,
363 size_t N_aligned,
364 P2pState** p2pStates,
365 at::BFloat16** buffers,
366 int hcmInfo[4],
367 size_t bufferSize,
368 size_t rank) {
369 const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
370 const size_t offset =
371 (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
372 const size_t stride = blockDim.x * gridDim.x * numelPerThread;
373 const int relayRank = hcmInfo[3];
374
375 // Wait for HCM neigbors to enter the kernel
376 if (threadIdx.x < 3) {
377 auto targetRank = hcmInfo[threadIdx.x];
378 releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
379 acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
380 }
381 __syncthreads();
382
383 const at::BFloat16* srcs[4] = {
384 buffers[rank],
385 buffers[hcmInfo[0]],
386 buffers[hcmInfo[1]],
387 buffers[hcmInfo[2]],
388 };
389 // Use the half second half of the buffer as relay
390 at::BFloat16* localRelay =
391 buffers[rank] + (bufferSize / sizeof(at::BFloat16) / 2);
392 at::BFloat16* remoteRelay =
393 buffers[relayRank] + (bufferSize / sizeof(at::BFloat16) / 2);
394
395 for (size_t i = offset; i < N_aligned; i += stride) {
396 bf16x8 vals[4];
397
398 #pragma unroll 4
399 for (size_t ii = 0; ii < 4; ++ii) {
400 streamLoad128(vals[ii], &srcs[ii][i]);
401 }
402
403 bf16x8 sums;
404 memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
405
406 #pragma unroll 4
407 for (size_t ii = 0; ii < 4; ++ii) {
408 sums = add_bf16x8(sums, vals[ii]);
409 }
410 // Cached store for local sums
411 store128(&localRelay[i], sums);
412 }
413 __syncthreads();
414
415 if (threadIdx.x == 0) {
416 releaseSignal(&p2pStates[relayRank]->signals0[blockIdx.x][rank]);
417 acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][relayRank]);
418 }
419 __syncthreads();
420
421 for (size_t i = offset; i < N_aligned; i += stride) {
422 bf16x8 localSum, remoteSum;
423 // Cached load for local sums
424 load128(localSum, &localRelay[i]);
425 streamLoad128(remoteSum, &remoteRelay[i]);
426 localSum = add_bf16x8(localSum, remoteSum);
427 if constexpr (kAligned) {
428 streamStore128(&input[i], localSum);
429 } else {
430 for (size_t ii = 0; ii < numelPerThread; ++ii) {
431 if (i + ii < N) {
432 input[i + ii] = reinterpret_cast<at::BFloat16*>(&localSum)[ii];
433 }
434 }
435 }
436 }
437 }
438
divUp(uint32_t a,uint32_t b)439 static inline size_t divUp(uint32_t a, uint32_t b) {
440 return (a + b - 1) / b;
441 }
442
alignUp(uint32_t a,uint32_t b)443 static inline size_t alignUp(uint32_t a, uint32_t b) {
444 return divUp(a, b) * b;
445 }
446
checkInput(const at::Tensor & input,int deviceIdx)447 static void checkInput(const at::Tensor& input, int deviceIdx) {
448 TORCH_CHECK(
449 input.dtype() == at::kBFloat16,
450 "oneShotAllReduce only supports bf16 for now");
451 TORCH_CHECK(input.is_non_overlapping_and_dense());
452 TORCH_CHECK(input.device().is_cuda());
453 TORCH_CHECK(
454 input.get_device() == deviceIdx,
455 "IntraNodeComm: expect input to be on device ",
456 deviceIdx,
457 ", got device ",
458 input.get_device());
459 }
460
getLaunchConfig(size_t N_aligned,size_t elemSize,dim3 & blocks,dim3 & threads)461 static void getLaunchConfig(
462 size_t N_aligned,
463 size_t elemSize,
464 dim3& blocks,
465 dim3& threads) {
466 blocks = dim3(0, 1, 1);
467 threads = dim3(0, 1, 1);
468
469 const auto numelPerThread = kBytesPerThread / elemSize;
470 const auto numelPerWarp = numelPerThread * kWarpSize;
471 TORCH_CHECK(N_aligned % numelPerThread == 0);
472 TORCH_CHECK(N_aligned % numelPerWarp == 0);
473 if (N_aligned < numelPerThread * kThreadsPerBlock) {
474 threads.x = N_aligned / numelPerWarp * kWarpSize;
475 blocks.x = 1;
476 } else {
477 auto warpsRequired = N_aligned / numelPerWarp;
478 auto threadsRequired = N_aligned / numelPerThread;
479 blocks.x =
480 std::min(divUp(threadsRequired, kThreadsPerBlock), kMaxAllReduceBlocks);
481 auto warpsPerBlock = divUp(warpsRequired, blocks.x);
482 threads.x = std::min(kThreadsPerBlock, warpsPerBlock * kWarpSize);
483 }
484 }
485
isIntraNodeCommSupported()486 bool isIntraNodeCommSupported() {
487 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
488 return false;
489 #else
490 return true;
491 #endif
492 }
493
initP2pState()494 void* initP2pState() {
495 void* state = nullptr;
496 AT_CUDA_CHECK(cudaMalloc(&state, sizeof(P2pState)));
497 AT_CUDA_CHECK(cudaMemset(state, 0, sizeof(P2pState)));
498 return state;
499 }
500
initTopoInfo(Topology topology,NvlMesh nvlMesh,size_t rank)501 void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank) {
502 void* topoInfo = nullptr;
503 if (topology != Topology::HYBRID_CUBE_MESH) {
504 return topoInfo;
505 }
506 auto hcm = getHybridCubeMesh(nvlMesh);
507 int hcmInfo[4];
508 std::copy((*hcm)[rank].begin(), (*hcm)[rank].begin() + 4, hcmInfo);
509 AT_CUDA_CHECK(cudaMalloc(&topoInfo, sizeof(hcmInfo)));
510 AT_CUDA_CHECK(
511 cudaMemcpy(topoInfo, hcmInfo, sizeof(hcmInfo), cudaMemcpyHostToDevice));
512 return topoInfo;
513 }
514
oneShotAllReduce(const at::Tensor & input,at::cuda::CUDAStream & stream)515 at::Tensor IntraNodeComm::oneShotAllReduce(
516 const at::Tensor& input,
517 at::cuda::CUDAStream& stream) {
518 checkInput(input, deviceIdx_);
519
520 const size_t numelPerWarp =
521 kBytesPerThread / input.element_size() * kWarpSize;
522 const size_t N_aligned = alignUp(input.numel(), numelPerWarp);
523 const bool isAligned = (N_aligned == static_cast<size_t>(input.numel()));
524 TORCH_CHECK(N_aligned <= bufferSize_ / input.element_size());
525
526 dim3 blocks, threads;
527 getLaunchConfig(N_aligned, input.element_size(), blocks, threads);
528
529 at::cuda::OptionalCUDAGuard guard(input.get_device());
530
531 // When the input data is small, copying inside the kernel is faster. Because
532 // in such cases, the launch overhead of cudaMemcpyAsync outweighs its
533 // efficiency. Here we consider the input data to be small if the copy loop
534 // can finish in a single iteration.
535 const bool fuseInputCopy = isAligned && blocks.x < kMaxAllReduceBlocks;
536 if (!fuseInputCopy) {
537 AT_CUDA_CHECK(cudaMemcpyAsync(
538 symmetricMemory_->get_buffer_ptrs()[rank_],
539 input.data_ptr(),
540 input.numel() * input.element_size(),
541 cudaMemcpyDeviceToDevice,
542 stream));
543 }
544
545 #define X(kWorldSize, kAligned) \
546 if (worldSize_ == kWorldSize) { \
547 oneShotAllReduceKernel<kWorldSize, kAligned> \
548 <<<blocks, threads, 0, stream>>>( \
549 input.data_ptr<at::BFloat16>(), \
550 input.numel(), \
551 N_aligned, \
552 reinterpret_cast<P2pState**>(p2pStatesDev_), \
553 reinterpret_cast<at::BFloat16**>(buffersDev_), \
554 rank_, \
555 fuseInputCopy); \
556 C10_CUDA_KERNEL_LAUNCH_CHECK(); \
557 }
558
559 #define DISPATCH_ALL_WORLD_SIZES(kAligned) \
560 X(2, kAligned); \
561 X(3, kAligned); \
562 X(4, kAligned); \
563 X(5, kAligned); \
564 X(6, kAligned); \
565 X(7, kAligned); \
566 X(8, kAligned);
567
568 if (isAligned) {
569 DISPATCH_ALL_WORLD_SIZES(true);
570 } else {
571 DISPATCH_ALL_WORLD_SIZES(false);
572 }
573
574 #undef DISPATCH_ALL_WORLD_SIZES
575 #undef X
576 return input;
577 }
578
twoShotAllReduce(const at::Tensor & input,at::cuda::CUDAStream & stream)579 at::Tensor IntraNodeComm::twoShotAllReduce(
580 const at::Tensor& input,
581 at::cuda::CUDAStream& stream) {
582 checkInput(input, deviceIdx_);
583
584 size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
585 size_t N_aligned = alignUp(input.numel(), worldSize_ * numelPerWarp);
586 size_t N_per_rank = N_aligned / worldSize_;
587 TORCH_CHECK(N_aligned <= bufferSize_ / input.element_size());
588
589 dim3 blocks, threads;
590 getLaunchConfig(N_per_rank, input.element_size(), blocks, threads);
591
592 auto output = N_aligned == static_cast<size_t>(input.numel())
593 ? input
594 : input.new_empty(N_aligned);
595
596 at::cuda::OptionalCUDAGuard guard(input.get_device());
597 AT_CUDA_CHECK(cudaMemcpyAsync(
598 symmetricMemory_->get_buffer_ptrs()[rank_],
599 input.data_ptr(),
600 input.numel() * input.element_size(),
601 cudaMemcpyDeviceToDevice,
602 stream));
603
604 #define X(kWorldSize) \
605 if (worldSize_ == kWorldSize) { \
606 twoShotAllReduceKernel<kWorldSize><<<blocks, threads, 0, stream>>>( \
607 output.data_ptr<at::BFloat16>(), \
608 N_aligned, \
609 reinterpret_cast<P2pState**>(p2pStatesDev_), \
610 reinterpret_cast<at::BFloat16**>(buffersDev_), \
611 rank_); \
612 C10_CUDA_KERNEL_LAUNCH_CHECK(); \
613 }
614 X(2);
615 X(3);
616 X(4);
617 X(5);
618 X(6);
619 X(7);
620 X(8);
621 #undef X
622
623 if (output.data_ptr() != input.data_ptr()) {
624 AT_CUDA_CHECK(cudaMemcpyAsync(
625 input.data_ptr(),
626 output.data_ptr(),
627 input.numel() * input.element_size(),
628 cudaMemcpyDeviceToDevice,
629 stream));
630 }
631 return input;
632 }
633
hybridCubeMeshAllReduce(const at::Tensor & input,at::cuda::CUDAStream & stream)634 at::Tensor IntraNodeComm::hybridCubeMeshAllReduce(
635 const at::Tensor& input,
636 at::cuda::CUDAStream& stream) {
637 checkInput(input, deviceIdx_);
638
639 size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
640 size_t N_aligned = alignUp(input.numel(), numelPerWarp);
641 TORCH_CHECK(N_aligned * 2 <= bufferSize_ / input.element_size());
642
643 dim3 blocks, threads;
644 getLaunchConfig(N_aligned, input.element_size(), blocks, threads);
645
646 at::cuda::OptionalCUDAGuard guard(input.get_device());
647 AT_CUDA_CHECK(cudaMemcpyAsync(
648 symmetricMemory_->get_buffer_ptrs()[rank_],
649 input.data_ptr(),
650 input.numel() * input.element_size(),
651 cudaMemcpyDeviceToDevice,
652 stream));
653
654 #define X(kAligned) \
655 hybridCubeMeshAllReduceKernel<kAligned><<<blocks, threads, 0, stream>>>( \
656 input.data_ptr<at::BFloat16>(), \
657 input.numel(), \
658 N_aligned, \
659 reinterpret_cast<P2pState**>(p2pStatesDev_), \
660 reinterpret_cast<at::BFloat16**>(buffersDev_), \
661 static_cast<int*>(topoInfo_), \
662 bufferSize_, \
663 rank_); \
664 C10_CUDA_KERNEL_LAUNCH_CHECK();
665
666 if (N_aligned == static_cast<size_t>(input.numel())) {
667 X(true);
668 } else {
669 X(false);
670 }
671 #undef X
672 return input;
673 }
674
selectAllReduceAlgo(const at::Tensor & input)675 AllReduceAlgo IntraNodeComm::selectAllReduceAlgo(const at::Tensor& input) {
676 // Only support bf16 for now
677 if (input.dtype() != at::kBFloat16) {
678 return AllReduceAlgo::NONE;
679 }
680 const auto inputSize = input.numel() * input.element_size();
681 const auto bytesPerWarp = kBytesPerThread * kWarpSize;
682
683 if (topology_ == Topology::HYBRID_CUBE_MESH) {
684 TORCH_CHECK(
685 worldSize_ == 8, "hyperCubeAllReduce only supports exactly 8 GPUs");
686 const auto hcmInputSize = alignUp(inputSize, bytesPerWarp);
687 const auto hcmBufferSizeReq = hcmInputSize * 2;
688 if (hcmInputSize <= kHcmThreshBytes && hcmBufferSizeReq <= bufferSize_) {
689 return AllReduceAlgo::HCM;
690 }
691 }
692 if (topology_ == Topology::FULLY_CONNECTED) {
693 const auto oneShotInputSize = alignUp(inputSize, bytesPerWarp);
694 const auto oneShotBufferSizeReq = oneShotInputSize;
695 if (oneShotInputSize <= kOneShotThreshBytes &&
696 oneShotBufferSizeReq <= bufferSize_) {
697 return AllReduceAlgo::ONE_SHOT;
698 }
699
700 const auto twoShotInputSize = alignUp(inputSize, bytesPerWarp * worldSize_);
701 const auto twoShotBufferSizeReq = twoShotInputSize;
702 if (twoShotInputSize <= kTwoShotThreshBytes &&
703 twoShotBufferSizeReq <= bufferSize_) {
704 return AllReduceAlgo::TWO_SHOT;
705 }
706 }
707 return AllReduceAlgo::NONE;
708 }
709
710 static int64_t usageCounter = 0;
711
allReduce(const at::Tensor & input,AllReduceAlgo algo)712 at::Tensor IntraNodeComm::allReduce(
713 const at::Tensor& input,
714 AllReduceAlgo algo) {
715 // Report usage for testing purposes.
716 // We don't care about overflowing.
717 ++usageCounter;
718 auto stream = at::cuda::getCurrentCUDAStream();
719 c10::cuda::CUDACachingAllocator::recordStream(
720 input.storage().data_ptr(), stream);
721 switch (algo) {
722 case AllReduceAlgo::ONE_SHOT:
723 return oneShotAllReduce(input, stream);
724 case AllReduceAlgo::TWO_SHOT:
725 return twoShotAllReduce(input, stream);
726 case AllReduceAlgo::HCM:
727 return hybridCubeMeshAllReduce(input, stream);
728 default:
729 C10_THROW_ERROR(ValueError, "IntraNodeComm: invalid algo");
730 }
731 }
732
getIntraNodeCommUsageCounter()733 int64_t getIntraNodeCommUsageCounter() {
734 return usageCounter;
735 }
736
barrierKernel(P2pState ** p2pStates,uint64_t mask,size_t rank,size_t worldSize)737 static __global__ void barrierKernel(
738 P2pState** p2pStates,
739 uint64_t mask,
740 size_t rank,
741 size_t worldSize) {
742 if (threadIdx.x < worldSize && (mask & (1ULL << threadIdx.x))) {
743 auto targetRank = threadIdx.x;
744 releaseSignal(&p2pStates[targetRank]->signals0[0][rank]);
745 acquireSignal(&p2pStates[rank]->signals0[0][targetRank]);
746 }
747 }
748
barrier(std::optional<std::vector<int64_t>> ranks)749 void IntraNodeComm::barrier(std::optional<std::vector<int64_t>> ranks) {
750 barrierReady_.block(at::cuda::getCurrentCUDAStream());
751 if (!ranks.has_value()) {
752 ranks = std::vector<int64_t>(worldSize_);
753 std::iota(ranks->begin(), ranks->end(), 0);
754 }
755 uint64_t mask = 0;
756 for (const auto& r : ranks.value()) {
757 TORCH_CHECK(r >= 0 && r < static_cast<int64_t>(worldSize_));
758 mask |= (1ULL << r);
759 }
760 barrierKernel<<<1, kWarpSize, 0, at::cuda::getCurrentCUDAStream()>>>(
761 reinterpret_cast<P2pState**>(p2pStatesDev_), mask, rank_, worldSize_);
762 C10_CUDA_KERNEL_LAUNCH_CHECK();
763 barrierReady_.record();
764 }
765
getBuffer(size_t rank,const std::vector<int64_t> & sizes,c10::ScalarType dtype,int64_t storageOffset)766 at::Tensor IntraNodeComm::getBuffer(
767 size_t rank,
768 const std::vector<int64_t>& sizes,
769 c10::ScalarType dtype,
770 int64_t storageOffset) {
771 return symmetricMemory_->get_buffer(rank, sizes, dtype, storageOffset);
772 }
773
774 } // namespace intra_node_comm
775 } // namespace c10d
776