xref: /aosp_15_r20/external/pytorch/test/cpp/c10d/HashStoreTest.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include "StoreTestCommon.hpp"
3 
4 #include <unistd.h>
5 
6 #include <iostream>
7 #include <thread>
8 
9 #include <torch/csrc/distributed/c10d/HashStore.hpp>
10 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
11 
12 constexpr int64_t kShortStoreTimeoutMillis = 100;
13 
testGetSet(std::string prefix="")14 void testGetSet(std::string prefix = "") {
15   // Basic set/get
16   {
17     auto hashStore = c10::make_intrusive<c10d::HashStore>();
18     c10d::PrefixStore store(prefix, hashStore);
19     c10d::test::set(store, "key0", "value0");
20     c10d::test::set(store, "key1", "value1");
21     c10d::test::set(store, "key2", "value2");
22     c10d::test::check(store, "key0", "value0");
23     c10d::test::check(store, "key1", "value1");
24     c10d::test::check(store, "key2", "value2");
25 
26     // Check compareSet, does not check return value
27     c10d::test::compareSet(store, "key0", "wrongExpectedValue", "newValue");
28     c10d::test::check(store, "key0", "value0");
29     c10d::test::compareSet(store, "key0", "value0", "newValue");
30     c10d::test::check(store, "key0", "newValue");
31 
32     auto numKeys = store.getNumKeys();
33     EXPECT_EQ(numKeys, 3);
34     auto delSuccess = store.deleteKey("key0");
35     EXPECT_TRUE(delSuccess);
36     numKeys = store.getNumKeys();
37     EXPECT_EQ(numKeys, 2);
38     auto delFailure = store.deleteKey("badKeyName");
39     EXPECT_FALSE(delFailure);
40     auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis);
41     store.setTimeout(timeout);
42     EXPECT_THROW(store.get("key0"), c10::DistStoreError);
43   }
44 
45   // get() waits up to timeout_.
46   {
47     auto hashStore = c10::make_intrusive<c10d::HashStore>();
48     c10d::PrefixStore store(prefix, hashStore);
49     std::thread th([&]() { c10d::test::set(store, "key0", "value0"); });
50     c10d::test::check(store, "key0", "value0");
51     th.join();
52   }
53 }
54 
stressTestStore(std::string prefix="")55 void stressTestStore(std::string prefix = "") {
56   // Hammer on HashStore::add
57   const auto numThreads = 4;
58   const auto numIterations = 100;
59 
60   std::vector<std::thread> threads;
61   c10d::test::Semaphore sem1, sem2;
62   auto hashStore = c10::make_intrusive<c10d::HashStore>();
63   c10d::PrefixStore store(prefix, hashStore);
64 
65   for (C10_UNUSED const auto i : c10::irange(numThreads)) {
66     threads.emplace_back(std::thread([&] {
67       sem1.post();
68       sem2.wait();
69       for (C10_UNUSED const auto j : c10::irange(numIterations)) {
70         store.add("counter", 1);
71       }
72     }));
73   }
74 
75   sem1.wait(numThreads);
76   sem2.post(numThreads);
77 
78   for (auto& thread : threads) {
79     thread.join();
80   }
81   std::string expected = std::to_string(numThreads * numIterations);
82   c10d::test::check(store, "counter", expected);
83 }
84 
TEST(HashStoreTest,testGetAndSet)85 TEST(HashStoreTest, testGetAndSet) {
86   testGetSet();
87 }
88 
TEST(HashStoreTest,testGetAndSetWithPrefix)89 TEST(HashStoreTest, testGetAndSetWithPrefix) {
90   testGetSet("testPrefix");
91 }
92 
TEST(HashStoreTest,testStressStore)93 TEST(HashStoreTest, testStressStore) {
94   stressTestStore();
95 }
96 
TEST(HashStoreTest,testStressStoreWithPrefix)97 TEST(HashStoreTest, testStressStoreWithPrefix) {
98   stressTestStore("testPrefix");
99 }
100