1 #include <c10/util/Exception.h>
2 #include <fmt/format.h>
3 #include <torch/csrc/distributed/c10d/Store.hpp>
4 #include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
5 #include <chrono>
6 #include <exception>
7 #include <vector>
8
9 namespace {
getRankKey(const std::string & key,int rank)10 std::string getRankKey(const std::string& key, int rank) {
11 return fmt::format("{}/{}", key, rank);
12 }
13 } // namespace
14
15 namespace c10d {
16
StoreCollectives(c10::intrusive_ptr<::c10d::Store> store,int rank,int worldSize)17 StoreCollectives::StoreCollectives(
18 c10::intrusive_ptr<::c10d::Store> store,
19 int rank,
20 int worldSize)
21 : store_(std::move(store)), rank_(rank), worldSize_(worldSize) {}
22
barrier(const std::string & key,std::chrono::milliseconds timeout,bool blocking)23 void StoreCollectives::barrier(
24 const std::string& key,
25 std::chrono::milliseconds timeout,
26 bool blocking) {
27 enforceUnique(key);
28 StoreTimeoutGuard g{*store_, timeout};
29
30 auto num_members_key = fmt::format("{}/num_members", key);
31 auto last_members_key = fmt::format("{}/last_members", key);
32
33 auto idx = store_->add(num_members_key, 1);
34 store_->set(getRankKey(key, rank_), "joined");
35
36 if (idx == worldSize_) {
37 store_->set(last_members_key, "<val_ignored>");
38 } else if (blocking) {
39 try {
40 store_->wait({last_members_key});
41 } catch (const std::exception& e) {
42 std::string msg = "barrier failed -- missing ranks: ";
43 for (int i = 0; i < worldSize_; i++) {
44 if (i == rank_) {
45 continue;
46 }
47 auto rank_key = getRankKey(key, i);
48 if (!store_->check({rank_key})) {
49 msg += fmt::format("{}, ", i);
50 }
51 }
52 throw std::runtime_error(msg + e.what());
53 }
54 }
55 }
56
broadcastSend(const std::string & key,const std::vector<uint8_t> & data,std::chrono::milliseconds timeout)57 void StoreCollectives::broadcastSend(
58 const std::string& key,
59 const std::vector<uint8_t>& data,
60 std::chrono::milliseconds timeout) {
61 enforceUnique(key);
62 StoreTimeoutGuard g{*store_, timeout};
63
64 store_->set(key, data);
65 }
66
broadcastRecv(const std::string & key,std::chrono::milliseconds timeout)67 std::vector<uint8_t> StoreCollectives::broadcastRecv(
68 const std::string& key,
69 std::chrono::milliseconds timeout) {
70 enforceUnique(key);
71 StoreTimeoutGuard g{*store_, timeout};
72
73 return store_->get(key);
74 }
75
gatherSend(const std::string & key,const std::vector<uint8_t> & data,std::chrono::milliseconds timeout)76 void StoreCollectives::gatherSend(
77 const std::string& key,
78 const std::vector<uint8_t>& data,
79 std::chrono::milliseconds timeout) {
80 enforceUnique(key);
81 StoreTimeoutGuard g{*store_, timeout};
82
83 auto rank_key = getRankKey(key, rank_);
84 store_->set(rank_key, data);
85 }
86
gatherRecv(const std::string & key,const std::vector<uint8_t> & data,std::chrono::milliseconds timeout)87 std::vector<std::vector<uint8_t>> StoreCollectives::gatherRecv(
88 const std::string& key,
89 const std::vector<uint8_t>& data,
90 std::chrono::milliseconds timeout) {
91 enforceUnique(key);
92 StoreTimeoutGuard g{*store_, timeout};
93
94 std::vector<std::string> keys;
95 keys.reserve(worldSize_);
96
97 for (int i = 0; i < worldSize_; i++) {
98 if (i == rank_) {
99 continue;
100 }
101 auto rank_key = getRankKey(key, i);
102 keys.emplace_back(rank_key);
103 }
104
105 std::vector<std::vector<uint8_t>> results;
106 results.reserve(worldSize_);
107
108 try {
109 results = store_->multiGet(keys);
110 } catch (const std::exception& e) {
111 std::string msg = "gather failed -- missing ranks: ";
112 for (int i = 0; i < worldSize_; i++) {
113 if (i == rank_) {
114 continue;
115 }
116 auto rank_key = getRankKey(key, i);
117 if (!store_->check({rank_key})) {
118 msg += fmt::format("{}, ", i);
119 }
120 }
121 throw std::runtime_error(msg + e.what());
122 }
123
124 // insert local data
125 results.insert(results.begin() + rank_, data);
126 return results;
127 }
128
scatterSend(const std::string & key,const std::vector<std::vector<uint8_t>> & data,std::chrono::milliseconds timeout)129 std::vector<uint8_t> StoreCollectives::scatterSend(
130 const std::string& key,
131 const std::vector<std::vector<uint8_t>>& data,
132 std::chrono::milliseconds timeout) {
133 enforceUnique(key);
134 StoreTimeoutGuard g{*store_, timeout};
135
136 std::vector<std::string> keys;
137 keys.reserve(worldSize_);
138 for (int i = 0; i < worldSize_; i++) {
139 if (i == rank_) {
140 continue;
141 }
142 auto rank_key = getRankKey(key, i);
143 keys.emplace_back(rank_key);
144 }
145 auto local = data.at(rank_);
146
147 std::vector<std::vector<uint8_t>> toSend{data};
148
149 toSend.erase(toSend.begin() + rank_);
150
151 store_->multiSet(keys, toSend);
152
153 return local;
154 }
155
scatterRecv(const std::string & key,std::chrono::milliseconds timeout)156 std::vector<uint8_t> StoreCollectives::scatterRecv(
157 const std::string& key,
158 std::chrono::milliseconds timeout) {
159 enforceUnique(key);
160 StoreTimeoutGuard g{*store_, timeout};
161
162 auto rank_key = getRankKey(key, rank_);
163 return store_->get(rank_key);
164 }
165
allGather(const std::string & key,const std::vector<uint8_t> & data,std::chrono::milliseconds timeout)166 std::vector<std::vector<uint8_t>> StoreCollectives::allGather(
167 const std::string& key,
168 const std::vector<uint8_t>& data,
169 std::chrono::milliseconds timeout) {
170 enforceUnique(key);
171 StoreTimeoutGuard g{*store_, timeout};
172
173 auto localKey = getRankKey(key, rank_);
174 store_->set(localKey, data);
175
176 std::vector<std::string> keys;
177 keys.reserve(worldSize_);
178
179 for (int i = 0; i < worldSize_; i++) {
180 auto rank_key = getRankKey(key, i);
181 keys.emplace_back(rank_key);
182 }
183
184 try {
185 return store_->multiGet(keys);
186 } catch (const std::exception& e) {
187 std::string msg = "all_gather failed -- missing ranks: ";
188 for (int i = 0; i < worldSize_; i++) {
189 if (i == rank_) {
190 continue;
191 }
192 auto rank_key = getRankKey(key, i);
193 if (!store_->check({rank_key})) {
194 msg += fmt::format("{}, ", i);
195 }
196 }
197 throw std::runtime_error(msg + e.what());
198 }
199 }
200
allSum(const std::string & key,int64_t value,std::chrono::milliseconds timeout)201 int64_t StoreCollectives::allSum(
202 const std::string& key,
203 int64_t value,
204 std::chrono::milliseconds timeout) {
205 enforceUnique(key);
206 StoreTimeoutGuard g{*store_, timeout};
207
208 store_->add(key, value);
209
210 barrier(key + "/barrier", timeout);
211
212 return store_->add(key, 0);
213 }
214
enforceUnique(const std::string & key)215 void StoreCollectives::enforceUnique(const std::string& key) {
216 auto it = seenKeys_.find(key);
217 TORCH_INTERNAL_ASSERT(
218 it == seenKeys_.end(), "Key ", key, " has already been used.");
219 seenKeys_.emplace(key);
220 }
221
222 } // namespace c10d
223