1 #include <algorithm>
2 #include <deque>
3 #include <exception>
4 #include <memory>
5 #include <ostream>
6 #include <unordered_map>
7 #include <unordered_set>
8 #include <utility>
9 #include <vector>
10
11 #include <c10/util/thread_name.h>
12 #include <fmt/format.h>
13 #include <torch/csrc/distributed/c10d/TCPStore.hpp>
14 #include <torch/csrc/distributed/c10d/TCPStoreBackend.hpp>
15 #include <torch/csrc/distributed/c10d/logging.h>
16
17 #ifdef TORCH_USE_LIBUV
18 #include <uv.h>
19 #endif
20
21 namespace c10d::detail {
22
23 #ifdef TORCH_USE_LIBUV
24
25 /*
26
27 Exception safety:
28
29 It's ok to use exceptions during client processing.
30 Other callbacks don't provide exception safety so avoid there.
31
32 */
33
34 // This controls how many un-accepted TCP connections can be waiting in the
35 // backlog. This should be at least world size to avoid issues on init. We set
36 // it to -1 to use the host max value which is controlled by `soconnmax`.
37 #define DEFAULT_BACKLOG -1
38 #define MAX_KEY_COUNT (128 * 1024)
39 #define MAX_STRING_LEN (8 * 1024)
40 #define MAX_PAYLOAD_LEN (8 * 1024 * 1024)
41
42 // This controls the preferred size for buffers.
43 // Too small and we'll need multiple buffers for one request
44 // Too big and we might taxing malloc
45 #define ALLOC_BUFFER_SIZE ((size_t)4000)
46 class UvHandle : public c10::intrusive_ptr_target {
47 public:
48 ~UvHandle() override = default;
49
iptr()50 c10::intrusive_ptr<UvHandle> iptr() {
51 return c10::intrusive_ptr<UvHandle>::reclaim_copy(this);
52 }
53
close()54 void close() {
55 if (uv_is_closing(unsafeGetHandle())) {
56 return;
57 }
58 uv_close(unsafeGetHandle(), on_close);
59 }
60
61 virtual uv_handle_t* unsafeGetHandle() = 0;
62
63 protected:
handleReady()64 void handleReady() {
65 /*
66 This method must be called once the handle is ready and registered with the
67 loop.
68
69 Do not call this in the ctor, make_intrusive reset refcounts to one after
70 construction.
71 */
72 uv_handle_set_data(unsafeGetHandle(), this);
73 at::raw::intrusive_ptr::incref(this);
74 }
75
76 virtual void onClose() = 0;
77
78 private:
reclaim(uv_handle_t * handle)79 static c10::intrusive_ptr<UvHandle> reclaim(uv_handle_t* handle) {
80 auto h = (UvHandle*)uv_handle_get_data(handle);
81 return c10::intrusive_ptr<UvHandle>::reclaim(h);
82 }
83
on_close(uv_handle_t * uv_handle)84 static void on_close(uv_handle_t* uv_handle) {
85 auto handle = reclaim(uv_handle);
86 handle->onClose();
87 }
88 };
89
90 class UvTcpSocket : public UvHandle {
91 uv_tcp_t client{};
92
iptr()93 c10::intrusive_ptr<UvTcpSocket> iptr() {
94 return c10::intrusive_ptr<UvTcpSocket>::reclaim_copy(this);
95 }
96
borrow(uv_stream_t * handle)97 static c10::intrusive_ptr<UvTcpSocket> borrow(uv_stream_t* handle) {
98 auto h = (UvTcpSocket*)uv_handle_get_data((uv_handle_t*)handle);
99 return h->iptr();
100 }
101
alloc_buffer(uv_handle_t * handle,size_t suggested_size,uv_buf_t * buf)102 static void alloc_buffer(
103
104 uv_handle_t* handle,
105 size_t suggested_size,
106 uv_buf_t* buf) {
107 suggested_size = std::min(suggested_size, (size_t)ALLOC_BUFFER_SIZE);
108 buf->base = (char*)malloc(suggested_size);
109 buf->len = suggested_size;
110 }
111
read_callback(uv_stream_t * client,ssize_t nread,const uv_buf_t * buf)112 static void read_callback(
113 uv_stream_t* client,
114 ssize_t nread,
115 const uv_buf_t* buf) {
116 auto uv_socket = UvTcpSocket::borrow(client);
117
118 if (nread < 0) {
119 C10D_DEBUG(
120 "Read callback failed. code:{} name:{} desc:{}",
121 nread,
122 uv_err_name(nread),
123 uv_strerror(nread));
124 uv_socket->close();
125 return;
126 }
127 if (nread > 0) {
128 try {
129 uv_socket->processBuf(buf, nread);
130 } catch (std::exception& ex) {
131 C10D_WARNING("Error processing client message: {}", ex.what());
132 uv_socket->close();
133 }
134 }
135 }
136
137 public:
UvTcpSocket(uv_loop_t * loop)138 explicit UvTcpSocket(uv_loop_t* loop) {
139 uv_tcp_init(loop, &client);
140 if (int err = uv_tcp_nodelay(&client, 1)) {
141 C10D_WARNING(
142 "The no-delay option cannot be enabled for the client socket. err={}",
143 err);
144 }
145 }
146
startRead()147 void startRead() {
148 int res = uv_read_start((uv_stream_t*)&client, alloc_buffer, read_callback);
149 if (res) {
150 C10D_WARNING(
151 "Failed to setup read callback. client:{} code:{} name:{} desc:{}.",
152 (void*)this,
153 res,
154 uv_err_name(res),
155 uv_strerror(res));
156 close();
157 }
158 }
159
unsafeGetHandle()160 uv_handle_t* unsafeGetHandle() override {
161 return (uv_handle_t*)&client;
162 }
163
164 protected:
unsafeGetStream()165 uv_stream_t* unsafeGetStream() {
166 return (uv_stream_t*)&client;
167 }
168
unsafeGetSocket()169 uv_tcp_t* unsafeGetSocket() {
170 return &client;
171 }
172
processBuf(const uv_buf_t * buf,size_t nread)173 virtual void processBuf(const uv_buf_t* buf, size_t nread) {
174 TORCH_CHECK(
175 false, "Trying to read from a socket subclass that lacks processBuf");
176 }
177
onClose()178 void onClose() override {
179 // TODO use registerClient (and rename it to registerHandle) - this will
180 // significantly simplify things.
181 }
182 };
183
184 class UvTcpServer : public UvTcpSocket {
185 public:
186 typedef std::function<void(int)> OnConnectCallback;
UvTcpServer(uv_loop_t * loop)187 explicit UvTcpServer(uv_loop_t* loop)
188 : UvTcpSocket(loop), onConnectCb(missingOnConnect) {}
189
makeWithSocket(uv_loop_t * loop,int socket)190 static c10::intrusive_ptr<UvTcpServer> makeWithSocket(
191 uv_loop_t* loop,
192 int socket) {
193 auto res = c10::make_intrusive<UvTcpServer>(loop);
194 res->handleReady();
195 try {
196 int uv_res = uv_tcp_open((uv_tcp_t*)res->unsafeGetStream(), socket);
197 TORCH_CHECK(
198 uv_res == 0,
199 "Failed to open existing socket. ",
200 "socket: ",
201 socket,
202 ", code: ",
203 uv_res,
204 ", name: ",
205 uv_err_name(uv_res),
206 ", message: ",
207 uv_strerror(uv_res));
208
209 res->cacheSocketPort();
210 } catch (std::exception& ex) {
211 res->close();
212 throw;
213 }
214
215 return res;
216 }
217
setOnConnectCallback(OnConnectCallback && callback)218 void setOnConnectCallback(OnConnectCallback&& callback) {
219 onConnectCb = std::move(callback);
220 }
221
makeWithPort(uv_loop_t * loop,uint16_t port,bool useIpv6)222 static c10::intrusive_ptr<UvTcpServer> makeWithPort(
223 uv_loop_t* loop,
224 uint16_t port,
225 bool useIpv6) {
226 auto res = c10::make_intrusive<UvTcpServer>(loop);
227 res->handleReady();
228 try {
229 struct sockaddr_storage addr {};
230 int uv_res = 0;
231 if (useIpv6) {
232 uv_res = uv_ip6_addr("::", port, (struct sockaddr_in6*)&addr);
233 } else {
234 uv_res = uv_ip4_addr("0.0.0.0", port, (struct sockaddr_in*)&addr);
235 }
236 TORCH_CHECK(
237 uv_res == 0,
238 "UV Store addr parsing failure. ",
239 "port: ",
240 port,
241 ", useIpv6: ",
242 useIpv6,
243 ", code: ",
244 uv_res,
245 ", name: ",
246 uv_err_name(uv_res),
247 ", message: ",
248 uv_strerror(uv_res));
249
250 uv_res =
251 uv_tcp_bind(res->unsafeGetSocket(), (const struct sockaddr*)&addr, 0);
252 TORCH_CHECK(
253 uv_res == 0,
254 "The server socket has failed to bind. ",
255 "port: ",
256 port,
257 ", useIpv6: ",
258 useIpv6,
259 ", code: ",
260 uv_res,
261 ", name: ",
262 uv_err_name(uv_res),
263 ", message: ",
264 uv_strerror(uv_res));
265
266 uv_res =
267 uv_listen(res->unsafeGetStream(), DEFAULT_BACKLOG, on_new_connection);
268 TORCH_CHECK(
269 uv_res == 0,
270 "The server socket has failed to listen on any local network address. ",
271 "port: ",
272 port,
273 ", useIpv6: ",
274 useIpv6,
275 ", code: ",
276 uv_res,
277 ", name: ",
278 uv_err_name(uv_res),
279 ", message: ",
280 uv_strerror(uv_res));
281
282 res->cacheSocketPort();
283 } catch (std::exception& ex) {
284 res->close();
285 throw;
286 }
287
288 return res;
289 }
290
port() const291 uint16_t port() const {
292 return portNum;
293 }
294
accept(const c10::intrusive_ptr<UvTcpSocket> & socket)295 void accept(const c10::intrusive_ptr<UvTcpSocket>& socket) {
296 int res =
297 uv_accept(unsafeGetStream(), (uv_stream_t*)socket->unsafeGetHandle());
298 TORCH_CHECK(
299 res == 0,
300 "Failed to accept socket. ",
301 "code: ",
302 res,
303 ", name: ",
304 uv_err_name(res),
305 ", message: ",
306 uv_strerror(res));
307 }
308
309 private:
310 OnConnectCallback onConnectCb;
311 uint16_t portNum{};
312
iptr()313 c10::intrusive_ptr<UvTcpServer> iptr() {
314 return c10::intrusive_ptr<UvTcpServer>::reclaim_copy(this);
315 }
316
borrow(uv_stream_t * handle)317 static c10::intrusive_ptr<UvTcpServer> borrow(uv_stream_t* handle) {
318 auto h = (UvTcpServer*)uv_handle_get_data((uv_handle_t*)handle);
319 return h->iptr();
320 }
321
cacheSocketPort()322 void cacheSocketPort() {
323 sockaddr_storage addr_s{};
324
325 int addr_len = sizeof(addr_s);
326
327 if (uv_tcp_getsockname(
328 (uv_tcp_t*)unsafeGetStream(),
329 reinterpret_cast<sockaddr*>(&addr_s),
330 &addr_len) != 0) {
331 throw std::runtime_error(
332 "The port number of the socket cannot be retrieved.");
333 }
334
335 if (addr_s.ss_family == AF_INET) {
336 portNum = ntohs(reinterpret_cast<sockaddr_in*>(&addr_s)->sin_port);
337 } else {
338 portNum = ntohs(reinterpret_cast<sockaddr_in6*>(&addr_s)->sin6_port);
339 }
340 }
341
missingOnConnect(int status)342 static void missingOnConnect(int status) {
343 TORCH_CHECK(false, "Socket accepted byt onConnect callback missing");
344 }
345
on_new_connection(uv_stream_t * server,int status)346 static void on_new_connection(uv_stream_t* server, int status) {
347 borrow(server)->onConnectCb(status);
348 }
349 };
350
351 class WriterPayload : public c10::intrusive_ptr_target {
reclaim(uv_write_t * request)352 static c10::intrusive_ptr<WriterPayload> reclaim(uv_write_t* request) {
353 /* This method returns a intrusive_ptr that does not increase the refcount.
354 */
355 auto h = (WriterPayload*)uv_req_get_data((uv_req_t*)request);
356 return c10::intrusive_ptr<WriterPayload>::reclaim(h);
357 }
358
registeredInLoop()359 void registeredInLoop() {
360 /*
361 This refcount increment must be matched by a reclaim call.
362 Call this method after sucessfully scheduling this handle with a loop.
363 */
364 at::raw::intrusive_ptr::incref(this);
365 }
366
write_done(uv_write_t * req,int status)367 static void write_done(uv_write_t* req, int status) {
368 /* Since we're no longer actively used by the event loop, transfer ownership
369 * to this frame. */
370 auto wp = WriterPayload::reclaim(req);
371 auto handle = wp->handle;
372
373 if (status) {
374 C10D_WARNING(
375 "Write to client failed. code:{} name:{} desc:{}.",
376 status,
377 uv_err_name(status),
378 uv_strerror(status));
379 handle->close();
380 }
381 }
382
383 std::vector<uint8_t> data;
384 uv_write_t req = {};
385 uv_buf_t buf = {};
386 c10::intrusive_ptr<UvHandle> handle;
387
388 public:
WriterPayload(std::vector<uint8_t> && in_data,c10::intrusive_ptr<UvHandle> handle)389 WriterPayload(
390 std::vector<uint8_t>&& in_data,
391 c10::intrusive_ptr<UvHandle> handle)
392 : data(std::move(in_data)), handle(std::move(handle)) {
393 uv_req_set_data((uv_req_t*)&req, this);
394 }
395
396 ~WriterPayload() override = default;
397
send()398 void send() {
399 buf = uv_buf_init((char*)data.data(), data.size());
400 int res = uv_write(
401 &req, (uv_stream_t*)handle->unsafeGetHandle(), &buf, 1, write_done);
402
403 if (res) {
404 C10D_WARNING(
405 "Write setup to client failed. code:{} name:{} desc:{}.",
406 res,
407 uv_err_name(res),
408 uv_strerror(res));
409 handle->close();
410 } else {
411 /* This object was successfully registered with the event loop, so keep it
412 * alive until it's unregistered. */
413 registeredInLoop();
414 }
415 }
416 };
417
418 class StreamWriter {
419 std::vector<uint8_t> data;
420 c10::intrusive_ptr<UvHandle> handle;
421
422 // must be stack allocated
423 void* operator new(size_t);
424
425 public:
StreamWriter(c10::intrusive_ptr<UvHandle> handle)426 StreamWriter(c10::intrusive_ptr<UvHandle> handle)
427 : handle(std::move(handle)) {}
428
write1(uint8_t val)429 void write1(uint8_t val) {
430 data.push_back(val);
431 }
432
433 template <typename T>
write_value(T val)434 void write_value(T val) {
435 uint8_t* val_ptr = (uint8_t*)&val;
436 data.insert(data.end(), val_ptr, val_ptr + sizeof(T));
437 }
438
write_vector(const std::vector<uint8_t> & val)439 void write_vector(const std::vector<uint8_t>& val) {
440 write_value<uint64_t>(val.size());
441 data.insert(data.end(), val.begin(), val.end());
442 }
443
write_string(const std::string & val)444 void write_string(const std::string& val) {
445 write_value<uint64_t>(val.size());
446 data.insert(data.end(), val.data(), val.data() + val.size());
447 }
send()448 void send() {
449 auto wd = c10::make_intrusive<WriterPayload>(std::move(data), handle);
450 wd->send();
451 }
452 };
453
454 class ChunkedStream {
455 std::deque<uv_buf_t> buffers;
456 size_t buff_idx{0};
457 size_t buff_offset{0};
458 size_t capacity{0};
459 size_t buff_offset_commit{0};
460 size_t read_offset{0};
461
462 public:
463 ChunkedStream() = default;
464
buf_count()465 size_t buf_count() {
466 return buffers.size();
467 }
468
append(uv_buf_t buf)469 void append(uv_buf_t buf) {
470 if (buf.len == 0) {
471 free(buf.base);
472 } else {
473 capacity += buf.len;
474 buffers.push_back(buf);
475 }
476 }
read_many(char * dest,size_t size)477 bool read_many(char* dest, size_t size) {
478 if (available() < size) {
479 return false;
480 }
481
482 size_t remaining = size;
483 char* write_base = dest;
484 while (remaining > 0) {
485 auto to_read = std::min(buffers[buff_idx].len - buff_offset, remaining);
486 ::memcpy(write_base, buffers[buff_idx].base + buff_offset, to_read);
487 buff_offset += to_read;
488 remaining -= to_read;
489 write_base += to_read;
490 if (buff_offset >= buffers[buff_idx].len) {
491 buff_offset = 0;
492 ++buff_idx;
493 if (buff_idx >= buffers.size() && remaining > 0) {
494 TORCH_CHECK(
495 false,
496 "Trying to read past end of buffer. ",
497 "buffer_idx: ",
498 buff_idx,
499 ", available: ",
500 buffers.size(),
501 ", remaining: ",
502 remaining);
503 }
504 }
505 }
506 read_offset += size;
507 return true;
508 }
509
read1(uint8_t & byte)510 bool read1(uint8_t& byte) {
511 while (true) {
512 if (buff_idx >= buffers.size())
513 return false;
514 if (buff_offset >= buffers[buff_idx].len) {
515 buff_offset = 0;
516 ++buff_idx;
517 continue;
518 }
519 break;
520 }
521
522 byte = buffers[buff_idx].base[buff_offset];
523 ++buff_offset;
524 ++read_offset;
525 return true;
526 }
527
528 template <typename T>
read_value(T & value)529 bool read_value(T& value) {
530 return read_many((char*)&value, sizeof(T));
531 }
532
read_key(std::string & str)533 bool read_key(std::string& str) {
534 uint64_t size = 0;
535 if (!read_value(size))
536 return false;
537 TORCH_CHECK(
538 size <= MAX_STRING_LEN,
539 "Invalid string size. ",
540 "size: ",
541 size,
542 ", max: ",
543 MAX_STRING_LEN);
544
545 if (available() < size)
546 return false;
547 str.resize(size);
548 return read_many((char*)str.data(), size);
549 }
550
read_payload(std::vector<uint8_t> & data)551 bool read_payload(std::vector<uint8_t>& data) {
552 uint64_t size = 0;
553 if (!read_value(size))
554 return false;
555 auto size_in_bytes = size * sizeof(uint8_t);
556 TORCH_CHECK(
557 size_in_bytes <= MAX_PAYLOAD_LEN,
558 "Invalid payload size. ",
559 "size: ",
560 size_in_bytes,
561 ", max: ",
562 MAX_PAYLOAD_LEN);
563
564 if (available() < size_in_bytes)
565 return false;
566 data.resize(size);
567 return read_many((char*)data.data(), size_in_bytes);
568 }
569
available()570 size_t available() {
571 return capacity - read_offset;
572 }
573
commit()574 void commit() {
575 if (buff_idx >= buffers.size() || buff_offset >= buffers[buff_idx].len) {
576 buff_offset = 0;
577 if (buff_idx < buffers.size())
578 ++buff_idx;
579 }
580
581 for (size_t i = 0; i < buff_idx; ++i) {
582 free(buffers[0].base);
583 capacity -= buffers[0].len;
584 buffers.pop_front();
585 }
586 buff_idx = 0;
587 read_offset = buff_offset_commit = buff_offset;
588 }
589
reset()590 void reset() {
591 buff_idx = 0;
592 read_offset = buff_offset = buff_offset_commit;
593 }
594 };
595
596 class LibUVStoreDaemon : public BackgroundThread {
597 public:
598 explicit LibUVStoreDaemon(int port);
599 ~LibUVStoreDaemon() override;
600
601 uint16_t port() const override;
602
603 void set(const std::string& key, const std::vector<uint8_t>& value);
604 const std::vector<uint8_t>& compareAndSet(
605 const std::string& key,
606 const std::vector<uint8_t>& expectedValue,
607 const std::vector<uint8_t>& newValue);
608 const std::vector<uint8_t>& get(const std::string& key);
609 int64_t add(const std::string& key, int64_t addVal);
610 bool checkKeys(const std::vector<std::string>& keys);
611 bool waitKeys(
612 const std::vector<std::string>& keys,
613 const c10::intrusive_ptr<UvHandle>& client);
614 int64_t size();
615 int64_t deleteKey(const std::string& key);
616 void append(const std::string& key, const std::vector<uint8_t>& value);
617
618 void registerClient(const c10::intrusive_ptr<UvHandle>& client);
619 void unregisterClient(const c10::intrusive_ptr<UvHandle>& client);
620 void clearClientWaitState(const c10::intrusive_ptr<UvHandle>& client);
621 bool isMiscellaneousClient(const c10::intrusive_ptr<UvHandle>& client);
622
623 uint16_t get_socket_port(uv_tcp_t* handle);
624 void init(const TCPStoreOptions& opts);
625
626 protected:
627 void run() override;
628 void stop() override;
629
630 private:
631 uv_loop_t loop{};
632 c10::intrusive_ptr<UvTcpServer> tcpServer;
633
634 uv_async_t exit_handle{};
635 std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_;
636 // From key -> the list of UvClient waiting on the key
637 std::unordered_map<std::string, std::vector<c10::intrusive_ptr<UvHandle>>>
638 waitingSockets_;
639 // From socket -> number of keys awaited
640 std::unordered_map<c10::intrusive_ptr<UvHandle>, size_t> keysAwaited_;
641 std::unordered_set<c10::intrusive_ptr<UvHandle>> clients_;
642 std::unordered_set<c10::intrusive_ptr<UvHandle>> miscellaneousClients_;
643 int port_;
644
from_uv(uv_handle_t * stream)645 static LibUVStoreDaemon& from_uv(uv_handle_t* stream) {
646 return *(LibUVStoreDaemon*)uv_handle_get_data(stream);
647 }
648
on_new_connection(uv_stream_t * server,int status)649 static void on_new_connection(uv_stream_t* server, int status) {
650 from_uv((uv_handle_t*)server).onConnect(status);
651 }
652
on_exit_request(uv_async_t * handle)653 static void on_exit_request(uv_async_t* handle) {
654 from_uv((uv_handle_t*)handle).onExitRequest();
655 }
656
657 void onConnect(int status);
658 void onExitRequest();
659 void wakeupWaitingClients(const std::string& key);
660 // bool tryListen(bool use_ipv6);
661
662 static void print_active_handles(uv_handle_t* handle, void* arg);
663 };
664
665 class UvClient : public UvTcpSocket {
666 ChunkedStream stream;
667 LibUVStoreDaemon* store;
668
669 protected:
processBuf(const uv_buf_t * buf,size_t nread)670 void processBuf(const uv_buf_t* buf, size_t nread) override {
671 auto tmp = *buf;
672 tmp.len = nread;
673 stream.append(tmp);
674
675 while (true) {
676 stream.reset();
677 uint8_t command = -1;
678 if (!stream.read1(command))
679 break;
680 if (store->isMiscellaneousClient(iptr())) {
681 if ((QueryType)command != QueryType::VALIDATE)
682 return;
683 if (!parse_validate_command())
684 return;
685 } else {
686 switch ((QueryType)command) {
687 case QueryType::PING:
688 if (!parse_ping_command())
689 return;
690 break;
691 case QueryType::SET:
692 if (!parse_set_command())
693 return;
694 break;
695 case QueryType::COMPARE_SET:
696 if (!parse_compare_set_command())
697 return;
698 break;
699 case QueryType::GET:
700 if (!parse_get_command())
701 return;
702 break;
703 case QueryType::ADD:
704 if (!parse_add_command())
705 return;
706 break;
707 case QueryType::CHECK:
708 if (!parse_check_command())
709 return;
710 break;
711 case QueryType::WAIT:
712 if (!parse_wait_command())
713 return;
714 break;
715 case QueryType::GETNUMKEYS:
716 if (!parse_getnumkeys_command())
717 return;
718 break;
719 case QueryType::DELETE_KEY:
720 if (!parse_delete_key_command())
721 return;
722 break;
723 case QueryType::APPEND:
724 if (!parse_append_command())
725 return;
726 break;
727 case QueryType::MULTI_GET:
728 if (!parse_multi_get_command())
729 return;
730 break;
731 case QueryType::MULTI_SET:
732 if (!parse_multi_set_command())
733 return;
734 break;
735 case QueryType::CANCEL_WAIT:
736 if (!parse_cancel_wait_command())
737 return;
738 break;
739 default:
740 C10D_DEBUG(
741 "Client sent invalid command. client:{} command:{}",
742 (void*)this,
743 (int)command);
744 close();
745 return;
746 }
747 }
748 stream.commit();
749 }
750 }
751
parse_validate_command()752 bool parse_validate_command() {
753 uint32_t validateNumber = 0;
754 if (!stream.read_value(validateNumber))
755 return false;
756
757 if (validateNumber != c10d::detail::validationMagicNumber)
758 return false;
759 return true;
760 }
761
parse_ping_command()762 bool parse_ping_command() {
763 uint32_t nonce;
764 if (!stream.read_value(nonce)) {
765 return false;
766 }
767
768 StreamWriter sw(iptr());
769 sw.write_value(nonce);
770 sw.send();
771 return true;
772 }
773
parse_set_command()774 bool parse_set_command() {
775 std::string key;
776 if (!stream.read_key(key))
777 return false;
778
779 std::vector<uint8_t> newData;
780 if (!stream.read_payload(newData))
781 return false;
782
783 store->set(key, newData);
784 return true;
785 }
786
parse_compare_set_command()787 bool parse_compare_set_command() {
788 std::string key;
789 if (!stream.read_key(key))
790 return false;
791
792 std::vector<uint8_t> currentValue;
793 if (!stream.read_payload(currentValue))
794 return false;
795
796 std::vector<uint8_t> newValue;
797 if (!stream.read_payload(newValue))
798 return false;
799
800 auto res = store->compareAndSet(key, currentValue, newValue);
801 StreamWriter sw(iptr());
802 sw.write_vector(res);
803 sw.send();
804
805 return true;
806 }
807
parse_get_command()808 bool parse_get_command() {
809 std::string key;
810 if (!stream.read_key(key))
811 return false;
812
813 const auto& data = store->get(key);
814 StreamWriter sw(iptr());
815 sw.write_vector(data);
816 sw.send();
817 return true;
818 }
819
parse_add_command()820 bool parse_add_command() {
821 std::string key;
822 if (!stream.read_key(key))
823 return false;
824
825 int64_t addVal = 0;
826 if (!stream.read_value(addVal))
827 return false;
828
829 addVal = store->add(key, addVal);
830 StreamWriter sw(iptr());
831 sw.write_value(addVal);
832 sw.send();
833
834 return true;
835 }
836
parse_check_command()837 bool parse_check_command() {
838 uint64_t key_count = 0;
839 if (!stream.read_value(key_count))
840 return false;
841 TORCH_CHECK(
842 key_count <= MAX_KEY_COUNT,
843 "Too many keys being waited. ",
844 "keys: ",
845 key_count,
846 ", max: ",
847 MAX_KEY_COUNT);
848
849 std::vector<std::string> keys(key_count);
850 for (uint64_t i = 0; i < key_count; ++i) {
851 if (!stream.read_key(keys[i]))
852 return false;
853 }
854
855 // Now we have received all the keys
856 StreamWriter sw(iptr());
857 if (store->checkKeys(keys)) {
858 sw.write_value(CheckResponseType::READY);
859 } else {
860 sw.write_value(CheckResponseType::NOT_READY);
861 }
862 sw.send();
863 return true;
864 }
865
parse_wait_command()866 bool parse_wait_command() {
867 uint64_t key_count = 0;
868 if (!stream.read_value(key_count)) {
869 return false;
870 }
871 TORCH_CHECK(
872 key_count <= MAX_KEY_COUNT,
873 "Too many keys being waited. ",
874 "keys: ",
875 key_count,
876 ", max: ",
877 MAX_KEY_COUNT);
878
879 std::vector<std::string> keys(key_count);
880 for (uint64_t i = 0; i < key_count; ++i) {
881 if (!stream.read_key(keys[i]))
882 return false;
883 }
884
885 if (store->waitKeys(keys, iptr())) {
886 StreamWriter sw(iptr());
887 sw.write1((uint8_t)WaitResponseType::STOP_WAITING);
888 sw.send();
889 }
890
891 return true;
892 }
893
parse_getnumkeys_command()894 bool parse_getnumkeys_command() {
895 StreamWriter sw(iptr());
896 sw.write_value<int64_t>(store->size());
897 sw.send();
898
899 return true;
900 }
901
parse_delete_key_command()902 bool parse_delete_key_command() {
903 std::string key;
904 if (!stream.read_key(key))
905 return false;
906
907 auto numDeleted = store->deleteKey(key);
908 StreamWriter sw(iptr());
909 sw.write_value<int64_t>(numDeleted);
910 sw.send();
911
912 return true;
913 }
914
parse_append_command()915 bool parse_append_command() {
916 std::string key;
917 if (!stream.read_key(key)) {
918 return false;
919 }
920
921 std::vector<uint8_t> data;
922 if (!stream.read_payload(data)) {
923 return false;
924 }
925
926 store->append(key, data);
927 return true;
928 }
929
parse_multi_get_command()930 bool parse_multi_get_command() {
931 uint64_t key_count = 0;
932 if (!stream.read_value(key_count)) {
933 return false;
934 }
935 TORCH_CHECK(
936 key_count <= MAX_KEY_COUNT,
937 "Too many keys with multi_get. ",
938 "keys: ",
939 key_count,
940 ", max: ",
941 MAX_KEY_COUNT);
942
943 StreamWriter sw(iptr());
944 for (const auto _ : c10::irange(key_count)) {
945 (void)_; // Suppress unused variable warning
946 std::string key;
947 if (!stream.read_key(key)) {
948 return false;
949 }
950
951 sw.write_vector(store->get(key));
952 }
953 sw.send();
954
955 return true;
956 }
957
parse_multi_set_command()958 bool parse_multi_set_command() {
959 uint64_t key_count = 0;
960 if (!stream.read_value(key_count)) {
961 return false;
962 }
963 TORCH_CHECK(
964 key_count <= MAX_KEY_COUNT,
965 "Too many keys with multi_get. ",
966 "keys: ",
967 key_count,
968 ", max: ",
969 MAX_KEY_COUNT);
970
971 for (const auto _ : c10::irange(key_count)) {
972 (void)_; // Suppress unused variable warning
973
974 std::string key;
975 if (!stream.read_key(key)) {
976 return false;
977 }
978
979 std::vector<uint8_t> newData;
980 if (!stream.read_payload(newData))
981 return false;
982 store->set(key, newData);
983 }
984
985 return true;
986 }
987
parse_cancel_wait_command()988 bool parse_cancel_wait_command() {
989 store->clearClientWaitState(iptr());
990
991 StreamWriter sw(iptr());
992 sw.write1((uint8_t)WaitResponseType::WAIT_CANCELED);
993 sw.send();
994
995 return true;
996 }
997
998 public:
UvClient(uv_loop_t * loop,LibUVStoreDaemon * store)999 explicit UvClient(uv_loop_t* loop, LibUVStoreDaemon* store)
1000 : UvTcpSocket(loop), store(store) {}
1001
make(uv_loop_t * loop,LibUVStoreDaemon * store)1002 static c10::intrusive_ptr<UvClient> make(
1003 uv_loop_t* loop,
1004 LibUVStoreDaemon* store) {
1005 auto res = c10::make_intrusive<UvClient>(loop, store);
1006 res->handleReady();
1007 return res;
1008 }
1009
iptr()1010 c10::intrusive_ptr<UvClient> iptr() {
1011 return c10::intrusive_ptr<UvClient>::reclaim_copy(this);
1012 }
1013
1014 protected:
onClose()1015 void onClose() override {
1016 store->unregisterClient(iptr());
1017 }
1018 };
1019
onConnect(int status)1020 void LibUVStoreDaemon::onConnect(int status) {
1021 auto client = UvClient::make(&loop, this);
1022 registerClient(client);
1023 try {
1024 tcpServer->accept(client);
1025 client->startRead();
1026 } catch (std::exception& e) {
1027 C10D_WARNING("Failed to accept client due to {}", e.what());
1028 client->close();
1029 }
1030 }
1031
onExitRequest()1032 void LibUVStoreDaemon::onExitRequest() {
1033 C10D_DEBUG("Store exit requested\n");
1034 uv_close((uv_handle_t*)&exit_handle, nullptr);
1035 uv_stop(&loop);
1036 }
1037
init(const TCPStoreOptions & opts)1038 void LibUVStoreDaemon::init(const TCPStoreOptions& opts) {
1039 if (opts.masterListenFd.has_value()) {
1040 tcpServer = UvTcpServer::makeWithSocket(&loop, *opts.masterListenFd);
1041 } else {
1042 try {
1043 tcpServer = UvTcpServer::makeWithPort(&loop, opts.port, /*useIpv6=*/true);
1044 } catch (std::exception& ex) {
1045 C10D_INFO(
1046 "Failed to bind to ipv6 address, trying ipv4. Error: {}", ex.what());
1047 tcpServer =
1048 UvTcpServer::makeWithPort(&loop, opts.port, /*useIpv6=*/false);
1049 }
1050 }
1051 tcpServer->setOnConnectCallback(
1052 [this](auto status) { this->onConnect(status); });
1053
1054 port_ = tcpServer->port();
1055 TORCH_CHECK(
1056 port_ == opts.port || opts.port == 0, // zero means use any port
1057 "listen fd ",
1058 *opts.masterListenFd,
1059 " is bound to port ",
1060 port_,
1061 ", expected to be bound to port ",
1062 opts.port);
1063 }
1064
LibUVStoreDaemon(int port)1065 LibUVStoreDaemon::LibUVStoreDaemon(int port) : port_(port) {
1066 TORCH_CHECK(uv_loop_init(&loop) == 0, "Failed to init uv loop");
1067 TORCH_CHECK(
1068 uv_async_init(&loop, &exit_handle, LibUVStoreDaemon::on_exit_request) ==
1069 0,
1070 "Failed to init uv async event");
1071 uv_handle_set_data((uv_handle_t*)&exit_handle, this);
1072 }
1073
~LibUVStoreDaemon()1074 LibUVStoreDaemon::~LibUVStoreDaemon() {
1075 if (!is_running()) {
1076 uv_close((uv_handle_t*)&exit_handle, nullptr);
1077 uv_run(&loop, UV_RUN_NOWAIT);
1078 TORCH_CHECK(uv_loop_close(&loop) == 0, "loop cleanup didn't work");
1079 } else {
1080 // the daemon thread cleanup libuv
1081 dispose();
1082 }
1083 }
1084
port() const1085 uint16_t LibUVStoreDaemon::port() const {
1086 return port_;
1087 }
1088
print_active_handles(uv_handle_t * handle,void * arg)1089 void LibUVStoreDaemon::print_active_handles(uv_handle_t* handle, void* arg) {
1090 C10D_DEBUG(
1091 "UV live handle type {} active:{} is-closing:{}",
1092 (int)handle->type,
1093 uv_is_active(handle),
1094 uv_is_closing(handle));
1095 }
1096
run()1097 void LibUVStoreDaemon::run() {
1098 c10::setThreadName("pt_tcpstore_uv");
1099
1100 C10D_DEBUG("Uv main loop running");
1101 int res = uv_run(&loop, UV_RUN_DEFAULT);
1102 if (res) {
1103 C10D_DEBUG("UV main loop done: res:{}", res);
1104 }
1105 bool debug_enabled =
1106 c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Debug);
1107
1108 if (debug_enabled) {
1109 C10D_DEBUG("Walking live handles prior to closing clients");
1110 uv_walk(&loop, LibUVStoreDaemon::print_active_handles, nullptr);
1111 }
1112
1113 for (const auto& client : clients_) {
1114 client->close();
1115 }
1116 tcpServer->close();
1117
1118 if (debug_enabled) {
1119 C10D_DEBUG("Walking live handles after closing clients");
1120 uv_walk(&loop, LibUVStoreDaemon::print_active_handles, nullptr);
1121 }
1122
1123 while (true) {
1124 res = uv_loop_close(&loop);
1125 if (res == 0) {
1126 break;
1127 }
1128 C10D_INFO(
1129 "uv_loop_close failed with:{} errn:{} desc:{}",
1130 res,
1131 uv_err_name(res),
1132 uv_strerror(res));
1133 res = uv_run(&loop, UV_RUN_NOWAIT);
1134 if (res != 0) {
1135 std::this_thread::sleep_for(std::chrono::milliseconds(500));
1136 }
1137 }
1138 C10D_INFO("uv_loop cleanup finished.");
1139 }
1140
stop()1141 void LibUVStoreDaemon::stop() {
1142 int res = uv_async_send(&exit_handle);
1143 if (res) {
1144 C10D_WARNING(
1145 "uv_async_send failed with:{} errn:{} desc:{}\n",
1146 res,
1147 uv_err_name(res),
1148 uv_strerror(res));
1149 }
1150 }
1151
isMiscellaneousClient(const c10::intrusive_ptr<UvHandle> & client)1152 bool LibUVStoreDaemon::isMiscellaneousClient(
1153 const c10::intrusive_ptr<UvHandle>& client) {
1154 if (miscellaneousClients_.find(client) != miscellaneousClients_.end()) {
1155 miscellaneousClients_.erase(client);
1156 return true;
1157 }
1158 return false;
1159 }
1160
registerClient(const c10::intrusive_ptr<UvHandle> & client)1161 void LibUVStoreDaemon::registerClient(
1162 const c10::intrusive_ptr<UvHandle>& client) {
1163 clients_.insert(client);
1164 miscellaneousClients_.insert(client);
1165 }
1166
unregisterClient(const c10::intrusive_ptr<UvHandle> & client)1167 void LibUVStoreDaemon::unregisterClient(
1168 const c10::intrusive_ptr<UvHandle>& client) {
1169 clients_.erase(client);
1170 if (miscellaneousClients_.find(client) != miscellaneousClients_.end()) {
1171 miscellaneousClients_.erase(client);
1172 }
1173 clearClientWaitState(client);
1174 }
1175
clearClientWaitState(const c10::intrusive_ptr<UvHandle> & client)1176 void LibUVStoreDaemon::clearClientWaitState(
1177 const c10::intrusive_ptr<UvHandle>& client) {
1178 if (keysAwaited_.find(client) == keysAwaited_.end()) {
1179 return;
1180 }
1181 keysAwaited_.erase(client);
1182 for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
1183 for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
1184 if (*vecIt == client) {
1185 vecIt = it->second.erase(vecIt);
1186 } else {
1187 ++vecIt;
1188 }
1189 }
1190 if (it->second.empty()) {
1191 it = waitingSockets_.erase(it);
1192 } else {
1193 ++it;
1194 }
1195 }
1196 }
1197
set(const std::string & key,const std::vector<uint8_t> & value)1198 void LibUVStoreDaemon::set(
1199 const std::string& key,
1200 const std::vector<uint8_t>& value) {
1201 tcpStore_[key] = value;
1202 // On "set", wake up all clients that have been waiting
1203 wakeupWaitingClients(key);
1204 }
1205
compareAndSet(const std::string & key,const std::vector<uint8_t> & expectedValue,const std::vector<uint8_t> & newValue)1206 const std::vector<uint8_t>& LibUVStoreDaemon::compareAndSet(
1207 const std::string& key,
1208 const std::vector<uint8_t>& expectedValue,
1209 const std::vector<uint8_t>& newValue) {
1210 auto pos = tcpStore_.find(key);
1211 if (pos == tcpStore_.end()) {
1212 if (expectedValue.empty()) {
1213 tcpStore_[key] = newValue;
1214 wakeupWaitingClients(key);
1215 return newValue;
1216 } else {
1217 // TODO: This code path is not ideal as we are "lying" to the caller in
1218 // case the key does not exist. We should come up with a working solution.
1219 // It might make more sense to return ""
1220 wakeupWaitingClients(key);
1221 return expectedValue;
1222 }
1223 } else {
1224 if (pos->second == expectedValue) {
1225 pos->second = newValue;
1226 }
1227 wakeupWaitingClients(key);
1228 return pos->second;
1229 }
1230 }
1231
get(const std::string & key)1232 const std::vector<uint8_t>& LibUVStoreDaemon::get(const std::string& key) {
1233 static std::vector<uint8_t> missing_key;
1234 return tcpStore_.count(key) ? tcpStore_.at(key) : missing_key;
1235 }
1236
add(const std::string & key,int64_t addVal)1237 int64_t LibUVStoreDaemon::add(const std::string& key, int64_t addVal) {
1238 std::vector<uint8_t> oldData;
1239 auto it = tcpStore_.find(key);
1240 if (it != tcpStore_.end()) {
1241 oldData = it->second;
1242 auto buf = reinterpret_cast<const char*>(it->second.data());
1243 auto len = it->second.size();
1244 addVal += std::stoll(std::string(buf, len));
1245 }
1246 auto addValStr = std::to_string(addVal);
1247 std::vector<uint8_t> newData =
1248 std::vector<uint8_t>(addValStr.begin(), addValStr.end());
1249 tcpStore_[key] = newData;
1250
1251 // On "add", wake up all clients that have been waiting
1252 wakeupWaitingClients(key);
1253
1254 return addVal;
1255 }
1256
checkKeys(const std::vector<std::string> & keys)1257 bool LibUVStoreDaemon::checkKeys(const std::vector<std::string>& keys) {
1258 return std::all_of(keys.begin(), keys.end(), [&](const std::string& s) {
1259 return tcpStore_.count(s) > 0;
1260 });
1261 }
1262
waitKeys(const std::vector<std::string> & keys,const c10::intrusive_ptr<UvHandle> & client)1263 bool LibUVStoreDaemon::waitKeys(
1264 const std::vector<std::string>& keys,
1265 const c10::intrusive_ptr<UvHandle>& client) {
1266 if (checkKeys(keys)) {
1267 return true;
1268 }
1269 int numKeysToAwait = 0;
1270 for (auto& key : keys) {
1271 // Only count keys that have not already been set
1272 if (tcpStore_.find(key) == tcpStore_.end()) {
1273 waitingSockets_[key].push_back(client);
1274 numKeysToAwait++;
1275 }
1276 }
1277 keysAwaited_[client] = numKeysToAwait;
1278 return false;
1279 }
1280
size()1281 int64_t LibUVStoreDaemon::size() {
1282 return tcpStore_.size();
1283 }
1284
deleteKey(const std::string & key)1285 int64_t LibUVStoreDaemon::deleteKey(const std::string& key) {
1286 return tcpStore_.erase(key);
1287 }
1288
append(const std::string & key,const std::vector<uint8_t> & value)1289 void LibUVStoreDaemon::append(
1290 const std::string& key,
1291 const std::vector<uint8_t>& value) {
1292 std::vector<uint8_t> oldData;
1293 auto it = tcpStore_.find(key);
1294 if (it != tcpStore_.end()) {
1295 it->second.insert(it->second.end(), value.begin(), value.end());
1296 } else {
1297 tcpStore_[key] = value;
1298 }
1299
1300 // we should not have clients waiting if we're appending, so it's all fine
1301 wakeupWaitingClients(key);
1302 }
1303
wakeupWaitingClients(const std::string & key)1304 void LibUVStoreDaemon::wakeupWaitingClients(const std::string& key) {
1305 auto socketsToWait = waitingSockets_.find(key);
1306 if (socketsToWait != waitingSockets_.end()) {
1307 for (const auto& client : socketsToWait->second) {
1308 if (--keysAwaited_[client] == 0) {
1309 StreamWriter sw(client->iptr());
1310 sw.write1((uint8_t)WaitResponseType::STOP_WAITING);
1311 sw.send();
1312 }
1313 }
1314 waitingSockets_.erase(socketsToWait);
1315 }
1316 }
1317
1318 #endif
1319
create_libuv_tcpstore_backend(const TCPStoreOptions & opts)1320 std::unique_ptr<BackgroundThread> create_libuv_tcpstore_backend(
1321 const TCPStoreOptions& opts) {
1322 #ifdef TORCH_USE_LIBUV
1323 auto res = std::make_unique<LibUVStoreDaemon>(opts.port);
1324 res->init(opts);
1325 return res;
1326 #else
1327 TORCH_CHECK(false, "LibUV TCPStore implementation missing");
1328 #endif
1329 }
1330
is_libuv_tcpstore_backend_available()1331 bool is_libuv_tcpstore_backend_available() {
1332 #ifdef TORCH_USE_LIBUV
1333 return true;
1334 #else
1335 return false;
1336 #endif
1337 }
1338
1339 } // namespace c10d::detail
1340