1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <torch/csrc/distributed/c10d/Store.hpp> 5 #include <torch/csrc/distributed/c10d/SymmetricMemory.hpp> 6 7 namespace c10d { 8 namespace symmetric_memory { 9 10 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) 11 using HandleType = CUmemGenericAllocationHandle; 12 #else 13 using HandleType = void*; 14 #endif 15 16 class CUDASymmetricMemory : public SymmetricMemory { 17 public: 18 CUDASymmetricMemory( 19 std::vector<HandleType> handles, 20 size_t block_size, 21 std::vector<void*> buffers, 22 std::vector<void*> signal_pads, 23 HandleType mc_handle, 24 void* mc_addr, 25 size_t buffer_size, 26 int local_device_idx, 27 int rank, 28 int world_size); 29 30 ~CUDASymmetricMemory() override; 31 32 std::vector<void*> get_buffer_ptrs() override; 33 std::vector<void*> get_signal_pad_ptrs() override; 34 void** get_buffer_ptrs_dev() override; 35 void** get_signal_pad_ptrs_dev() override; 36 size_t get_buffer_size() override; 37 size_t get_signal_pad_size() override; 38 39 bool has_multicast_support() override; 40 void* get_multicast_ptr() override; 41 42 at::Tensor get_buffer( 43 int rank, 44 c10::IntArrayRef sizes, 45 c10::ScalarType dtype, 46 int64_t storage_offset) override; 47 48 void barrier(int channel) override; 49 void put_signal(int dst_rank, int channel) override; 50 void wait_signal(int src_rank, int channel) override; 51 52 int get_rank() override; 53 int get_world_size() override; 54 55 private: 56 std::vector<HandleType> handles_; 57 size_t block_size_; 58 std::vector<void*> buffers_; 59 std::vector<void*> signal_pads_; 60 HandleType mc_handle_; 61 void* mc_addr_; 62 size_t buffer_size_; 63 int local_device_idx_; 64 int rank_; 65 int world_size_; 66 void** buffers_dev_; 67 void** signal_pads_dev_; 68 std::optional<std::function<void(void)>> finalizer_; 69 }; 70 71 struct Block : public c10::intrusive_ptr_target { 72 HandleType handle; 73 int device_idx; 74 size_t block_size; 75 size_t buffer_size; 76 size_t signal_pad_offset; 77 std::string group_name; 78 c10::intrusive_ptr<CUDASymmetricMemory> symm_mem = nullptr; 79 Blockc10d::symmetric_memory::Block80 Block( 81 HandleType handle, 82 int device_idx, 83 size_t block_size, 84 size_t buffer_size, 85 size_t signal_pad_offset, 86 const std::string& group_name) 87 : handle(handle), 88 device_idx(device_idx), 89 block_size(block_size), 90 buffer_size(buffer_size), 91 signal_pad_offset(signal_pad_offset), 92 group_name(group_name), 93 symm_mem(nullptr) {} 94 }; 95 96 class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator { 97 public: 98 void* alloc(size_t size, int device_idx, const std::string& group_name) 99 override; 100 101 void free(void* ptr) override; 102 size_t get_alloc_size(void* ptr) override; 103 c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) override; 104 bool is_rendezvous_completed(void* ptr) override; 105 bool has_multicast_support(int device_idx) override; 106 107 private: 108 c10::intrusive_ptr<Block> find_block(void* ptr); 109 110 std::shared_mutex mutex_; 111 std::unordered_map<void*, c10::intrusive_ptr<Block>> ptr_to_block_; 112 }; 113 114 } // namespace symmetric_memory 115 } // namespace c10d 116