xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/TCPStore.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/WaitCounter.h>
2 #include <c10/util/irange.h>
3 #include <fmt/format.h>
4 #include <fmt/ranges.h>
5 #include <torch/csrc/distributed/c10d/Backoff.hpp>
6 #include <torch/csrc/distributed/c10d/TCPStore.hpp>
7 #include <torch/csrc/distributed/c10d/TCPStoreBackend.hpp>
8 #include <torch/csrc/distributed/c10d/logging.h>
9 
10 #include <fcntl.h>
11 #include <chrono>
12 #include <fstream>
13 #include <random>
14 #include <thread>
15 #include <unordered_map>
16 #include <utility>
17 
18 #ifdef _WIN32
19 #include <io.h>
20 #include <winsock2.h>
21 #else
22 #include <poll.h>
23 #include <unistd.h>
24 #endif
25 
26 #ifdef _WIN32
27 #include <torch/csrc/distributed/c10d/WinSockUtils.hpp>
28 #else
29 #include <torch/csrc/distributed/c10d/UnixSockUtils.hpp>
30 #endif
31 
32 #include <torch/csrc/distributed/c10d/socket.h>
33 
34 namespace c10d {
35 namespace detail {
36 
37 // Manages the lifecycle of a server daemon.
38 class TCPServer {
39  public:
40   static std::shared_ptr<TCPServer> start(const TCPStoreOptions& opts);
41 
port() const42   std::uint16_t port() const noexcept {
43     return port_;
44   }
45 
TCPServer(std::uint16_t port,std::unique_ptr<BackgroundThread> && daemon)46   explicit TCPServer(
47       std::uint16_t port,
48       std::unique_ptr<BackgroundThread>&& daemon)
49       : port_{port}, daemon_{std::move(daemon)} {}
50 
repr() const51   std::string repr() const {
52     return fmt::format("TCPServer(port={})", port_);
53   }
54 
55  private:
56   std::uint16_t port_;
57   std::unique_ptr<BackgroundThread> daemon_;
58 
59   // We store weak references to all TCPServers for which the caller requested
60   // multi-tenancy.
61   static std::unordered_map<std::uint16_t, std::weak_ptr<TCPServer>>
62       cachedServers_;
63 
64   static std::mutex cache_mutex_;
65 };
66 
67 std::unordered_map<std::uint16_t, std::weak_ptr<TCPServer>>
68     TCPServer::cachedServers_{};
69 
70 std::mutex TCPServer::cache_mutex_{};
71 
start(const TCPStoreOptions & opts)72 std::shared_ptr<TCPServer> TCPServer::start(const TCPStoreOptions& opts) {
73   auto startCore = [&opts]() {
74     auto daemon = opts.useLibUV ? create_libuv_tcpstore_backend(opts)
75                                 : create_tcpstore_backend(opts);
76     daemon->start();
77     return std::make_shared<TCPServer>(daemon->port(), std::move(daemon));
78   };
79 
80   std::shared_ptr<TCPServer> server{};
81 
82   if (opts.multiTenant) {
83     std::lock_guard<std::mutex> guard{cache_mutex_};
84 
85     // If the caller is okay with a multi-tenant store, first check if we
86     // already have a TCPServer running on the specified port.
87     if (opts.port > 0) {
88       auto pos = cachedServers_.find(opts.port);
89       if (pos != cachedServers_.end()) {
90         server = pos->second.lock();
91         if (server != nullptr) {
92           return server;
93         }
94 
95         // Looks like the TCPStore has been disposed, make sure that we release
96         // the control block.
97         cachedServers_.erase(pos);
98       }
99     }
100 
101     server = startCore();
102 
103     cachedServers_.emplace(server->port(), server);
104   } else {
105     server = startCore();
106   }
107 
108   return server;
109 }
110 
111 class TCPClient {
112  public:
113   static std::unique_ptr<TCPClient> connect(
114       const SocketAddress& addr,
115       const TCPStoreOptions& opts,
116       std::shared_ptr<Backoff> backoff);
117 
sendRaw(uint8_t * data,size_t length)118   void sendRaw(uint8_t* data, size_t length) {
119     try {
120       tcputil::sendBytes(socket_.handle(), data, length);
121     } catch (const std::exception& e) {
122       C10D_WARNING("sendBytes failed on {}: {}", socket_.repr(), e.what());
123       throw;
124     }
125   }
126 
receiveBits()127   std::vector<std::uint8_t> receiveBits() {
128     try {
129       return tcputil::recvVector<std::uint8_t>(socket_.handle());
130     } catch (const std::exception& e) {
131       C10D_WARNING("recvVector failed on {}: {}", socket_.repr(), e.what());
132       throw;
133     }
134   }
135 
136   template <typename T>
receiveValue()137   T receiveValue() {
138     try {
139       return tcputil::recvValue<T>(socket_.handle());
140     } catch (const std::exception& e) {
141       C10D_WARNING("recvValue failed on {}: {}", socket_.repr(), e.what());
142       throw;
143     }
144   }
145   template <typename T>
receiveValueWithTimeout(T & t,std::chrono::milliseconds timeout)146   bool receiveValueWithTimeout(T& t, std::chrono::milliseconds timeout) {
147     if (!socket_.waitForInput(timeout))
148       return false;
149     t = tcputil::recvValue<T>(socket_.handle());
150     return true;
151   }
152   void setTimeout(std::chrono::milliseconds value);
153 
TCPClient(Socket && socket)154   explicit TCPClient(Socket&& socket) : socket_{std::move(socket)} {}
155 
repr() const156   std::string repr() const {
157     return fmt::format("TCPClient({})", socket_.repr());
158   }
159 
160  private:
161   Socket socket_;
162 };
163 
connect(const SocketAddress & addr,const TCPStoreOptions & opts,std::shared_ptr<Backoff> backoff)164 std::unique_ptr<TCPClient> TCPClient::connect(
165     const SocketAddress& addr,
166     const TCPStoreOptions& opts,
167     std::shared_ptr<Backoff> backoff) {
168   Socket socket = Socket::connect(
169       addr.host,
170       addr.port,
171       SocketOptions{}
172           .connect_timeout(opts.timeout)
173           .connect_backoff(std::move(backoff)));
174 
175   return std::make_unique<TCPClient>(std::move(socket));
176 }
177 
setTimeout(std::chrono::milliseconds value)178 void TCPClient::setTimeout(std::chrono::milliseconds value) {
179   if (value == std::chrono::milliseconds::zero()) {
180     return;
181   }
182 
183 #ifdef _WIN32
184   struct timeval timeoutTV = {
185       static_cast<long>(value.count() / 1000),
186       static_cast<long>((value.count() % 1000) * 1000)};
187 #else
188   struct timeval timeoutTV = {
189       .tv_sec = value.count() / 1000,
190       .tv_usec = static_cast<suseconds_t>((value.count() % 1000) * 1000),
191   };
192 #endif
193   SYSCHECK_ERR_RETURN_NEG1(::setsockopt(
194       socket_.handle(),
195       SOL_SOCKET,
196       SO_RCVTIMEO,
197       reinterpret_cast<char*>(&timeoutTV),
198       sizeof(timeoutTV)));
199 }
200 
201 class SendBuffer {
202   // ethernet mtu 1500 - 40 (ip v6 header) - 20 (tcp header)
203   const size_t FLUSH_WATERMARK = 1440;
204   std::vector<uint8_t> buffer;
205   detail::TCPClient& client;
206 
maybeFlush()207   void maybeFlush() {
208     if (buffer.size() >= FLUSH_WATERMARK) {
209       flush();
210     }
211   }
212 
213  public:
SendBuffer(detail::TCPClient & client,detail::QueryType cmd)214   SendBuffer(detail::TCPClient& client, detail::QueryType cmd)
215       : client(client) {
216     buffer.reserve(32); // enough for most commands
217     buffer.push_back((uint8_t)cmd);
218   }
219 
appendString(const std::string & str)220   void appendString(const std::string& str) {
221     appendValue<uint64_t>(str.size());
222     buffer.insert(buffer.end(), str.begin(), str.end());
223     maybeFlush();
224   }
225 
appendBytes(const std::vector<uint8_t> & vec)226   void appendBytes(const std::vector<uint8_t>& vec) {
227     appendValue<uint64_t>(vec.size());
228     buffer.insert(buffer.end(), vec.begin(), vec.end());
229     maybeFlush();
230   }
231 
232   template <typename T>
appendValue(T value)233   void appendValue(T value) {
234     uint8_t* begin = (uint8_t*)&value;
235     buffer.insert(buffer.end(), begin, begin + sizeof(T));
236     maybeFlush();
237   }
238 
flush()239   void flush() {
240     if (!buffer.empty()) {
241       client.sendRaw(buffer.data(), buffer.size());
242       buffer.clear();
243     }
244   }
245 };
246 
247 } // namespace detail
248 
249 using detail::Socket;
250 
251 // TCPStore class methods
TCPStore(const std::string & masterAddr,std::uint16_t masterPort,std::optional<int> numWorkers,bool isServer,const std::chrono::milliseconds & timeout,bool waitWorkers)252 TCPStore::TCPStore(
253     const std::string& masterAddr,
254     std::uint16_t masterPort,
255     std::optional<int> numWorkers,
256     bool isServer,
257     const std::chrono::milliseconds& timeout,
258     bool waitWorkers)
259     : TCPStore{
260           masterAddr,
261           TCPStoreOptions{
262               masterPort,
263               isServer,
264               numWorkers ? std::optional<std::size_t>(*numWorkers)
265                          : std::nullopt,
266               waitWorkers,
267               timeout}} {}
268 
TCPStore(std::string host,const TCPStoreOptions & opts)269 TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts)
270     : Store{opts.timeout},
271       addr_{std::move(host)},
272       numWorkers_{opts.numWorkers},
273       usingLibUv_{opts.useLibUV} {
274   STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__init);
275 
276   if (opts.useLibUV) {
277     TORCH_CHECK(
278         ::c10d::detail::is_libuv_tcpstore_backend_available(),
279         "use_libuv was requested but PyTorch was build without libuv support");
280 
281     if (opts.masterListenFd.has_value()) {
282       // TODO(xilunwu): support this init method after testing
283       constexpr auto* msg =
284           "The libuv TCPStore backend does not support initialization with an listen fd. "
285           "Please switch to the legacy TCPStore by setting environment variable USE_LIBUV "
286           "to \"0\".";
287       C10D_ERROR(msg);
288       C10_THROW_ERROR(NotImplementedError, msg);
289       return;
290     }
291   }
292 
293   Socket::initialize();
294 
295   if (opts.isServer) {
296     server_ = detail::TCPServer::start(opts);
297     // server successfully started
298     C10D_DEBUG("The server has started on port = {}.", server_->port());
299 
300     std::ifstream maxconnFile("/proc/sys/net/core/somaxconn");
301     if (maxconnFile.good() && numWorkers_.has_value()) {
302       try {
303         std::string str(
304             (std::istreambuf_iterator<char>(maxconnFile)),
305             std::istreambuf_iterator<char>());
306         std::size_t somaxconn = std::stoll(str);
307         if (somaxconn < *numWorkers_) {
308           C10D_WARNING(
309               "Starting store with {} workers but somaxconn is {}."
310               "This might cause instability during bootstrap, consider increasing it.",
311               *numWorkers_,
312               somaxconn);
313         }
314       } catch (std::logic_error& e) {
315         C10D_INFO("failed to parse somaxconn proc file due to {}", e.what());
316       }
317     }
318 
319     addr_.port = server_->port();
320   } else {
321     addr_.port = opts.port;
322   }
323 
324   // Try connecting several times -- if the server listen backlog is full it may
325   // fail on the first send in validate.
326   auto deadline = std::chrono::steady_clock::now() + opts.timeout;
327   auto backoff = std::make_shared<ExponentialBackoffWithJitter>();
328 
329   auto retry = 0;
330   do {
331     try {
332       client_ = detail::TCPClient::connect(addr_, opts, backoff);
333       // TCP connection established
334       C10D_DEBUG("TCP client connected to host {}:{}", addr_.host, addr_.port);
335 
336       // client's first query for validation
337       validate();
338 
339       // ping to verify network connectivity
340       ping();
341 
342       // success
343       break;
344     } catch (const c10::DistNetworkError& ex) {
345       if (deadline < std::chrono::steady_clock::now()) {
346         C10D_ERROR(
347             "TCP client failed to connect/validate to host {}:{} - timed out (try={}, timeout={}ms): {}",
348             addr_.host,
349             addr_.port,
350             retry,
351             opts.timeout.count(),
352             ex.what());
353         throw;
354       }
355 
356       auto delayDuration = backoff->nextBackoff();
357 
358       C10D_WARNING(
359           "TCP client failed to connect/validate to host {}:{} - retrying (try={}, timeout={}ms, delay={}ms): {}",
360           addr_.host,
361           addr_.port,
362           retry,
363           opts.timeout.count(),
364           delayDuration.count(),
365           ex.what());
366 
367       std::this_thread::sleep_for(delayDuration);
368       retry += 1;
369     }
370   } while (true);
371 
372   if (opts.waitWorkers) {
373     waitForWorkers();
374   }
375 }
376 
377 TCPStore::~TCPStore() = default;
378 
waitForWorkers()379 void TCPStore::waitForWorkers() {
380   STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__waitForWorkers);
381   if (numWorkers_ == std::nullopt) {
382     return;
383   }
384 
385   incrementValueBy(initKey_, 1);
386 
387   // Let server block until all workers have completed, this ensures that
388   // the server daemon thread is always running until the very end
389   if (server_) {
390     const auto start = std::chrono::steady_clock::now();
391     while (true) {
392       // TODO: Any chance to make this cleaner?
393       std::vector<uint8_t> value = doGet(initKey_);
394       auto buf = reinterpret_cast<const char*>(value.data());
395       auto len = value.size();
396       int numWorkersCompleted = std::stoi(std::string(buf, len));
397       if (numWorkersCompleted >= static_cast<int>(*numWorkers_)) {
398         break;
399       }
400       const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
401           std::chrono::steady_clock::now() - start);
402       if (timeout_ != kNoTimeout && elapsed > timeout_) {
403         C10_THROW_ERROR(
404             DistStoreError,
405             fmt::format(
406                 "Timed out after {} seconds waiting for clients. {}/{} clients joined.",
407                 elapsed.count(),
408                 numWorkersCompleted,
409                 *numWorkers_));
410       }
411       /* sleep override */
412       std::this_thread::sleep_for(std::chrono::milliseconds(10));
413     }
414   }
415 }
416 
validate()417 void TCPStore::validate() {
418   const std::lock_guard<std::mutex> lock(activeOpLock_);
419   detail::SendBuffer buffer(*client_, detail::QueryType::VALIDATE);
420   buffer.appendValue<std::uint32_t>(c10d::detail::validationMagicNumber);
421   buffer.flush();
422 }
423 
ping()424 void TCPStore::ping() {
425   const std::lock_guard<std::mutex> lock(activeOpLock_);
426   detail::SendBuffer buffer(*client_, detail::QueryType::PING);
427 
428   uint32_t nonce = getpid();
429   buffer.appendValue<std::uint32_t>(nonce);
430   buffer.flush();
431 
432   uint32_t returnedNonce = client_->receiveValue<std::uint32_t>();
433   TORCH_INTERNAL_ASSERT(
434       nonce == returnedNonce, "Ping failed, invalid nonce returned");
435 }
436 
_splitSet(const std::string & key,const std::vector<uint8_t> & data)437 void TCPStore::_splitSet(
438     const std::string& key,
439     const std::vector<uint8_t>& data) {
440   const std::lock_guard<std::mutex> lock(activeOpLock_);
441   detail::SendBuffer buffer(*client_, detail::QueryType::SET);
442   buffer.appendString(keyPrefix_ + key);
443   buffer.flush();
444   std::this_thread::sleep_for(std::chrono::milliseconds(1000));
445   buffer.appendBytes(data);
446   buffer.flush();
447 }
448 
set(const std::string & key,const std::vector<uint8_t> & data)449 void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
450   STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__set);
451   const std::lock_guard<std::mutex> lock(activeOpLock_);
452   detail::SendBuffer buffer(*client_, detail::QueryType::SET);
453   buffer.appendString(keyPrefix_ + key);
454   buffer.appendBytes(data);
455   buffer.flush();
456 }
457 
compareSet(const std::string & key,const std::vector<uint8_t> & expectedValue,const std::vector<uint8_t> & desiredValue)458 std::vector<uint8_t> TCPStore::compareSet(
459     const std::string& key,
460     const std::vector<uint8_t>& expectedValue,
461     const std::vector<uint8_t>& desiredValue) {
462   STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__compareSet);
463   const std::lock_guard<std::mutex> lock(activeOpLock_);
464   detail::SendBuffer buffer(*client_, detail::QueryType::COMPARE_SET);
465   buffer.appendString(keyPrefix_ + key);
466   buffer.appendBytes(expectedValue);
467   buffer.appendBytes(desiredValue);
468   buffer.flush();
469 
470   return client_->receiveBits();
471 }
472 
get(const std::string & key)473 std::vector<uint8_t> TCPStore::get(const std::string& key) {
474   STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__get);
475   const std::lock_guard<std::mutex> lock(activeOpLock_);
476   return doGet(keyPrefix_ + key);
477 }
478 
doGet(const std::string & key)479 std::vector<uint8_t> TCPStore::doGet(const std::string& key) {
480   doWait(key, timeout_);
481   detail::SendBuffer buffer(*client_, detail::QueryType::GET);
482   buffer.appendString(key);
483   buffer.flush();
484 
485   return client_->receiveBits();
486 }
487 
add(const std::string & key,int64_t value)488 int64_t TCPStore::add(const std::string& key, int64_t value) {
489   STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__add);
490   const std::lock_guard<std::mutex> lock(activeOpLock_);
491   return incrementValueBy(keyPrefix_ + key, value);
492 }
493 
deleteKey(const std::string & key)494 bool TCPStore::deleteKey(const std::string& key) {
495   STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__delete);
496   const std::lock_guard<std::mutex> lock(activeOpLock_);
497   detail::SendBuffer buffer(*client_, detail::QueryType::DELETE_KEY);
498   buffer.appendString(keyPrefix_ + key);
499   buffer.flush();
500 
501   auto numDeleted = client_->receiveValue<std::int64_t>();
502   return numDeleted == 1;
503 }
504 
incrementValueBy(const std::string & key,int64_t delta)505 int64_t TCPStore::incrementValueBy(const std::string& key, int64_t delta) {
506   detail::SendBuffer buff(*client_, detail::QueryType::ADD);
507   buff.appendString(key);
508   buff.appendValue<std::int64_t>(delta);
509   buff.flush();
510 
511   return client_->receiveValue<std::int64_t>();
512 }
513 
getNumKeys()514 int64_t TCPStore::getNumKeys() {
515   const std::lock_guard<std::mutex> lock(activeOpLock_);
516   detail::SendBuffer buffer(*client_, detail::QueryType::GETNUMKEYS);
517   buffer.flush();
518 
519   return client_->receiveValue<std::int64_t>();
520 }
521 
check(const std::vector<std::string> & keys)522 bool TCPStore::check(const std::vector<std::string>& keys) {
523   STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__check);
524   const std::lock_guard<std::mutex> lock(activeOpLock_);
525   detail::SendBuffer buffer(*client_, detail::QueryType::CHECK);
526   buffer.appendValue(keys.size());
527 
528   for (const std::string& key : keys) {
529     buffer.appendString(keyPrefix_ + key);
530   }
531   buffer.flush();
532 
533   auto response = client_->receiveValue<detail::CheckResponseType>();
534   if (response == detail::CheckResponseType::READY) {
535     return true;
536   }
537   if (response == detail::CheckResponseType::NOT_READY) {
538     return false;
539   }
540   TORCH_CHECK(false, "ready or not_ready response expected");
541 }
542 
wait(const std::vector<std::string> & keys)543 void TCPStore::wait(const std::vector<std::string>& keys) {
544   wait(keys, timeout_);
545 }
546 
wait(const std::vector<std::string> & keys,const std::chrono::milliseconds & timeout)547 void TCPStore::wait(
548     const std::vector<std::string>& keys,
549     const std::chrono::milliseconds& timeout) {
550   STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__wait);
551   const std::lock_guard<std::mutex> lock(activeOpLock_);
552   std::vector<std::string> prefixedKeys{};
553   prefixedKeys.reserve(keys.size());
554   for (const std::string& key : keys) {
555     prefixedKeys.emplace_back(keyPrefix_ + key);
556   }
557 
558   doWait(prefixedKeys, timeout);
559 }
560 
doWait(c10::ArrayRef<std::string> keys,std::chrono::milliseconds timeout)561 void TCPStore::doWait(
562     c10::ArrayRef<std::string> keys,
563     std::chrono::milliseconds timeout) {
564   {
565     detail::SendBuffer buffer(*client_, detail::QueryType::WAIT);
566     buffer.appendValue(keys.size());
567     for (const std::string& key : keys) {
568       buffer.appendString(key);
569     }
570     buffer.flush();
571   }
572 
573   detail::WaitResponseType response;
574   if (client_->receiveValueWithTimeout<detail::WaitResponseType>(
575           response, timeout)) {
576     if (response != detail::WaitResponseType::STOP_WAITING) {
577       TORCH_CHECK(false, "Stop_waiting response is expected");
578     }
579     return;
580   }
581   // this is the cancel wait timeout, once here we expect the server to respond
582   // in a timely fashion
583   {
584     detail::SendBuffer buffer(*client_, detail::QueryType::CANCEL_WAIT);
585     buffer.flush();
586   }
587 
588   response = client_->receiveValue<detail::WaitResponseType>();
589   // this can happen if the server responds before we cancel, just ignore it
590   if (response != detail::WaitResponseType::WAIT_CANCELED) {
591     if (response != detail::WaitResponseType::STOP_WAITING) {
592       TORCH_CHECK(false, "Stop_waiting response is expected");
593     }
594 
595     response = client_->receiveValue<detail::WaitResponseType>(); // ignore
596     if (response != detail::WaitResponseType::WAIT_CANCELED) {
597       TORCH_CHECK(false, "wait_canceled response is expected");
598     }
599   }
600   C10_THROW_ERROR(
601       DistStoreError,
602       fmt::format(
603           "wait timeout after {}ms, keys: {}",
604           timeout.count(),
605           fmt::join(keys, ", ")));
606 }
607 
append(const std::string & key,const std::vector<uint8_t> & data)608 void TCPStore::append(
609     const std::string& key,
610     const std::vector<uint8_t>& data) {
611   STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__append);
612   const std::lock_guard<std::mutex> lock(activeOpLock_);
613   detail::SendBuffer buffer(*client_, detail::QueryType::APPEND);
614   buffer.appendString(keyPrefix_ + key);
615   buffer.appendBytes(data);
616   buffer.flush();
617 }
618 
multiGet(const std::vector<std::string> & keys)619 std::vector<std::vector<uint8_t>> TCPStore::multiGet(
620     const std::vector<std::string>& keys) {
621   STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__multiGet);
622   const std::lock_guard<std::mutex> lock(activeOpLock_);
623   std::vector<std::string> prefixedKeys;
624   prefixedKeys.reserve(keys.size());
625   for (const std::string& key : keys) {
626     prefixedKeys.emplace_back(keyPrefix_ + key);
627   }
628   doWait(prefixedKeys, timeout_);
629 
630   detail::SendBuffer buffer(*client_, detail::QueryType::MULTI_GET);
631   buffer.appendValue(keys.size());
632   for (auto& key : prefixedKeys) {
633     buffer.appendString(key);
634   }
635   buffer.flush();
636 
637   std::vector<std::vector<uint8_t>> result;
638   result.reserve(keys.size());
639   for (size_t i = 0; i < keys.size(); ++i) {
640     result.emplace_back(client_->receiveBits());
641   }
642   return result;
643 }
644 
multiSet(const std::vector<std::string> & keys,const std::vector<std::vector<uint8_t>> & values)645 void TCPStore::multiSet(
646     const std::vector<std::string>& keys,
647     const std::vector<std::vector<uint8_t>>& values) {
648   STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__multiSet);
649   TORCH_CHECK(
650       keys.size() == values.size(),
651       "multiSet keys and values vectors must be of same size");
652   const std::lock_guard<std::mutex> lock(activeOpLock_);
653 
654   detail::SendBuffer buffer(*client_, detail::QueryType::MULTI_SET);
655   buffer.appendValue<std::int64_t>(keys.size());
656   for (auto i : c10::irange(keys.size())) {
657     buffer.appendString(keyPrefix_ + keys[i]);
658     buffer.appendBytes(values[i]);
659   }
660   buffer.flush();
661 }
662 
hasExtendedApi() const663 bool TCPStore::hasExtendedApi() const {
664   return true;
665 }
666 
repr() const667 std::string TCPStore::repr() const {
668   auto clientRepr = client_ ? client_->repr() : "<nullptr>";
669   auto serverRepr = server_ ? server_->repr() : "<nullptr>";
670   return fmt::format("TCPStore(client={}, server={})", clientRepr, serverRepr);
671 }
672 
673 } // namespace c10d
674