1 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
2 #include <utility>
3
4 namespace c10d {
5
PrefixStore(std::string prefix,c10::intrusive_ptr<Store> store)6 PrefixStore::PrefixStore(std::string prefix, c10::intrusive_ptr<Store> store)
7 : prefix_(std::move(prefix)), store_(std::move(store)) {}
8
joinKey(const std::string & key)9 std::string PrefixStore::joinKey(const std::string& key) {
10 return prefix_ + "/" + key;
11 }
12
joinKeys(const std::vector<std::string> & keys)13 std::vector<std::string> PrefixStore::joinKeys(
14 const std::vector<std::string>& keys) {
15 std::vector<std::string> joinedKeys;
16 joinedKeys.reserve(keys.size());
17 for (const auto& key : keys) {
18 joinedKeys.emplace_back(joinKey(key));
19 }
20 return joinedKeys;
21 }
22
set(const std::string & key,const std::vector<uint8_t> & value)23 void PrefixStore::set(
24 const std::string& key,
25 const std::vector<uint8_t>& value) {
26 store_->set(joinKey(key), value);
27 }
28
compareSet(const std::string & key,const std::vector<uint8_t> & expectedValue,const std::vector<uint8_t> & desiredValue)29 std::vector<uint8_t> PrefixStore::compareSet(
30 const std::string& key,
31 const std::vector<uint8_t>& expectedValue,
32 const std::vector<uint8_t>& desiredValue) {
33 return store_->compareSet(joinKey(key), expectedValue, desiredValue);
34 }
35
get(const std::string & key)36 std::vector<uint8_t> PrefixStore::get(const std::string& key) {
37 return store_->get(joinKey(key));
38 }
39
add(const std::string & key,int64_t value)40 int64_t PrefixStore::add(const std::string& key, int64_t value) {
41 return store_->add(joinKey(key), value);
42 }
43
deleteKey(const std::string & key)44 bool PrefixStore::deleteKey(const std::string& key) {
45 return store_->deleteKey(joinKey(key));
46 }
47
getNumKeys()48 int64_t PrefixStore::getNumKeys() {
49 return store_->getNumKeys();
50 }
51
check(const std::vector<std::string> & keys)52 bool PrefixStore::check(const std::vector<std::string>& keys) {
53 auto joinedKeys = joinKeys(keys);
54 return store_->check(joinedKeys);
55 }
56
wait(const std::vector<std::string> & keys)57 void PrefixStore::wait(const std::vector<std::string>& keys) {
58 auto joinedKeys = joinKeys(keys);
59 store_->wait(joinedKeys);
60 }
61
wait(const std::vector<std::string> & keys,const std::chrono::milliseconds & timeout)62 void PrefixStore::wait(
63 const std::vector<std::string>& keys,
64 const std::chrono::milliseconds& timeout) {
65 auto joinedKeys = joinKeys(keys);
66 store_->wait(joinedKeys, timeout);
67 }
68
getTimeout() const69 const std::chrono::milliseconds& PrefixStore::getTimeout() const noexcept {
70 return store_->getTimeout();
71 }
72
setTimeout(const std::chrono::milliseconds & timeout)73 void PrefixStore::setTimeout(const std::chrono::milliseconds& timeout) {
74 store_->setTimeout(timeout);
75 }
76
append(const std::string & key,const std::vector<uint8_t> & value)77 void PrefixStore::append(
78 const std::string& key,
79 const std::vector<uint8_t>& value) {
80 store_->append(joinKey(key), value);
81 }
82
multiGet(const std::vector<std::string> & keys)83 std::vector<std::vector<uint8_t>> PrefixStore::multiGet(
84 const std::vector<std::string>& keys) {
85 std::vector<std::string> prefixed_keys;
86 prefixed_keys.reserve(keys.size());
87 for (auto& key : keys) {
88 prefixed_keys.push_back(joinKey(key));
89 }
90 return store_->multiGet(prefixed_keys);
91 }
92
multiSet(const std::vector<std::string> & keys,const std::vector<std::vector<uint8_t>> & values)93 void PrefixStore::multiSet(
94 const std::vector<std::string>& keys,
95 const std::vector<std::vector<uint8_t>>& values) {
96 std::vector<std::string> prefixed_keys;
97 prefixed_keys.reserve(keys.size());
98 for (auto& key : keys) {
99 prefixed_keys.push_back(joinKey(key));
100 }
101 store_->multiSet(prefixed_keys, values);
102 }
103
104 // Returns true if this store support append, multiGet and multiSet
hasExtendedApi() const105 bool PrefixStore::hasExtendedApi() const {
106 return store_->hasExtendedApi();
107 }
108
getUnderlyingStore()109 c10::intrusive_ptr<Store> PrefixStore::getUnderlyingStore() {
110 return store_;
111 }
112
getUnderlyingNonPrefixStore()113 c10::intrusive_ptr<Store> PrefixStore::getUnderlyingNonPrefixStore() {
114 c10::intrusive_ptr<Store> store = store_;
115
116 while (store) {
117 // Attempt to dynamically cast to PrefixStore
118 PrefixStore* asPrefixStore = dynamic_cast<PrefixStore*>(store.get());
119 if (asPrefixStore) {
120 store = asPrefixStore->getUnderlyingStore();
121 } else {
122 break; // We've reached a non-PrefixStore
123 }
124 }
125
126 TORCH_CHECK(
127 store != nullptr, "Underlying Non-PrefixStore shouldn't be null.");
128 return store;
129 }
130
131 } // namespace c10d
132