xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/reducer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/reducer.hpp>
2 
3 #include <torch/csrc/distributed/c10d/Utils.hpp>
4 #include <torch/csrc/distributed/c10d/default_comm_hooks.hpp>
5 
6 #include <functional>
7 
8 #include <c10/core/DeviceGuard.h>
9 #include <c10/core/ScalarType.h>
10 #include <c10/core/StreamGuard.h>
11 #include <c10/util/Exception.h>
12 #include <c10/util/Logging.h>
13 #include <c10/util/hash.h>
14 #include <c10/util/irange.h>
15 #include <torch/csrc/autograd/engine.h>
16 #include <torch/csrc/autograd/function_hook.h>
17 #include <torch/csrc/autograd/functions/accumulate_grad.h>
18 #include <torch/csrc/autograd/profiler.h>
19 #include <torch/csrc/autograd/utils/grad_layout_contract.h>
20 #include <torch/csrc/autograd/utils/lambda_post_hook.h>
21 #include <torch/csrc/distributed/c10d/comm.hpp>
22 #include <torch/csrc/distributed/c10d/logger.hpp>
23 #include <utility>
24 
25 namespace c10d {
26 namespace {
27 
28 constexpr int kUnsetDivFactor = -1;
29 
30 // Macro that wraps TORCH_CHECK with DDP logging.
31 #define REDUCER_CHECK(cond, logger_, ...)             \
32   if (C10_UNLIKELY_OR_CONST(!(cond))) {               \
33     if (!logger_.expired()) {                         \
34       logger_.lock()->set_error_and_log(__VA_ARGS__); \
35     }                                                 \
36     TORCH_CHECK(false, ##__VA_ARGS__);                \
37   }
38 
39 } // namespace
40 
41 C10_DEFINE_TYPED_REGISTRY( // NOLINT
42     TimerRegistry,
43     c10::DeviceType,
44     Timer,
45     std::unique_ptr,
46     c10::Device);
47 
48 namespace {
49 
50 class CpuTimer : public Timer {
51  public:
CpuTimer(c10::Device)52   explicit CpuTimer(c10::Device /* unused */) {}
53 
measureDifference(Event start,Event end)54   std::optional<int64_t> measureDifference(Event start, Event end) override {
55     int64_t start_time = getTimeRef(start);
56     int64_t end_time = getTimeRef(end);
57     // If cpu_end_time is not recorded in this iteration,
58     // avg_time will return invalid value.
59     // For some cases like DDP runs on non-sync mode, backward compute
60     // end time can not be recorded in this iteration and thus can not
61     // calculate the valid avg_time.
62     // In this case, skip calculating the avg_time and return.
63     if (end_time < start_time) {
64       return std::nullopt;
65     }
66     return end_time - start_time;
67   }
68 };
69 
70 C10_REGISTER_TYPED_CLASS(TimerRegistry, c10::kCPU, CpuTimer);
71 
extractTensors(const c10::IValue & result)72 std::vector<at::Tensor> extractTensors(const c10::IValue& result) {
73   if (result.isPyObject()) {
74     return result.toPyObjectHolder()->extractTensors();
75   }
76   TORCH_INTERNAL_ASSERT(
77       result.isTensor() || result.isTensorList(),
78       "expected the hook result is either a Tensor or a TensorList found ",
79       result.tagKind());
80 
81   if (result.isTensor()) {
82     return {result.toTensor()};
83   }
84 
85   return result.toTensorVector();
86 }
87 
88 } // namespace
89 
Reducer(std::vector<at::Tensor> params,std::vector<std::vector<size_t>> bucket_indices,const std::vector<size_t> & per_bucket_size_limits,c10::intrusive_ptr<c10d::ProcessGroup> process_group,std::vector<bool> expect_sparse_gradients,int64_t bucket_bytes_cap,bool find_unused_parameters,bool gradient_as_bucket_view,std::unordered_map<size_t,std::string> param_names,int64_t first_bucket_bytes_cap)90 Reducer::Reducer(
91     std::vector<at::Tensor> params,
92     std::vector<std::vector<size_t>> bucket_indices,
93     const std::vector<size_t>& per_bucket_size_limits,
94     c10::intrusive_ptr<c10d::ProcessGroup> process_group,
95     std::vector<bool> expect_sparse_gradients,
96     int64_t bucket_bytes_cap,
97     bool find_unused_parameters,
98     bool gradient_as_bucket_view,
99     std::unordered_map<size_t, std::string> param_names,
100     int64_t first_bucket_bytes_cap)
101     : params_(std::move(params)),
102       process_group_(std::move(process_group)),
103       expect_sparse_gradients_(std::move(expect_sparse_gradients)),
104       expect_autograd_hooks_(false),
105       require_finalize_(false),
106       next_bucket_(0),
107       has_marked_unused_parameters_(false),
108       find_unused_parameters_(find_unused_parameters),
109       gradient_as_bucket_view_(gradient_as_bucket_view),
110       local_used_map_reduced_(false),
111       num_iterations_(0),
112       num_bwd_calls_(0),
113       first_autograd_hook_called_(false),
114       num_buckets_ready_(0),
115       has_rebuilt_bucket_(false),
116       bucket_bytes_cap_(bucket_bytes_cap),
117       div_factor_(kUnsetDivFactor),
118       static_graph_(false),
119       comm_hook_(nullptr),
120       ddp_debug_level_(debug_level()),
121       param_names_(std::move(param_names)),
122       first_bucket_bytes_cap_(first_bucket_bytes_cap) {
123   C10_LOG_API_USAGE_ONCE("torch.distributed.ddp.reducer");
124   TORCH_INTERNAL_ASSERT(!params_.empty(), "Expected at least one parameter.");
125 
126   if (ddp_debug_level_ != c10d::DebugLevel::Off) {
127     LOG(INFO) << "Reducer initialized with bucket_bytes_cap: "
128               << bucket_bytes_cap_
129               << " first_bucket_bytes_cap: " << first_bucket_bytes_cap;
130   }
131   // Check whether the module is multi_device_module
132   {
133     std::set<int> unique_devices;
134     for (const auto& v : params_) {
135       auto device_idx = int(v.device().index());
136       if (unique_devices.find(device_idx) == unique_devices.end()) {
137         unique_devices.insert(device_idx);
138         if (unique_devices.size() > 1) {
139           is_multi_device_module_ = true;
140           break;
141         }
142       }
143     }
144   }
145 
146   // For CUDA, record events only for single device module.
147   c10::Device device = params_[0].device();
148   if (!(device.is_cuda() && is_multi_device_module_)) {
149     timer_ = TimerRegistry()->Create(device.type(), device);
150   }
151 
152   // If `expect_sparse_gradients` is not specified, initialize it such that
153   // we do not expect sparse gradients for any parameter.
154   if (expect_sparse_gradients_.empty()) {
155     expect_sparse_gradients_ = std::vector<bool>(params_.size(), false);
156   }
157   TORCH_INTERNAL_ASSERT(expect_sparse_gradients_.size() == params_.size());
158 
159   // Initialize variable bucketing.
160   // This can be reinitialized later after capturing runtime information.
161   {
162     std::lock_guard<std::mutex> lock(mutex_);
163     initialize_buckets(std::move(bucket_indices));
164   }
165 
166   // All variables are expected to have their `grad_fn` set to the gradient
167   // accumulation function (since they are leafs in the autograd graph).
168   // We store pointers to these functions such that we can check if they are
169   // used in an autograd pass. If they are not, we know their grad tensors
170   // can be marked as ready for reduction.
171   {
172     const auto variable_count = params_.size();
173     grad_accumulators_.resize(variable_count);
174     for (const auto variable_index : c10::irange(variable_count)) {
175       auto& variable = params_[variable_index];
176 
177       // The gradient accumulator function is lazily initialized once.
178       // Therefore we can use its presence in the autograd graph as
179       // evidence that the parameter has participated in an iteration.
180       auto grad_accumulator = torch::autograd::impl::grad_accumulator(variable);
181 
182 #ifndef _WIN32
183       using torch::distributed::autograd::ThreadLocalDistAutogradContext;
184 #endif
185       // Hook to execute after the gradient accumulator has executed.
186       hooks_.emplace_back(
187           grad_accumulator->add_post_hook(
188               std::make_unique<torch::autograd::utils::LambdaPostHook>(
189                   [this, variable_index](
190                       const torch::autograd::variable_list& outputs,
191                       const torch::autograd::variable_list& /* unused */) {
192 #ifndef _WIN32
193                     this->rpc_context_.set(
194                         ThreadLocalDistAutogradContext::getContextPtr());
195 #endif
196                     this->autograd_hook(variable_index);
197                     return outputs;
198                   },
199                   [=](torch::autograd::CompiledNodeArgs& args) {
200                     // Make post_hook an noop if compiled_autograds is enabled.
201                   })),
202           grad_accumulator);
203 
204       // Map raw function pointer to parameter index.
205       // This is used later on when the autograd graph is traversed
206       // to check for parameters for which no gradient is computed, if
207       // find_unused_parameters=True.
208       // Note that the mapping of gradient accumulator to variable should be
209       // one to one as we deduplicate shared parameters before constructing
210       // Reducer.
211       if (find_unused_parameters_) {
212         gradAccToVariableMap_[grad_accumulator.get()] = variable_index;
213       }
214 
215       numGradHooksTriggeredMap_[variable_index] = 0;
216 
217       // The gradient accumulator is stored as weak_ptr in the autograd
218       // metadata of the variable, so we have to keep it alive here for
219       // the raw pointer to be valid.
220       REDUCER_CHECK(
221           grad_accumulators_[variable_index] == nullptr,
222           logger_,
223           c10::str(
224               "Reducer tried to register duplicate grad accumulator for variable ",
225               variable_index));
226 
227       grad_accumulators_[variable_index] = std::move(grad_accumulator);
228     }
229   }
230 
231   // Initialize backward stats vector.
232   {
233     const auto variable_count = params_.size();
234     backward_stats_.resize(variable_count);
235   }
236 
237   // See Note [Skip allreducing local_used_map_dev]
238   if (find_unused_parameters_) {
239     initialize_local_used_map();
240   }
241 }
242 
243 // Note [Skip allreducing local_used_map_dev]
244 // ~~~~~~~~~~~~~~~~~~~~~~~~~~
245 // If find_unused_parameters_ is set to false, there is no need to allreduce
246 // local_used_map_dev_, because all parameters will be reduced anyway.
247 // Therefore, we can avoid allocating memory for local_used_map and
248 // local_used_map_dev_ if find_unused_parameters_ is false.
249 
250 // Note [DDP Communication Hook]
251 // ~~~~~~~~~~~~~~~~~~~~~~~~~~
252 // If DDP communication hook is not registered, the reducer reduces the buckets
253 // by just calling allreduce. If registered, it calls the hook and uses future
254 // work handle. If registered, reducer also skips dividing grads by world size.
255 // The reason for this is that the communication hook is expected to completely
256 // override how we perform communication and the user should have complete
257 // control over how the grads are handled.
258 //
259 // DDP communication hook is an enhancement that provides a hook which can be
260 // used to override how DDP communicates gradients across ranks, this can be
261 // used for algorithms like Gradient Compression/GossipGrad. This hook can be
262 // registered from Python API using `register_comm_hook`. `PythonCommHook`
263 // enables registering a Python hook and is a subclass of `CommHookInterface`.
264 // Additionally, there are also some built-in C++ hook implementations that can
265 // be specified by calling `register_builtin_comm_hook` from Python API.
266 
~Reducer()267 Reducer::~Reducer() noexcept(false) {
268   remove_autograd_hooks();
269 }
270 
dynamic_graph_find_unused()271 bool Reducer::dynamic_graph_find_unused() {
272   return !static_graph_ && find_unused_parameters_;
273 }
274 
static_graph_first_iteration()275 bool Reducer::static_graph_first_iteration() {
276   return static_graph_ && num_bwd_calls_ == 1;
277 }
278 
static_graph_after_first_iteration()279 bool Reducer::static_graph_after_first_iteration() {
280   return static_graph_ && num_bwd_calls_ > 1;
281 }
282 
ddp_graph_static()283 bool Reducer::ddp_graph_static() {
284   std::lock_guard<std::mutex> lock(mutex_);
285   return ddp_graph_static_;
286 }
287 
initialize_local_used_map()288 void Reducer::initialize_local_used_map() {
289   const auto variable_count = params_.size();
290   at::TensorOptions options;
291   options = options.dtype(at::kInt);
292 
293   // Deliberately don't pin the memory even if local_used_map_dev_ will
294   // be cuda. See Note [local_used_map_ -> local_used_map_dev copying]
295   local_used_map_ = at::zeros({static_cast<long>(variable_count)}, options);
296 
297   // This tensor needs to be on the same device as the replica params because
298   // backend such as NCCL may not support CPU tensors, and hence it might not
299   // work if we always put it on CPU. The dist backend for MTIA doesn't support
300   // int32 allreduce for now, so it has to be placed on CPU.
301   options = options.device(
302       (params_[0].is_mtia()) ? c10::Device(c10::DeviceType::CPU)
303                              : params_[0].device());
304   local_used_map_dev_ = at::empty({static_cast<long>(variable_count)}, options);
305 }
306 
check_grad_layout(const at::Tensor & grad,const at::Tensor & bucket_view)307 void Reducer::check_grad_layout(
308     const at::Tensor& grad,
309     const at::Tensor& bucket_view) {
310   // Ensure that the gradient type matches the bucket type, or mixed precision
311   // type if we are training with mixed precision.
312   auto type = mixed_precision_param_dtype_
313       ? *mixed_precision_param_dtype_
314       : bucket_view.options().dtype().toScalarType();
315   REDUCER_CHECK(
316       grad.options().dtype().toScalarType() == type,
317       logger_,
318       c10::str(
319           "Expected ", type, ", got ", grad.options().dtype().toScalarType()));
320 
321   TORCH_INTERNAL_ASSERT(grad.device() == bucket_view.device());
322   TORCH_INTERNAL_ASSERT(grad.numel() == bucket_view.numel());
323   // AccumulateGrad doesn't HAVE to obey the grad layout contract.
324   // The penalty for disobedience is reduced performance, not numerical
325   // death. Warnings here help diagnose poor DDP performance.
326   if (grad.strides() != bucket_view.strides()) {
327     TORCH_WARN_ONCE(
328         "Grad strides do not match bucket view strides. "
329         "This may indicate grad was not created according to the "
330         "gradient layout contract, or that the param's strides "
331         "changed since DDP was constructed.  This is not an error, "
332         "but may impair performance.\n"
333         "grad.sizes() = ",
334         grad.sizes(),
335         ", strides() = ",
336         grad.strides(),
337         "\n",
338         "bucket_view.sizes() = ",
339         bucket_view.sizes(),
340         ", strides() = ",
341         bucket_view.strides());
342   }
343   if (!gradient_as_bucket_view_) {
344     TORCH_INTERNAL_ASSERT(!grad.is_alias_of(bucket_view));
345   }
346 }
347 
mark_variable_ready_dense(size_t variable_index)348 void Reducer::mark_variable_ready_dense(size_t variable_index) {
349   const auto& bucket_index = variable_locators_[variable_index];
350   auto& bucket = buckets_[bucket_index.bucket_index];
351   auto& variable = bucket.variables[bucket_index.intra_bucket_index];
352   auto& bucket_view = bucket.bucket_views_in[bucket_index.intra_bucket_index];
353 
354   // Copy the contents of the gradient tensor to the corresponding part of the
355   // bucket's flattened gradient tensor.
356   // If the gradient is not set, we assume it wasn't computed as part of the
357   // current backwards pass, and we zero the part of the bucket it would
358   // otherwise hold.
359   runGradCallbackForVariable(variable, [&](auto& grad) {
360     if (grad.defined()) {
361       this->check_grad_layout(grad, bucket_view);
362       // When gradient_as_bucket_view_ is false, or even when
363       // gradient_as_bucket_view_ is true, in rare cases users may set grad to
364       // be None after every iteration. In these cases, grad and bucket_view are
365       // pointing to different storages and thus need to copy grads to
366       // bucket_view. If gradient_as_bucket_view_ is set as true, let grad point
367       // to bucket_view. If grad has already been set as views of buckets in
368       // previous iterations, no copy is needed.
369       if (!grad.is_alias_of(bucket_view)) {
370         if (comm_hook_ == nullptr) {
371           auto wrapped =
372               at::native::wrapped_scalar_tensor(double(1.) / div_factor_);
373           if (!grad.requires_grad()) {
374             // Divides while copying into the bucket view to save one scan over
375             // all the input parameters.
376             RECORD_FUNCTION(
377                 "torch::distributed::reducer::mul_out",
378                 std::vector<c10::IValue>({bucket_view}))
379             at::mul_out(bucket_view, grad, wrapped);
380           } else {
381             // If DDP is running with create_graph=True, gradients require_grad
382             // themselves in order to compute higher order derivatives. However,
383             // DDP will not sync up these gradients currently (see
384             // https://github.com/pytorch/pytorch/issues/63812).
385             C10_LOG_EVERY_N(WARNING, 1000)
386                 << "Using DistributedDataParallel with create_graph=True "
387                 << " is not well-supported. The higher-order gradient will "
388                 << " not be synchronized across ranks, and backpropagation "
389                 << " through all_reduce operations will not occur. If you require "
390                 << " DDP to work with higher-order gradients for your use case, "
391                 << " please ping https://github.com/pytorch/pytorch/issues/63929";
392             auto div_result = at::mul(grad, wrapped);
393             RECORD_FUNCTION(
394                 "torch::distributed::reducer::copy_",
395                 std::vector<c10::IValue>({bucket_view}))
396             bucket_view.copy_(div_result);
397           }
398         } else {
399           RECORD_FUNCTION(
400               "torch::distributed::reducer::copy_",
401               std::vector<c10::IValue>({bucket_view}))
402           bucket_view.copy_(grad);
403         }
404 
405         if (gradient_as_bucket_view_) {
406           // Let grad point to bucket_view buffer.
407           grad = bucket_view;
408           // The grad is modified and need to be written back.
409           return true;
410         }
411       } else {
412         // If grad and bucket view point to the same storage, no need to copy.
413         if (comm_hook_ == nullptr) {
414           bucket_view.div_(div_factor_);
415         }
416       }
417     } else {
418       // Gradient is undefined. When find_unused_parameters=True, ensure it is
419       // not marked as locally used, otherwise we will be allreducing zero's
420       // instead of not touching .grad field of parameter.
421       if (this->dynamic_graph_find_unused() ||
422           this->static_graph_first_iteration()) {
423         REDUCER_CHECK(
424             local_used_map_[variable_index].item<int>() == 0,
425             logger_,
426             "Encountered gradient which is undefined, but still allreduced by "
427             "DDP reducer. This indicates a bug in DDP implementation, please "
428             "report a bug with a repro to PyTorch.");
429       }
430       bucket_view.zero_();
431     }
432     // The grad is not modified and doesn't need to be written back.
433     return false;
434   });
435 }
436 
mark_variable_ready_sparse(size_t variable_index)437 void Reducer::mark_variable_ready_sparse(size_t variable_index) {
438   const auto& bucket_index = variable_locators_[variable_index];
439   auto& bucket = buckets_[bucket_index.bucket_index];
440   auto& variable = bucket.variables[bucket_index.intra_bucket_index];
441 
442   runGradCallbackForVariable(variable, [&](auto& grad) {
443     REDUCER_CHECK(
444         grad.defined(), logger_, "Expected sparse gradient to be defined.");
445     REDUCER_CHECK(
446         grad.options().layout() == c10::kSparse,
447         logger_,
448         "Expected variable to have sparse gradient.");
449 
450     // Copy the indices of sparse metadata
451     if (sparse_metadata_) {
452       grad = grad.coalesce();
453       REDUCER_CHECK(
454           !param_names_.empty(), logger_, "No parameter names were found");
455       std::string& param_name = param_names_[variable_index];
456       auto iter = sparse_metadata_->find(param_name);
457       REDUCER_CHECK(
458           iter != sparse_metadata_->end(),
459           logger_,
460           "param: " + param_name + " not found in sparse metadata");
461       bucket.sparse_tensor_indices =
462           iter->second.to(at::kLong).unsqueeze(0).to(grad.device());
463       auto indices = at::searchsorted(
464           bucket.sparse_tensor_indices.value(), grad.indices(), false, false);
465       // For indices we are using the ones set by sparse_metadata
466       grad = at::sparse_coo_tensor(indices, grad.values(), grad.sizes());
467     }
468 
469     // Sparse tensors cannot be grouped together with other sparse tensors in a
470     // single reduction operation like we can for dense tensors. Therefore, the
471     // `offsets` and `lengths` vectors in the bucket struct are empty, and
472     // there is no pre-existing accumulation tensor.
473     // Directly assign the sparse tensor to the `gradients` field.
474     bucket.gradients = grad;
475     // If no DDP comm hook is registered, the allreduce only sums up the
476     // value, and a separate division is required.
477     if (comm_hook_ == nullptr) {
478       bucket.gradients.div_(div_factor_);
479     }
480     // The grad is modified in place and needs to be written back.
481     return true;
482   });
483 }
484 
get_grad_buckets(bool return_zero_tensors) const485 std::vector<c10d::GradBucket> Reducer::get_grad_buckets(
486     bool return_zero_tensors) const {
487   std::lock_guard<std::mutex> lock(mutex_);
488   std::vector<c10d::GradBucket> gradBuckets;
489   gradBuckets.reserve(buckets_.size());
490   for (const auto i : c10::irange(buckets_.size())) {
491     auto& bucket = buckets_[i];
492     auto variables_for_bucket = get_variables_for_bucket(i, bucket);
493     gradBuckets.emplace_back(
494         i,
495         buckets_.size(),
496         return_zero_tensors ? at::zeros_like(bucket.gradients)
497                             : bucket.gradients,
498         bucket.offsets,
499         bucket.lengths,
500         bucket.sizes_vec,
501         variables_for_bucket,
502         std::nullopt);
503   }
504   return gradBuckets;
505 }
506 
set_forward_pass_work_handle(c10::intrusive_ptr<c10d::Work> forwardPassWorkHandle,bool useStaticWorldSize)507 void Reducer::set_forward_pass_work_handle(
508     c10::intrusive_ptr<c10d::Work> forwardPassWorkHandle,
509     bool useStaticWorldSize) {
510   std::lock_guard<std::mutex> lock(mutex_);
511   forwardPassWorkHandle_.workHandle = std::move(forwardPassWorkHandle);
512   forwardPassWorkHandle_.useStaticWorldSize = useStaticWorldSize;
513 }
514 
get_local_used_map_on_device() const515 at::Tensor Reducer::get_local_used_map_on_device() const {
516   std::lock_guard<std::mutex> lock(mutex_);
517   return local_used_map_dev_;
518 }
519 
push_rebuilt_params_for_all_indices()520 void Reducer::push_rebuilt_params_for_all_indices() {
521   std::lock_guard<std::mutex> lock(mutex_);
522   if (!should_rebuild_buckets() || !rebuilt_param_indices_.empty()) {
523     return;
524   }
525   const auto variable_count = params_.size();
526   for (const auto variable_index : c10::irange(variable_count)) {
527     push_rebuilt_params(variable_index);
528   }
529 }
530 
push_rebuilt_params(const size_t & index)531 void Reducer::push_rebuilt_params(const size_t& index) {
532   rebuilt_params_.push_back(params_[index]);
533   rebuilt_param_indices_.push_back(static_cast<int64_t>(index));
534 }
535 
set_divide_factor()536 void Reducer::set_divide_factor() {
537   // If it was scheduled, wait on allreduce in forward pass that tells us
538   // division factor based on no. of currently participating processes.
539   if (div_factor_ == kUnsetDivFactor) {
540     div_factor_ = process_group_->getSize();
541     auto& workHandle = forwardPassWorkHandle_.workHandle;
542     if (workHandle && !forwardPassWorkHandle_.useStaticWorldSize) {
543       workHandle->wait();
544       // PyProcessGroup::PyWork doesn't expose value, so fetch it from the
545       // future
546       auto results = extractTensors(workHandle->getFuture()->value());
547 
548       // Guard against the results being empty
549       TORCH_INTERNAL_ASSERT(!results.empty());
550       at::Tensor& res = results.front();
551       div_factor_ = res.item().to<int>();
552     }
553   }
554 }
555 
556 // This is called before training and converts the gradients to the dtype they
557 // should be reduced in.
set_mixed_precision_param_dtype(c10::ScalarType dtype)558 void Reducer::set_mixed_precision_param_dtype(c10::ScalarType dtype) {
559   mixed_precision_param_dtype_ = dtype;
560   for (auto& bucket : buckets_) {
561     bucket.gradients = bucket.gradients.to(dtype);
562   }
563 }
564 
565 // Right now delay_all_reduce is only called when static_graph_=true and
566 // num_iterations_==1.
delay_all_reduce()567 void Reducer::delay_all_reduce() {
568   std::lock_guard<std::mutex> lock(this->mutex_);
569 
570   if (should_collect_runtime_stats()) {
571     record_backward_compute_end_time();
572     record_backward_comm_start_time();
573   }
574 
575   // launch all reduce local used map
576   all_reduce_local_used_map();
577 
578   // prepare to set unused_parameters_, if it is static graph,
579   // unused_parameters_ will not change after 1st iteration.
580   unused_parameters_.clear();
581 
582   require_finalize_ = true;
583   // copy all gradients to buckets
584   for (const auto variable_index : c10::irange(params_.size())) {
585     // set unused_parameters_
586     if (numGradHooksTriggeredMap_[variable_index] == 0) {
587       unused_parameters_.push_back(variable_index);
588     }
589     set_divide_factor();
590     if (expect_sparse_gradients_[variable_index]) {
591       mark_variable_ready_sparse(variable_index);
592     } else {
593       mark_variable_ready_dense(variable_index);
594     }
595   }
596 
597   // To avoid confusion around why static graph is picking up
598   // some parameters as unused on a rank vs not, we log
599   // unused parameter names for each rank for better
600   // debugability when TORCH_DISTRIBUTED_DEBUG is set to
601   // INFO or DETAIL
602   if (ddp_debug_level_ != c10d::DebugLevel::Off) {
603     // construct one string to output
604     std::ostringstream unused_params_stream;
605 
606     for (const auto& unused_index : unused_parameters_) {
607       auto param_name = param_names_.find(unused_index);
608       TORCH_INTERNAL_ASSERT(
609           param_name != param_names_.end(),
610           "Expected to find parameter name from unused parameters map in debug mode.");
611       // Add the param_name
612       unused_params_stream << "{" << param_name->second << "," << unused_index
613                            << "}";
614     }
615 
616     // Each rank prints out all the unused parameters detected
617     if (!unused_parameters_.empty()) {
618       LOG(INFO) << "[Rank " << process_group_->getRank() << "]: "
619                 << "Parameter(s) (in the format of {param_name, index}): "
620                 << unused_params_stream.str()
621                 << " is(are) unused during first iteration. Since"
622                 << " static_graph=True is enabled for DDP, we expect"
623                 << " this set of unused parameters to remain consistent"
624                 << " on this rank throughout the training.";
625     }
626   }
627 
628   // launch all reduces for all buckets
629   for (auto& bucket : buckets_) {
630     all_reduce_bucket(bucket);
631   }
632 
633   finalize_backward();
634 }
635 
set_logger(std::weak_ptr<c10d::Logger> logger)636 void Reducer::set_logger(std::weak_ptr<c10d::Logger> logger) {
637   logger_ = std::move(logger);
638 }
639 
640 // The function `autograd_hook` is called after the gradient for a
641 // model parameter has been accumulated into its gradient tensor.
642 // This function is only to be called from the autograd thread.
autograd_hook(size_t index)643 void Reducer::autograd_hook(size_t index) {
644   std::lock_guard<std::mutex> lock(this->mutex_);
645   if (!first_autograd_hook_called_) {
646     first_autograd_hook_called_ = true;
647     num_bwd_calls_++;
648   }
649 
650   // See Note [Skip allreducing local_used_map_dev]
651   if (dynamic_graph_find_unused() || static_graph_first_iteration()) {
652     // Since it gets here, this param has been used for this iteration. We want
653     // to mark it in local_used_map_. During no_sync session, the same var can
654     // be set multiple times, which is OK as does not affect correctness. As
655     // long as it is used once during no_sync session, it is marked as used.
656     // Only set it as locally used if the grad is defined. Otherwise, hooks can
657     // be fired  with undefined grads, such as when not all outputs are used in
658     // DDP when computing loss. In this case, we don't want to mark it as
659     // locally used to ensure we don't touch the parameter's .grad field.
660     auto& variable = get_param_from_index(index);
661     runGradCallbackForVariable(variable, [&](auto& grad) {
662       if (grad.defined()) {
663         local_used_map_[static_cast<int64_t>(index)] = 1;
664       }
665       // The gradient is never modified.
666       return false;
667     });
668   }
669 
670   if (static_graph_first_iteration()) {
671     numGradHooksTriggeredMap_[index] += 1;
672     return;
673   }
674 
675   // Ignore if we don't expect to be called.
676   // This may be the case if the user wants to accumulate gradients
677   // for number of iterations before reducing them.
678   if (!expect_autograd_hooks_) {
679     return;
680   }
681 
682   grad_ready_order_indices_.push_back(static_cast<int64_t>(index));
683 
684   // If `find_unused_parameters_` is true there may be model parameters that
685   // went unused when computing the model output, they won't be part of the
686   // autograd graph, and won't receive gradients. These parameters are
687   // discovered in the `prepare_for_backward` function and their indexes stored
688   // in the `unused_parameters_` vector.
689   if (!has_marked_unused_parameters_) {
690     has_marked_unused_parameters_ = true;
691     for (const auto& unused_index : unused_parameters_) {
692       mark_variable_ready(unused_index);
693     }
694   }
695 
696   // Rebuild bucket only if 1) it is the first time to rebuild bucket 2)
697   // static_graph_ is true or find_unused_parameters_ is false,
698   // 3) this backward pass needs to run allreduce.
699   // Here, we just dump tensors and their parameter indices into
700   // rebuilt_params_ and rebuilt_param_indices_ based on gradient arriving
701   // order, and then at the end of finalize_backward(), buckets will be
702   // rebuilt based on rebuilt_params_ and rebuilt_param_indices_, and then
703   // will be broadcasted and initialized.
704   // If it is static graph, after 1st iteration, check if a variable
705   // is ready for communication based on numGradHooksTriggeredMap_.
706   if (static_graph_after_first_iteration()) {
707     REDUCER_CHECK(
708         numGradHooksTriggeredMapPerIteration_[index] > 0,
709         logger_,
710         "Your training graph has changed in this iteration, ",
711         "e.g., one parameter is unused in first iteration, but ",
712         "then got used in the second iteration. this is not ",
713         "compatible with static_graph set to True.");
714     if (--numGradHooksTriggeredMapPerIteration_[index] == 0) {
715       if (should_rebuild_buckets()) {
716         push_rebuilt_params(index);
717       }
718       // Finally mark variable for which this function was originally called.
719       mark_variable_ready(index);
720     }
721   } else {
722     if (should_rebuild_buckets()) {
723       push_rebuilt_params(index);
724     }
725     // Finally mark variable for which this function was originally called.
726     mark_variable_ready(index);
727   }
728 }
729 
all_reduce_local_used_map()730 void Reducer::all_reduce_local_used_map() {
731   // See Note [Skip allreducing local_used_map_dev]
732   // H2D from local_used_map_ to local_used_map_dev_
733   if (local_used_map_dev_.is_cuda() || local_used_map_dev_.is_privateuseone()) {
734     // Note [local_used_map_ -> local_used_map_dev copying]
735     // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
736     // We do async H2D to avoid the blocking overhead. The async copy and
737     // allreduce respect the current stream, so will be sequenced
738     // correctly.
739     //
740     // Correct sequencing with respect to host operations is also
741     // essential. The H2D copy_ is stream ordered, while the host's
742     // changes to local_used_map_ are host ordered. If a large backlog of
743     // cuda/privateuseone-stream work pushes the copy_ far into the future, and
744     // if no blocking calls occur between now and finalize_backward()** such
745     // that finalize_backward() re-zeroes local_used_map_ on the host
746     // before the stream executes the copy_, copy_ will read those zeros
747     // instead of the values we thought we told it to read here. Copying
748     // local_used_map_ to a pinned temporary (which the pinned caching
749     // allocator should supply asynchronously) avoids this nasty, rare
750     // race condition.
751     //
752     // ** In the hoped-for case where all params are used, DDP itself
753     // won't do any blocking work between now and the re-zeroing, so the
754     // danger is real.
755     //
756     // Defensively ensures local_used_map_tmp is distinct from
757     // local_used_map_
758     auto local_used_map_tmp = at::native::empty_like(
759         local_used_map_,
760         c10::optTypeMetaToScalarType(local_used_map_.options().dtype_opt()),
761         local_used_map_.options().layout_opt(),
762         local_used_map_.options().device_opt(),
763         true /* pinned_memory */);
764     // Paranoid asserts here because in some workloads, the pinned
765     // allocator behaves in a way we don't understand, and may be bugged.
766     // See https://github.com/pytorch/pytorch/pull/54474
767     TORCH_INTERNAL_ASSERT(local_used_map_tmp.is_pinned());
768     TORCH_INTERNAL_ASSERT(
769         local_used_map_tmp.data_ptr() != local_used_map_.data_ptr());
770     local_used_map_tmp.copy_(local_used_map_);
771     local_used_map_dev_.copy_(local_used_map_tmp, true);
772   } else if (local_used_map_dev_.is_mtia()) {
773     // MTIA probably will have special logic in the future, following code might
774     // be changed drastically. Therefore, a new if case is created for MTIA, for
775     // now, the implementation is similar to the CUDA/privateuseone one, except
776     // for the pin memory step.
777     auto local_used_map_tmp = at::native::empty_like(
778         local_used_map_,
779         c10::optTypeMetaToScalarType(local_used_map_.options().dtype_opt()),
780         local_used_map_.options().layout_opt(),
781         local_used_map_.options().device_opt());
782     local_used_map_tmp.copy_(local_used_map_);
783     local_used_map_dev_.copy_(local_used_map_tmp, true);
784   } else {
785     local_used_map_dev_.copy_(local_used_map_, true);
786   }
787   std::vector<at::Tensor> temp_local_used_map_dev_vec_ = {local_used_map_dev_};
788   local_used_work_ = process_group_->allreduce(temp_local_used_map_dev_vec_);
789 }
790 
get_param_from_index(size_t index)791 at::Tensor& Reducer::get_param_from_index(size_t index) {
792   const auto& bucket_index = variable_locators_[index];
793   auto& bucket = buckets_[bucket_index.bucket_index];
794   // Cannot simply access variable via `bucket.variables[variable_index]` since
795   // return value is used in `runGradCallbackForVariable()` which does not
796   // accept const tensors.
797   auto& variable = bucket.variables[bucket_index.intra_bucket_index];
798   return variable;
799 }
800 
checkAndRaiseMarkedTwiceError(size_t index)801 void Reducer::checkAndRaiseMarkedTwiceError(size_t index) {
802   // Something is wrong if all variables contained in this bucket have
803   // already been marked as ready.
804   // We don't expect the same variable to be marked ready twice.
805   bool marked_twice =
806       perIterationReadyParams_.find(index) != perIterationReadyParams_.end();
807 
808   if (marked_twice) {
809     // Report index of param that has been marked twice. In debug mode, also
810     // report fully qualified parameter name.
811     auto param_name = param_names_.find(index);
812     const bool found_param_name = param_name != param_names_.end();
813     TORCH_INTERNAL_ASSERT(
814         ddp_debug_level_ == c10d::DebugLevel::Off || found_param_name,
815         "Expected to find parameter name in debug mode.");
816     std::string paramInfo = c10::str(
817         "Parameter at index ",
818         index,
819         found_param_name ? c10::str(" with name ", param_name->second) : "",
820         " has been marked as ready twice. This means that multiple autograd engine ",
821         " hooks have fired for this particular parameter during this iteration.");
822     // param_names_ is empty in debug mode.
823     if (!found_param_name) {
824       paramInfo += c10::str(
825           " You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either",
826           " INFO or DETAIL to print parameter names for further debugging.");
827     }
828     std::string common_error = c10::str(
829         "Expected to mark a variable ready only once. ",
830         "",
831         "This error is caused by one of the following reasons: ",
832         "1) Use of a module parameter outside the `forward` function. ",
833         "Please make sure model parameters are not shared across multiple ",
834         "concurrent forward-backward passes. or try to use _set_static_graph() ",
835         "as a workaround if this module graph does not change ",
836         "during training loop.",
837         "2) Reused parameters in multiple reentrant backward passes. For ",
838         "example, if you use multiple `checkpoint` functions to wrap the ",
839         "same part of your model, it would result in the same set of ",
840         "parameters been used by different reentrant backward passes ",
841         "multiple times, and hence marking a variable ready multiple times. ",
842         "DDP does not support such use cases in default. You can try to ",
843         "use _set_static_graph() as a workaround if your module graph ",
844         "does not change over iterations.");
845 
846     common_error += c10::str("\n", paramInfo);
847 
848     REDUCER_CHECK(
849         has_marked_unused_parameters_,
850         logger_,
851         common_error,
852         "3) Incorrect unused parameter detection. The return value of the ",
853         "`forward` function is inspected by the distributed data parallel ",
854         "wrapper to figure out if any of the module's parameters went ",
855         "unused. For unused parameters, DDP would not expect gradients from ",
856         "then. However, if an unused parameter becomes part of the autograd ",
857         "graph at a later point in time (e.g., in a reentrant backward when ",
858         "using `checkpoint`), the gradient will show up unexpectedly. If all ",
859         "parameters in the model participate in the backward pass, you can ",
860         "disable unused parameter detection by passing the keyword argument ",
861         "`find_unused_parameters=False` to ",
862         "`torch.nn.parallel.DistributedDataParallel`. If unused parameters ",
863         "in the model do not change over iterations, You can try to use ",
864         "_set_static_graph() as a workaround if this module graph does not ",
865         "change during training loop.");
866     REDUCER_CHECK(!has_marked_unused_parameters_, logger_, common_error);
867   }
868 }
869 
mark_variable_ready(size_t variable_index)870 void Reducer::mark_variable_ready(size_t variable_index) {
871   REDUCER_CHECK(
872       variable_index < variable_locators_.size(),
873       logger_,
874       "Out of range variable index.");
875 
876   checkAndRaiseMarkedTwiceError(variable_index);
877   perIterationReadyParams_.insert(variable_index);
878   backward_stats_[variable_index] =
879       current_time_in_nanos() - backward_compute_start_time_;
880 
881   // Any time we mark a variable ready (be it in line due to unused parameters,
882   // or via an autograd hook), we require a call to the finalize function. If
883   // this doesn't happen before the next iteration (or call to
884   // `prepare_for_backwards`), we know something is wrong.
885   require_finalize_ = true;
886 
887   const auto& bucket_index = variable_locators_[variable_index];
888   auto& bucket = buckets_[bucket_index.bucket_index];
889 
890   set_divide_factor();
891 
892   if (bucket.expect_sparse_gradient) {
893     mark_variable_ready_sparse(variable_index);
894   } else {
895     mark_variable_ready_dense(variable_index);
896   }
897 
898   // TODO(@pietern): Make this work for both CPU/CUDA tensors.
899   // When using CPU tensors we don't need to do this.
900   // Record event so that we can wait for all of them.
901   // auto& event = bucket.events[bucket_index.intra_bucket_index];
902   // event.record();
903 
904   // Check if this was the final gradient for this bucket.
905   if (--bucket.pending == 0) {
906     mark_bucket_ready(bucket_index.bucket_index);
907   }
908 
909   // Run finalizer function and kick off reduction for local_used_map once the
910   // final bucket was marked ready.
911   if (next_bucket_ == buckets_.size()) {
912     if (dynamic_graph_find_unused()) {
913       all_reduce_local_used_map();
914     }
915 
916     torch::autograd::Engine::get_default_engine().queue_callback([this] {
917       std::lock_guard<std::mutex> lock(this->mutex_);
918       if (should_collect_runtime_stats()) {
919         record_backward_compute_end_time();
920       }
921       // Check that all buckets were completed and had their work kicked off.
922       TORCH_INTERNAL_ASSERT(next_bucket_ == buckets_.size());
923       if (static_graph_after_first_iteration() && should_rebuild_buckets()) {
924         for (const auto& unused_index : unused_parameters_) {
925           push_rebuilt_params(unused_index);
926         }
927       }
928       this->finalize_backward();
929     });
930   }
931 }
932 
run_comm_hook(GradBucket & grad_bucket)933 c10::intrusive_ptr<c10::ivalue::Future> Reducer::run_comm_hook(
934     GradBucket& grad_bucket) {
935   if (comm_hook_ == nullptr) {
936     return run_allreduce_hook(grad_bucket);
937   } else {
938     return comm_hook_->runHook(grad_bucket);
939   }
940 }
941 
run_allreduce_hook(GradBucket & grad_bucket)942 c10::intrusive_ptr<c10::ivalue::Future> Reducer::run_allreduce_hook(
943     GradBucket& grad_bucket) {
944   _AllReduceBySumCommHook allreduce_hook(process_group_);
945   return allreduce_hook.runHook(grad_bucket);
946 }
947 
all_reduce_bucket(Bucket & bucket)948 void Reducer::all_reduce_bucket(Bucket& bucket) {
949   auto variables_for_bucket = get_variables_for_bucket(next_bucket_, bucket);
950   // TODO(@pietern): Ensure proper synchronization with the CUDA events
951   // that recorded copies into this `gradients` tensor. If these copies are
952   // executed on non-default streams, the current stream for the device
953   // that holds the `gradients` tensor must wait on these events.
954   //
955   // As long as autograd uses the default stream for every device,
956   // these operations are implicitly sequenced, and we don't need to
957   // do any extra synchronization here.
958   const auto& tensor = bucket.gradients;
959 
960   // TODO(@egienvalue): remove special case after view ops are fully
961   // supported on MTIA.
962   // If the bucket.gradients is on MTIA, bucket.bucket_views_in might not
963   // point to the same storage as bucket.gradients due to the special
964   // memory layout. It has to explicitly copy the data back to 1-D gradients.
965   if (tensor.is_mtia()) {
966     for (const auto i : c10::irange(bucket.variables.size())) {
967       const auto offset = bucket.offsets[i];
968       const auto length = bucket.lengths[i];
969       if (!bucket.bucket_views_in[i].is_alias_of(tensor)) {
970         tensor
971             .narrow(
972                 0, static_cast<int64_t>(offset), static_cast<int64_t>(length))
973             .copy_(bucket.bucket_views_in[i].flatten());
974       }
975     }
976   }
977 
978   GradBucket grad_bucket(
979       next_bucket_,
980       buckets_.size(),
981       tensor,
982       bucket.offsets,
983       bucket.lengths,
984       bucket.sizes_vec,
985       variables_for_bucket,
986       bucket.sparse_tensor_indices);
987   bucket.future_work = run_comm_hook(grad_bucket);
988 }
989 
get_variables_for_bucket(size_t bucket_index,const Bucket & bucket) const990 std::vector<at::Tensor> Reducer::get_variables_for_bucket(
991     size_t bucket_index,
992     const Bucket& bucket) const {
993   // Check if we have cached mapping previously.
994   if (has_rebuilt_bucket_ &&
995       cached_variables_for_bucket_.find(bucket_index) !=
996           cached_variables_for_bucket_.end()) {
997     return cached_variables_for_bucket_[bucket_index];
998   }
999   std::vector<at::Tensor> variables_for_bucket;
1000   variables_for_bucket.reserve(bucket.variable_indices.size());
1001   for (const auto& variable_index : bucket.variable_indices) {
1002     // Grab bucket index where gradient is located using variable_locators_.
1003     auto& bucket_index_for_variable = variable_locators_[variable_index];
1004     // Grab the actual model parameter.
1005     auto& variable =
1006         bucket.variables[bucket_index_for_variable.intra_bucket_index];
1007     variables_for_bucket.emplace_back(variable);
1008   }
1009 
1010   if (has_rebuilt_bucket_) {
1011     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1012         cached_variables_for_bucket_.find(bucket_index) ==
1013         cached_variables_for_bucket_.end());
1014     cached_variables_for_bucket_.insert(
1015         {bucket_index, std::move(variables_for_bucket)});
1016     return cached_variables_for_bucket_[bucket_index];
1017   } else {
1018     return variables_for_bucket;
1019   }
1020 }
1021 
1022 // Called when the bucket at the specified index is ready to be reduced.
mark_bucket_ready(size_t bucket_index)1023 void Reducer::mark_bucket_ready(size_t bucket_index) {
1024   TORCH_INTERNAL_ASSERT(bucket_index >= next_bucket_);
1025 
1026   // Buckets are reduced in sequence. Ignore this bucket if
1027   // it's not its turn to be reduced.
1028   if (bucket_index > next_bucket_) {
1029     return;
1030   }
1031 
1032   // Keep going, until we either:
1033   // - have kicked off reduction for all buckets, or
1034   // - found a bucket that's not yet ready for reduction.
1035   for (; next_bucket_ < buckets_.size() && buckets_[next_bucket_].pending == 0;
1036        next_bucket_++) {
1037     num_buckets_ready_++;
1038     if (num_buckets_ready_ == 1 && should_collect_runtime_stats()) {
1039       record_backward_comm_start_time();
1040     }
1041     auto& bucket = buckets_[next_bucket_];
1042     all_reduce_bucket(bucket);
1043   }
1044 }
1045 
install_futures(c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futs)1046 void Reducer::install_futures(
1047     c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futs) {
1048   // Append instead of overwrite so that this method can be called multiple
1049   // times in one iteration.
1050   if (!installed_futures_) {
1051     installed_futures_ = std::move(futs);
1052   } else {
1053     installed_futures_->append(futs);
1054   }
1055 }
1056 
initialize_buckets(std::vector<std::vector<size_t>> bucket_indices)1057 void Reducer::initialize_buckets(
1058     std::vector<std::vector<size_t>> bucket_indices) {
1059   // If initialize_buckets is called inside DDP constructor, then
1060   // it does not matter rpc context ptr is nullptr or not, as grad
1061   // will not be mutated.
1062   // If initialize_buckets is called during training loop, e.g, inside
1063   // rebuild_buckets(), since grad could be mutated and be pointed to
1064   // bucket_view, then it needs to check rpc context ptr is nullptr or not,
1065   // If rpc context ptr is nullptr, mutate variable.grad(); otherwise,
1066   // mutate grad in rpc context.
1067 #ifndef _WIN32
1068   using torch::distributed::autograd::ThreadLocalDistAutogradContext;
1069   this->rpc_context_.set(ThreadLocalDistAutogradContext::getContextPtr());
1070 #endif
1071 
1072   // This shouldn't be called if we're expecting autograd hooks to fire.
1073   REDUCER_CHECK(
1074       !expect_autograd_hooks_,
1075       logger_,
1076       "`initialize_buckets` must NOT be called during autograd execution.");
1077 
1078   // Clear current bucket assignment.
1079   buckets_.clear();
1080   variable_locators_.clear();
1081 
1082   // Ensure we have a bucket index for every variable.
1083   variable_locators_.resize(params_.size());
1084 
1085   // Iterate over buckets.
1086   const auto bucket_count = bucket_indices.size();
1087   buckets_.reserve(bucket_count);
1088   for (const auto bucket_index : c10::irange(bucket_count)) {
1089     Bucket bucket;
1090 
1091     // TODO(@pietern): Validate indices.
1092     // Must be non-empty, unique, and unique across buckets.
1093     REDUCER_CHECK(
1094         !bucket_indices[bucket_index].empty(),
1095         logger_,
1096         "Empty bucket specified.");
1097 
1098     // Variables that expect sparse gradients must have their own bucket.
1099     if (bucket_indices[bucket_index].size() == 1) {
1100       const auto variable_index = bucket_indices[bucket_index].front();
1101       bucket.expect_sparse_gradient = expect_sparse_gradients_[variable_index];
1102     } else {
1103       for (const auto variable_index : bucket_indices[bucket_index]) {
1104         REDUCER_CHECK(
1105             !expect_sparse_gradients_[variable_index],
1106             logger_,
1107             "Buckets with more than one variable cannot include variables ",
1108             "that expect a sparse gradient.");
1109       }
1110     }
1111 
1112     if (bucket.expect_sparse_gradient) {
1113       const auto variable_index = bucket_indices[bucket_index].front();
1114       const auto& variable = params_[variable_index];
1115       TORCH_INTERNAL_ASSERT(bucket_indices[bucket_index].size() == 1);
1116       bucket.variables = {variable};
1117     } else {
1118       at::TensorOptions options;
1119       // The start index of the variable in the flattened tensor.
1120       size_t offset = 0;
1121 
1122       // Reserve enough space for the per-variable fields stored in the bucket
1123       // for efficiency.
1124       const size_t num_variables = bucket_indices[bucket_index].size();
1125       bucket.variables.reserve(num_variables);
1126       bucket.offsets.reserve(num_variables);
1127       bucket.lengths.reserve(num_variables);
1128       bucket.sizes_vec.reserve(num_variables);
1129 
1130       // Iterate over bucket variables.
1131       for (const auto variable_index : bucket_indices[bucket_index]) {
1132         TORCH_INTERNAL_ASSERT(
1133             variable_index < params_.size(),
1134             "Out of range variable index specified.");
1135         const auto& variable = params_[variable_index];
1136         if (!options.has_device()) {
1137           options = options.device(variable.device());
1138         } else {
1139           REDUCER_CHECK(
1140               variable.device() == options.device(),
1141               logger_,
1142               "All parameters in a bucket must be ",
1143               "placed on the same device.");
1144         }
1145         if (!options.has_dtype()) {
1146           options = options.dtype(variable.dtype());
1147         } else {
1148           REDUCER_CHECK(
1149               variable.dtype() == options.dtype(),
1150               logger_,
1151               "All parameters in a bucket must have the same dtype.");
1152         }
1153         const auto length = variable.numel();
1154         bucket.variables.push_back(variable);
1155         bucket.offsets.push_back(offset);
1156         bucket.lengths.push_back(length);
1157         bucket.sizes_vec.push_back(variable.sizes());
1158         offset += length;
1159       }
1160 
1161       // Allocate the bucket's flattened `gradients` tensor.
1162       // Make gradient type in the reduced precision if mixed precision is
1163       // enabled. This ensures that the type is correct when e.g. rebuilding
1164       // buckets.
1165       if (mixed_precision_param_dtype_) {
1166         options = options.dtype(*mixed_precision_param_dtype_);
1167       }
1168       bucket.gradients = at::empty({static_cast<long>(offset)}, options);
1169 
1170       // Note:  "Gradient Layout Contract"
1171       //
1172       // Here, create views into the `gradients` tensor for each variable's
1173       // grad. Views serve as entry points to `copy_()` each grad's data in/out
1174       // of the flattened `gradients` tensor.
1175       //
1176       // Gradients may have dense memory but non-row-major-contiguous strides
1177       // (e.g. channels_last or channels_last_3d). For coalesced accesses
1178       // during copy_s, it's beneficial for each view's layout to match its
1179       // grad's layout.
1180       //
1181       // Specifically, we expect torch/csrc/autograd/functions/accumulate_grad.h
1182       // produces grads that obey the "Gradient Layout Contract":
1183       //   (1) if variable.is_non_overlapping_and_dense(), the stashed grad's
1184       //       strides match variable.
1185       //   (2) else, stashed grad is rowmajor contiguous.
1186       // and create views to match.
1187       //
1188       // If AccumulateGrad breaks the contract, and produces a grad with an
1189       // unexpected layout, performance will degrade due to poor memory access
1190       // patterns when copy_ing grad data in and out of its bucket view.
1191       // However, numerics remain correct, because the bucket view is the same
1192       // on either end of the raw allreduce.  bucket_view_in.copy(grad)
1193       // tranposes
1194       // (+ densifies) to the bucket view's layout, the data is allreduced,
1195       // then grad.copy_(bucket_view_out) transposes it back to grad's layout.
1196       //
1197       // The only way the numerics can go haywire is if the bucket views
1198       // themselves have different layouts across processes.
1199       // Bucket views' sizes and strides are set based on param layouts, using
1200       // the same logic that (we expect) AccumulateGrad uses for their grads.
1201       // Therefore, the only way a bucket view could have different layouts in
1202       // different processes is if its param has a different layout in
1203       // different processes. We can check that param layouts match across
1204       // processes in Reducer's constructor by allreducing some metadata.
1205       // Checking just once won't catch if someone messes with
1206       // param layouts over time, but not messing with params after DDP
1207       // construction is already a documented constraint.
1208       initialize_bucket_views(bucket);
1209     }
1210 
1211     // Map participating variables to this bucket.
1212     size_t intra_bucket_index = 0;
1213     for (const auto variable_index : bucket_indices[bucket_index]) {
1214       TORCH_INTERNAL_ASSERT(
1215           variable_index < variable_locators_.size(),
1216           "Out of range variable index specified.");
1217       variable_locators_[variable_index] =
1218           VariableLocator(bucket_index, intra_bucket_index++);
1219     }
1220     bucket.variable_indices = std::move(bucket_indices[bucket_index]);
1221 
1222     buckets_.push_back(std::move(bucket));
1223   }
1224 }
1225 
1226 // (see Note:  "Gradient Layout Contract" in initialize_buckets).
initialize_bucket_views(Reducer::Bucket & bucket)1227 void Reducer::initialize_bucket_views(Reducer::Bucket& bucket) {
1228   const auto& gradients = bucket.gradients;
1229   for (const auto i : c10::irange(bucket.variables.size())) {
1230     auto& v = bucket.variables[i];
1231     const auto offset = bucket.offsets[i];
1232     const auto length = bucket.lengths[i];
1233     // TODO(@egienvalue): remove special case after view ops are fully
1234     // supported on MTIA.
1235     // In general, on MTIA, due to the special memory layout, it doesn't
1236     // support as_strided which creates a view tensor and aten::view will
1237     // create a new tensor on MTIA for now.
1238     if (v.is_non_overlapping_and_dense() && !v.is_mtia()) {
1239       // If the param's memory is dense, match its layout, anticipating
1240       // the autograd engine (AccumulateGrad) will also create gradients
1241       // matching its layout.
1242       bucket.bucket_views_in.push_back(
1243           gradients.as_strided(v.sizes(), v.strides(), offset));
1244     } else {
1245       // Fall back to a C-style contiguous view, again anticipating
1246       // AccumulateGrad will do the same when stashing grads for non-dense
1247       // params.
1248       bucket.bucket_views_in.push_back(
1249           gradients
1250               .narrow(
1251                   0, static_cast<int64_t>(offset), static_cast<int64_t>(length))
1252               .view(v.sizes()));
1253     }
1254     // By default `bucket_views_out` and `bucket_views_in` are
1255     // essentially the same thing.
1256     bucket.bucket_views_out = bucket.bucket_views_in;
1257 
1258     // If gradient_as_bucket_view_ is set as true, then there are two cases to
1259     // handle: initialize_bucket_views could be called inside initialize_buckets
1260     // when rebuild_buckets, if grad has already been defined/calculated in
1261     // previous iteration, old grad needs to be copied into new bucket_view and
1262     // let grad point to the new bucket_view, initialize_bucket_views could also
1263     // be called inside initialize_buckets during construction. Grads are not
1264     // defined during construction time, in this case, do not let grad point to
1265     // bucket_view, because grads should be kept as being undefined for globally
1266     // unused parameters.
1267     if (gradient_as_bucket_view_) {
1268       auto& bucket_view = bucket.bucket_views_in.back();
1269       runGradCallbackForVariable(v, [&](auto& grad) {
1270         if (grad.defined() && !grad.is_alias_of(bucket_view)) {
1271           bucket_view.copy_(grad);
1272           grad = bucket_view;
1273           // The grad is modified and needs to be written back.
1274           return true;
1275         }
1276         // The grad is not modified and does not need to be written back.
1277         return false;
1278       });
1279     }
1280   }
1281 }
1282 
1283 // (see Note:  "Gradient Layout Contract" in initialize_buckets).
populate_bucket_views_out(Reducer::Bucket & bucket,at::Tensor & tensor)1284 void Reducer::populate_bucket_views_out(
1285     Reducer::Bucket& bucket,
1286     at::Tensor& tensor) {
1287   bucket.bucket_views_out.clear();
1288   for (const auto i : c10::irange(bucket.variables.size())) {
1289     const auto& v = bucket.variables[i];
1290     const auto offset = bucket.offsets[i];
1291     const auto length = bucket.lengths[i];
1292     // TODO(@egienvalue): remove special case after view ops are fully
1293     // supported on MTIA.
1294     // In general, on MTIA, due to the special memory layout, it doesn't
1295     // support as_strided which creates a view tensor and aten::view will
1296     // create a new tensor on MTIA for now.
1297     if (v.is_non_overlapping_and_dense() && !v.is_mtia()) {
1298       // If the param's memory is dense, match its layout, anticipating
1299       // the autograd engine (AccumulateGrad) will also create gradients
1300       // matching its layout.
1301       bucket.bucket_views_out.push_back(
1302           tensor.as_strided(v.sizes(), v.strides(), offset));
1303     } else {
1304       // Fall back to a C-style contiguous view, again anticipating
1305       // AccumulateGrad will do the same when stashing grads for non-dense
1306       // params.
1307       bucket.bucket_views_out.push_back(
1308           tensor
1309               .narrow(
1310                   0, static_cast<int64_t>(offset), static_cast<int64_t>(length))
1311               .view(v.sizes()));
1312     }
1313   }
1314 }
1315 
prepare_for_forward()1316 void Reducer::prepare_for_forward() {
1317   std::lock_guard<std::mutex> lock(mutex_);
1318   num_iterations_++;
1319   if (should_collect_runtime_stats()) {
1320     record_forward_compute_start_time();
1321   }
1322 }
1323 
reset_bucket_counting()1324 void Reducer::reset_bucket_counting() {
1325   next_bucket_ = 0;
1326   // Reset num_buckets_ready_ at the beginning of backward computation
1327   // in each iteration.
1328   num_buckets_ready_ = 0;
1329 
1330   for (auto& bucket : buckets_) {
1331     bucket.pending = bucket.variables.size();
1332   }
1333 
1334   if (static_graph_) {
1335     numGradHooksTriggeredMapPerIteration_ = numGradHooksTriggeredMap_;
1336   }
1337 }
1338 
1339 // Traverse the autograd graph starting at the specified output.
1340 // All parameters for which we have a pointer to their gradient accumulation
1341 // functions, but don't show up in the autograd graph will be marked ready for
1342 // for reduction as soon as the first autograd hook is called. This is not
1343 // done immediately because the model output may be ignored, and we only
1344 // want to start performing reductions on `torch.autograd.backward()`.
search_unused_parameters(const std::vector<torch::autograd::Variable> & outputs)1345 void Reducer::search_unused_parameters(
1346     const std::vector<torch::autograd::Variable>& outputs) {
1347   std::unordered_set<torch::autograd::Node*> seen;
1348   std::vector<torch::autograd::Node*> queue;
1349 
1350   RECORD_FUNCTION(
1351       "torch.distributed.ddp.reducer::search_unused_parameters",
1352       std::vector<c10::IValue>());
1353 
1354   // Seed queue with the grad functions of all outputs.
1355   for (const auto& output : outputs) {
1356     const auto& grad_fn = output.grad_fn();
1357     if (grad_fn) {
1358       queue.push_back(grad_fn.get());
1359     }
1360   }
1361 
1362   // Traverse the autograd graph starting at the specified output.
1363   while (!queue.empty()) {
1364     auto fn = queue.back();
1365     queue.pop_back();
1366     for (const auto& edge : fn->next_edges()) {
1367       if (auto next_ptr = edge.function.get()) {
1368         const bool was_inserted = seen.insert(next_ptr).second;
1369         if (was_inserted) {
1370           queue.push_back(next_ptr);
1371         }
1372       }
1373     }
1374   }
1375 
1376   // Find accumulator functions that don't show up in this graph.
1377   for (const auto& it : gradAccToVariableMap_) {
1378     // If the accumulator function is present in the graph, we know
1379     // a gradient will be computed for the corresponding parameter.
1380     if (seen.count(it.first) == 0) {
1381       if (ddp_debug_level_ == c10d::DebugLevel::Detail) {
1382         const auto param_info = param_names_.find(it.second);
1383         TORCH_INTERNAL_ASSERT(
1384             param_info != param_names_.end(),
1385             "Did not find variable index ",
1386             it.second,
1387             " in DDP parameter name mapping!");
1388         const auto param_name = param_info->second;
1389         LOG(INFO) << "[Rank " << process_group_->getRank() << "]: "
1390                   << "Parameter " << param_name << " at index " << it.second
1391                   << " is marked as unused.";
1392       }
1393       unused_parameters_.push_back(it.second);
1394     }
1395   }
1396 
1397   // Warn user about unnecessary perf hit if all parameters were used in
1398   // forward.
1399   if (unused_parameters_.empty()) {
1400     TORCH_WARN_ONCE(
1401         "find_unused_parameters=True was specified in DDP constructor, "
1402         "but did not find any unused parameters in the forward pass. This flag "
1403         "results in an extra traversal of the autograd graph every iteration, "
1404         " which can adversely affect performance. If your model indeed never "
1405         "has any unused parameters in the forward pass, consider turning this "
1406         "flag off. Note that this warning may be a false positive if your model "
1407         "has flow control causing later iterations to have unused parameters.");
1408   }
1409   if (!static_graph_ && ddp_graph_static_) {
1410     if (num_iterations_ > 1) {
1411       // Graph is still static if the set of unused parameters did not change.
1412       ddp_graph_static_ =
1413           prev_iteration_unused_parameters_ == unused_parameters_;
1414 
1415       if (!ddp_graph_static_) {
1416         // Log graph is not static. Logger takes care of ensuring this is done
1417         // only once to avoid overhead.
1418         logger_.lock()->log_if_graph_static(false);
1419       }
1420     }
1421     prev_iteration_unused_parameters_ = unused_parameters_;
1422   }
1423 }
1424 
prepare_for_backward(const std::vector<torch::autograd::Variable> & outputs)1425 void Reducer::prepare_for_backward(
1426     const std::vector<torch::autograd::Variable>& outputs) {
1427   std::lock_guard<std::mutex> lock(mutex_);
1428 
1429   backward_compute_start_time_ = current_time_in_nanos();
1430   if (should_collect_runtime_stats()) {
1431     record_backward_compute_start_time();
1432   }
1433 
1434   // Reset accounting.
1435   expect_autograd_hooks_ = true;
1436   // Clear gradient ready order as it can be different in the next iteration.
1437   grad_ready_order_indices_.clear();
1438 
1439   reset_bucket_counting();
1440 
1441   // Reset unused parameter accounting.
1442   has_marked_unused_parameters_ = false;
1443   // Reset per iteration marked ready parameters.
1444   perIterationReadyParams_.clear();
1445 
1446   // If static graph is not set, search graph to detect unused parameters.
1447   // When static graph is set, unused_parameters_ will be detected and will
1448   // not change after 1st iteration.
1449   // If static_graph_ = false and find_unused_parameters_ is false,
1450   // we assume that autograd hooks for ALL variables will be called,
1451   // and we don't have to search the autograd graph for presence of these hooks.
1452   if (dynamic_graph_find_unused()) {
1453     unused_parameters_.clear();
1454     search_unused_parameters(outputs);
1455   }
1456 }
1457 
copy_bucket_to_grad(at::Tensor & variable,Reducer::Bucket & bucket,size_t intra_bucket_index,bool global_unused)1458 void Reducer::copy_bucket_to_grad(
1459     at::Tensor& variable,
1460     Reducer::Bucket& bucket,
1461     size_t intra_bucket_index,
1462     bool global_unused) {
1463   const auto& bucket_view = bucket.bucket_views_out[intra_bucket_index];
1464   runGradCallbackForVariable(variable, [&](auto& grad) {
1465     // If a parameter is globally unused, we keep its grad untouched.
1466     if (!global_unused) {
1467       if (!grad.defined()) {
1468         // Creates grad according to the "Gradient Layout Contract"
1469         // (see torch/csrc/autograd/functions/accumulate_grad.h)
1470         grad =
1471             torch::autograd::utils::clone_obey_contract(bucket_view, variable);
1472       } else {
1473         grad.copy_(bucket_view);
1474       }
1475       // The grad is modified and needs to be written back.
1476       return true;
1477     }
1478     // The grad is not modified.
1479     return false;
1480   });
1481 }
1482 
getUnmarkedParamsForIteration()1483 std::vector<std::string> Reducer::getUnmarkedParamsForIteration() {
1484   std::vector<std::string> unMarkedParamNames;
1485   for (const auto& it : param_names_) {
1486     if (perIterationReadyParams_.find(it.first) ==
1487         perIterationReadyParams_.end()) {
1488       unMarkedParamNames.push_back(it.second);
1489     }
1490   }
1491   return unMarkedParamNames;
1492 }
1493 
getUnmarkedParamIndicesForIteration()1494 std::vector<size_t> Reducer::getUnmarkedParamIndicesForIteration() {
1495   std::vector<size_t> unmarked_param_indices;
1496   const auto variable_count = params_.size();
1497   for (const auto variable_index : c10::irange(variable_count)) {
1498     if (perIterationReadyParams_.find(variable_index) ==
1499         perIterationReadyParams_.end()) {
1500       unmarked_param_indices.push_back(variable_index);
1501     }
1502   }
1503   return unmarked_param_indices;
1504 }
1505 
1506 // A bucket with one or more dense tensors needs to be unflattened.
finalize_bucket_dense(Bucket & bucket)1507 void Reducer::finalize_bucket_dense(Bucket& bucket) {
1508   for (const auto intra_bucket_index : c10::irange(bucket.variables.size())) {
1509     auto& variable = bucket.variables[intra_bucket_index];
1510 
1511     bool global_unused = false;
1512     // See Note [Skip allreducing local_used_map_dev]
1513     if (static_graph_ || find_unused_parameters_) {
1514       // Determine if this param has been used globally or not.
1515       //
1516       // If the variable was used locally, it is also used globally and then
1517       // we don't need to wait for the reduction. Otherwise we lazily wait for
1518       // the reduction to complete, only when we see a variable that was
1519       // unused locally. Then we end up delaying the synchronization point
1520       // that local_used_work_->wait() implies. If we don't have any unused
1521       // parameters at all, we can skip waiting for the work to complete
1522       // altogether, and cause negligible performance overhead for models
1523       // where all parameters are used. Such lazily waiting means minimizing
1524       // performance impact for the big majority of models where all
1525       // parameters are always used. Then we only pay the overhead cost if
1526       // there is indeed a parameter that is locally unused, because we need
1527       // to check if it's also globally unused.
1528       int64_t variable_index =
1529           static_cast<int64_t>(bucket.variable_indices[intra_bucket_index]);
1530       // Note: global_unused might not be global yet. As we lazily wait for
1531       // the reduction to complete, it becomes really global only if we get to
1532       // the point as below where we wait for the reduction work, make D2H
1533       // copy, and update global_unused with the real global consensus, i.e.
1534       // local_used_map_reduced_ is true.
1535       global_unused = local_used_map_[variable_index].item<int>() == 0;
1536       if (global_unused && !local_used_map_reduced_) {
1537         // Wait for local_used_map reduction to complete.
1538         local_used_work_->wait();
1539         // D2H from local_used_map_dev_ to local_used_map_
1540         // Blocking copy, if local_used_map_dev_ is cuda
1541         local_used_map_.copy_(local_used_map_dev_);
1542 
1543         global_unused = local_used_map_[variable_index].item<int>() == 0;
1544         local_used_map_reduced_ = true;
1545       }
1546     }
1547 
1548     if (!gradient_as_bucket_view_) {
1549       if (optim_in_backward_) {
1550         // Return early if optimizer has already run.
1551         runGradCallbackForVariable(variable, [&](auto& grad) { return true; });
1552       } else {
1553         RECORD_FUNCTION(
1554             "torch.distributed.ddp.reducer::copy_bucket_to_grad",
1555             std::vector<c10::IValue>({variable}));
1556         copy_bucket_to_grad(
1557             variable, bucket, intra_bucket_index, global_unused);
1558       }
1559     } else {
1560       const auto& bucket_view_out = bucket.bucket_views_out[intra_bucket_index];
1561       auto& bucket_view_in = bucket.bucket_views_in[intra_bucket_index];
1562       // If a communication hook is registered, then `bucket_view_out` stores
1563       // the allreduced results in a newly allocated tensor, so we copy
1564       // `bucket_view_out` back to `bucket_view_in` for this gradient.
1565       if (!bucket_view_in.is_alias_of(bucket_view_out)) {
1566         bucket_view_in.copy_(bucket_view_out);
1567       }
1568       runGradCallbackForVariable(variable, [&](auto& grad) {
1569         if (optim_in_backward_) {
1570           // Return early if optimizer has already run.
1571           return true;
1572         }
1573         // If a parameter is globally unused, we keep its grad untouched.
1574         if (!global_unused) {
1575           // If grad is globally used but locally unused, let grad point to
1576           // bucket_view_in
1577           if (!grad.defined()) {
1578             grad = bucket_view_in;
1579           } else {
1580             if (!grad.is_alias_of(bucket_view_in)) {
1581               REDUCER_CHECK(
1582                   false,
1583                   logger_,
1584                   "Detected at least one parameter gradient is not the "
1585                   "expected DDP bucket view with gradient_as_bucket_view=True. "
1586                   "This may happen (for example) if multiple allreduce hooks "
1587                   "were registered onto the same parameter. If you hit this error, "
1588                   "please file an issue with a minimal repro.");
1589             }
1590           }
1591           // The grad is modified and needs to be written back.
1592           return true;
1593         }
1594         // The grad is not modified.
1595         return false;
1596       });
1597     }
1598   }
1599 }
1600 
finalize_backward()1601 void Reducer::finalize_backward() {
1602   // No longer expect autograd hooks to fire after this function returns.
1603   TORCH_INTERNAL_ASSERT(expect_autograd_hooks_);
1604   expect_autograd_hooks_ = false;
1605   // reset for the next iteration
1606   first_autograd_hook_called_ = false;
1607 
1608   // No longer require call to finalize after this function returns.
1609   TORCH_INTERNAL_ASSERT(require_finalize_);
1610   require_finalize_ = false;
1611 
1612   // Wait for asynchronous reduction to complete, and unflatten the bucket's
1613   // flattened `gradients` tensor.
1614   for (auto& bucket : buckets_) {
1615     // See Note [DDP Communication Hook]
1616     TORCH_INTERNAL_ASSERT(
1617         bucket.future_work,
1618         "Expected bucket.future_work not to be null. "
1619         "This may indicate that communication hook was not properly installed.");
1620     bucket.future_work->wait();
1621     auto future_result = comm_hook_ == nullptr
1622         ? detail::parseCppCommHookResult(bucket.future_work->value())
1623         : comm_hook_->parseHookResult(bucket.future_work->value());
1624     if (bucket.expect_sparse_gradient) {
1625       // sparse metadata is set so the bucket should have sparse_tensor_indices
1626       if (sparse_metadata_) {
1627         REDUCER_CHECK(
1628             bucket.sparse_tensor_indices.value().numel() ==
1629                 bucket.gradients.sizes()[0],
1630             logger_,
1631             "Sparse metadata and gradient size mismatch");
1632         auto sparse_result = at::sparse_coo_tensor(
1633             bucket.sparse_tensor_indices.value(),
1634             future_result,
1635             bucket.gradients.sizes());
1636         bucket.gradients.copy_(sparse_result);
1637       } else {
1638         bucket.gradients.copy_(future_result);
1639       }
1640     } else {
1641       // Reinitialize only `bucket_views_out` with the future_result by
1642       // following the same logic in `initialize_buckets`.
1643       populate_bucket_views_out(bucket, future_result);
1644     }
1645 
1646     // Unset allreduce division factor, as it may change in next backwards pass
1647     // when running with DDP join mode.
1648     div_factor_ = kUnsetDivFactor;
1649 
1650     if (!bucket.expect_sparse_gradient) {
1651       // We don't need to finalize the sparse bucket since the sparse grad and
1652       // the bucket essentially point to the same storage. As a result, once
1653       // the allreduce is done, the sparse grads are automatically updated.
1654       finalize_bucket_dense(bucket);
1655     }
1656   }
1657 
1658   if (installed_futures_ != std::nullopt) {
1659     c10::collectAll(*installed_futures_)->wait();
1660     installed_futures_ = std::nullopt;
1661   }
1662 
1663   // See Note [Skip allreducing local_used_maps_dev]
1664   if (dynamic_graph_find_unused() || static_graph_first_iteration()) {
1665     // Due to the lazy wait, it is possible that reduction of the current
1666     // iteration is still going when the one for next iteration gets kicked off.
1667     // For such case, we want to wait explicitly to make sure the reduction does
1668     // complete before kicking off next one. Otherwise the previous one may
1669     // interfere, write to the device-side memory and clobber the content of
1670     // local_unused_maps_dev_.
1671     if (!local_used_map_reduced_) {
1672       local_used_work_->wait();
1673     }
1674   }
1675 
1676   if (dynamic_graph_find_unused()) {
1677     // Reset unused parameter accounting.
1678     // See Note [local_used_map_ -> local_used_map_dev copying]
1679     local_used_map_.fill_(0);
1680     local_used_map_reduced_ = false;
1681   }
1682 
1683   if (should_collect_runtime_stats()) {
1684     record_backward_comm_end_time();
1685   }
1686 
1687   sparse_metadata_.reset();
1688 }
1689 
runGradCallbackForVariable(at::Tensor & variable,GradCallback && cb)1690 void Reducer::runGradCallbackForVariable(
1691     at::Tensor& variable,
1692     GradCallback&& cb) {
1693 #ifdef _WIN32
1694   cb(variable.mutable_grad());
1695 #else
1696   auto context_ptr = rpc_context_.context_ptr.load();
1697   if (context_ptr == nullptr) {
1698     cb(variable.mutable_grad());
1699   } else {
1700     // Under distributed autograd
1701     context_ptr->runGradCallbackForVariable(variable, std::move(cb));
1702   }
1703 #endif
1704 }
1705 
1706 #ifndef _WIN32
set(ContextPtr && new_context_ptr)1707 void Reducer::RpcContext::set(ContextPtr&& new_context_ptr) {
1708   // We should set 'new_context_ptr' even if it's nullptr. That means the
1709   // reducer is under a local backward run.
1710   const auto new_context_raw_ptr = new_context_ptr.get();
1711   if (context_ptr.exchange(new_context_raw_ptr) != new_context_raw_ptr) {
1712     // Set the shared ptr to the context only if it's set first time.
1713     // All call sites should use the same context ptr.
1714     // Use an atomic to avoid data race from multiple threads.
1715     context_ptr_holder = std::move(new_context_ptr);
1716   }
1717 }
1718 #endif
1719 
sync_bucket_indices(std::vector<std::vector<size_t>> & bucket_indices)1720 void Reducer::sync_bucket_indices(
1721     std::vector<std::vector<size_t>>& bucket_indices) {
1722   auto num_buckets = bucket_indices.size();
1723   std::vector<size_t> bucket_sizes;
1724   bucket_sizes.reserve(num_buckets);
1725   int64_t total_size = 0;
1726   for (const auto i : c10::irange(num_buckets)) {
1727     auto bucket_size = bucket_indices.at(i).size();
1728     bucket_sizes.push_back(bucket_size);
1729     total_size += static_cast<int64_t>(bucket_size);
1730   }
1731 
1732   at::TensorOptions options;
1733   options = options.dtype(at::kInt);
1734   options = options.device(params_[0].device());
1735 
1736   // Group indices and num_bucket together into indices_tensor
1737   // Broadcast this tensor first, as its size is equal among all processes
1738   auto indices_tensor = at::empty({total_size + 1}, at::kInt);
1739   auto indices_accessor = indices_tensor.accessor<int, 1>();
1740   auto indices_accessor_Index = 0;
1741   for (const auto i : c10::irange(num_buckets)) {
1742     const auto& bucket_size = bucket_indices.at(i).size();
1743     for (const auto j : c10::irange(bucket_size)) {
1744       indices_accessor[indices_accessor_Index++] =
1745           static_cast<int>(bucket_indices[i][j]);
1746     }
1747   }
1748   indices_accessor[indices_accessor_Index] = static_cast<int>(num_buckets);
1749 
1750   // Copy CPU tensor to device tensor, as the process_group_ could be NCCL and
1751   // it can only broadcast device tensors.
1752   auto indices_tensor_device = at::empty({total_size + 1}, options);
1753   indices_tensor_device.copy_(indices_tensor, /*non_blocking=*/true);
1754   std::vector<at::Tensor> indices_tensor_list = {indices_tensor_device};
1755   process_group_->broadcast(indices_tensor_list)->wait();
1756   indices_tensor.copy_(indices_tensor_list.front(), /*non_blocking=*/false);
1757 
1758   // Update num_buckets after receiving it from rank 0
1759   num_buckets = indices_accessor[indices_accessor_Index];
1760 
1761   // Broadcast bucket_sizes
1762   auto bucket_sizes_tensor = at::empty({(int64_t)num_buckets}, at::kInt);
1763   auto bucket_sizes_accessor = bucket_sizes_tensor.accessor<int, 1>();
1764   for (const auto i : c10::irange(num_buckets)) {
1765     // For rank != 0, it is possible that local num buckets bucket_sizes.size()
1766     // is smaller than broadcasted num_buckets
1767     bucket_sizes_accessor[i] =
1768         bucket_sizes.at(std::min(i, (bucket_sizes.size() - 1)));
1769   }
1770   auto bucket_sizes_tensor_device = at::empty({(int64_t)num_buckets}, options);
1771   bucket_sizes_tensor_device.copy_(bucket_sizes_tensor, /*non_blocking=*/true);
1772   std::vector<at::Tensor> bucket_sizes_tensor_list = {
1773       bucket_sizes_tensor_device};
1774   process_group_->broadcast(bucket_sizes_tensor_list)->wait();
1775   bucket_sizes_tensor.copy_(
1776       bucket_sizes_tensor_list.front(), /*non_blocking=*/false);
1777 
1778   // Clear bucket_indices first, and then update bucket_indices using received
1779   // num_buckets, bucket_sizes_tensor and indices_tensor from rank 0
1780   bucket_indices.clear();
1781   bucket_indices.reserve(num_buckets);
1782   indices_accessor_Index = 0;
1783   for (const auto i : c10::irange(num_buckets)) {
1784     const auto& bucket_size = bucket_sizes_accessor[static_cast<int64_t>(i)];
1785     std::vector<size_t> bucket;
1786     bucket.reserve(bucket_size);
1787     for (const auto j : c10::irange(bucket_size)) {
1788       (void)j;
1789       bucket.push_back(indices_accessor[indices_accessor_Index++]);
1790     }
1791     bucket_indices.emplace_back(std::move(bucket));
1792   }
1793 }
1794 
rebuild_buckets()1795 bool Reducer::rebuild_buckets() {
1796   // Ensure reduction for previous backwards pass is finished. If user's model
1797   // has unused parameters for example, this will raise an error recommending to
1798   // run with find_unused_parameters=True, instead of the size mismatch
1799   // exception below.
1800   std::lock_guard<std::mutex> lock(mutex_);
1801   ensure_prior_reduction_finished();
1802   if (!should_rebuild_buckets() || rebuilt_params_.empty()) {
1803     return false;
1804   }
1805 
1806   TORCH_INTERNAL_ASSERT(
1807       rebuilt_params_.size() == rebuilt_param_indices_.size(),
1808       c10::str(
1809           "rebuilt parameter tensors size is not same as rebuilt parameter indices size: ",
1810           rebuilt_params_.size(),
1811           " versus ",
1812           rebuilt_param_indices_.size()));
1813   TORCH_INTERNAL_ASSERT(
1814       params_.size() == rebuilt_param_indices_.size(),
1815       c10::str(
1816           "rebuilt parameter indices size is not same as original model parameters size.",
1817           "Original model param size is: ",
1818           params_.size(),
1819           " versus rebuilt params size of: ",
1820           rebuilt_param_indices_.size()));
1821   std::vector<size_t> bucket_size_limits;
1822   bucket_size_limits.push_back(first_bucket_bytes_cap_);
1823   bucket_size_limits.push_back(bucket_bytes_cap_);
1824   auto ddp_set_last_bucket_as_small =
1825       (getCvarString({"DDP_SET_LAST_BUCKET_CAP"}, "N/A") == "1");
1826 
1827   if (ddp_set_last_bucket_as_small) {
1828     // Reverse so that first_bucket_bytes_cap_ (smaller bucket) becomes the last
1829     // bucket. We cannot simply pass in {bucket_bytes_cap_,
1830     // first_bucket_bytes_cap} as the bucket order as we would immediately
1831     // advance to the 2nd element after the first bucket, whereas we only want
1832     // the last bucket to have a smaller size.
1833     std::reverse(rebuilt_params_.begin(), rebuilt_params_.end());
1834     std::reverse(rebuilt_param_indices_.begin(), rebuilt_param_indices_.end());
1835   }
1836   auto [rebuilt_bucket_indices, per_bucket_size_limits] =
1837       compute_bucket_assignment_by_size(
1838           rebuilt_params_,
1839           bucket_size_limits,
1840           expect_sparse_gradients_,
1841           rebuilt_param_indices_,
1842           logger_);
1843 
1844   if (ddp_set_last_bucket_as_small) {
1845     // Reverse again because buckets were rebuilt in the opposite of gradient
1846     // ready order.
1847     std::reverse(rebuilt_bucket_indices.begin(), rebuilt_bucket_indices.end());
1848     std::reverse(per_bucket_size_limits.begin(), per_bucket_size_limits.end());
1849   }
1850 
1851   if (ddp_debug_level_ != c10d::DebugLevel::Off) {
1852     TORCH_INTERNAL_ASSERT(
1853         rebuilt_bucket_indices.size() == per_bucket_size_limits.size())
1854     LOG(INFO) << rebuilt_bucket_indices.size()
1855               << " buckets rebuilt with size limits: "
1856               << c10::Join(", ", per_bucket_size_limits) << " bytes.";
1857   }
1858 
1859   // For rebuilt bucket indices, it needs to be synced across all ranks.
1860   // Broadcast the newly rebuilt bucket indices from rank 0 in default.
1861   // After syncing up rebuilt bucket indices, initialize buckets for reducer.
1862   sync_bucket_indices(rebuilt_bucket_indices);
1863 
1864   has_rebuilt_bucket_ = true;
1865   rebuilt_params_.clear();
1866   rebuilt_param_indices_.clear();
1867 
1868   initialize_buckets(std::move(rebuilt_bucket_indices));
1869 
1870   return true;
1871 }
1872 
setSparseMetadata(std::map<std::string,at::Tensor> & metadata)1873 void Reducer::setSparseMetadata(std::map<std::string, at::Tensor>& metadata) {
1874   sparse_metadata_ =
1875       std::make_unique<std::map<std::string, at::Tensor>>(metadata);
1876 }
1877 
1878 // See Note [DDP Communication Hook]
register_comm_hook(std::unique_ptr<CommHookInterface> iface)1879 void Reducer::register_comm_hook(std::unique_ptr<CommHookInterface> iface) {
1880   REDUCER_CHECK(
1881       comm_hook_ == nullptr,
1882       logger_,
1883       "register_comm_hook or register_builtin_comm_hook can only be called once.");
1884 
1885   comm_hook_ = std::move(iface);
1886 }
1887 
1888 // See Note [DDP Communication Hook]
register_builtin_comm_hook(c10d::BuiltinCommHookType comm_hook_type)1889 void Reducer::register_builtin_comm_hook(
1890     c10d::BuiltinCommHookType comm_hook_type) {
1891   REDUCER_CHECK(
1892       comm_hook_ == nullptr,
1893       logger_,
1894       "register_builtin_comm_hook or register_comm_hook can only be called once.");
1895 
1896   switch (comm_hook_type) {
1897     case c10d::BuiltinCommHookType::ALLREDUCE:
1898       comm_hook_ = std::make_unique<c10d::AllReduceCommHook>(process_group_);
1899       LOG(INFO) << "Built-in communication hook ALLREDUCE is registered.";
1900       break;
1901     case c10d::BuiltinCommHookType::FP16_COMPRESS:
1902       comm_hook_ = std::make_unique<c10d::FP16CompressCommHook>(process_group_);
1903       LOG(INFO) << "Built-in communication hook FP16_COMPRESS is registered.";
1904       break;
1905     default:
1906       TORCH_WARN_ONCE(
1907           "Unknown built-in DDP comm hook type is provided. No comm hook will be used.");
1908   }
1909 }
1910 
ensure_prior_reduction_finished()1911 void Reducer::ensure_prior_reduction_finished() {
1912   // Check that any prior reduction has finished.
1913   // The variable `require_finalize_` is true until all gradients
1914   // have been computed and reduction of all buckets has been kicked off.
1915   if (require_finalize_) {
1916     // Collect unmarked parameter indices, additionally, in debug mode retrieve
1917     // parameter names.
1918     auto unmarked_param_indices = getUnmarkedParamIndicesForIteration();
1919     // We should have some unmarked parameter indices, otherwise we would not
1920     // have run into this error branch.
1921     TORCH_INTERNAL_ASSERT(!unmarked_param_indices.empty());
1922 
1923     std::string kBaseErrorMsg =
1924         "Expected to have finished reduction in the prior iteration before "
1925         "starting a new one. "
1926         ""
1927         "This error indicates that your module has parameters that were "
1928         "not used in producing loss. ";
1929     std::string kOutputsNotUsedInLossErrorMsg =
1930         "making sure all "
1931         "`forward` function outputs participate in calculating loss. ";
1932     std::string kDDPBugErrorMsg =
1933         "\nIf you already have done the above, then the distributed "
1934         "data parallel module wasn't able to locate the output tensors in the "
1935         "return value of your module's `forward` function. "
1936         "Please include the loss function and the structure of the return "
1937         "value of `forward` of your module when reporting this issue (e.g. "
1938         "list, dict, iterable).";
1939 
1940     if (static_graph_) {
1941       kBaseErrorMsg =
1942           "Expected to have finished reduction in the prior iteration before "
1943           "starting a new one. "
1944           "This error indicates that your training graph has changed "
1945           "in this iteration, e.g., one parameter is used in first "
1946           "iteration, but then got unused in the second iteration. "
1947           "this is not compatible with static_graph set to True.";
1948     } else if (!find_unused_parameters_) {
1949       // Parameters may have been unused in forward pass, or not all outputs
1950       // were used in producing loss.
1951       kBaseErrorMsg +=
1952           "You can enable unused parameter detection by passing the "
1953           "keyword argument `find_unused_parameters=True` to "
1954           "`torch.nn.parallel.DistributedDataParallel`, and by \n";
1955       kBaseErrorMsg += kOutputsNotUsedInLossErrorMsg;
1956       kBaseErrorMsg += kDDPBugErrorMsg;
1957     } else {
1958       // Note that it does not really matter whether unused_parameters_.empty(),
1959       // since user may have enabled detection but this particular iteration
1960       // could have used or not used all parameters.
1961       kBaseErrorMsg +=
1962           "Since `find_unused_parameters=True` is enabled, this likely "
1963           " means that not all `forward` outputs participate in computing loss. You can fix this by ";
1964       kBaseErrorMsg += kOutputsNotUsedInLossErrorMsg;
1965       kBaseErrorMsg += kDDPBugErrorMsg;
1966     }
1967 
1968     const std::string unmarked_param_indices_info = c10::str(
1969         "\n",
1970         "Parameter indices which did not receive grad for rank ",
1971         process_group_->getRank(),
1972         ": ",
1973         unmarked_param_indices);
1974 
1975     if (ddp_debug_level_ == DebugLevel::Off) {
1976       // Without debug mode, log unmarked_param_indices, as well as
1977       // recommendation to use debug mode to print parameter names.
1978       kBaseErrorMsg += unmarked_param_indices_info;
1979       kBaseErrorMsg +=
1980           "\n In addition, you can set the environment variable "
1981           "TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information "
1982           "about which particular parameters did not receive gradient on this rank "
1983           "as part of this error";
1984     } else {
1985       // Retrieve set of parameter names that did not receive gradient.
1986       auto unmarkedParams = getUnmarkedParamsForIteration();
1987       TORCH_INTERNAL_ASSERT(!unmarkedParams.empty());
1988       for (const auto& s : unmarkedParams) {
1989         LOG(INFO) << "[Rank " << process_group_->getRank() << "] "
1990                   << "Parameter: " << s
1991                   << " did not get gradient in backwards pass.";
1992       }
1993       const std::string unmarkedParamInfo = c10::Join(", ", unmarkedParams);
1994       // In debug mode, log param names and indices that went unused.
1995       kBaseErrorMsg += c10::str(
1996           "\n",
1997           "Parameters which did not receive grad for rank ",
1998           process_group_->getRank(),
1999           ": ",
2000           unmarkedParamInfo);
2001       kBaseErrorMsg += unmarked_param_indices_info;
2002     }
2003     REDUCER_CHECK(false, logger_, kBaseErrorMsg);
2004   }
2005 }
2006 
set_ddp_runtime_logging_sample_rate(int sample_rate)2007 void Reducer::set_ddp_runtime_logging_sample_rate(int sample_rate) {
2008   ddp_runtime_logging_sample_rate_ = sample_rate;
2009 }
2010 
get_ddp_runtime_logging_sample_rate()2011 int Reducer::get_ddp_runtime_logging_sample_rate() {
2012   return ddp_runtime_logging_sample_rate_;
2013 }
2014 
should_collect_runtime_stats()2015 bool Reducer::should_collect_runtime_stats() {
2016   if (num_iterations_ > 0 &&
2017       (num_iterations_ <= 10 ||
2018        num_iterations_ % get_ddp_runtime_logging_sample_rate() == 0)) {
2019     return true;
2020   }
2021   return false;
2022 }
2023 
record_forward_compute_start_time()2024 void Reducer::record_forward_compute_start_time() {
2025   if (timer_) {
2026     timer_->record(Timer::Event::kForwardStart);
2027   }
2028 }
2029 
record_backward_compute_start_time()2030 void Reducer::record_backward_compute_start_time() {
2031   if (timer_) {
2032     timer_->record(Timer::Event::kBackwardComputeStart);
2033   }
2034 }
2035 
record_backward_compute_end_time()2036 void Reducer::record_backward_compute_end_time() {
2037   if (timer_) {
2038     timer_->record(Timer::Event::kBackwardComputeEnd);
2039   }
2040 }
2041 
record_backward_comm_start_time()2042 void Reducer::record_backward_comm_start_time() {
2043   if (timer_) {
2044     timer_->record(Timer::Event::kBackwardCommStart);
2045   }
2046 }
2047 
record_backward_comm_end_time()2048 void Reducer::record_backward_comm_end_time() {
2049   if (timer_) {
2050     timer_->record(Timer::Event::kBackwardCommEnd);
2051   }
2052 }
2053 
set_static_graph()2054 void Reducer::set_static_graph() {
2055   std::lock_guard<std::mutex> lock(mutex_);
2056   REDUCER_CHECK(
2057       num_iterations_ == 0,
2058       logger_,
2059       "set_static_graph() should be called before training loop starts "
2060       "and after DistributedDataParallel is constructed.");
2061   static_graph_ = true;
2062   // when static_graph_ is set as true, always initialize_local_used_map
2063   // and detect the global unused parameters in the first iteration.
2064   initialize_local_used_map();
2065 }
2066 
2067 namespace {
2068 
2069 // Tensors may be coalesced into buckets. Buckets must contain tensors of
2070 // the same type, on the same device, so a bucket can identified by a
2071 // composite key of a tensor's type identifier and its device.
2072 struct BucketKey {
BucketKeyc10d::__anon78824b6c0d11::BucketKey2073   BucketKey(c10::ScalarType type, c10::Device device)
2074       : type(type), device(device) {}
2075 
2076   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const*)
2077   const c10::ScalarType type;
2078   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const*)
2079   const c10::Device device;
2080 
2081   // See torch/csrc/utils/hash.h for dispatch code.
hashc10d::__anon78824b6c0d11::BucketKey2082   static size_t hash(const BucketKey& key) {
2083     return c10::get_hash(key.type, key.device);
2084   }
2085 };
2086 
operator ==(const BucketKey & lhs,const BucketKey & rhs)2087 inline bool operator==(const BucketKey& lhs, const BucketKey& rhs) {
2088   return lhs.type == rhs.type && lhs.device == rhs.device;
2089 }
2090 
2091 } // namespace
2092 
2093 std::tuple<std::vector<std::vector<size_t>>, std::vector<size_t>>
compute_bucket_assignment_by_size(const std::vector<at::Tensor> & tensors,const std::vector<size_t> & bucket_size_limits,const std::vector<bool> & expect_sparse_gradient,const std::vector<int64_t> & tensor_indices,const std::optional<std::weak_ptr<c10d::Logger>> & logger)2094 compute_bucket_assignment_by_size(
2095     const std::vector<at::Tensor>& tensors,
2096     const std::vector<size_t>& bucket_size_limits,
2097     const std::vector<bool>& expect_sparse_gradient,
2098     const std::vector<int64_t>& tensor_indices,
2099     const std::optional<std::weak_ptr<c10d::Logger>>& logger) {
2100   // Either expect_sparse_gradient is not specified or it has as many elements
2101   // as the vector with tensors.
2102   TORCH_INTERNAL_ASSERT(
2103       expect_sparse_gradient.empty() ||
2104       (tensors.size() == expect_sparse_gradient.size()));
2105   TORCH_INTERNAL_ASSERT(!tensors.empty());
2106   // Store bucket indices and their sizes together, because we later sort the
2107   // resulting indices by minimum tensor index and want to keep sizes
2108   // consistent.
2109   std::vector<std::tuple<std::vector<size_t>, size_t>> result;
2110   // Sparse tensors go in their own bucket, so they do not have an enforced size
2111   // limit.
2112   size_t kNoSizeLimit = 0;
2113   result.reserve(tensors.size());
2114 
2115   // Keep iterator into the size_limit vector by tensor type and device.
2116   // This is done so that we can use the consecutive bucket limits per type.
2117   std::unordered_map<
2118       BucketKey,
2119       std::vector<size_t>::const_iterator,
2120       c10::hash<BucketKey>>
2121       bucket_size_limit_iterators;
2122 
2123   // Keep vector of indices and size accumulator by tensor type and device.
2124   std::unordered_map<BucketKey, BucketAccumulator, c10::hash<BucketKey>>
2125       buckets;
2126 
2127   for (const auto i : c10::irange(tensors.size())) {
2128     const auto& tensor = tensors[i];
2129     auto msg = std::string("No support for sparse tensors.");
2130     if (logger.has_value()) {
2131       REDUCER_CHECK(!tensor.is_sparse(), logger.value(), msg);
2132     } else {
2133       TORCH_CHECK(!tensor.is_sparse(), msg);
2134     }
2135 
2136     // when tensor_indices is empty, the index of tensors[i] assigned to
2137     // bucket is i, otherwise the tensor index is tensor_indices[i].
2138     auto tensor_index = i;
2139     if (!tensor_indices.empty()) {
2140       tensor_index = tensor_indices[i];
2141     }
2142     // If we expect a sparse gradient to be produced for this tensor, it cannot
2143     // be grouped together with other gradients and gets its own bucket.
2144     if (!expect_sparse_gradient.empty() &&
2145         expect_sparse_gradient[tensor_index]) {
2146       result.emplace_back(std::vector<size_t>({tensor_index}), kNoSizeLimit);
2147       continue;
2148     }
2149 
2150     auto key = BucketKey(tensor.scalar_type(), tensor.device());
2151     auto& bucket = buckets[key];
2152     bucket.indices.push_back(tensor_index);
2153     bucket.size += tensor.numel() * tensor.element_size();
2154 
2155     // Initialize bucket size limit iterator if necessary.
2156     if (bucket_size_limit_iterators.count(key) == 0) {
2157       bucket_size_limit_iterators[key] = bucket_size_limits.begin();
2158     }
2159 
2160     auto& bucket_size_limit_iterator = bucket_size_limit_iterators[key];
2161     const auto bucket_size_limit = *bucket_size_limit_iterator;
2162     bucket.size_limit = bucket_size_limit;
2163     if (bucket.size >= bucket_size_limit) {
2164       result.emplace_back(std::move(bucket.indices), bucket.size_limit);
2165       bucket = BucketAccumulator();
2166 
2167       // Advance to the next bucket size limit for this type/device.
2168       auto next = bucket_size_limit_iterator + 1;
2169       if (next != bucket_size_limits.end()) {
2170         bucket_size_limit_iterator = next;
2171       }
2172     }
2173   }
2174 
2175   // Add remaining buckets.
2176   for (auto& it : buckets) {
2177     auto& bucket = it.second;
2178     if (!bucket.indices.empty()) {
2179       result.emplace_back(std::move(bucket.indices), bucket.size_limit);
2180     }
2181   }
2182 
2183   // If tensor_indices is not empty, the order of the tensors is in the gradient
2184   // ready order, so no need to sort.
2185   // If tensor_indices is empty, sort resulting buckets by the minimum tensor
2186   // index they include. We assume that the order of the tensors is the order in
2187   // which they are used (or the reverse order in which their gradients are
2188   // produced). This sorting step ensures that the buckets are ready in
2189   // consecutive order.
2190   if (tensor_indices.empty()) {
2191     std::sort(
2192         result.begin(),
2193         result.end(),
2194         [](const std::tuple<std::vector<size_t>, size_t>& a,
2195            const std::tuple<std::vector<size_t>, size_t>& b) {
2196           auto indices_a = std::get<0>(a);
2197           auto indices_b = std::get<0>(b);
2198           const auto amin =
2199               std::min_element(indices_a.begin(), indices_a.end());
2200           const auto bmin =
2201               std::min_element(indices_b.begin(), indices_b.end());
2202           return *amin < *bmin;
2203         });
2204   }
2205 
2206   // Return bucket indices and size limits as separate entries in tuple, as some
2207   // APIs only need to consume bucket indices.
2208   std::vector<std::vector<size_t>> bucket_indices;
2209   bucket_indices.reserve(result.size());
2210   std::vector<size_t> per_bucket_size_limits;
2211   per_bucket_size_limits.reserve(result.size());
2212   for (const auto& bucket_indices_with_size : result) {
2213     bucket_indices.emplace_back(std::get<0>(bucket_indices_with_size));
2214     per_bucket_size_limits.emplace_back(std::get<1>(bucket_indices_with_size));
2215   }
2216   return std::make_tuple(bucket_indices, per_bucket_size_limits);
2217 }
2218 
2219 // Verifies corresponding params in the model replica have the same
2220 // sizes/strides across processes.
verify_params_across_processes(const c10::intrusive_ptr<c10d::ProcessGroup> & process_group,const std::vector<at::Tensor> & params,const std::optional<std::weak_ptr<c10d::Logger>> & logger)2221 void verify_params_across_processes(
2222     const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
2223     const std::vector<at::Tensor>& params,
2224     const std::optional<std::weak_ptr<c10d::Logger>>& logger) {
2225   // First verify number of parameters to avoid inconsistent inputs into
2226   // broadcast which can cause a crash.
2227   // See https://github.com/pytorch/pytorch/issues/73547
2228 
2229   at::TensorOptions param_size_options;
2230   param_size_options = param_size_options.dtype(at::kLong);
2231   param_size_options = param_size_options.device(params[0].device());
2232   // Note: Not using tensor building API because of
2233   // https://github.com/pytorch/pytorch/issues/74114
2234   at::Tensor param_size_tensor =
2235       at::tensor({static_cast<int64_t>(params.size())}, param_size_options);
2236 
2237   // Allgather and verify parameter size.
2238   std::vector<std::vector<at::Tensor>> param_size_output_tensors;
2239   param_size_output_tensors.emplace_back();
2240   auto world_size = process_group->getSize();
2241   for (C10_UNUSED const auto i : c10::irange(world_size)) {
2242     param_size_output_tensors.front().emplace_back(
2243         at::empty_like(param_size_tensor));
2244   }
2245 
2246   std::vector<at::Tensor> param_size_vec{param_size_tensor};
2247   process_group->allgather(param_size_output_tensors, param_size_vec)->wait();
2248   auto result_size_tensors = param_size_output_tensors.front();
2249   for (const auto i : c10::irange(world_size)) {
2250     auto param_size_for_rank = result_size_tensors[i][0].item<int>();
2251     TORCH_CHECK(
2252         static_cast<size_t>(param_size_for_rank) == params.size(),
2253         c10::str(
2254             "DDP expects same model across all ranks, but Rank ",
2255             process_group->getRank(),
2256             " has ",
2257             params.size(),
2258             " params, while rank ",
2259             i,
2260             " has inconsistent ",
2261             param_size_for_rank,
2262             " params."));
2263   }
2264 
2265   // Continue with parameter shape verification.
2266   size_t i = 0;
2267   for (const auto& t : params) {
2268     i += 2 * t.dim();
2269   }
2270   at::TensorOptions options;
2271   options = options.dtype(at::kLong);
2272   auto metadata = at::empty({static_cast<long>(i)}, options);
2273 
2274   // Technically, process 0 is the broadcast source, so only process 0 needs
2275   // to populate metadata.  But no harm keeping work aligned across processes.
2276   auto metadata_accessor = metadata.accessor<int64_t, 1>();
2277   i = 0;
2278   for (const auto& t : params) {
2279     for (const auto& sz : t.sizes()) {
2280       metadata_accessor[static_cast<int64_t>(i++)] = sz;
2281     }
2282     for (const auto& str : t.strides()) {
2283       metadata_accessor[static_cast<int64_t>(i++)] = str;
2284     }
2285   }
2286 
2287   auto metadata_dev = metadata.clone().to(params[0].device());
2288   std::vector<at::Tensor> vec{metadata_dev};
2289   process_group->broadcast(vec)->wait();
2290 
2291   // Technically, process 0 doesn't need to double-check metadata, because it
2292   // was the source.  But no harm keeping work aligned.
2293   auto control = at::empty({static_cast<long>(i)}, options);
2294   control.copy_(metadata_dev, /*non_blocking=*/false);
2295   auto control_accessor = control.accessor<int64_t, 1>();
2296   i = 0;
2297   for (const auto p : c10::irange(params.size())) {
2298     const auto& t = params[p];
2299     for (const auto& sz : t.sizes()) {
2300       auto msg = c10::str(
2301           "[",
2302           process_group->getRank(),
2303           "]: params[",
2304           p,
2305           "] in this process",
2306           " with sizes ",
2307           t.sizes(),
2308           " appears not to match sizes of the same param in process 0.");
2309       if (logger.has_value()) {
2310         REDUCER_CHECK(sz == control_accessor[i++], logger.value(), msg)
2311       } else {
2312         TORCH_CHECK(sz == control_accessor[i++], msg)
2313       }
2314     }
2315     for (const auto& str : t.strides()) {
2316       auto msg = c10::str(
2317           "params[",
2318           p,
2319           "] in this process",
2320           " with sizes ",
2321           t.sizes(),
2322           " appears not to match strides of the same param in process 0.");
2323       if (logger.has_value()) {
2324         REDUCER_CHECK(str == control_accessor[i++], logger.value(), msg)
2325       } else {
2326         TORCH_CHECK(str == control_accessor[i++], msg)
2327       }
2328     }
2329   }
2330 }
2331 
remove_autograd_hooks()2332 void Reducer::remove_autograd_hooks() {
2333   // Remove all hooks on variables registered by this Reducer. This is necessary
2334   // to make DDP failure recoverable. Otherwise, multiple Reducer instances
2335   // (from recoveries) will add their hooks to the original model, and those
2336   // hooks will try to invoke methods on a deleted Reducer objects.
2337   for (auto& hook : hooks_) {
2338     auto& key = hook.first;
2339     auto& grad_accumulator = hook.second;
2340 
2341     TORCH_INTERNAL_ASSERT(
2342         grad_accumulator->del_post_hook(key),
2343         "Reducer attempts to delete a non-existing hook.");
2344   }
2345   hooks_.clear();
2346 }
2347 
check_finalized()2348 void Reducer::check_finalized() {
2349   std::lock_guard<std::mutex> lock(mutex_);
2350   ensure_prior_reduction_finished();
2351 }
2352 
update_process_group(c10::intrusive_ptr<c10d::ProcessGroup> new_process_group)2353 void Reducer::update_process_group(
2354     c10::intrusive_ptr<c10d::ProcessGroup> new_process_group) {
2355   std::lock_guard<std::mutex> lock(mutex_);
2356   process_group_ = std::move(new_process_group);
2357 }
2358 
reset_state()2359 void Reducer::reset_state() {
2360   std::lock_guard<std::mutex> lock(mutex_);
2361   // Force rebuild of buckets.
2362   has_rebuilt_bucket_ = false;
2363   rebuilt_params_.clear();
2364   rebuilt_param_indices_.clear();
2365 
2366   // Ensure forward can run despite previous backward not succeeding.
2367   expect_autograd_hooks_ = false;
2368   require_finalize_ = false;
2369   first_autograd_hook_called_ = false;
2370 
2371   // Unset allreduce division factor, as it may change in next backwards pass
2372   // when running with DDP join mode.
2373   div_factor_ = kUnsetDivFactor;
2374 
2375   // Reset unused parameter accounting.
2376   // See Note [local_used_map_ -> local_used_map_dev copying]
2377   if (find_unused_parameters_) {
2378     local_used_map_.zero_();
2379     local_used_map_reduced_ = false;
2380   }
2381 }
2382 
2383 } // namespace c10d
2384