1 #include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
2
3 namespace {
4
5 using namespace c10d::symmetric_memory;
6
7 static bool is_finalizing_ = false;
8
9 class AllocatorMap {
10 public:
get()11 static AllocatorMap& get() {
12 static AllocatorMap instance;
13 return instance;
14 }
15
register_allocator(c10::DeviceType device_type,c10::intrusive_ptr<SymmetricMemoryAllocator> allocator)16 void register_allocator(
17 c10::DeviceType device_type,
18 c10::intrusive_ptr<SymmetricMemoryAllocator> allocator) {
19 map_[device_type] = std::move(allocator);
20 }
21
get_allocator(c10::DeviceType device_type)22 c10::intrusive_ptr<SymmetricMemoryAllocator> get_allocator(
23 c10::DeviceType device_type) {
24 auto it = map_.find(device_type);
25 TORCH_CHECK(
26 it != map_.end(),
27 "SymmetricMemory does not support device type ",
28 device_type);
29 return it->second;
30 }
31
~AllocatorMap()32 ~AllocatorMap() {
33 is_finalizing_ = true;
34 }
35
36 private:
37 AllocatorMap() = default;
38 AllocatorMap(const AllocatorMap&) = delete;
39 AllocatorMap& operator=(const AllocatorMap&) = delete;
40
41 std::unordered_map<
42 c10::DeviceType,
43 c10::intrusive_ptr<SymmetricMemoryAllocator>>
44 map_;
45 };
46
47 static std::unordered_map<std::string, GroupInfo> group_info_map{};
48
49 // Data structures for tracking persistent allocations
50 static std::unordered_map<uint64_t, void*> alloc_id_to_dev_ptr{};
51 static std::unordered_map<uint64_t, c10::weak_intrusive_ptr<c10::StorageImpl>>
52 alloc_id_to_storage{};
53
empty_strided_p2p_persistent(c10::IntArrayRef size,c10::IntArrayRef stride,c10::ScalarType dtype,c10::Device device,const std::string & group_name,uint64_t alloc_id)54 static at::Tensor empty_strided_p2p_persistent(
55 c10::IntArrayRef size,
56 c10::IntArrayRef stride,
57 c10::ScalarType dtype,
58 c10::Device device,
59 const std::string& group_name,
60 uint64_t alloc_id) {
61 // Make the allocation fails if a previous allocation with the same alloc_id
62 // is still active.
63 auto storage = alloc_id_to_storage.find(alloc_id);
64 if (storage != alloc_id_to_storage.end() && storage->second.use_count() > 0) {
65 TORCH_CHECK(
66 false,
67 "SymmetricMemory::empty_strided_p2p_persistent: ",
68 "can not allocate with alloc_id == ",
69 alloc_id,
70 " because a previous allocation with the same alloc_id "
71 "is still active.");
72 }
73
74 const size_t numel =
75 std::accumulate(size.begin(), size.end(), 1, std::multiplies<int>());
76 const size_t element_size = c10::elementSize(dtype);
77 const size_t alloc_size = numel * element_size;
78
79 auto allocator = get_allocator(device.type());
80 void* dev_ptr = nullptr;
81 if (alloc_id_to_dev_ptr.find(alloc_id) != alloc_id_to_dev_ptr.end()) {
82 dev_ptr = alloc_id_to_dev_ptr[alloc_id];
83 TORCH_CHECK(
84 alloc_size == allocator->get_alloc_size(dev_ptr),
85 "SymmetricMemory::empty_strided_p2p_persistent: ",
86 "requested allocation size (",
87 alloc_size,
88 ") is different from the size of a previous allocation ",
89 "with the same alloc_id ",
90 allocator->get_alloc_size(dev_ptr));
91 } else {
92 dev_ptr = allocator->alloc(alloc_size, device.index(), group_name);
93 alloc_id_to_dev_ptr[alloc_id] = dev_ptr;
94 }
95
96 auto options = at::TensorOptions().dtype(dtype).device(device);
97 auto allocated = at::from_blob(dev_ptr, size, stride, options);
98
99 // Track the allocation's activeness
100 alloc_id_to_storage.erase(alloc_id);
101 alloc_id_to_storage.emplace(
102 alloc_id, allocated.storage().getWeakStorageImpl());
103 return allocated;
104 }
105
106 } // namespace
107
108 namespace c10d {
109 namespace symmetric_memory {
110
is_finalizing()111 bool is_finalizing() {
112 return is_finalizing_;
113 }
114
register_allocator(c10::DeviceType device_type,c10::intrusive_ptr<SymmetricMemoryAllocator> allocator)115 void register_allocator(
116 c10::DeviceType device_type,
117 c10::intrusive_ptr<SymmetricMemoryAllocator> allocator) {
118 return AllocatorMap::get().register_allocator(
119 device_type, std::move(allocator));
120 }
121
get_allocator(c10::DeviceType device_type)122 c10::intrusive_ptr<SymmetricMemoryAllocator> get_allocator(
123 c10::DeviceType device_type) {
124 return AllocatorMap::get().get_allocator(device_type);
125 }
126
set_group_info(const std::string & group_name,int rank,int world_size,c10::intrusive_ptr<Store> store)127 void set_group_info(
128 const std::string& group_name,
129 int rank,
130 int world_size,
131 c10::intrusive_ptr<Store> store) {
132 TORCH_CHECK(group_info_map.find(group_name) == group_info_map.end());
133 GroupInfo group_info;
134 group_info.rank = rank;
135 group_info.world_size = world_size;
136 group_info.store = std::move(store);
137 group_info_map.emplace(group_name, std::move(group_info));
138 }
139
get_group_info(const std::string & group_name)140 const GroupInfo& get_group_info(const std::string& group_name) {
141 TORCH_CHECK(
142 group_info_map.find(group_name) != group_info_map.end(),
143 "get_group_info: no group info associated with the group name ",
144 group_name);
145 return group_info_map[group_name];
146 }
147
empty_strided_p2p(c10::IntArrayRef size,c10::IntArrayRef stride,c10::ScalarType dtype,c10::Device device,const std::string & group_name,std::optional<uint64_t> alloc_id)148 at::Tensor empty_strided_p2p(
149 c10::IntArrayRef size,
150 c10::IntArrayRef stride,
151 c10::ScalarType dtype,
152 c10::Device device,
153 const std::string& group_name,
154 std::optional<uint64_t> alloc_id) {
155 if (alloc_id.has_value()) {
156 return empty_strided_p2p_persistent(
157 size, stride, dtype, device, group_name, *alloc_id);
158 }
159 const size_t numel =
160 std::accumulate(size.begin(), size.end(), 1, std::multiplies<int>());
161 const size_t element_size = c10::elementSize(dtype);
162 const size_t alloc_size = numel * element_size;
163
164 auto allocator = get_allocator(device.type());
165 void* dev_ptr = allocator->alloc(alloc_size, device.index(), group_name);
166
167 auto options = at::TensorOptions().dtype(dtype).device(device);
168 return at::from_blob(
169 dev_ptr,
170 size,
171 stride,
172 [allocator = std::move(allocator)](void* ptr) { allocator->free(ptr); },
173 options);
174 }
175
rendezvous(const at::Tensor & tensor)176 TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
177 const at::Tensor& tensor) {
178 auto allocator = get_allocator(tensor.device().type());
179 return allocator->rendezvous(tensor.storage().data_ptr().get());
180 }
181
get_symmetric_memory(const at::Tensor & tensor)182 c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
183 const at::Tensor& tensor) {
184 auto allocator = get_allocator(tensor.device().type());
185 TORCH_CHECK(
186 allocator->is_rendezvous_completed(tensor.data_ptr()),
187 "SymmetricMemory: must invoke rendezvous on a tensor ",
188 "before calling get_symmetric_memory on it");
189 return allocator->rendezvous(tensor.data_ptr());
190 }
191
has_multicast_support(c10::DeviceType device_type,int device_idx)192 TORCH_API bool has_multicast_support(
193 c10::DeviceType device_type,
194 int device_idx) {
195 auto allocator = get_allocator(device_type);
196 return allocator->has_multicast_support(device_idx);
197 }
198 } // namespace symmetric_memory
199 } // namespace c10d
200