xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/TCPStore.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstddef>
4 #include <cstdint>
5 #include <memory>
6 
7 #include <torch/csrc/distributed/c10d/Store.hpp>
8 
9 namespace c10d {
10 namespace detail {
11 
12 class TCPServer;
13 
14 class TCPClient;
15 
16 struct SocketAddress {
17   std::string host{};
18   std::uint16_t port{};
19 };
20 
21 } // namespace detail
22 
23 struct TCPStoreOptions {
24   static constexpr std::uint16_t kDefaultPort = 29500;
25 
26   std::uint16_t port = kDefaultPort;
27   bool isServer = false;
28   std::optional<std::size_t> numWorkers = std::nullopt;
29   bool waitWorkers = true;
30   std::chrono::milliseconds timeout = Store::kDefaultTimeout;
31 
32   // A boolean value indicating whether multiple store instances can be
33   // initialized with the same host:port pair.
34   bool multiTenant = false;
35 
36   // If specified, and if isServer is true, the underlying TCPServer will take
37   // over the bound socket associated to this fd. This option is useful to avoid
38   // port assignment races in certain scenarios.
39   std::optional<int> masterListenFd = std::nullopt;
40 
41   // A boolean value indicating whether to use the experimental libUV backend.
42   bool useLibUV = true;
43 };
44 
45 class TORCH_API TCPStore : public Store {
46  public:
47   static constexpr std::chrono::milliseconds kConnectRetryDelay{1000};
48 
49   explicit TCPStore(std::string host, const TCPStoreOptions& opts = {});
50 
51   [[deprecated("Use TCPStore(host, opts) instead.")]] explicit TCPStore(
52       const std::string& masterAddr,
53       std::uint16_t masterPort,
54       std::optional<int> numWorkers = std::nullopt,
55       bool isServer = false,
56       const std::chrono::milliseconds& timeout = kDefaultTimeout,
57       bool waitWorkers = true);
58 
59   ~TCPStore() override;
60 
61   void set(const std::string& key, const std::vector<uint8_t>& value) override;
62 
63   std::vector<uint8_t> compareSet(
64       const std::string& key,
65       const std::vector<uint8_t>& expectedValue,
66       const std::vector<uint8_t>& desiredValue) override;
67 
68   std::vector<uint8_t> get(const std::string& key) override;
69 
70   int64_t add(const std::string& key, int64_t value) override;
71 
72   bool deleteKey(const std::string& key) override;
73 
74   bool check(const std::vector<std::string>& keys) override;
75 
76   int64_t getNumKeys() override;
77 
78   void wait(const std::vector<std::string>& keys) override;
79 
80   void wait(
81       const std::vector<std::string>& keys,
82       const std::chrono::milliseconds& timeout) override;
83 
84   void append(const std::string& key, const std::vector<uint8_t>& value)
85       override;
86 
87   std::vector<std::vector<uint8_t>> multiGet(
88       const std::vector<std::string>& keys) override;
89 
90   void multiSet(
91       const std::vector<std::string>& keys,
92       const std::vector<std::vector<uint8_t>>& values) override;
93 
94   bool hasExtendedApi() const override;
95 
96   // Waits for all workers to join.
97   void waitForWorkers();
98 
99   // Returns the hostname used by the TCPStore.
getHost() const100   const std::string& getHost() const noexcept {
101     return addr_.host;
102   }
103 
104   // Returns the port used by the TCPStore.
getPort() const105   std::uint16_t getPort() const noexcept {
106     return addr_.port;
107   }
108 
isLibUvBackend() const109   bool isLibUvBackend() const noexcept {
110     return usingLibUv_;
111   }
112 
113   // note(xilunwu): this function is only for internal testing
114   void _splitSet(const std::string& key, const std::vector<uint8_t>& data);
115 
116   std::string repr() const;
117 
118  private:
119   int64_t incrementValueBy(const std::string& key, int64_t delta);
120 
121   void ping();
122   void validate();
123 
124   std::vector<uint8_t> doGet(const std::string& key);
125 
126   void doWait(
127       c10::ArrayRef<std::string> keys,
128       std::chrono::milliseconds timeout);
129 
130   detail::SocketAddress addr_;
131   std::shared_ptr<detail::TCPServer> server_;
132   std::unique_ptr<detail::TCPClient> client_;
133   std::optional<std::size_t> numWorkers_;
134 
135   const std::string initKey_ = "init/";
136   const std::string keyPrefix_ = "/";
137   std::mutex activeOpLock_;
138   bool usingLibUv_ = true;
139 };
140 
141 } // namespace c10d
142