xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/PrefixStore.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
2 #include <utility>
3 
4 namespace c10d {
5 
PrefixStore(std::string prefix,c10::intrusive_ptr<Store> store)6 PrefixStore::PrefixStore(std::string prefix, c10::intrusive_ptr<Store> store)
7     : prefix_(std::move(prefix)), store_(std::move(store)) {}
8 
joinKey(const std::string & key)9 std::string PrefixStore::joinKey(const std::string& key) {
10   return prefix_ + "/" + key;
11 }
12 
joinKeys(const std::vector<std::string> & keys)13 std::vector<std::string> PrefixStore::joinKeys(
14     const std::vector<std::string>& keys) {
15   std::vector<std::string> joinedKeys;
16   joinedKeys.reserve(keys.size());
17   for (const auto& key : keys) {
18     joinedKeys.emplace_back(joinKey(key));
19   }
20   return joinedKeys;
21 }
22 
set(const std::string & key,const std::vector<uint8_t> & value)23 void PrefixStore::set(
24     const std::string& key,
25     const std::vector<uint8_t>& value) {
26   store_->set(joinKey(key), value);
27 }
28 
compareSet(const std::string & key,const std::vector<uint8_t> & expectedValue,const std::vector<uint8_t> & desiredValue)29 std::vector<uint8_t> PrefixStore::compareSet(
30     const std::string& key,
31     const std::vector<uint8_t>& expectedValue,
32     const std::vector<uint8_t>& desiredValue) {
33   return store_->compareSet(joinKey(key), expectedValue, desiredValue);
34 }
35 
get(const std::string & key)36 std::vector<uint8_t> PrefixStore::get(const std::string& key) {
37   return store_->get(joinKey(key));
38 }
39 
add(const std::string & key,int64_t value)40 int64_t PrefixStore::add(const std::string& key, int64_t value) {
41   return store_->add(joinKey(key), value);
42 }
43 
deleteKey(const std::string & key)44 bool PrefixStore::deleteKey(const std::string& key) {
45   return store_->deleteKey(joinKey(key));
46 }
47 
getNumKeys()48 int64_t PrefixStore::getNumKeys() {
49   return store_->getNumKeys();
50 }
51 
check(const std::vector<std::string> & keys)52 bool PrefixStore::check(const std::vector<std::string>& keys) {
53   auto joinedKeys = joinKeys(keys);
54   return store_->check(joinedKeys);
55 }
56 
wait(const std::vector<std::string> & keys)57 void PrefixStore::wait(const std::vector<std::string>& keys) {
58   auto joinedKeys = joinKeys(keys);
59   store_->wait(joinedKeys);
60 }
61 
wait(const std::vector<std::string> & keys,const std::chrono::milliseconds & timeout)62 void PrefixStore::wait(
63     const std::vector<std::string>& keys,
64     const std::chrono::milliseconds& timeout) {
65   auto joinedKeys = joinKeys(keys);
66   store_->wait(joinedKeys, timeout);
67 }
68 
getTimeout() const69 const std::chrono::milliseconds& PrefixStore::getTimeout() const noexcept {
70   return store_->getTimeout();
71 }
72 
setTimeout(const std::chrono::milliseconds & timeout)73 void PrefixStore::setTimeout(const std::chrono::milliseconds& timeout) {
74   store_->setTimeout(timeout);
75 }
76 
append(const std::string & key,const std::vector<uint8_t> & value)77 void PrefixStore::append(
78     const std::string& key,
79     const std::vector<uint8_t>& value) {
80   store_->append(joinKey(key), value);
81 }
82 
multiGet(const std::vector<std::string> & keys)83 std::vector<std::vector<uint8_t>> PrefixStore::multiGet(
84     const std::vector<std::string>& keys) {
85   std::vector<std::string> prefixed_keys;
86   prefixed_keys.reserve(keys.size());
87   for (auto& key : keys) {
88     prefixed_keys.push_back(joinKey(key));
89   }
90   return store_->multiGet(prefixed_keys);
91 }
92 
multiSet(const std::vector<std::string> & keys,const std::vector<std::vector<uint8_t>> & values)93 void PrefixStore::multiSet(
94     const std::vector<std::string>& keys,
95     const std::vector<std::vector<uint8_t>>& values) {
96   std::vector<std::string> prefixed_keys;
97   prefixed_keys.reserve(keys.size());
98   for (auto& key : keys) {
99     prefixed_keys.push_back(joinKey(key));
100   }
101   store_->multiSet(prefixed_keys, values);
102 }
103 
104 // Returns true if this store support append, multiGet and multiSet
hasExtendedApi() const105 bool PrefixStore::hasExtendedApi() const {
106   return store_->hasExtendedApi();
107 }
108 
getUnderlyingStore()109 c10::intrusive_ptr<Store> PrefixStore::getUnderlyingStore() {
110   return store_;
111 }
112 
getUnderlyingNonPrefixStore()113 c10::intrusive_ptr<Store> PrefixStore::getUnderlyingNonPrefixStore() {
114   c10::intrusive_ptr<Store> store = store_;
115 
116   while (store) {
117     // Attempt to dynamically cast to PrefixStore
118     PrefixStore* asPrefixStore = dynamic_cast<PrefixStore*>(store.get());
119     if (asPrefixStore) {
120       store = asPrefixStore->getUnderlyingStore();
121     } else {
122       break; // We've reached a non-PrefixStore
123     }
124   }
125 
126   TORCH_CHECK(
127       store != nullptr, "Underlying Non-PrefixStore shouldn't be null.");
128   return store;
129 }
130 
131 } // namespace c10d
132