1 #include <fmt/format.h>
2 #include <torch/csrc/distributed/rpc/agent_utils.h>
3
4 namespace torch::distributed::rpc {
5
collectNames(::c10d::PrefixStore store,const worker_id_t selfId,const std::string & selfName,const int worldSize)6 std::unordered_map<std::string, worker_id_t> collectNames(
7 ::c10d::PrefixStore store,
8 const worker_id_t selfId,
9 const std::string& selfName,
10 const int worldSize) {
11 std::vector<uint8_t> selfNameVector(
12 (uint8_t*)selfName.c_str(),
13 (uint8_t*)selfName.c_str() + selfName.length());
14 store.set(std::to_string(selfId), selfNameVector);
15
16 std::unordered_map<std::string, worker_id_t> nameToId;
17 nameToId.reserve(worldSize);
18 nameToId.emplace(selfName, selfId);
19 for (worker_id_t workerId = 0; workerId < worldSize; ++workerId) {
20 if (workerId == selfId) {
21 continue;
22 }
23 std::vector<uint8_t> workerNameVector = store.get(std::to_string(workerId));
24 std::string workerName(
25 (char*)workerNameVector.data(), workerNameVector.size());
26
27 TORCH_CHECK(
28 nameToId.find(workerName) == nameToId.end(),
29 "RPC worker name ",
30 workerName,
31 " is not unique. Workers ",
32 nameToId.find(workerName)->second,
33 " and ",
34 workerId,
35 " share the same name.");
36
37 nameToId.emplace(workerName, workerId);
38 }
39 return nameToId;
40 }
41
splitString(const std::string & s,const std::string & delim)42 static std::vector<std::string> splitString(
43 const std::string& s,
44 const std::string& delim) {
45 std::vector<std::string> tokens;
46 size_t start = 0;
47 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
48 size_t end;
49 // Iterate through each delimiter
50 while ((end = s.find(delim, start)) != std::string::npos) {
51 tokens.emplace_back(s.substr(start, end - start));
52 start = end + delim.length();
53 }
54 tokens.emplace_back(s.substr(start));
55 return tokens;
56 }
57
58 const std::string allWorkerInfosKey = "_ALL_WORKER_INFOS";
59
collectCurrentNames(::c10d::PrefixStore store,const worker_id_t selfId,const std::string & selfName)60 std::unordered_map<std::string, worker_id_t> collectCurrentNames(
61 ::c10d::PrefixStore store,
62 const worker_id_t selfId,
63 const std::string& selfName) {
64 std::vector<uint8_t> selfNameVector(
65 (uint8_t*)selfName.c_str(),
66 (uint8_t*)selfName.c_str() + selfName.length());
67
68 // Check that ID does not already exist and set {ID : NAME}
69 std::vector<uint8_t> resultVector = store.compareSet(
70 std::to_string(selfId), std::vector<uint8_t>(), selfNameVector);
71 TORCH_CHECK(
72 resultVector == selfNameVector,
73 "RPC worker id ",
74 selfId,
75 " is not unique. Worker ",
76 resultVector,
77 " and already has ID and ",
78 selfNameVector,
79 " cannot be added.");
80
81 store.set(std::to_string(selfId), selfNameVector);
82
83 std::unordered_map<std::string, worker_id_t> nameToId;
84 nameToId.emplace(selfName, selfId);
85
86 // Check to see if there is list of worker names in the store
87 bool worker_names_available =
88 store.check(std::vector<std::string>{allWorkerInfosKey});
89 std::string allWorkerInfos;
90 if (worker_names_available) {
91 // Get the current list of workers
92 std::vector<uint8_t> allWorkerInfosKeyVector = store.get(allWorkerInfosKey);
93 allWorkerInfos = std::string(
94 (char*)allWorkerInfosKeyVector.data(), allWorkerInfosKeyVector.size());
95 // workerInfos are comma separated with a comma at the end (e.g.
96 // "Name1-Rank1,Name2-Rank2,Name3-Rank2,") parse list of workers.
97 if (!allWorkerInfos.empty()) {
98 for (const std::string& workerInfoString : splitString(
99 allWorkerInfos.substr(0, allWorkerInfos.size() - 1), ",")) {
100 auto workerInfoVec = splitString(workerInfoString, "-");
101 std::string workerName = workerInfoVec.at(0);
102 int workerId = std::stoi(workerInfoVec.at(1));
103
104 TORCH_CHECK(
105 nameToId.find(workerName) == nameToId.end(),
106 "RPC worker name ",
107 workerName,
108 " is not unique. Workers ",
109 nameToId.find(workerName)->second,
110 " and ",
111 workerId,
112 " share the same name.");
113
114 nameToId.emplace(workerName, workerId);
115 }
116 }
117 }
118 // Add own name to worker list
119 allWorkerInfos = fmt::format("{}{}-{},", allWorkerInfos, selfName, selfId);
120 std::vector<uint8_t> allWorkerInfosVector(
121 (uint8_t*)allWorkerInfos.c_str(),
122 (uint8_t*)allWorkerInfos.c_str() + allWorkerInfos.length());
123 store.set(allWorkerInfosKey, allWorkerInfosVector);
124
125 return nameToId;
126 }
127
removeCurrentName(::c10d::PrefixStore store,const worker_id_t selfId,const std::string & selfName)128 void removeCurrentName(
129 ::c10d::PrefixStore store,
130 const worker_id_t selfId,
131 const std::string& selfName) {
132 // Get current list of names/ranks
133 std::vector<uint8_t> allWorkerInfosKeyVector = store.get(allWorkerInfosKey);
134 std::string allWorkerInfos = std::string(
135 (char*)allWorkerInfosKeyVector.data(), allWorkerInfosKeyVector.size());
136
137 // Remove the current name and rank
138 std::string str_to_erase = fmt::format("{}-{},", selfName, selfId);
139 int start_position_to_erase = allWorkerInfos.find(str_to_erase);
140 allWorkerInfos.erase(start_position_to_erase, str_to_erase.length());
141
142 // Set the new data
143 std::vector<uint8_t> newAllWorkerInfosVector(
144 (uint8_t*)allWorkerInfos.c_str(),
145 (uint8_t*)allWorkerInfos.c_str() + allWorkerInfos.length());
146 store.set(allWorkerInfosKey, newAllWorkerInfosVector);
147 }
148
149 const string storeKeyBarrierId = "_ID_";
150 const string storeKeyProcessCount = "PROCESS_COUNT";
151 const string storeKeyActiveCallCount = "ACTIVE_CALLS";
152 const string storeKeyReady = "READY";
153 static std::atomic<int> barrierId(0);
154
getNextKeyIds()155 static std::tuple<std::string, std::string, std::string> getNextKeyIds() {
156 barrierId++;
157 auto newBarrierId = barrierId.load();
158 std::string processCountKey = fmt::format(
159 "{}{}{}", storeKeyProcessCount, storeKeyBarrierId, newBarrierId);
160 std::string activeCallCountKey = fmt::format(
161 "{}{}{}", storeKeyActiveCallCount, storeKeyBarrierId, newBarrierId);
162 std::string barrierKey =
163 fmt::format("{}{}{}", storeKeyReady, storeKeyBarrierId, newBarrierId);
164 return std::make_tuple(
165 std::move(processCountKey),
166 std::move(activeCallCountKey),
167 std::move(barrierKey));
168 }
169
170 // Synchronize process with all other agent processes strictly using store
171 // Block until all ``RpcAgent``s reach this method.
172 // Returns total number of active calls of all RPC agents in the group
syncCallCount(::c10d::PrefixStore store,const int worldSize,int activeCalls)173 int syncCallCount(
174 ::c10d::PrefixStore store,
175 const int worldSize,
176 int activeCalls) {
177 auto [processCountKey, activeCallCountKey, readyKey] = getNextKeyIds();
178
179 // Add to keys which will record the number of processes and active calls
180 store.add(activeCallCountKey, activeCalls);
181 int totalProcessCount = store.add(processCountKey, 1);
182
183 // The last worker will need to set the ready key
184 if (totalProcessCount == worldSize) {
185 store.set(readyKey, std::vector<uint8_t>());
186 }
187
188 // Wait on the ready key to be set
189 store.wait(std::vector<std::string>{readyKey});
190
191 // Read count of active calls which may have changed
192 auto activeCallCountData = store.get(activeCallCountKey);
193 int totalCallCount = std::stoi(
194 std::string(activeCallCountData.begin(), activeCallCountData.end()));
195 return totalCallCount;
196 }
197
198 } // namespace torch::distributed::rpc
199