1 #include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
2
3 #include <ATen/ceil_div.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <c10/cuda/CUDACachingAllocator.h>
6 #include <c10/cuda/CUDAGuard.h>
7
8 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
9 #include <c10/cuda/driver_api.h>
10 #endif
11
12 #include <sys/socket.h>
13 #include <sys/syscall.h>
14 #include <sys/un.h>
15 #include <unistd.h>
16
17 #if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
18 #define CUDART_SUPPORTS_MULTICAST
19 #endif
20
21 namespace {
22
device_has_multicast_support(int device_idx)23 bool device_has_multicast_support(int device_idx) {
24 #if defined(CUDART_SUPPORTS_MULTICAST)
25 if (c10::utils::check_env("TORCH_SYMM_MEM_DISABLE_MULTICAST") == true) {
26 return false;
27 }
28 // Multicast support requirements:
29 // - CUDA Runtime version >= 12030: Checked at compile time using
30 // CUDART_VERSION.
31 // - Driver version >= 535: Checked at runtime by verifying the existence of
32 // cuMulticastCreate_.
33 // - Device support: Determined by querying
34 // CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED at runtime.
35 auto driver_api = c10::cuda::DriverAPI::get();
36 int multicast_supported;
37 C10_CUDA_DRIVER_CHECK(driver_api->cuDeviceGetAttribute_(
38 &multicast_supported,
39 CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED,
40 device_idx));
41 return driver_api->cuMulticastCreate_ != nullptr && multicast_supported;
42 #else
43 return false;
44 #endif
45 }
46
47 class IpcChannel {
48 public:
IpcChannel()49 IpcChannel() : socket_name_(get_socket_name(getpid())) {
50 TORCH_CHECK(
51 (socket_ = socket(AF_UNIX, SOCK_DGRAM, 0)) != 0,
52 "Failed to create socket: ",
53 strerror(errno));
54
55 struct sockaddr_un addr = {.sun_family = AF_UNIX};
56 std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path);
57
58 TORCH_CHECK(
59 bind(socket_, (struct sockaddr*)&addr, SUN_LEN(&addr)) == 0,
60 "Failed to bind socket: ",
61 strerror(errno));
62 }
63
~IpcChannel()64 ~IpcChannel() {
65 close(socket_);
66 unlink(socket_name_.c_str());
67 }
68
send_fd(int dst_pid,int fd)69 void send_fd(int dst_pid, int fd) {
70 struct sockaddr_un addr = {.sun_family = AF_UNIX};
71 auto socket_name = get_socket_name(dst_pid);
72 std::copy(socket_name.begin(), socket_name.end(), addr.sun_path);
73
74 struct iovec io = {.iov_base = (void*)("fd"), .iov_len = 2};
75
76 char cbuf[CMSG_SPACE(sizeof(int))];
77 memset(cbuf, 0, sizeof(cbuf));
78
79 struct msghdr msg {
80 .msg_name = (void*)&addr, .msg_namelen = sizeof(struct sockaddr_un),
81 .msg_iov = &io, .msg_iovlen = 1, .msg_control = cbuf,
82 .msg_controllen = sizeof(cbuf)
83 };
84
85 auto cmsg = CMSG_FIRSTHDR(&msg);
86 cmsg->cmsg_len = CMSG_LEN(sizeof(int));
87 cmsg->cmsg_level = SOL_SOCKET;
88 cmsg->cmsg_type = SCM_RIGHTS;
89
90 if (fd != -1) {
91 // memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));
92 std::copy(
93 reinterpret_cast<const char*>(&fd),
94 reinterpret_cast<const char*>(&fd) + sizeof(fd),
95 reinterpret_cast<char*>(CMSG_DATA(cmsg)));
96 } else {
97 msg.msg_controllen = 0;
98 }
99
100 TORCH_CHECK(
101 sendmsg(socket_, &msg, 0) > 0, "Failed to send fd: ", strerror(errno));
102 }
103
recv_fd()104 int recv_fd() {
105 char buf[2];
106 struct iovec io = {.iov_base = (void*)buf, .iov_len = sizeof(buf)};
107
108 char cbuf[CMSG_SPACE(sizeof(int))];
109 memset(cbuf, 0, sizeof(cbuf));
110
111 struct msghdr msg = {
112 .msg_iov = &io,
113 .msg_iovlen = 1,
114 .msg_control = cbuf,
115 .msg_controllen = sizeof(cbuf)};
116
117 TORCH_CHECK(
118 recvmsg(socket_, &msg, 0) > 0,
119 "Failed to receive fd: ",
120 strerror(errno));
121
122 if (msg.msg_controllen == 0) {
123 return -1;
124 }
125
126 auto cmsg = CMSG_FIRSTHDR(&msg);
127 TORCH_CHECK(cmsg != NULL);
128 TORCH_CHECK(cmsg->cmsg_len == CMSG_LEN(sizeof(int)));
129 TORCH_CHECK(
130 cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS);
131 return *reinterpret_cast<int*>(CMSG_DATA(cmsg));
132 }
133
all_gather_fds(int rank,const std::vector<int> & pids,int fd)134 std::vector<int> all_gather_fds(
135 int rank,
136 const std::vector<int>& pids,
137 int fd) {
138 size_t world_size = pids.size();
139 std::vector<int> fds(pids.size());
140 fds[rank] = fd;
141
142 int dst_rank = (rank + 1) % world_size;
143 for (size_t step = 1; step < world_size; ++step) {
144 int src_rank = (rank + world_size - step) % world_size;
145 send_fd(pids[dst_rank], fd);
146 fd = recv_fd();
147 fds[src_rank] = fd;
148 }
149 return fds;
150 }
151
broadcast_fds(int rank,int src_rank,const std::vector<int> & pids,int fd)152 int broadcast_fds(
153 int rank,
154 int src_rank,
155 const std::vector<int>& pids,
156 int fd) {
157 size_t world_size = pids.size();
158
159 if (rank == src_rank) {
160 for (int dst_rank = 0; dst_rank < (int)world_size; ++dst_rank) {
161 if (dst_rank == rank) {
162 continue;
163 }
164 send_fd(pids[dst_rank], fd);
165 }
166 return fd;
167 }
168 return recv_fd();
169 }
170
171 private:
get_socket_name(int pid)172 static std::string get_socket_name(int pid) {
173 const char* tmp_dir = "/tmp";
174 for (const char* env_var : {"TMPDIR", "TMP", "TEMP", "TEMPDIR"}) {
175 if (const char* path = getenv(env_var)) {
176 tmp_dir = path;
177 break;
178 }
179 }
180 std::ostringstream oss;
181 oss << tmp_dir << "/symm_mem-" << pid;
182 return oss.str();
183 }
184
185 std::string socket_name_;
186 int socket_;
187 };
188
189 constexpr size_t signal_pad_size = 2048;
190 const std::string store_comm_prefix = "CUDASymmetricMemory";
191
192 static size_t store_comm_seq_id = 0;
193
194 template <typename T>
store_all_gather(const c10::intrusive_ptr<c10d::Store> & store,int rank,int world_size,T val)195 std::vector<T> store_all_gather(
196 const c10::intrusive_ptr<c10d::Store>& store,
197 int rank,
198 int world_size,
199 T val) {
200 static_assert(std::is_trivially_copyable_v<T>);
201
202 std::vector<std::string> peer_keys;
203 for (int r = 0; r < world_size; ++r) {
204 std::ostringstream oss;
205 oss << store_comm_prefix << "/" << store_comm_seq_id << "/" << r;
206 peer_keys.push_back(oss.str());
207 }
208 ++store_comm_seq_id;
209
210 {
211 std::vector<uint8_t> payload(
212 reinterpret_cast<uint8_t*>(&val),
213 reinterpret_cast<uint8_t*>(&val) + sizeof(T));
214 store->set(peer_keys[rank], payload);
215 }
216
217 std::vector<T> peer_vals;
218 for (int r = 0; r < world_size; ++r) {
219 if (r == rank) {
220 peer_vals.push_back(val);
221 continue;
222 }
223 store->wait({peer_keys[r]});
224 auto payload = store->get(peer_keys[r]);
225 TORCH_CHECK(payload.size() == sizeof(T));
226 T peer_val{};
227 std::memcpy(&peer_val, payload.data(), sizeof(T));
228 peer_vals.push_back(peer_val);
229 }
230 return peer_vals;
231 }
232
store_barrier(const c10::intrusive_ptr<c10d::Store> & store,int rank,int world_size)233 void store_barrier(
234 const c10::intrusive_ptr<c10d::Store>& store,
235 int rank,
236 int world_size) {
237 store_all_gather(store, rank, world_size, 0);
238 }
239
map_block(void ** ptr,c10d::symmetric_memory::HandleType handle,size_t size,int device_idx)240 void map_block(
241 void** ptr,
242 c10d::symmetric_memory::HandleType handle,
243 size_t size,
244 int device_idx) {
245 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
246 auto driver_api = c10::cuda::DriverAPI::get();
247 auto dev_ptr = reinterpret_cast<CUdeviceptr*>(ptr);
248 C10_CUDA_DRIVER_CHECK(
249 driver_api->cuMemAddressReserve_(dev_ptr, size, 0ULL, 0, 0ULL));
250 C10_CUDA_DRIVER_CHECK(driver_api->cuMemMap_(*dev_ptr, size, 0, handle, 0ULL));
251
252 CUmemAccessDesc desc;
253 desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
254 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
255 desc.location.id = static_cast<int>(device_idx);
256 desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
257 C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1));
258 #else
259 TORCH_CHECK(
260 false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
261 #endif
262 }
263
264 } // namespace
265
266 namespace c10d {
267 namespace symmetric_memory {
268
CUDASymmetricMemory(std::vector<HandleType> handles,size_t block_size,std::vector<void * > buffers,std::vector<void * > signal_pads,HandleType mc_handle,void * mc_addr,size_t buffer_size,int local_device_idx,int rank,int world_size)269 CUDASymmetricMemory::CUDASymmetricMemory(
270 std::vector<HandleType> handles,
271 size_t block_size,
272 std::vector<void*> buffers,
273 std::vector<void*> signal_pads,
274 HandleType mc_handle,
275 void* mc_addr,
276 size_t buffer_size,
277 int local_device_idx,
278 int rank,
279 int world_size)
280 : handles_(std::move(handles)),
281 block_size_(block_size),
282 buffers_(std::move(buffers)),
283 signal_pads_(std::move(signal_pads)),
284 mc_handle_(mc_handle),
285 mc_addr_(mc_addr),
286 buffer_size_(buffer_size),
287 local_device_idx_(local_device_idx),
288 rank_(rank),
289 world_size_(world_size) {
290 const size_t arr_size = sizeof(void*) * world_size_;
291 buffers_dev_ = reinterpret_cast<void**>(
292 c10::cuda::CUDACachingAllocator::raw_alloc(arr_size));
293 signal_pads_dev_ = reinterpret_cast<void**>(
294 c10::cuda::CUDACachingAllocator::raw_alloc(arr_size));
295
296 c10::cuda::CUDAGuard guard(local_device_idx);
297 AT_CUDA_CHECK(cudaMemcpy(
298 buffers_dev_, buffers_.data(), arr_size, cudaMemcpyHostToDevice));
299 AT_CUDA_CHECK(cudaMemcpy(
300 signal_pads_dev_, signal_pads_.data(), arr_size, cudaMemcpyHostToDevice));
301 }
302
~CUDASymmetricMemory()303 CUDASymmetricMemory::~CUDASymmetricMemory() {
304 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
305 // Leak the cuda allocations during static deinitialization
306 if (is_finalizing()) {
307 return;
308 }
309 c10::cuda::CUDAGuard guard(local_device_idx_);
310 C10_CUDA_CHECK(cudaDeviceSynchronize());
311
312 auto driver_api = c10::cuda::DriverAPI::get();
313 for (int r = 0; r < world_size_; ++r) {
314 C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_(
315 reinterpret_cast<CUdeviceptr>(buffers_[r]), block_size_));
316 C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handles_[r]));
317 }
318 c10::cuda::CUDACachingAllocator::raw_delete(buffers_dev_);
319 c10::cuda::CUDACachingAllocator::raw_delete(signal_pads_dev_);
320 #else
321 TORCH_CHECK(
322 false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
323 #endif
324 }
325
get_buffer_ptrs()326 std::vector<void*> CUDASymmetricMemory::get_buffer_ptrs() {
327 return buffers_;
328 }
329
get_signal_pad_ptrs()330 std::vector<void*> CUDASymmetricMemory::get_signal_pad_ptrs() {
331 return signal_pads_;
332 }
333
get_buffer_ptrs_dev()334 void** CUDASymmetricMemory::get_buffer_ptrs_dev() {
335 return buffers_dev_;
336 }
337
get_signal_pad_ptrs_dev()338 void** CUDASymmetricMemory::get_signal_pad_ptrs_dev() {
339 return signal_pads_dev_;
340 }
341
get_buffer_size()342 size_t CUDASymmetricMemory::get_buffer_size() {
343 return buffer_size_;
344 }
345
get_signal_pad_size()346 size_t CUDASymmetricMemory::get_signal_pad_size() {
347 return signal_pad_size;
348 }
349
has_multicast_support()350 bool CUDASymmetricMemory::has_multicast_support() {
351 return mc_addr_ != nullptr;
352 }
353
get_multicast_ptr()354 void* CUDASymmetricMemory::get_multicast_ptr() {
355 return mc_addr_;
356 }
357
get_buffer(int rank,c10::IntArrayRef sizes,c10::ScalarType dtype,int64_t storage_offset)358 at::Tensor CUDASymmetricMemory::get_buffer(
359 int rank,
360 c10::IntArrayRef sizes,
361 c10::ScalarType dtype,
362 int64_t storage_offset) {
363 const auto numel =
364 std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<int>());
365 const auto element_size = c10::elementSize(dtype);
366 const auto req_size = (numel + storage_offset) * element_size;
367 TORCH_CHECK(
368 req_size <= buffer_size_,
369 "CUDASymmetricMemory::get_buffer: the requested size (",
370 req_size,
371 " bytes) exceeds the allocated size (",
372 buffer_size_,
373 " bytes)");
374 auto device = c10::Device(c10::DeviceType::CUDA, local_device_idx_);
375 auto options = at::TensorOptions().dtype(dtype).device(device);
376 return at::for_blob(buffers_[rank], sizes)
377 .storage_offset(storage_offset)
378 .options(options)
379 .target_device(device)
380 .make_tensor();
381 }
382
check_channel(int channel,int world_size)383 void check_channel(int channel, int world_size) {
384 TORCH_CHECK(
385 channel >= 0,
386 "channel for barrier(), put_signal() and wait_signal() ",
387 "must be greater than 0 (got ",
388 channel,
389 ")");
390 const size_t num_channels = signal_pad_size / sizeof(uint32_t) * world_size;
391 TORCH_CHECK(
392 static_cast<size_t>(channel) < num_channels,
393 "The maximum supported channel for barrier(), put_signal() and wait_signal() is ",
394 num_channels - 1,
395 " (got ",
396 channel,
397 ")");
398 }
399
release_signal(uint32_t * addr)400 __device__ __forceinline__ void release_signal(uint32_t* addr) {
401 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
402 CUDA_KERNEL_ASSERT(false);
403 #else
404 volatile uint32_t* signal = addr;
405 uint32_t val;
406 do {
407 val = *signal;
408 } while (val != 0 || atomicCAS_system(addr, 0, 1) != 0);
409 #endif
410 }
411
acquire_signal(uint32_t * addr)412 __device__ __forceinline__ void acquire_signal(uint32_t* addr) {
413 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
414 CUDA_KERNEL_ASSERT(false);
415 #else
416 volatile uint32_t* signal = addr;
417 uint32_t val;
418 do {
419 val = *signal;
420 } while (val != 1 || atomicCAS_system(addr, 1, 0) != 1);
421 #endif
422 }
423
barrier_kernel(uint32_t ** signal_pads,int channel,int rank,int world_size)424 static __global__ void barrier_kernel(
425 uint32_t** signal_pads,
426 int channel,
427 int rank,
428 int world_size) {
429 if (threadIdx.x < world_size) {
430 auto target_rank = threadIdx.x;
431 release_signal(signal_pads[target_rank] + world_size * channel + rank);
432 acquire_signal(signal_pads[rank] + world_size * channel + target_rank);
433 }
434 }
435
barrier(int channel)436 void CUDASymmetricMemory::barrier(int channel) {
437 check_channel(channel, world_size_);
438 c10::cuda::CUDAGuard guard(local_device_idx_);
439 barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
440 reinterpret_cast<uint32_t**>(signal_pads_dev_),
441 channel,
442 rank_,
443 world_size_);
444 C10_CUDA_KERNEL_LAUNCH_CHECK();
445 }
446
put_signal_kernel(uint32_t ** signal_pads,int dst_rank,int channel,int rank,int world_size)447 static __global__ void put_signal_kernel(
448 uint32_t** signal_pads,
449 int dst_rank,
450 int channel,
451 int rank,
452 int world_size) {
453 if (threadIdx.x == 0) {
454 release_signal(signal_pads[dst_rank] + world_size * channel + rank);
455 }
456 }
457
put_signal(int dst_rank,int channel)458 void CUDASymmetricMemory::put_signal(int dst_rank, int channel) {
459 check_channel(channel, world_size_);
460 c10::cuda::CUDAGuard guard(local_device_idx_);
461 put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
462 reinterpret_cast<uint32_t**>(signal_pads_dev_),
463 dst_rank,
464 channel,
465 rank_,
466 world_size_);
467 C10_CUDA_KERNEL_LAUNCH_CHECK();
468 }
469
wait_signal_kernel(uint32_t ** signal_pads,int src_rank,int channel,int rank,int world_size)470 static __global__ void wait_signal_kernel(
471 uint32_t** signal_pads,
472 int src_rank,
473 int channel,
474 int rank,
475 int world_size) {
476 if (threadIdx.x == 0) {
477 acquire_signal(signal_pads[rank] + world_size * channel + src_rank);
478 }
479 __threadfence_system();
480 }
481
wait_signal(int src_rank,int channel)482 void CUDASymmetricMemory::wait_signal(int src_rank, int channel) {
483 check_channel(channel, world_size_);
484 c10::cuda::CUDAGuard guard(local_device_idx_);
485 wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
486 reinterpret_cast<uint32_t**>(signal_pads_dev_),
487 src_rank,
488 channel,
489 rank_,
490 world_size_);
491 C10_CUDA_KERNEL_LAUNCH_CHECK();
492 }
493
get_rank()494 int CUDASymmetricMemory::get_rank() {
495 return rank_;
496 }
497
get_world_size()498 int CUDASymmetricMemory::get_world_size() {
499 return world_size_;
500 }
501
alloc(size_t size,int device_idx,const std::string & group_name)502 void* CUDASymmetricMemoryAllocator::alloc(
503 size_t size,
504 int device_idx,
505 const std::string& group_name) {
506 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
507 auto driver_api = c10::cuda::DriverAPI::get();
508
509 CUmemAllocationProp prop = {};
510 prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
511 prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
512 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
513 prop.location.id = device_idx;
514 prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
515
516 size_t signal_pad_offset = at::round_up(size, 16UL);
517 size_t block_size = signal_pad_offset + signal_pad_size;
518
519 size_t granularity;
520 C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_(
521 &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
522 block_size = at::round_up(block_size, granularity);
523
524 HandleType handle;
525 C10_CUDA_DRIVER_CHECK(
526 driver_api->cuMemCreate_(&handle, block_size, &prop, 0));
527
528 void* ptr = nullptr;
529 map_block(&ptr, handle, block_size, device_idx);
530
531 c10::cuda::CUDAGuard guard(device_idx);
532 AT_CUDA_CHECK(cudaMemset(ptr, 0, block_size));
533
534 auto block = c10::make_intrusive<Block>(
535 handle, device_idx, block_size, size, signal_pad_offset, group_name);
536 {
537 std::unique_lock lock(mutex_);
538 ptr_to_block_.emplace(ptr, std::move(block));
539 }
540 return ptr;
541 #else
542 TORCH_CHECK(
543 false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
544 #endif
545 }
546
free(void * ptr)547 void CUDASymmetricMemoryAllocator::free(void* ptr) {
548 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
549 auto block = find_block(ptr);
550 // Leak the cuda allocations during static deinitialization
551 if (block == nullptr || is_finalizing()) {
552 return;
553 }
554 // Initializing CUDASymmetricMemory with an allocation transfers its
555 // ownership to the CUDASymmetricMemory object.
556 if (block->symm_mem == nullptr) {
557 auto driver_api = c10::cuda::DriverAPI::get();
558 C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_(
559 reinterpret_cast<CUdeviceptr>(ptr), block->block_size));
560 C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(block->handle));
561 }
562 {
563 std::unique_lock lock(mutex_);
564 ptr_to_block_.erase(ptr);
565 }
566 #else
567 TORCH_CHECK(
568 false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
569 #endif
570 }
571
get_alloc_size(void * ptr)572 size_t CUDASymmetricMemoryAllocator::get_alloc_size(void* ptr) {
573 auto block = find_block(ptr);
574 TORCH_CHECK(
575 block != nullptr,
576 "CUDASymmetricMemoryAllocator::get_alloc_size: input must be allocated ",
577 "via CUDASymmetricMemoryAllocator::alloc");
578 return block->buffer_size;
579 }
580
581 struct RendezvousRequest {
582 int device_idx;
583 int pid;
584 size_t block_size;
585 size_t buffer_size;
586 size_t signal_pad_offset;
587 bool has_multicast_support;
588 };
589
validate_rendezvous_requests(const std::vector<RendezvousRequest> & reqs,int world_size)590 void validate_rendezvous_requests(
591 const std::vector<RendezvousRequest>& reqs,
592 int world_size) {
593 TORCH_CHECK(reqs.size() == (size_t)world_size);
594
595 std::unordered_set<int> device_indices;
596 device_indices.reserve(world_size);
597 for (auto req : reqs) {
598 device_indices.insert(req.device_idx);
599 }
600 if (device_indices.size() < (size_t)world_size) {
601 TORCH_CHECK(
602 false,
603 "CUDASymmetricMemoryAllocator::rendezvous: ",
604 "detected allocations from overlapping devices ",
605 "from different ranks.");
606 }
607
608 for (int r = 1; r < world_size; ++r) {
609 TORCH_CHECK(reqs[r].block_size == reqs[0].block_size);
610 TORCH_CHECK(reqs[r].buffer_size == reqs[0].buffer_size);
611 TORCH_CHECK(reqs[r].signal_pad_offset == reqs[0].signal_pad_offset);
612 }
613 }
614
check_group_multicast_support(const std::vector<RendezvousRequest> & reqs)615 static bool check_group_multicast_support(
616 const std::vector<RendezvousRequest>& reqs) {
617 std::vector<size_t> ranks_with_multicast_support;
618 for (size_t r = 0; r < reqs.size(); ++r) {
619 if (reqs[r].has_multicast_support) {
620 ranks_with_multicast_support.push_back(r);
621 }
622 }
623 if (ranks_with_multicast_support.size() == reqs.size()) {
624 return true;
625 } else {
626 // We don't expect this to happen. But we want to let the user to know if
627 // this happens.
628 if (ranks_with_multicast_support.size() != 0) {
629 LOG(WARNING)
630 << "Only a subset of ranks in the group has multicast support: "
631 << ranks_with_multicast_support << " (world_size=" << reqs.size()
632 << "). Skipping multicast initialization because this is unexpected.";
633 }
634 return false;
635 }
636 }
637
init_multicast_for_block(HandleType & mc_handle,void * & mc_addr,const c10::intrusive_ptr<Block> & block,IpcChannel & ipc_channel,const std::vector<int> & pids,const c10::intrusive_ptr<c10d::Store> & store,int rank,int world_size)638 static void init_multicast_for_block(
639 HandleType& mc_handle,
640 void*& mc_addr,
641 const c10::intrusive_ptr<Block>& block,
642 IpcChannel& ipc_channel,
643 const std::vector<int>& pids,
644 const c10::intrusive_ptr<c10d::Store>& store,
645 int rank,
646 int world_size) {
647 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) && \
648 defined(CUDART_SUPPORTS_MULTICAST)
649 auto driver_api = c10::cuda::DriverAPI::get();
650 if (rank == 0) {
651 CUmulticastObjectProp mc_prop{};
652 mc_prop.numDevices = world_size;
653 mc_prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
654 mc_prop.size = block->block_size;
655
656 auto err = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop);
657 if (err != CUDA_SUCCESS) {
658 const char* err_str;
659 CUresult get_error_str_err = driver_api->cuGetErrorString_(err, &err_str);
660 if (get_error_str_err != CUDA_SUCCESS) {
661 err_str = "unknown cuda driver error";
662 }
663 LOG(WARNING)
664 << "SymmetricMemory: cuMulticastCreate failed with: \"" << err_str
665 << "\". Gracefully skipping multicast initialization. "
666 << "However, this is unexpected. Please report the issue on GitHub.";
667 // Allow peers gracefully skip multicast initialization by sending -1
668 ipc_channel.broadcast_fds(rank, 0, pids, -1);
669 return;
670 }
671
672 int mc_fd;
673 C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
674 &mc_fd, mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
675 ipc_channel.broadcast_fds(rank, 0, pids, mc_fd);
676 // Ref count is incremented as soon as SCM_RIGHTS send happens
677 close(mc_fd);
678 } else {
679 int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1);
680 if (mc_fd == -1) {
681 return;
682 }
683 C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
684 &mc_handle,
685 (void*)(uintptr_t)mc_fd,
686 CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
687 close(mc_fd);
688 }
689
690 // All rank adds their physical allocation to the multicast object
691 C10_CUDA_DRIVER_CHECK(
692 driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx));
693 C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_(
694 mc_handle, 0, block->handle, 0, block->block_size, 0));
695
696 map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
697 store_barrier(store, rank, world_size);
698 #endif
699 }
700
rendezvous(void * ptr)701 c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
702 void* ptr) {
703 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
704 auto block = find_block(ptr);
705 if (block == nullptr) {
706 return nullptr;
707 }
708
709 if (block->symm_mem != nullptr) {
710 return block->symm_mem;
711 }
712
713 IpcChannel ipc_channel;
714 auto group_info = get_group_info(block->group_name);
715 auto store = group_info.store;
716 int rank = group_info.rank;
717 int world_size = group_info.world_size;
718
719 auto driver_api = c10::cuda::DriverAPI::get();
720 int block_fd;
721 C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
722 &block_fd, block->handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
723
724 auto local_req = RendezvousRequest{
725 .device_idx = block->device_idx,
726 .pid = getpid(),
727 .block_size = block->block_size,
728 .buffer_size = block->buffer_size,
729 .signal_pad_offset = block->signal_pad_offset,
730 .has_multicast_support = device_has_multicast_support(block->device_idx)};
731 auto reqs = store_all_gather(store, rank, world_size, local_req);
732 validate_rendezvous_requests(reqs, world_size);
733
734 std::vector<int> pids(world_size);
735 for (int r = 0; r < world_size; ++r) {
736 pids[r] = reqs[r].pid;
737 }
738 auto imported_fds = ipc_channel.all_gather_fds(rank, pids, block_fd);
739
740 std::vector<HandleType> handles(world_size);
741 std::vector<void*> buffers(world_size, nullptr);
742 std::vector<void*> signal_pads(world_size, nullptr);
743
744 for (int r = 0; r < world_size; ++r) {
745 if (r == rank) {
746 handles[r] = block->handle;
747 buffers[r] = ptr;
748 signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset);
749 continue;
750 }
751 C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
752 &handles[r],
753 (void*)(uintptr_t)imported_fds[r],
754 CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
755 map_block(&buffers[r], handles[r], block->block_size, block->device_idx);
756 signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset);
757 close(imported_fds[r]);
758 }
759 store_barrier(store, rank, world_size);
760 close(block_fd);
761
762 HandleType mc_handle{};
763 void* mc_addr = nullptr;
764 bool group_has_multicast_support = check_group_multicast_support(reqs);
765 if (group_has_multicast_support) {
766 init_multicast_for_block(
767 mc_handle, mc_addr, block, ipc_channel, pids, store, rank, world_size);
768 }
769
770 // Initializing CUDASymmetricMemory with an allocation transfers its
771 // ownership to the CUDASymmetricMemory object. So that outstanding
772 // references to the CUDASymmetricMemory object can keep the allocation
773 // alive.
774 block->symm_mem = c10::make_intrusive<CUDASymmetricMemory>(
775 std::move(handles),
776 block->block_size,
777 std::move(buffers),
778 std::move(signal_pads),
779 mc_handle,
780 mc_addr,
781 block->buffer_size,
782 block->device_idx,
783 group_info.rank,
784 group_info.world_size);
785 return block->symm_mem;
786 #else
787 TORCH_CHECK(
788 false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
789 #endif
790 }
791
is_rendezvous_completed(void * ptr)792 bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) {
793 auto block = find_block(ptr);
794 TORCH_CHECK(
795 block != nullptr,
796 "CUDASymmetricMemoryAllocator::is_rendezvous_completed: input must be allocated ",
797 "via CUDASymmetricMemoryAllocator::alloc");
798 return block->symm_mem != nullptr;
799 }
800
has_multicast_support(int device_idx)801 bool CUDASymmetricMemoryAllocator::has_multicast_support(int device_idx) {
802 return device_has_multicast_support(device_idx);
803 }
804
find_block(void * ptr)805 c10::intrusive_ptr<Block> CUDASymmetricMemoryAllocator::find_block(void* ptr) {
806 std::shared_lock lock(mutex_);
807 auto it = ptr_to_block_.find(ptr);
808 if (it == ptr_to_block_.end()) {
809 return nullptr;
810 }
811 return it->second;
812 }
813
814 struct RegisterCUDASymmetricMemoryAllocator {
RegisterCUDASymmetricMemoryAllocatorc10d::symmetric_memory::RegisterCUDASymmetricMemoryAllocator815 RegisterCUDASymmetricMemoryAllocator() {
816 register_allocator(
817 c10::DeviceType::CUDA,
818 c10::make_intrusive<CUDASymmetricMemoryAllocator>());
819 }
820 };
821
822 static RegisterCUDASymmetricMemoryAllocator register_allocator_;
823
824 } // namespace symmetric_memory
825 } // namespace c10d
826