1
2 #include <c10/util/irange.h>
3 #include <fcntl.h>
4 #include <algorithm>
5 #include <array>
6 #include <system_error>
7 #include <unordered_map>
8 #include <utility>
9
10 #ifdef _WIN32
11 #include <io.h>
12 #include <winsock2.h>
13 #else
14 #include <poll.h>
15 #include <unistd.h>
16 #endif
17
18 #include <c10/util/thread_name.h>
19 #include <torch/csrc/distributed/c10d/TCPStoreBackend.hpp>
20 #include <torch/csrc/distributed/c10d/logging.h>
21
22 #ifdef _WIN32
23 #include <torch/csrc/distributed/c10d/WinSockUtils.hpp>
24 #else
25 #include <torch/csrc/distributed/c10d/UnixSockUtils.hpp>
26 #endif
27
28 #include <torch/csrc/distributed/c10d/socket.h>
29
30 namespace c10d::detail {
31
32 // Background thread parent class methods
33 BackgroundThread::BackgroundThread() = default;
34
35 BackgroundThread::~BackgroundThread() = default;
36
37 // WARNING:
38 // Since we rely on the subclass for the daemon thread clean-up, we cannot
39 // destruct our member variables in the destructor. The subclass must call
40 // dispose() in its own destructor.
dispose()41 void BackgroundThread::dispose() {
42 // Stop the run
43 stop();
44 // Join the thread
45 daemonThread_.join();
46 }
47
start()48 void BackgroundThread::start() {
49 daemonThread_ = std::thread{&BackgroundThread::run, this};
50 is_running_.store(true);
51 }
52
53 // Separate thread that is only launched on master
54 class TCPStoreMasterDaemon : public BackgroundThread {
55 public:
56 explicit TCPStoreMasterDaemon(Socket&& storeListenSocket);
57
58 ~TCPStoreMasterDaemon() override;
59
60 uint16_t port() const override;
61
62 protected:
63 void run() override;
64 void stop() override;
65
66 private:
67 void initStopSignal();
68 void closeStopSignal();
69
70 void queryFds(std::vector<struct pollfd>& fds);
71 void query(int socket);
72
73 void clearSocketWaitState(int socket);
74
75 // The master runs on a single thread so only
76 // one handler can be executed at a time
77 void validateHandler(int socket);
78 void pingHandler(int socket);
79 void setHandler(int socket);
80 void compareSetHandler(int socket);
81 void addHandler(int socket);
82 void getHandler(int socket) const;
83 void checkHandler(int socket) const;
84 void getNumKeysHandler(int socket) const;
85 void deleteHandler(int socket);
86 void waitHandler(int socket);
87 void appendHandler(int socket);
88 void multiGetHandler(int socket);
89 void multiSetHandler(int socket);
90 void cancelWaitHandler(int socket);
91 void addMiscellaneousSocket(int socket);
92 void removeMiscellaneousSocket(int socket);
93 bool isMiscellaneousSocket(int socket);
94
95 bool checkKeys(const std::vector<std::string>& keys) const;
96 // Helper function to alerts waiting workers, used in setHandler, getHandler
97 void wakeupWaitingClients(const std::string& key);
98 void doSet(const std::string& key, const std::vector<uint8_t>& newData);
99
100 std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_;
101 // From key -> the list of sockets waiting on the key
102 std::unordered_map<std::string, std::vector<int>> waitingSockets_;
103 // From socket -> number of keys awaited
104 std::unordered_map<int, size_t> keysAwaited_;
105 // miscellaneous sockets
106 std::unordered_set<int> miscellaneousSockets_;
107
108 Socket storeListenSocket_;
109 std::vector<Socket> sockets_{};
110 #ifdef _WIN32
111 const std::chrono::milliseconds checkTimeout_ = std::chrono::milliseconds{10};
112 HANDLE ghStopEvent_{};
113 #else
114 std::array<int, 2> controlPipeFd_{{-1, -1}};
115 #endif
116 };
117
118 // Simply start the daemon thread
TCPStoreMasterDaemon(Socket && storeListenSocket)119 TCPStoreMasterDaemon::TCPStoreMasterDaemon(Socket&& storeListenSocket)
120 : storeListenSocket_{std::move(storeListenSocket)} {
121 initStopSignal();
122 }
123
~TCPStoreMasterDaemon()124 TCPStoreMasterDaemon::~TCPStoreMasterDaemon() {
125 dispose();
126 // it's now safe for us to cleanup
127 // Close unclosed sockets
128 sockets_.clear();
129 // Now close the rest control pipe
130 closeStopSignal();
131 }
132
port() const133 std::uint16_t TCPStoreMasterDaemon::port() const {
134 return storeListenSocket_.port();
135 }
136
137 #ifdef _WIN32
initStopSignal()138 void TCPStoreMasterDaemon::initStopSignal() {
139 ghStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL);
140 if (ghStopEvent_ == NULL) {
141 TORCH_CHECK(
142 false,
143 "Failed to create the control pipe to start the "
144 "BackgroundThread run");
145 }
146 }
147
closeStopSignal()148 void TCPStoreMasterDaemon::closeStopSignal() {
149 CloseHandle(ghStopEvent_);
150 }
151
stop()152 void TCPStoreMasterDaemon::stop() {
153 SetEvent(ghStopEvent_);
154 }
155
156 #else
initStopSignal()157 void TCPStoreMasterDaemon::initStopSignal() {
158 if (pipe(controlPipeFd_.data()) == -1) {
159 TORCH_CHECK(
160 false,
161 "Failed to create the control pipe to start the "
162 "BackgroundThread run");
163 }
164 }
165
closeStopSignal()166 void TCPStoreMasterDaemon::closeStopSignal() {
167 for (int fd : controlPipeFd_) {
168 if (fd != -1) {
169 ::close(fd);
170 }
171 }
172 }
173
stop()174 void TCPStoreMasterDaemon::stop() {
175 if (controlPipeFd_[1] != -1) {
176 ssize_t written_bytes = -1;
177 while (true) {
178 written_bytes = ::write(controlPipeFd_[1], "\0", 1);
179 if (written_bytes < 0) {
180 if (errno == EAGAIN) {
181 continue;
182 }
183 TORCH_CHECK(false, "Failed to write the control pipe:", errno);
184 }
185 break;
186 }
187 if (written_bytes == 0) {
188 TORCH_CHECK(false, "Failed to write the control pipe");
189 }
190
191 // close the write end of the pipe
192 ::close(controlPipeFd_[1]);
193 controlPipeFd_[1] = -1;
194 }
195 }
196 #endif
197
queryFds(std::vector<struct pollfd> & fds)198 void TCPStoreMasterDaemon::queryFds(std::vector<struct pollfd>& fds) {
199 // Skipping the fds[0] and fds[1],
200 // fds[0] is master's listening socket
201 // fds[1] is control pipe's reading fd, it is not for Windows platform
202 for (size_t fdIdx = CONNECT_SOCKET_OFFSET; fdIdx < fds.size(); ++fdIdx) {
203 if (fds[fdIdx].revents == 0) {
204 continue;
205 }
206
207 // Now query the socket that has the event
208 try {
209 query(fds[fdIdx].fd);
210 } catch (...) {
211 // There was an error when processing query. Probably an exception
212 // occurred in recv/send what would indicate that socket on the other
213 // side has been closed. If the closing was due to normal exit, then
214 // the store should continue executing. Otherwise, if it was different
215 // exception, other connections will get an exception once they try to
216 // use the store. We will go ahead and close this connection whenever
217 // we hit an exception here.
218 clearSocketWaitState(fds[fdIdx].fd);
219
220 fds.erase(fds.begin() + fdIdx);
221 sockets_.erase(sockets_.begin() + fdIdx - CONNECT_SOCKET_OFFSET);
222 --fdIdx;
223 continue;
224 }
225 }
226 }
227
clearSocketWaitState(int socket)228 void TCPStoreMasterDaemon::clearSocketWaitState(int socket) {
229 // Remove all the tracking state of the close FD
230 for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
231 for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
232 if (*vecIt == socket) {
233 vecIt = it->second.erase(vecIt);
234 } else {
235 ++vecIt;
236 }
237 }
238 if (it->second.empty()) {
239 it = waitingSockets_.erase(it);
240 } else {
241 ++it;
242 }
243 }
244 for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) {
245 if (it->first == socket) {
246 it = keysAwaited_.erase(it);
247 } else {
248 ++it;
249 }
250 }
251 }
252
253 // query communicates with the worker. The format
254 // of the query is as follows:
255 // type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
256 // or, in the case of wait
257 // type of query | number of args | size of arg1 | arg1 | ...
query(int socket)258 void TCPStoreMasterDaemon::query(int socket) {
259 QueryType qt;
260 tcputil::recvBytes<QueryType>(socket, &qt, 1);
261
262 if (isMiscellaneousSocket(socket)) {
263 removeMiscellaneousSocket(socket);
264 if (qt == QueryType::VALIDATE) {
265 validateHandler(socket);
266 } else {
267 // real miscellaneous client: the first msg is not VALIDATE
268 TORCH_CHECK(
269 false, "Miscellaneous client without VALIDATE query is detected");
270 }
271
272 } else if (qt == QueryType::PING) {
273 pingHandler(socket);
274
275 } else if (qt == QueryType::SET) {
276 setHandler(socket);
277
278 } else if (qt == QueryType::COMPARE_SET) {
279 compareSetHandler(socket);
280
281 } else if (qt == QueryType::ADD) {
282 addHandler(socket);
283
284 } else if (qt == QueryType::GET) {
285 getHandler(socket);
286
287 } else if (qt == QueryType::CHECK) {
288 checkHandler(socket);
289
290 } else if (qt == QueryType::WAIT) {
291 waitHandler(socket);
292
293 } else if (qt == QueryType::GETNUMKEYS) {
294 getNumKeysHandler(socket);
295
296 } else if (qt == QueryType::DELETE_KEY) {
297 deleteHandler(socket);
298 } else if (qt == QueryType::APPEND) {
299 appendHandler(socket);
300 } else if (qt == QueryType::MULTI_GET) {
301 multiGetHandler(socket);
302 } else if (qt == QueryType::MULTI_SET) {
303 multiSetHandler(socket);
304 } else if (qt == QueryType::CANCEL_WAIT) {
305 cancelWaitHandler(socket);
306 } else {
307 TORCH_CHECK(false, "Unexpected query type");
308 }
309 }
310
wakeupWaitingClients(const std::string & key)311 void TCPStoreMasterDaemon::wakeupWaitingClients(const std::string& key) {
312 auto socketsToWait = waitingSockets_.find(key);
313 if (socketsToWait != waitingSockets_.end()) {
314 for (int socket : socketsToWait->second) {
315 if (--keysAwaited_[socket] == 0) {
316 tcputil::sendValue<WaitResponseType>(
317 socket, WaitResponseType::STOP_WAITING);
318 }
319 }
320 waitingSockets_.erase(socketsToWait);
321 }
322 }
323
doSet(const std::string & key,const std::vector<uint8_t> & newData)324 void TCPStoreMasterDaemon::doSet(
325 const std::string& key,
326 const std::vector<uint8_t>& newData) {
327 tcpStore_[key] = newData;
328 // On "set", wake up all clients that have been waiting
329 wakeupWaitingClients(key);
330 }
331
validateHandler(int socket)332 void TCPStoreMasterDaemon::validateHandler(int socket) {
333 uint32_t validateNumber = 0;
334 tcputil::recvBytes<uint32_t>(socket, &validateNumber, 1);
335 if (validateNumber != detail::validationMagicNumber) {
336 TORCH_CHECK(
337 false,
338 "Miscellaneous client with incorrect VALIDATE query is detected");
339 }
340 }
341
pingHandler(int socket)342 void TCPStoreMasterDaemon::pingHandler(int socket) {
343 uint32_t nonce = 0;
344 tcputil::recvBytes<uint32_t>(socket, &nonce, 1);
345 tcputil::sendValue<uint32_t>(socket, nonce);
346 }
347
setHandler(int socket)348 void TCPStoreMasterDaemon::setHandler(int socket) {
349 std::string key = tcputil::recvString(socket);
350 std::vector<uint8_t> newData = tcputil::recvVector<uint8_t>(socket);
351 doSet(key, newData);
352 }
353
compareSetHandler(int socket)354 void TCPStoreMasterDaemon::compareSetHandler(int socket) {
355 std::string key = tcputil::recvString(socket);
356 std::vector<uint8_t> currentValue = tcputil::recvVector<uint8_t>(socket);
357 std::vector<uint8_t> newValue = tcputil::recvVector<uint8_t>(socket);
358
359 auto pos = tcpStore_.find(key);
360 if (pos == tcpStore_.end()) {
361 if (currentValue.empty()) {
362 tcpStore_[key] = newValue;
363 tcputil::sendVector<uint8_t>(socket, newValue);
364 } else {
365 // TODO: This code path is not ideal as we are "lying" to the caller in
366 // case the key does not exist. We should come up with a working solution.
367 tcputil::sendVector<uint8_t>(socket, currentValue);
368 }
369 } else {
370 if (pos->second == currentValue) {
371 pos->second = std::move(newValue);
372 }
373 tcputil::sendVector<uint8_t>(socket, pos->second);
374 }
375 }
376
addHandler(int socket)377 void TCPStoreMasterDaemon::addHandler(int socket) {
378 std::string key = tcputil::recvString(socket);
379 int64_t addVal = tcputil::recvValue<int64_t>(socket);
380
381 auto it = tcpStore_.find(key);
382 if (it != tcpStore_.end()) {
383 auto buf = reinterpret_cast<const char*>(it->second.data());
384 auto len = it->second.size();
385 addVal += std::stoll(std::string(buf, len));
386 }
387 auto addValStr = std::to_string(addVal);
388 std::vector<uint8_t> newData =
389 std::vector<uint8_t>(addValStr.begin(), addValStr.end());
390 tcpStore_[key] = newData;
391 // Now send the new value
392 tcputil::sendValue<int64_t>(socket, addVal);
393 // On "add", wake up all clients that have been waiting
394 wakeupWaitingClients(key);
395 }
396
getHandler(int socket) const397 void TCPStoreMasterDaemon::getHandler(int socket) const {
398 std::string key = tcputil::recvString(socket);
399 auto data = tcpStore_.at(key);
400 tcputil::sendVector<uint8_t>(socket, data);
401 }
402
getNumKeysHandler(int socket) const403 void TCPStoreMasterDaemon::getNumKeysHandler(int socket) const {
404 tcputil::sendValue<int64_t>(socket, tcpStore_.size());
405 }
406
deleteHandler(int socket)407 void TCPStoreMasterDaemon::deleteHandler(int socket) {
408 std::string key = tcputil::recvString(socket);
409 auto numDeleted = tcpStore_.erase(key);
410 tcputil::sendValue<int64_t>(socket, numDeleted);
411 }
412
checkHandler(int socket) const413 void TCPStoreMasterDaemon::checkHandler(int socket) const {
414 SizeType nargs = 0;
415 tcputil::recvBytes<SizeType>(socket, &nargs, 1);
416 std::vector<std::string> keys(nargs);
417 for (const auto i : c10::irange(nargs)) {
418 keys[i] = tcputil::recvString(socket);
419 }
420 // Now we have received all the keys
421 if (checkKeys(keys)) {
422 tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::READY);
423 } else {
424 tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::NOT_READY);
425 }
426 }
427
waitHandler(int socket)428 void TCPStoreMasterDaemon::waitHandler(int socket) {
429 SizeType nargs = 0;
430 tcputil::recvBytes<SizeType>(socket, &nargs, 1);
431 std::vector<std::string> keys(nargs);
432 for (const auto i : c10::irange(nargs)) {
433 keys[i] = tcputil::recvString(socket);
434 }
435 if (checkKeys(keys)) {
436 tcputil::sendValue<WaitResponseType>(
437 socket, WaitResponseType::STOP_WAITING);
438 } else {
439 int numKeysToAwait = 0;
440 for (auto& key : keys) {
441 // Only count keys that have not already been set
442 if (tcpStore_.find(key) == tcpStore_.end()) {
443 waitingSockets_[key].push_back(socket);
444 numKeysToAwait++;
445 }
446 }
447 keysAwaited_[socket] = numKeysToAwait;
448 }
449 }
450
appendHandler(int socket)451 void TCPStoreMasterDaemon::appendHandler(int socket) {
452 std::string key = tcputil::recvString(socket);
453 std::vector<uint8_t> newData = tcputil::recvVector<uint8_t>(socket);
454 auto it = tcpStore_.find(key);
455 if (it != tcpStore_.end()) {
456 it->second.insert(it->second.end(), newData.begin(), newData.end());
457 } else {
458 tcpStore_[key] = newData;
459 }
460 // we should not have clients waiting if we're appending, so it's all fine
461 wakeupWaitingClients(key);
462 }
463
multiGetHandler(int socket)464 void TCPStoreMasterDaemon::multiGetHandler(int socket) {
465 SizeType nargs = 0;
466 tcputil::recvBytes<SizeType>(socket, &nargs, 1);
467 for (const auto i : c10::irange(nargs)) {
468 auto key = tcputil::recvString(socket);
469 auto& data = tcpStore_.at(key);
470 tcputil::sendVector<uint8_t>(socket, data, i < (nargs - 1));
471 }
472 }
473
multiSetHandler(int socket)474 void TCPStoreMasterDaemon::multiSetHandler(int socket) {
475 SizeType nargs = 0;
476 tcputil::recvBytes<SizeType>(socket, &nargs, 1);
477 for (auto _ : c10::irange(nargs)) {
478 (void)_; // Suppress unused variable warning
479 auto key = tcputil::recvString(socket);
480 auto value = tcputil::recvVector<uint8_t>(socket);
481 doSet(key, value);
482 }
483 }
484
cancelWaitHandler(int socket)485 void TCPStoreMasterDaemon::cancelWaitHandler(int socket) {
486 clearSocketWaitState(socket);
487
488 // Send update to TCPStoreWorkerDaemon on client
489 tcputil::sendValue<WaitResponseType>(
490 socket, detail::WaitResponseType::WAIT_CANCELED);
491 }
492
checkKeys(const std::vector<std::string> & keys) const493 bool TCPStoreMasterDaemon::checkKeys(
494 const std::vector<std::string>& keys) const {
495 return std::all_of(keys.begin(), keys.end(), [this](const std::string& s) {
496 return tcpStore_.count(s) > 0;
497 });
498 }
499
addMiscellaneousSocket(int socket)500 void TCPStoreMasterDaemon::addMiscellaneousSocket(int socket) {
501 if (miscellaneousSockets_.find(socket) == miscellaneousSockets_.end()) {
502 miscellaneousSockets_.insert(socket);
503 }
504 }
505
removeMiscellaneousSocket(int socket)506 void TCPStoreMasterDaemon::removeMiscellaneousSocket(int socket) {
507 auto it = miscellaneousSockets_.find(socket);
508 if (it != miscellaneousSockets_.end()) {
509 miscellaneousSockets_.erase(it);
510 }
511 }
512
isMiscellaneousSocket(int socket)513 bool TCPStoreMasterDaemon::isMiscellaneousSocket(int socket) {
514 return miscellaneousSockets_.find(socket) != miscellaneousSockets_.end();
515 }
516
517 #ifdef _WIN32
run()518 void TCPStoreMasterDaemon::run() {
519 std::vector<struct pollfd> fds;
520 tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
521
522 // receive the queries
523 bool finished = false;
524 while (!finished) {
525 for (const auto i : c10::irange(sockets_.size())) {
526 fds[i].revents = 0;
527 }
528
529 int res;
530 SYSCHECK_ERR_RETURN_NEG1(
531 res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
532 if (res == 0) {
533 auto rv = WaitForSingleObject(ghStopEvent_, 0);
534 if (rv != WAIT_TIMEOUT) {
535 finished = true;
536 break;
537 }
538 continue;
539 }
540
541 // TCPStore's listening socket has an event and it should now be able to
542 // accept new connections.
543 if (fds[0].revents != 0) {
544 if (!(fds[0].revents & POLLIN)) {
545 C10_THROW_ERROR(
546 DistStoreError,
547 "Unexpected poll revent on the master's listening socket: " +
548 std::to_string(fds[0].revents));
549 }
550 Socket socket = storeListenSocket_.accept();
551 int rawSocket = socket.handle();
552 sockets_.emplace_back(std::move(socket));
553 tcputil::addPollfd(fds, rawSocket, POLLIN);
554 addMiscellaneousSocket(rawSocket);
555 }
556 queryFds(fds);
557 }
558 }
559 #else
run()560 void TCPStoreMasterDaemon::run() {
561 try {
562 c10::setThreadName("pt_tcpstore");
563
564 std::vector<struct pollfd> fds;
565 tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
566 // Although we haven't found any documentation or literature describing
567 // this, we've seen cases that, under certain circumstances, the read end of
568 // the pipe won't receive POLLHUP when the write end is closed. However,
569 // under the same circumstances, writing to the pipe will guarantee POLLIN
570 // to be received on the read end.
571 //
572 // For more reliable termination, the main thread will write a byte to the
573 // pipe before closing it, and the background thread will poll for both
574 // POLLIN and POLLHUP.
575 tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP);
576
577 // receive the queries
578 bool finished = false;
579 while (!finished) {
580 for (const auto i : c10::irange(sockets_.size())) {
581 fds[i].revents = 0;
582 }
583
584 SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));
585
586 // TCPStore's listening socket has an event and it should now be able to
587 // accept new connections.
588 if (fds[0].revents != 0) {
589 if (fds[0].revents ^ POLLIN) {
590 C10_THROW_ERROR(
591 DistStoreError,
592 "Unexpected poll revent on the master's listening socket: " +
593 std::to_string(fds[0].revents));
594 }
595 Socket socket = storeListenSocket_.accept();
596 int rawSocket = socket.handle();
597 sockets_.emplace_back(std::move(socket));
598 tcputil::addPollfd(fds, rawSocket, POLLIN);
599 // all clients are miscellaneous before getting its validation query
600 addMiscellaneousSocket(rawSocket);
601 }
602
603 // The pipe receives an event which tells us to shutdown the daemon
604 if (fds[1].revents != 0) {
605 // The main thread will write a byte to the pipe then close it before
606 // joining the background thread
607 if (fds[1].revents & ~(POLLIN | POLLHUP)) {
608 C10_THROW_ERROR(
609 DistStoreError,
610 "Unexpected poll revent on the control pipe's reading fd: " +
611 std::to_string(fds[1].revents));
612 }
613 finished = true;
614 break;
615 }
616 queryFds(fds);
617 }
618 } catch (const std::exception& ex) {
619 C10D_ERROR(
620 "TCPStoreMasterDaemon::run() failed with exception: ", ex.what());
621 throw;
622 } catch (...) {
623 C10D_ERROR("TCPStoreMasterDaemon::run() failed with unknown exception");
624 throw;
625 }
626 }
627 #endif
628
create_tcpstore_backend(const TCPStoreOptions & opts)629 std::unique_ptr<BackgroundThread> create_tcpstore_backend(
630 const TCPStoreOptions& opts) {
631 Socket socket = opts.masterListenFd.has_value()
632 ? Socket::listenFromFd(*opts.masterListenFd, opts.port)
633 : Socket::listen(opts.port);
634
635 return std::make_unique<TCPStoreMasterDaemon>(std::move(socket));
636 }
637
638 } // namespace c10d::detail
639