1 #include <fcntl.h>
2 #include <poll.h>
3 #include <sys/mman.h>
4 #include <unistd.h>
5 #include <algorithm>
6 #include <cerrno>
7 #include <memory>
8 #include <set>
9 #include <unordered_map>
10 #include <vector>
11
12 #include <c10/util/tempfile.h>
13
14 #include <libshm/err.h>
15 #include <libshm/socket.h>
16
17 const int SHUTDOWN_TIMEOUT = 2000; // 2s
18
19 #ifdef DEBUG_LOG
20 #define COLOR "\033[31;1m"
21 #define RESET "\033[0m"
22 #define __DEBUG(msg, ...) fprintf(stderr, COLOR msg "%c" RESET, __VA_ARGS__);
23 #define DEBUG(...) __DEBUG(__VA_ARGS__, '\n')
24 #else
25 #define DEBUG(...) (void)0
26 #endif
27
28 struct ClientSession {
ClientSessionClientSession29 ClientSession(ManagerSocket s) : socket(std::move(s)), pid(0) {}
30
31 ManagerSocket socket;
32 pid_t pid;
33 };
34
35 std::vector<struct pollfd> pollfds;
36 std::unordered_map<int, ClientSession> client_sessions;
37 // TODO: check if objects have been freed from time to time
38 std::set<std::string> used_objects;
39
register_fd(int fd)40 void register_fd(int fd) {
41 struct pollfd pfd = {0};
42 pfd.fd = fd;
43 pfd.events = POLLIN;
44 pollfds.push_back(pfd);
45 }
46
unregister_fd(int fd)47 void unregister_fd(int fd) {
48 pollfds.erase(
49 std::remove_if(
50 pollfds.begin(),
51 pollfds.end(),
52 [fd](const struct pollfd& pfd) { return pfd.fd == fd; }),
53 pollfds.end());
54 client_sessions.erase(fd);
55 }
56
print_init_message(std::string_view message)57 void print_init_message(std::string_view message) {
58 ssize_t written_bytes = -1;
59 while (!message.empty()) {
60 // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
61 SYSCHECK_ERR_RETURN_NEG1(
62 written_bytes = write(1, message.data(), message.size()));
63 message.remove_prefix(written_bytes);
64 }
65 written_bytes = 0;
66 while (written_bytes != 1) {
67 // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
68 SYSCHECK_ERR_RETURN_NEG1(written_bytes = write(1, "\n", 1));
69 }
70 }
71
object_exists(const char * name)72 bool object_exists(const char* name) {
73 int fd = shm_open(name, O_RDONLY, 0);
74 if (fd >= 0) {
75 close(fd);
76 return true;
77 } else {
78 return false;
79 }
80 }
81
free_used_object(const std::string & name)82 void free_used_object(const std::string& name) {
83 if (!object_exists(name.c_str())) {
84 DEBUG("object %s appears to have been freed", name.c_str());
85 used_objects.erase(name);
86 } else {
87 DEBUG("object %s still exists", name.c_str());
88 }
89 }
90
91 // NOLINTNEXTLINE(bugprone-exception-escape)
main(int argc,char * argv[])92 int main(int argc, char* argv[]) {
93 setsid(); // Daemonize the process
94
95 std::unique_ptr<ManagerServerSocket> srv_socket;
96 std::optional<c10::TempDir> tempdir;
97 try {
98 tempdir = c10::try_make_tempdir(/*name_prefix=*/"torch-shm-dir-");
99 if (!tempdir.has_value()) {
100 throw std::runtime_error(
101 "could not generate a random directory for manager socket");
102 }
103
104 std::string tempfile = tempdir->name + "/manager.sock";
105
106 srv_socket = std::make_unique<ManagerServerSocket>(tempfile);
107 register_fd(srv_socket->socket_fd);
108 print_init_message(tempfile.c_str());
109 DEBUG("opened socket %s", tempfile.c_str());
110 } catch (const std::exception& e) {
111 std::string message("ERROR: ");
112 message += e.what();
113 print_init_message(message.c_str());
114 return 1;
115 } catch (...) {
116 print_init_message("ERROR: unhandled exception");
117 return 1;
118 }
119
120 int timeout = -1;
121 std::vector<int> to_add;
122 std::vector<int> to_remove;
123 for (;;) {
124 int nevents = -1;
125 if (client_sessions.empty())
126 timeout = SHUTDOWN_TIMEOUT;
127 // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
128 SYSCHECK_ERR_RETURN_NEG1(
129 nevents = poll(pollfds.data(), pollfds.size(), timeout));
130 timeout = -1;
131 if (nevents == 0 && client_sessions.empty())
132 break;
133
134 for (auto& pfd : pollfds) {
135 if (pfd.revents & (POLLERR | POLLHUP)) {
136 // some process died
137 DEBUG("detaching process");
138 auto& session = client_sessions.at(pfd.fd);
139 (void)session;
140 DEBUG("%d has died", session.pid);
141 to_remove.push_back(pfd.fd);
142 } else if (pfd.revents & POLLIN) {
143 if (pfd.fd == srv_socket->socket_fd) {
144 // someone is joining
145 DEBUG("registered new client");
146 auto client = srv_socket->accept();
147 int fd = client.socket_fd;
148 to_add.push_back(fd);
149 client_sessions.emplace(fd, std::move(client));
150 } else {
151 // someone wants to register a segment
152 DEBUG("got alloc info");
153 auto& session = client_sessions.at(pfd.fd);
154 AllocInfo info = session.socket.receive();
155 session.pid = info.pid;
156 DEBUG(
157 "got alloc info: %d %d %s",
158 (int)info.free,
159 info.pid,
160 info.filename);
161 if (info.free) {
162 free_used_object(info.filename);
163 } else {
164 used_objects.insert(info.filename);
165 DEBUG("registered object %s", info.filename);
166 session.socket.confirm();
167 }
168 }
169 }
170 }
171
172 for (int fd : to_add)
173 register_fd(fd);
174 to_add.clear();
175
176 for (int fd : to_remove)
177 unregister_fd(fd);
178 to_remove.clear();
179 }
180
181 for (auto& obj_name : used_objects) {
182 DEBUG("freeing %s", obj_name.c_str());
183 shm_unlink(obj_name.c_str());
184 }
185
186 // Clean up file descriptors
187 for (auto& pfd : pollfds) {
188 unregister_fd(pfd.fd);
189 }
190 // Clean up manager.sock
191 srv_socket->remove();
192 // Clean up directory automatically
193
194 DEBUG("manager done");
195 return 0;
196 }
197