xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/TCPStoreBackend.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 
2 #include <c10/util/irange.h>
3 #include <fcntl.h>
4 #include <algorithm>
5 #include <array>
6 #include <system_error>
7 #include <unordered_map>
8 #include <utility>
9 
10 #ifdef _WIN32
11 #include <io.h>
12 #include <winsock2.h>
13 #else
14 #include <poll.h>
15 #include <unistd.h>
16 #endif
17 
18 #include <c10/util/thread_name.h>
19 #include <torch/csrc/distributed/c10d/TCPStoreBackend.hpp>
20 #include <torch/csrc/distributed/c10d/logging.h>
21 
22 #ifdef _WIN32
23 #include <torch/csrc/distributed/c10d/WinSockUtils.hpp>
24 #else
25 #include <torch/csrc/distributed/c10d/UnixSockUtils.hpp>
26 #endif
27 
28 #include <torch/csrc/distributed/c10d/socket.h>
29 
30 namespace c10d::detail {
31 
32 // Background thread parent class methods
33 BackgroundThread::BackgroundThread() = default;
34 
35 BackgroundThread::~BackgroundThread() = default;
36 
37 // WARNING:
38 // Since we rely on the subclass for the daemon thread clean-up, we cannot
39 // destruct our member variables in the destructor. The subclass must call
40 // dispose() in its own destructor.
dispose()41 void BackgroundThread::dispose() {
42   // Stop the run
43   stop();
44   // Join the thread
45   daemonThread_.join();
46 }
47 
start()48 void BackgroundThread::start() {
49   daemonThread_ = std::thread{&BackgroundThread::run, this};
50   is_running_.store(true);
51 }
52 
53 // Separate thread that is only launched on master
54 class TCPStoreMasterDaemon : public BackgroundThread {
55  public:
56   explicit TCPStoreMasterDaemon(Socket&& storeListenSocket);
57 
58   ~TCPStoreMasterDaemon() override;
59 
60   uint16_t port() const override;
61 
62  protected:
63   void run() override;
64   void stop() override;
65 
66  private:
67   void initStopSignal();
68   void closeStopSignal();
69 
70   void queryFds(std::vector<struct pollfd>& fds);
71   void query(int socket);
72 
73   void clearSocketWaitState(int socket);
74 
75   // The master runs on a single thread so only
76   // one handler can be executed at a time
77   void validateHandler(int socket);
78   void pingHandler(int socket);
79   void setHandler(int socket);
80   void compareSetHandler(int socket);
81   void addHandler(int socket);
82   void getHandler(int socket) const;
83   void checkHandler(int socket) const;
84   void getNumKeysHandler(int socket) const;
85   void deleteHandler(int socket);
86   void waitHandler(int socket);
87   void appendHandler(int socket);
88   void multiGetHandler(int socket);
89   void multiSetHandler(int socket);
90   void cancelWaitHandler(int socket);
91   void addMiscellaneousSocket(int socket);
92   void removeMiscellaneousSocket(int socket);
93   bool isMiscellaneousSocket(int socket);
94 
95   bool checkKeys(const std::vector<std::string>& keys) const;
96   // Helper function to alerts waiting workers, used in setHandler, getHandler
97   void wakeupWaitingClients(const std::string& key);
98   void doSet(const std::string& key, const std::vector<uint8_t>& newData);
99 
100   std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_;
101   // From key -> the list of sockets waiting on the key
102   std::unordered_map<std::string, std::vector<int>> waitingSockets_;
103   // From socket -> number of keys awaited
104   std::unordered_map<int, size_t> keysAwaited_;
105   // miscellaneous sockets
106   std::unordered_set<int> miscellaneousSockets_;
107 
108   Socket storeListenSocket_;
109   std::vector<Socket> sockets_{};
110 #ifdef _WIN32
111   const std::chrono::milliseconds checkTimeout_ = std::chrono::milliseconds{10};
112   HANDLE ghStopEvent_{};
113 #else
114   std::array<int, 2> controlPipeFd_{{-1, -1}};
115 #endif
116 };
117 
118 // Simply start the daemon thread
TCPStoreMasterDaemon(Socket && storeListenSocket)119 TCPStoreMasterDaemon::TCPStoreMasterDaemon(Socket&& storeListenSocket)
120     : storeListenSocket_{std::move(storeListenSocket)} {
121   initStopSignal();
122 }
123 
~TCPStoreMasterDaemon()124 TCPStoreMasterDaemon::~TCPStoreMasterDaemon() {
125   dispose();
126   // it's now safe for us to cleanup
127   // Close unclosed sockets
128   sockets_.clear();
129   // Now close the rest control pipe
130   closeStopSignal();
131 }
132 
port() const133 std::uint16_t TCPStoreMasterDaemon::port() const {
134   return storeListenSocket_.port();
135 }
136 
137 #ifdef _WIN32
initStopSignal()138 void TCPStoreMasterDaemon::initStopSignal() {
139   ghStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL);
140   if (ghStopEvent_ == NULL) {
141     TORCH_CHECK(
142         false,
143         "Failed to create the control pipe to start the "
144         "BackgroundThread run");
145   }
146 }
147 
closeStopSignal()148 void TCPStoreMasterDaemon::closeStopSignal() {
149   CloseHandle(ghStopEvent_);
150 }
151 
stop()152 void TCPStoreMasterDaemon::stop() {
153   SetEvent(ghStopEvent_);
154 }
155 
156 #else
initStopSignal()157 void TCPStoreMasterDaemon::initStopSignal() {
158   if (pipe(controlPipeFd_.data()) == -1) {
159     TORCH_CHECK(
160         false,
161         "Failed to create the control pipe to start the "
162         "BackgroundThread run");
163   }
164 }
165 
closeStopSignal()166 void TCPStoreMasterDaemon::closeStopSignal() {
167   for (int fd : controlPipeFd_) {
168     if (fd != -1) {
169       ::close(fd);
170     }
171   }
172 }
173 
stop()174 void TCPStoreMasterDaemon::stop() {
175   if (controlPipeFd_[1] != -1) {
176     ssize_t written_bytes = -1;
177     while (true) {
178       written_bytes = ::write(controlPipeFd_[1], "\0", 1);
179       if (written_bytes < 0) {
180         if (errno == EAGAIN) {
181           continue;
182         }
183         TORCH_CHECK(false, "Failed to write the control pipe:", errno);
184       }
185       break;
186     }
187     if (written_bytes == 0) {
188       TORCH_CHECK(false, "Failed to write the control pipe");
189     }
190 
191     // close the write end of the pipe
192     ::close(controlPipeFd_[1]);
193     controlPipeFd_[1] = -1;
194   }
195 }
196 #endif
197 
queryFds(std::vector<struct pollfd> & fds)198 void TCPStoreMasterDaemon::queryFds(std::vector<struct pollfd>& fds) {
199   // Skipping the fds[0] and fds[1],
200   // fds[0] is master's listening socket
201   // fds[1] is control pipe's reading fd, it is not for Windows platform
202   for (size_t fdIdx = CONNECT_SOCKET_OFFSET; fdIdx < fds.size(); ++fdIdx) {
203     if (fds[fdIdx].revents == 0) {
204       continue;
205     }
206 
207     // Now query the socket that has the event
208     try {
209       query(fds[fdIdx].fd);
210     } catch (...) {
211       // There was an error when processing query. Probably an exception
212       // occurred in recv/send what would indicate that socket on the other
213       // side has been closed. If the closing was due to normal exit, then
214       // the store should continue executing. Otherwise, if it was different
215       // exception, other connections will get an exception once they try to
216       // use the store. We will go ahead and close this connection whenever
217       // we hit an exception here.
218       clearSocketWaitState(fds[fdIdx].fd);
219 
220       fds.erase(fds.begin() + fdIdx);
221       sockets_.erase(sockets_.begin() + fdIdx - CONNECT_SOCKET_OFFSET);
222       --fdIdx;
223       continue;
224     }
225   }
226 }
227 
clearSocketWaitState(int socket)228 void TCPStoreMasterDaemon::clearSocketWaitState(int socket) {
229   // Remove all the tracking state of the close FD
230   for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
231     for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
232       if (*vecIt == socket) {
233         vecIt = it->second.erase(vecIt);
234       } else {
235         ++vecIt;
236       }
237     }
238     if (it->second.empty()) {
239       it = waitingSockets_.erase(it);
240     } else {
241       ++it;
242     }
243   }
244   for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) {
245     if (it->first == socket) {
246       it = keysAwaited_.erase(it);
247     } else {
248       ++it;
249     }
250   }
251 }
252 
253 // query communicates with the worker. The format
254 // of the query is as follows:
255 // type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
256 // or, in the case of wait
257 // type of query | number of args | size of arg1 | arg1 | ...
query(int socket)258 void TCPStoreMasterDaemon::query(int socket) {
259   QueryType qt;
260   tcputil::recvBytes<QueryType>(socket, &qt, 1);
261 
262   if (isMiscellaneousSocket(socket)) {
263     removeMiscellaneousSocket(socket);
264     if (qt == QueryType::VALIDATE) {
265       validateHandler(socket);
266     } else {
267       // real miscellaneous client: the first msg is not VALIDATE
268       TORCH_CHECK(
269           false, "Miscellaneous client without VALIDATE query is detected");
270     }
271 
272   } else if (qt == QueryType::PING) {
273     pingHandler(socket);
274 
275   } else if (qt == QueryType::SET) {
276     setHandler(socket);
277 
278   } else if (qt == QueryType::COMPARE_SET) {
279     compareSetHandler(socket);
280 
281   } else if (qt == QueryType::ADD) {
282     addHandler(socket);
283 
284   } else if (qt == QueryType::GET) {
285     getHandler(socket);
286 
287   } else if (qt == QueryType::CHECK) {
288     checkHandler(socket);
289 
290   } else if (qt == QueryType::WAIT) {
291     waitHandler(socket);
292 
293   } else if (qt == QueryType::GETNUMKEYS) {
294     getNumKeysHandler(socket);
295 
296   } else if (qt == QueryType::DELETE_KEY) {
297     deleteHandler(socket);
298   } else if (qt == QueryType::APPEND) {
299     appendHandler(socket);
300   } else if (qt == QueryType::MULTI_GET) {
301     multiGetHandler(socket);
302   } else if (qt == QueryType::MULTI_SET) {
303     multiSetHandler(socket);
304   } else if (qt == QueryType::CANCEL_WAIT) {
305     cancelWaitHandler(socket);
306   } else {
307     TORCH_CHECK(false, "Unexpected query type");
308   }
309 }
310 
wakeupWaitingClients(const std::string & key)311 void TCPStoreMasterDaemon::wakeupWaitingClients(const std::string& key) {
312   auto socketsToWait = waitingSockets_.find(key);
313   if (socketsToWait != waitingSockets_.end()) {
314     for (int socket : socketsToWait->second) {
315       if (--keysAwaited_[socket] == 0) {
316         tcputil::sendValue<WaitResponseType>(
317             socket, WaitResponseType::STOP_WAITING);
318       }
319     }
320     waitingSockets_.erase(socketsToWait);
321   }
322 }
323 
doSet(const std::string & key,const std::vector<uint8_t> & newData)324 void TCPStoreMasterDaemon::doSet(
325     const std::string& key,
326     const std::vector<uint8_t>& newData) {
327   tcpStore_[key] = newData;
328   // On "set", wake up all clients that have been waiting
329   wakeupWaitingClients(key);
330 }
331 
validateHandler(int socket)332 void TCPStoreMasterDaemon::validateHandler(int socket) {
333   uint32_t validateNumber = 0;
334   tcputil::recvBytes<uint32_t>(socket, &validateNumber, 1);
335   if (validateNumber != detail::validationMagicNumber) {
336     TORCH_CHECK(
337         false,
338         "Miscellaneous client with incorrect VALIDATE query is detected");
339   }
340 }
341 
pingHandler(int socket)342 void TCPStoreMasterDaemon::pingHandler(int socket) {
343   uint32_t nonce = 0;
344   tcputil::recvBytes<uint32_t>(socket, &nonce, 1);
345   tcputil::sendValue<uint32_t>(socket, nonce);
346 }
347 
setHandler(int socket)348 void TCPStoreMasterDaemon::setHandler(int socket) {
349   std::string key = tcputil::recvString(socket);
350   std::vector<uint8_t> newData = tcputil::recvVector<uint8_t>(socket);
351   doSet(key, newData);
352 }
353 
compareSetHandler(int socket)354 void TCPStoreMasterDaemon::compareSetHandler(int socket) {
355   std::string key = tcputil::recvString(socket);
356   std::vector<uint8_t> currentValue = tcputil::recvVector<uint8_t>(socket);
357   std::vector<uint8_t> newValue = tcputil::recvVector<uint8_t>(socket);
358 
359   auto pos = tcpStore_.find(key);
360   if (pos == tcpStore_.end()) {
361     if (currentValue.empty()) {
362       tcpStore_[key] = newValue;
363       tcputil::sendVector<uint8_t>(socket, newValue);
364     } else {
365       // TODO: This code path is not ideal as we are "lying" to the caller in
366       // case the key does not exist. We should come up with a working solution.
367       tcputil::sendVector<uint8_t>(socket, currentValue);
368     }
369   } else {
370     if (pos->second == currentValue) {
371       pos->second = std::move(newValue);
372     }
373     tcputil::sendVector<uint8_t>(socket, pos->second);
374   }
375 }
376 
addHandler(int socket)377 void TCPStoreMasterDaemon::addHandler(int socket) {
378   std::string key = tcputil::recvString(socket);
379   int64_t addVal = tcputil::recvValue<int64_t>(socket);
380 
381   auto it = tcpStore_.find(key);
382   if (it != tcpStore_.end()) {
383     auto buf = reinterpret_cast<const char*>(it->second.data());
384     auto len = it->second.size();
385     addVal += std::stoll(std::string(buf, len));
386   }
387   auto addValStr = std::to_string(addVal);
388   std::vector<uint8_t> newData =
389       std::vector<uint8_t>(addValStr.begin(), addValStr.end());
390   tcpStore_[key] = newData;
391   // Now send the new value
392   tcputil::sendValue<int64_t>(socket, addVal);
393   // On "add", wake up all clients that have been waiting
394   wakeupWaitingClients(key);
395 }
396 
getHandler(int socket) const397 void TCPStoreMasterDaemon::getHandler(int socket) const {
398   std::string key = tcputil::recvString(socket);
399   auto data = tcpStore_.at(key);
400   tcputil::sendVector<uint8_t>(socket, data);
401 }
402 
getNumKeysHandler(int socket) const403 void TCPStoreMasterDaemon::getNumKeysHandler(int socket) const {
404   tcputil::sendValue<int64_t>(socket, tcpStore_.size());
405 }
406 
deleteHandler(int socket)407 void TCPStoreMasterDaemon::deleteHandler(int socket) {
408   std::string key = tcputil::recvString(socket);
409   auto numDeleted = tcpStore_.erase(key);
410   tcputil::sendValue<int64_t>(socket, numDeleted);
411 }
412 
checkHandler(int socket) const413 void TCPStoreMasterDaemon::checkHandler(int socket) const {
414   SizeType nargs = 0;
415   tcputil::recvBytes<SizeType>(socket, &nargs, 1);
416   std::vector<std::string> keys(nargs);
417   for (const auto i : c10::irange(nargs)) {
418     keys[i] = tcputil::recvString(socket);
419   }
420   // Now we have received all the keys
421   if (checkKeys(keys)) {
422     tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::READY);
423   } else {
424     tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::NOT_READY);
425   }
426 }
427 
waitHandler(int socket)428 void TCPStoreMasterDaemon::waitHandler(int socket) {
429   SizeType nargs = 0;
430   tcputil::recvBytes<SizeType>(socket, &nargs, 1);
431   std::vector<std::string> keys(nargs);
432   for (const auto i : c10::irange(nargs)) {
433     keys[i] = tcputil::recvString(socket);
434   }
435   if (checkKeys(keys)) {
436     tcputil::sendValue<WaitResponseType>(
437         socket, WaitResponseType::STOP_WAITING);
438   } else {
439     int numKeysToAwait = 0;
440     for (auto& key : keys) {
441       // Only count keys that have not already been set
442       if (tcpStore_.find(key) == tcpStore_.end()) {
443         waitingSockets_[key].push_back(socket);
444         numKeysToAwait++;
445       }
446     }
447     keysAwaited_[socket] = numKeysToAwait;
448   }
449 }
450 
appendHandler(int socket)451 void TCPStoreMasterDaemon::appendHandler(int socket) {
452   std::string key = tcputil::recvString(socket);
453   std::vector<uint8_t> newData = tcputil::recvVector<uint8_t>(socket);
454   auto it = tcpStore_.find(key);
455   if (it != tcpStore_.end()) {
456     it->second.insert(it->second.end(), newData.begin(), newData.end());
457   } else {
458     tcpStore_[key] = newData;
459   }
460   // we should not have clients waiting if we're appending, so it's all fine
461   wakeupWaitingClients(key);
462 }
463 
multiGetHandler(int socket)464 void TCPStoreMasterDaemon::multiGetHandler(int socket) {
465   SizeType nargs = 0;
466   tcputil::recvBytes<SizeType>(socket, &nargs, 1);
467   for (const auto i : c10::irange(nargs)) {
468     auto key = tcputil::recvString(socket);
469     auto& data = tcpStore_.at(key);
470     tcputil::sendVector<uint8_t>(socket, data, i < (nargs - 1));
471   }
472 }
473 
multiSetHandler(int socket)474 void TCPStoreMasterDaemon::multiSetHandler(int socket) {
475   SizeType nargs = 0;
476   tcputil::recvBytes<SizeType>(socket, &nargs, 1);
477   for (auto _ : c10::irange(nargs)) {
478     (void)_; // Suppress unused variable warning
479     auto key = tcputil::recvString(socket);
480     auto value = tcputil::recvVector<uint8_t>(socket);
481     doSet(key, value);
482   }
483 }
484 
cancelWaitHandler(int socket)485 void TCPStoreMasterDaemon::cancelWaitHandler(int socket) {
486   clearSocketWaitState(socket);
487 
488   // Send update to TCPStoreWorkerDaemon on client
489   tcputil::sendValue<WaitResponseType>(
490       socket, detail::WaitResponseType::WAIT_CANCELED);
491 }
492 
checkKeys(const std::vector<std::string> & keys) const493 bool TCPStoreMasterDaemon::checkKeys(
494     const std::vector<std::string>& keys) const {
495   return std::all_of(keys.begin(), keys.end(), [this](const std::string& s) {
496     return tcpStore_.count(s) > 0;
497   });
498 }
499 
addMiscellaneousSocket(int socket)500 void TCPStoreMasterDaemon::addMiscellaneousSocket(int socket) {
501   if (miscellaneousSockets_.find(socket) == miscellaneousSockets_.end()) {
502     miscellaneousSockets_.insert(socket);
503   }
504 }
505 
removeMiscellaneousSocket(int socket)506 void TCPStoreMasterDaemon::removeMiscellaneousSocket(int socket) {
507   auto it = miscellaneousSockets_.find(socket);
508   if (it != miscellaneousSockets_.end()) {
509     miscellaneousSockets_.erase(it);
510   }
511 }
512 
isMiscellaneousSocket(int socket)513 bool TCPStoreMasterDaemon::isMiscellaneousSocket(int socket) {
514   return miscellaneousSockets_.find(socket) != miscellaneousSockets_.end();
515 }
516 
517 #ifdef _WIN32
run()518 void TCPStoreMasterDaemon::run() {
519   std::vector<struct pollfd> fds;
520   tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
521 
522   // receive the queries
523   bool finished = false;
524   while (!finished) {
525     for (const auto i : c10::irange(sockets_.size())) {
526       fds[i].revents = 0;
527     }
528 
529     int res;
530     SYSCHECK_ERR_RETURN_NEG1(
531         res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
532     if (res == 0) {
533       auto rv = WaitForSingleObject(ghStopEvent_, 0);
534       if (rv != WAIT_TIMEOUT) {
535         finished = true;
536         break;
537       }
538       continue;
539     }
540 
541     // TCPStore's listening socket has an event and it should now be able to
542     // accept new connections.
543     if (fds[0].revents != 0) {
544       if (!(fds[0].revents & POLLIN)) {
545         C10_THROW_ERROR(
546             DistStoreError,
547             "Unexpected poll revent on the master's listening socket: " +
548                 std::to_string(fds[0].revents));
549       }
550       Socket socket = storeListenSocket_.accept();
551       int rawSocket = socket.handle();
552       sockets_.emplace_back(std::move(socket));
553       tcputil::addPollfd(fds, rawSocket, POLLIN);
554       addMiscellaneousSocket(rawSocket);
555     }
556     queryFds(fds);
557   }
558 }
559 #else
run()560 void TCPStoreMasterDaemon::run() {
561   try {
562     c10::setThreadName("pt_tcpstore");
563 
564     std::vector<struct pollfd> fds;
565     tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
566     // Although we haven't found any documentation or literature describing
567     // this, we've seen cases that, under certain circumstances, the read end of
568     // the pipe won't receive POLLHUP when the write end is closed. However,
569     // under the same circumstances, writing to the pipe will guarantee POLLIN
570     // to be received on the read end.
571     //
572     // For more reliable termination, the main thread will write a byte to the
573     // pipe before closing it, and the background thread will poll for both
574     // POLLIN and POLLHUP.
575     tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP);
576 
577     // receive the queries
578     bool finished = false;
579     while (!finished) {
580       for (const auto i : c10::irange(sockets_.size())) {
581         fds[i].revents = 0;
582       }
583 
584       SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));
585 
586       // TCPStore's listening socket has an event and it should now be able to
587       // accept new connections.
588       if (fds[0].revents != 0) {
589         if (fds[0].revents ^ POLLIN) {
590           C10_THROW_ERROR(
591               DistStoreError,
592               "Unexpected poll revent on the master's listening socket: " +
593                   std::to_string(fds[0].revents));
594         }
595         Socket socket = storeListenSocket_.accept();
596         int rawSocket = socket.handle();
597         sockets_.emplace_back(std::move(socket));
598         tcputil::addPollfd(fds, rawSocket, POLLIN);
599         // all clients are miscellaneous before getting its validation query
600         addMiscellaneousSocket(rawSocket);
601       }
602 
603       // The pipe receives an event which tells us to shutdown the daemon
604       if (fds[1].revents != 0) {
605         // The main thread will write a byte to the pipe then close it before
606         // joining the background thread
607         if (fds[1].revents & ~(POLLIN | POLLHUP)) {
608           C10_THROW_ERROR(
609               DistStoreError,
610               "Unexpected poll revent on the control pipe's reading fd: " +
611                   std::to_string(fds[1].revents));
612         }
613         finished = true;
614         break;
615       }
616       queryFds(fds);
617     }
618   } catch (const std::exception& ex) {
619     C10D_ERROR(
620         "TCPStoreMasterDaemon::run() failed with exception: ", ex.what());
621     throw;
622   } catch (...) {
623     C10D_ERROR("TCPStoreMasterDaemon::run() failed with unknown exception");
624     throw;
625   }
626 }
627 #endif
628 
create_tcpstore_backend(const TCPStoreOptions & opts)629 std::unique_ptr<BackgroundThread> create_tcpstore_backend(
630     const TCPStoreOptions& opts) {
631   Socket socket = opts.masterListenFd.has_value()
632       ? Socket::listenFromFd(*opts.masterListenFd, opts.port)
633       : Socket::listen(opts.port);
634 
635   return std::make_unique<TCPStoreMasterDaemon>(std::move(socket));
636 }
637 
638 } // namespace c10d::detail
639