xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
2 
3 #include <ATen/ceil_div.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <c10/cuda/CUDACachingAllocator.h>
6 #include <c10/cuda/CUDAGuard.h>
7 
8 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
9 #include <c10/cuda/driver_api.h>
10 #endif
11 
12 #include <sys/socket.h>
13 #include <sys/syscall.h>
14 #include <sys/un.h>
15 #include <unistd.h>
16 
17 #if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
18 #define CUDART_SUPPORTS_MULTICAST
19 #endif
20 
21 namespace {
22 
device_has_multicast_support(int device_idx)23 bool device_has_multicast_support(int device_idx) {
24 #if defined(CUDART_SUPPORTS_MULTICAST)
25   if (c10::utils::check_env("TORCH_SYMM_MEM_DISABLE_MULTICAST") == true) {
26     return false;
27   }
28   // Multicast support requirements:
29   // - CUDA Runtime version >= 12030: Checked at compile time using
30   // CUDART_VERSION.
31   // - Driver version >= 535: Checked at runtime by verifying the existence of
32   // cuMulticastCreate_.
33   // - Device support: Determined by querying
34   // CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED at runtime.
35   auto driver_api = c10::cuda::DriverAPI::get();
36   int multicast_supported;
37   C10_CUDA_DRIVER_CHECK(driver_api->cuDeviceGetAttribute_(
38       &multicast_supported,
39       CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED,
40       device_idx));
41   return driver_api->cuMulticastCreate_ != nullptr && multicast_supported;
42 #else
43   return false;
44 #endif
45 }
46 
47 class IpcChannel {
48  public:
IpcChannel()49   IpcChannel() : socket_name_(get_socket_name(getpid())) {
50     TORCH_CHECK(
51         (socket_ = socket(AF_UNIX, SOCK_DGRAM, 0)) != 0,
52         "Failed to create socket: ",
53         strerror(errno));
54 
55     struct sockaddr_un addr = {.sun_family = AF_UNIX};
56     std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path);
57 
58     TORCH_CHECK(
59         bind(socket_, (struct sockaddr*)&addr, SUN_LEN(&addr)) == 0,
60         "Failed to bind socket: ",
61         strerror(errno));
62   }
63 
~IpcChannel()64   ~IpcChannel() {
65     close(socket_);
66     unlink(socket_name_.c_str());
67   }
68 
send_fd(int dst_pid,int fd)69   void send_fd(int dst_pid, int fd) {
70     struct sockaddr_un addr = {.sun_family = AF_UNIX};
71     auto socket_name = get_socket_name(dst_pid);
72     std::copy(socket_name.begin(), socket_name.end(), addr.sun_path);
73 
74     struct iovec io = {.iov_base = (void*)("fd"), .iov_len = 2};
75 
76     char cbuf[CMSG_SPACE(sizeof(int))];
77     memset(cbuf, 0, sizeof(cbuf));
78 
79     struct msghdr msg {
80       .msg_name = (void*)&addr, .msg_namelen = sizeof(struct sockaddr_un),
81       .msg_iov = &io, .msg_iovlen = 1, .msg_control = cbuf,
82       .msg_controllen = sizeof(cbuf)
83     };
84 
85     auto cmsg = CMSG_FIRSTHDR(&msg);
86     cmsg->cmsg_len = CMSG_LEN(sizeof(int));
87     cmsg->cmsg_level = SOL_SOCKET;
88     cmsg->cmsg_type = SCM_RIGHTS;
89 
90     if (fd != -1) {
91       // memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));
92       std::copy(
93           reinterpret_cast<const char*>(&fd),
94           reinterpret_cast<const char*>(&fd) + sizeof(fd),
95           reinterpret_cast<char*>(CMSG_DATA(cmsg)));
96     } else {
97       msg.msg_controllen = 0;
98     }
99 
100     TORCH_CHECK(
101         sendmsg(socket_, &msg, 0) > 0, "Failed to send fd: ", strerror(errno));
102   }
103 
recv_fd()104   int recv_fd() {
105     char buf[2];
106     struct iovec io = {.iov_base = (void*)buf, .iov_len = sizeof(buf)};
107 
108     char cbuf[CMSG_SPACE(sizeof(int))];
109     memset(cbuf, 0, sizeof(cbuf));
110 
111     struct msghdr msg = {
112         .msg_iov = &io,
113         .msg_iovlen = 1,
114         .msg_control = cbuf,
115         .msg_controllen = sizeof(cbuf)};
116 
117     TORCH_CHECK(
118         recvmsg(socket_, &msg, 0) > 0,
119         "Failed to receive fd: ",
120         strerror(errno));
121 
122     if (msg.msg_controllen == 0) {
123       return -1;
124     }
125 
126     auto cmsg = CMSG_FIRSTHDR(&msg);
127     TORCH_CHECK(cmsg != NULL);
128     TORCH_CHECK(cmsg->cmsg_len == CMSG_LEN(sizeof(int)));
129     TORCH_CHECK(
130         cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS);
131     return *reinterpret_cast<int*>(CMSG_DATA(cmsg));
132   }
133 
all_gather_fds(int rank,const std::vector<int> & pids,int fd)134   std::vector<int> all_gather_fds(
135       int rank,
136       const std::vector<int>& pids,
137       int fd) {
138     size_t world_size = pids.size();
139     std::vector<int> fds(pids.size());
140     fds[rank] = fd;
141 
142     int dst_rank = (rank + 1) % world_size;
143     for (size_t step = 1; step < world_size; ++step) {
144       int src_rank = (rank + world_size - step) % world_size;
145       send_fd(pids[dst_rank], fd);
146       fd = recv_fd();
147       fds[src_rank] = fd;
148     }
149     return fds;
150   }
151 
broadcast_fds(int rank,int src_rank,const std::vector<int> & pids,int fd)152   int broadcast_fds(
153       int rank,
154       int src_rank,
155       const std::vector<int>& pids,
156       int fd) {
157     size_t world_size = pids.size();
158 
159     if (rank == src_rank) {
160       for (int dst_rank = 0; dst_rank < (int)world_size; ++dst_rank) {
161         if (dst_rank == rank) {
162           continue;
163         }
164         send_fd(pids[dst_rank], fd);
165       }
166       return fd;
167     }
168     return recv_fd();
169   }
170 
171  private:
get_socket_name(int pid)172   static std::string get_socket_name(int pid) {
173     const char* tmp_dir = "/tmp";
174     for (const char* env_var : {"TMPDIR", "TMP", "TEMP", "TEMPDIR"}) {
175       if (const char* path = getenv(env_var)) {
176         tmp_dir = path;
177         break;
178       }
179     }
180     std::ostringstream oss;
181     oss << tmp_dir << "/symm_mem-" << pid;
182     return oss.str();
183   }
184 
185   std::string socket_name_;
186   int socket_;
187 };
188 
189 constexpr size_t signal_pad_size = 2048;
190 const std::string store_comm_prefix = "CUDASymmetricMemory";
191 
192 static size_t store_comm_seq_id = 0;
193 
194 template <typename T>
store_all_gather(const c10::intrusive_ptr<c10d::Store> & store,int rank,int world_size,T val)195 std::vector<T> store_all_gather(
196     const c10::intrusive_ptr<c10d::Store>& store,
197     int rank,
198     int world_size,
199     T val) {
200   static_assert(std::is_trivially_copyable_v<T>);
201 
202   std::vector<std::string> peer_keys;
203   for (int r = 0; r < world_size; ++r) {
204     std::ostringstream oss;
205     oss << store_comm_prefix << "/" << store_comm_seq_id << "/" << r;
206     peer_keys.push_back(oss.str());
207   }
208   ++store_comm_seq_id;
209 
210   {
211     std::vector<uint8_t> payload(
212         reinterpret_cast<uint8_t*>(&val),
213         reinterpret_cast<uint8_t*>(&val) + sizeof(T));
214     store->set(peer_keys[rank], payload);
215   }
216 
217   std::vector<T> peer_vals;
218   for (int r = 0; r < world_size; ++r) {
219     if (r == rank) {
220       peer_vals.push_back(val);
221       continue;
222     }
223     store->wait({peer_keys[r]});
224     auto payload = store->get(peer_keys[r]);
225     TORCH_CHECK(payload.size() == sizeof(T));
226     T peer_val{};
227     std::memcpy(&peer_val, payload.data(), sizeof(T));
228     peer_vals.push_back(peer_val);
229   }
230   return peer_vals;
231 }
232 
store_barrier(const c10::intrusive_ptr<c10d::Store> & store,int rank,int world_size)233 void store_barrier(
234     const c10::intrusive_ptr<c10d::Store>& store,
235     int rank,
236     int world_size) {
237   store_all_gather(store, rank, world_size, 0);
238 }
239 
map_block(void ** ptr,c10d::symmetric_memory::HandleType handle,size_t size,int device_idx)240 void map_block(
241     void** ptr,
242     c10d::symmetric_memory::HandleType handle,
243     size_t size,
244     int device_idx) {
245 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
246   auto driver_api = c10::cuda::DriverAPI::get();
247   auto dev_ptr = reinterpret_cast<CUdeviceptr*>(ptr);
248   C10_CUDA_DRIVER_CHECK(
249       driver_api->cuMemAddressReserve_(dev_ptr, size, 0ULL, 0, 0ULL));
250   C10_CUDA_DRIVER_CHECK(driver_api->cuMemMap_(*dev_ptr, size, 0, handle, 0ULL));
251 
252   CUmemAccessDesc desc;
253   desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
254   // NOLINTNEXTLINE(bugprone-signed-char-misuse)
255   desc.location.id = static_cast<int>(device_idx);
256   desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
257   C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1));
258 #else
259   TORCH_CHECK(
260       false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
261 #endif
262 }
263 
264 } // namespace
265 
266 namespace c10d {
267 namespace symmetric_memory {
268 
CUDASymmetricMemory(std::vector<HandleType> handles,size_t block_size,std::vector<void * > buffers,std::vector<void * > signal_pads,HandleType mc_handle,void * mc_addr,size_t buffer_size,int local_device_idx,int rank,int world_size)269 CUDASymmetricMemory::CUDASymmetricMemory(
270     std::vector<HandleType> handles,
271     size_t block_size,
272     std::vector<void*> buffers,
273     std::vector<void*> signal_pads,
274     HandleType mc_handle,
275     void* mc_addr,
276     size_t buffer_size,
277     int local_device_idx,
278     int rank,
279     int world_size)
280     : handles_(std::move(handles)),
281       block_size_(block_size),
282       buffers_(std::move(buffers)),
283       signal_pads_(std::move(signal_pads)),
284       mc_handle_(mc_handle),
285       mc_addr_(mc_addr),
286       buffer_size_(buffer_size),
287       local_device_idx_(local_device_idx),
288       rank_(rank),
289       world_size_(world_size) {
290   const size_t arr_size = sizeof(void*) * world_size_;
291   buffers_dev_ = reinterpret_cast<void**>(
292       c10::cuda::CUDACachingAllocator::raw_alloc(arr_size));
293   signal_pads_dev_ = reinterpret_cast<void**>(
294       c10::cuda::CUDACachingAllocator::raw_alloc(arr_size));
295 
296   c10::cuda::CUDAGuard guard(local_device_idx);
297   AT_CUDA_CHECK(cudaMemcpy(
298       buffers_dev_, buffers_.data(), arr_size, cudaMemcpyHostToDevice));
299   AT_CUDA_CHECK(cudaMemcpy(
300       signal_pads_dev_, signal_pads_.data(), arr_size, cudaMemcpyHostToDevice));
301 }
302 
~CUDASymmetricMemory()303 CUDASymmetricMemory::~CUDASymmetricMemory() {
304 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
305   // Leak the cuda allocations during static deinitialization
306   if (is_finalizing()) {
307     return;
308   }
309   c10::cuda::CUDAGuard guard(local_device_idx_);
310   C10_CUDA_CHECK(cudaDeviceSynchronize());
311 
312   auto driver_api = c10::cuda::DriverAPI::get();
313   for (int r = 0; r < world_size_; ++r) {
314     C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_(
315         reinterpret_cast<CUdeviceptr>(buffers_[r]), block_size_));
316     C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handles_[r]));
317   }
318   c10::cuda::CUDACachingAllocator::raw_delete(buffers_dev_);
319   c10::cuda::CUDACachingAllocator::raw_delete(signal_pads_dev_);
320 #else
321   TORCH_CHECK(
322       false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
323 #endif
324 }
325 
get_buffer_ptrs()326 std::vector<void*> CUDASymmetricMemory::get_buffer_ptrs() {
327   return buffers_;
328 }
329 
get_signal_pad_ptrs()330 std::vector<void*> CUDASymmetricMemory::get_signal_pad_ptrs() {
331   return signal_pads_;
332 }
333 
get_buffer_ptrs_dev()334 void** CUDASymmetricMemory::get_buffer_ptrs_dev() {
335   return buffers_dev_;
336 }
337 
get_signal_pad_ptrs_dev()338 void** CUDASymmetricMemory::get_signal_pad_ptrs_dev() {
339   return signal_pads_dev_;
340 }
341 
get_buffer_size()342 size_t CUDASymmetricMemory::get_buffer_size() {
343   return buffer_size_;
344 }
345 
get_signal_pad_size()346 size_t CUDASymmetricMemory::get_signal_pad_size() {
347   return signal_pad_size;
348 }
349 
has_multicast_support()350 bool CUDASymmetricMemory::has_multicast_support() {
351   return mc_addr_ != nullptr;
352 }
353 
get_multicast_ptr()354 void* CUDASymmetricMemory::get_multicast_ptr() {
355   return mc_addr_;
356 }
357 
get_buffer(int rank,c10::IntArrayRef sizes,c10::ScalarType dtype,int64_t storage_offset)358 at::Tensor CUDASymmetricMemory::get_buffer(
359     int rank,
360     c10::IntArrayRef sizes,
361     c10::ScalarType dtype,
362     int64_t storage_offset) {
363   const auto numel =
364       std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<int>());
365   const auto element_size = c10::elementSize(dtype);
366   const auto req_size = (numel + storage_offset) * element_size;
367   TORCH_CHECK(
368       req_size <= buffer_size_,
369       "CUDASymmetricMemory::get_buffer: the requested size (",
370       req_size,
371       " bytes) exceeds the allocated size (",
372       buffer_size_,
373       " bytes)");
374   auto device = c10::Device(c10::DeviceType::CUDA, local_device_idx_);
375   auto options = at::TensorOptions().dtype(dtype).device(device);
376   return at::for_blob(buffers_[rank], sizes)
377       .storage_offset(storage_offset)
378       .options(options)
379       .target_device(device)
380       .make_tensor();
381 }
382 
check_channel(int channel,int world_size)383 void check_channel(int channel, int world_size) {
384   TORCH_CHECK(
385       channel >= 0,
386       "channel for barrier(), put_signal() and wait_signal() ",
387       "must be greater than 0 (got ",
388       channel,
389       ")");
390   const size_t num_channels = signal_pad_size / sizeof(uint32_t) * world_size;
391   TORCH_CHECK(
392       static_cast<size_t>(channel) < num_channels,
393       "The maximum supported channel for barrier(), put_signal() and wait_signal() is ",
394       num_channels - 1,
395       " (got ",
396       channel,
397       ")");
398 }
399 
release_signal(uint32_t * addr)400 __device__ __forceinline__ void release_signal(uint32_t* addr) {
401 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
402   CUDA_KERNEL_ASSERT(false);
403 #else
404   volatile uint32_t* signal = addr;
405   uint32_t val;
406   do {
407     val = *signal;
408   } while (val != 0 || atomicCAS_system(addr, 0, 1) != 0);
409 #endif
410 }
411 
acquire_signal(uint32_t * addr)412 __device__ __forceinline__ void acquire_signal(uint32_t* addr) {
413 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
414   CUDA_KERNEL_ASSERT(false);
415 #else
416   volatile uint32_t* signal = addr;
417   uint32_t val;
418   do {
419     val = *signal;
420   } while (val != 1 || atomicCAS_system(addr, 1, 0) != 1);
421 #endif
422 }
423 
barrier_kernel(uint32_t ** signal_pads,int channel,int rank,int world_size)424 static __global__ void barrier_kernel(
425     uint32_t** signal_pads,
426     int channel,
427     int rank,
428     int world_size) {
429   if (threadIdx.x < world_size) {
430     auto target_rank = threadIdx.x;
431     release_signal(signal_pads[target_rank] + world_size * channel + rank);
432     acquire_signal(signal_pads[rank] + world_size * channel + target_rank);
433   }
434 }
435 
barrier(int channel)436 void CUDASymmetricMemory::barrier(int channel) {
437   check_channel(channel, world_size_);
438   c10::cuda::CUDAGuard guard(local_device_idx_);
439   barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
440       reinterpret_cast<uint32_t**>(signal_pads_dev_),
441       channel,
442       rank_,
443       world_size_);
444   C10_CUDA_KERNEL_LAUNCH_CHECK();
445 }
446 
put_signal_kernel(uint32_t ** signal_pads,int dst_rank,int channel,int rank,int world_size)447 static __global__ void put_signal_kernel(
448     uint32_t** signal_pads,
449     int dst_rank,
450     int channel,
451     int rank,
452     int world_size) {
453   if (threadIdx.x == 0) {
454     release_signal(signal_pads[dst_rank] + world_size * channel + rank);
455   }
456 }
457 
put_signal(int dst_rank,int channel)458 void CUDASymmetricMemory::put_signal(int dst_rank, int channel) {
459   check_channel(channel, world_size_);
460   c10::cuda::CUDAGuard guard(local_device_idx_);
461   put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
462       reinterpret_cast<uint32_t**>(signal_pads_dev_),
463       dst_rank,
464       channel,
465       rank_,
466       world_size_);
467   C10_CUDA_KERNEL_LAUNCH_CHECK();
468 }
469 
wait_signal_kernel(uint32_t ** signal_pads,int src_rank,int channel,int rank,int world_size)470 static __global__ void wait_signal_kernel(
471     uint32_t** signal_pads,
472     int src_rank,
473     int channel,
474     int rank,
475     int world_size) {
476   if (threadIdx.x == 0) {
477     acquire_signal(signal_pads[rank] + world_size * channel + src_rank);
478   }
479   __threadfence_system();
480 }
481 
wait_signal(int src_rank,int channel)482 void CUDASymmetricMemory::wait_signal(int src_rank, int channel) {
483   check_channel(channel, world_size_);
484   c10::cuda::CUDAGuard guard(local_device_idx_);
485   wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
486       reinterpret_cast<uint32_t**>(signal_pads_dev_),
487       src_rank,
488       channel,
489       rank_,
490       world_size_);
491   C10_CUDA_KERNEL_LAUNCH_CHECK();
492 }
493 
get_rank()494 int CUDASymmetricMemory::get_rank() {
495   return rank_;
496 }
497 
get_world_size()498 int CUDASymmetricMemory::get_world_size() {
499   return world_size_;
500 }
501 
alloc(size_t size,int device_idx,const std::string & group_name)502 void* CUDASymmetricMemoryAllocator::alloc(
503     size_t size,
504     int device_idx,
505     const std::string& group_name) {
506 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
507   auto driver_api = c10::cuda::DriverAPI::get();
508 
509   CUmemAllocationProp prop = {};
510   prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
511   prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
512   // NOLINTNEXTLINE(bugprone-signed-char-misuse)
513   prop.location.id = device_idx;
514   prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
515 
516   size_t signal_pad_offset = at::round_up(size, 16UL);
517   size_t block_size = signal_pad_offset + signal_pad_size;
518 
519   size_t granularity;
520   C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_(
521       &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
522   block_size = at::round_up(block_size, granularity);
523 
524   HandleType handle;
525   C10_CUDA_DRIVER_CHECK(
526       driver_api->cuMemCreate_(&handle, block_size, &prop, 0));
527 
528   void* ptr = nullptr;
529   map_block(&ptr, handle, block_size, device_idx);
530 
531   c10::cuda::CUDAGuard guard(device_idx);
532   AT_CUDA_CHECK(cudaMemset(ptr, 0, block_size));
533 
534   auto block = c10::make_intrusive<Block>(
535       handle, device_idx, block_size, size, signal_pad_offset, group_name);
536   {
537     std::unique_lock lock(mutex_);
538     ptr_to_block_.emplace(ptr, std::move(block));
539   }
540   return ptr;
541 #else
542   TORCH_CHECK(
543       false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
544 #endif
545 }
546 
free(void * ptr)547 void CUDASymmetricMemoryAllocator::free(void* ptr) {
548 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
549   auto block = find_block(ptr);
550   // Leak the cuda allocations during static deinitialization
551   if (block == nullptr || is_finalizing()) {
552     return;
553   }
554   // Initializing CUDASymmetricMemory with an allocation transfers its
555   // ownership to the CUDASymmetricMemory object.
556   if (block->symm_mem == nullptr) {
557     auto driver_api = c10::cuda::DriverAPI::get();
558     C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_(
559         reinterpret_cast<CUdeviceptr>(ptr), block->block_size));
560     C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(block->handle));
561   }
562   {
563     std::unique_lock lock(mutex_);
564     ptr_to_block_.erase(ptr);
565   }
566 #else
567   TORCH_CHECK(
568       false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
569 #endif
570 }
571 
get_alloc_size(void * ptr)572 size_t CUDASymmetricMemoryAllocator::get_alloc_size(void* ptr) {
573   auto block = find_block(ptr);
574   TORCH_CHECK(
575       block != nullptr,
576       "CUDASymmetricMemoryAllocator::get_alloc_size: input must be allocated ",
577       "via CUDASymmetricMemoryAllocator::alloc");
578   return block->buffer_size;
579 }
580 
581 struct RendezvousRequest {
582   int device_idx;
583   int pid;
584   size_t block_size;
585   size_t buffer_size;
586   size_t signal_pad_offset;
587   bool has_multicast_support;
588 };
589 
validate_rendezvous_requests(const std::vector<RendezvousRequest> & reqs,int world_size)590 void validate_rendezvous_requests(
591     const std::vector<RendezvousRequest>& reqs,
592     int world_size) {
593   TORCH_CHECK(reqs.size() == (size_t)world_size);
594 
595   std::unordered_set<int> device_indices;
596   device_indices.reserve(world_size);
597   for (auto req : reqs) {
598     device_indices.insert(req.device_idx);
599   }
600   if (device_indices.size() < (size_t)world_size) {
601     TORCH_CHECK(
602         false,
603         "CUDASymmetricMemoryAllocator::rendezvous: ",
604         "detected allocations from overlapping devices ",
605         "from different ranks.");
606   }
607 
608   for (int r = 1; r < world_size; ++r) {
609     TORCH_CHECK(reqs[r].block_size == reqs[0].block_size);
610     TORCH_CHECK(reqs[r].buffer_size == reqs[0].buffer_size);
611     TORCH_CHECK(reqs[r].signal_pad_offset == reqs[0].signal_pad_offset);
612   }
613 }
614 
check_group_multicast_support(const std::vector<RendezvousRequest> & reqs)615 static bool check_group_multicast_support(
616     const std::vector<RendezvousRequest>& reqs) {
617   std::vector<size_t> ranks_with_multicast_support;
618   for (size_t r = 0; r < reqs.size(); ++r) {
619     if (reqs[r].has_multicast_support) {
620       ranks_with_multicast_support.push_back(r);
621     }
622   }
623   if (ranks_with_multicast_support.size() == reqs.size()) {
624     return true;
625   } else {
626     // We don't expect this to happen. But we want to let the user to know if
627     // this happens.
628     if (ranks_with_multicast_support.size() != 0) {
629       LOG(WARNING)
630           << "Only a subset of ranks in the group has multicast support: "
631           << ranks_with_multicast_support << " (world_size=" << reqs.size()
632           << "). Skipping multicast initialization because this is unexpected.";
633     }
634     return false;
635   }
636 }
637 
init_multicast_for_block(HandleType & mc_handle,void * & mc_addr,const c10::intrusive_ptr<Block> & block,IpcChannel & ipc_channel,const std::vector<int> & pids,const c10::intrusive_ptr<c10d::Store> & store,int rank,int world_size)638 static void init_multicast_for_block(
639     HandleType& mc_handle,
640     void*& mc_addr,
641     const c10::intrusive_ptr<Block>& block,
642     IpcChannel& ipc_channel,
643     const std::vector<int>& pids,
644     const c10::intrusive_ptr<c10d::Store>& store,
645     int rank,
646     int world_size) {
647 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) && \
648     defined(CUDART_SUPPORTS_MULTICAST)
649   auto driver_api = c10::cuda::DriverAPI::get();
650   if (rank == 0) {
651     CUmulticastObjectProp mc_prop{};
652     mc_prop.numDevices = world_size;
653     mc_prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
654     mc_prop.size = block->block_size;
655 
656     auto err = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop);
657     if (err != CUDA_SUCCESS) {
658       const char* err_str;
659       CUresult get_error_str_err = driver_api->cuGetErrorString_(err, &err_str);
660       if (get_error_str_err != CUDA_SUCCESS) {
661         err_str = "unknown cuda driver error";
662       }
663       LOG(WARNING)
664           << "SymmetricMemory: cuMulticastCreate failed with: \"" << err_str
665           << "\". Gracefully skipping multicast initialization. "
666           << "However, this is unexpected. Please report the issue on GitHub.";
667       // Allow peers gracefully skip multicast initialization by sending -1
668       ipc_channel.broadcast_fds(rank, 0, pids, -1);
669       return;
670     }
671 
672     int mc_fd;
673     C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
674         &mc_fd, mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
675     ipc_channel.broadcast_fds(rank, 0, pids, mc_fd);
676     // Ref count is incremented as soon as SCM_RIGHTS send happens
677     close(mc_fd);
678   } else {
679     int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1);
680     if (mc_fd == -1) {
681       return;
682     }
683     C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
684         &mc_handle,
685         (void*)(uintptr_t)mc_fd,
686         CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
687     close(mc_fd);
688   }
689 
690   // All rank adds their physical allocation to the multicast object
691   C10_CUDA_DRIVER_CHECK(
692       driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx));
693   C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_(
694       mc_handle, 0, block->handle, 0, block->block_size, 0));
695 
696   map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
697   store_barrier(store, rank, world_size);
698 #endif
699 }
700 
rendezvous(void * ptr)701 c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
702     void* ptr) {
703 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
704   auto block = find_block(ptr);
705   if (block == nullptr) {
706     return nullptr;
707   }
708 
709   if (block->symm_mem != nullptr) {
710     return block->symm_mem;
711   }
712 
713   IpcChannel ipc_channel;
714   auto group_info = get_group_info(block->group_name);
715   auto store = group_info.store;
716   int rank = group_info.rank;
717   int world_size = group_info.world_size;
718 
719   auto driver_api = c10::cuda::DriverAPI::get();
720   int block_fd;
721   C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
722       &block_fd, block->handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
723 
724   auto local_req = RendezvousRequest{
725       .device_idx = block->device_idx,
726       .pid = getpid(),
727       .block_size = block->block_size,
728       .buffer_size = block->buffer_size,
729       .signal_pad_offset = block->signal_pad_offset,
730       .has_multicast_support = device_has_multicast_support(block->device_idx)};
731   auto reqs = store_all_gather(store, rank, world_size, local_req);
732   validate_rendezvous_requests(reqs, world_size);
733 
734   std::vector<int> pids(world_size);
735   for (int r = 0; r < world_size; ++r) {
736     pids[r] = reqs[r].pid;
737   }
738   auto imported_fds = ipc_channel.all_gather_fds(rank, pids, block_fd);
739 
740   std::vector<HandleType> handles(world_size);
741   std::vector<void*> buffers(world_size, nullptr);
742   std::vector<void*> signal_pads(world_size, nullptr);
743 
744   for (int r = 0; r < world_size; ++r) {
745     if (r == rank) {
746       handles[r] = block->handle;
747       buffers[r] = ptr;
748       signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset);
749       continue;
750     }
751     C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
752         &handles[r],
753         (void*)(uintptr_t)imported_fds[r],
754         CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
755     map_block(&buffers[r], handles[r], block->block_size, block->device_idx);
756     signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset);
757     close(imported_fds[r]);
758   }
759   store_barrier(store, rank, world_size);
760   close(block_fd);
761 
762   HandleType mc_handle{};
763   void* mc_addr = nullptr;
764   bool group_has_multicast_support = check_group_multicast_support(reqs);
765   if (group_has_multicast_support) {
766     init_multicast_for_block(
767         mc_handle, mc_addr, block, ipc_channel, pids, store, rank, world_size);
768   }
769 
770   // Initializing CUDASymmetricMemory with an allocation transfers its
771   // ownership to the CUDASymmetricMemory object. So that outstanding
772   // references to the CUDASymmetricMemory object can keep the allocation
773   // alive.
774   block->symm_mem = c10::make_intrusive<CUDASymmetricMemory>(
775       std::move(handles),
776       block->block_size,
777       std::move(buffers),
778       std::move(signal_pads),
779       mc_handle,
780       mc_addr,
781       block->buffer_size,
782       block->device_idx,
783       group_info.rank,
784       group_info.world_size);
785   return block->symm_mem;
786 #else
787   TORCH_CHECK(
788       false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
789 #endif
790 }
791 
is_rendezvous_completed(void * ptr)792 bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) {
793   auto block = find_block(ptr);
794   TORCH_CHECK(
795       block != nullptr,
796       "CUDASymmetricMemoryAllocator::is_rendezvous_completed: input must be allocated ",
797       "via CUDASymmetricMemoryAllocator::alloc");
798   return block->symm_mem != nullptr;
799 }
800 
has_multicast_support(int device_idx)801 bool CUDASymmetricMemoryAllocator::has_multicast_support(int device_idx) {
802   return device_has_multicast_support(device_idx);
803 }
804 
find_block(void * ptr)805 c10::intrusive_ptr<Block> CUDASymmetricMemoryAllocator::find_block(void* ptr) {
806   std::shared_lock lock(mutex_);
807   auto it = ptr_to_block_.find(ptr);
808   if (it == ptr_to_block_.end()) {
809     return nullptr;
810   }
811   return it->second;
812 }
813 
814 struct RegisterCUDASymmetricMemoryAllocator {
RegisterCUDASymmetricMemoryAllocatorc10d::symmetric_memory::RegisterCUDASymmetricMemoryAllocator815   RegisterCUDASymmetricMemoryAllocator() {
816     register_allocator(
817         c10::DeviceType::CUDA,
818         c10::make_intrusive<CUDASymmetricMemoryAllocator>());
819   }
820 };
821 
822 static RegisterCUDASymmetricMemoryAllocator register_allocator_;
823 
824 } // namespace symmetric_memory
825 } // namespace c10d
826