xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/CudaDMAConnectivity.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
2 #include <torch/csrc/distributed/c10d/DMAConnectivity.hpp>
3 
4 #include <c10/cuda/CUDAException.h>
5 #include <c10/cuda/driver_api.h>
6 
7 #include <cuda_runtime.h>
8 #include <nvml.h>
9 
10 namespace {
11 
12 constexpr int max_nvlinks = 64;
13 
get_bus_id(int device_idx)14 std::string get_bus_id(int device_idx) {
15   char bus_id[80];
16   cudaDeviceProp prop{};
17   C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_idx));
18   snprintf(
19       bus_id,
20       sizeof(bus_id),
21       NVML_DEVICE_PCI_BUS_ID_FMT,
22       prop.pciDomainID,
23       prop.pciBusID,
24       prop.pciDeviceID);
25   return std::string(bus_id);
26 }
27 
28 struct C10_EXPORT NVLinkDetector : public c10d::DMAConnectivityDetector {
detect__anonde6897900111::NVLinkDetector29   c10::intrusive_ptr<c10d::DMAConnectivity> detect() override {
30     int num_devices;
31     C10_CUDA_CHECK(cudaGetDeviceCount(&num_devices));
32 
33     std::vector<std::vector<int>> matrix;
34     matrix.reserve(num_devices);
35     for (int i = 0; i < num_devices; ++i) {
36       matrix.emplace_back(num_devices, 0);
37     }
38 
39     // Obtain the bus_id for all visible devices
40     std::unordered_map<std::string, int> bus_id_to_device_idx;
41     std::vector<std::string> bus_ids;
42     bus_ids.reserve(num_devices);
43     for (int i = 0; i < num_devices; ++i) {
44       auto bus_id = get_bus_id(i);
45       bus_id_to_device_idx.emplace(bus_id, i);
46       bus_ids.push_back(std::move(bus_id));
47     }
48 
49     // Obtain the nvml device for all bus_ids
50     auto driver_api = c10::cuda::DriverAPI::get();
51     std::vector<nvmlDevice_t> nvml_devices(num_devices, nullptr);
52     for (int i = 0; i < num_devices; ++i) {
53       TORCH_CHECK_EQ(
54           driver_api->nvmlDeviceGetHandleByPciBusId_v2_(
55               bus_ids[i].c_str(), &nvml_devices[i]),
56           NVML_SUCCESS);
57     }
58 
59     std::vector<int> switch_link_count(num_devices, 0);
60     for (int i = 0; i < num_devices; ++i) {
61       for (int link = 0; link < max_nvlinks; ++link) {
62         nvmlReturn_t ret;
63         nvmlIntNvLinkDeviceType_t deviceType;
64         ret = driver_api->nvmlDeviceGetNvLinkRemoteDeviceType_(
65             nvml_devices[i], link, &deviceType);
66         if (ret != NVML_SUCCESS) {
67           // We've exhausted the NVLinks connected to this device. This error
68           // is benign. There doesn't seem to be a reliable way to obtain the
69           // maximum link value that can be passed to the API. Therefore, we
70           // simply increment the link value until the API fails or we reach a
71           // predefined maximum value.
72           break;
73         }
74         // Remote device is GPU
75         if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) {
76           nvmlPciInfo_t pciInfo;
77           TORCH_CHECK_EQ(
78               driver_api->nvmlDeviceGetNvLinkRemotePciInfo_v2_(
79                   nvml_devices[i], link, &pciInfo),
80               NVML_SUCCESS);
81           auto it = bus_id_to_device_idx.find(pciInfo.busId);
82           if (it != bus_id_to_device_idx.end()) {
83             if (i != it->second) {
84               matrix[i][it->second] += 1;
85             }
86           }
87           // Remote device is NVSwitch
88         } else if (deviceType == NVML_NVLINK_DEVICE_TYPE_SWITCH) {
89           switch_link_count[i] += 1;
90         }
91       }
92     }
93 
94     // Process NVSwitch connections.
95     // For simplicity, we assume that all NVSwitches are interconnected.
96     for (int i = 0; i < num_devices; ++i) {
97       for (int j = 0; j < num_devices; ++j) {
98         if (i == j) {
99           continue;
100         }
101         matrix[i][j] += std::min(switch_link_count[i], switch_link_count[j]);
102       }
103     }
104 
105     return c10::make_intrusive<c10d::DMAConnectivity>(
106         c10::DeviceType::CUDA, "nvlink", std::move(matrix));
107   }
108 };
109 
110 struct RegisterDetector {
RegisterDetector__anonde6897900111::RegisterDetector111   RegisterDetector() {
112     register_dma_connectivity_detector(
113         c10::DeviceType::CUDA, "nvlink", c10::make_intrusive<NVLinkDetector>());
114   }
115 };
116 
117 static RegisterDetector register_detector_;
118 
119 } // namespace
120 #endif
121