xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/agent_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <fmt/format.h>
2 #include <torch/csrc/distributed/rpc/agent_utils.h>
3 
4 namespace torch::distributed::rpc {
5 
collectNames(::c10d::PrefixStore store,const worker_id_t selfId,const std::string & selfName,const int worldSize)6 std::unordered_map<std::string, worker_id_t> collectNames(
7     ::c10d::PrefixStore store,
8     const worker_id_t selfId,
9     const std::string& selfName,
10     const int worldSize) {
11   std::vector<uint8_t> selfNameVector(
12       (uint8_t*)selfName.c_str(),
13       (uint8_t*)selfName.c_str() + selfName.length());
14   store.set(std::to_string(selfId), selfNameVector);
15 
16   std::unordered_map<std::string, worker_id_t> nameToId;
17   nameToId.reserve(worldSize);
18   nameToId.emplace(selfName, selfId);
19   for (worker_id_t workerId = 0; workerId < worldSize; ++workerId) {
20     if (workerId == selfId) {
21       continue;
22     }
23     std::vector<uint8_t> workerNameVector = store.get(std::to_string(workerId));
24     std::string workerName(
25         (char*)workerNameVector.data(), workerNameVector.size());
26 
27     TORCH_CHECK(
28         nameToId.find(workerName) == nameToId.end(),
29         "RPC worker name ",
30         workerName,
31         " is not unique. Workers ",
32         nameToId.find(workerName)->second,
33         " and ",
34         workerId,
35         " share the same name.");
36 
37     nameToId.emplace(workerName, workerId);
38   }
39   return nameToId;
40 }
41 
splitString(const std::string & s,const std::string & delim)42 static std::vector<std::string> splitString(
43     const std::string& s,
44     const std::string& delim) {
45   std::vector<std::string> tokens;
46   size_t start = 0;
47   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
48   size_t end;
49   // Iterate through each delimiter
50   while ((end = s.find(delim, start)) != std::string::npos) {
51     tokens.emplace_back(s.substr(start, end - start));
52     start = end + delim.length();
53   }
54   tokens.emplace_back(s.substr(start));
55   return tokens;
56 }
57 
58 const std::string allWorkerInfosKey = "_ALL_WORKER_INFOS";
59 
collectCurrentNames(::c10d::PrefixStore store,const worker_id_t selfId,const std::string & selfName)60 std::unordered_map<std::string, worker_id_t> collectCurrentNames(
61     ::c10d::PrefixStore store,
62     const worker_id_t selfId,
63     const std::string& selfName) {
64   std::vector<uint8_t> selfNameVector(
65       (uint8_t*)selfName.c_str(),
66       (uint8_t*)selfName.c_str() + selfName.length());
67 
68   // Check that ID does not already exist and set {ID : NAME}
69   std::vector<uint8_t> resultVector = store.compareSet(
70       std::to_string(selfId), std::vector<uint8_t>(), selfNameVector);
71   TORCH_CHECK(
72       resultVector == selfNameVector,
73       "RPC worker id ",
74       selfId,
75       " is not unique. Worker ",
76       resultVector,
77       " and already has ID and ",
78       selfNameVector,
79       " cannot be added.");
80 
81   store.set(std::to_string(selfId), selfNameVector);
82 
83   std::unordered_map<std::string, worker_id_t> nameToId;
84   nameToId.emplace(selfName, selfId);
85 
86   // Check to see if there is list of worker names in the store
87   bool worker_names_available =
88       store.check(std::vector<std::string>{allWorkerInfosKey});
89   std::string allWorkerInfos;
90   if (worker_names_available) {
91     // Get the current list of workers
92     std::vector<uint8_t> allWorkerInfosKeyVector = store.get(allWorkerInfosKey);
93     allWorkerInfos = std::string(
94         (char*)allWorkerInfosKeyVector.data(), allWorkerInfosKeyVector.size());
95     // workerInfos are comma separated with a comma at the end (e.g.
96     // "Name1-Rank1,Name2-Rank2,Name3-Rank2,") parse list of workers.
97     if (!allWorkerInfos.empty()) {
98       for (const std::string& workerInfoString : splitString(
99                allWorkerInfos.substr(0, allWorkerInfos.size() - 1), ",")) {
100         auto workerInfoVec = splitString(workerInfoString, "-");
101         std::string workerName = workerInfoVec.at(0);
102         int workerId = std::stoi(workerInfoVec.at(1));
103 
104         TORCH_CHECK(
105             nameToId.find(workerName) == nameToId.end(),
106             "RPC worker name ",
107             workerName,
108             " is not unique. Workers ",
109             nameToId.find(workerName)->second,
110             " and ",
111             workerId,
112             " share the same name.");
113 
114         nameToId.emplace(workerName, workerId);
115       }
116     }
117   }
118   // Add own name to worker list
119   allWorkerInfos = fmt::format("{}{}-{},", allWorkerInfos, selfName, selfId);
120   std::vector<uint8_t> allWorkerInfosVector(
121       (uint8_t*)allWorkerInfos.c_str(),
122       (uint8_t*)allWorkerInfos.c_str() + allWorkerInfos.length());
123   store.set(allWorkerInfosKey, allWorkerInfosVector);
124 
125   return nameToId;
126 }
127 
removeCurrentName(::c10d::PrefixStore store,const worker_id_t selfId,const std::string & selfName)128 void removeCurrentName(
129     ::c10d::PrefixStore store,
130     const worker_id_t selfId,
131     const std::string& selfName) {
132   // Get current list of names/ranks
133   std::vector<uint8_t> allWorkerInfosKeyVector = store.get(allWorkerInfosKey);
134   std::string allWorkerInfos = std::string(
135       (char*)allWorkerInfosKeyVector.data(), allWorkerInfosKeyVector.size());
136 
137   // Remove the current name and rank
138   std::string str_to_erase = fmt::format("{}-{},", selfName, selfId);
139   int start_position_to_erase = allWorkerInfos.find(str_to_erase);
140   allWorkerInfos.erase(start_position_to_erase, str_to_erase.length());
141 
142   // Set the new data
143   std::vector<uint8_t> newAllWorkerInfosVector(
144       (uint8_t*)allWorkerInfos.c_str(),
145       (uint8_t*)allWorkerInfos.c_str() + allWorkerInfos.length());
146   store.set(allWorkerInfosKey, newAllWorkerInfosVector);
147 }
148 
149 const string storeKeyBarrierId = "_ID_";
150 const string storeKeyProcessCount = "PROCESS_COUNT";
151 const string storeKeyActiveCallCount = "ACTIVE_CALLS";
152 const string storeKeyReady = "READY";
153 static std::atomic<int> barrierId(0);
154 
getNextKeyIds()155 static std::tuple<std::string, std::string, std::string> getNextKeyIds() {
156   barrierId++;
157   auto newBarrierId = barrierId.load();
158   std::string processCountKey = fmt::format(
159       "{}{}{}", storeKeyProcessCount, storeKeyBarrierId, newBarrierId);
160   std::string activeCallCountKey = fmt::format(
161       "{}{}{}", storeKeyActiveCallCount, storeKeyBarrierId, newBarrierId);
162   std::string barrierKey =
163       fmt::format("{}{}{}", storeKeyReady, storeKeyBarrierId, newBarrierId);
164   return std::make_tuple(
165       std::move(processCountKey),
166       std::move(activeCallCountKey),
167       std::move(barrierKey));
168 }
169 
170 // Synchronize process with all other agent processes strictly using store
171 // Block until all ``RpcAgent``s reach this method.
172 // Returns total number of active calls of all RPC agents in the group
syncCallCount(::c10d::PrefixStore store,const int worldSize,int activeCalls)173 int syncCallCount(
174     ::c10d::PrefixStore store,
175     const int worldSize,
176     int activeCalls) {
177   auto [processCountKey, activeCallCountKey, readyKey] = getNextKeyIds();
178 
179   // Add to keys which will record the number of processes and active calls
180   store.add(activeCallCountKey, activeCalls);
181   int totalProcessCount = store.add(processCountKey, 1);
182 
183   // The last worker will need to set the ready key
184   if (totalProcessCount == worldSize) {
185     store.set(readyKey, std::vector<uint8_t>());
186   }
187 
188   // Wait on the ready key to be set
189   store.wait(std::vector<std::string>{readyKey});
190 
191   // Read count of active calls which may have changed
192   auto activeCallCountData = store.get(activeCallCountKey);
193   int totalCallCount = std::stoi(
194       std::string(activeCallCountData.begin(), activeCallCountData.end()));
195   return totalCallCount;
196 }
197 
198 } // namespace torch::distributed::rpc
199