xref: /aosp_15_r20/external/pytorch/test/cpp/c10d/TCPStoreTest.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include "StoreTestCommon.hpp"
3 
4 #include <cstdlib>
5 #include <future>
6 #include <iostream>
7 #include <string>
8 #include <system_error>
9 #include <thread>
10 
11 #include <gtest/gtest.h>
12 
13 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
14 #include <torch/csrc/distributed/c10d/TCPStore.hpp>
15 
16 constexpr int64_t kShortStoreTimeoutMillis = 100;
17 constexpr int defaultTimeout = 20;
18 
_createServer(bool useLibUV,int numWorkers=1,int timeout=defaultTimeout)19 c10::intrusive_ptr<c10d::TCPStore> _createServer(
20     bool useLibUV,
21     int numWorkers = 1,
22     int timeout = defaultTimeout) {
23   return c10::make_intrusive<c10d::TCPStore>(
24       "127.0.0.1",
25       c10d::TCPStoreOptions{
26           /* port */ 0,
27           /* isServer */ true,
28           numWorkers,
29           /* waitWorkers */ false,
30           /* timeout */ std::chrono::seconds(timeout),
31           /* multiTenant */ false,
32           /* masterListenFd */ std::nullopt,
33           /* useLibUV*/ useLibUV});
34 }
35 
36 // Different ports for different tests.
testHelper(bool useLibUV,const std::string & prefix="")37 void testHelper(bool useLibUV, const std::string& prefix = "") {
38   constexpr auto numThreads = 16;
39   constexpr auto numWorkers = numThreads + 1;
40 
41   auto serverTCPStore = _createServer(useLibUV, numWorkers);
42 
43   auto serverStore =
44       c10::make_intrusive<c10d::PrefixStore>(prefix, serverTCPStore);
45   // server store
46   auto serverThread = std::thread([&serverStore, &serverTCPStore] {
47     // Wait for all workers to join.
48     serverTCPStore->waitForWorkers();
49 
50     // Basic set/get on the server store
51     c10d::test::set(*serverStore, "key0", "value0");
52     c10d::test::set(*serverStore, "key1", "value1");
53     c10d::test::set(*serverStore, "key2", "value2");
54     c10d::test::check(*serverStore, "key0", "value0");
55     c10d::test::check(*serverStore, "key1", "value1");
56     c10d::test::check(*serverStore, "key2", "value2");
57     serverStore->add("counter", 1);
58     auto numKeys = serverStore->getNumKeys();
59     // We expect 5 keys since 3 are added above, 'counter' is added by the
60     // helper thread, and the init key to coordinate workers.
61     EXPECT_EQ(numKeys, 5);
62 
63     // Check compareSet, does not check return value
64     c10d::test::compareSet(
65         *serverStore, "key0", "wrongExpectedValue", "newValue");
66     c10d::test::check(*serverStore, "key0", "value0");
67     c10d::test::compareSet(*serverStore, "key0", "value0", "newValue");
68     c10d::test::check(*serverStore, "key0", "newValue");
69 
70     auto delSuccess = serverStore->deleteKey("key0");
71     // Ensure that the key was successfully deleted
72     EXPECT_TRUE(delSuccess);
73     auto delFailure = serverStore->deleteKey("badKeyName");
74     // The key was not in the store so the delete operation should have failed
75     // and returned false.
76     EXPECT_FALSE(delFailure);
77     numKeys = serverStore->getNumKeys();
78     EXPECT_EQ(numKeys, 4);
79     auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis);
80     serverStore->setTimeout(timeout);
81     EXPECT_THROW(serverStore->get("key0"), c10::Error);
82   });
83 
84   // Hammer on TCPStore
85   std::vector<std::thread> threads;
86   constexpr auto numIterations = 1000;
87   c10d::test::Semaphore sem1, sem2;
88 
89   c10d::TCPStoreOptions opts{};
90   opts.port = serverTCPStore->getPort();
91   opts.numWorkers = numWorkers;
92 
93   // Each thread will have a client store to send/recv data
94   std::vector<c10::intrusive_ptr<c10d::TCPStore>> clientTCPStores;
95   std::vector<c10::intrusive_ptr<c10d::PrefixStore>> clientStores;
96   for (const auto i : c10::irange(numThreads)) {
97     clientTCPStores.push_back(
98         c10::make_intrusive<c10d::TCPStore>("127.0.0.1", opts));
99     clientStores.push_back(
100         c10::make_intrusive<c10d::PrefixStore>(prefix, clientTCPStores[i]));
101   }
102 
103   std::string expectedCounterRes =
104       std::to_string(numThreads * numIterations + 1);
105 
106   for (const auto i : c10::irange(numThreads)) {
107     threads.emplace_back(
108         std::thread([=, &sem1, &sem2, &clientStores, &expectedCounterRes] {
109           for (C10_UNUSED const auto j : c10::irange(numIterations)) {
110             clientStores[i]->add("counter", 1);
111           }
112           // Let each thread set and get key on its client store
113           std::string key = "thread_" + std::to_string(i);
114           for (const auto j : c10::irange(numIterations)) {
115             std::string val = "thread_val_" + std::to_string(j);
116             c10d::test::set(*clientStores[i], key, val);
117             c10d::test::check(*clientStores[i], key, val);
118           }
119 
120           sem1.post();
121           sem2.wait();
122           // Check the counter results
123           c10d::test::check(*clientStores[i], "counter", expectedCounterRes);
124           // Now check other threads' written data
125           for (const auto j : c10::irange(numThreads)) {
126             if (j == i) {
127               continue;
128             }
129             std::string key = "thread_" + std::to_string(i);
130             std::string val = "thread_val_" + std::to_string(numIterations - 1);
131             c10d::test::check(*clientStores[i], key, val);
132           }
133         }));
134   }
135 
136   sem1.wait(numThreads);
137   sem2.post(numThreads);
138 
139   for (auto& thread : threads) {
140     thread.join();
141   }
142 
143   serverThread.join();
144 
145   // Clear the store to test that client disconnect won't shutdown the store
146   clientStores.clear();
147   clientTCPStores.clear();
148 
149   // Check that the counter has the expected value
150   c10d::test::check(*serverStore, "counter", expectedCounterRes);
151 
152   // Check that each threads' written data from the main thread
153   for (const auto i : c10::irange(numThreads)) {
154     std::string key = "thread_" + std::to_string(i);
155     std::string val = "thread_val_" + std::to_string(numIterations - 1);
156     c10d::test::check(*serverStore, key, val);
157   }
158 }
159 
TEST(TCPStoreTest,testHelper)160 TEST(TCPStoreTest, testHelper) {
161   testHelper(false);
162 }
163 
TEST(TCPStoreTest,testHelperUV)164 TEST(TCPStoreTest, testHelperUV) {
165   testHelper(true);
166 }
167 
TEST(TCPStoreTest,testHelperPrefix)168 TEST(TCPStoreTest, testHelperPrefix) {
169   testHelper(false, "testPrefixNoUV");
170 }
171 
TEST(TCPStoreTest,testHelperPrefixUV)172 TEST(TCPStoreTest, testHelperPrefixUV) {
173   testHelper(true, "testPrefixUV");
174 }
175 
TEST(TCPStoreTest,testCleanShutdown)176 TEST(TCPStoreTest, testCleanShutdown) {
177   int numWorkers = 2;
178 
179   auto serverTCPStore = std::make_unique<c10d::TCPStore>(
180       "127.0.0.1",
181       0,
182       numWorkers,
183       true,
184       std::chrono::seconds(defaultTimeout),
185       /* wait */ false);
186   c10d::test::set(*serverTCPStore, "key", "val");
187 
188   auto clientTCPStore = c10::make_intrusive<c10d::TCPStore>(
189       "127.0.0.1",
190       c10d::TCPStoreOptions{
191           /* port */ serverTCPStore->getPort(),
192           /* isServer */ false,
193           numWorkers,
194           /* waitWorkers */ false,
195           /* timeout */ std::chrono::seconds(defaultTimeout)});
196   clientTCPStore->get("key");
197 
198   auto clientThread = std::thread([&clientTCPStore] {
199     EXPECT_THROW(clientTCPStore->get("invalid_key"), c10::DistNetworkError);
200   });
201 
202   // start server shutdown during a client request
203   serverTCPStore = nullptr;
204 
205   clientThread.join();
206 }
207 
TEST(TCPStoreTest,testLibUVPartialRead)208 TEST(TCPStoreTest, testLibUVPartialRead) {
209   int numWorkers = 2; // thread 0 creates both server and client
210 
211   // server part
212   c10d::TCPStoreOptions server_opts{
213       0,
214       true, // is master
215       numWorkers,
216       false, // don't wait otherwise client thread won't spawn
217       std::chrono::seconds(defaultTimeout)};
218   server_opts.useLibUV = true;
219 
220   auto serverTCPStore =
221       std::make_unique<c10d::TCPStore>("127.0.0.1", server_opts);
222 
223   // client part
224   c10d::TCPStoreOptions client_opts{
225       serverTCPStore->getPort(),
226       false, // is master
227       numWorkers,
228       false, // wait workers
229       std::chrono::seconds(defaultTimeout)};
230   client_opts.useLibUV = true;
231   auto clientTCPStore =
232       c10::make_intrusive<c10d::TCPStore>("127.0.0.1", client_opts);
233   auto clientThread = std::thread([&clientTCPStore] {
234     std::string keyPrefix(
235         "/default_pg/0//b7dc24de75e482ba2ceb9f9ee20732c25c0166d8//cuda//");
236     std::string value("v");
237     std::vector<uint8_t> valueBuf(value.begin(), value.end());
238 
239     // split store->set(key, valueBuf) into two requests
240     for (int i = 0; i < 10; ++i) {
241       std::string key = keyPrefix + std::to_string(i);
242       clientTCPStore->_splitSet(key, valueBuf);
243 
244       // check the result on server
245       c10d::test::check(*clientTCPStore, key, "v");
246     }
247   });
248 
249   clientThread.join();
250 }
251 
testMultiTenantStores(bool libUV)252 void testMultiTenantStores(bool libUV) {
253   c10d::TCPStoreOptions opts{};
254   opts.isServer = true;
255   opts.multiTenant = true;
256   opts.useLibUV = libUV;
257 
258   // Construct two server stores on the same port.
259   auto store1 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
260   auto store2 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
261 
262   // Assert that the two stores share the same server.
263   c10d::test::set(*store1, "key0", "value0");
264   c10d::test::check(*store2, "key0", "value0");
265 
266   // Dispose the second instance and assert that the server is still alive.
267   store2.reset();
268 
269   c10d::test::set(*store1, "key0", "value0");
270   c10d::test::check(*store1, "key0", "value0");
271 }
272 
TEST(TCPStoreTest,testMultiTenantStores)273 TEST(TCPStoreTest, testMultiTenantStores) {
274   testMultiTenantStores(false);
275 }
276 
TEST(TCPStoreTest,testMultiTenantStoresUV)277 TEST(TCPStoreTest, testMultiTenantStoresUV) {
278   testMultiTenantStores(true);
279 }
280