1 #include <mutex>
2 #include <shared_mutex>
3 #include <sstream>
4 #include <tuple>
5 #include <unordered_map>
6
7 #include <ATen/core/interned_strings.h>
8 #include <caffe2/utils/threadpool/WorkersPool.h>
9 #include <torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp>
10 #include <torch/csrc/distributed/c10d/logging.h>
11
12 // NS: TODO: Use `std::filesystem` regardless of OS when it's possible
13 // to use it without leaking symbols on PRECXX11 ABI Linux OSes
14 // See https://github.com/pytorch/pytorch/issues/133437 for more details
15 #ifdef _WIN32
16 #include <filesystem>
17 #else
18 #include <sys/stat.h>
19 #endif
20
21 namespace c10d {
22 namespace control_plane {
23
24 namespace {
25 class RequestImpl : public Request {
26 public:
RequestImpl(const httplib::Request & req)27 RequestImpl(const httplib::Request& req) : req_(req) {}
28
body()29 const std::string& body() override {
30 return req_.body;
31 }
32
33 private:
34 const httplib::Request& req_;
35 };
36
37 class ResponseImpl : public Response {
38 public:
ResponseImpl(httplib::Response & res)39 ResponseImpl(httplib::Response& res) : res_(res) {}
40
setStatus(int status)41 void setStatus(int status) override {
42 res_.status = status;
43 }
44
setContent(std::string && content,const std::string & content_type)45 void setContent(std::string&& content, const std::string& content_type)
46 override {
47 res_.set_content(std::move(content), content_type);
48 }
49
50 private:
51 httplib::Response& res_;
52 };
53
jsonStrEscape(const std::string & str)54 std::string jsonStrEscape(const std::string& str) {
55 std::ostringstream ostream;
56 for (char ch : str) {
57 if (ch == '"') {
58 ostream << "\\\"";
59 } else if (ch == '\\') {
60 ostream << "\\\\";
61 } else if (ch == '\b') {
62 ostream << "\\b";
63 } else if (ch == '\f') {
64 ostream << "\\f";
65 } else if (ch == '\n') {
66 ostream << "\\n";
67 } else if (ch == '\r') {
68 ostream << "\\r";
69 } else if (ch == '\t') {
70 ostream << "\\t";
71 } else if ('\x00' <= ch && ch <= '\x1f') {
72 ostream << "\\u" << std::hex << std::setw(4) << std::setfill('0')
73 << static_cast<int>(ch);
74 } else {
75 ostream << ch;
76 }
77 }
78 return ostream.str();
79 }
80
file_exists(const std::string & path)81 bool file_exists(const std::string& path) {
82 #ifdef _WIN32
83 return std::filesystem::exists(path);
84 #else
85 struct stat rc;
86 return lstat(path.c_str(), &rc) == 0;
87 #endif
88 }
89 } // namespace
90
WorkerServer(const std::string & hostOrFile,int port)91 WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
92 server_.Get("/", [](const httplib::Request& req, httplib::Response& res) {
93 res.set_content(
94 R"BODY(<h1>torch.distributed.WorkerServer</h1>
95 <a href="/handler/">Handler names</a>
96 )BODY",
97 "text/html");
98 });
99 server_.Get(
100 "/handler/", [](const httplib::Request& req, httplib::Response& res) {
101 std::ostringstream body;
102 body << "[";
103 bool first = true;
104 for (const auto& name : getHandlerNames()) {
105 if (!first) {
106 body << ",";
107 }
108 first = false;
109
110 body << "\"" << jsonStrEscape(name) << "\"";
111 }
112 body << "]";
113
114 res.set_content(body.str(), "application/json");
115 });
116 server_.Post(
117 "/handler/:handler",
118 [](const httplib::Request& req, httplib::Response& res) {
119 auto handler_name = req.path_params.at("handler");
120 HandlerFunc handler;
121 try {
122 handler = getHandler(handler_name);
123 } catch (const std::exception& e) {
124 res.status = 404;
125 res.set_content(
126 fmt::format("Handler {} not found: {}", handler_name, e.what()),
127 "text/plain");
128 return;
129 }
130 RequestImpl torchReq{req};
131 ResponseImpl torchRes{res};
132
133 try {
134 handler(torchReq, torchRes);
135 } catch (const std::exception& e) {
136 res.status = 500;
137 res.set_content(
138 fmt::format("Handler {} failed: {}", handler_name, e.what()),
139 "text/plain");
140 return;
141 } catch (...) {
142 res.status = 500;
143 res.set_content(
144 fmt::format(
145 "Handler {} failed with unknown exception", handler_name),
146 "text/plain");
147 return;
148 }
149 });
150
151 // adjust keep alives as it stops the server from shutting down quickly
152 server_.set_keep_alive_timeout(1); // second, default is 5
153 server_.set_keep_alive_max_count(
154 30); // wait max 30 seconds before closing socket
155
156 if (port == -1) {
157 // using unix sockets
158 server_.set_address_family(AF_UNIX);
159
160 if (file_exists(hostOrFile)) {
161 throw std::runtime_error(fmt::format("{} already exists", hostOrFile));
162 }
163
164 C10D_WARNING("Server listening to UNIX {}", hostOrFile);
165 if (!server_.bind_to_port(hostOrFile, 80)) {
166 throw std::runtime_error(fmt::format("Error binding to {}", hostOrFile));
167 }
168 } else {
169 C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
170 if (!server_.bind_to_port(hostOrFile, port)) {
171 throw std::runtime_error(
172 fmt::format("Error binding to {}:{}", hostOrFile, port));
173 }
174 }
175
176 serverThread_ = std::thread([this]() {
177 try {
178 if (!server_.listen_after_bind()) {
179 throw std::runtime_error("failed to listen");
180 }
181 } catch (std::exception& e) {
182 C10D_ERROR("Error while running server: {}", e.what());
183 throw;
184 }
185 C10D_WARNING("Server exited");
186 });
187 }
188
shutdown()189 void WorkerServer::shutdown() {
190 C10D_WARNING("Server shutting down");
191 server_.stop();
192 serverThread_.join();
193 }
194
~WorkerServer()195 WorkerServer::~WorkerServer() {
196 if (serverThread_.joinable()) {
197 C10D_WARNING("WorkerServer destructor called without shutdown");
198 shutdown();
199 }
200 }
201
202 } // namespace control_plane
203 } // namespace c10d
204