xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/NCCLUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
2 #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
3 #include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
4 
5 #include <c10/util/CallOnce.h>
6 #include <c10/util/env.h>
7 #include <algorithm>
8 
9 #ifdef USE_C10D_NCCL
10 #include <vector>
11 
12 #include <cuda_runtime.h>
13 #include <mutex>
14 
15 #include <nlohmann/json.hpp>
16 
17 namespace {
18 constexpr int64_t kCommInitBusyWaitMillis = 10;
19 } // namespace
20 
21 namespace c10d {
22 
getNcclComm()23 ncclComm_t NCCLComm::getNcclComm() {
24   std::unique_lock<std::mutex> lock(mutex_);
25   if (aborted_) {
26     auto commFailureMsg = commFailureReason_ != std::nullopt
27         ? c10::str(" Original reason for failure was: ", *commFailureReason_)
28         : "";
29     TORCH_CHECK_WITH(
30         DistBackendError,
31         false,
32         c10::str(
33             "NCCL communicator was aborted on rank ",
34             rank_,
35             ". ",
36             commFailureMsg));
37   }
38   // only wait for initialization if nonblocking mode is enabled
39   if (!initialized_ && nccl_use_nonblocking()) {
40     waitUntilInitialized(nccl_nonblocking_timeout());
41   }
42 
43   return ncclComm_;
44 }
45 
waitUntilInitialized(int timeoutSecs)46 void NCCLComm::waitUntilInitialized(int timeoutSecs) {
47   auto startTimepoint = std::chrono::steady_clock::now();
48   while (!initialized_) {
49     if (ncclComm_) {
50       ncclResult_t result;
51       ncclCommGetAsyncError(ncclComm_, &result);
52       if (result == ncclSuccess) {
53         LOG(INFO) << "Rank " << rank_ << ": NCCL communicator is initialized.";
54         initialized_ = true;
55         break;
56       }
57     }
58     auto currentTimepoint = std::chrono::steady_clock::now();
59     auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
60                            currentTimepoint - startTimepoint)
61                            .count();
62     if (timeElapsed > timeoutSecs) {
63       std::string err = "NCCL timeout in communicator initialization.";
64       TORCH_CHECK_WITH(DistBackendError, false, err);
65     }
66     std::this_thread::sleep_for(
67         std::chrono::milliseconds(kCommInitBusyWaitMillis));
68   }
69 }
70 
71 #if defined(NCCL_HAS_COMM_SPLIT) && !defined(FBCODE_CAFFE2)
72 // last argument to split() API is not used to support
73 // multiple implementations
split(NCCLComm * source,int color_id,int rank,ncclConfig_t & config,std::vector<uint64_t> & ranks_ull)74 std::shared_ptr<NCCLComm> NCCLComm::split(
75     NCCLComm* source,
76     int color_id,
77     int rank,
78     ncclConfig_t& config,
79     std::vector<uint64_t>& ranks_ull) {
80   auto comm = std::make_shared<NCCLComm>();
81   C10D_NCCL_CHECK(
82       ncclCommSplit(
83           source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config),
84       std::nullopt);
85   ++source->ncclCommSplitCounter_;
86   comm->rank_ = rank;
87   if (!nccl_use_nonblocking()) {
88     comm->initialized_ = true;
89   }
90   return comm;
91 }
92 #endif
93 
getNcclVersion()94 std::string getNcclVersion() {
95   static c10::once_flag ncclGetVersionFlag;
96   static std::string versionString;
97 
98   c10::call_once(ncclGetVersionFlag, []() {
99     int version;
100     ncclResult_t status = ncclGetVersion(&version);
101     // can't compute the version if call did not return successfully or version
102     // code < 100 (corresponding to 0.1.0)
103     if (status != ncclSuccess || version < 100) {
104       versionString = "Unknown NCCL version";
105     } else {
106       // NCCL changed version coding starting 2.9
107       const int majorBase = version < 2900 ? 1000 : 10000;
108       const int minorBase = 100;
109       auto ncclMajor = version / majorBase;
110       auto ncclMinor = (version % majorBase) / minorBase;
111       auto ncclPatch =
112           version % (ncclMajor * majorBase + ncclMinor * minorBase);
113       versionString = std::to_string(ncclMajor) + "." +
114           std::to_string(ncclMinor) + "." + std::to_string(ncclPatch);
115 #ifdef NCCL_SUFFIX
116       const auto ncclSuffix = std::string(NCCL_SUFFIX);
117       if (ncclSuffix.length()) {
118         versionString += "." + ncclSuffix;
119       }
120 #endif
121     }
122   });
123 
124   return versionString;
125 }
126 
127 #ifdef USE_C10D_NCCL
hashTensors(const std::vector<at::Tensor> & tensors)128 size_t hashTensors(const std::vector<at::Tensor>& tensors) {
129   size_t hash = 0;
130   for (auto& tensor : tensors) {
131     if (tensor.numel() > 0 && tensor.storage()) {
132       size_t data_size = tensor.storage().nbytes();
133       if (data_size > 0 && tensor.storage().data_ptr()) {
134         auto src = static_cast<const char*>(tensor.storage().data_ptr().get());
135         char* dst = (char*)std::calloc(data_size, sizeof(char));
136         // This is needed so that we trigger a device synchronization so we can
137         // get the collective finished if launched on GPU and hash its output.
138         cudaMemcpy(dst, src, data_size, cudaMemcpyDeviceToHost);
139         for (size_t i = 0; i < data_size; ++i) {
140           // Update the hash for each byte in the tensor
141           hash = c10::hash_combine(
142               hash, c10::get_hash(((char*)dst)[i], data_size));
143         }
144         free(dst);
145       }
146     }
147   }
148   return hash;
149 }
150 #endif
151 
nccl_use_nonblocking()152 bool nccl_use_nonblocking() {
153   static bool nccl_use_nonblocking_ =
154       c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
155   if (nccl_use_nonblocking_) {
156     TORCH_WARN_ONCE("Using experimental non-blocking NCCL communicator.");
157   }
158   return nccl_use_nonblocking_;
159 }
160 
_parse_nccl_nonblocking_timeout()161 int _parse_nccl_nonblocking_timeout() {
162   const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT");
163   int timeout = -1;
164   if (val) {
165     const std::string config(val);
166     timeout = std::stoi(config);
167     if (!nccl_use_nonblocking() && timeout > 0) {
168       TORCH_WARN(
169           "TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false.");
170       timeout = -1;
171     }
172   }
173   return timeout;
174 }
175 
nccl_nonblocking_timeout()176 int nccl_nonblocking_timeout() {
177   static int timeout = _parse_nccl_nonblocking_timeout();
178   return timeout;
179 }
180 
ncclGetErrorWithVersion(ncclResult_t error)181 std::string ncclGetErrorWithVersion(ncclResult_t error) {
182   return std::string(ncclGetErrorString(error)) + ", NCCL version " +
183       getNcclVersion();
184 }
185 
186 // Provides additional detail into NCCL error codes based on when these are
187 // thrown in the NCCL codebase.
getNcclErrorDetailStr(ncclResult_t error,std::optional<std::string> processGroupFailureReason)188 std::string getNcclErrorDetailStr(
189     ncclResult_t error,
190     std::optional<std::string> processGroupFailureReason /* = std::nullopt */
191 ) {
192   // Prioritize failure reason provided by PG NCCL first, as it can abort
193   // communicators when it encounters collective timeouts, etc.
194   if (processGroupFailureReason != std::nullopt) {
195     return *processGroupFailureReason;
196   }
197   std::string interpret;
198   std::string err;
199 #ifdef ENABLE_NCCL_GET_LAST_ERROR
200   auto ret = ncclGetLastError(NULL);
201   if (ret) {
202     err = "\nLast error:\n" + std::string(ret);
203   } else {
204     err = "\nLast error: Unknown NCCL Error\n";
205   }
206 #endif
207   switch (error) {
208     case ncclUnhandledCudaError:
209       interpret = "ncclUnhandledCudaError: Call to CUDA function failed.";
210       break;
211     case ncclSystemError:
212       interpret =
213           "ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. ";
214 #ifndef NCCL_REMOTE_ERROR
215       // Before ncclRemoteError was created, unexpected remote disconnect was
216       // categorized as ncclSystemError
217       interpret += "It can be also caused by unexpected exit of a remote peer.";
218 #endif
219       break;
220     case ncclInternalError:
221       interpret = "ncclInternalError: Internal check failed.";
222       break;
223     case ncclInvalidArgument:
224       interpret = "ncclInvalidArgument: Invalid value for an argument.";
225       break;
226     case ncclInvalidUsage:
227       interpret =
228           "ncclInvalidUsage: This usually reflects invalid usage of NCCL library.";
229       break;
230 #ifdef NCCL_REMOTE_ERROR
231     case ncclRemoteError:
232       interpret =
233           "ncclRemoteError: A call failed possibly due to a network error or a remote process exiting prematurely.";
234       break;
235 #endif
236     default:
237       interpret = "Unknown NCCL error!";
238   }
239   return interpret + err;
240 }
241 
242 control_plane::RegisterHandler dumpHandler{
243     "dump_nccl_trace_pickle",
__anon77ccabb30302() 244     [](const control_plane::Request& req, control_plane::Response& res) {
245       const auto params = req.params();
246       size_t validParamCount = 0;
247 
248       // valid params
249       const std::string includeCollectivesStr = "includecollectives";
250       const std::string includeStackTracesStr = "includestacktraces";
251       const std::string onlyActiveStr = "onlyactive";
252 
253       std::unordered_map<std::string, bool> processedParams = {
254           {includeCollectivesStr, true},
255           {includeStackTracesStr, true},
256           {onlyActiveStr, false}};
257 
258       for (const auto& [paramName, paramValue] : params) {
259         auto it = processedParams.find(paramName);
260         if (it != processedParams.end()) {
261           validParamCount++;
262           if (paramValue == "true") {
263             it->second = true;
264           } else if (paramValue == "false") {
265             it->second = false;
266           } else {
267             res.setStatus(400);
268             res.setContent(
269                 "Invalid value for " + paramName +
270                     " valid values are true or false",
271                 "text/plain");
272             return;
273           }
274         }
275       }
276       if (validParamCount < params.size()) {
277         res.setStatus(400);
278         res.setContent(
279             "Invalid parameters - unexpected param passed in", "text/plain");
280         return;
281       }
282       res.setContent(
283           dump_nccl_trace(
284               processedParams[includeCollectivesStr],
285               processedParams[includeStackTracesStr],
286               processedParams[onlyActiveStr]),
287           "application/octet-stream");
288     }};
289 
290 control_plane::RegisterHandler jsonDumpHandler{
291     "dump_nccl_trace_json",
__anon77ccabb30402() 292     [](const control_plane::Request& req, control_plane::Response& res) {
293       const auto params = req.params();
294       size_t validParamCount = 0;
295 
296       // valid params
297       const std::string includeCollectivesStr = "includecollectives";
298       const std::string onlyActiveStr = "onlyactive";
299 
300       std::unordered_map<std::string, bool> processedParams = {
301           {includeCollectivesStr, true}, {onlyActiveStr, false}};
302 
303       for (const auto& [paramName, paramValue] : params) {
304         auto it = processedParams.find(paramName);
305         if (it != processedParams.end()) {
306           validParamCount++;
307           if (paramValue == "true") {
308             it->second = true;
309           } else if (paramValue == "false") {
310             it->second = false;
311           } else {
312             res.setStatus(400);
313             res.setContent(
314                 "Invalid value for " + paramName +
315                     " valid values are true or false",
316                 "text/plain");
317             return;
318           }
319         }
320       }
321       if (validParamCount < params.size()) {
322         res.setStatus(400);
323         res.setContent(
324             "Invalid parameters - unexpected param passed in", "text/plain");
325         return;
326       }
327       res.setStatus(200);
328       res.setContent(
329           dump_nccl_trace_json(
330               processedParams[includeCollectivesStr],
331               processedParams[onlyActiveStr]),
332           "application/json");
333     }};
334 
write(const std::string & ncclTrace)335 void DebugInfoWriter::write(const std::string& ncclTrace) {
336   // Open a file for writing. The ios::binary flag is used to write data as
337   // binary.
338   std::ofstream file(filename_, std::ios::binary);
339 
340   // Check if the file was opened successfully.
341   if (!file.is_open()) {
342     LOG(ERROR) << "Error opening file for writing NCCLPG debug info: "
343                << filename_;
344     return;
345   }
346 
347   file.write(ncclTrace.data(), ncclTrace.size());
348   LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_;
349 }
350 
getWriter(int rank)351 DebugInfoWriter& DebugInfoWriter::getWriter(int rank) {
352   if (writer_ == nullptr) {
353     std::string fileNamePrefix = getCvarString(
354         {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
355     // Using std::unique_ptr here to auto-delete the writer object
356     // when the pointer itself is destroyed.
357     std::unique_ptr<DebugInfoWriter> writerPtr(
358         new DebugInfoWriter(fileNamePrefix, rank));
359     DebugInfoWriter::registerWriter(std::move(writerPtr));
360   }
361   return *writer_;
362 }
363 
registerWriter(std::unique_ptr<DebugInfoWriter> writer)364 void DebugInfoWriter::registerWriter(std::unique_ptr<DebugInfoWriter> writer) {
365   TORCH_CHECK_WITH(
366       DistBackendError,
367       hasWriterRegistered_.load() == false,
368       "debugInfoWriter already registered");
369   hasWriterRegistered_.store(true);
370   writer_ = std::move(writer);
371 }
372 
record(size_t pg_id,const std::tuple<std::string,std::string> & pg_name,size_t collective_seq_id,size_t p2p_seq_id,size_t op_id,std::string profiling_name,const std::vector<at::Tensor> & inputs,const std::vector<at::Tensor> & outputs,Event * start,Event * end,std::chrono::milliseconds timeout_ms,std::shared_ptr<ProcessGroupStatus> pg_status,bool isP2P)373 std::optional<size_t> NCCLTraceBuffer::record(
374     size_t pg_id,
375     const std::tuple<std::string, std::string>& pg_name,
376     size_t collective_seq_id,
377     size_t p2p_seq_id,
378     size_t op_id,
379     std::string profiling_name,
380     const std::vector<at::Tensor>& inputs,
381     const std::vector<at::Tensor>& outputs,
382     Event* start,
383     Event* end,
384     std::chrono::milliseconds timeout_ms,
385     std::shared_ptr<ProcessGroupStatus> pg_status,
386     bool isP2P) {
387   if (!enabled_) {
388     return std::nullopt;
389   }
390   if (all_pg_status_.find(pg_id) == all_pg_status_.end()) {
391     // Current pg_status is not in FR.
392     all_pg_status_[pg_id] = pg_status;
393   }
394   auto traceback =
395       torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
396   std::lock_guard<std::mutex> guard(mutex_);
397 
398   auto te = Entry{
399       id_,
400       pg_id,
401       pg_name,
402       collective_seq_id,
403       p2p_seq_id,
404       op_id,
405       std::move(profiling_name),
406       std::move(traceback),
407       std::move(start),
408       std::move(end),
409       c10::getTime(),
410       timeout_ms.count(),
411       isP2P,
412       std::nullopt,
413       std::nullopt,
414       std::nullopt,
415       {},
416       {},
417       {},
418       {},
419       {},
420       false};
421 
422   for (const auto& input : inputs) {
423     c10::IntArrayRef sizes = input.sizes();
424     te.input_dtypes_.push_back(input.dtype().toScalarType());
425     te.input_dims_.push_back(sizes.size());
426     te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end());
427   }
428 
429   for (const auto& output : outputs) {
430     c10::IntArrayRef sizes = output.sizes();
431     te.output_dtypes_.push_back(output.dtype().toScalarType());
432     te.output_dims_.push_back(sizes.size());
433     te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end());
434   }
435 
436   if (entries_.size() < max_entries_) {
437     entries_.emplace_back(std::move(te));
438   } else {
439     entries_[next_++] = std::move(te);
440     if (next_ == max_entries_) {
441       next_ = 0;
442     }
443   }
444   return id_++;
445 }
446 
record_pg_ranks(const std::tuple<std::string,std::string> & pg_name,std::vector<uint64_t> ranks)447 void NCCLTraceBuffer::record_pg_ranks(
448     const std::tuple<std::string, std::string>& pg_name,
449     std::vector<uint64_t> ranks) {
450   if (!enabled_) {
451     return;
452   }
453   std::lock_guard<std::mutex> guard(mutex_);
454   pg_name_to_ranks_[pg_name] = ranks;
455 }
456 
update_state(Entry & r)457 void NCCLTraceBuffer::update_state(Entry& r) {
458   if (r.start_ != nullptr) {
459     bool started = r.start_->query();
460     if (started && !r.time_discovered_started_) {
461       r.time_discovered_started_ = c10::getTime();
462     }
463   }
464   if (r.end_ != nullptr) {
465     bool completed = r.end_->query();
466     if (completed && !r.time_discovered_completed_) {
467       r.time_discovered_completed_ = c10::getTime();
468     }
469   }
470 }
471 
dump_entries()472 std::vector<NCCLTraceBuffer::Entry> NCCLTraceBuffer::dump_entries() {
473   std::lock_guard<std::mutex> guard(mutex_);
474   std::vector<Entry> result;
475   result.reserve(entries_.size());
476   result.insert(result.end(), entries_.begin() + next_, entries_.end());
477   result.insert(result.end(), entries_.begin(), entries_.begin() + next_);
478   // query any remaining events
479   for (auto& r : result) {
480     update_state(r);
481     r.start_ = r.end_ = nullptr;
482   }
483   return result;
484 }
485 
retire_id(std::optional<size_t> id,bool compute_duration)486 void NCCLTraceBuffer::retire_id(
487     std::optional<size_t> id,
488     bool compute_duration) {
489   if (!enabled_ || !id) {
490     return;
491   }
492 
493   bool can_compute_duration = false;
494   Event* startEvent = nullptr;
495   Event* endEvent = nullptr;
496   std::optional<float> duration = std::nullopt;
497 
498   std::unique_lock<std::mutex> guard(mutex_);
499 
500   Entry* entry = &entries_.at(*id % max_entries_);
501   if (entry->id_ == *id) {
502     update_state(*entry);
503 
504     if (compute_duration) {
505       can_compute_duration = entry->time_discovered_completed_.has_value() &&
506           entry->start_ && entry->end_;
507       startEvent = entry->start_;
508       endEvent = entry->end_;
509     }
510     entry->retired_ = true;
511     entry->start_ = entry->end_ = nullptr;
512   }
513 
514   if (can_compute_duration) {
515     // Compute duration without without holding the lock, because
516     // cudaEventDuration() can hang, and we need to acquire the lock before we
517     // can dump(), which we never want to block.
518     guard.unlock();
519     duration = getDurationFromEvent(*startEvent, *endEvent);
520     guard.lock();
521 
522     // Refresh the entry pointer, see if the entry has been overwritten
523     entry = &entries_.at(*id % max_entries_);
524     if (entry->id_ != *id) {
525       LOG(INFO) << "retire_id abandoned for id " << *id
526                 << ", event was overwritten while waiting to compute duration.";
527       return;
528     }
529     if (duration.has_value()) {
530       entry->duration_ = duration.value();
531     }
532   }
533 }
534 
getCollectiveTrace(bool includeStacktraces,bool onlyActive)535 const c10::List<c10::IValue> NCCLTraceBuffer::getCollectiveTrace(
536     bool includeStacktraces,
537     bool onlyActive) {
538   auto entries = new_list();
539   // Entries are returned in the order they were recorded
540   auto result = dump_entries();
541   std::vector<torch::CapturedTraceback*> tracebacks;
542   torch::SymbolizedTracebacks stracebacks;
543   std::vector<c10::IValue> all_frames;
544   if (includeStacktraces) {
545     for (auto& e : result) {
546       tracebacks.push_back(e.traceback_.get());
547     }
548     stracebacks = torch::symbolize(tracebacks);
549     for (const auto& f : stracebacks.all_frames) {
550       auto d = new_dict();
551       d.insert(name_key, f.funcname);
552       d.insert(filename_key, f.filename);
553       d.insert(line_key, int64_t(f.lineno));
554       all_frames.emplace_back(std::move(d));
555     }
556   }
557   for (auto i : c10::irange(result.size())) {
558     auto dict = new_dict();
559     auto& e = result.at(i);
560     // Skip completed events
561     if (onlyActive && e.time_discovered_completed_.has_value()) {
562       continue;
563     }
564     if (includeStacktraces) {
565       auto& tb = stracebacks.tracebacks.at(i);
566       auto frames = new_list();
567       for (int64_t frame : tb) {
568         frames.push_back(all_frames.at(frame));
569       }
570       dict.insert(frames_key, frames);
571     }
572 
573     dict.insert(record_id_key, int64_t(e.id_));
574     dict.insert(pg_id_key, int64_t(e.pg_id_));
575     dict.insert(pg_name_key, e.pg_name_);
576     dict.insert(collective_seq_id_key, int64_t(e.collective_seq_id_));
577     dict.insert(p2p_seq_id_key, int64_t(e.p2p_seq_id_));
578     dict.insert(op_id_key, int64_t(e.op_id_));
579     dict.insert(profiling_name_key, e.profiling_name_);
580     dict.insert(time_created_key, int64_t(e.time_created_));
581     if (e.duration_) {
582       dict.insert(duration_key, *e.duration_);
583     }
584 
585     auto it = e.sizes_.begin();
586     auto read_sizes = [&](const c10::SmallVector<int, 4>& dims) {
587       auto sizes = new_list();
588       for (auto dim : dims) {
589         auto arg_sizes = new_list();
590         for (C10_UNUSED auto i : c10::irange(dim)) {
591           arg_sizes.push_back(*it++);
592         }
593         sizes.push_back(arg_sizes);
594       }
595       return sizes;
596     };
597 
598     dict.insert(input_sizes_key, read_sizes(e.input_dims_));
599     std::vector<std::string> input_dtypes_strs;
600     input_dtypes_strs.reserve(e.input_dtypes_.size());
601     for (const auto& input_dtype : e.input_dtypes_) {
602       input_dtypes_strs.push_back(c10::toString(input_dtype));
603     }
604     dict.insert(input_dtypes_key, input_dtypes_strs);
605     dict.insert(output_sizes_key, read_sizes(e.output_dims_));
606     std::vector<std::string> output_dtypes_strs;
607     output_dtypes_strs.reserve(e.output_dtypes_.size());
608     for (const auto& output_dtype : e.output_dtypes_) {
609       output_dtypes_strs.push_back(c10::toString(output_dtype));
610     }
611     dict.insert(output_dtypes_key, output_dtypes_strs);
612     if (e.time_discovered_completed_.has_value()) {
613       dict.insert(state_key, completed_state);
614     } else if (e.time_discovered_started_.has_value()) {
615       dict.insert(state_key, started_state);
616     } else {
617       dict.insert(state_key, scheduled_state);
618     }
619 
620     dict.insert(
621         time_discovered_started_key,
622         e.time_discovered_started_.has_value()
623             ? int64_t(*e.time_discovered_started_)
624             : c10::IValue());
625     dict.insert(
626         time_discovered_completed_key,
627         e.time_discovered_completed_.has_value()
628             ? int64_t(*e.time_discovered_completed_)
629             : c10::IValue());
630     dict.insert(retired_key, e.retired_);
631     dict.insert(timeout_key, e.timeout_ms_);
632     dict.insert(is_p2p_key, e.isP2P_);
633 
634     entries.push_back(dict);
635   }
636   return entries;
637 }
638 
getPgConfig()639 const c10::Dict<c10::IValue, c10::IValue> NCCLTraceBuffer::getPgConfig() {
640   auto pg_config = new_dict();
641   for (const auto& [pg_name, ranks] : pg_name_to_ranks_) {
642     auto pg_info = new_dict();
643     pg_info.insert("name", std::get<0>(pg_name));
644     pg_info.insert("desc", std::get<1>(pg_name));
645     pg_info.insert("ranks", ranks_str(ranks));
646     pg_config.insert(std::get<0>(pg_name), pg_info);
647   }
648   return pg_config;
649 }
650 
651 const std::map<std::string, std::map<std::string, std::string>> NCCLTraceBuffer::
getPgConfigJson()652     getPgConfigJson() {
653   std::map<std::string, std::map<std::string, std::string>> result;
654   for (const auto& [pg_name, ranks] : pg_name_to_ranks_) {
655     auto pg_info = std::map<std::string, std::string>();
656     pg_info["name"] = std::get<0>(pg_name);
657     pg_info["desc"] = std::get<1>(pg_name);
658     pg_info["ranks"] = ranks_str(ranks);
659     result.emplace(std::get<0>(pg_name), pg_info);
660   }
661   return result;
662 }
663 
getPgStatus()664 const c10::Dict<c10::IValue, c10::IValue> NCCLTraceBuffer::getPgStatus() {
665   auto all_pg_status = new_dict();
666   for (const auto& [pg_id, status] : all_pg_status_) {
667     auto pg_status = new_dict();
668     pg_status.insert("last_enqueued_collective", status->lastEnqueuedSeq);
669     pg_status.insert("last_started_collective", status->lastStartedSeq);
670     pg_status.insert("last_completed_collective", status->lastCompletedSeq);
671     all_pg_status.insert(std::to_string(pg_id), pg_status);
672   }
673   return all_pg_status;
674 }
675 
676 const std::map<std::string, std::map<std::string, std::string>> NCCLTraceBuffer::
getPgStatusJson()677     getPgStatusJson() {
678   std::map<std::string, std::map<std::string, std::string>> result;
679   for (const auto& [pg_id, status] : all_pg_status_) {
680     auto pg_status = std::map<std::string, std::string>();
681     pg_status["last_enqueued_collective"] =
682         std::to_string(status->lastEnqueuedSeq);
683     pg_status["last_started_collective"] =
684         std::to_string(status->lastStartedSeq);
685     pg_status["last_completed_collective"] =
686         std::to_string(status->lastCompletedSeq);
687     result[std::to_string(pg_id)] = pg_status;
688   }
689   return result;
690 }
691 
dump_json(const std::optional<std::unordered_map<std::string,std::unordered_map<std::string,std::string>>> & ncclDumpMap,bool includeCollectives,bool onlyActive)692 std::string NCCLTraceBuffer::dump_json(
693     const std::optional<std::unordered_map<
694         std::string,
695         std::unordered_map<std::string, std::string>>>& ncclDumpMap,
696     bool includeCollectives,
697     bool onlyActive) {
698   using json = nlohmann::json;
699   json result;
700   result[version_key_str] = version_val_str;
701   result[pg_config_key_str] = getPgConfigJson();
702   result[pg_status_key_str] = getPgStatusJson();
703 
704   // collective trace
705   if (includeCollectives) {
706     std::list<json> entries;
707     for (auto& e : dump_entries()) {
708       json j;
709       if (onlyActive && e.time_discovered_completed_.has_value()) {
710         continue;
711       }
712       j[record_id_key_str] = int64_t(e.id_);
713       j[pg_id_key_str] = int64_t(e.pg_id_);
714       j[pg_name_key_str] = e.pg_name_;
715       j[collective_seq_id_key_str] = int64_t(e.collective_seq_id_);
716       j[p2p_seq_id_key_str] = int64_t(e.p2p_seq_id_);
717       j[op_id_key_str] = int64_t(e.op_id_);
718       j[profiling_name_key_str] = e.profiling_name_;
719       j[time_created_key_str] = int64_t(e.time_created_);
720       if (e.duration_) {
721         j[duration_key_str] = *e.duration_;
722       }
723       auto it = e.sizes_.begin();
724       auto read_sizes = [&](const c10::SmallVector<int, 4>& dims) {
725         auto sizes = std::list<std::list<int>>();
726         for (auto dim : dims) {
727           auto arg_sizes = std::list<int>();
728           for (auto i : c10::irange(dim)) {
729             (void)i;
730             arg_sizes.push_back(*it++);
731           }
732           sizes.push_back(arg_sizes);
733         }
734         return sizes;
735       };
736       j[input_sizes_key_str] = read_sizes(e.input_dims_);
737       std::vector<std::string> input_dtypes_strs;
738       input_dtypes_strs.reserve(e.input_dtypes_.size());
739       for (const auto& input_dtype : e.input_dtypes_) {
740         input_dtypes_strs.push_back(c10::toString(input_dtype));
741       }
742       j[input_dtypes_key_str] = input_dtypes_strs;
743       j[output_sizes_key_str] = read_sizes(e.output_dims_);
744       std::vector<std::string> output_dtypes_strs;
745       output_dtypes_strs.reserve(e.output_dtypes_.size());
746       for (const auto& output_dtype : e.output_dtypes_) {
747         output_dtypes_strs.push_back(c10::toString(output_dtype));
748       }
749       j[output_dtypes_key_str] = output_dtypes_strs;
750       if (e.time_discovered_completed_.has_value()) {
751         j[state_key_str] = completed_state_str;
752       } else if (e.time_discovered_started_.has_value()) {
753         j[state_key_str] = started_state_str;
754       } else {
755         j[state_key_str] = scheduled_state_str;
756       }
757       j[time_discovered_started_key_str] =
758           e.time_discovered_started_.has_value()
759           ? int64_t(*e.time_discovered_started_)
760           : 0;
761       j[time_discovered_completed_key_str] =
762           e.time_discovered_completed_.has_value()
763           ? int64_t(*e.time_discovered_completed_)
764           : 0;
765       j[retired_key_str] = e.retired_;
766       j[timeout_key_str] = e.timeout_ms_;
767       j[is_p2p_key_str] = e.isP2P_;
768       entries.emplace_back(j);
769     }
770 
771     if (entries.size() > 0) {
772       result[entries_key_str] = entries;
773     }
774   }
775 
776   if (ncclDumpMap.has_value()) {
777     result[nccl_comm_key_str] = ncclDumpMap.value();
778   }
779 
780   return result.dump();
781 }
782 
dump(const std::optional<std::unordered_map<std::string,std::unordered_map<std::string,std::string>>> & ncclDumpMap,bool includeCollectives,bool includeStackTraces,bool onlyActive)783 std::string NCCLTraceBuffer::dump(
784     const std::optional<std::unordered_map<
785         std::string,
786         std::unordered_map<std::string, std::string>>>& ncclDumpMap,
787     bool includeCollectives,
788     bool includeStackTraces,
789     bool onlyActive) {
790   auto result = new_dict();
791   // common values
792   result.insert(version_key, version_val);
793   result.insert(pg_config_key, getPgConfig());
794   result.insert(pg_status_key, getPgStatus());
795 
796   // collective trace
797   if (includeCollectives) {
798     result.insert(
799         entries_key, getCollectiveTrace(includeStackTraces, onlyActive));
800   }
801   // convert ncclDumpMap into a dictionary
802   auto per_comm_dict = new_dict();
803   if (ncclDumpMap.has_value()) {
804     for (const auto& [ncclId, ncclDump] : ncclDumpMap.value()) {
805       auto inner_dict = new_dict();
806       for (const auto& [key, value] : ncclDump) {
807         inner_dict.insert(key, value);
808       }
809       per_comm_dict.insert(ncclId, inner_dict);
810     }
811   }
812   if (per_comm_dict.size() > 0) {
813     result.insert(nccl_comm_key, per_comm_dict);
814   }
815   return pickle_str(result);
816 }
817 
818 std::unique_ptr<DebugInfoWriter> DebugInfoWriter::writer_ = nullptr;
819 std::atomic<bool> DebugInfoWriter::hasWriterRegistered_(false);
820 
getDurationFromEvent(at::cuda::CUDAEvent & ncclStartEvent,at::cuda::CUDAEvent & ncclEndEvent)821 float getDurationFromEvent(
822     at::cuda::CUDAEvent& ncclStartEvent,
823     at::cuda::CUDAEvent& ncclEndEvent) {
824   TORCH_CHECK(
825       ncclEndEvent.query(),
826       "getDuration can only be called after work is succeeded.")
827   return ncclStartEvent.elapsed_time(ncclEndEvent);
828 }
829 
830 } // namespace c10d
831 
832 #endif // USE_C10D_NCCL
833