xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/intra_node_comm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
2 
3 #include <ATen/cuda/CUDAContext.h>
4 #include <c10/cuda/CUDAGuard.h>
5 #include <c10/util/Logging.h>
6 #include <torch/csrc/distributed/c10d/Utils.hpp>
7 
8 #include <iostream>
9 #include <utility>
10 
11 #include <fcntl.h>
12 #include <pthread.h>
13 #include <semaphore.h>
14 #include <sys/mman.h>
15 #include <sys/stat.h>
16 #include <unistd.h>
17 
18 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
19 #include <c10/cuda/driver_api.h>
20 #include <nvml.h>
21 #endif
22 
23 #include <cuda_runtime.h>
24 
25 namespace c10d::intra_node_comm {
26 
27 static std::vector<std::string> ENABLE_INTRA_NODE_COMM = {
28     "ENABLE_INTRA_NODE_COMM"};
29 // Forces detectedTopology() to return Topology::FULLY_CONNECTED, so
30 // IntraNodeComm can be used even without NVLink connection. This is only used
31 // for testing purposes.
32 static std::vector<std::string> TEST_INTRA_NODE_COMM = {"TEST_INTRA_NODE_COMM"};
33 
34 static int intraNodeCommIdx = 0;
35 
36 ////////////////////////////////////////////////////////////////////////////////
37 // CUDA Functions
38 ////////////////////////////////////////////////////////////////////////////////
39 
40 bool isIntraNodeCommSupported();
41 
42 std::optional<HybridCubeMesh> getHybridCubeMesh(NvlMesh nvlMesh);
43 
44 void* initP2pState();
45 
46 void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank);
47 
48 ////////////////////////////////////////////////////////////////////////////////
49 // Topology Detection
50 ////////////////////////////////////////////////////////////////////////////////
51 
operator <<(std::ostream & os,const NvlMesh & nvlMesh)52 static std::ostream& operator<<(std::ostream& os, const NvlMesh& nvlMesh) {
53   std::ostringstream oss;
54   for (size_t i = 0; i < kMaxDevices; ++i) {
55     for (size_t j = 0; j < kMaxDevices; ++j) {
56       oss << nvlMesh[i][j] << " ";
57     }
58     oss << '\n';
59   }
60   os << oss.str();
61   return os;
62 }
63 
isSame(NvlMesh lhs,NvlMesh rhs)64 static bool isSame(NvlMesh lhs, NvlMesh rhs) {
65   for (size_t i = 0; i < kMaxDevices; ++i) {
66     for (size_t j = 0; j < kMaxDevices; ++j) {
67       if (lhs[i][j] != rhs[i][j]) {
68         return false;
69       }
70     }
71   }
72   return true;
73 }
74 
75 /**
76  * Query the nvlink connection among devices.
77  */
getNvlMesh(const std::vector<std::string> & rankToBusId)78 static NvlMesh getNvlMesh(const std::vector<std::string>& rankToBusId) {
79 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
80   using namespace c10::cuda;
81 
82   NvlMesh nvlMesh = {};
83   auto driverApi = DriverAPI::get();
84   if (driverApi == nullptr) {
85     return nvlMesh;
86   }
87 
88   const auto worldSize = rankToBusId.size();
89   std::vector<nvmlDevice_t> devices(worldSize, nullptr);
90   std::unordered_map<std::string, size_t> busIdToRank;
91   std::vector<size_t> switchLinkCount(worldSize, 0);
92 
93   for (size_t r = 0; r < worldSize; ++r) {
94     busIdToRank.emplace(rankToBusId[r], r);
95     TORCH_CHECK(
96         driverApi->nvmlDeviceGetHandleByPciBusId_v2_(
97             rankToBusId[r].c_str(), &devices[r]) == NVML_SUCCESS);
98   }
99 
100   // TODO: find a better way to determine this
101   constexpr size_t kMaxNvLinks = 20;
102 
103   // For each device, loop over devices connected to it via NVLink
104   for (size_t idx = 0; idx < worldSize; ++idx) {
105     for (size_t link = 0; link < kMaxNvLinks; ++link) {
106       nvmlReturn_t ret;
107       nvmlIntNvLinkDeviceType_t deviceType;
108       ret = driverApi->nvmlDeviceGetNvLinkRemoteDeviceType_(
109           devices[idx], link, &deviceType);
110       if (ret != NVML_SUCCESS) {
111         // We've exhausted the NVLinks connected to this device.
112         // This error is benign. There doesn't seem to be a reliable
113         // way to obtain the maximum link value that can be passed to
114         // the API, so we simply increment the link value until the
115         // API fails or we hit a predefined maximum value.
116         break;
117       }
118       // Remote device is GPU
119       if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) {
120         nvmlPciInfo_t pciInfo;
121         ret = driverApi->nvmlDeviceGetNvLinkRemotePciInfo_v2_(
122             devices[idx], link, &pciInfo);
123         if (ret != NVML_SUCCESS) {
124           // Unexpected error. Return an empty NvlMesh
125           return {};
126         }
127         auto it = busIdToRank.find(pciInfo.busId);
128         if (it != busIdToRank.end()) {
129           if (idx != it->second) {
130             nvlMesh[idx][it->second] += 1;
131           }
132         }
133         // Remote device is NVSwitch
134       } else if (deviceType == NVML_NVLINK_DEVICE_TYPE_SWITCH) {
135         switchLinkCount[idx] += 1;
136       }
137     }
138   }
139   // Process NVSwitch connections. For simplicity, we assume
140   // all NVSwitches are interconnected.
141   for (size_t i = 0; i < worldSize; ++i) {
142     for (size_t j = 0; j < worldSize; ++j) {
143       if (i == j) {
144         continue;
145       }
146       nvlMesh[i][j] += std::min(switchLinkCount[i], switchLinkCount[j]);
147     }
148   }
149   return nvlMesh;
150 #else
151   return {};
152 #endif
153 }
154 
155 /**
156  * Determine if the devices form a hybrid cube mesh
157  * topology given a NvlMesh.
158  */
isHybridCubeMesh(const NvlMesh nvlMesh)159 static bool isHybridCubeMesh(const NvlMesh nvlMesh) {
160   std::array<size_t, kMaxDevices> numNeighbors = {};
161   for (size_t i = 0; i < kMaxDevices; ++i) {
162     for (size_t j = 0; j < kMaxDevices; ++j) {
163       if (nvlMesh[i][j] > 0) {
164         numNeighbors[i] += 1;
165       }
166     }
167   }
168   for (size_t i = 0; i < kMaxDevices; ++i) {
169     // TODO: this is insufficent and needs revisit
170     if (numNeighbors[i] != 4) {
171       return false;
172     }
173   }
174   return true;
175 }
176 
177 /**
178  * Detech topology given a NvlMesh.
179  */
detectTopology(const NvlMesh nvlMesh,size_t worldSize)180 static Topology detectTopology(const NvlMesh nvlMesh, size_t worldSize) {
181   if (getCvarBool(TEST_INTRA_NODE_COMM, false)) {
182     return Topology::FULLY_CONNECTED;
183   }
184   bool fullyConnected = true;
185   for (size_t i = 0; i < worldSize - 1; ++i) {
186     for (size_t j = i + 1; j < worldSize; ++j) {
187       if (nvlMesh[i][j] == 0 || nvlMesh[j][i] == 0) {
188         fullyConnected = false;
189       }
190     }
191   }
192   if (fullyConnected) {
193     LOG(INFO) << "IntraNodeComm: Topology::FULLY_CONNECTED";
194     return Topology::FULLY_CONNECTED;
195   }
196   if (worldSize == kMaxDevices && getHybridCubeMesh(nvlMesh) != std::nullopt) {
197     LOG(INFO) << "IntraNodeComm: Topology::HYBRID_CUBE_MESH";
198     return Topology::HYBRID_CUBE_MESH;
199   }
200   LOG(INFO) << "IntraNodeComm: Topology::UNKNOWN";
201   return Topology::UNKNOWN;
202 };
203 
204 ////////////////////////////////////////////////////////////////////////////////
205 // Rendezvous and Initialization
206 ////////////////////////////////////////////////////////////////////////////////
207 
IntraNodeComm(c10::intrusive_ptr<c10d::Store> store,size_t rank,size_t worldSize,std::optional<size_t> bufferSize)208 IntraNodeComm::IntraNodeComm(
209     c10::intrusive_ptr<c10d::Store> store,
210     size_t rank,
211     size_t worldSize,
212     std::optional<size_t> bufferSize)
213     : store_(std::move(store)),
214       rank_(rank),
215       worldSize_(worldSize),
216       bufferSize_(bufferSize.has_value() ? *bufferSize : kDefaultBufferSize),
217       barrierReady_(at::cuda::CUDAEvent()) {}
218 
~IntraNodeComm()219 IntraNodeComm::~IntraNodeComm() {
220   if (!isInitialized_) {
221     return;
222   }
223   auto allocator = get_allocator(c10::DeviceType::CUDA);
224   allocator->free(symmetricMemoryPtr_);
225 }
226 
isEnabled()227 bool IntraNodeComm::isEnabled() {
228   return getCvarBool(ENABLE_INTRA_NODE_COMM, false);
229 }
230 
231 /**
232  * Use c10d::Store to perform allgather on a trivially copyable type.
233  */
234 template <typename T>
storeAllGather(const c10::intrusive_ptr<c10d::Store> & store,const std::string & prefix,size_t rank,size_t worldSize,T val)235 std::vector<T> storeAllGather(
236     const c10::intrusive_ptr<c10d::Store>& store,
237     const std::string& prefix,
238     size_t rank,
239     size_t worldSize,
240     T val) {
241   static_assert(std::is_trivially_copyable_v<T>);
242 
243   std::vector<std::string> peerKeys;
244   for (size_t r = 0; r < worldSize; ++r) {
245     std::ostringstream oss;
246     oss << prefix << "-" << r;
247     peerKeys.push_back(oss.str());
248   }
249 
250   {
251     std::vector<uint8_t> payload(
252         reinterpret_cast<uint8_t*>(&val),
253         reinterpret_cast<uint8_t*>(&val) + sizeof(T));
254     store->set(peerKeys[rank], payload);
255   }
256 
257   std::vector<T> peerVals;
258   for (size_t r = 0; r < worldSize; ++r) {
259     if (r == rank) {
260       peerVals.push_back(val);
261       continue;
262     }
263     store->wait({peerKeys[r]});
264     auto payload = store->get(peerKeys[r]);
265     TORCH_CHECK(payload.size() == sizeof(T));
266     T peerVal{};
267     std::memcpy(&peerVal, payload.data(), sizeof(T));
268     peerVals.push_back(peerVal);
269   }
270   return peerVals;
271 }
272 
rendezvous()273 bool IntraNodeComm::rendezvous() {
274   if (isInitialized_) {
275     return true;
276   }
277 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
278   if (!isIntraNodeCommSupported() || worldSize_ < 2 ||
279       worldSize_ > kMaxDevices) {
280     return false;
281   }
282 
283   deviceIdx_ = at::cuda::current_device();
284   c10::cuda::CUDAGuard guard(deviceIdx_);
285 
286   // First hand shake: exchange hostname and device bus ID
287   struct DevInfo {
288     char hostname[HOST_NAME_MAX + 1];
289     char busId[80];
290   };
291 
292   DevInfo devInfo{};
293   gethostname(devInfo.hostname, sizeof(devInfo.hostname));
294   cudaDeviceProp prop{};
295   AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceIdx_));
296   snprintf(
297       devInfo.busId,
298       sizeof(devInfo.busId),
299       NVML_DEVICE_PCI_BUS_ID_FMT,
300       prop.pciDomainID,
301       prop.pciBusID,
302       prop.pciDeviceID);
303 
304   auto peerDevInfos =
305       storeAllGather(store_, "handshake-0", rank_, worldSize_, devInfo);
306 
307   std::vector<std::string> rankToBusId;
308   for (const auto& info : peerDevInfos) {
309     if (strcmp(info.hostname, peerDevInfos.front().hostname) != 0) {
310       LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some "
311                       "participants are not on the same host ("
312                    << info.hostname << ", " << devInfo.hostname << ")";
313       return false;
314     }
315     rankToBusId.emplace_back(info.busId);
316   }
317 
318   // Verify unique devices
319   {
320     std::unordered_set uniqueBusIds(rankToBusId.begin(), rankToBusId.end());
321     TORCH_CHECK(
322         uniqueBusIds.size() == worldSize_,
323         "IntraNodeComm::rendezvous: detected overlapping devices across ranks. "
324         "Please properly set device via torch.cuda.set_device() before "
325         "initiating rendezvous.");
326   }
327 
328   // Query nvlink connection
329   auto nvlMesh = getNvlMesh(rankToBusId);
330 
331   // Detect topology
332   Topology topology = detectTopology(nvlMesh, worldSize_);
333 
334   auto groupName = "IntraNodeComm" + std::to_string(intraNodeCommIdx++);
335   set_group_info(groupName, rank_, worldSize_, store_);
336   auto allocator = get_allocator(c10::DeviceType::CUDA);
337   symmetricMemoryPtr_ = allocator->alloc(bufferSize_, deviceIdx_, groupName);
338   symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_);
339   TORCH_CHECK(symmetricMemory_->get_signal_pad_size() >= kP2pStateSize);
340 
341   void* topoInfo = initTopoInfo(topology, nvlMesh, rank_);
342 
343   isInitialized_ = true;
344   topology_ = topology;
345   p2pStatesDev_ = symmetricMemory_->get_signal_pad_ptrs_dev();
346   buffersDev_ = symmetricMemory_->get_buffer_ptrs_dev();
347   topoInfo_ = topoInfo;
348   return true;
349 #endif
350   return false;
351 }
352 
353 } // namespace c10d::intra_node_comm
354