1 #include <c10/util/irange.h>
2 #include "StoreTestCommon.hpp"
3
4 #include <cstdlib>
5 #include <future>
6 #include <iostream>
7 #include <string>
8 #include <system_error>
9 #include <thread>
10
11 #include <gtest/gtest.h>
12
13 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
14 #include <torch/csrc/distributed/c10d/TCPStore.hpp>
15
16 constexpr int64_t kShortStoreTimeoutMillis = 100;
17 constexpr int defaultTimeout = 20;
18
_createServer(bool useLibUV,int numWorkers=1,int timeout=defaultTimeout)19 c10::intrusive_ptr<c10d::TCPStore> _createServer(
20 bool useLibUV,
21 int numWorkers = 1,
22 int timeout = defaultTimeout) {
23 return c10::make_intrusive<c10d::TCPStore>(
24 "127.0.0.1",
25 c10d::TCPStoreOptions{
26 /* port */ 0,
27 /* isServer */ true,
28 numWorkers,
29 /* waitWorkers */ false,
30 /* timeout */ std::chrono::seconds(timeout),
31 /* multiTenant */ false,
32 /* masterListenFd */ std::nullopt,
33 /* useLibUV*/ useLibUV});
34 }
35
36 // Different ports for different tests.
testHelper(bool useLibUV,const std::string & prefix="")37 void testHelper(bool useLibUV, const std::string& prefix = "") {
38 constexpr auto numThreads = 16;
39 constexpr auto numWorkers = numThreads + 1;
40
41 auto serverTCPStore = _createServer(useLibUV, numWorkers);
42
43 auto serverStore =
44 c10::make_intrusive<c10d::PrefixStore>(prefix, serverTCPStore);
45 // server store
46 auto serverThread = std::thread([&serverStore, &serverTCPStore] {
47 // Wait for all workers to join.
48 serverTCPStore->waitForWorkers();
49
50 // Basic set/get on the server store
51 c10d::test::set(*serverStore, "key0", "value0");
52 c10d::test::set(*serverStore, "key1", "value1");
53 c10d::test::set(*serverStore, "key2", "value2");
54 c10d::test::check(*serverStore, "key0", "value0");
55 c10d::test::check(*serverStore, "key1", "value1");
56 c10d::test::check(*serverStore, "key2", "value2");
57 serverStore->add("counter", 1);
58 auto numKeys = serverStore->getNumKeys();
59 // We expect 5 keys since 3 are added above, 'counter' is added by the
60 // helper thread, and the init key to coordinate workers.
61 EXPECT_EQ(numKeys, 5);
62
63 // Check compareSet, does not check return value
64 c10d::test::compareSet(
65 *serverStore, "key0", "wrongExpectedValue", "newValue");
66 c10d::test::check(*serverStore, "key0", "value0");
67 c10d::test::compareSet(*serverStore, "key0", "value0", "newValue");
68 c10d::test::check(*serverStore, "key0", "newValue");
69
70 auto delSuccess = serverStore->deleteKey("key0");
71 // Ensure that the key was successfully deleted
72 EXPECT_TRUE(delSuccess);
73 auto delFailure = serverStore->deleteKey("badKeyName");
74 // The key was not in the store so the delete operation should have failed
75 // and returned false.
76 EXPECT_FALSE(delFailure);
77 numKeys = serverStore->getNumKeys();
78 EXPECT_EQ(numKeys, 4);
79 auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis);
80 serverStore->setTimeout(timeout);
81 EXPECT_THROW(serverStore->get("key0"), c10::Error);
82 });
83
84 // Hammer on TCPStore
85 std::vector<std::thread> threads;
86 constexpr auto numIterations = 1000;
87 c10d::test::Semaphore sem1, sem2;
88
89 c10d::TCPStoreOptions opts{};
90 opts.port = serverTCPStore->getPort();
91 opts.numWorkers = numWorkers;
92
93 // Each thread will have a client store to send/recv data
94 std::vector<c10::intrusive_ptr<c10d::TCPStore>> clientTCPStores;
95 std::vector<c10::intrusive_ptr<c10d::PrefixStore>> clientStores;
96 for (const auto i : c10::irange(numThreads)) {
97 clientTCPStores.push_back(
98 c10::make_intrusive<c10d::TCPStore>("127.0.0.1", opts));
99 clientStores.push_back(
100 c10::make_intrusive<c10d::PrefixStore>(prefix, clientTCPStores[i]));
101 }
102
103 std::string expectedCounterRes =
104 std::to_string(numThreads * numIterations + 1);
105
106 for (const auto i : c10::irange(numThreads)) {
107 threads.emplace_back(
108 std::thread([=, &sem1, &sem2, &clientStores, &expectedCounterRes] {
109 for (C10_UNUSED const auto j : c10::irange(numIterations)) {
110 clientStores[i]->add("counter", 1);
111 }
112 // Let each thread set and get key on its client store
113 std::string key = "thread_" + std::to_string(i);
114 for (const auto j : c10::irange(numIterations)) {
115 std::string val = "thread_val_" + std::to_string(j);
116 c10d::test::set(*clientStores[i], key, val);
117 c10d::test::check(*clientStores[i], key, val);
118 }
119
120 sem1.post();
121 sem2.wait();
122 // Check the counter results
123 c10d::test::check(*clientStores[i], "counter", expectedCounterRes);
124 // Now check other threads' written data
125 for (const auto j : c10::irange(numThreads)) {
126 if (j == i) {
127 continue;
128 }
129 std::string key = "thread_" + std::to_string(i);
130 std::string val = "thread_val_" + std::to_string(numIterations - 1);
131 c10d::test::check(*clientStores[i], key, val);
132 }
133 }));
134 }
135
136 sem1.wait(numThreads);
137 sem2.post(numThreads);
138
139 for (auto& thread : threads) {
140 thread.join();
141 }
142
143 serverThread.join();
144
145 // Clear the store to test that client disconnect won't shutdown the store
146 clientStores.clear();
147 clientTCPStores.clear();
148
149 // Check that the counter has the expected value
150 c10d::test::check(*serverStore, "counter", expectedCounterRes);
151
152 // Check that each threads' written data from the main thread
153 for (const auto i : c10::irange(numThreads)) {
154 std::string key = "thread_" + std::to_string(i);
155 std::string val = "thread_val_" + std::to_string(numIterations - 1);
156 c10d::test::check(*serverStore, key, val);
157 }
158 }
159
TEST(TCPStoreTest,testHelper)160 TEST(TCPStoreTest, testHelper) {
161 testHelper(false);
162 }
163
TEST(TCPStoreTest,testHelperUV)164 TEST(TCPStoreTest, testHelperUV) {
165 testHelper(true);
166 }
167
TEST(TCPStoreTest,testHelperPrefix)168 TEST(TCPStoreTest, testHelperPrefix) {
169 testHelper(false, "testPrefixNoUV");
170 }
171
TEST(TCPStoreTest,testHelperPrefixUV)172 TEST(TCPStoreTest, testHelperPrefixUV) {
173 testHelper(true, "testPrefixUV");
174 }
175
TEST(TCPStoreTest,testCleanShutdown)176 TEST(TCPStoreTest, testCleanShutdown) {
177 int numWorkers = 2;
178
179 auto serverTCPStore = std::make_unique<c10d::TCPStore>(
180 "127.0.0.1",
181 0,
182 numWorkers,
183 true,
184 std::chrono::seconds(defaultTimeout),
185 /* wait */ false);
186 c10d::test::set(*serverTCPStore, "key", "val");
187
188 auto clientTCPStore = c10::make_intrusive<c10d::TCPStore>(
189 "127.0.0.1",
190 c10d::TCPStoreOptions{
191 /* port */ serverTCPStore->getPort(),
192 /* isServer */ false,
193 numWorkers,
194 /* waitWorkers */ false,
195 /* timeout */ std::chrono::seconds(defaultTimeout)});
196 clientTCPStore->get("key");
197
198 auto clientThread = std::thread([&clientTCPStore] {
199 EXPECT_THROW(clientTCPStore->get("invalid_key"), c10::DistNetworkError);
200 });
201
202 // start server shutdown during a client request
203 serverTCPStore = nullptr;
204
205 clientThread.join();
206 }
207
TEST(TCPStoreTest,testLibUVPartialRead)208 TEST(TCPStoreTest, testLibUVPartialRead) {
209 int numWorkers = 2; // thread 0 creates both server and client
210
211 // server part
212 c10d::TCPStoreOptions server_opts{
213 0,
214 true, // is master
215 numWorkers,
216 false, // don't wait otherwise client thread won't spawn
217 std::chrono::seconds(defaultTimeout)};
218 server_opts.useLibUV = true;
219
220 auto serverTCPStore =
221 std::make_unique<c10d::TCPStore>("127.0.0.1", server_opts);
222
223 // client part
224 c10d::TCPStoreOptions client_opts{
225 serverTCPStore->getPort(),
226 false, // is master
227 numWorkers,
228 false, // wait workers
229 std::chrono::seconds(defaultTimeout)};
230 client_opts.useLibUV = true;
231 auto clientTCPStore =
232 c10::make_intrusive<c10d::TCPStore>("127.0.0.1", client_opts);
233 auto clientThread = std::thread([&clientTCPStore] {
234 std::string keyPrefix(
235 "/default_pg/0//b7dc24de75e482ba2ceb9f9ee20732c25c0166d8//cuda//");
236 std::string value("v");
237 std::vector<uint8_t> valueBuf(value.begin(), value.end());
238
239 // split store->set(key, valueBuf) into two requests
240 for (int i = 0; i < 10; ++i) {
241 std::string key = keyPrefix + std::to_string(i);
242 clientTCPStore->_splitSet(key, valueBuf);
243
244 // check the result on server
245 c10d::test::check(*clientTCPStore, key, "v");
246 }
247 });
248
249 clientThread.join();
250 }
251
testMultiTenantStores(bool libUV)252 void testMultiTenantStores(bool libUV) {
253 c10d::TCPStoreOptions opts{};
254 opts.isServer = true;
255 opts.multiTenant = true;
256 opts.useLibUV = libUV;
257
258 // Construct two server stores on the same port.
259 auto store1 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
260 auto store2 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
261
262 // Assert that the two stores share the same server.
263 c10d::test::set(*store1, "key0", "value0");
264 c10d::test::check(*store2, "key0", "value0");
265
266 // Dispose the second instance and assert that the server is still alive.
267 store2.reset();
268
269 c10d::test::set(*store1, "key0", "value0");
270 c10d::test::check(*store1, "key0", "value0");
271 }
272
TEST(TCPStoreTest,testMultiTenantStores)273 TEST(TCPStoreTest, testMultiTenantStores) {
274 testMultiTenantStores(false);
275 }
276
TEST(TCPStoreTest,testMultiTenantStoresUV)277 TEST(TCPStoreTest, testMultiTenantStoresUV) {
278 testMultiTenantStores(true);
279 }
280