1 #include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
2 #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
3 #include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
4
5 #include <c10/util/CallOnce.h>
6 #include <c10/util/env.h>
7 #include <algorithm>
8
9 #ifdef USE_C10D_NCCL
10 #include <vector>
11
12 #include <cuda_runtime.h>
13 #include <mutex>
14
15 #include <nlohmann/json.hpp>
16
17 namespace {
18 constexpr int64_t kCommInitBusyWaitMillis = 10;
19 } // namespace
20
21 namespace c10d {
22
getNcclComm()23 ncclComm_t NCCLComm::getNcclComm() {
24 std::unique_lock<std::mutex> lock(mutex_);
25 if (aborted_) {
26 auto commFailureMsg = commFailureReason_ != std::nullopt
27 ? c10::str(" Original reason for failure was: ", *commFailureReason_)
28 : "";
29 TORCH_CHECK_WITH(
30 DistBackendError,
31 false,
32 c10::str(
33 "NCCL communicator was aborted on rank ",
34 rank_,
35 ". ",
36 commFailureMsg));
37 }
38 // only wait for initialization if nonblocking mode is enabled
39 if (!initialized_ && nccl_use_nonblocking()) {
40 waitUntilInitialized(nccl_nonblocking_timeout());
41 }
42
43 return ncclComm_;
44 }
45
waitUntilInitialized(int timeoutSecs)46 void NCCLComm::waitUntilInitialized(int timeoutSecs) {
47 auto startTimepoint = std::chrono::steady_clock::now();
48 while (!initialized_) {
49 if (ncclComm_) {
50 ncclResult_t result;
51 ncclCommGetAsyncError(ncclComm_, &result);
52 if (result == ncclSuccess) {
53 LOG(INFO) << "Rank " << rank_ << ": NCCL communicator is initialized.";
54 initialized_ = true;
55 break;
56 }
57 }
58 auto currentTimepoint = std::chrono::steady_clock::now();
59 auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
60 currentTimepoint - startTimepoint)
61 .count();
62 if (timeElapsed > timeoutSecs) {
63 std::string err = "NCCL timeout in communicator initialization.";
64 TORCH_CHECK_WITH(DistBackendError, false, err);
65 }
66 std::this_thread::sleep_for(
67 std::chrono::milliseconds(kCommInitBusyWaitMillis));
68 }
69 }
70
71 #if defined(NCCL_HAS_COMM_SPLIT) && !defined(FBCODE_CAFFE2)
72 // last argument to split() API is not used to support
73 // multiple implementations
split(NCCLComm * source,int color_id,int rank,ncclConfig_t & config,std::vector<uint64_t> & ranks_ull)74 std::shared_ptr<NCCLComm> NCCLComm::split(
75 NCCLComm* source,
76 int color_id,
77 int rank,
78 ncclConfig_t& config,
79 std::vector<uint64_t>& ranks_ull) {
80 auto comm = std::make_shared<NCCLComm>();
81 C10D_NCCL_CHECK(
82 ncclCommSplit(
83 source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config),
84 std::nullopt);
85 ++source->ncclCommSplitCounter_;
86 comm->rank_ = rank;
87 if (!nccl_use_nonblocking()) {
88 comm->initialized_ = true;
89 }
90 return comm;
91 }
92 #endif
93
getNcclVersion()94 std::string getNcclVersion() {
95 static c10::once_flag ncclGetVersionFlag;
96 static std::string versionString;
97
98 c10::call_once(ncclGetVersionFlag, []() {
99 int version;
100 ncclResult_t status = ncclGetVersion(&version);
101 // can't compute the version if call did not return successfully or version
102 // code < 100 (corresponding to 0.1.0)
103 if (status != ncclSuccess || version < 100) {
104 versionString = "Unknown NCCL version";
105 } else {
106 // NCCL changed version coding starting 2.9
107 const int majorBase = version < 2900 ? 1000 : 10000;
108 const int minorBase = 100;
109 auto ncclMajor = version / majorBase;
110 auto ncclMinor = (version % majorBase) / minorBase;
111 auto ncclPatch =
112 version % (ncclMajor * majorBase + ncclMinor * minorBase);
113 versionString = std::to_string(ncclMajor) + "." +
114 std::to_string(ncclMinor) + "." + std::to_string(ncclPatch);
115 #ifdef NCCL_SUFFIX
116 const auto ncclSuffix = std::string(NCCL_SUFFIX);
117 if (ncclSuffix.length()) {
118 versionString += "." + ncclSuffix;
119 }
120 #endif
121 }
122 });
123
124 return versionString;
125 }
126
127 #ifdef USE_C10D_NCCL
hashTensors(const std::vector<at::Tensor> & tensors)128 size_t hashTensors(const std::vector<at::Tensor>& tensors) {
129 size_t hash = 0;
130 for (auto& tensor : tensors) {
131 if (tensor.numel() > 0 && tensor.storage()) {
132 size_t data_size = tensor.storage().nbytes();
133 if (data_size > 0 && tensor.storage().data_ptr()) {
134 auto src = static_cast<const char*>(tensor.storage().data_ptr().get());
135 char* dst = (char*)std::calloc(data_size, sizeof(char));
136 // This is needed so that we trigger a device synchronization so we can
137 // get the collective finished if launched on GPU and hash its output.
138 cudaMemcpy(dst, src, data_size, cudaMemcpyDeviceToHost);
139 for (size_t i = 0; i < data_size; ++i) {
140 // Update the hash for each byte in the tensor
141 hash = c10::hash_combine(
142 hash, c10::get_hash(((char*)dst)[i], data_size));
143 }
144 free(dst);
145 }
146 }
147 }
148 return hash;
149 }
150 #endif
151
nccl_use_nonblocking()152 bool nccl_use_nonblocking() {
153 static bool nccl_use_nonblocking_ =
154 c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
155 if (nccl_use_nonblocking_) {
156 TORCH_WARN_ONCE("Using experimental non-blocking NCCL communicator.");
157 }
158 return nccl_use_nonblocking_;
159 }
160
_parse_nccl_nonblocking_timeout()161 int _parse_nccl_nonblocking_timeout() {
162 const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT");
163 int timeout = -1;
164 if (val) {
165 const std::string config(val);
166 timeout = std::stoi(config);
167 if (!nccl_use_nonblocking() && timeout > 0) {
168 TORCH_WARN(
169 "TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false.");
170 timeout = -1;
171 }
172 }
173 return timeout;
174 }
175
nccl_nonblocking_timeout()176 int nccl_nonblocking_timeout() {
177 static int timeout = _parse_nccl_nonblocking_timeout();
178 return timeout;
179 }
180
ncclGetErrorWithVersion(ncclResult_t error)181 std::string ncclGetErrorWithVersion(ncclResult_t error) {
182 return std::string(ncclGetErrorString(error)) + ", NCCL version " +
183 getNcclVersion();
184 }
185
186 // Provides additional detail into NCCL error codes based on when these are
187 // thrown in the NCCL codebase.
getNcclErrorDetailStr(ncclResult_t error,std::optional<std::string> processGroupFailureReason)188 std::string getNcclErrorDetailStr(
189 ncclResult_t error,
190 std::optional<std::string> processGroupFailureReason /* = std::nullopt */
191 ) {
192 // Prioritize failure reason provided by PG NCCL first, as it can abort
193 // communicators when it encounters collective timeouts, etc.
194 if (processGroupFailureReason != std::nullopt) {
195 return *processGroupFailureReason;
196 }
197 std::string interpret;
198 std::string err;
199 #ifdef ENABLE_NCCL_GET_LAST_ERROR
200 auto ret = ncclGetLastError(NULL);
201 if (ret) {
202 err = "\nLast error:\n" + std::string(ret);
203 } else {
204 err = "\nLast error: Unknown NCCL Error\n";
205 }
206 #endif
207 switch (error) {
208 case ncclUnhandledCudaError:
209 interpret = "ncclUnhandledCudaError: Call to CUDA function failed.";
210 break;
211 case ncclSystemError:
212 interpret =
213 "ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. ";
214 #ifndef NCCL_REMOTE_ERROR
215 // Before ncclRemoteError was created, unexpected remote disconnect was
216 // categorized as ncclSystemError
217 interpret += "It can be also caused by unexpected exit of a remote peer.";
218 #endif
219 break;
220 case ncclInternalError:
221 interpret = "ncclInternalError: Internal check failed.";
222 break;
223 case ncclInvalidArgument:
224 interpret = "ncclInvalidArgument: Invalid value for an argument.";
225 break;
226 case ncclInvalidUsage:
227 interpret =
228 "ncclInvalidUsage: This usually reflects invalid usage of NCCL library.";
229 break;
230 #ifdef NCCL_REMOTE_ERROR
231 case ncclRemoteError:
232 interpret =
233 "ncclRemoteError: A call failed possibly due to a network error or a remote process exiting prematurely.";
234 break;
235 #endif
236 default:
237 interpret = "Unknown NCCL error!";
238 }
239 return interpret + err;
240 }
241
242 control_plane::RegisterHandler dumpHandler{
243 "dump_nccl_trace_pickle",
__anon77ccabb30302() 244 [](const control_plane::Request& req, control_plane::Response& res) {
245 const auto params = req.params();
246 size_t validParamCount = 0;
247
248 // valid params
249 const std::string includeCollectivesStr = "includecollectives";
250 const std::string includeStackTracesStr = "includestacktraces";
251 const std::string onlyActiveStr = "onlyactive";
252
253 std::unordered_map<std::string, bool> processedParams = {
254 {includeCollectivesStr, true},
255 {includeStackTracesStr, true},
256 {onlyActiveStr, false}};
257
258 for (const auto& [paramName, paramValue] : params) {
259 auto it = processedParams.find(paramName);
260 if (it != processedParams.end()) {
261 validParamCount++;
262 if (paramValue == "true") {
263 it->second = true;
264 } else if (paramValue == "false") {
265 it->second = false;
266 } else {
267 res.setStatus(400);
268 res.setContent(
269 "Invalid value for " + paramName +
270 " valid values are true or false",
271 "text/plain");
272 return;
273 }
274 }
275 }
276 if (validParamCount < params.size()) {
277 res.setStatus(400);
278 res.setContent(
279 "Invalid parameters - unexpected param passed in", "text/plain");
280 return;
281 }
282 res.setContent(
283 dump_nccl_trace(
284 processedParams[includeCollectivesStr],
285 processedParams[includeStackTracesStr],
286 processedParams[onlyActiveStr]),
287 "application/octet-stream");
288 }};
289
290 control_plane::RegisterHandler jsonDumpHandler{
291 "dump_nccl_trace_json",
__anon77ccabb30402() 292 [](const control_plane::Request& req, control_plane::Response& res) {
293 const auto params = req.params();
294 size_t validParamCount = 0;
295
296 // valid params
297 const std::string includeCollectivesStr = "includecollectives";
298 const std::string onlyActiveStr = "onlyactive";
299
300 std::unordered_map<std::string, bool> processedParams = {
301 {includeCollectivesStr, true}, {onlyActiveStr, false}};
302
303 for (const auto& [paramName, paramValue] : params) {
304 auto it = processedParams.find(paramName);
305 if (it != processedParams.end()) {
306 validParamCount++;
307 if (paramValue == "true") {
308 it->second = true;
309 } else if (paramValue == "false") {
310 it->second = false;
311 } else {
312 res.setStatus(400);
313 res.setContent(
314 "Invalid value for " + paramName +
315 " valid values are true or false",
316 "text/plain");
317 return;
318 }
319 }
320 }
321 if (validParamCount < params.size()) {
322 res.setStatus(400);
323 res.setContent(
324 "Invalid parameters - unexpected param passed in", "text/plain");
325 return;
326 }
327 res.setStatus(200);
328 res.setContent(
329 dump_nccl_trace_json(
330 processedParams[includeCollectivesStr],
331 processedParams[onlyActiveStr]),
332 "application/json");
333 }};
334
write(const std::string & ncclTrace)335 void DebugInfoWriter::write(const std::string& ncclTrace) {
336 // Open a file for writing. The ios::binary flag is used to write data as
337 // binary.
338 std::ofstream file(filename_, std::ios::binary);
339
340 // Check if the file was opened successfully.
341 if (!file.is_open()) {
342 LOG(ERROR) << "Error opening file for writing NCCLPG debug info: "
343 << filename_;
344 return;
345 }
346
347 file.write(ncclTrace.data(), ncclTrace.size());
348 LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_;
349 }
350
getWriter(int rank)351 DebugInfoWriter& DebugInfoWriter::getWriter(int rank) {
352 if (writer_ == nullptr) {
353 std::string fileNamePrefix = getCvarString(
354 {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
355 // Using std::unique_ptr here to auto-delete the writer object
356 // when the pointer itself is destroyed.
357 std::unique_ptr<DebugInfoWriter> writerPtr(
358 new DebugInfoWriter(fileNamePrefix, rank));
359 DebugInfoWriter::registerWriter(std::move(writerPtr));
360 }
361 return *writer_;
362 }
363
registerWriter(std::unique_ptr<DebugInfoWriter> writer)364 void DebugInfoWriter::registerWriter(std::unique_ptr<DebugInfoWriter> writer) {
365 TORCH_CHECK_WITH(
366 DistBackendError,
367 hasWriterRegistered_.load() == false,
368 "debugInfoWriter already registered");
369 hasWriterRegistered_.store(true);
370 writer_ = std::move(writer);
371 }
372
record(size_t pg_id,const std::tuple<std::string,std::string> & pg_name,size_t collective_seq_id,size_t p2p_seq_id,size_t op_id,std::string profiling_name,const std::vector<at::Tensor> & inputs,const std::vector<at::Tensor> & outputs,Event * start,Event * end,std::chrono::milliseconds timeout_ms,std::shared_ptr<ProcessGroupStatus> pg_status,bool isP2P)373 std::optional<size_t> NCCLTraceBuffer::record(
374 size_t pg_id,
375 const std::tuple<std::string, std::string>& pg_name,
376 size_t collective_seq_id,
377 size_t p2p_seq_id,
378 size_t op_id,
379 std::string profiling_name,
380 const std::vector<at::Tensor>& inputs,
381 const std::vector<at::Tensor>& outputs,
382 Event* start,
383 Event* end,
384 std::chrono::milliseconds timeout_ms,
385 std::shared_ptr<ProcessGroupStatus> pg_status,
386 bool isP2P) {
387 if (!enabled_) {
388 return std::nullopt;
389 }
390 if (all_pg_status_.find(pg_id) == all_pg_status_.end()) {
391 // Current pg_status is not in FR.
392 all_pg_status_[pg_id] = pg_status;
393 }
394 auto traceback =
395 torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
396 std::lock_guard<std::mutex> guard(mutex_);
397
398 auto te = Entry{
399 id_,
400 pg_id,
401 pg_name,
402 collective_seq_id,
403 p2p_seq_id,
404 op_id,
405 std::move(profiling_name),
406 std::move(traceback),
407 std::move(start),
408 std::move(end),
409 c10::getTime(),
410 timeout_ms.count(),
411 isP2P,
412 std::nullopt,
413 std::nullopt,
414 std::nullopt,
415 {},
416 {},
417 {},
418 {},
419 {},
420 false};
421
422 for (const auto& input : inputs) {
423 c10::IntArrayRef sizes = input.sizes();
424 te.input_dtypes_.push_back(input.dtype().toScalarType());
425 te.input_dims_.push_back(sizes.size());
426 te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end());
427 }
428
429 for (const auto& output : outputs) {
430 c10::IntArrayRef sizes = output.sizes();
431 te.output_dtypes_.push_back(output.dtype().toScalarType());
432 te.output_dims_.push_back(sizes.size());
433 te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end());
434 }
435
436 if (entries_.size() < max_entries_) {
437 entries_.emplace_back(std::move(te));
438 } else {
439 entries_[next_++] = std::move(te);
440 if (next_ == max_entries_) {
441 next_ = 0;
442 }
443 }
444 return id_++;
445 }
446
record_pg_ranks(const std::tuple<std::string,std::string> & pg_name,std::vector<uint64_t> ranks)447 void NCCLTraceBuffer::record_pg_ranks(
448 const std::tuple<std::string, std::string>& pg_name,
449 std::vector<uint64_t> ranks) {
450 if (!enabled_) {
451 return;
452 }
453 std::lock_guard<std::mutex> guard(mutex_);
454 pg_name_to_ranks_[pg_name] = ranks;
455 }
456
update_state(Entry & r)457 void NCCLTraceBuffer::update_state(Entry& r) {
458 if (r.start_ != nullptr) {
459 bool started = r.start_->query();
460 if (started && !r.time_discovered_started_) {
461 r.time_discovered_started_ = c10::getTime();
462 }
463 }
464 if (r.end_ != nullptr) {
465 bool completed = r.end_->query();
466 if (completed && !r.time_discovered_completed_) {
467 r.time_discovered_completed_ = c10::getTime();
468 }
469 }
470 }
471
dump_entries()472 std::vector<NCCLTraceBuffer::Entry> NCCLTraceBuffer::dump_entries() {
473 std::lock_guard<std::mutex> guard(mutex_);
474 std::vector<Entry> result;
475 result.reserve(entries_.size());
476 result.insert(result.end(), entries_.begin() + next_, entries_.end());
477 result.insert(result.end(), entries_.begin(), entries_.begin() + next_);
478 // query any remaining events
479 for (auto& r : result) {
480 update_state(r);
481 r.start_ = r.end_ = nullptr;
482 }
483 return result;
484 }
485
retire_id(std::optional<size_t> id,bool compute_duration)486 void NCCLTraceBuffer::retire_id(
487 std::optional<size_t> id,
488 bool compute_duration) {
489 if (!enabled_ || !id) {
490 return;
491 }
492
493 bool can_compute_duration = false;
494 Event* startEvent = nullptr;
495 Event* endEvent = nullptr;
496 std::optional<float> duration = std::nullopt;
497
498 std::unique_lock<std::mutex> guard(mutex_);
499
500 Entry* entry = &entries_.at(*id % max_entries_);
501 if (entry->id_ == *id) {
502 update_state(*entry);
503
504 if (compute_duration) {
505 can_compute_duration = entry->time_discovered_completed_.has_value() &&
506 entry->start_ && entry->end_;
507 startEvent = entry->start_;
508 endEvent = entry->end_;
509 }
510 entry->retired_ = true;
511 entry->start_ = entry->end_ = nullptr;
512 }
513
514 if (can_compute_duration) {
515 // Compute duration without without holding the lock, because
516 // cudaEventDuration() can hang, and we need to acquire the lock before we
517 // can dump(), which we never want to block.
518 guard.unlock();
519 duration = getDurationFromEvent(*startEvent, *endEvent);
520 guard.lock();
521
522 // Refresh the entry pointer, see if the entry has been overwritten
523 entry = &entries_.at(*id % max_entries_);
524 if (entry->id_ != *id) {
525 LOG(INFO) << "retire_id abandoned for id " << *id
526 << ", event was overwritten while waiting to compute duration.";
527 return;
528 }
529 if (duration.has_value()) {
530 entry->duration_ = duration.value();
531 }
532 }
533 }
534
getCollectiveTrace(bool includeStacktraces,bool onlyActive)535 const c10::List<c10::IValue> NCCLTraceBuffer::getCollectiveTrace(
536 bool includeStacktraces,
537 bool onlyActive) {
538 auto entries = new_list();
539 // Entries are returned in the order they were recorded
540 auto result = dump_entries();
541 std::vector<torch::CapturedTraceback*> tracebacks;
542 torch::SymbolizedTracebacks stracebacks;
543 std::vector<c10::IValue> all_frames;
544 if (includeStacktraces) {
545 for (auto& e : result) {
546 tracebacks.push_back(e.traceback_.get());
547 }
548 stracebacks = torch::symbolize(tracebacks);
549 for (const auto& f : stracebacks.all_frames) {
550 auto d = new_dict();
551 d.insert(name_key, f.funcname);
552 d.insert(filename_key, f.filename);
553 d.insert(line_key, int64_t(f.lineno));
554 all_frames.emplace_back(std::move(d));
555 }
556 }
557 for (auto i : c10::irange(result.size())) {
558 auto dict = new_dict();
559 auto& e = result.at(i);
560 // Skip completed events
561 if (onlyActive && e.time_discovered_completed_.has_value()) {
562 continue;
563 }
564 if (includeStacktraces) {
565 auto& tb = stracebacks.tracebacks.at(i);
566 auto frames = new_list();
567 for (int64_t frame : tb) {
568 frames.push_back(all_frames.at(frame));
569 }
570 dict.insert(frames_key, frames);
571 }
572
573 dict.insert(record_id_key, int64_t(e.id_));
574 dict.insert(pg_id_key, int64_t(e.pg_id_));
575 dict.insert(pg_name_key, e.pg_name_);
576 dict.insert(collective_seq_id_key, int64_t(e.collective_seq_id_));
577 dict.insert(p2p_seq_id_key, int64_t(e.p2p_seq_id_));
578 dict.insert(op_id_key, int64_t(e.op_id_));
579 dict.insert(profiling_name_key, e.profiling_name_);
580 dict.insert(time_created_key, int64_t(e.time_created_));
581 if (e.duration_) {
582 dict.insert(duration_key, *e.duration_);
583 }
584
585 auto it = e.sizes_.begin();
586 auto read_sizes = [&](const c10::SmallVector<int, 4>& dims) {
587 auto sizes = new_list();
588 for (auto dim : dims) {
589 auto arg_sizes = new_list();
590 for (C10_UNUSED auto i : c10::irange(dim)) {
591 arg_sizes.push_back(*it++);
592 }
593 sizes.push_back(arg_sizes);
594 }
595 return sizes;
596 };
597
598 dict.insert(input_sizes_key, read_sizes(e.input_dims_));
599 std::vector<std::string> input_dtypes_strs;
600 input_dtypes_strs.reserve(e.input_dtypes_.size());
601 for (const auto& input_dtype : e.input_dtypes_) {
602 input_dtypes_strs.push_back(c10::toString(input_dtype));
603 }
604 dict.insert(input_dtypes_key, input_dtypes_strs);
605 dict.insert(output_sizes_key, read_sizes(e.output_dims_));
606 std::vector<std::string> output_dtypes_strs;
607 output_dtypes_strs.reserve(e.output_dtypes_.size());
608 for (const auto& output_dtype : e.output_dtypes_) {
609 output_dtypes_strs.push_back(c10::toString(output_dtype));
610 }
611 dict.insert(output_dtypes_key, output_dtypes_strs);
612 if (e.time_discovered_completed_.has_value()) {
613 dict.insert(state_key, completed_state);
614 } else if (e.time_discovered_started_.has_value()) {
615 dict.insert(state_key, started_state);
616 } else {
617 dict.insert(state_key, scheduled_state);
618 }
619
620 dict.insert(
621 time_discovered_started_key,
622 e.time_discovered_started_.has_value()
623 ? int64_t(*e.time_discovered_started_)
624 : c10::IValue());
625 dict.insert(
626 time_discovered_completed_key,
627 e.time_discovered_completed_.has_value()
628 ? int64_t(*e.time_discovered_completed_)
629 : c10::IValue());
630 dict.insert(retired_key, e.retired_);
631 dict.insert(timeout_key, e.timeout_ms_);
632 dict.insert(is_p2p_key, e.isP2P_);
633
634 entries.push_back(dict);
635 }
636 return entries;
637 }
638
getPgConfig()639 const c10::Dict<c10::IValue, c10::IValue> NCCLTraceBuffer::getPgConfig() {
640 auto pg_config = new_dict();
641 for (const auto& [pg_name, ranks] : pg_name_to_ranks_) {
642 auto pg_info = new_dict();
643 pg_info.insert("name", std::get<0>(pg_name));
644 pg_info.insert("desc", std::get<1>(pg_name));
645 pg_info.insert("ranks", ranks_str(ranks));
646 pg_config.insert(std::get<0>(pg_name), pg_info);
647 }
648 return pg_config;
649 }
650
651 const std::map<std::string, std::map<std::string, std::string>> NCCLTraceBuffer::
getPgConfigJson()652 getPgConfigJson() {
653 std::map<std::string, std::map<std::string, std::string>> result;
654 for (const auto& [pg_name, ranks] : pg_name_to_ranks_) {
655 auto pg_info = std::map<std::string, std::string>();
656 pg_info["name"] = std::get<0>(pg_name);
657 pg_info["desc"] = std::get<1>(pg_name);
658 pg_info["ranks"] = ranks_str(ranks);
659 result.emplace(std::get<0>(pg_name), pg_info);
660 }
661 return result;
662 }
663
getPgStatus()664 const c10::Dict<c10::IValue, c10::IValue> NCCLTraceBuffer::getPgStatus() {
665 auto all_pg_status = new_dict();
666 for (const auto& [pg_id, status] : all_pg_status_) {
667 auto pg_status = new_dict();
668 pg_status.insert("last_enqueued_collective", status->lastEnqueuedSeq);
669 pg_status.insert("last_started_collective", status->lastStartedSeq);
670 pg_status.insert("last_completed_collective", status->lastCompletedSeq);
671 all_pg_status.insert(std::to_string(pg_id), pg_status);
672 }
673 return all_pg_status;
674 }
675
676 const std::map<std::string, std::map<std::string, std::string>> NCCLTraceBuffer::
getPgStatusJson()677 getPgStatusJson() {
678 std::map<std::string, std::map<std::string, std::string>> result;
679 for (const auto& [pg_id, status] : all_pg_status_) {
680 auto pg_status = std::map<std::string, std::string>();
681 pg_status["last_enqueued_collective"] =
682 std::to_string(status->lastEnqueuedSeq);
683 pg_status["last_started_collective"] =
684 std::to_string(status->lastStartedSeq);
685 pg_status["last_completed_collective"] =
686 std::to_string(status->lastCompletedSeq);
687 result[std::to_string(pg_id)] = pg_status;
688 }
689 return result;
690 }
691
dump_json(const std::optional<std::unordered_map<std::string,std::unordered_map<std::string,std::string>>> & ncclDumpMap,bool includeCollectives,bool onlyActive)692 std::string NCCLTraceBuffer::dump_json(
693 const std::optional<std::unordered_map<
694 std::string,
695 std::unordered_map<std::string, std::string>>>& ncclDumpMap,
696 bool includeCollectives,
697 bool onlyActive) {
698 using json = nlohmann::json;
699 json result;
700 result[version_key_str] = version_val_str;
701 result[pg_config_key_str] = getPgConfigJson();
702 result[pg_status_key_str] = getPgStatusJson();
703
704 // collective trace
705 if (includeCollectives) {
706 std::list<json> entries;
707 for (auto& e : dump_entries()) {
708 json j;
709 if (onlyActive && e.time_discovered_completed_.has_value()) {
710 continue;
711 }
712 j[record_id_key_str] = int64_t(e.id_);
713 j[pg_id_key_str] = int64_t(e.pg_id_);
714 j[pg_name_key_str] = e.pg_name_;
715 j[collective_seq_id_key_str] = int64_t(e.collective_seq_id_);
716 j[p2p_seq_id_key_str] = int64_t(e.p2p_seq_id_);
717 j[op_id_key_str] = int64_t(e.op_id_);
718 j[profiling_name_key_str] = e.profiling_name_;
719 j[time_created_key_str] = int64_t(e.time_created_);
720 if (e.duration_) {
721 j[duration_key_str] = *e.duration_;
722 }
723 auto it = e.sizes_.begin();
724 auto read_sizes = [&](const c10::SmallVector<int, 4>& dims) {
725 auto sizes = std::list<std::list<int>>();
726 for (auto dim : dims) {
727 auto arg_sizes = std::list<int>();
728 for (auto i : c10::irange(dim)) {
729 (void)i;
730 arg_sizes.push_back(*it++);
731 }
732 sizes.push_back(arg_sizes);
733 }
734 return sizes;
735 };
736 j[input_sizes_key_str] = read_sizes(e.input_dims_);
737 std::vector<std::string> input_dtypes_strs;
738 input_dtypes_strs.reserve(e.input_dtypes_.size());
739 for (const auto& input_dtype : e.input_dtypes_) {
740 input_dtypes_strs.push_back(c10::toString(input_dtype));
741 }
742 j[input_dtypes_key_str] = input_dtypes_strs;
743 j[output_sizes_key_str] = read_sizes(e.output_dims_);
744 std::vector<std::string> output_dtypes_strs;
745 output_dtypes_strs.reserve(e.output_dtypes_.size());
746 for (const auto& output_dtype : e.output_dtypes_) {
747 output_dtypes_strs.push_back(c10::toString(output_dtype));
748 }
749 j[output_dtypes_key_str] = output_dtypes_strs;
750 if (e.time_discovered_completed_.has_value()) {
751 j[state_key_str] = completed_state_str;
752 } else if (e.time_discovered_started_.has_value()) {
753 j[state_key_str] = started_state_str;
754 } else {
755 j[state_key_str] = scheduled_state_str;
756 }
757 j[time_discovered_started_key_str] =
758 e.time_discovered_started_.has_value()
759 ? int64_t(*e.time_discovered_started_)
760 : 0;
761 j[time_discovered_completed_key_str] =
762 e.time_discovered_completed_.has_value()
763 ? int64_t(*e.time_discovered_completed_)
764 : 0;
765 j[retired_key_str] = e.retired_;
766 j[timeout_key_str] = e.timeout_ms_;
767 j[is_p2p_key_str] = e.isP2P_;
768 entries.emplace_back(j);
769 }
770
771 if (entries.size() > 0) {
772 result[entries_key_str] = entries;
773 }
774 }
775
776 if (ncclDumpMap.has_value()) {
777 result[nccl_comm_key_str] = ncclDumpMap.value();
778 }
779
780 return result.dump();
781 }
782
dump(const std::optional<std::unordered_map<std::string,std::unordered_map<std::string,std::string>>> & ncclDumpMap,bool includeCollectives,bool includeStackTraces,bool onlyActive)783 std::string NCCLTraceBuffer::dump(
784 const std::optional<std::unordered_map<
785 std::string,
786 std::unordered_map<std::string, std::string>>>& ncclDumpMap,
787 bool includeCollectives,
788 bool includeStackTraces,
789 bool onlyActive) {
790 auto result = new_dict();
791 // common values
792 result.insert(version_key, version_val);
793 result.insert(pg_config_key, getPgConfig());
794 result.insert(pg_status_key, getPgStatus());
795
796 // collective trace
797 if (includeCollectives) {
798 result.insert(
799 entries_key, getCollectiveTrace(includeStackTraces, onlyActive));
800 }
801 // convert ncclDumpMap into a dictionary
802 auto per_comm_dict = new_dict();
803 if (ncclDumpMap.has_value()) {
804 for (const auto& [ncclId, ncclDump] : ncclDumpMap.value()) {
805 auto inner_dict = new_dict();
806 for (const auto& [key, value] : ncclDump) {
807 inner_dict.insert(key, value);
808 }
809 per_comm_dict.insert(ncclId, inner_dict);
810 }
811 }
812 if (per_comm_dict.size() > 0) {
813 result.insert(nccl_comm_key, per_comm_dict);
814 }
815 return pickle_str(result);
816 }
817
818 std::unique_ptr<DebugInfoWriter> DebugInfoWriter::writer_ = nullptr;
819 std::atomic<bool> DebugInfoWriter::hasWriterRegistered_(false);
820
getDurationFromEvent(at::cuda::CUDAEvent & ncclStartEvent,at::cuda::CUDAEvent & ncclEndEvent)821 float getDurationFromEvent(
822 at::cuda::CUDAEvent& ncclStartEvent,
823 at::cuda::CUDAEvent& ncclEndEvent) {
824 TORCH_CHECK(
825 ncclEndEvent.query(),
826 "getDuration can only be called after work is succeeded.")
827 return ncclStartEvent.elapsed_time(ncclEndEvent);
828 }
829
830 } // namespace c10d
831
832 #endif // USE_C10D_NCCL
833