xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/ProcessGroupUCC.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_C10D_UCC
4 
5 #include <torch/csrc/distributed/c10d/UCCUtils.hpp>
6 
7 #include <exception>
8 #include <memory>
9 #include <mutex>
10 #include <queue>
11 #include <thread>
12 #include <vector>
13 
14 #include <torch/csrc/distributed/c10d/Backend.hpp>
15 #include <torch/csrc/distributed/c10d/Store.hpp>
16 #include <torch/csrc/distributed/c10d/Types.hpp>
17 #include <torch/csrc/distributed/c10d/Utils.hpp>
18 #ifdef USE_CUDA
19 #include <ATen/cuda/CUDAEvent.h>
20 #include <c10/cuda/CUDAStream.h>
21 #endif
22 
23 namespace c10d {
24 
25 #define TORCH_UCC_DEVICE_NOT_SET -2
26 
27 #ifdef USE_CUDA
28 #define SAVE_TENSORS(_TENSORS, _DATA)                       \
29   do {                                                      \
30     if ((_TENSORS)[0].device().is_cuda()) {                 \
31       for (const auto i : c10::irange((_TENSORS).size())) { \
32         c10::cuda::CUDACachingAllocator::recordStream(      \
33             (_TENSORS)[i].storage().data_ptr(), (*stream)); \
34       }                                                     \
35     } else {                                                \
36       (_DATA) = (_TENSORS);                                 \
37     }                                                       \
38   } while (0)
39 
40 #else
41 #define SAVE_TENSORS(_TENSORS, _DATA) (_DATA) = (_TENSORS);
42 #endif
43 
44 constexpr const char* UCC_BACKEND_NAME = "ucc";
45 
46 struct event_pool_t {
47 #ifdef USE_CUDA
48   std::queue<std::unique_ptr<at::cuda::CUDAEvent>> event_pool;
49 #endif
50   std::mutex event_pool_mutex;
51 };
52 
53 class Comm;
54 
55 // UCC does not support multiple CUDA devices per process.
56 class TORCH_API ProcessGroupUCC : public Backend {
57  private:
58   void set_timeout(ucc_coll_args_t& args);
59 
60  public:
61   class WorkData {
62    public:
63     std::vector<at::Tensor> src;
64     std::vector<at::Tensor> dst;
65     std::vector<at::Tensor> flat;
WorkData()66     WorkData() {}
67     virtual ~WorkData() = default;
68   };
69   class AlltoallWorkData : public WorkData {
70    public:
AlltoallWorkData(int size)71     AlltoallWorkData(int size)
72         : send_lengths(size),
73           send_offsets(size),
74           recv_lengths(size),
75           recv_offsets(size) {}
76     std::vector<uint64_t> send_lengths;
77     std::vector<uint64_t> send_offsets;
78     std::vector<uint64_t> recv_lengths;
79     std::vector<uint64_t> recv_offsets;
80   };
81 
82   class AllgathervWorkData : public WorkData {
83    public:
AllgathervWorkData(int size)84     AllgathervWorkData(int size) : recv_lengths(size), recv_offsets(size) {}
85     std::vector<uint64_t> recv_lengths;
86     std::vector<uint64_t> recv_offsets;
87   };
88 
89   class ScattervWorkData : public WorkData {
90    public:
ScattervWorkData(int size)91     ScattervWorkData(int size) : send_lengths(size), send_offsets(size) {}
92     std::vector<uint64_t> send_lengths;
93     std::vector<uint64_t> send_offsets;
94   };
95 
96   class ProgressEntry {
97     friend class ProcessGroupUCC;
98     friend class Comm;
99 
100    public:
ProgressEntry(CommBase * comm,ucc_coll_req_h request)101     ProgressEntry(CommBase* comm, ucc_coll_req_h request)
102         : status_(UCC_INPROGRESS), comm_(comm), request_(request) {}
103     // Finalizes UCC status or exception of collective request.
104     void finalize(std::exception_ptr eptr = nullptr);
105     ucc_status_t status_;
106     CommBase* comm_;
107     ucc_coll_req_h request_;
108     std::unique_ptr<WorkData> data;
109     c10::intrusive_ptr<c10::ivalue::Future> future_;
110     std::exception_ptr eptr_;
111   };
112 
113   class WorkUCC : public Work {
114     friend class ProcessGroupUCC;
115     friend class Comm;
116 
117    public:
WorkUCC(OpType opType,uint64_t seq,const char * prof_title,const std::optional<std::vector<at::Tensor>> & inputs,const c10::intrusive_ptr<ProcessGroupUCCLogger> & logger)118     WorkUCC(
119         OpType opType,
120         uint64_t seq,
121         const char* prof_title,
122         const std::optional<std::vector<at::Tensor>>& inputs,
123         const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger)
124         : Work(-1, opType, prof_title, inputs), logger_(logger), seq_(seq) {}
125     ~WorkUCC();
126     void setException();
127     void setAndThrowException();
128     bool isCompleted() override;
129     bool isSuccess() const override;
130     bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
131     c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
132     std::vector<at::Tensor> result() override;
133     int sourceRank() const override;
134 #ifdef USE_CUDA
135     std::unique_ptr<at::cuda::CUDAEvent> fence = nullptr;
136     event_pool_t* ep = nullptr;
137 #endif
138     int sourceRank_;
139 
140    protected:
141     std::shared_ptr<ProgressEntry> entry_;
142     c10::intrusive_ptr<ProcessGroupUCCLogger> logger_;
143     uint64_t seq_;
144 
145    private:
146     // The future returned by getFuture.
147     c10::intrusive_ptr<at::ivalue::Future> future_;
148     // Store a reference to collective's outputs, used by result
149     std::shared_ptr<std::vector<at::Tensor>> outputs_;
150   };
151 
152   explicit ProcessGroupUCC(
153       const c10::intrusive_ptr<Store>& store,
154       int rank = -1,
155       int size = -1,
156       std::chrono::duration<float> timeout = kBackendDefaultTimeout);
157 
158   void initComm(c10::Device dev);
159 
160   ~ProcessGroupUCC() override;
161 
getBackendName() const162   const std::string getBackendName() const override {
163     return std::string(UCC_BACKEND_NAME);
164   }
165 
166 #ifdef USE_CUDA
167   std::unique_ptr<at::cuda::CUDAEvent> getPooledEvent();
168 #endif
169 
170   // Performs a health check by initializing dummy UCC & UCX communicators and
171   // then destroying them. This will help indicate and signal any
172   // UCC/UCX-related issues prior to the first collective. The actual
173   // initialization and subsequent destruction is ran on a separate thread and
174   // the main thread is signalled about timeouts/errors to report to the
175   // application.
176   void runHealthCheck();
177 
178   template <typename PreProcess, typename PostProcess>
179   c10::intrusive_ptr<Work> collective_post(
180       OpType opType,
181       PreProcess preproc,
182       PostProcess postproc,
183       ucc_coll_args_t& coll,
184       std::unique_ptr<ProcessGroupUCC::WorkData> data,
185       c10::Device dev,
186       std::vector<at::Tensor>& inputTensors,
187       std::vector<at::Tensor>& outputTensors,
188       const char* prof_title);
189 
190   c10::intrusive_ptr<Work> broadcast(
191       std::vector<at::Tensor>& data,
192       const BroadcastOptions& opts = BroadcastOptions()) override;
193 
194   c10::intrusive_ptr<Work> allreduce(
195       std::vector<at::Tensor>& tensors,
196       const AllreduceOptions& opts = AllreduceOptions()) override;
197 
198   c10::intrusive_ptr<Work> allreduce_coalesced(
199       std::vector<at::Tensor>& tensors,
200       const AllreduceCoalescedOptions& opts =
201           AllreduceCoalescedOptions()) override;
202 
203   c10::intrusive_ptr<Work> reduce(
204       std::vector<at::Tensor>& tensors,
205       const ReduceOptions& opts = ReduceOptions()) override;
206 
207   c10::intrusive_ptr<Work> allgather(
208       std::vector<std::vector<at::Tensor>>& outputTensors,
209       std::vector<at::Tensor>& inputTensors,
210       const AllgatherOptions& opts = AllgatherOptions()) override;
211 
212   c10::intrusive_ptr<Work> _allgather_base(
213       at::Tensor& outputBuffer,
214       at::Tensor& inputBuffer,
215       const AllgatherOptions& opts = AllgatherOptions()) override;
216 
217   c10::intrusive_ptr<Work> barrier(
218       const BarrierOptions& opts = BarrierOptions()) override;
219 
220   c10::intrusive_ptr<Work> gather(
221       std::vector<std::vector<at::Tensor>>& outputTensors,
222       std::vector<at::Tensor>& inputTensors,
223       const GatherOptions& opts = GatherOptions()) override;
224 
225   c10::intrusive_ptr<Work> scatter(
226       std::vector<at::Tensor>& outputTensors,
227       std::vector<std::vector<at::Tensor>>& inputTensors,
228       const ScatterOptions& opts = ScatterOptions()) override;
229 
230   c10::intrusive_ptr<Work> reduce_scatter(
231       std::vector<at::Tensor>& outputTensors,
232       std::vector<std::vector<at::Tensor>>& inputTensors,
233       const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
234 
235   c10::intrusive_ptr<Work> alltoall_base(
236       at::Tensor& outputTensor,
237       at::Tensor& inputTensor,
238       std::vector<int64_t>& outputSplitSizes,
239       std::vector<int64_t>& inputSplitSizes,
240       const AllToAllOptions& opts = AllToAllOptions()) override;
241 
242   c10::intrusive_ptr<Work> alltoall(
243       std::vector<at::Tensor>& outputTensors,
244       std::vector<at::Tensor>& inputTensors,
245       const AllToAllOptions& opts = AllToAllOptions()) override;
246 
247   c10::intrusive_ptr<Work> send(
248       std::vector<at::Tensor>& tensors,
249       int dstRank,
250       int tag) override;
251 
252   c10::intrusive_ptr<Work> recv(
253       std::vector<at::Tensor>& tensors,
254       int srcRank,
255       int tag) override;
256 
257   // Counting for the sequential number of UCC collective_post call.
258   uint64_t seq_{0};
259 
260   // Agrees on an initial sequence number for the whole group by having rank 0
261   // create it and broadcast it to other ranks using the store.
262   void setSequenceNumberForGroup() override;
263 
264   // Retrieves the current sequence number for the whole group, which should be
265   // in sync. If the returned number is not consistent across the group, it
266   // may indicate that there is some sort of collective desynchronization.
267   uint64_t getSequenceNumberForGroup() override;
268 
269   static c10::intrusive_ptr<Backend> createProcessGroupUCC(
270       const c10::intrusive_ptr<::c10d::Store>& store,
271       int rank,
272       int size,
273       const std::chrono::duration<float>& timeout);
274 
275  protected:
276   const std::chrono::duration<float> timeout_;
277   std::shared_ptr<torch_ucc_oob_coll_info_t> oob;
278   std::shared_ptr<Comm> comm = {nullptr};
279   uint32_t comm_id;
280   ucc_team_h team{nullptr};
281   ucc_ee_h cuda_ee{nullptr};
282   ucc_ee_h cuda_ee_p2p[2]{nullptr, nullptr};
283 
284 #ifdef USE_CUDA
285   std::unique_ptr<at::cuda::CUDAStream> stream = nullptr;
286   std::unique_ptr<at::cuda::CUDAStream> stream_p2p[2] = {nullptr, nullptr};
287   event_pool_t ep;
288 #endif
289   c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
290 };
291 
292 class Comm {
293   c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
294   std::shared_ptr<torch_ucc_oob_coll_info_t> oob;
295   CommUCC ucc_comm;
296   std::mutex mutex;
297   std::thread progress_thread;
298   std::condition_variable queue_produce_cv;
299   std::condition_variable queue_consume_cv;
300   std::deque<std::shared_ptr<ProcessGroupUCC::ProgressEntry>> progress_queue;
301   bool stop_progress_loop;
302   bool collective_inprogress;
303   torch_ucc_phase_t finalize_phase;
304 
305  public:
306   c10::DeviceIndex cuda_device_index;
307   Comm(
308       const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
309       std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
310       c10::Device dev,
311       bool is_health_check);
312 
313   ~Comm();
314 
315   void ucc_create_team(
316       ucc_team_h& team,
317       std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
318 
319   void ucc_destroy_team(ucc_team_h& team);
320 
321   c10::intrusive_ptr<Work> enqueue_p2p(
322       OpType opType,
323       ucc_coll_req_h request,
324       const char* prof_title);
325 
326 #ifdef USE_CUDA
327   void enqueue_cuda_collective(
328       std::unique_ptr<ProcessGroupUCC::WorkData> data,
329       c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
330       ucc_coll_args_t& coll,
331       ucc_team_h team,
332       ucc_ee_h ee);
333 #endif
334 
335   void enqueue_collective(
336       std::unique_ptr<ProcessGroupUCC::WorkData> data,
337       c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
338       ucc_coll_args_t& coll,
339       ucc_team_h team);
340 
341   static std::shared_ptr<Comm> get_comm(
342       uint32_t& id,
343       c10::Device dev,
344       std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
345       const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
346       bool is_health_check = false);
347 
348   void progress_loop();
349 };
350 
351 } // namespace c10d
352 
353 #endif // USE_C10D_UCC
354