xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/Exception.h>
2 #include <fmt/format.h>
3 #include <torch/csrc/distributed/c10d/Store.hpp>
4 #include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
5 #include <chrono>
6 #include <exception>
7 #include <vector>
8 
9 namespace {
getRankKey(const std::string & key,int rank)10 std::string getRankKey(const std::string& key, int rank) {
11   return fmt::format("{}/{}", key, rank);
12 }
13 } // namespace
14 
15 namespace c10d {
16 
StoreCollectives(c10::intrusive_ptr<::c10d::Store> store,int rank,int worldSize)17 StoreCollectives::StoreCollectives(
18     c10::intrusive_ptr<::c10d::Store> store,
19     int rank,
20     int worldSize)
21     : store_(std::move(store)), rank_(rank), worldSize_(worldSize) {}
22 
barrier(const std::string & key,std::chrono::milliseconds timeout,bool blocking)23 void StoreCollectives::barrier(
24     const std::string& key,
25     std::chrono::milliseconds timeout,
26     bool blocking) {
27   enforceUnique(key);
28   StoreTimeoutGuard g{*store_, timeout};
29 
30   auto num_members_key = fmt::format("{}/num_members", key);
31   auto last_members_key = fmt::format("{}/last_members", key);
32 
33   auto idx = store_->add(num_members_key, 1);
34   store_->set(getRankKey(key, rank_), "joined");
35 
36   if (idx == worldSize_) {
37     store_->set(last_members_key, "<val_ignored>");
38   } else if (blocking) {
39     try {
40       store_->wait({last_members_key});
41     } catch (const std::exception& e) {
42       std::string msg = "barrier failed -- missing ranks: ";
43       for (int i = 0; i < worldSize_; i++) {
44         if (i == rank_) {
45           continue;
46         }
47         auto rank_key = getRankKey(key, i);
48         if (!store_->check({rank_key})) {
49           msg += fmt::format("{}, ", i);
50         }
51       }
52       throw std::runtime_error(msg + e.what());
53     }
54   }
55 }
56 
broadcastSend(const std::string & key,const std::vector<uint8_t> & data,std::chrono::milliseconds timeout)57 void StoreCollectives::broadcastSend(
58     const std::string& key,
59     const std::vector<uint8_t>& data,
60     std::chrono::milliseconds timeout) {
61   enforceUnique(key);
62   StoreTimeoutGuard g{*store_, timeout};
63 
64   store_->set(key, data);
65 }
66 
broadcastRecv(const std::string & key,std::chrono::milliseconds timeout)67 std::vector<uint8_t> StoreCollectives::broadcastRecv(
68     const std::string& key,
69     std::chrono::milliseconds timeout) {
70   enforceUnique(key);
71   StoreTimeoutGuard g{*store_, timeout};
72 
73   return store_->get(key);
74 }
75 
gatherSend(const std::string & key,const std::vector<uint8_t> & data,std::chrono::milliseconds timeout)76 void StoreCollectives::gatherSend(
77     const std::string& key,
78     const std::vector<uint8_t>& data,
79     std::chrono::milliseconds timeout) {
80   enforceUnique(key);
81   StoreTimeoutGuard g{*store_, timeout};
82 
83   auto rank_key = getRankKey(key, rank_);
84   store_->set(rank_key, data);
85 }
86 
gatherRecv(const std::string & key,const std::vector<uint8_t> & data,std::chrono::milliseconds timeout)87 std::vector<std::vector<uint8_t>> StoreCollectives::gatherRecv(
88     const std::string& key,
89     const std::vector<uint8_t>& data,
90     std::chrono::milliseconds timeout) {
91   enforceUnique(key);
92   StoreTimeoutGuard g{*store_, timeout};
93 
94   std::vector<std::string> keys;
95   keys.reserve(worldSize_);
96 
97   for (int i = 0; i < worldSize_; i++) {
98     if (i == rank_) {
99       continue;
100     }
101     auto rank_key = getRankKey(key, i);
102     keys.emplace_back(rank_key);
103   }
104 
105   std::vector<std::vector<uint8_t>> results;
106   results.reserve(worldSize_);
107 
108   try {
109     results = store_->multiGet(keys);
110   } catch (const std::exception& e) {
111     std::string msg = "gather failed -- missing ranks: ";
112     for (int i = 0; i < worldSize_; i++) {
113       if (i == rank_) {
114         continue;
115       }
116       auto rank_key = getRankKey(key, i);
117       if (!store_->check({rank_key})) {
118         msg += fmt::format("{}, ", i);
119       }
120     }
121     throw std::runtime_error(msg + e.what());
122   }
123 
124   // insert local data
125   results.insert(results.begin() + rank_, data);
126   return results;
127 }
128 
scatterSend(const std::string & key,const std::vector<std::vector<uint8_t>> & data,std::chrono::milliseconds timeout)129 std::vector<uint8_t> StoreCollectives::scatterSend(
130     const std::string& key,
131     const std::vector<std::vector<uint8_t>>& data,
132     std::chrono::milliseconds timeout) {
133   enforceUnique(key);
134   StoreTimeoutGuard g{*store_, timeout};
135 
136   std::vector<std::string> keys;
137   keys.reserve(worldSize_);
138   for (int i = 0; i < worldSize_; i++) {
139     if (i == rank_) {
140       continue;
141     }
142     auto rank_key = getRankKey(key, i);
143     keys.emplace_back(rank_key);
144   }
145   auto local = data.at(rank_);
146 
147   std::vector<std::vector<uint8_t>> toSend{data};
148 
149   toSend.erase(toSend.begin() + rank_);
150 
151   store_->multiSet(keys, toSend);
152 
153   return local;
154 }
155 
scatterRecv(const std::string & key,std::chrono::milliseconds timeout)156 std::vector<uint8_t> StoreCollectives::scatterRecv(
157     const std::string& key,
158     std::chrono::milliseconds timeout) {
159   enforceUnique(key);
160   StoreTimeoutGuard g{*store_, timeout};
161 
162   auto rank_key = getRankKey(key, rank_);
163   return store_->get(rank_key);
164 }
165 
allGather(const std::string & key,const std::vector<uint8_t> & data,std::chrono::milliseconds timeout)166 std::vector<std::vector<uint8_t>> StoreCollectives::allGather(
167     const std::string& key,
168     const std::vector<uint8_t>& data,
169     std::chrono::milliseconds timeout) {
170   enforceUnique(key);
171   StoreTimeoutGuard g{*store_, timeout};
172 
173   auto localKey = getRankKey(key, rank_);
174   store_->set(localKey, data);
175 
176   std::vector<std::string> keys;
177   keys.reserve(worldSize_);
178 
179   for (int i = 0; i < worldSize_; i++) {
180     auto rank_key = getRankKey(key, i);
181     keys.emplace_back(rank_key);
182   }
183 
184   try {
185     return store_->multiGet(keys);
186   } catch (const std::exception& e) {
187     std::string msg = "all_gather failed -- missing ranks: ";
188     for (int i = 0; i < worldSize_; i++) {
189       if (i == rank_) {
190         continue;
191       }
192       auto rank_key = getRankKey(key, i);
193       if (!store_->check({rank_key})) {
194         msg += fmt::format("{}, ", i);
195       }
196     }
197     throw std::runtime_error(msg + e.what());
198   }
199 }
200 
allSum(const std::string & key,int64_t value,std::chrono::milliseconds timeout)201 int64_t StoreCollectives::allSum(
202     const std::string& key,
203     int64_t value,
204     std::chrono::milliseconds timeout) {
205   enforceUnique(key);
206   StoreTimeoutGuard g{*store_, timeout};
207 
208   store_->add(key, value);
209 
210   barrier(key + "/barrier", timeout);
211 
212   return store_->add(key, 0);
213 }
214 
enforceUnique(const std::string & key)215 void StoreCollectives::enforceUnique(const std::string& key) {
216   auto it = seenKeys_.find(key);
217   TORCH_INTERNAL_ASSERT(
218       it == seenKeys_.end(), "Key ", key, " has already been used.");
219   seenKeys_.emplace(key);
220 }
221 
222 } // namespace c10d
223