1 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
2 #include <torch/csrc/distributed/c10d/DMAConnectivity.hpp>
3
4 #include <c10/cuda/CUDAException.h>
5 #include <c10/cuda/driver_api.h>
6
7 #include <cuda_runtime.h>
8 #include <nvml.h>
9
10 namespace {
11
12 constexpr int max_nvlinks = 64;
13
get_bus_id(int device_idx)14 std::string get_bus_id(int device_idx) {
15 char bus_id[80];
16 cudaDeviceProp prop{};
17 C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_idx));
18 snprintf(
19 bus_id,
20 sizeof(bus_id),
21 NVML_DEVICE_PCI_BUS_ID_FMT,
22 prop.pciDomainID,
23 prop.pciBusID,
24 prop.pciDeviceID);
25 return std::string(bus_id);
26 }
27
28 struct C10_EXPORT NVLinkDetector : public c10d::DMAConnectivityDetector {
detect__anonde6897900111::NVLinkDetector29 c10::intrusive_ptr<c10d::DMAConnectivity> detect() override {
30 int num_devices;
31 C10_CUDA_CHECK(cudaGetDeviceCount(&num_devices));
32
33 std::vector<std::vector<int>> matrix;
34 matrix.reserve(num_devices);
35 for (int i = 0; i < num_devices; ++i) {
36 matrix.emplace_back(num_devices, 0);
37 }
38
39 // Obtain the bus_id for all visible devices
40 std::unordered_map<std::string, int> bus_id_to_device_idx;
41 std::vector<std::string> bus_ids;
42 bus_ids.reserve(num_devices);
43 for (int i = 0; i < num_devices; ++i) {
44 auto bus_id = get_bus_id(i);
45 bus_id_to_device_idx.emplace(bus_id, i);
46 bus_ids.push_back(std::move(bus_id));
47 }
48
49 // Obtain the nvml device for all bus_ids
50 auto driver_api = c10::cuda::DriverAPI::get();
51 std::vector<nvmlDevice_t> nvml_devices(num_devices, nullptr);
52 for (int i = 0; i < num_devices; ++i) {
53 TORCH_CHECK_EQ(
54 driver_api->nvmlDeviceGetHandleByPciBusId_v2_(
55 bus_ids[i].c_str(), &nvml_devices[i]),
56 NVML_SUCCESS);
57 }
58
59 std::vector<int> switch_link_count(num_devices, 0);
60 for (int i = 0; i < num_devices; ++i) {
61 for (int link = 0; link < max_nvlinks; ++link) {
62 nvmlReturn_t ret;
63 nvmlIntNvLinkDeviceType_t deviceType;
64 ret = driver_api->nvmlDeviceGetNvLinkRemoteDeviceType_(
65 nvml_devices[i], link, &deviceType);
66 if (ret != NVML_SUCCESS) {
67 // We've exhausted the NVLinks connected to this device. This error
68 // is benign. There doesn't seem to be a reliable way to obtain the
69 // maximum link value that can be passed to the API. Therefore, we
70 // simply increment the link value until the API fails or we reach a
71 // predefined maximum value.
72 break;
73 }
74 // Remote device is GPU
75 if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) {
76 nvmlPciInfo_t pciInfo;
77 TORCH_CHECK_EQ(
78 driver_api->nvmlDeviceGetNvLinkRemotePciInfo_v2_(
79 nvml_devices[i], link, &pciInfo),
80 NVML_SUCCESS);
81 auto it = bus_id_to_device_idx.find(pciInfo.busId);
82 if (it != bus_id_to_device_idx.end()) {
83 if (i != it->second) {
84 matrix[i][it->second] += 1;
85 }
86 }
87 // Remote device is NVSwitch
88 } else if (deviceType == NVML_NVLINK_DEVICE_TYPE_SWITCH) {
89 switch_link_count[i] += 1;
90 }
91 }
92 }
93
94 // Process NVSwitch connections.
95 // For simplicity, we assume that all NVSwitches are interconnected.
96 for (int i = 0; i < num_devices; ++i) {
97 for (int j = 0; j < num_devices; ++j) {
98 if (i == j) {
99 continue;
100 }
101 matrix[i][j] += std::min(switch_link_count[i], switch_link_count[j]);
102 }
103 }
104
105 return c10::make_intrusive<c10d::DMAConnectivity>(
106 c10::DeviceType::CUDA, "nvlink", std::move(matrix));
107 }
108 };
109
110 struct RegisterDetector {
RegisterDetector__anonde6897900111::RegisterDetector111 RegisterDetector() {
112 register_dma_connectivity_detector(
113 c10::DeviceType::CUDA, "nvlink", c10::make_intrusive<NVLinkDetector>());
114 }
115 };
116
117 static RegisterDetector register_detector_;
118
119 } // namespace
120 #endif
121