xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/DMAConnectivity.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <optional>
4 
5 #include <ATen/ATen.h>
6 
7 namespace c10d {
8 
9 struct TORCH_API DMAConnectivity : c10::intrusive_ptr_target {
10   c10::DeviceType device_type;
11   std::string connection_type;
12 
13   // This is an NxN matrix representing the connectivity between N devices,
14   // where each element matrix[i][j] indicates the connectivity between device
15   // i and device j. A value of 0 denotes that there is no connection between
16   // device i and j. The meaning of non-zero values are specific to the
17   // connection type (e.g., for NVLink it represents the number of NVLinks).
18   std::vector<std::vector<int>> matrix;
19 
20   explicit DMAConnectivity(
21       c10::DeviceType device_type,
22       std::string connection_type,
23       std::vector<std::vector<int>> matrix);
24 };
25 
26 struct DMAConnectivityDetector : c10::intrusive_ptr_target {
27   virtual c10::intrusive_ptr<DMAConnectivity> detect() = 0;
~DMAConnectivityDetectorc10d::DMAConnectivityDetector28   virtual ~DMAConnectivityDetector() {}
29 };
30 
31 C10_EXPORT void register_dma_connectivity_detector(
32     c10::DeviceType device_type,
33     const std::string& connection_type,
34     c10::intrusive_ptr<DMAConnectivityDetector> detector);
35 
36 TORCH_API c10::intrusive_ptr<DMAConnectivity> detect_dma_connectivity(
37     c10::DeviceType device_type,
38     const std::string& connection_type);
39 
40 } // namespace c10d
41