1 #include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
2
3 #include <ATen/cuda/CUDAContext.h>
4 #include <c10/cuda/CUDAGuard.h>
5 #include <c10/util/Logging.h>
6 #include <torch/csrc/distributed/c10d/Utils.hpp>
7
8 #include <iostream>
9 #include <utility>
10
11 #include <fcntl.h>
12 #include <pthread.h>
13 #include <semaphore.h>
14 #include <sys/mman.h>
15 #include <sys/stat.h>
16 #include <unistd.h>
17
18 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
19 #include <c10/cuda/driver_api.h>
20 #include <nvml.h>
21 #endif
22
23 #include <cuda_runtime.h>
24
25 namespace c10d::intra_node_comm {
26
27 static std::vector<std::string> ENABLE_INTRA_NODE_COMM = {
28 "ENABLE_INTRA_NODE_COMM"};
29 // Forces detectedTopology() to return Topology::FULLY_CONNECTED, so
30 // IntraNodeComm can be used even without NVLink connection. This is only used
31 // for testing purposes.
32 static std::vector<std::string> TEST_INTRA_NODE_COMM = {"TEST_INTRA_NODE_COMM"};
33
34 static int intraNodeCommIdx = 0;
35
36 ////////////////////////////////////////////////////////////////////////////////
37 // CUDA Functions
38 ////////////////////////////////////////////////////////////////////////////////
39
40 bool isIntraNodeCommSupported();
41
42 std::optional<HybridCubeMesh> getHybridCubeMesh(NvlMesh nvlMesh);
43
44 void* initP2pState();
45
46 void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank);
47
48 ////////////////////////////////////////////////////////////////////////////////
49 // Topology Detection
50 ////////////////////////////////////////////////////////////////////////////////
51
operator <<(std::ostream & os,const NvlMesh & nvlMesh)52 static std::ostream& operator<<(std::ostream& os, const NvlMesh& nvlMesh) {
53 std::ostringstream oss;
54 for (size_t i = 0; i < kMaxDevices; ++i) {
55 for (size_t j = 0; j < kMaxDevices; ++j) {
56 oss << nvlMesh[i][j] << " ";
57 }
58 oss << '\n';
59 }
60 os << oss.str();
61 return os;
62 }
63
isSame(NvlMesh lhs,NvlMesh rhs)64 static bool isSame(NvlMesh lhs, NvlMesh rhs) {
65 for (size_t i = 0; i < kMaxDevices; ++i) {
66 for (size_t j = 0; j < kMaxDevices; ++j) {
67 if (lhs[i][j] != rhs[i][j]) {
68 return false;
69 }
70 }
71 }
72 return true;
73 }
74
75 /**
76 * Query the nvlink connection among devices.
77 */
getNvlMesh(const std::vector<std::string> & rankToBusId)78 static NvlMesh getNvlMesh(const std::vector<std::string>& rankToBusId) {
79 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
80 using namespace c10::cuda;
81
82 NvlMesh nvlMesh = {};
83 auto driverApi = DriverAPI::get();
84 if (driverApi == nullptr) {
85 return nvlMesh;
86 }
87
88 const auto worldSize = rankToBusId.size();
89 std::vector<nvmlDevice_t> devices(worldSize, nullptr);
90 std::unordered_map<std::string, size_t> busIdToRank;
91 std::vector<size_t> switchLinkCount(worldSize, 0);
92
93 for (size_t r = 0; r < worldSize; ++r) {
94 busIdToRank.emplace(rankToBusId[r], r);
95 TORCH_CHECK(
96 driverApi->nvmlDeviceGetHandleByPciBusId_v2_(
97 rankToBusId[r].c_str(), &devices[r]) == NVML_SUCCESS);
98 }
99
100 // TODO: find a better way to determine this
101 constexpr size_t kMaxNvLinks = 20;
102
103 // For each device, loop over devices connected to it via NVLink
104 for (size_t idx = 0; idx < worldSize; ++idx) {
105 for (size_t link = 0; link < kMaxNvLinks; ++link) {
106 nvmlReturn_t ret;
107 nvmlIntNvLinkDeviceType_t deviceType;
108 ret = driverApi->nvmlDeviceGetNvLinkRemoteDeviceType_(
109 devices[idx], link, &deviceType);
110 if (ret != NVML_SUCCESS) {
111 // We've exhausted the NVLinks connected to this device.
112 // This error is benign. There doesn't seem to be a reliable
113 // way to obtain the maximum link value that can be passed to
114 // the API, so we simply increment the link value until the
115 // API fails or we hit a predefined maximum value.
116 break;
117 }
118 // Remote device is GPU
119 if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) {
120 nvmlPciInfo_t pciInfo;
121 ret = driverApi->nvmlDeviceGetNvLinkRemotePciInfo_v2_(
122 devices[idx], link, &pciInfo);
123 if (ret != NVML_SUCCESS) {
124 // Unexpected error. Return an empty NvlMesh
125 return {};
126 }
127 auto it = busIdToRank.find(pciInfo.busId);
128 if (it != busIdToRank.end()) {
129 if (idx != it->second) {
130 nvlMesh[idx][it->second] += 1;
131 }
132 }
133 // Remote device is NVSwitch
134 } else if (deviceType == NVML_NVLINK_DEVICE_TYPE_SWITCH) {
135 switchLinkCount[idx] += 1;
136 }
137 }
138 }
139 // Process NVSwitch connections. For simplicity, we assume
140 // all NVSwitches are interconnected.
141 for (size_t i = 0; i < worldSize; ++i) {
142 for (size_t j = 0; j < worldSize; ++j) {
143 if (i == j) {
144 continue;
145 }
146 nvlMesh[i][j] += std::min(switchLinkCount[i], switchLinkCount[j]);
147 }
148 }
149 return nvlMesh;
150 #else
151 return {};
152 #endif
153 }
154
155 /**
156 * Determine if the devices form a hybrid cube mesh
157 * topology given a NvlMesh.
158 */
isHybridCubeMesh(const NvlMesh nvlMesh)159 static bool isHybridCubeMesh(const NvlMesh nvlMesh) {
160 std::array<size_t, kMaxDevices> numNeighbors = {};
161 for (size_t i = 0; i < kMaxDevices; ++i) {
162 for (size_t j = 0; j < kMaxDevices; ++j) {
163 if (nvlMesh[i][j] > 0) {
164 numNeighbors[i] += 1;
165 }
166 }
167 }
168 for (size_t i = 0; i < kMaxDevices; ++i) {
169 // TODO: this is insufficent and needs revisit
170 if (numNeighbors[i] != 4) {
171 return false;
172 }
173 }
174 return true;
175 }
176
177 /**
178 * Detech topology given a NvlMesh.
179 */
detectTopology(const NvlMesh nvlMesh,size_t worldSize)180 static Topology detectTopology(const NvlMesh nvlMesh, size_t worldSize) {
181 if (getCvarBool(TEST_INTRA_NODE_COMM, false)) {
182 return Topology::FULLY_CONNECTED;
183 }
184 bool fullyConnected = true;
185 for (size_t i = 0; i < worldSize - 1; ++i) {
186 for (size_t j = i + 1; j < worldSize; ++j) {
187 if (nvlMesh[i][j] == 0 || nvlMesh[j][i] == 0) {
188 fullyConnected = false;
189 }
190 }
191 }
192 if (fullyConnected) {
193 LOG(INFO) << "IntraNodeComm: Topology::FULLY_CONNECTED";
194 return Topology::FULLY_CONNECTED;
195 }
196 if (worldSize == kMaxDevices && getHybridCubeMesh(nvlMesh) != std::nullopt) {
197 LOG(INFO) << "IntraNodeComm: Topology::HYBRID_CUBE_MESH";
198 return Topology::HYBRID_CUBE_MESH;
199 }
200 LOG(INFO) << "IntraNodeComm: Topology::UNKNOWN";
201 return Topology::UNKNOWN;
202 };
203
204 ////////////////////////////////////////////////////////////////////////////////
205 // Rendezvous and Initialization
206 ////////////////////////////////////////////////////////////////////////////////
207
IntraNodeComm(c10::intrusive_ptr<c10d::Store> store,size_t rank,size_t worldSize,std::optional<size_t> bufferSize)208 IntraNodeComm::IntraNodeComm(
209 c10::intrusive_ptr<c10d::Store> store,
210 size_t rank,
211 size_t worldSize,
212 std::optional<size_t> bufferSize)
213 : store_(std::move(store)),
214 rank_(rank),
215 worldSize_(worldSize),
216 bufferSize_(bufferSize.has_value() ? *bufferSize : kDefaultBufferSize),
217 barrierReady_(at::cuda::CUDAEvent()) {}
218
~IntraNodeComm()219 IntraNodeComm::~IntraNodeComm() {
220 if (!isInitialized_) {
221 return;
222 }
223 auto allocator = get_allocator(c10::DeviceType::CUDA);
224 allocator->free(symmetricMemoryPtr_);
225 }
226
isEnabled()227 bool IntraNodeComm::isEnabled() {
228 return getCvarBool(ENABLE_INTRA_NODE_COMM, false);
229 }
230
231 /**
232 * Use c10d::Store to perform allgather on a trivially copyable type.
233 */
234 template <typename T>
storeAllGather(const c10::intrusive_ptr<c10d::Store> & store,const std::string & prefix,size_t rank,size_t worldSize,T val)235 std::vector<T> storeAllGather(
236 const c10::intrusive_ptr<c10d::Store>& store,
237 const std::string& prefix,
238 size_t rank,
239 size_t worldSize,
240 T val) {
241 static_assert(std::is_trivially_copyable_v<T>);
242
243 std::vector<std::string> peerKeys;
244 for (size_t r = 0; r < worldSize; ++r) {
245 std::ostringstream oss;
246 oss << prefix << "-" << r;
247 peerKeys.push_back(oss.str());
248 }
249
250 {
251 std::vector<uint8_t> payload(
252 reinterpret_cast<uint8_t*>(&val),
253 reinterpret_cast<uint8_t*>(&val) + sizeof(T));
254 store->set(peerKeys[rank], payload);
255 }
256
257 std::vector<T> peerVals;
258 for (size_t r = 0; r < worldSize; ++r) {
259 if (r == rank) {
260 peerVals.push_back(val);
261 continue;
262 }
263 store->wait({peerKeys[r]});
264 auto payload = store->get(peerKeys[r]);
265 TORCH_CHECK(payload.size() == sizeof(T));
266 T peerVal{};
267 std::memcpy(&peerVal, payload.data(), sizeof(T));
268 peerVals.push_back(peerVal);
269 }
270 return peerVals;
271 }
272
rendezvous()273 bool IntraNodeComm::rendezvous() {
274 if (isInitialized_) {
275 return true;
276 }
277 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
278 if (!isIntraNodeCommSupported() || worldSize_ < 2 ||
279 worldSize_ > kMaxDevices) {
280 return false;
281 }
282
283 deviceIdx_ = at::cuda::current_device();
284 c10::cuda::CUDAGuard guard(deviceIdx_);
285
286 // First hand shake: exchange hostname and device bus ID
287 struct DevInfo {
288 char hostname[HOST_NAME_MAX + 1];
289 char busId[80];
290 };
291
292 DevInfo devInfo{};
293 gethostname(devInfo.hostname, sizeof(devInfo.hostname));
294 cudaDeviceProp prop{};
295 AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceIdx_));
296 snprintf(
297 devInfo.busId,
298 sizeof(devInfo.busId),
299 NVML_DEVICE_PCI_BUS_ID_FMT,
300 prop.pciDomainID,
301 prop.pciBusID,
302 prop.pciDeviceID);
303
304 auto peerDevInfos =
305 storeAllGather(store_, "handshake-0", rank_, worldSize_, devInfo);
306
307 std::vector<std::string> rankToBusId;
308 for (const auto& info : peerDevInfos) {
309 if (strcmp(info.hostname, peerDevInfos.front().hostname) != 0) {
310 LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some "
311 "participants are not on the same host ("
312 << info.hostname << ", " << devInfo.hostname << ")";
313 return false;
314 }
315 rankToBusId.emplace_back(info.busId);
316 }
317
318 // Verify unique devices
319 {
320 std::unordered_set uniqueBusIds(rankToBusId.begin(), rankToBusId.end());
321 TORCH_CHECK(
322 uniqueBusIds.size() == worldSize_,
323 "IntraNodeComm::rendezvous: detected overlapping devices across ranks. "
324 "Please properly set device via torch.cuda.set_device() before "
325 "initiating rendezvous.");
326 }
327
328 // Query nvlink connection
329 auto nvlMesh = getNvlMesh(rankToBusId);
330
331 // Detect topology
332 Topology topology = detectTopology(nvlMesh, worldSize_);
333
334 auto groupName = "IntraNodeComm" + std::to_string(intraNodeCommIdx++);
335 set_group_info(groupName, rank_, worldSize_, store_);
336 auto allocator = get_allocator(c10::DeviceType::CUDA);
337 symmetricMemoryPtr_ = allocator->alloc(bufferSize_, deviceIdx_, groupName);
338 symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_);
339 TORCH_CHECK(symmetricMemory_->get_signal_pad_size() >= kP2pStateSize);
340
341 void* topoInfo = initTopoInfo(topology, nvlMesh, rank_);
342
343 isInitialized_ = true;
344 topology_ = topology;
345 p2pStatesDev_ = symmetricMemory_->get_signal_pad_ptrs_dev();
346 buffersDev_ = symmetricMemory_->get_buffer_ptrs_dev();
347 topoInfo_ = topoInfo;
348 return true;
349 #endif
350 return false;
351 }
352
353 } // namespace c10d::intra_node_comm
354