xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <mutex>
2 #include <shared_mutex>
3 #include <sstream>
4 #include <tuple>
5 #include <unordered_map>
6 
7 #include <ATen/core/interned_strings.h>
8 #include <caffe2/utils/threadpool/WorkersPool.h>
9 #include <torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp>
10 #include <torch/csrc/distributed/c10d/logging.h>
11 
12 // NS: TODO: Use `std::filesystem` regardless of OS when it's possible
13 // to use it without leaking symbols on PRECXX11 ABI Linux OSes
14 // See https://github.com/pytorch/pytorch/issues/133437 for more details
15 #ifdef _WIN32
16 #include <filesystem>
17 #else
18 #include <sys/stat.h>
19 #endif
20 
21 namespace c10d {
22 namespace control_plane {
23 
24 namespace {
25 class RequestImpl : public Request {
26  public:
RequestImpl(const httplib::Request & req)27   RequestImpl(const httplib::Request& req) : req_(req) {}
28 
body()29   const std::string& body() override {
30     return req_.body;
31   }
32 
33  private:
34   const httplib::Request& req_;
35 };
36 
37 class ResponseImpl : public Response {
38  public:
ResponseImpl(httplib::Response & res)39   ResponseImpl(httplib::Response& res) : res_(res) {}
40 
setStatus(int status)41   void setStatus(int status) override {
42     res_.status = status;
43   }
44 
setContent(std::string && content,const std::string & content_type)45   void setContent(std::string&& content, const std::string& content_type)
46       override {
47     res_.set_content(std::move(content), content_type);
48   }
49 
50  private:
51   httplib::Response& res_;
52 };
53 
jsonStrEscape(const std::string & str)54 std::string jsonStrEscape(const std::string& str) {
55   std::ostringstream ostream;
56   for (char ch : str) {
57     if (ch == '"') {
58       ostream << "\\\"";
59     } else if (ch == '\\') {
60       ostream << "\\\\";
61     } else if (ch == '\b') {
62       ostream << "\\b";
63     } else if (ch == '\f') {
64       ostream << "\\f";
65     } else if (ch == '\n') {
66       ostream << "\\n";
67     } else if (ch == '\r') {
68       ostream << "\\r";
69     } else if (ch == '\t') {
70       ostream << "\\t";
71     } else if ('\x00' <= ch && ch <= '\x1f') {
72       ostream << "\\u" << std::hex << std::setw(4) << std::setfill('0')
73               << static_cast<int>(ch);
74     } else {
75       ostream << ch;
76     }
77   }
78   return ostream.str();
79 }
80 
file_exists(const std::string & path)81 bool file_exists(const std::string& path) {
82 #ifdef _WIN32
83   return std::filesystem::exists(path);
84 #else
85   struct stat rc;
86   return lstat(path.c_str(), &rc) == 0;
87 #endif
88 }
89 } // namespace
90 
WorkerServer(const std::string & hostOrFile,int port)91 WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
92   server_.Get("/", [](const httplib::Request& req, httplib::Response& res) {
93     res.set_content(
94         R"BODY(<h1>torch.distributed.WorkerServer</h1>
95 <a href="/handler/">Handler names</a>
96 )BODY",
97         "text/html");
98   });
99   server_.Get(
100       "/handler/", [](const httplib::Request& req, httplib::Response& res) {
101         std::ostringstream body;
102         body << "[";
103         bool first = true;
104         for (const auto& name : getHandlerNames()) {
105           if (!first) {
106             body << ",";
107           }
108           first = false;
109 
110           body << "\"" << jsonStrEscape(name) << "\"";
111         }
112         body << "]";
113 
114         res.set_content(body.str(), "application/json");
115       });
116   server_.Post(
117       "/handler/:handler",
118       [](const httplib::Request& req, httplib::Response& res) {
119         auto handler_name = req.path_params.at("handler");
120         HandlerFunc handler;
121         try {
122           handler = getHandler(handler_name);
123         } catch (const std::exception& e) {
124           res.status = 404;
125           res.set_content(
126               fmt::format("Handler {} not found: {}", handler_name, e.what()),
127               "text/plain");
128           return;
129         }
130         RequestImpl torchReq{req};
131         ResponseImpl torchRes{res};
132 
133         try {
134           handler(torchReq, torchRes);
135         } catch (const std::exception& e) {
136           res.status = 500;
137           res.set_content(
138               fmt::format("Handler {} failed: {}", handler_name, e.what()),
139               "text/plain");
140           return;
141         } catch (...) {
142           res.status = 500;
143           res.set_content(
144               fmt::format(
145                   "Handler {} failed with unknown exception", handler_name),
146               "text/plain");
147           return;
148         }
149       });
150 
151   // adjust keep alives as it stops the server from shutting down quickly
152   server_.set_keep_alive_timeout(1); // second, default is 5
153   server_.set_keep_alive_max_count(
154       30); // wait max 30 seconds before closing socket
155 
156   if (port == -1) {
157     // using unix sockets
158     server_.set_address_family(AF_UNIX);
159 
160     if (file_exists(hostOrFile)) {
161       throw std::runtime_error(fmt::format("{} already exists", hostOrFile));
162     }
163 
164     C10D_WARNING("Server listening to UNIX {}", hostOrFile);
165     if (!server_.bind_to_port(hostOrFile, 80)) {
166       throw std::runtime_error(fmt::format("Error binding to {}", hostOrFile));
167     }
168   } else {
169     C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
170     if (!server_.bind_to_port(hostOrFile, port)) {
171       throw std::runtime_error(
172           fmt::format("Error binding to {}:{}", hostOrFile, port));
173     }
174   }
175 
176   serverThread_ = std::thread([this]() {
177     try {
178       if (!server_.listen_after_bind()) {
179         throw std::runtime_error("failed to listen");
180       }
181     } catch (std::exception& e) {
182       C10D_ERROR("Error while running server: {}", e.what());
183       throw;
184     }
185     C10D_WARNING("Server exited");
186   });
187 }
188 
shutdown()189 void WorkerServer::shutdown() {
190   C10D_WARNING("Server shutting down");
191   server_.stop();
192   serverThread_.join();
193 }
194 
~WorkerServer()195 WorkerServer::~WorkerServer() {
196   if (serverThread_.joinable()) {
197     C10D_WARNING("WorkerServer destructor called without shutdown");
198     shutdown();
199   }
200 }
201 
202 } // namespace control_plane
203 } // namespace c10d
204