xref: /aosp_15_r20/external/pytorch/torch/lib/libshm/manager.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <fcntl.h>
2 #include <poll.h>
3 #include <sys/mman.h>
4 #include <unistd.h>
5 #include <algorithm>
6 #include <cerrno>
7 #include <memory>
8 #include <set>
9 #include <unordered_map>
10 #include <vector>
11 
12 #include <c10/util/tempfile.h>
13 
14 #include <libshm/err.h>
15 #include <libshm/socket.h>
16 
17 const int SHUTDOWN_TIMEOUT = 2000; // 2s
18 
19 #ifdef DEBUG_LOG
20 #define COLOR "\033[31;1m"
21 #define RESET "\033[0m"
22 #define __DEBUG(msg, ...) fprintf(stderr, COLOR msg "%c" RESET, __VA_ARGS__);
23 #define DEBUG(...) __DEBUG(__VA_ARGS__, '\n')
24 #else
25 #define DEBUG(...) (void)0
26 #endif
27 
28 struct ClientSession {
ClientSessionClientSession29   ClientSession(ManagerSocket s) : socket(std::move(s)), pid(0) {}
30 
31   ManagerSocket socket;
32   pid_t pid;
33 };
34 
35 std::vector<struct pollfd> pollfds;
36 std::unordered_map<int, ClientSession> client_sessions;
37 // TODO: check if objects have been freed from time to time
38 std::set<std::string> used_objects;
39 
register_fd(int fd)40 void register_fd(int fd) {
41   struct pollfd pfd = {0};
42   pfd.fd = fd;
43   pfd.events = POLLIN;
44   pollfds.push_back(pfd);
45 }
46 
unregister_fd(int fd)47 void unregister_fd(int fd) {
48   pollfds.erase(
49       std::remove_if(
50           pollfds.begin(),
51           pollfds.end(),
52           [fd](const struct pollfd& pfd) { return pfd.fd == fd; }),
53       pollfds.end());
54   client_sessions.erase(fd);
55 }
56 
print_init_message(std::string_view message)57 void print_init_message(std::string_view message) {
58   ssize_t written_bytes = -1;
59   while (!message.empty()) {
60     // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
61     SYSCHECK_ERR_RETURN_NEG1(
62         written_bytes = write(1, message.data(), message.size()));
63     message.remove_prefix(written_bytes);
64   }
65   written_bytes = 0;
66   while (written_bytes != 1) {
67     // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
68     SYSCHECK_ERR_RETURN_NEG1(written_bytes = write(1, "\n", 1));
69   }
70 }
71 
object_exists(const char * name)72 bool object_exists(const char* name) {
73   int fd = shm_open(name, O_RDONLY, 0);
74   if (fd >= 0) {
75     close(fd);
76     return true;
77   } else {
78     return false;
79   }
80 }
81 
free_used_object(const std::string & name)82 void free_used_object(const std::string& name) {
83   if (!object_exists(name.c_str())) {
84     DEBUG("object %s appears to have been freed", name.c_str());
85     used_objects.erase(name);
86   } else {
87     DEBUG("object %s still exists", name.c_str());
88   }
89 }
90 
91 // NOLINTNEXTLINE(bugprone-exception-escape)
main(int argc,char * argv[])92 int main(int argc, char* argv[]) {
93   setsid(); // Daemonize the process
94 
95   std::unique_ptr<ManagerServerSocket> srv_socket;
96   std::optional<c10::TempDir> tempdir;
97   try {
98     tempdir = c10::try_make_tempdir(/*name_prefix=*/"torch-shm-dir-");
99     if (!tempdir.has_value()) {
100       throw std::runtime_error(
101           "could not generate a random directory for manager socket");
102     }
103 
104     std::string tempfile = tempdir->name + "/manager.sock";
105 
106     srv_socket = std::make_unique<ManagerServerSocket>(tempfile);
107     register_fd(srv_socket->socket_fd);
108     print_init_message(tempfile.c_str());
109     DEBUG("opened socket %s", tempfile.c_str());
110   } catch (const std::exception& e) {
111     std::string message("ERROR: ");
112     message += e.what();
113     print_init_message(message.c_str());
114     return 1;
115   } catch (...) {
116     print_init_message("ERROR: unhandled exception");
117     return 1;
118   }
119 
120   int timeout = -1;
121   std::vector<int> to_add;
122   std::vector<int> to_remove;
123   for (;;) {
124     int nevents = -1;
125     if (client_sessions.empty())
126       timeout = SHUTDOWN_TIMEOUT;
127     // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
128     SYSCHECK_ERR_RETURN_NEG1(
129         nevents = poll(pollfds.data(), pollfds.size(), timeout));
130     timeout = -1;
131     if (nevents == 0 && client_sessions.empty())
132       break;
133 
134     for (auto& pfd : pollfds) {
135       if (pfd.revents & (POLLERR | POLLHUP)) {
136         // some process died
137         DEBUG("detaching process");
138         auto& session = client_sessions.at(pfd.fd);
139         (void)session;
140         DEBUG("%d has died", session.pid);
141         to_remove.push_back(pfd.fd);
142       } else if (pfd.revents & POLLIN) {
143         if (pfd.fd == srv_socket->socket_fd) {
144           // someone is joining
145           DEBUG("registered new client");
146           auto client = srv_socket->accept();
147           int fd = client.socket_fd;
148           to_add.push_back(fd);
149           client_sessions.emplace(fd, std::move(client));
150         } else {
151           // someone wants to register a segment
152           DEBUG("got alloc info");
153           auto& session = client_sessions.at(pfd.fd);
154           AllocInfo info = session.socket.receive();
155           session.pid = info.pid;
156           DEBUG(
157               "got alloc info: %d %d %s",
158               (int)info.free,
159               info.pid,
160               info.filename);
161           if (info.free) {
162             free_used_object(info.filename);
163           } else {
164             used_objects.insert(info.filename);
165             DEBUG("registered object %s", info.filename);
166             session.socket.confirm();
167           }
168         }
169       }
170     }
171 
172     for (int fd : to_add)
173       register_fd(fd);
174     to_add.clear();
175 
176     for (int fd : to_remove)
177       unregister_fd(fd);
178     to_remove.clear();
179   }
180 
181   for (auto& obj_name : used_objects) {
182     DEBUG("freeing %s", obj_name.c_str());
183     shm_unlink(obj_name.c_str());
184   }
185 
186   // Clean up file descriptors
187   for (auto& pfd : pollfds) {
188     unregister_fd(pfd.fd);
189   }
190   // Clean up manager.sock
191   srv_socket->remove();
192   // Clean up directory automatically
193 
194   DEBUG("manager done");
195   return 0;
196 }
197