xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <torch/csrc/distributed/c10d/Store.hpp>
5 #include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
6 
7 namespace c10d {
8 namespace symmetric_memory {
9 
10 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
11 using HandleType = CUmemGenericAllocationHandle;
12 #else
13 using HandleType = void*;
14 #endif
15 
16 class CUDASymmetricMemory : public SymmetricMemory {
17  public:
18   CUDASymmetricMemory(
19       std::vector<HandleType> handles,
20       size_t block_size,
21       std::vector<void*> buffers,
22       std::vector<void*> signal_pads,
23       HandleType mc_handle,
24       void* mc_addr,
25       size_t buffer_size,
26       int local_device_idx,
27       int rank,
28       int world_size);
29 
30   ~CUDASymmetricMemory() override;
31 
32   std::vector<void*> get_buffer_ptrs() override;
33   std::vector<void*> get_signal_pad_ptrs() override;
34   void** get_buffer_ptrs_dev() override;
35   void** get_signal_pad_ptrs_dev() override;
36   size_t get_buffer_size() override;
37   size_t get_signal_pad_size() override;
38 
39   bool has_multicast_support() override;
40   void* get_multicast_ptr() override;
41 
42   at::Tensor get_buffer(
43       int rank,
44       c10::IntArrayRef sizes,
45       c10::ScalarType dtype,
46       int64_t storage_offset) override;
47 
48   void barrier(int channel) override;
49   void put_signal(int dst_rank, int channel) override;
50   void wait_signal(int src_rank, int channel) override;
51 
52   int get_rank() override;
53   int get_world_size() override;
54 
55  private:
56   std::vector<HandleType> handles_;
57   size_t block_size_;
58   std::vector<void*> buffers_;
59   std::vector<void*> signal_pads_;
60   HandleType mc_handle_;
61   void* mc_addr_;
62   size_t buffer_size_;
63   int local_device_idx_;
64   int rank_;
65   int world_size_;
66   void** buffers_dev_;
67   void** signal_pads_dev_;
68   std::optional<std::function<void(void)>> finalizer_;
69 };
70 
71 struct Block : public c10::intrusive_ptr_target {
72   HandleType handle;
73   int device_idx;
74   size_t block_size;
75   size_t buffer_size;
76   size_t signal_pad_offset;
77   std::string group_name;
78   c10::intrusive_ptr<CUDASymmetricMemory> symm_mem = nullptr;
79 
Blockc10d::symmetric_memory::Block80   Block(
81       HandleType handle,
82       int device_idx,
83       size_t block_size,
84       size_t buffer_size,
85       size_t signal_pad_offset,
86       const std::string& group_name)
87       : handle(handle),
88         device_idx(device_idx),
89         block_size(block_size),
90         buffer_size(buffer_size),
91         signal_pad_offset(signal_pad_offset),
92         group_name(group_name),
93         symm_mem(nullptr) {}
94 };
95 
96 class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator {
97  public:
98   void* alloc(size_t size, int device_idx, const std::string& group_name)
99       override;
100 
101   void free(void* ptr) override;
102   size_t get_alloc_size(void* ptr) override;
103   c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) override;
104   bool is_rendezvous_completed(void* ptr) override;
105   bool has_multicast_support(int device_idx) override;
106 
107  private:
108   c10::intrusive_ptr<Block> find_block(void* ptr);
109 
110   std::shared_mutex mutex_;
111   std::unordered_map<void*, c10::intrusive_ptr<Block>> ptr_to_block_;
112 };
113 
114 } // namespace symmetric_memory
115 } // namespace c10d
116