1 #include <c10/util/StringUtil.h>
2 #include <fmt/format.h>
3 #include <torch/csrc/distributed/c10d/Utils.hpp>
4 #include <torch/csrc/distributed/c10d/debug.h>
5 #include <torch/csrc/distributed/c10d/logger.hpp>
6 #include <string>
7
8 #include <c10/util/CallOnce.h>
9
10 #ifdef USE_C10D_GLOO
11 #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
12 #endif
13
14 namespace c10d {
15
16 static std::vector<std::string> TORCH_NCCL_BLOCKING_WAIT = {
17 "TORCH_NCCL_BLOCKING_WAIT",
18 "NCCL_BLOCKING_WAIT"};
19 static std::vector<std::string> TORCH_NCCL_ASYNC_ERROR_HANDLING = {
20 "TORCH_NCCL_ASYNC_ERROR_HANDLING",
21 "NCCL_ASYNC_ERROR_HANDLING"};
22
23 // Logs runtime stats to configured destination. Note that since data collection
24 // only runs every ddp_runtime_logging_sample_rate iterations, the actual
25 // training iterations recorded will be like 10,
26 // (20-10) * ddp_runtime_logging_sample_rate,
27 // (50-10) * ddp_runtime_logging_sample_rate and so on.
28 const int LoggingIterations[] = {10, 20, 50, 100, 500, 800, 1000}; // NOLINT
29
operator <<(std::ostream & output,const Logger & logger)30 std::ostream& operator<<(std::ostream& output, const Logger& logger) {
31 auto& ddp_logging_data = (*logger.ddp_logging_data_);
32
33 std::string loggerInfo = fmt::format(
34 "[Rank {} / {}] [before iteration {}] Training {} unused_parameter_size={} \n "
35 "Avg forward compute time: {} \n Avg backward compute time: {} \n"
36 "Avg backward comm. time: {} \n Avg backward comm/comp overlap time: {}",
37 ddp_logging_data.ints_map["rank"],
38 ddp_logging_data.ints_map["world_size"],
39 ddp_logging_data.ints_map["iteration"],
40 ddp_logging_data.strs_map["module_name"],
41 ddp_logging_data.ints_map["unused_parameter_size"],
42 ddp_logging_data.ints_map["avg_forward_compute_time"],
43 ddp_logging_data.ints_map["avg_backward_compute_time"],
44 ddp_logging_data.ints_map["avg_backward_comm_time"],
45 ddp_logging_data.ints_map["avg_backward_compute_comm_overlap_time"]);
46
47 if (!ddp_logging_data.strs_map["comm_hook"].empty()) {
48 loggerInfo += fmt::format(
49 "\n Gradient comm. hook: {}", ddp_logging_data.strs_map["comm_hook"]);
50 }
51
52 if (ddp_logging_data.ints_map["join_uneven_inputs"]) {
53 loggerInfo += "\n Uneven input detection with join() enabled.";
54 }
55
56 return output << loggerInfo;
57 }
58
Logger(std::shared_ptr<c10d::Reducer> reducer)59 Logger::Logger(std::shared_ptr<c10d::Reducer> reducer)
60 : reducer_(std::move(reducer)) {
61 ddp_logging_data_ = std::make_unique<at::DDPLoggingData>();
62 }
63
64 c10::once_flag log_graph_static_flag;
65
log_if_graph_static(bool is_static)66 void Logger::log_if_graph_static(bool is_static) {
67 c10::call_once(log_graph_static_flag, [this, is_static]() {
68 ddp_logging_data_->ints_map["can_set_static_graph"] = is_static;
69 // It is useful to report the iteration that training finished at.
70 ddp_logging_data_->ints_map["iteration"] = reducer_->num_iterations_;
71 at::LogPyTorchDDPUsage(*ddp_logging_data_);
72 });
73 }
74
75 // Environment variables
set_env_variables()76 void Logger::set_env_variables() {
77 ddp_logging_data_->strs_map["master_port"] =
78 getCvarString({"MASTER_PORT"}, "N/A");
79 ddp_logging_data_->strs_map["master_addr"] =
80 getCvarString({"MASTER_ADDR"}, "N/A");
81 ddp_logging_data_->strs_map["torch_distributed_debug"] =
82 getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, "N/A");
83 ddp_logging_data_->strs_map["cuda_visible_devices"] =
84 getCvarString({"CUDA_VISIBLE_DEVICES"}, "N/A");
85 if (reducer_->process_group_->getBackendName() == "nccl") {
86 ddp_logging_data_->strs_map["nccl_socket_ifname"] =
87 getCvarString({"NCCL_SOCKET_IFNAME"}, "N/A");
88 ddp_logging_data_->strs_map["nccl_blocking_wait"] =
89 getCvarString(TORCH_NCCL_BLOCKING_WAIT, "N/A");
90 ddp_logging_data_->strs_map["nccl_async_error_handling"] =
91 getCvarString(TORCH_NCCL_ASYNC_ERROR_HANDLING, "N/A");
92 ddp_logging_data_->strs_map["nccl_debug"] =
93 getCvarString({"NCCL_DEBUG"}, "N/A");
94 ddp_logging_data_->strs_map["nccl_nthreads"] =
95 getCvarString({"NCCL_NTHREADS"}, "N/A");
96 ddp_logging_data_->strs_map["nccl_ib_timeout"] =
97 getCvarString({"NCCL_IB_TIMEOUT"}, "N/A");
98 }
99 if (reducer_->process_group_->getBackendName() == "gloo") {
100 ddp_logging_data_->strs_map["gloo_socket_ifname"] =
101 getCvarString({"GLOO_SOCKET_IFNAME"}, "N/A");
102 ddp_logging_data_->strs_map["gloo_device_transport"] =
103 getCvarString({"GLOO_DEVICE_TRANSPORT"}, "N/A");
104
105 #ifdef USE_C10D_GLOO
106 auto gloo_pg = static_cast<c10d::ProcessGroupGloo*>(
107 reducer_->process_group_
108 ->getBackend(c10d::ProcessGroup::BackendType::GLOO)
109 .get());
110 auto n_threads = gloo_pg->getNumThreads();
111 ddp_logging_data_->ints_map["gloo_num_threads"] = n_threads;
112 #endif
113 }
114 }
115
set_parameter_stats()116 void Logger::set_parameter_stats() {
117 // The number of parameter tensors
118 ddp_logging_data_->ints_map["num_parameter_tensors"] =
119 reducer_->params_.size();
120 // Total parameters size (Bytes)
121 ddp_logging_data_->ints_map["total_parameter_size_bytes"] = 0;
122 // Parameters' data types, there may be multiple data
123 // types for mixed precision training.
124 std::set<std::string> unique_dtypes;
125 for (const auto& t : reducer_->params_) {
126 ddp_logging_data_->ints_map["total_parameter_size_bytes"] +=
127 t.numel() * t.element_size();
128 unique_dtypes.insert(std::string(t.dtype().name()));
129 }
130 ddp_logging_data_->strs_map["dtypes"] = c10::Join(", ", unique_dtypes);
131 }
132
get_per_bucket_variable_indices()133 std::vector<std::vector<size_t>> Logger::get_per_bucket_variable_indices() {
134 std::vector<std::vector<size_t>> per_bucket_variable_indices;
135 per_bucket_variable_indices.reserve(reducer_->buckets_.size());
136 for (const auto& bucket : reducer_->buckets_) {
137 const auto& indices = bucket.variable_indices;
138 per_bucket_variable_indices.push_back(indices);
139 }
140 return per_bucket_variable_indices;
141 }
142
get_bucket_sizes()143 std::vector<int64_t> Logger::get_bucket_sizes() {
144 std::vector<int64_t> bucket_sizes;
145 for (const auto& bucket : reducer_->buckets_) {
146 const auto& variables = bucket.variables;
147 int64_t bucket_size = 0;
148 for (const auto& v : variables) {
149 bucket_size += v.numel() * v.element_size();
150 }
151 bucket_sizes.push_back(bucket_size);
152 }
153 return bucket_sizes;
154 }
155
156 // Communication hook. Empty string if not set, in which case it will not be
157 // logged.
set_comm_hook(const std::string & hook)158 void Logger::set_comm_hook(const std::string& hook) {
159 ddp_logging_data_->strs_map["comm_hook"] = hook;
160 }
161
162 // Whether we are running under model.join() context manager for DDP uneven
163 // inputs.
set_uneven_input_join()164 void Logger::set_uneven_input_join() {
165 ddp_logging_data_->ints_map["join_uneven_inputs"] = true;
166 }
167
set_static_graph()168 void Logger::set_static_graph() {
169 ddp_logging_data_->ints_map["static_graph"] = reducer_->static_graph_;
170 }
171
172 // Data that can be got during DistributedDataParallel construction time
set_construction_data_and_log(const std::string & module_name,const std::vector<int> & device_ids,int output_device,bool broadcast_buffers,bool has_sync_bn,bool static_graph)173 void Logger::set_construction_data_and_log(
174 const std::string& module_name,
175 const std::vector<int>& device_ids,
176 int output_device,
177 bool broadcast_buffers,
178 bool has_sync_bn,
179 bool static_graph) {
180 // No lock is needed, as it will be called in DistributedDataParallel
181 // constructor.
182 if (static_graph) {
183 set_static_graph();
184 }
185 ddp_logging_data_->strs_map["module_name"] = module_name;
186 ddp_logging_data_->ints_map["world_size"] =
187 reducer_->process_group_->getSize();
188 ddp_logging_data_->ints_map["rank"] = reducer_->process_group_->getRank();
189 // In which iteration of the training loop the get_ddp_logging_data()
190 // is called to fetch the DDPLoggingData, 0 if the data is fetched
191 // before training loop.
192 ddp_logging_data_->ints_map["iteration"] = 0;
193 ddp_logging_data_->ints_map["is_multi_device_module"] =
194 reducer_->is_multi_device_module_;
195
196 set_parameter_stats();
197 // A list of bucket sizes (Bytes) calculated during construction time
198 ddp_logging_data_->strs_map["bucket_sizes"] =
199 c10::Join(", ", get_bucket_sizes());
200 set_env_variables();
201
202 // DistributedDataParallel constructor input parameters
203 ddp_logging_data_->strs_map["device_ids"] = c10::Join(", ", device_ids);
204 ddp_logging_data_->ints_map["output_device"] = output_device;
205 ddp_logging_data_->ints_map["broadcast_buffers"] = broadcast_buffers;
206 ddp_logging_data_->ints_map["has_sync_bn"] = has_sync_bn;
207 ddp_logging_data_->ints_map["bucket_cap_bytes"] = reducer_->bucket_bytes_cap_;
208 ddp_logging_data_->ints_map["find_unused_parameters"] =
209 reducer_->find_unused_parameters_;
210 ddp_logging_data_->ints_map["gradient_as_bucket_view"] =
211 reducer_->gradient_as_bucket_view_;
212 ddp_logging_data_->strs_map["backend_name"] =
213 reducer_->process_group_->getBackendName();
214
215 if (debug_level() != DebugLevel::Off) {
216 std::string initInfo = fmt::format(
217 "[Rank {}]: DDP Initialized with: \n",
218 ddp_logging_data_->ints_map["rank"]);
219 std::stringstream ddpLoggingDataInfo;
220 for (const auto& intItem : ddp_logging_data_->ints_map) {
221 ddpLoggingDataInfo << intItem.first << ": " << intItem.second << "\n";
222 }
223 for (const auto& strItem : ddp_logging_data_->strs_map) {
224 ddpLoggingDataInfo << strItem.first << ": " << strItem.second << "\n";
225 }
226 LOG(INFO) << initInfo << ddpLoggingDataInfo.str();
227 }
228
229 at::LogPyTorchDDPUsage(*ddp_logging_data_);
230 }
231
set_event_time(int64_t & event_time,Timer & timer,Timer::Event event)232 void Logger::set_event_time(
233 int64_t& event_time,
234 Timer& timer,
235 Timer::Event event) {
236 auto timestamp = timer.getTimestamp(event);
237 if (timestamp != std::nullopt) {
238 // TODO: should we set this as human-readable time instead of unixtime?
239 event_time = *timestamp;
240 }
241 }
242
calculate_avg_time(int64_t & avg_time,int64_t & time_duration,Timer & timer,Timer::Event start_event,Timer::Event end_event)243 void Logger::calculate_avg_time(
244 int64_t& avg_time,
245 int64_t& time_duration,
246 Timer& timer,
247 Timer::Event start_event,
248 Timer::Event end_event) {
249 TORCH_CHECK(num_iterations_stats_recorded_ > 0);
250 std::optional<int64_t> maybe_time_duration =
251 timer.measureDifference(start_event, end_event);
252 if (!maybe_time_duration.has_value()) {
253 return;
254 }
255 time_duration = maybe_time_duration.value();
256 avg_time = (time_duration + avg_time * (num_iterations_stats_recorded_ - 1)) /
257 num_iterations_stats_recorded_;
258 }
259
reset_performance_stats()260 void Logger::reset_performance_stats() {
261 ddp_logging_data_->ints_map["forward_compute_time"] = 0;
262 ddp_logging_data_->ints_map["backward_comm_time"] = 0;
263 ddp_logging_data_->ints_map["backward_compute_time"] = 0;
264 ddp_logging_data_->ints_map["backward_compute_comm_overlap_time"] = 0;
265 ddp_logging_data_->ints_map["forward_compute_time_start"] = 0;
266 ddp_logging_data_->ints_map["backward_compute_time_start"] = 0;
267 ddp_logging_data_->ints_map["backward_comm_time_start"] = 0;
268 ddp_logging_data_->ints_map["backward_compute_time_end"] = 0;
269 ddp_logging_data_->ints_map["backward_comm_time_end"] = 0;
270 }
271
set_runtime_stats_and_log()272 void Logger::set_runtime_stats_and_log() {
273 // Sync with reducer's data
274 std::lock_guard<std::mutex> lock(reducer_->mutex_);
275 // Set runtime stats at the sampling iterations.
276 if (!reducer_->should_collect_runtime_stats()) {
277 return;
278 }
279 num_iterations_stats_recorded_++;
280 // Set ith iteration when the runtime stats are set.
281 ddp_logging_data_->ints_map["iteration"] = reducer_->num_iterations_;
282 // When get_ddp_logging_data() is called, "unused_parameter_size",
283 // "has_rebuilt_buckets" and "rebuilt_bucket_sizes" are updated in the latest
284 // sampling iteration.
285 // If unused_parameters_ is not empty, calculate its sizes.
286 // unused_parameters_ is calculated in forward call of
287 // each iteration.
288 if (reducer_->unused_parameters_.empty() &&
289 reducer_->find_unused_parameters_) {
290 // No unused params in this iteration
291 ddp_logging_data_->ints_map["unused_parameter_size"] = 0;
292 }
293 for (const auto& unused_index : reducer_->unused_parameters_) {
294 const auto& v = reducer_->params_[unused_index];
295 ddp_logging_data_->ints_map["unused_parameter_size"] +=
296 v.numel() * v.element_size();
297 }
298 // rebuilt_bucket_sizes will not change once buckets are rebuilt,
299 // so it only needs to set once during whole training loop.
300 // Rebuild buckets stats after 1st iteration
301 if (ddp_logging_data_->ints_map["has_rebuilt_buckets"] !=
302 reducer_->has_rebuilt_bucket_) {
303 ddp_logging_data_->ints_map["has_rebuilt_buckets"] =
304 reducer_->has_rebuilt_bucket_;
305 ddp_logging_data_->strs_map["rebuilt_bucket_sizes"] =
306 c10::Join(", ", get_bucket_sizes());
307 // Log per-bucket variable indices
308 std::vector<std::string> per_bucket_variable_indices;
309 auto indices = get_per_bucket_variable_indices();
310 per_bucket_variable_indices.reserve(indices.size());
311 for (const auto& bucket_indices : indices) {
312 per_bucket_variable_indices.push_back(c10::Join(" ", bucket_indices));
313 }
314 ddp_logging_data_->strs_map["rebuilt_per_bucket_param_indices"] =
315 c10::Join(", ", per_bucket_variable_indices);
316 }
317 // Log gradient ready order
318 if (!reducer_->grad_ready_order_indices_.empty()) {
319 // Note that the indices are for the previous iteration as
320 // this function is called in forward pass, and we last computed gradient
321 // ready order in the last backward pass.
322 ddp_logging_data_->strs_map["prev_iteration_grad_ready_order_indices"] =
323 c10::Join(", ", reducer_->grad_ready_order_indices_);
324 }
325
326 reset_performance_stats();
327
328 // Cuda time stats are only collected for single device modules.
329 if (reducer_->params_[0].is_cuda() && reducer_->is_multi_device_module_) {
330 TORCH_WARN_ONCE(
331 "Cuda time stats are not collected for multi-device modules.");
332 return;
333 }
334
335 if (!reducer_->timer_ &&
336 (!reducer_->params_[0].is_cuda() && !reducer_->params_[0].is_cpu())) {
337 TORCH_WARN_ONCE(
338 "Time stats are currently only collected for CPU and CUDA devices. "
339 "Please refer to CpuTimer or CudaTimer for how to register timer "
340 "for other device type.");
341 return;
342 }
343 TORCH_INTERNAL_ASSERT(reducer_->timer_);
344 calculate_avg_time(
345 ddp_logging_data_->ints_map["avg_forward_compute_time"],
346 ddp_logging_data_->ints_map["forward_compute_time"],
347 *reducer_->timer_,
348 Timer::Event::kForwardStart,
349 Timer::Event::kBackwardComputeStart);
350 calculate_avg_time(
351 ddp_logging_data_->ints_map["avg_backward_compute_time"],
352 ddp_logging_data_->ints_map["backward_compute_time"],
353 *reducer_->timer_,
354 Timer::Event::kBackwardComputeStart,
355 Timer::Event::kBackwardComputeEnd);
356 calculate_avg_time(
357 ddp_logging_data_->ints_map["avg_backward_comm_time"],
358 ddp_logging_data_->ints_map["backward_comm_time"],
359 *reducer_->timer_,
360 Timer::Event::kBackwardCommStart,
361 Timer::Event::kBackwardCommEnd);
362 calculate_avg_time(
363 ddp_logging_data_->ints_map["avg_backward_compute_comm_overlap_time"],
364 ddp_logging_data_->ints_map["backward_compute_comm_overlap_time"],
365 *reducer_->timer_,
366 Timer::Event::kBackwardCommStart,
367 Timer::Event::kBackwardComputeEnd);
368
369 set_event_time(
370 ddp_logging_data_->ints_map["forward_compute_time_start"],
371 *reducer_->timer_,
372 Timer::Event::kForwardStart);
373 set_event_time(
374 ddp_logging_data_->ints_map["backward_compute_time_start"],
375 *reducer_->timer_,
376 Timer::Event::kBackwardComputeStart);
377 set_event_time(
378 ddp_logging_data_->ints_map["backward_comm_time_start"],
379 *reducer_->timer_,
380 Timer::Event::kBackwardCommStart);
381 set_event_time(
382 ddp_logging_data_->ints_map["backward_compute_time_end"],
383 *reducer_->timer_,
384 Timer::Event::kBackwardComputeEnd);
385 set_event_time(
386 ddp_logging_data_->ints_map["backward_comm_time_end"],
387 *reducer_->timer_,
388 Timer::Event::kBackwardCommEnd);
389
390 // Log runtime stats to stderr if TORCH_DISTRIBUTED_DEBUG=DETAIL is enabled.
391 if (debug_level() == DebugLevel::Detail) {
392 LOG(INFO) << *this;
393 }
394
395 // Log runtime (e.g. avg performance) stats at the beginning and also
396 // after a larger number of iterations. Choosing 10/1000/10000 is
397 // not scientific here, it assumes most of applications will run
398 // at least 10 iterations. stats could have smaller variance if
399 // selected num_iterations_ is larger.
400 if (std::find(
401 std::begin(LoggingIterations),
402 std::end(LoggingIterations),
403 num_iterations_stats_recorded_) != std::end(LoggingIterations)) {
404 at::LogPyTorchDDPUsage(*ddp_logging_data_);
405 }
406 }
407
get_ddp_logging_data()408 at::DDPLoggingData Logger::get_ddp_logging_data() {
409 std::lock_guard<std::mutex> lock(reducer_->mutex_);
410 return *ddp_logging_data_;
411 }
412
413 // initialization of static variables in C10dLogger
414 std::unique_ptr<C10dLogger> C10dLogger::logger_ = nullptr;
415 std::atomic<bool> C10dLogger::registered_(false);
416
getLogger()417 C10dLogger* C10dLogger::getLogger() {
418 if (!registered_.load()) {
419 return nullptr;
420 }
421 return logger_.get();
422 }
423
registerLogger(std::unique_ptr<C10dLogger> logger)424 void C10dLogger::registerLogger(std::unique_ptr<C10dLogger> logger) {
425 if (registered_.load()) {
426 LOG(WARNING) << "C10dLogger has already been registered.";
427 return;
428 }
429 registered_.store(true);
430 logger_ = std::move(logger);
431 }
432
log(const C10dLoggingData & data)433 void C10dLogger::log(const C10dLoggingData& data) {
434 for (const auto& [key, value] : data.integers) {
435 LOG(INFO) << key << ": " << value;
436 }
437 for (const auto& [key, value] : data.strings) {
438 LOG(INFO) << key << ": " << value;
439 }
440 return;
441 }
442 } // namespace c10d
443