xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <algorithm>
2 #include <deque>
3 #include <exception>
4 #include <memory>
5 #include <ostream>
6 #include <unordered_map>
7 #include <unordered_set>
8 #include <utility>
9 #include <vector>
10 
11 #include <c10/util/thread_name.h>
12 #include <fmt/format.h>
13 #include <torch/csrc/distributed/c10d/TCPStore.hpp>
14 #include <torch/csrc/distributed/c10d/TCPStoreBackend.hpp>
15 #include <torch/csrc/distributed/c10d/logging.h>
16 
17 #ifdef TORCH_USE_LIBUV
18 #include <uv.h>
19 #endif
20 
21 namespace c10d::detail {
22 
23 #ifdef TORCH_USE_LIBUV
24 
25 /*
26 
27 Exception safety:
28 
29 It's ok to use exceptions during client processing.
30 Other callbacks don't provide exception safety so avoid there.
31 
32 */
33 
34 // This controls how many un-accepted TCP connections can be waiting in the
35 // backlog. This should be at least world size to avoid issues on init. We set
36 // it to -1 to use the host max value which is controlled by `soconnmax`.
37 #define DEFAULT_BACKLOG -1
38 #define MAX_KEY_COUNT (128 * 1024)
39 #define MAX_STRING_LEN (8 * 1024)
40 #define MAX_PAYLOAD_LEN (8 * 1024 * 1024)
41 
42 // This controls the preferred size for buffers.
43 // Too small and we'll need multiple buffers for one request
44 // Too big and we might taxing malloc
45 #define ALLOC_BUFFER_SIZE ((size_t)4000)
46 class UvHandle : public c10::intrusive_ptr_target {
47  public:
48   ~UvHandle() override = default;
49 
iptr()50   c10::intrusive_ptr<UvHandle> iptr() {
51     return c10::intrusive_ptr<UvHandle>::reclaim_copy(this);
52   }
53 
close()54   void close() {
55     if (uv_is_closing(unsafeGetHandle())) {
56       return;
57     }
58     uv_close(unsafeGetHandle(), on_close);
59   }
60 
61   virtual uv_handle_t* unsafeGetHandle() = 0;
62 
63  protected:
handleReady()64   void handleReady() {
65     /*
66     This method must be called once the handle is ready and registered with the
67     loop.
68 
69     Do not call this in the ctor, make_intrusive reset refcounts to one after
70     construction.
71     */
72     uv_handle_set_data(unsafeGetHandle(), this);
73     at::raw::intrusive_ptr::incref(this);
74   }
75 
76   virtual void onClose() = 0;
77 
78  private:
reclaim(uv_handle_t * handle)79   static c10::intrusive_ptr<UvHandle> reclaim(uv_handle_t* handle) {
80     auto h = (UvHandle*)uv_handle_get_data(handle);
81     return c10::intrusive_ptr<UvHandle>::reclaim(h);
82   }
83 
on_close(uv_handle_t * uv_handle)84   static void on_close(uv_handle_t* uv_handle) {
85     auto handle = reclaim(uv_handle);
86     handle->onClose();
87   }
88 };
89 
90 class UvTcpSocket : public UvHandle {
91   uv_tcp_t client{};
92 
iptr()93   c10::intrusive_ptr<UvTcpSocket> iptr() {
94     return c10::intrusive_ptr<UvTcpSocket>::reclaim_copy(this);
95   }
96 
borrow(uv_stream_t * handle)97   static c10::intrusive_ptr<UvTcpSocket> borrow(uv_stream_t* handle) {
98     auto h = (UvTcpSocket*)uv_handle_get_data((uv_handle_t*)handle);
99     return h->iptr();
100   }
101 
alloc_buffer(uv_handle_t * handle,size_t suggested_size,uv_buf_t * buf)102   static void alloc_buffer(
103 
104       uv_handle_t* handle,
105       size_t suggested_size,
106       uv_buf_t* buf) {
107     suggested_size = std::min(suggested_size, (size_t)ALLOC_BUFFER_SIZE);
108     buf->base = (char*)malloc(suggested_size);
109     buf->len = suggested_size;
110   }
111 
read_callback(uv_stream_t * client,ssize_t nread,const uv_buf_t * buf)112   static void read_callback(
113       uv_stream_t* client,
114       ssize_t nread,
115       const uv_buf_t* buf) {
116     auto uv_socket = UvTcpSocket::borrow(client);
117 
118     if (nread < 0) {
119       C10D_DEBUG(
120           "Read callback failed. code:{} name:{} desc:{}",
121           nread,
122           uv_err_name(nread),
123           uv_strerror(nread));
124       uv_socket->close();
125       return;
126     }
127     if (nread > 0) {
128       try {
129         uv_socket->processBuf(buf, nread);
130       } catch (std::exception& ex) {
131         C10D_WARNING("Error processing client message: {}", ex.what());
132         uv_socket->close();
133       }
134     }
135   }
136 
137  public:
UvTcpSocket(uv_loop_t * loop)138   explicit UvTcpSocket(uv_loop_t* loop) {
139     uv_tcp_init(loop, &client);
140     if (int err = uv_tcp_nodelay(&client, 1)) {
141       C10D_WARNING(
142           "The no-delay option cannot be enabled for the client socket. err={}",
143           err);
144     }
145   }
146 
startRead()147   void startRead() {
148     int res = uv_read_start((uv_stream_t*)&client, alloc_buffer, read_callback);
149     if (res) {
150       C10D_WARNING(
151           "Failed to setup read callback. client:{} code:{} name:{} desc:{}.",
152           (void*)this,
153           res,
154           uv_err_name(res),
155           uv_strerror(res));
156       close();
157     }
158   }
159 
unsafeGetHandle()160   uv_handle_t* unsafeGetHandle() override {
161     return (uv_handle_t*)&client;
162   }
163 
164  protected:
unsafeGetStream()165   uv_stream_t* unsafeGetStream() {
166     return (uv_stream_t*)&client;
167   }
168 
unsafeGetSocket()169   uv_tcp_t* unsafeGetSocket() {
170     return &client;
171   }
172 
processBuf(const uv_buf_t * buf,size_t nread)173   virtual void processBuf(const uv_buf_t* buf, size_t nread) {
174     TORCH_CHECK(
175         false, "Trying to read from a socket subclass that lacks processBuf");
176   }
177 
onClose()178   void onClose() override {
179     // TODO use registerClient (and rename it to registerHandle) - this will
180     // significantly simplify things.
181   }
182 };
183 
184 class UvTcpServer : public UvTcpSocket {
185  public:
186   typedef std::function<void(int)> OnConnectCallback;
UvTcpServer(uv_loop_t * loop)187   explicit UvTcpServer(uv_loop_t* loop)
188       : UvTcpSocket(loop), onConnectCb(missingOnConnect) {}
189 
makeWithSocket(uv_loop_t * loop,int socket)190   static c10::intrusive_ptr<UvTcpServer> makeWithSocket(
191       uv_loop_t* loop,
192       int socket) {
193     auto res = c10::make_intrusive<UvTcpServer>(loop);
194     res->handleReady();
195     try {
196       int uv_res = uv_tcp_open((uv_tcp_t*)res->unsafeGetStream(), socket);
197       TORCH_CHECK(
198           uv_res == 0,
199           "Failed to open existing socket. ",
200           "socket: ",
201           socket,
202           ", code: ",
203           uv_res,
204           ", name: ",
205           uv_err_name(uv_res),
206           ", message: ",
207           uv_strerror(uv_res));
208 
209       res->cacheSocketPort();
210     } catch (std::exception& ex) {
211       res->close();
212       throw;
213     }
214 
215     return res;
216   }
217 
setOnConnectCallback(OnConnectCallback && callback)218   void setOnConnectCallback(OnConnectCallback&& callback) {
219     onConnectCb = std::move(callback);
220   }
221 
makeWithPort(uv_loop_t * loop,uint16_t port,bool useIpv6)222   static c10::intrusive_ptr<UvTcpServer> makeWithPort(
223       uv_loop_t* loop,
224       uint16_t port,
225       bool useIpv6) {
226     auto res = c10::make_intrusive<UvTcpServer>(loop);
227     res->handleReady();
228     try {
229       struct sockaddr_storage addr {};
230       int uv_res = 0;
231       if (useIpv6) {
232         uv_res = uv_ip6_addr("::", port, (struct sockaddr_in6*)&addr);
233       } else {
234         uv_res = uv_ip4_addr("0.0.0.0", port, (struct sockaddr_in*)&addr);
235       }
236       TORCH_CHECK(
237           uv_res == 0,
238           "UV Store addr parsing failure. ",
239           "port: ",
240           port,
241           ", useIpv6: ",
242           useIpv6,
243           ", code: ",
244           uv_res,
245           ", name: ",
246           uv_err_name(uv_res),
247           ", message: ",
248           uv_strerror(uv_res));
249 
250       uv_res =
251           uv_tcp_bind(res->unsafeGetSocket(), (const struct sockaddr*)&addr, 0);
252       TORCH_CHECK(
253           uv_res == 0,
254           "The server socket has failed to bind. ",
255           "port: ",
256           port,
257           ", useIpv6: ",
258           useIpv6,
259           ", code: ",
260           uv_res,
261           ", name: ",
262           uv_err_name(uv_res),
263           ", message: ",
264           uv_strerror(uv_res));
265 
266       uv_res =
267           uv_listen(res->unsafeGetStream(), DEFAULT_BACKLOG, on_new_connection);
268       TORCH_CHECK(
269           uv_res == 0,
270           "The server socket has failed to listen on any local network address. ",
271           "port: ",
272           port,
273           ", useIpv6: ",
274           useIpv6,
275           ", code: ",
276           uv_res,
277           ", name: ",
278           uv_err_name(uv_res),
279           ", message: ",
280           uv_strerror(uv_res));
281 
282       res->cacheSocketPort();
283     } catch (std::exception& ex) {
284       res->close();
285       throw;
286     }
287 
288     return res;
289   }
290 
port() const291   uint16_t port() const {
292     return portNum;
293   }
294 
accept(const c10::intrusive_ptr<UvTcpSocket> & socket)295   void accept(const c10::intrusive_ptr<UvTcpSocket>& socket) {
296     int res =
297         uv_accept(unsafeGetStream(), (uv_stream_t*)socket->unsafeGetHandle());
298     TORCH_CHECK(
299         res == 0,
300         "Failed to accept socket. ",
301         "code: ",
302         res,
303         ", name: ",
304         uv_err_name(res),
305         ", message: ",
306         uv_strerror(res));
307   }
308 
309  private:
310   OnConnectCallback onConnectCb;
311   uint16_t portNum{};
312 
iptr()313   c10::intrusive_ptr<UvTcpServer> iptr() {
314     return c10::intrusive_ptr<UvTcpServer>::reclaim_copy(this);
315   }
316 
borrow(uv_stream_t * handle)317   static c10::intrusive_ptr<UvTcpServer> borrow(uv_stream_t* handle) {
318     auto h = (UvTcpServer*)uv_handle_get_data((uv_handle_t*)handle);
319     return h->iptr();
320   }
321 
cacheSocketPort()322   void cacheSocketPort() {
323     sockaddr_storage addr_s{};
324 
325     int addr_len = sizeof(addr_s);
326 
327     if (uv_tcp_getsockname(
328             (uv_tcp_t*)unsafeGetStream(),
329             reinterpret_cast<sockaddr*>(&addr_s),
330             &addr_len) != 0) {
331       throw std::runtime_error(
332           "The port number of the socket cannot be retrieved.");
333     }
334 
335     if (addr_s.ss_family == AF_INET) {
336       portNum = ntohs(reinterpret_cast<sockaddr_in*>(&addr_s)->sin_port);
337     } else {
338       portNum = ntohs(reinterpret_cast<sockaddr_in6*>(&addr_s)->sin6_port);
339     }
340   }
341 
missingOnConnect(int status)342   static void missingOnConnect(int status) {
343     TORCH_CHECK(false, "Socket accepted byt onConnect callback missing");
344   }
345 
on_new_connection(uv_stream_t * server,int status)346   static void on_new_connection(uv_stream_t* server, int status) {
347     borrow(server)->onConnectCb(status);
348   }
349 };
350 
351 class WriterPayload : public c10::intrusive_ptr_target {
reclaim(uv_write_t * request)352   static c10::intrusive_ptr<WriterPayload> reclaim(uv_write_t* request) {
353     /* This method returns a intrusive_ptr that does not increase the refcount.
354      */
355     auto h = (WriterPayload*)uv_req_get_data((uv_req_t*)request);
356     return c10::intrusive_ptr<WriterPayload>::reclaim(h);
357   }
358 
registeredInLoop()359   void registeredInLoop() {
360     /*
361     This refcount increment must be matched by a reclaim call.
362     Call this method after sucessfully scheduling this handle with a loop.
363     */
364     at::raw::intrusive_ptr::incref(this);
365   }
366 
write_done(uv_write_t * req,int status)367   static void write_done(uv_write_t* req, int status) {
368     /* Since we're no longer actively used by the event loop, transfer ownership
369      * to this frame. */
370     auto wp = WriterPayload::reclaim(req);
371     auto handle = wp->handle;
372 
373     if (status) {
374       C10D_WARNING(
375           "Write to client failed. code:{} name:{} desc:{}.",
376           status,
377           uv_err_name(status),
378           uv_strerror(status));
379       handle->close();
380     }
381   }
382 
383   std::vector<uint8_t> data;
384   uv_write_t req = {};
385   uv_buf_t buf = {};
386   c10::intrusive_ptr<UvHandle> handle;
387 
388  public:
WriterPayload(std::vector<uint8_t> && in_data,c10::intrusive_ptr<UvHandle> handle)389   WriterPayload(
390       std::vector<uint8_t>&& in_data,
391       c10::intrusive_ptr<UvHandle> handle)
392       : data(std::move(in_data)), handle(std::move(handle)) {
393     uv_req_set_data((uv_req_t*)&req, this);
394   }
395 
396   ~WriterPayload() override = default;
397 
send()398   void send() {
399     buf = uv_buf_init((char*)data.data(), data.size());
400     int res = uv_write(
401         &req, (uv_stream_t*)handle->unsafeGetHandle(), &buf, 1, write_done);
402 
403     if (res) {
404       C10D_WARNING(
405           "Write setup to client failed. code:{} name:{} desc:{}.",
406           res,
407           uv_err_name(res),
408           uv_strerror(res));
409       handle->close();
410     } else {
411       /* This object was successfully registered with the event loop, so keep it
412        * alive until it's unregistered. */
413       registeredInLoop();
414     }
415   }
416 };
417 
418 class StreamWriter {
419   std::vector<uint8_t> data;
420   c10::intrusive_ptr<UvHandle> handle;
421 
422   // must be stack allocated
423   void* operator new(size_t);
424 
425  public:
StreamWriter(c10::intrusive_ptr<UvHandle> handle)426   StreamWriter(c10::intrusive_ptr<UvHandle> handle)
427       : handle(std::move(handle)) {}
428 
write1(uint8_t val)429   void write1(uint8_t val) {
430     data.push_back(val);
431   }
432 
433   template <typename T>
write_value(T val)434   void write_value(T val) {
435     uint8_t* val_ptr = (uint8_t*)&val;
436     data.insert(data.end(), val_ptr, val_ptr + sizeof(T));
437   }
438 
write_vector(const std::vector<uint8_t> & val)439   void write_vector(const std::vector<uint8_t>& val) {
440     write_value<uint64_t>(val.size());
441     data.insert(data.end(), val.begin(), val.end());
442   }
443 
write_string(const std::string & val)444   void write_string(const std::string& val) {
445     write_value<uint64_t>(val.size());
446     data.insert(data.end(), val.data(), val.data() + val.size());
447   }
send()448   void send() {
449     auto wd = c10::make_intrusive<WriterPayload>(std::move(data), handle);
450     wd->send();
451   }
452 };
453 
454 class ChunkedStream {
455   std::deque<uv_buf_t> buffers;
456   size_t buff_idx{0};
457   size_t buff_offset{0};
458   size_t capacity{0};
459   size_t buff_offset_commit{0};
460   size_t read_offset{0};
461 
462  public:
463   ChunkedStream() = default;
464 
buf_count()465   size_t buf_count() {
466     return buffers.size();
467   }
468 
append(uv_buf_t buf)469   void append(uv_buf_t buf) {
470     if (buf.len == 0) {
471       free(buf.base);
472     } else {
473       capacity += buf.len;
474       buffers.push_back(buf);
475     }
476   }
read_many(char * dest,size_t size)477   bool read_many(char* dest, size_t size) {
478     if (available() < size) {
479       return false;
480     }
481 
482     size_t remaining = size;
483     char* write_base = dest;
484     while (remaining > 0) {
485       auto to_read = std::min(buffers[buff_idx].len - buff_offset, remaining);
486       ::memcpy(write_base, buffers[buff_idx].base + buff_offset, to_read);
487       buff_offset += to_read;
488       remaining -= to_read;
489       write_base += to_read;
490       if (buff_offset >= buffers[buff_idx].len) {
491         buff_offset = 0;
492         ++buff_idx;
493         if (buff_idx >= buffers.size() && remaining > 0) {
494           TORCH_CHECK(
495               false,
496               "Trying to read past end of buffer. ",
497               "buffer_idx: ",
498               buff_idx,
499               ", available: ",
500               buffers.size(),
501               ", remaining: ",
502               remaining);
503         }
504       }
505     }
506     read_offset += size;
507     return true;
508   }
509 
read1(uint8_t & byte)510   bool read1(uint8_t& byte) {
511     while (true) {
512       if (buff_idx >= buffers.size())
513         return false;
514       if (buff_offset >= buffers[buff_idx].len) {
515         buff_offset = 0;
516         ++buff_idx;
517         continue;
518       }
519       break;
520     }
521 
522     byte = buffers[buff_idx].base[buff_offset];
523     ++buff_offset;
524     ++read_offset;
525     return true;
526   }
527 
528   template <typename T>
read_value(T & value)529   bool read_value(T& value) {
530     return read_many((char*)&value, sizeof(T));
531   }
532 
read_key(std::string & str)533   bool read_key(std::string& str) {
534     uint64_t size = 0;
535     if (!read_value(size))
536       return false;
537     TORCH_CHECK(
538         size <= MAX_STRING_LEN,
539         "Invalid string size. ",
540         "size: ",
541         size,
542         ", max: ",
543         MAX_STRING_LEN);
544 
545     if (available() < size)
546       return false;
547     str.resize(size);
548     return read_many((char*)str.data(), size);
549   }
550 
read_payload(std::vector<uint8_t> & data)551   bool read_payload(std::vector<uint8_t>& data) {
552     uint64_t size = 0;
553     if (!read_value(size))
554       return false;
555     auto size_in_bytes = size * sizeof(uint8_t);
556     TORCH_CHECK(
557         size_in_bytes <= MAX_PAYLOAD_LEN,
558         "Invalid payload size. ",
559         "size: ",
560         size_in_bytes,
561         ", max: ",
562         MAX_PAYLOAD_LEN);
563 
564     if (available() < size_in_bytes)
565       return false;
566     data.resize(size);
567     return read_many((char*)data.data(), size_in_bytes);
568   }
569 
available()570   size_t available() {
571     return capacity - read_offset;
572   }
573 
commit()574   void commit() {
575     if (buff_idx >= buffers.size() || buff_offset >= buffers[buff_idx].len) {
576       buff_offset = 0;
577       if (buff_idx < buffers.size())
578         ++buff_idx;
579     }
580 
581     for (size_t i = 0; i < buff_idx; ++i) {
582       free(buffers[0].base);
583       capacity -= buffers[0].len;
584       buffers.pop_front();
585     }
586     buff_idx = 0;
587     read_offset = buff_offset_commit = buff_offset;
588   }
589 
reset()590   void reset() {
591     buff_idx = 0;
592     read_offset = buff_offset = buff_offset_commit;
593   }
594 };
595 
596 class LibUVStoreDaemon : public BackgroundThread {
597  public:
598   explicit LibUVStoreDaemon(int port);
599   ~LibUVStoreDaemon() override;
600 
601   uint16_t port() const override;
602 
603   void set(const std::string& key, const std::vector<uint8_t>& value);
604   const std::vector<uint8_t>& compareAndSet(
605       const std::string& key,
606       const std::vector<uint8_t>& expectedValue,
607       const std::vector<uint8_t>& newValue);
608   const std::vector<uint8_t>& get(const std::string& key);
609   int64_t add(const std::string& key, int64_t addVal);
610   bool checkKeys(const std::vector<std::string>& keys);
611   bool waitKeys(
612       const std::vector<std::string>& keys,
613       const c10::intrusive_ptr<UvHandle>& client);
614   int64_t size();
615   int64_t deleteKey(const std::string& key);
616   void append(const std::string& key, const std::vector<uint8_t>& value);
617 
618   void registerClient(const c10::intrusive_ptr<UvHandle>& client);
619   void unregisterClient(const c10::intrusive_ptr<UvHandle>& client);
620   void clearClientWaitState(const c10::intrusive_ptr<UvHandle>& client);
621   bool isMiscellaneousClient(const c10::intrusive_ptr<UvHandle>& client);
622 
623   uint16_t get_socket_port(uv_tcp_t* handle);
624   void init(const TCPStoreOptions& opts);
625 
626  protected:
627   void run() override;
628   void stop() override;
629 
630  private:
631   uv_loop_t loop{};
632   c10::intrusive_ptr<UvTcpServer> tcpServer;
633 
634   uv_async_t exit_handle{};
635   std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_;
636   // From key -> the list of UvClient waiting on the key
637   std::unordered_map<std::string, std::vector<c10::intrusive_ptr<UvHandle>>>
638       waitingSockets_;
639   // From socket -> number of keys awaited
640   std::unordered_map<c10::intrusive_ptr<UvHandle>, size_t> keysAwaited_;
641   std::unordered_set<c10::intrusive_ptr<UvHandle>> clients_;
642   std::unordered_set<c10::intrusive_ptr<UvHandle>> miscellaneousClients_;
643   int port_;
644 
from_uv(uv_handle_t * stream)645   static LibUVStoreDaemon& from_uv(uv_handle_t* stream) {
646     return *(LibUVStoreDaemon*)uv_handle_get_data(stream);
647   }
648 
on_new_connection(uv_stream_t * server,int status)649   static void on_new_connection(uv_stream_t* server, int status) {
650     from_uv((uv_handle_t*)server).onConnect(status);
651   }
652 
on_exit_request(uv_async_t * handle)653   static void on_exit_request(uv_async_t* handle) {
654     from_uv((uv_handle_t*)handle).onExitRequest();
655   }
656 
657   void onConnect(int status);
658   void onExitRequest();
659   void wakeupWaitingClients(const std::string& key);
660   // bool tryListen(bool use_ipv6);
661 
662   static void print_active_handles(uv_handle_t* handle, void* arg);
663 };
664 
665 class UvClient : public UvTcpSocket {
666   ChunkedStream stream;
667   LibUVStoreDaemon* store;
668 
669  protected:
processBuf(const uv_buf_t * buf,size_t nread)670   void processBuf(const uv_buf_t* buf, size_t nread) override {
671     auto tmp = *buf;
672     tmp.len = nread;
673     stream.append(tmp);
674 
675     while (true) {
676       stream.reset();
677       uint8_t command = -1;
678       if (!stream.read1(command))
679         break;
680       if (store->isMiscellaneousClient(iptr())) {
681         if ((QueryType)command != QueryType::VALIDATE)
682           return;
683         if (!parse_validate_command())
684           return;
685       } else {
686         switch ((QueryType)command) {
687           case QueryType::PING:
688             if (!parse_ping_command())
689               return;
690             break;
691           case QueryType::SET:
692             if (!parse_set_command())
693               return;
694             break;
695           case QueryType::COMPARE_SET:
696             if (!parse_compare_set_command())
697               return;
698             break;
699           case QueryType::GET:
700             if (!parse_get_command())
701               return;
702             break;
703           case QueryType::ADD:
704             if (!parse_add_command())
705               return;
706             break;
707           case QueryType::CHECK:
708             if (!parse_check_command())
709               return;
710             break;
711           case QueryType::WAIT:
712             if (!parse_wait_command())
713               return;
714             break;
715           case QueryType::GETNUMKEYS:
716             if (!parse_getnumkeys_command())
717               return;
718             break;
719           case QueryType::DELETE_KEY:
720             if (!parse_delete_key_command())
721               return;
722             break;
723           case QueryType::APPEND:
724             if (!parse_append_command())
725               return;
726             break;
727           case QueryType::MULTI_GET:
728             if (!parse_multi_get_command())
729               return;
730             break;
731           case QueryType::MULTI_SET:
732             if (!parse_multi_set_command())
733               return;
734             break;
735           case QueryType::CANCEL_WAIT:
736             if (!parse_cancel_wait_command())
737               return;
738             break;
739           default:
740             C10D_DEBUG(
741                 "Client sent invalid command. client:{} command:{}",
742                 (void*)this,
743                 (int)command);
744             close();
745             return;
746         }
747       }
748       stream.commit();
749     }
750   }
751 
parse_validate_command()752   bool parse_validate_command() {
753     uint32_t validateNumber = 0;
754     if (!stream.read_value(validateNumber))
755       return false;
756 
757     if (validateNumber != c10d::detail::validationMagicNumber)
758       return false;
759     return true;
760   }
761 
parse_ping_command()762   bool parse_ping_command() {
763     uint32_t nonce;
764     if (!stream.read_value(nonce)) {
765       return false;
766     }
767 
768     StreamWriter sw(iptr());
769     sw.write_value(nonce);
770     sw.send();
771     return true;
772   }
773 
parse_set_command()774   bool parse_set_command() {
775     std::string key;
776     if (!stream.read_key(key))
777       return false;
778 
779     std::vector<uint8_t> newData;
780     if (!stream.read_payload(newData))
781       return false;
782 
783     store->set(key, newData);
784     return true;
785   }
786 
parse_compare_set_command()787   bool parse_compare_set_command() {
788     std::string key;
789     if (!stream.read_key(key))
790       return false;
791 
792     std::vector<uint8_t> currentValue;
793     if (!stream.read_payload(currentValue))
794       return false;
795 
796     std::vector<uint8_t> newValue;
797     if (!stream.read_payload(newValue))
798       return false;
799 
800     auto res = store->compareAndSet(key, currentValue, newValue);
801     StreamWriter sw(iptr());
802     sw.write_vector(res);
803     sw.send();
804 
805     return true;
806   }
807 
parse_get_command()808   bool parse_get_command() {
809     std::string key;
810     if (!stream.read_key(key))
811       return false;
812 
813     const auto& data = store->get(key);
814     StreamWriter sw(iptr());
815     sw.write_vector(data);
816     sw.send();
817     return true;
818   }
819 
parse_add_command()820   bool parse_add_command() {
821     std::string key;
822     if (!stream.read_key(key))
823       return false;
824 
825     int64_t addVal = 0;
826     if (!stream.read_value(addVal))
827       return false;
828 
829     addVal = store->add(key, addVal);
830     StreamWriter sw(iptr());
831     sw.write_value(addVal);
832     sw.send();
833 
834     return true;
835   }
836 
parse_check_command()837   bool parse_check_command() {
838     uint64_t key_count = 0;
839     if (!stream.read_value(key_count))
840       return false;
841     TORCH_CHECK(
842         key_count <= MAX_KEY_COUNT,
843         "Too many keys being waited. ",
844         "keys: ",
845         key_count,
846         ", max: ",
847         MAX_KEY_COUNT);
848 
849     std::vector<std::string> keys(key_count);
850     for (uint64_t i = 0; i < key_count; ++i) {
851       if (!stream.read_key(keys[i]))
852         return false;
853     }
854 
855     // Now we have received all the keys
856     StreamWriter sw(iptr());
857     if (store->checkKeys(keys)) {
858       sw.write_value(CheckResponseType::READY);
859     } else {
860       sw.write_value(CheckResponseType::NOT_READY);
861     }
862     sw.send();
863     return true;
864   }
865 
parse_wait_command()866   bool parse_wait_command() {
867     uint64_t key_count = 0;
868     if (!stream.read_value(key_count)) {
869       return false;
870     }
871     TORCH_CHECK(
872         key_count <= MAX_KEY_COUNT,
873         "Too many keys being waited. ",
874         "keys: ",
875         key_count,
876         ", max: ",
877         MAX_KEY_COUNT);
878 
879     std::vector<std::string> keys(key_count);
880     for (uint64_t i = 0; i < key_count; ++i) {
881       if (!stream.read_key(keys[i]))
882         return false;
883     }
884 
885     if (store->waitKeys(keys, iptr())) {
886       StreamWriter sw(iptr());
887       sw.write1((uint8_t)WaitResponseType::STOP_WAITING);
888       sw.send();
889     }
890 
891     return true;
892   }
893 
parse_getnumkeys_command()894   bool parse_getnumkeys_command() {
895     StreamWriter sw(iptr());
896     sw.write_value<int64_t>(store->size());
897     sw.send();
898 
899     return true;
900   }
901 
parse_delete_key_command()902   bool parse_delete_key_command() {
903     std::string key;
904     if (!stream.read_key(key))
905       return false;
906 
907     auto numDeleted = store->deleteKey(key);
908     StreamWriter sw(iptr());
909     sw.write_value<int64_t>(numDeleted);
910     sw.send();
911 
912     return true;
913   }
914 
parse_append_command()915   bool parse_append_command() {
916     std::string key;
917     if (!stream.read_key(key)) {
918       return false;
919     }
920 
921     std::vector<uint8_t> data;
922     if (!stream.read_payload(data)) {
923       return false;
924     }
925 
926     store->append(key, data);
927     return true;
928   }
929 
parse_multi_get_command()930   bool parse_multi_get_command() {
931     uint64_t key_count = 0;
932     if (!stream.read_value(key_count)) {
933       return false;
934     }
935     TORCH_CHECK(
936         key_count <= MAX_KEY_COUNT,
937         "Too many keys with multi_get. ",
938         "keys: ",
939         key_count,
940         ", max: ",
941         MAX_KEY_COUNT);
942 
943     StreamWriter sw(iptr());
944     for (const auto _ : c10::irange(key_count)) {
945       (void)_; // Suppress unused variable warning
946       std::string key;
947       if (!stream.read_key(key)) {
948         return false;
949       }
950 
951       sw.write_vector(store->get(key));
952     }
953     sw.send();
954 
955     return true;
956   }
957 
parse_multi_set_command()958   bool parse_multi_set_command() {
959     uint64_t key_count = 0;
960     if (!stream.read_value(key_count)) {
961       return false;
962     }
963     TORCH_CHECK(
964         key_count <= MAX_KEY_COUNT,
965         "Too many keys with multi_get. ",
966         "keys: ",
967         key_count,
968         ", max: ",
969         MAX_KEY_COUNT);
970 
971     for (const auto _ : c10::irange(key_count)) {
972       (void)_; // Suppress unused variable warning
973 
974       std::string key;
975       if (!stream.read_key(key)) {
976         return false;
977       }
978 
979       std::vector<uint8_t> newData;
980       if (!stream.read_payload(newData))
981         return false;
982       store->set(key, newData);
983     }
984 
985     return true;
986   }
987 
parse_cancel_wait_command()988   bool parse_cancel_wait_command() {
989     store->clearClientWaitState(iptr());
990 
991     StreamWriter sw(iptr());
992     sw.write1((uint8_t)WaitResponseType::WAIT_CANCELED);
993     sw.send();
994 
995     return true;
996   }
997 
998  public:
UvClient(uv_loop_t * loop,LibUVStoreDaemon * store)999   explicit UvClient(uv_loop_t* loop, LibUVStoreDaemon* store)
1000       : UvTcpSocket(loop), store(store) {}
1001 
make(uv_loop_t * loop,LibUVStoreDaemon * store)1002   static c10::intrusive_ptr<UvClient> make(
1003       uv_loop_t* loop,
1004       LibUVStoreDaemon* store) {
1005     auto res = c10::make_intrusive<UvClient>(loop, store);
1006     res->handleReady();
1007     return res;
1008   }
1009 
iptr()1010   c10::intrusive_ptr<UvClient> iptr() {
1011     return c10::intrusive_ptr<UvClient>::reclaim_copy(this);
1012   }
1013 
1014  protected:
onClose()1015   void onClose() override {
1016     store->unregisterClient(iptr());
1017   }
1018 };
1019 
onConnect(int status)1020 void LibUVStoreDaemon::onConnect(int status) {
1021   auto client = UvClient::make(&loop, this);
1022   registerClient(client);
1023   try {
1024     tcpServer->accept(client);
1025     client->startRead();
1026   } catch (std::exception& e) {
1027     C10D_WARNING("Failed to accept client due to {}", e.what());
1028     client->close();
1029   }
1030 }
1031 
onExitRequest()1032 void LibUVStoreDaemon::onExitRequest() {
1033   C10D_DEBUG("Store exit requested\n");
1034   uv_close((uv_handle_t*)&exit_handle, nullptr);
1035   uv_stop(&loop);
1036 }
1037 
init(const TCPStoreOptions & opts)1038 void LibUVStoreDaemon::init(const TCPStoreOptions& opts) {
1039   if (opts.masterListenFd.has_value()) {
1040     tcpServer = UvTcpServer::makeWithSocket(&loop, *opts.masterListenFd);
1041   } else {
1042     try {
1043       tcpServer = UvTcpServer::makeWithPort(&loop, opts.port, /*useIpv6=*/true);
1044     } catch (std::exception& ex) {
1045       C10D_INFO(
1046           "Failed to bind to ipv6 address, trying ipv4. Error: {}", ex.what());
1047       tcpServer =
1048           UvTcpServer::makeWithPort(&loop, opts.port, /*useIpv6=*/false);
1049     }
1050   }
1051   tcpServer->setOnConnectCallback(
1052       [this](auto status) { this->onConnect(status); });
1053 
1054   port_ = tcpServer->port();
1055   TORCH_CHECK(
1056       port_ == opts.port || opts.port == 0, // zero means use any port
1057       "listen fd ",
1058       *opts.masterListenFd,
1059       " is bound to port ",
1060       port_,
1061       ", expected to be bound to port ",
1062       opts.port);
1063 }
1064 
LibUVStoreDaemon(int port)1065 LibUVStoreDaemon::LibUVStoreDaemon(int port) : port_(port) {
1066   TORCH_CHECK(uv_loop_init(&loop) == 0, "Failed to init uv loop");
1067   TORCH_CHECK(
1068       uv_async_init(&loop, &exit_handle, LibUVStoreDaemon::on_exit_request) ==
1069           0,
1070       "Failed to init uv async event");
1071   uv_handle_set_data((uv_handle_t*)&exit_handle, this);
1072 }
1073 
~LibUVStoreDaemon()1074 LibUVStoreDaemon::~LibUVStoreDaemon() {
1075   if (!is_running()) {
1076     uv_close((uv_handle_t*)&exit_handle, nullptr);
1077     uv_run(&loop, UV_RUN_NOWAIT);
1078     TORCH_CHECK(uv_loop_close(&loop) == 0, "loop cleanup didn't work");
1079   } else {
1080     // the daemon thread cleanup libuv
1081     dispose();
1082   }
1083 }
1084 
port() const1085 uint16_t LibUVStoreDaemon::port() const {
1086   return port_;
1087 }
1088 
print_active_handles(uv_handle_t * handle,void * arg)1089 void LibUVStoreDaemon::print_active_handles(uv_handle_t* handle, void* arg) {
1090   C10D_DEBUG(
1091       "UV live handle type {} active:{} is-closing:{}",
1092       (int)handle->type,
1093       uv_is_active(handle),
1094       uv_is_closing(handle));
1095 }
1096 
run()1097 void LibUVStoreDaemon::run() {
1098   c10::setThreadName("pt_tcpstore_uv");
1099 
1100   C10D_DEBUG("Uv main loop running");
1101   int res = uv_run(&loop, UV_RUN_DEFAULT);
1102   if (res) {
1103     C10D_DEBUG("UV main loop done: res:{}", res);
1104   }
1105   bool debug_enabled =
1106       c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Debug);
1107 
1108   if (debug_enabled) {
1109     C10D_DEBUG("Walking live handles prior to closing clients");
1110     uv_walk(&loop, LibUVStoreDaemon::print_active_handles, nullptr);
1111   }
1112 
1113   for (const auto& client : clients_) {
1114     client->close();
1115   }
1116   tcpServer->close();
1117 
1118   if (debug_enabled) {
1119     C10D_DEBUG("Walking live handles after closing clients");
1120     uv_walk(&loop, LibUVStoreDaemon::print_active_handles, nullptr);
1121   }
1122 
1123   while (true) {
1124     res = uv_loop_close(&loop);
1125     if (res == 0) {
1126       break;
1127     }
1128     C10D_INFO(
1129         "uv_loop_close failed with:{} errn:{} desc:{}",
1130         res,
1131         uv_err_name(res),
1132         uv_strerror(res));
1133     res = uv_run(&loop, UV_RUN_NOWAIT);
1134     if (res != 0) {
1135       std::this_thread::sleep_for(std::chrono::milliseconds(500));
1136     }
1137   }
1138   C10D_INFO("uv_loop cleanup finished.");
1139 }
1140 
stop()1141 void LibUVStoreDaemon::stop() {
1142   int res = uv_async_send(&exit_handle);
1143   if (res) {
1144     C10D_WARNING(
1145         "uv_async_send failed with:{} errn:{} desc:{}\n",
1146         res,
1147         uv_err_name(res),
1148         uv_strerror(res));
1149   }
1150 }
1151 
isMiscellaneousClient(const c10::intrusive_ptr<UvHandle> & client)1152 bool LibUVStoreDaemon::isMiscellaneousClient(
1153     const c10::intrusive_ptr<UvHandle>& client) {
1154   if (miscellaneousClients_.find(client) != miscellaneousClients_.end()) {
1155     miscellaneousClients_.erase(client);
1156     return true;
1157   }
1158   return false;
1159 }
1160 
registerClient(const c10::intrusive_ptr<UvHandle> & client)1161 void LibUVStoreDaemon::registerClient(
1162     const c10::intrusive_ptr<UvHandle>& client) {
1163   clients_.insert(client);
1164   miscellaneousClients_.insert(client);
1165 }
1166 
unregisterClient(const c10::intrusive_ptr<UvHandle> & client)1167 void LibUVStoreDaemon::unregisterClient(
1168     const c10::intrusive_ptr<UvHandle>& client) {
1169   clients_.erase(client);
1170   if (miscellaneousClients_.find(client) != miscellaneousClients_.end()) {
1171     miscellaneousClients_.erase(client);
1172   }
1173   clearClientWaitState(client);
1174 }
1175 
clearClientWaitState(const c10::intrusive_ptr<UvHandle> & client)1176 void LibUVStoreDaemon::clearClientWaitState(
1177     const c10::intrusive_ptr<UvHandle>& client) {
1178   if (keysAwaited_.find(client) == keysAwaited_.end()) {
1179     return;
1180   }
1181   keysAwaited_.erase(client);
1182   for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
1183     for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
1184       if (*vecIt == client) {
1185         vecIt = it->second.erase(vecIt);
1186       } else {
1187         ++vecIt;
1188       }
1189     }
1190     if (it->second.empty()) {
1191       it = waitingSockets_.erase(it);
1192     } else {
1193       ++it;
1194     }
1195   }
1196 }
1197 
set(const std::string & key,const std::vector<uint8_t> & value)1198 void LibUVStoreDaemon::set(
1199     const std::string& key,
1200     const std::vector<uint8_t>& value) {
1201   tcpStore_[key] = value;
1202   // On "set", wake up all clients that have been waiting
1203   wakeupWaitingClients(key);
1204 }
1205 
compareAndSet(const std::string & key,const std::vector<uint8_t> & expectedValue,const std::vector<uint8_t> & newValue)1206 const std::vector<uint8_t>& LibUVStoreDaemon::compareAndSet(
1207     const std::string& key,
1208     const std::vector<uint8_t>& expectedValue,
1209     const std::vector<uint8_t>& newValue) {
1210   auto pos = tcpStore_.find(key);
1211   if (pos == tcpStore_.end()) {
1212     if (expectedValue.empty()) {
1213       tcpStore_[key] = newValue;
1214       wakeupWaitingClients(key);
1215       return newValue;
1216     } else {
1217       // TODO: This code path is not ideal as we are "lying" to the caller in
1218       // case the key does not exist. We should come up with a working solution.
1219       // It might make more sense to return ""
1220       wakeupWaitingClients(key);
1221       return expectedValue;
1222     }
1223   } else {
1224     if (pos->second == expectedValue) {
1225       pos->second = newValue;
1226     }
1227     wakeupWaitingClients(key);
1228     return pos->second;
1229   }
1230 }
1231 
get(const std::string & key)1232 const std::vector<uint8_t>& LibUVStoreDaemon::get(const std::string& key) {
1233   static std::vector<uint8_t> missing_key;
1234   return tcpStore_.count(key) ? tcpStore_.at(key) : missing_key;
1235 }
1236 
add(const std::string & key,int64_t addVal)1237 int64_t LibUVStoreDaemon::add(const std::string& key, int64_t addVal) {
1238   std::vector<uint8_t> oldData;
1239   auto it = tcpStore_.find(key);
1240   if (it != tcpStore_.end()) {
1241     oldData = it->second;
1242     auto buf = reinterpret_cast<const char*>(it->second.data());
1243     auto len = it->second.size();
1244     addVal += std::stoll(std::string(buf, len));
1245   }
1246   auto addValStr = std::to_string(addVal);
1247   std::vector<uint8_t> newData =
1248       std::vector<uint8_t>(addValStr.begin(), addValStr.end());
1249   tcpStore_[key] = newData;
1250 
1251   // On "add", wake up all clients that have been waiting
1252   wakeupWaitingClients(key);
1253 
1254   return addVal;
1255 }
1256 
checkKeys(const std::vector<std::string> & keys)1257 bool LibUVStoreDaemon::checkKeys(const std::vector<std::string>& keys) {
1258   return std::all_of(keys.begin(), keys.end(), [&](const std::string& s) {
1259     return tcpStore_.count(s) > 0;
1260   });
1261 }
1262 
waitKeys(const std::vector<std::string> & keys,const c10::intrusive_ptr<UvHandle> & client)1263 bool LibUVStoreDaemon::waitKeys(
1264     const std::vector<std::string>& keys,
1265     const c10::intrusive_ptr<UvHandle>& client) {
1266   if (checkKeys(keys)) {
1267     return true;
1268   }
1269   int numKeysToAwait = 0;
1270   for (auto& key : keys) {
1271     // Only count keys that have not already been set
1272     if (tcpStore_.find(key) == tcpStore_.end()) {
1273       waitingSockets_[key].push_back(client);
1274       numKeysToAwait++;
1275     }
1276   }
1277   keysAwaited_[client] = numKeysToAwait;
1278   return false;
1279 }
1280 
size()1281 int64_t LibUVStoreDaemon::size() {
1282   return tcpStore_.size();
1283 }
1284 
deleteKey(const std::string & key)1285 int64_t LibUVStoreDaemon::deleteKey(const std::string& key) {
1286   return tcpStore_.erase(key);
1287 }
1288 
append(const std::string & key,const std::vector<uint8_t> & value)1289 void LibUVStoreDaemon::append(
1290     const std::string& key,
1291     const std::vector<uint8_t>& value) {
1292   std::vector<uint8_t> oldData;
1293   auto it = tcpStore_.find(key);
1294   if (it != tcpStore_.end()) {
1295     it->second.insert(it->second.end(), value.begin(), value.end());
1296   } else {
1297     tcpStore_[key] = value;
1298   }
1299 
1300   // we should not have clients waiting if we're appending, so it's all fine
1301   wakeupWaitingClients(key);
1302 }
1303 
wakeupWaitingClients(const std::string & key)1304 void LibUVStoreDaemon::wakeupWaitingClients(const std::string& key) {
1305   auto socketsToWait = waitingSockets_.find(key);
1306   if (socketsToWait != waitingSockets_.end()) {
1307     for (const auto& client : socketsToWait->second) {
1308       if (--keysAwaited_[client] == 0) {
1309         StreamWriter sw(client->iptr());
1310         sw.write1((uint8_t)WaitResponseType::STOP_WAITING);
1311         sw.send();
1312       }
1313     }
1314     waitingSockets_.erase(socketsToWait);
1315   }
1316 }
1317 
1318 #endif
1319 
create_libuv_tcpstore_backend(const TCPStoreOptions & opts)1320 std::unique_ptr<BackgroundThread> create_libuv_tcpstore_backend(
1321     const TCPStoreOptions& opts) {
1322 #ifdef TORCH_USE_LIBUV
1323   auto res = std::make_unique<LibUVStoreDaemon>(opts.port);
1324   res->init(opts);
1325   return res;
1326 #else
1327   TORCH_CHECK(false, "LibUV TCPStore implementation missing");
1328 #endif
1329 }
1330 
is_libuv_tcpstore_backend_available()1331 bool is_libuv_tcpstore_backend_available() {
1332 #ifdef TORCH_USE_LIBUV
1333   return true;
1334 #else
1335   return false;
1336 #endif
1337 }
1338 
1339 } // namespace c10d::detail
1340