xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/TraceUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/core/ScalarType.h>
3 #include <c10/util/ApproximateClock.h>
4 #include <c10/util/irange.h>
5 #include <c10/util/string_view.h>
6 #include <torch/csrc/distributed/c10d/Store.hpp>
7 #include <torch/csrc/distributed/c10d/Types.hpp>
8 #include <torch/csrc/distributed/c10d/Utils.hpp>
9 #include <torch/csrc/jit/serialization/pickler.h>
10 #include <torch/csrc/profiler/combined_traceback.h>
11 #include <chrono>
12 
13 #include <sys/types.h>
14 #include <cstdlib>
15 #include <fstream>
16 #include <string>
17 #include <system_error>
18 #include <vector>
19 
20 namespace c10d {
21 
22 // A struct to hold the latest status of the process group.
23 struct ProcessGroupStatus {
24   // the sequential number of the last collective enqueued into workMetaList_
25   // This is useful for indentifying a rank that has not join a collective
26   // initialized to be -1 to indicate no collective has been enqueued
27   int64_t lastEnqueuedSeq{-1};
28   // the sequential number of the last collective started as the kernel
29   int64_t lastStartedSeq{-1};
30   // the sequential number of the last colletive completed marked by
31   // the watchdog thread
32   // initialized to be -1 to indicate no collective has been completed
33   int64_t lastCompletedSeq{-1};
34 
35   // the name of the last collective enqueued into workMetaList_
36   std::string lastEnqueuedWorkName;
37   // the name of the last collective started as the kernel
38   std::string lastStartedWorkName;
39   // the name of the last collective completed
40   std::string lastCompletedWorkName;
41 
42   // the sizes of the last work enqueued
43   size_t lastEnqueuedNumelIn;
44   size_t lastEnqueuedNumelOut;
45   // the sizes of the last work completed
46   size_t lastCompletedNumelIn;
47   size_t lastCompletedNumelOut;
48 };
49 
getTraceStartKey(const std::string & pgName,int rank)50 inline std::string getTraceStartKey(const std::string& pgName, int rank) {
51   return pgName + "_" + std::to_string(rank) + "_trace_start";
52 }
53 
getTraceEndKey(const std::string & pgName,int rank)54 inline std::string getTraceEndKey(const std::string& pgName, int rank) {
55   return pgName + "_" + std::to_string(rank) + "_trace_end";
56 }
57 
traceUpdate(c10::intrusive_ptr<Store> & store,const std::string & key,uint64_t seq,const std::string & col)58 inline bool traceUpdate(
59     c10::intrusive_ptr<Store>& store,
60     const std::string& key,
61     uint64_t seq,
62     const std::string& col) {
63   std::vector<uint8_t> value(col.size() + sizeof(seq) + 1);
64   memcpy(value.data(), &seq, sizeof(seq));
65   memcpy(value.data() + sizeof(seq), col.data(), col.size());
66   try {
67     store->set(key, value);
68     return true;
69   } catch (...) {
70     LOG(ERROR) << "Store is down while updating #" << seq << " with key "
71                << key;
72     return false;
73   }
74   return true;
75 }
76 
77 enum TraceDebugEvent {
78   kEventStart,
79   kEventEnd,
80 };
81 // <seq, <rank, <col, start/end>>>
82 using TraceMap =
83     std::map<uint64_t, std::map<int, std::pair<std::string, TraceDebugEvent>>>;
84 
ranksToString(const std::vector<int> & ranks)85 inline std::string ranksToString(const std::vector<int>& ranks) {
86   std::string str;
87   for (int rank : ranks) {
88     if (str.empty()) {
89       str = std::to_string(rank);
90     } else {
91       str += ", " + std::to_string(rank);
92     }
93   }
94   return str;
95 }
96 
ranksFromTrace(const std::vector<std::pair<int,std::string>> & items)97 inline std::string ranksFromTrace(
98     const std::vector<std::pair<int, std::string>>& items) {
99   std::string ranks;
100   for (auto& p : items) {
101     if (ranks.empty()) {
102       ranks = std::to_string(p.first);
103     } else {
104       ranks += ", " + std::to_string(p.first);
105     }
106   }
107   return ranks;
108 }
109 
analyzeMissingRanks(const std::vector<int> & missingRanks)110 inline std::string analyzeMissingRanks(const std::vector<int>& missingRanks) {
111   return c10::str(
112       "\n\t - To our best knowledge, ranks [",
113       ranksToString(missingRanks),
114       "] are the lagging ranks that caused this timeout. "
115       "They never joined any collectives");
116 }
117 
analyzeLaggingRanks(const TraceMap & traceMap)118 inline std::string analyzeLaggingRanks(const TraceMap& traceMap) {
119   uint64_t lagSeq = traceMap.begin()->first;
120   std::vector<int> startRanks;
121   std::vector<int> endRanks;
122   for (auto& p : traceMap.begin()->second) {
123     if (p.second.second == kEventStart) {
124       startRanks.push_back(p.first);
125     } else {
126       endRanks.push_back(p.first);
127     }
128   }
129   std::string report =
130       "\n\t - To our best knowledge, the lagging/dead/mismatched ranks "
131       "that caused the desync are:";
132   if (startRanks.size()) {
133     report += c10::str(
134         "\n\t   - [",
135         ranksToString(startRanks),
136         "] joined but didn't finish collective #",
137         lagSeq,
138         " (count from 1)");
139   }
140   if (endRanks.size()) {
141     report += c10::str(
142         "\n\t     [",
143         ranksToString(endRanks),
144         "] finished collective #",
145         lagSeq,
146         ", but didn't join collective #",
147         lagSeq + 1,
148         " (count from 1)");
149   }
150   return report;
151 }
152 
dumpSnapshot(TraceMap & traceMap)153 inline std::string dumpSnapshot(TraceMap& traceMap) {
154   std::string report = "\n\t - Snapshot of ranks' latest states:";
155   for (auto& tracePair : traceMap) {
156     uint64_t seq = tracePair.first;
157     std::map<int, std::pair<std::string, TraceDebugEvent>>& subMap =
158         tracePair.second;
159 
160     std::unordered_map<std::string, std::vector<int>> collectivesStart;
161     std::unordered_map<std::string, std::vector<int>> collectivesEnd;
162     for (auto& p : subMap) {
163       int rank = p.first;
164       const std::string& col = p.second.first;
165       if (p.second.second == kEventStart) {
166         collectivesStart[col].push_back(rank);
167       } else {
168         collectivesEnd[col].push_back(rank);
169       }
170     }
171 
172     if (collectivesStart.size()) {
173       report += c10::str("\n\t   #", seq, " started ranks:");
174       for (auto& mapPair : collectivesStart) {
175         report += c10::str(
176             "\n\t     [",
177             ranksToString(mapPair.second),
178             "] started ",
179             mapPair.first);
180       }
181     }
182     if (collectivesEnd.size()) {
183       report += c10::str("\n\t   #", seq, " finished ranks:");
184       for (auto& mapPair : collectivesEnd) {
185         report += c10::str(
186             "\n\t     [",
187             ranksToString(mapPair.second),
188             "] finished ",
189             mapPair.first);
190       }
191     }
192   }
193   return report;
194 }
195 
parseTraceValue(c10::intrusive_ptr<Store> & store,const std::string & key,uint64_t & seq,std::string & col)196 inline bool parseTraceValue(
197     c10::intrusive_ptr<Store>& store,
198     const std::string& key,
199     uint64_t& seq,
200     std::string& col) {
201   try {
202     std::vector<uint8_t> traceValue = store->get(key);
203     memcpy(&seq, traceValue.data(), sizeof(seq));
204     std::string colName((char*)traceValue.data() + sizeof(seq));
205     col = colName;
206     return true;
207   } catch (...) {
208     LOG(ERROR) << "Store is down while getting key " << key;
209     return false;
210   }
211   return true;
212 }
213 
retrieveDesyncReport(c10::intrusive_ptr<Store> & store,const std::string & pgName,int myRank,int worldSize)214 inline std::string retrieveDesyncReport(
215     c10::intrusive_ptr<Store>& store,
216     const std::string& pgName,
217     int myRank,
218     int worldSize) {
219   std::string report;
220 
221   uint64_t thisSeq;
222   std::string thisCol;
223 
224   std::vector<int> missingRanks;
225   TraceMap traceMap;
226 
227   for (const auto rank : c10::irange(worldSize)) {
228     // Build traceMapStart.
229     uint64_t seqStart;
230     {
231       std::string traceKeyStart = getTraceStartKey(pgName, rank);
232       if (!store->check({traceKeyStart})) {
233         missingRanks.push_back(rank);
234         continue;
235       }
236       std::string col;
237       if (!parseTraceValue(store, traceKeyStart, seqStart, col)) {
238         return report;
239       }
240       traceMap[seqStart].emplace(rank, std::make_pair(col, kEventStart));
241       if (rank == myRank) {
242         thisSeq = seqStart;
243         thisCol = std::move(col);
244       }
245     }
246 
247     // Build traceMapEnd.
248     {
249       std::string traceKeyEnd = getTraceEndKey(pgName, rank);
250       if (!store->check({traceKeyEnd})) {
251         continue;
252       }
253       uint64_t seq;
254       std::string col;
255       if (!parseTraceValue(store, traceKeyEnd, seq, col)) {
256         return report;
257       }
258       if (seq == seqStart) {
259         traceMap[seq][rank].second = kEventEnd;
260       }
261     }
262   }
263 
264   TORCH_INTERNAL_ASSERT(
265       !missingRanks.empty() || !traceMap.empty(),
266       "Trace shouldn't be empty while enabled GLOO_ASYNC_TIMEOUT_DEBUG");
267   TORCH_INTERNAL_ASSERT(
268       !thisCol.empty(),
269       "Timeout rank [",
270       myRank,
271       "] must have collective tracking iteam in c10::Store trace");
272   TORCH_INTERNAL_ASSERT(
273       traceMap[thisSeq][myRank].second == kEventStart,
274       "Timeout rank [",
275       myRank,
276       "] last trace item must be kEventStart. thisSeq = ",
277       thisSeq,
278       ", col = ",
279       thisCol);
280 
281   report += c10::str(
282       "\n\t - [", myRank, "] Timeout at collective: ", thisCol, ", #", thisSeq);
283 
284   if (!missingRanks.empty()) {
285     report += analyzeMissingRanks(missingRanks);
286   } else {
287     report += analyzeLaggingRanks(traceMap);
288     report += dumpSnapshot(traceMap);
289   }
290 
291   return report;
292 }
293 
pickle_str(const c10::IValue & v)294 inline std::string pickle_str(const c10::IValue& v) {
295   std::vector<char> result;
296   {
297     auto writer = [&](const char* data, size_t size) {
298       result.insert(result.end(), data, data + size);
299     };
300     torch::jit::Pickler pickler(
301         writer, nullptr, nullptr, nullptr, nullptr, false);
302     pickler.protocol();
303     pickler.pushIValue(v);
304     pickler.stop();
305   }
306   return std::string(result.begin(), result.end());
307 }
308 
get_python_cpp_trace()309 inline std::string get_python_cpp_trace() {
310   // usage:
311   // LOG(INFO) << "stacktrace: "
312   //           << get_python_cpp_trace();
313   // warn: might be slow in getting cpp traces
314   // because of slow/broken addr2line
315   // in different system libs
316   std::shared_ptr<torch::CapturedTraceback> tb =
317       torch::CapturedTraceback::gather(
318           /*python=*/true, /*script=*/true, /*cpp=*/true);
319   torch::SymbolizedTracebacks s_tbs = torch::symbolize({tb.get()});
320   const auto& s_tb = s_tbs.tracebacks.at(0);
321   std::stringstream oss;
322   for (auto idx : c10::irange(s_tb.size())) {
323     auto frame_id = s_tb[idx];
324     const auto& frame = s_tbs.all_frames.at(frame_id);
325     oss << "#" << idx << " " << frame.funcname << " from " << frame.filename
326         << ":" << frame.lineno << std::endl;
327   }
328   return oss.str();
329 }
330 
new_dict()331 inline c10::Dict<c10::IValue, c10::IValue> new_dict() {
332   return c10::Dict<c10::IValue, c10::IValue>(
333       c10::AnyType::get(), c10::AnyType::get());
334 }
335 
new_list()336 inline c10::List<c10::IValue> new_list() {
337   return c10::List<c10::IValue>(c10::AnyType::get());
338 }
339 
ranks_str(const std::vector<uint64_t> & ranks)340 inline std::string ranks_str(const std::vector<uint64_t>& ranks) {
341   std::string str;
342   for (const auto& rank : ranks) {
343     if (str.empty()) {
344       str = std::to_string(rank);
345     } else {
346       str += ", " + std::to_string(rank);
347     }
348   }
349   return c10::str("[", str, "]");
350 }
351 
352 } // namespace c10d
353