xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/GlooDeviceFactory.hpp>
2 
3 #ifdef USE_C10D_GLOO
4 
5 #include <cstdlib>
6 
7 #include <c10/util/Exception.h>
8 
9 #if GLOO_HAVE_TRANSPORT_TCP
10 #include <gloo/transport/tcp/device.h>
11 #endif
12 
13 #if GLOO_HAVE_TRANSPORT_TCP_TLS
14 #include <gloo/transport/tcp/tls/device.h>
15 #endif
16 
17 #if GLOO_HAVE_TRANSPORT_UV
18 #include <gloo/transport/uv/device.h>
19 #endif
20 
21 // On Linux, check that the tcp transport is available.
22 #ifdef __linux__
23 #if !GLOO_HAVE_TRANSPORT_TCP
24 #error "Expected the tcp transport to be available on Linux."
25 #endif
26 #endif
27 
28 // On macOS, check that the uv transport is available.
29 #ifdef __APPLE__
30 #if !GLOO_HAVE_TRANSPORT_UV
31 #error "Expected the uv transport to be available on macOS."
32 #endif
33 #endif
34 
35 namespace c10d {
36 
37 C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING(
38     GlooDeviceRegistry,
39     ::gloo::transport::Device,
40     const std::string& /* interface */,
41     const std::string& /* hostname */);
42 
43 #if GLOO_HAVE_TRANSPORT_TCP
makeTCPDevice(const std::string & interfaceName,const std::string & hostname)44 static std::shared_ptr<::gloo::transport::Device> makeTCPDevice(
45     const std::string& interfaceName,
46     const std::string& hostname) {
47   TORCH_CHECK(
48       !interfaceName.empty() || !hostname.empty(),
49       "GlooDeviceFactory::makeTCPDevice(): interface or hostname "
50       "can't be empty");
51 
52   ::gloo::transport::tcp::attr attr;
53   if (!interfaceName.empty()) {
54     attr.iface = interfaceName;
55   } else {
56     attr.hostname = hostname;
57   }
58   return ::gloo::transport::tcp::CreateDevice(attr);
59 }
60 
61 // Registry priority is per key identifier. We register TCP to `LINUX` for
62 // the flexibility of other application to override by priority. Register
63 // TCP to `TCP` for env "GLOO_DEVICE_TRANSPORT" override.
64 C10_REGISTER_CREATOR(GlooDeviceRegistry, LINUX, makeTCPDevice);
65 C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP, makeTCPDevice);
66 #endif
67 
68 #if GLOO_HAVE_TRANSPORT_TCP_TLS
cstr_to_std_string(const char * chars)69 static std::string cstr_to_std_string(const char* chars) {
70   return std::string(chars != nullptr ? chars : "");
71 }
72 
makeTCPTLSDevice(const std::string & interface,const std::string & hostname)73 static std::shared_ptr<::gloo::transport::Device> makeTCPTLSDevice(
74     const std::string& interface,
75     const std::string& hostname) {
76   TORCH_CHECK(
77       !interface.empty() || !hostname.empty(),
78       "GlooDeviceFactory::makeTCPTLSDevice(): interface or hostname "
79       "can't be empty");
80 
81   ::gloo::transport::tcp::attr attr;
82   if (!interface.empty()) {
83     attr.iface = interface;
84   } else {
85     attr.hostname = hostname;
86   }
87   const auto pkey =
88       cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY"));
89   const auto cert =
90       cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT"));
91   const auto caFile =
92       cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE"));
93   const auto caPath =
94       cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_PATH"));
95   return ::gloo::transport::tcp::tls::CreateDevice(
96       attr, pkey, cert, caFile, caPath);
97 }
98 
99 C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP_TLS, makeTCPTLSDevice);
100 #endif
101 
102 #if GLOO_HAVE_TRANSPORT_UV
makeUVDevice(const std::string & interfaceName,const std::string & hostname)103 static std::shared_ptr<::gloo::transport::Device> makeUVDevice(
104     const std::string& interfaceName,
105     const std::string& hostname) {
106   TORCH_CHECK(
107       !interfaceName.empty() || !hostname.empty(),
108       "GlooDeviceFactory::makeUVDevice(): interface or hostname "
109       "can't be empty");
110 
111   ::gloo::transport::uv::attr attr;
112   if (!interfaceName.empty()) {
113     attr.iface = interfaceName;
114   } else {
115     attr.hostname = hostname;
116   }
117   return ::gloo::transport::uv::CreateDevice(attr);
118 }
119 
120 // Registry priority is per key identifier. We register UV to `APPLE` for
121 // the flexibility of other application to override by priority. Register
122 // UV to `UV` for env "GLOO_DEVICE_TRANSPORT" override.
123 C10_REGISTER_CREATOR(GlooDeviceRegistry, APPLE, makeUVDevice);
124 C10_REGISTER_CREATOR(GlooDeviceRegistry, WIN32, makeUVDevice);
125 C10_REGISTER_CREATOR(GlooDeviceRegistry, UV, makeUVDevice);
126 #endif
127 
128 namespace {
makeGlooDevice(const std::string & interfaceName,const std::string & hostName)129 std::shared_ptr<::gloo::transport::Device> makeGlooDevice(
130     const std::string& interfaceName,
131     const std::string& hostName) {
132   static auto transportName = getenv("GLOO_DEVICE_TRANSPORT");
133   if (transportName) {
134     return GlooDeviceRegistry()->Create(transportName, interfaceName, hostName);
135   }
136 
137 #ifdef __linux__
138   return GlooDeviceRegistry()->Create("LINUX", interfaceName, hostName);
139 #endif
140 
141 #ifdef __APPLE__
142   return GlooDeviceRegistry()->Create("APPLE", interfaceName, hostName);
143 #endif
144 
145 #ifdef _WIN32
146   return GlooDeviceRegistry()->Create("WIN32", interfaceName, hostName);
147 #endif
148 
149   return nullptr;
150 }
151 } // anonymous namespace
152 
153 std::shared_ptr<::gloo::transport::Device> GlooDeviceFactory::
makeDeviceForInterface(const std::string & interfaceName)154     makeDeviceForInterface(const std::string& interfaceName) {
155   auto device = makeGlooDevice(interfaceName, "");
156   if (!device) {
157     TORCH_CHECK(false, "makeDeviceForInterface(): unsupported gloo device");
158   }
159   return device;
160 }
161 
162 std::shared_ptr<::gloo::transport::Device> GlooDeviceFactory::
makeDeviceForHostname(const std::string & hostname)163     makeDeviceForHostname(const std::string& hostname) {
164   auto device = makeGlooDevice("", hostname);
165   if (!device) {
166     TORCH_CHECK(false, "makeDeviceForHostname(): unsupported gloo device");
167   }
168   return device;
169 }
170 
171 } // namespace c10d
172 
173 #endif // USE_C10D_GLOO
174