xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/comm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/comm.hpp>
2 
3 #include <deque>
4 
5 #include <ATen/core/functional.h>
6 #include <c10/util/irange.h>
7 #include <torch/csrc/distributed/c10d/reducer.hpp>
8 #include <torch/csrc/utils/tensor_flatten.h>
9 
10 namespace c10d {
11 namespace {
12 
13 class BroadcastWork {
14  public:
BroadcastWork(const c10::intrusive_ptr<c10d::ProcessGroup> & process_group,std::vector<at::Tensor> bucket_tensors,int root_rank=0)15   BroadcastWork(
16       const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
17       std::vector<at::Tensor> bucket_tensors,
18       int root_rank = 0)
19       : bucket_tensors_(std::move(bucket_tensors)),
20         flat_tensor_({torch::utils::flatten_dense_tensors(bucket_tensors_)}) {
21     BroadcastOptions broadcastOptions;
22     broadcastOptions.rootRank = root_rank;
23     work_ = process_group->broadcast(flat_tensor_, broadcastOptions);
24   }
25 
finish()26   void finish() {
27     work_->wait();
28 
29     // Copy the output of the broadcast operation back.
30     auto output_tensors = torch::utils::unflatten_dense_tensors(
31         flat_tensor_.front(), bucket_tensors_);
32     TORCH_INTERNAL_ASSERT(output_tensors.size() == bucket_tensors_.size());
33     for (const auto i : c10::irange(output_tensors.size())) {
34       // if output_tensor is empty, no need to copy it back,
35       // this can avoid error when both bucket_tensor and output_tensor
36       // are empty, but they have different shapes, see
37       // https://github.com/pytorch/pytorch/issues/87280
38       if (output_tensors[i].numel() != 0) {
39         bucket_tensors_[i].copy_(output_tensors[i], /*non_blocking=*/true);
40       }
41     }
42   }
43 
44  protected:
45   // The list of tensors to broadcast. They are guaranteed to be
46   // placed on the same device and have the same dtype.
47   std::vector<at::Tensor> bucket_tensors_;
48 
49   // The vector with a single flattened tensor containing the contents
50   // of the tensors in bucket_tensors_. It must be stored in a vector
51   // because c10d::ProcessGroup::broadcast takes a vector argument.
52   std::vector<at::Tensor> flat_tensor_;
53 
54  private:
55   // The broadcast work that is kicked off upon construction.
56   c10::intrusive_ptr<c10d::Work> work_;
57 };
58 
59 } // namespace
60 
61 // Broadcast many tensors to all processes in the process group.
broadcast_coalesced(const c10::intrusive_ptr<c10d::ProcessGroup> & process_group,at::TensorList tensors,size_t buffer_size,int rank)62 void broadcast_coalesced(
63     const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
64     at::TensorList tensors,
65     size_t buffer_size,
66     int rank) {
67   // Coalesce tensors into buckets taking into account the maximum buffer size.
68   // This routine is multi-device aware, so the tensors can be split across
69   // multiple devices and can contain a mix of CPU and CUDA tensors.
70   auto [buckets, _] =
71       compute_bucket_assignment_by_size(tensors.vec(), {buffer_size});
72 
73   // Returns tensor at specified index in input tensor list.
74   const auto lookup = [&tensors](size_t index) { return tensors[index]; };
75 
76   // We maintain a maximum of 2 in flight broadcast operations to avoid
77   // allocating too much memory (in case the specified tensors are very large).
78   std::deque<BroadcastWork> in_flight;
79   constexpr auto max_in_flight = 2;
80   for (const auto& bucket : buckets) {
81     if (in_flight.size() >= max_in_flight) {
82       in_flight.front().finish();
83       in_flight.pop_front();
84     }
85 
86     in_flight.emplace_back(process_group, c10::fmap(bucket, lookup), rank);
87   }
88 
89   while (!in_flight.empty()) {
90     in_flight.front().finish();
91     in_flight.pop_front();
92   }
93 }
94 
getGradients() const95 std::vector<at::Tensor> GradBucket::getGradients() const {
96   std::vector<at::Tensor> per_parameter_tensors;
97   size_t num_parameters = offsets_.size();
98   per_parameter_tensors.reserve(num_parameters);
99   for (const auto i : c10::irange(num_parameters)) {
100     per_parameter_tensors.push_back(
101         buffer_.slice(0, offsets_[i], offsets_[i] + lengths_[i])
102             .view(sizes_vec_[i]));
103   }
104   return per_parameter_tensors;
105 }
106 namespace detail {
107 
parseCppCommHookResult(const c10::IValue & result)108 at::Tensor parseCppCommHookResult(const c10::IValue& result) {
109   if (result.isPyObject()) {
110     std::vector<at::Tensor> tensors =
111         result.toPyObjectHolder()->extractTensors();
112     return tensors[0];
113   }
114   TORCH_INTERNAL_ASSERT(
115       result.isTensor() || result.isTensorList(),
116       "expected the hook result is either a Tensor or a TensorList found ",
117       result.tagKind());
118 
119   if (result.isTensor()) {
120     return result.toTensor();
121   }
122 
123   return result.toTensorVector()[0];
124 }
125 
126 } // namespace detail
127 
128 } // namespace c10d
129