1 #pragma once 2 3 #ifdef USE_C10D_UCC 4 5 #include <torch/csrc/distributed/c10d/UCCUtils.hpp> 6 7 #include <exception> 8 #include <memory> 9 #include <mutex> 10 #include <queue> 11 #include <thread> 12 #include <vector> 13 14 #include <torch/csrc/distributed/c10d/Backend.hpp> 15 #include <torch/csrc/distributed/c10d/Store.hpp> 16 #include <torch/csrc/distributed/c10d/Types.hpp> 17 #include <torch/csrc/distributed/c10d/Utils.hpp> 18 #ifdef USE_CUDA 19 #include <ATen/cuda/CUDAEvent.h> 20 #include <c10/cuda/CUDAStream.h> 21 #endif 22 23 namespace c10d { 24 25 #define TORCH_UCC_DEVICE_NOT_SET -2 26 27 #ifdef USE_CUDA 28 #define SAVE_TENSORS(_TENSORS, _DATA) \ 29 do { \ 30 if ((_TENSORS)[0].device().is_cuda()) { \ 31 for (const auto i : c10::irange((_TENSORS).size())) { \ 32 c10::cuda::CUDACachingAllocator::recordStream( \ 33 (_TENSORS)[i].storage().data_ptr(), (*stream)); \ 34 } \ 35 } else { \ 36 (_DATA) = (_TENSORS); \ 37 } \ 38 } while (0) 39 40 #else 41 #define SAVE_TENSORS(_TENSORS, _DATA) (_DATA) = (_TENSORS); 42 #endif 43 44 constexpr const char* UCC_BACKEND_NAME = "ucc"; 45 46 struct event_pool_t { 47 #ifdef USE_CUDA 48 std::queue<std::unique_ptr<at::cuda::CUDAEvent>> event_pool; 49 #endif 50 std::mutex event_pool_mutex; 51 }; 52 53 class Comm; 54 55 // UCC does not support multiple CUDA devices per process. 56 class TORCH_API ProcessGroupUCC : public Backend { 57 private: 58 void set_timeout(ucc_coll_args_t& args); 59 60 public: 61 class WorkData { 62 public: 63 std::vector<at::Tensor> src; 64 std::vector<at::Tensor> dst; 65 std::vector<at::Tensor> flat; WorkData()66 WorkData() {} 67 virtual ~WorkData() = default; 68 }; 69 class AlltoallWorkData : public WorkData { 70 public: AlltoallWorkData(int size)71 AlltoallWorkData(int size) 72 : send_lengths(size), 73 send_offsets(size), 74 recv_lengths(size), 75 recv_offsets(size) {} 76 std::vector<uint64_t> send_lengths; 77 std::vector<uint64_t> send_offsets; 78 std::vector<uint64_t> recv_lengths; 79 std::vector<uint64_t> recv_offsets; 80 }; 81 82 class AllgathervWorkData : public WorkData { 83 public: AllgathervWorkData(int size)84 AllgathervWorkData(int size) : recv_lengths(size), recv_offsets(size) {} 85 std::vector<uint64_t> recv_lengths; 86 std::vector<uint64_t> recv_offsets; 87 }; 88 89 class ScattervWorkData : public WorkData { 90 public: ScattervWorkData(int size)91 ScattervWorkData(int size) : send_lengths(size), send_offsets(size) {} 92 std::vector<uint64_t> send_lengths; 93 std::vector<uint64_t> send_offsets; 94 }; 95 96 class ProgressEntry { 97 friend class ProcessGroupUCC; 98 friend class Comm; 99 100 public: ProgressEntry(CommBase * comm,ucc_coll_req_h request)101 ProgressEntry(CommBase* comm, ucc_coll_req_h request) 102 : status_(UCC_INPROGRESS), comm_(comm), request_(request) {} 103 // Finalizes UCC status or exception of collective request. 104 void finalize(std::exception_ptr eptr = nullptr); 105 ucc_status_t status_; 106 CommBase* comm_; 107 ucc_coll_req_h request_; 108 std::unique_ptr<WorkData> data; 109 c10::intrusive_ptr<c10::ivalue::Future> future_; 110 std::exception_ptr eptr_; 111 }; 112 113 class WorkUCC : public Work { 114 friend class ProcessGroupUCC; 115 friend class Comm; 116 117 public: WorkUCC(OpType opType,uint64_t seq,const char * prof_title,const std::optional<std::vector<at::Tensor>> & inputs,const c10::intrusive_ptr<ProcessGroupUCCLogger> & logger)118 WorkUCC( 119 OpType opType, 120 uint64_t seq, 121 const char* prof_title, 122 const std::optional<std::vector<at::Tensor>>& inputs, 123 const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger) 124 : Work(-1, opType, prof_title, inputs), logger_(logger), seq_(seq) {} 125 ~WorkUCC(); 126 void setException(); 127 void setAndThrowException(); 128 bool isCompleted() override; 129 bool isSuccess() const override; 130 bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override; 131 c10::intrusive_ptr<c10::ivalue::Future> getFuture() override; 132 std::vector<at::Tensor> result() override; 133 int sourceRank() const override; 134 #ifdef USE_CUDA 135 std::unique_ptr<at::cuda::CUDAEvent> fence = nullptr; 136 event_pool_t* ep = nullptr; 137 #endif 138 int sourceRank_; 139 140 protected: 141 std::shared_ptr<ProgressEntry> entry_; 142 c10::intrusive_ptr<ProcessGroupUCCLogger> logger_; 143 uint64_t seq_; 144 145 private: 146 // The future returned by getFuture. 147 c10::intrusive_ptr<at::ivalue::Future> future_; 148 // Store a reference to collective's outputs, used by result 149 std::shared_ptr<std::vector<at::Tensor>> outputs_; 150 }; 151 152 explicit ProcessGroupUCC( 153 const c10::intrusive_ptr<Store>& store, 154 int rank = -1, 155 int size = -1, 156 std::chrono::duration<float> timeout = kBackendDefaultTimeout); 157 158 void initComm(c10::Device dev); 159 160 ~ProcessGroupUCC() override; 161 getBackendName() const162 const std::string getBackendName() const override { 163 return std::string(UCC_BACKEND_NAME); 164 } 165 166 #ifdef USE_CUDA 167 std::unique_ptr<at::cuda::CUDAEvent> getPooledEvent(); 168 #endif 169 170 // Performs a health check by initializing dummy UCC & UCX communicators and 171 // then destroying them. This will help indicate and signal any 172 // UCC/UCX-related issues prior to the first collective. The actual 173 // initialization and subsequent destruction is ran on a separate thread and 174 // the main thread is signalled about timeouts/errors to report to the 175 // application. 176 void runHealthCheck(); 177 178 template <typename PreProcess, typename PostProcess> 179 c10::intrusive_ptr<Work> collective_post( 180 OpType opType, 181 PreProcess preproc, 182 PostProcess postproc, 183 ucc_coll_args_t& coll, 184 std::unique_ptr<ProcessGroupUCC::WorkData> data, 185 c10::Device dev, 186 std::vector<at::Tensor>& inputTensors, 187 std::vector<at::Tensor>& outputTensors, 188 const char* prof_title); 189 190 c10::intrusive_ptr<Work> broadcast( 191 std::vector<at::Tensor>& data, 192 const BroadcastOptions& opts = BroadcastOptions()) override; 193 194 c10::intrusive_ptr<Work> allreduce( 195 std::vector<at::Tensor>& tensors, 196 const AllreduceOptions& opts = AllreduceOptions()) override; 197 198 c10::intrusive_ptr<Work> allreduce_coalesced( 199 std::vector<at::Tensor>& tensors, 200 const AllreduceCoalescedOptions& opts = 201 AllreduceCoalescedOptions()) override; 202 203 c10::intrusive_ptr<Work> reduce( 204 std::vector<at::Tensor>& tensors, 205 const ReduceOptions& opts = ReduceOptions()) override; 206 207 c10::intrusive_ptr<Work> allgather( 208 std::vector<std::vector<at::Tensor>>& outputTensors, 209 std::vector<at::Tensor>& inputTensors, 210 const AllgatherOptions& opts = AllgatherOptions()) override; 211 212 c10::intrusive_ptr<Work> _allgather_base( 213 at::Tensor& outputBuffer, 214 at::Tensor& inputBuffer, 215 const AllgatherOptions& opts = AllgatherOptions()) override; 216 217 c10::intrusive_ptr<Work> barrier( 218 const BarrierOptions& opts = BarrierOptions()) override; 219 220 c10::intrusive_ptr<Work> gather( 221 std::vector<std::vector<at::Tensor>>& outputTensors, 222 std::vector<at::Tensor>& inputTensors, 223 const GatherOptions& opts = GatherOptions()) override; 224 225 c10::intrusive_ptr<Work> scatter( 226 std::vector<at::Tensor>& outputTensors, 227 std::vector<std::vector<at::Tensor>>& inputTensors, 228 const ScatterOptions& opts = ScatterOptions()) override; 229 230 c10::intrusive_ptr<Work> reduce_scatter( 231 std::vector<at::Tensor>& outputTensors, 232 std::vector<std::vector<at::Tensor>>& inputTensors, 233 const ReduceScatterOptions& opts = ReduceScatterOptions()) override; 234 235 c10::intrusive_ptr<Work> alltoall_base( 236 at::Tensor& outputTensor, 237 at::Tensor& inputTensor, 238 std::vector<int64_t>& outputSplitSizes, 239 std::vector<int64_t>& inputSplitSizes, 240 const AllToAllOptions& opts = AllToAllOptions()) override; 241 242 c10::intrusive_ptr<Work> alltoall( 243 std::vector<at::Tensor>& outputTensors, 244 std::vector<at::Tensor>& inputTensors, 245 const AllToAllOptions& opts = AllToAllOptions()) override; 246 247 c10::intrusive_ptr<Work> send( 248 std::vector<at::Tensor>& tensors, 249 int dstRank, 250 int tag) override; 251 252 c10::intrusive_ptr<Work> recv( 253 std::vector<at::Tensor>& tensors, 254 int srcRank, 255 int tag) override; 256 257 // Counting for the sequential number of UCC collective_post call. 258 uint64_t seq_{0}; 259 260 // Agrees on an initial sequence number for the whole group by having rank 0 261 // create it and broadcast it to other ranks using the store. 262 void setSequenceNumberForGroup() override; 263 264 // Retrieves the current sequence number for the whole group, which should be 265 // in sync. If the returned number is not consistent across the group, it 266 // may indicate that there is some sort of collective desynchronization. 267 uint64_t getSequenceNumberForGroup() override; 268 269 static c10::intrusive_ptr<Backend> createProcessGroupUCC( 270 const c10::intrusive_ptr<::c10d::Store>& store, 271 int rank, 272 int size, 273 const std::chrono::duration<float>& timeout); 274 275 protected: 276 const std::chrono::duration<float> timeout_; 277 std::shared_ptr<torch_ucc_oob_coll_info_t> oob; 278 std::shared_ptr<Comm> comm = {nullptr}; 279 uint32_t comm_id; 280 ucc_team_h team{nullptr}; 281 ucc_ee_h cuda_ee{nullptr}; 282 ucc_ee_h cuda_ee_p2p[2]{nullptr, nullptr}; 283 284 #ifdef USE_CUDA 285 std::unique_ptr<at::cuda::CUDAStream> stream = nullptr; 286 std::unique_ptr<at::cuda::CUDAStream> stream_p2p[2] = {nullptr, nullptr}; 287 event_pool_t ep; 288 #endif 289 c10::intrusive_ptr<ProcessGroupUCCLogger> logger; 290 }; 291 292 class Comm { 293 c10::intrusive_ptr<ProcessGroupUCCLogger> logger; 294 std::shared_ptr<torch_ucc_oob_coll_info_t> oob; 295 CommUCC ucc_comm; 296 std::mutex mutex; 297 std::thread progress_thread; 298 std::condition_variable queue_produce_cv; 299 std::condition_variable queue_consume_cv; 300 std::deque<std::shared_ptr<ProcessGroupUCC::ProgressEntry>> progress_queue; 301 bool stop_progress_loop; 302 bool collective_inprogress; 303 torch_ucc_phase_t finalize_phase; 304 305 public: 306 c10::DeviceIndex cuda_device_index; 307 Comm( 308 const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger, 309 std::shared_ptr<torch_ucc_oob_coll_info_t> oob, 310 c10::Device dev, 311 bool is_health_check); 312 313 ~Comm(); 314 315 void ucc_create_team( 316 ucc_team_h& team, 317 std::shared_ptr<torch_ucc_oob_coll_info_t> oob); 318 319 void ucc_destroy_team(ucc_team_h& team); 320 321 c10::intrusive_ptr<Work> enqueue_p2p( 322 OpType opType, 323 ucc_coll_req_h request, 324 const char* prof_title); 325 326 #ifdef USE_CUDA 327 void enqueue_cuda_collective( 328 std::unique_ptr<ProcessGroupUCC::WorkData> data, 329 c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work, 330 ucc_coll_args_t& coll, 331 ucc_team_h team, 332 ucc_ee_h ee); 333 #endif 334 335 void enqueue_collective( 336 std::unique_ptr<ProcessGroupUCC::WorkData> data, 337 c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work, 338 ucc_coll_args_t& coll, 339 ucc_team_h team); 340 341 static std::shared_ptr<Comm> get_comm( 342 uint32_t& id, 343 c10::Device dev, 344 std::shared_ptr<torch_ucc_oob_coll_info_t> oob, 345 const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger, 346 bool is_health_check = false); 347 348 void progress_loop(); 349 }; 350 351 } // namespace c10d 352 353 #endif // USE_C10D_UCC 354