1 #pragma once 2 3 #include <cstddef> 4 #include <cstdint> 5 #include <memory> 6 7 #include <torch/csrc/distributed/c10d/Store.hpp> 8 9 namespace c10d { 10 namespace detail { 11 12 class TCPServer; 13 14 class TCPClient; 15 16 struct SocketAddress { 17 std::string host{}; 18 std::uint16_t port{}; 19 }; 20 21 } // namespace detail 22 23 struct TCPStoreOptions { 24 static constexpr std::uint16_t kDefaultPort = 29500; 25 26 std::uint16_t port = kDefaultPort; 27 bool isServer = false; 28 std::optional<std::size_t> numWorkers = std::nullopt; 29 bool waitWorkers = true; 30 std::chrono::milliseconds timeout = Store::kDefaultTimeout; 31 32 // A boolean value indicating whether multiple store instances can be 33 // initialized with the same host:port pair. 34 bool multiTenant = false; 35 36 // If specified, and if isServer is true, the underlying TCPServer will take 37 // over the bound socket associated to this fd. This option is useful to avoid 38 // port assignment races in certain scenarios. 39 std::optional<int> masterListenFd = std::nullopt; 40 41 // A boolean value indicating whether to use the experimental libUV backend. 42 bool useLibUV = true; 43 }; 44 45 class TORCH_API TCPStore : public Store { 46 public: 47 static constexpr std::chrono::milliseconds kConnectRetryDelay{1000}; 48 49 explicit TCPStore(std::string host, const TCPStoreOptions& opts = {}); 50 51 [[deprecated("Use TCPStore(host, opts) instead.")]] explicit TCPStore( 52 const std::string& masterAddr, 53 std::uint16_t masterPort, 54 std::optional<int> numWorkers = std::nullopt, 55 bool isServer = false, 56 const std::chrono::milliseconds& timeout = kDefaultTimeout, 57 bool waitWorkers = true); 58 59 ~TCPStore() override; 60 61 void set(const std::string& key, const std::vector<uint8_t>& value) override; 62 63 std::vector<uint8_t> compareSet( 64 const std::string& key, 65 const std::vector<uint8_t>& expectedValue, 66 const std::vector<uint8_t>& desiredValue) override; 67 68 std::vector<uint8_t> get(const std::string& key) override; 69 70 int64_t add(const std::string& key, int64_t value) override; 71 72 bool deleteKey(const std::string& key) override; 73 74 bool check(const std::vector<std::string>& keys) override; 75 76 int64_t getNumKeys() override; 77 78 void wait(const std::vector<std::string>& keys) override; 79 80 void wait( 81 const std::vector<std::string>& keys, 82 const std::chrono::milliseconds& timeout) override; 83 84 void append(const std::string& key, const std::vector<uint8_t>& value) 85 override; 86 87 std::vector<std::vector<uint8_t>> multiGet( 88 const std::vector<std::string>& keys) override; 89 90 void multiSet( 91 const std::vector<std::string>& keys, 92 const std::vector<std::vector<uint8_t>>& values) override; 93 94 bool hasExtendedApi() const override; 95 96 // Waits for all workers to join. 97 void waitForWorkers(); 98 99 // Returns the hostname used by the TCPStore. getHost() const100 const std::string& getHost() const noexcept { 101 return addr_.host; 102 } 103 104 // Returns the port used by the TCPStore. getPort() const105 std::uint16_t getPort() const noexcept { 106 return addr_.port; 107 } 108 isLibUvBackend() const109 bool isLibUvBackend() const noexcept { 110 return usingLibUv_; 111 } 112 113 // note(xilunwu): this function is only for internal testing 114 void _splitSet(const std::string& key, const std::vector<uint8_t>& data); 115 116 std::string repr() const; 117 118 private: 119 int64_t incrementValueBy(const std::string& key, int64_t delta); 120 121 void ping(); 122 void validate(); 123 124 std::vector<uint8_t> doGet(const std::string& key); 125 126 void doWait( 127 c10::ArrayRef<std::string> keys, 128 std::chrono::milliseconds timeout); 129 130 detail::SocketAddress addr_; 131 std::shared_ptr<detail::TCPServer> server_; 132 std::unique_ptr<detail::TCPClient> client_; 133 std::optional<std::size_t> numWorkers_; 134 135 const std::string initKey_ = "init/"; 136 const std::string keyPrefix_ = "/"; 137 std::mutex activeOpLock_; 138 bool usingLibUv_ = true; 139 }; 140 141 } // namespace c10d 142