1 #include <c10/util/WaitCounter.h>
2 #include <c10/util/irange.h>
3 #include <fmt/format.h>
4 #include <fmt/ranges.h>
5 #include <torch/csrc/distributed/c10d/Backoff.hpp>
6 #include <torch/csrc/distributed/c10d/TCPStore.hpp>
7 #include <torch/csrc/distributed/c10d/TCPStoreBackend.hpp>
8 #include <torch/csrc/distributed/c10d/logging.h>
9
10 #include <fcntl.h>
11 #include <chrono>
12 #include <fstream>
13 #include <random>
14 #include <thread>
15 #include <unordered_map>
16 #include <utility>
17
18 #ifdef _WIN32
19 #include <io.h>
20 #include <winsock2.h>
21 #else
22 #include <poll.h>
23 #include <unistd.h>
24 #endif
25
26 #ifdef _WIN32
27 #include <torch/csrc/distributed/c10d/WinSockUtils.hpp>
28 #else
29 #include <torch/csrc/distributed/c10d/UnixSockUtils.hpp>
30 #endif
31
32 #include <torch/csrc/distributed/c10d/socket.h>
33
34 namespace c10d {
35 namespace detail {
36
37 // Manages the lifecycle of a server daemon.
38 class TCPServer {
39 public:
40 static std::shared_ptr<TCPServer> start(const TCPStoreOptions& opts);
41
port() const42 std::uint16_t port() const noexcept {
43 return port_;
44 }
45
TCPServer(std::uint16_t port,std::unique_ptr<BackgroundThread> && daemon)46 explicit TCPServer(
47 std::uint16_t port,
48 std::unique_ptr<BackgroundThread>&& daemon)
49 : port_{port}, daemon_{std::move(daemon)} {}
50
repr() const51 std::string repr() const {
52 return fmt::format("TCPServer(port={})", port_);
53 }
54
55 private:
56 std::uint16_t port_;
57 std::unique_ptr<BackgroundThread> daemon_;
58
59 // We store weak references to all TCPServers for which the caller requested
60 // multi-tenancy.
61 static std::unordered_map<std::uint16_t, std::weak_ptr<TCPServer>>
62 cachedServers_;
63
64 static std::mutex cache_mutex_;
65 };
66
67 std::unordered_map<std::uint16_t, std::weak_ptr<TCPServer>>
68 TCPServer::cachedServers_{};
69
70 std::mutex TCPServer::cache_mutex_{};
71
start(const TCPStoreOptions & opts)72 std::shared_ptr<TCPServer> TCPServer::start(const TCPStoreOptions& opts) {
73 auto startCore = [&opts]() {
74 auto daemon = opts.useLibUV ? create_libuv_tcpstore_backend(opts)
75 : create_tcpstore_backend(opts);
76 daemon->start();
77 return std::make_shared<TCPServer>(daemon->port(), std::move(daemon));
78 };
79
80 std::shared_ptr<TCPServer> server{};
81
82 if (opts.multiTenant) {
83 std::lock_guard<std::mutex> guard{cache_mutex_};
84
85 // If the caller is okay with a multi-tenant store, first check if we
86 // already have a TCPServer running on the specified port.
87 if (opts.port > 0) {
88 auto pos = cachedServers_.find(opts.port);
89 if (pos != cachedServers_.end()) {
90 server = pos->second.lock();
91 if (server != nullptr) {
92 return server;
93 }
94
95 // Looks like the TCPStore has been disposed, make sure that we release
96 // the control block.
97 cachedServers_.erase(pos);
98 }
99 }
100
101 server = startCore();
102
103 cachedServers_.emplace(server->port(), server);
104 } else {
105 server = startCore();
106 }
107
108 return server;
109 }
110
111 class TCPClient {
112 public:
113 static std::unique_ptr<TCPClient> connect(
114 const SocketAddress& addr,
115 const TCPStoreOptions& opts,
116 std::shared_ptr<Backoff> backoff);
117
sendRaw(uint8_t * data,size_t length)118 void sendRaw(uint8_t* data, size_t length) {
119 try {
120 tcputil::sendBytes(socket_.handle(), data, length);
121 } catch (const std::exception& e) {
122 C10D_WARNING("sendBytes failed on {}: {}", socket_.repr(), e.what());
123 throw;
124 }
125 }
126
receiveBits()127 std::vector<std::uint8_t> receiveBits() {
128 try {
129 return tcputil::recvVector<std::uint8_t>(socket_.handle());
130 } catch (const std::exception& e) {
131 C10D_WARNING("recvVector failed on {}: {}", socket_.repr(), e.what());
132 throw;
133 }
134 }
135
136 template <typename T>
receiveValue()137 T receiveValue() {
138 try {
139 return tcputil::recvValue<T>(socket_.handle());
140 } catch (const std::exception& e) {
141 C10D_WARNING("recvValue failed on {}: {}", socket_.repr(), e.what());
142 throw;
143 }
144 }
145 template <typename T>
receiveValueWithTimeout(T & t,std::chrono::milliseconds timeout)146 bool receiveValueWithTimeout(T& t, std::chrono::milliseconds timeout) {
147 if (!socket_.waitForInput(timeout))
148 return false;
149 t = tcputil::recvValue<T>(socket_.handle());
150 return true;
151 }
152 void setTimeout(std::chrono::milliseconds value);
153
TCPClient(Socket && socket)154 explicit TCPClient(Socket&& socket) : socket_{std::move(socket)} {}
155
repr() const156 std::string repr() const {
157 return fmt::format("TCPClient({})", socket_.repr());
158 }
159
160 private:
161 Socket socket_;
162 };
163
connect(const SocketAddress & addr,const TCPStoreOptions & opts,std::shared_ptr<Backoff> backoff)164 std::unique_ptr<TCPClient> TCPClient::connect(
165 const SocketAddress& addr,
166 const TCPStoreOptions& opts,
167 std::shared_ptr<Backoff> backoff) {
168 Socket socket = Socket::connect(
169 addr.host,
170 addr.port,
171 SocketOptions{}
172 .connect_timeout(opts.timeout)
173 .connect_backoff(std::move(backoff)));
174
175 return std::make_unique<TCPClient>(std::move(socket));
176 }
177
setTimeout(std::chrono::milliseconds value)178 void TCPClient::setTimeout(std::chrono::milliseconds value) {
179 if (value == std::chrono::milliseconds::zero()) {
180 return;
181 }
182
183 #ifdef _WIN32
184 struct timeval timeoutTV = {
185 static_cast<long>(value.count() / 1000),
186 static_cast<long>((value.count() % 1000) * 1000)};
187 #else
188 struct timeval timeoutTV = {
189 .tv_sec = value.count() / 1000,
190 .tv_usec = static_cast<suseconds_t>((value.count() % 1000) * 1000),
191 };
192 #endif
193 SYSCHECK_ERR_RETURN_NEG1(::setsockopt(
194 socket_.handle(),
195 SOL_SOCKET,
196 SO_RCVTIMEO,
197 reinterpret_cast<char*>(&timeoutTV),
198 sizeof(timeoutTV)));
199 }
200
201 class SendBuffer {
202 // ethernet mtu 1500 - 40 (ip v6 header) - 20 (tcp header)
203 const size_t FLUSH_WATERMARK = 1440;
204 std::vector<uint8_t> buffer;
205 detail::TCPClient& client;
206
maybeFlush()207 void maybeFlush() {
208 if (buffer.size() >= FLUSH_WATERMARK) {
209 flush();
210 }
211 }
212
213 public:
SendBuffer(detail::TCPClient & client,detail::QueryType cmd)214 SendBuffer(detail::TCPClient& client, detail::QueryType cmd)
215 : client(client) {
216 buffer.reserve(32); // enough for most commands
217 buffer.push_back((uint8_t)cmd);
218 }
219
appendString(const std::string & str)220 void appendString(const std::string& str) {
221 appendValue<uint64_t>(str.size());
222 buffer.insert(buffer.end(), str.begin(), str.end());
223 maybeFlush();
224 }
225
appendBytes(const std::vector<uint8_t> & vec)226 void appendBytes(const std::vector<uint8_t>& vec) {
227 appendValue<uint64_t>(vec.size());
228 buffer.insert(buffer.end(), vec.begin(), vec.end());
229 maybeFlush();
230 }
231
232 template <typename T>
appendValue(T value)233 void appendValue(T value) {
234 uint8_t* begin = (uint8_t*)&value;
235 buffer.insert(buffer.end(), begin, begin + sizeof(T));
236 maybeFlush();
237 }
238
flush()239 void flush() {
240 if (!buffer.empty()) {
241 client.sendRaw(buffer.data(), buffer.size());
242 buffer.clear();
243 }
244 }
245 };
246
247 } // namespace detail
248
249 using detail::Socket;
250
251 // TCPStore class methods
TCPStore(const std::string & masterAddr,std::uint16_t masterPort,std::optional<int> numWorkers,bool isServer,const std::chrono::milliseconds & timeout,bool waitWorkers)252 TCPStore::TCPStore(
253 const std::string& masterAddr,
254 std::uint16_t masterPort,
255 std::optional<int> numWorkers,
256 bool isServer,
257 const std::chrono::milliseconds& timeout,
258 bool waitWorkers)
259 : TCPStore{
260 masterAddr,
261 TCPStoreOptions{
262 masterPort,
263 isServer,
264 numWorkers ? std::optional<std::size_t>(*numWorkers)
265 : std::nullopt,
266 waitWorkers,
267 timeout}} {}
268
TCPStore(std::string host,const TCPStoreOptions & opts)269 TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts)
270 : Store{opts.timeout},
271 addr_{std::move(host)},
272 numWorkers_{opts.numWorkers},
273 usingLibUv_{opts.useLibUV} {
274 STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__init);
275
276 if (opts.useLibUV) {
277 TORCH_CHECK(
278 ::c10d::detail::is_libuv_tcpstore_backend_available(),
279 "use_libuv was requested but PyTorch was build without libuv support");
280
281 if (opts.masterListenFd.has_value()) {
282 // TODO(xilunwu): support this init method after testing
283 constexpr auto* msg =
284 "The libuv TCPStore backend does not support initialization with an listen fd. "
285 "Please switch to the legacy TCPStore by setting environment variable USE_LIBUV "
286 "to \"0\".";
287 C10D_ERROR(msg);
288 C10_THROW_ERROR(NotImplementedError, msg);
289 return;
290 }
291 }
292
293 Socket::initialize();
294
295 if (opts.isServer) {
296 server_ = detail::TCPServer::start(opts);
297 // server successfully started
298 C10D_DEBUG("The server has started on port = {}.", server_->port());
299
300 std::ifstream maxconnFile("/proc/sys/net/core/somaxconn");
301 if (maxconnFile.good() && numWorkers_.has_value()) {
302 try {
303 std::string str(
304 (std::istreambuf_iterator<char>(maxconnFile)),
305 std::istreambuf_iterator<char>());
306 std::size_t somaxconn = std::stoll(str);
307 if (somaxconn < *numWorkers_) {
308 C10D_WARNING(
309 "Starting store with {} workers but somaxconn is {}."
310 "This might cause instability during bootstrap, consider increasing it.",
311 *numWorkers_,
312 somaxconn);
313 }
314 } catch (std::logic_error& e) {
315 C10D_INFO("failed to parse somaxconn proc file due to {}", e.what());
316 }
317 }
318
319 addr_.port = server_->port();
320 } else {
321 addr_.port = opts.port;
322 }
323
324 // Try connecting several times -- if the server listen backlog is full it may
325 // fail on the first send in validate.
326 auto deadline = std::chrono::steady_clock::now() + opts.timeout;
327 auto backoff = std::make_shared<ExponentialBackoffWithJitter>();
328
329 auto retry = 0;
330 do {
331 try {
332 client_ = detail::TCPClient::connect(addr_, opts, backoff);
333 // TCP connection established
334 C10D_DEBUG("TCP client connected to host {}:{}", addr_.host, addr_.port);
335
336 // client's first query for validation
337 validate();
338
339 // ping to verify network connectivity
340 ping();
341
342 // success
343 break;
344 } catch (const c10::DistNetworkError& ex) {
345 if (deadline < std::chrono::steady_clock::now()) {
346 C10D_ERROR(
347 "TCP client failed to connect/validate to host {}:{} - timed out (try={}, timeout={}ms): {}",
348 addr_.host,
349 addr_.port,
350 retry,
351 opts.timeout.count(),
352 ex.what());
353 throw;
354 }
355
356 auto delayDuration = backoff->nextBackoff();
357
358 C10D_WARNING(
359 "TCP client failed to connect/validate to host {}:{} - retrying (try={}, timeout={}ms, delay={}ms): {}",
360 addr_.host,
361 addr_.port,
362 retry,
363 opts.timeout.count(),
364 delayDuration.count(),
365 ex.what());
366
367 std::this_thread::sleep_for(delayDuration);
368 retry += 1;
369 }
370 } while (true);
371
372 if (opts.waitWorkers) {
373 waitForWorkers();
374 }
375 }
376
377 TCPStore::~TCPStore() = default;
378
waitForWorkers()379 void TCPStore::waitForWorkers() {
380 STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__waitForWorkers);
381 if (numWorkers_ == std::nullopt) {
382 return;
383 }
384
385 incrementValueBy(initKey_, 1);
386
387 // Let server block until all workers have completed, this ensures that
388 // the server daemon thread is always running until the very end
389 if (server_) {
390 const auto start = std::chrono::steady_clock::now();
391 while (true) {
392 // TODO: Any chance to make this cleaner?
393 std::vector<uint8_t> value = doGet(initKey_);
394 auto buf = reinterpret_cast<const char*>(value.data());
395 auto len = value.size();
396 int numWorkersCompleted = std::stoi(std::string(buf, len));
397 if (numWorkersCompleted >= static_cast<int>(*numWorkers_)) {
398 break;
399 }
400 const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
401 std::chrono::steady_clock::now() - start);
402 if (timeout_ != kNoTimeout && elapsed > timeout_) {
403 C10_THROW_ERROR(
404 DistStoreError,
405 fmt::format(
406 "Timed out after {} seconds waiting for clients. {}/{} clients joined.",
407 elapsed.count(),
408 numWorkersCompleted,
409 *numWorkers_));
410 }
411 /* sleep override */
412 std::this_thread::sleep_for(std::chrono::milliseconds(10));
413 }
414 }
415 }
416
validate()417 void TCPStore::validate() {
418 const std::lock_guard<std::mutex> lock(activeOpLock_);
419 detail::SendBuffer buffer(*client_, detail::QueryType::VALIDATE);
420 buffer.appendValue<std::uint32_t>(c10d::detail::validationMagicNumber);
421 buffer.flush();
422 }
423
ping()424 void TCPStore::ping() {
425 const std::lock_guard<std::mutex> lock(activeOpLock_);
426 detail::SendBuffer buffer(*client_, detail::QueryType::PING);
427
428 uint32_t nonce = getpid();
429 buffer.appendValue<std::uint32_t>(nonce);
430 buffer.flush();
431
432 uint32_t returnedNonce = client_->receiveValue<std::uint32_t>();
433 TORCH_INTERNAL_ASSERT(
434 nonce == returnedNonce, "Ping failed, invalid nonce returned");
435 }
436
_splitSet(const std::string & key,const std::vector<uint8_t> & data)437 void TCPStore::_splitSet(
438 const std::string& key,
439 const std::vector<uint8_t>& data) {
440 const std::lock_guard<std::mutex> lock(activeOpLock_);
441 detail::SendBuffer buffer(*client_, detail::QueryType::SET);
442 buffer.appendString(keyPrefix_ + key);
443 buffer.flush();
444 std::this_thread::sleep_for(std::chrono::milliseconds(1000));
445 buffer.appendBytes(data);
446 buffer.flush();
447 }
448
set(const std::string & key,const std::vector<uint8_t> & data)449 void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
450 STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__set);
451 const std::lock_guard<std::mutex> lock(activeOpLock_);
452 detail::SendBuffer buffer(*client_, detail::QueryType::SET);
453 buffer.appendString(keyPrefix_ + key);
454 buffer.appendBytes(data);
455 buffer.flush();
456 }
457
compareSet(const std::string & key,const std::vector<uint8_t> & expectedValue,const std::vector<uint8_t> & desiredValue)458 std::vector<uint8_t> TCPStore::compareSet(
459 const std::string& key,
460 const std::vector<uint8_t>& expectedValue,
461 const std::vector<uint8_t>& desiredValue) {
462 STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__compareSet);
463 const std::lock_guard<std::mutex> lock(activeOpLock_);
464 detail::SendBuffer buffer(*client_, detail::QueryType::COMPARE_SET);
465 buffer.appendString(keyPrefix_ + key);
466 buffer.appendBytes(expectedValue);
467 buffer.appendBytes(desiredValue);
468 buffer.flush();
469
470 return client_->receiveBits();
471 }
472
get(const std::string & key)473 std::vector<uint8_t> TCPStore::get(const std::string& key) {
474 STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__get);
475 const std::lock_guard<std::mutex> lock(activeOpLock_);
476 return doGet(keyPrefix_ + key);
477 }
478
doGet(const std::string & key)479 std::vector<uint8_t> TCPStore::doGet(const std::string& key) {
480 doWait(key, timeout_);
481 detail::SendBuffer buffer(*client_, detail::QueryType::GET);
482 buffer.appendString(key);
483 buffer.flush();
484
485 return client_->receiveBits();
486 }
487
add(const std::string & key,int64_t value)488 int64_t TCPStore::add(const std::string& key, int64_t value) {
489 STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__add);
490 const std::lock_guard<std::mutex> lock(activeOpLock_);
491 return incrementValueBy(keyPrefix_ + key, value);
492 }
493
deleteKey(const std::string & key)494 bool TCPStore::deleteKey(const std::string& key) {
495 STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__delete);
496 const std::lock_guard<std::mutex> lock(activeOpLock_);
497 detail::SendBuffer buffer(*client_, detail::QueryType::DELETE_KEY);
498 buffer.appendString(keyPrefix_ + key);
499 buffer.flush();
500
501 auto numDeleted = client_->receiveValue<std::int64_t>();
502 return numDeleted == 1;
503 }
504
incrementValueBy(const std::string & key,int64_t delta)505 int64_t TCPStore::incrementValueBy(const std::string& key, int64_t delta) {
506 detail::SendBuffer buff(*client_, detail::QueryType::ADD);
507 buff.appendString(key);
508 buff.appendValue<std::int64_t>(delta);
509 buff.flush();
510
511 return client_->receiveValue<std::int64_t>();
512 }
513
getNumKeys()514 int64_t TCPStore::getNumKeys() {
515 const std::lock_guard<std::mutex> lock(activeOpLock_);
516 detail::SendBuffer buffer(*client_, detail::QueryType::GETNUMKEYS);
517 buffer.flush();
518
519 return client_->receiveValue<std::int64_t>();
520 }
521
check(const std::vector<std::string> & keys)522 bool TCPStore::check(const std::vector<std::string>& keys) {
523 STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__check);
524 const std::lock_guard<std::mutex> lock(activeOpLock_);
525 detail::SendBuffer buffer(*client_, detail::QueryType::CHECK);
526 buffer.appendValue(keys.size());
527
528 for (const std::string& key : keys) {
529 buffer.appendString(keyPrefix_ + key);
530 }
531 buffer.flush();
532
533 auto response = client_->receiveValue<detail::CheckResponseType>();
534 if (response == detail::CheckResponseType::READY) {
535 return true;
536 }
537 if (response == detail::CheckResponseType::NOT_READY) {
538 return false;
539 }
540 TORCH_CHECK(false, "ready or not_ready response expected");
541 }
542
wait(const std::vector<std::string> & keys)543 void TCPStore::wait(const std::vector<std::string>& keys) {
544 wait(keys, timeout_);
545 }
546
wait(const std::vector<std::string> & keys,const std::chrono::milliseconds & timeout)547 void TCPStore::wait(
548 const std::vector<std::string>& keys,
549 const std::chrono::milliseconds& timeout) {
550 STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__wait);
551 const std::lock_guard<std::mutex> lock(activeOpLock_);
552 std::vector<std::string> prefixedKeys{};
553 prefixedKeys.reserve(keys.size());
554 for (const std::string& key : keys) {
555 prefixedKeys.emplace_back(keyPrefix_ + key);
556 }
557
558 doWait(prefixedKeys, timeout);
559 }
560
doWait(c10::ArrayRef<std::string> keys,std::chrono::milliseconds timeout)561 void TCPStore::doWait(
562 c10::ArrayRef<std::string> keys,
563 std::chrono::milliseconds timeout) {
564 {
565 detail::SendBuffer buffer(*client_, detail::QueryType::WAIT);
566 buffer.appendValue(keys.size());
567 for (const std::string& key : keys) {
568 buffer.appendString(key);
569 }
570 buffer.flush();
571 }
572
573 detail::WaitResponseType response;
574 if (client_->receiveValueWithTimeout<detail::WaitResponseType>(
575 response, timeout)) {
576 if (response != detail::WaitResponseType::STOP_WAITING) {
577 TORCH_CHECK(false, "Stop_waiting response is expected");
578 }
579 return;
580 }
581 // this is the cancel wait timeout, once here we expect the server to respond
582 // in a timely fashion
583 {
584 detail::SendBuffer buffer(*client_, detail::QueryType::CANCEL_WAIT);
585 buffer.flush();
586 }
587
588 response = client_->receiveValue<detail::WaitResponseType>();
589 // this can happen if the server responds before we cancel, just ignore it
590 if (response != detail::WaitResponseType::WAIT_CANCELED) {
591 if (response != detail::WaitResponseType::STOP_WAITING) {
592 TORCH_CHECK(false, "Stop_waiting response is expected");
593 }
594
595 response = client_->receiveValue<detail::WaitResponseType>(); // ignore
596 if (response != detail::WaitResponseType::WAIT_CANCELED) {
597 TORCH_CHECK(false, "wait_canceled response is expected");
598 }
599 }
600 C10_THROW_ERROR(
601 DistStoreError,
602 fmt::format(
603 "wait timeout after {}ms, keys: {}",
604 timeout.count(),
605 fmt::join(keys, ", ")));
606 }
607
append(const std::string & key,const std::vector<uint8_t> & data)608 void TCPStore::append(
609 const std::string& key,
610 const std::vector<uint8_t>& data) {
611 STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__append);
612 const std::lock_guard<std::mutex> lock(activeOpLock_);
613 detail::SendBuffer buffer(*client_, detail::QueryType::APPEND);
614 buffer.appendString(keyPrefix_ + key);
615 buffer.appendBytes(data);
616 buffer.flush();
617 }
618
multiGet(const std::vector<std::string> & keys)619 std::vector<std::vector<uint8_t>> TCPStore::multiGet(
620 const std::vector<std::string>& keys) {
621 STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__multiGet);
622 const std::lock_guard<std::mutex> lock(activeOpLock_);
623 std::vector<std::string> prefixedKeys;
624 prefixedKeys.reserve(keys.size());
625 for (const std::string& key : keys) {
626 prefixedKeys.emplace_back(keyPrefix_ + key);
627 }
628 doWait(prefixedKeys, timeout_);
629
630 detail::SendBuffer buffer(*client_, detail::QueryType::MULTI_GET);
631 buffer.appendValue(keys.size());
632 for (auto& key : prefixedKeys) {
633 buffer.appendString(key);
634 }
635 buffer.flush();
636
637 std::vector<std::vector<uint8_t>> result;
638 result.reserve(keys.size());
639 for (size_t i = 0; i < keys.size(); ++i) {
640 result.emplace_back(client_->receiveBits());
641 }
642 return result;
643 }
644
multiSet(const std::vector<std::string> & keys,const std::vector<std::vector<uint8_t>> & values)645 void TCPStore::multiSet(
646 const std::vector<std::string>& keys,
647 const std::vector<std::vector<uint8_t>>& values) {
648 STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__multiSet);
649 TORCH_CHECK(
650 keys.size() == values.size(),
651 "multiSet keys and values vectors must be of same size");
652 const std::lock_guard<std::mutex> lock(activeOpLock_);
653
654 detail::SendBuffer buffer(*client_, detail::QueryType::MULTI_SET);
655 buffer.appendValue<std::int64_t>(keys.size());
656 for (auto i : c10::irange(keys.size())) {
657 buffer.appendString(keyPrefix_ + keys[i]);
658 buffer.appendBytes(values[i]);
659 }
660 buffer.flush();
661 }
662
hasExtendedApi() const663 bool TCPStore::hasExtendedApi() const {
664 return true;
665 }
666
repr() const667 std::string TCPStore::repr() const {
668 auto clientRepr = client_ ? client_->repr() : "<nullptr>";
669 auto serverRepr = server_ ? server_->repr() : "<nullptr>";
670 return fmt::format("TCPStore(client={}, server={})", clientRepr, serverRepr);
671 }
672
673 } // namespace c10d
674