xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_C10D_UCC
2 
3 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
4 #include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
5 #include <torch/csrc/distributed/c10d/UCCTracing.hpp>
6 #include <torch/csrc/distributed/c10d/UCCUtils.hpp>
7 #include <list>
8 #include <memory>
9 #include <unordered_map>
10 #include <unordered_set>
11 
12 namespace c10d {
13 
14 namespace {
15 
16 const std::map<c10::DeviceType, ucc_memory_type_t> ucc_mtype_map = {
17     {c10::kCPU, UCC_MEMORY_TYPE_HOST},
18     {c10::kCUDA, UCC_MEMORY_TYPE_CUDA},
19 };
20 
to_ucc_memType(c10::DeviceType _c10_type)21 ucc_memory_type_t to_ucc_memType(c10::DeviceType _c10_type) {
22   if (ucc_mtype_map.find(_c10_type) != ucc_mtype_map.end())
23     return ucc_mtype_map.at(_c10_type);
24   else
25     return UCC_MEMORY_TYPE_UNKNOWN;
26 }
27 
28 const std::map<at::ScalarType, ucc_datatype_t> ucc_dtype_map = {
29     {at::kByte, UCC_DT_UINT8},
30     {at::kChar, UCC_DT_INT8},
31     {at::kHalf, UCC_DT_FLOAT16},
32     {at::kBFloat16, UCC_DT_BFLOAT16},
33     {at::kDouble, UCC_DT_FLOAT64},
34     {at::kFloat, UCC_DT_FLOAT32},
35     {at::kInt, UCC_DT_INT32},
36     {at::kLong, UCC_DT_INT64},
37     {at::kBool, UCC_DT_UINT8},
38 };
39 
to_ucc_dType(at::Tensor _tensor)40 ucc_datatype_t to_ucc_dType(at::Tensor _tensor) {
41   if (_tensor.scalar_type() == at::kBool && _tensor.element_size() != 1) {
42     TORCH_CHECK(
43         false, "Size of Boolean type larger than 1 is not supported in UCC");
44   }
45   try {
46     return ucc_dtype_map.at(_tensor.scalar_type());
47   } catch (const std::out_of_range&) {
48     TORCH_CHECK(false, "Not supported data type for UCC");
49   }
50 }
51 
52 const std::map<ReduceOp, ucc_reduction_op_t> ucc_op_map = {
53     {ReduceOp::SUM, UCC_OP_SUM},
54     {ReduceOp::PRODUCT, UCC_OP_PROD},
55     {ReduceOp::MIN, UCC_OP_MIN},
56     {ReduceOp::MAX, UCC_OP_MAX},
57     {ReduceOp::BAND, UCC_OP_BAND},
58     {ReduceOp::BOR, UCC_OP_BOR},
59     {ReduceOp::BXOR, UCC_OP_BXOR},
60     {ReduceOp::AVG, UCC_OP_AVG},
61 };
62 
to_ucc_reduceOp(const ReduceOp _op,const at::ScalarType _dt)63 ucc_reduction_op_t to_ucc_reduceOp(
64     const ReduceOp _op,
65     const at::ScalarType _dt) {
66   if (_dt == at::kBool) {
67     if (_op == ReduceOp::SUM) {
68       // bitwise or
69       return UCC_OP_MAX;
70     } else if (_op == ReduceOp::PRODUCT) {
71       // bitwise and
72       return UCC_OP_MIN;
73     } else if (_op == ReduceOp::AVG) {
74       TORCH_CHECK(false, "Cannot use ReduceOp.AVG with boolean inputs");
75     }
76   }
77 
78   try {
79     return ucc_op_map.at(_op);
80   } catch (const std::out_of_range&) {
81     TORCH_CHECK(false, "Not supported ReduceOp for UCC");
82   }
83 }
84 
85 struct torch_ucc_config_t {
86   c10::once_flag flag;
87   std::array<bool, 32> blocking_wait;
88   bool enable_comms_logger;
89   bool use_future;
90   // Sharing UCC communicator among multiple PGs to save resource.
91   bool shared_comm;
92   // Using allgatherv to achieve allgather, without flattening the list of
93   // (potentially non-contiguous) tensors.
94   bool use_allgatherv;
95   bool enable_health_check;
96 } torch_ucc_config;
97 
98 std::unordered_map<std::string, std::string> torch_ucc_envs_map = {
99     // TORCH_UCC_BLOCKING_WAIT allowed syntax:
100     // - TORCH_UCC_BLOCKING_WAIT=none --> blocking wait completely disabled
101     // - TORCH_UCC_BLOCKING_WAIT=all --> blocking wait completely enabled
102     // - TORCH_UCC_BLOCKING_WAIT=allreduce,send,recv --> blocking wait enabled
103     //                                                   on selected operations
104     // Supported operations:
105     // [allgather,allgather_base,allreduce,alltoall,broadcast,
106     //  gather,reduce,reduce_scatter,scatter,send,recv]
107     {"TORCH_UCC_BLOCKING_WAIT", "none"},
108 
109     {"TORCH_UCC_USE_FUTURE", "1"},
110     {"TORCH_UCC_PROFILING_ENABLE", "0"},
111     {"TORCH_UCC_SHARED_COMM", "1"},
112     {"TORCH_UCC_USE_ALLGATHERV", "0"},
113     {"TORCH_UCC_ENABLE_HEALTH_CHECK", "0"},
114     {"TORCH_UCC_ENABLE_COMMS_LOGGER", "0"},
115 };
116 
parse_blocking_wait(std::string op_list_string)117 std::vector<OpType> parse_blocking_wait(std::string op_list_string) {
118   const static std::unordered_map<std::string, OpType> str2op = {
119       {"allgather", OpType::ALLGATHER},
120       {"allgather_base", OpType::_ALLGATHER_BASE},
121       {"allreduce", OpType::ALLREDUCE},
122       {"alltoall_base", OpType::ALLTOALL_BASE},
123       {"broadcast", OpType::BROADCAST},
124       {"gather", OpType::GATHER},
125       {"reduce", OpType::REDUCE},
126       {"reduce_scatter", OpType::REDUCE_SCATTER},
127       {"scatter", OpType::SCATTER},
128       {"send", OpType::SEND},
129       {"recv", OpType::RECV},
130   };
131   auto op_list = parse_list(op_list_string);
132   if (op_list == std::vector<std::string>{"none"}) {
133     return {};
134   }
135   std::vector<OpType> result;
136   if (op_list == std::vector<std::string>{"all"}) {
137     for (auto entry : str2op) {
138       result.push_back(entry.second);
139     }
140   } else {
141     for (auto op_string : op_list) {
142       result.push_back(str2op.at(op_string));
143     }
144   }
145   return result;
146 }
147 
148 } // namespace
149 
read_config()150 void read_config() {
151   // default configuration
152   torch_ucc_config.blocking_wait.fill(false);
153   torch_ucc_config.use_future = true;
154   torch_ucc_config.shared_comm = false;
155   torch_ucc_config.use_allgatherv = false;
156   torch_ucc_config.enable_health_check = false;
157   torch_ucc_config.enable_comms_logger = false;
158 
159   // read all torch_ucc env. variables and update the map
160   char* env;
161   for (auto& torch_ucc_env : torch_ucc_envs_map) {
162     env = std::getenv(torch_ucc_env.first.c_str());
163     if (env) {
164       torch_ucc_envs_map[torch_ucc_env.first] = std::string(env);
165     }
166   }
167 
168   auto blocking_wait_str = torch_ucc_envs_map.at("TORCH_UCC_BLOCKING_WAIT");
169   for (auto op : parse_blocking_wait(blocking_wait_str)) {
170     torch_ucc_config.blocking_wait[(std::uint8_t)op] = true;
171   }
172   // barrier is always blocking
173   torch_ucc_config.blocking_wait[(std::uint8_t)OpType::BARRIER] = true;
174 
175   torch_ucc_config.use_future =
176       std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_FUTURE"));
177   torch_ucc_config.shared_comm =
178       std::stoi(torch_ucc_envs_map.at("TORCH_UCC_SHARED_COMM"));
179   torch_ucc_config.use_allgatherv =
180       std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_ALLGATHERV"));
181   torch_ucc_config.enable_health_check =
182       std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_HEALTH_CHECK"));
183   torch_ucc_config.enable_comms_logger =
184       std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_COMMS_LOGGER"));
185 }
186 
check_device(c10::Device dev1,c10::Device dev2)187 void check_device(c10::Device dev1, c10::Device dev2) {
188   if (dev1.is_cuda() && dev2.is_cuda() && dev1 != dev2) {
189     throw std::invalid_argument("ProcessGroupUCC multidevice is not supported");
190   }
191 }
192 
check_tensor(const std::vector<at::Tensor> & tensors)193 void check_tensor(const std::vector<at::Tensor>& tensors) {
194   if (tensors.size() != 1) {
195     throw std::invalid_argument(
196         "ProcessGroupUCC takes 1 tensor. Got " +
197         std::to_string(tensors.size()) + ". ");
198   }
199   if (!tensors[0].is_contiguous()) {
200     throw std::invalid_argument(
201         "ProcessGroupUCC input tensor has to be contiguous");
202   }
203   if (tensors[0].is_sparse()) {
204     throw std::invalid_argument("ProcessGroupUCC input tensor has to be dense");
205   }
206   // TODO: check cuda case
207 }
208 
~WorkUCC()209 ProcessGroupUCC::WorkUCC::~WorkUCC() {
210 #ifdef USE_CUDA
211   if (fence && ep) {
212     std::lock_guard<std::mutex> lock(ep->event_pool_mutex);
213     ep->event_pool.push(std::move(fence));
214   }
215 #endif
216 }
217 
setException()218 void ProcessGroupUCC::WorkUCC::setException() {
219   if (exception() || !entry_) {
220     return;
221   }
222   exception_ = entry_->eptr_;
223 }
224 
setAndThrowException()225 void ProcessGroupUCC::WorkUCC::setAndThrowException() {
226   setException();
227   if (exception()) {
228     std::rethrow_exception(exception());
229   }
230 }
231 
isCompleted()232 bool ProcessGroupUCC::WorkUCC::isCompleted() {
233   if (!entry_) {
234     return true;
235   }
236   setException();
237   // status_ <= 0 to avoid listing all possible status codes.  The main thread
238   // needs to be unblocked when UCC (in progress thread) returns success (== 0)
239   // or any error code (< 0).
240   return exception() || entry_->status_ <= 0;
241 }
242 
isSuccess() const243 bool ProcessGroupUCC::WorkUCC::isSuccess() const {
244   if (!entry_) {
245     return true;
246   }
247   return !exception() && entry_->status_ == 0;
248 }
249 
wait(std::chrono::milliseconds)250 bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) {
251   if (torch_ucc_config.enable_comms_logger && logger_) {
252     logger_->trace_generator->recordComms("wait", (uintptr_t)this, rank_);
253   }
254 #ifdef USE_CUDA
255   if (fence && !torch_ucc_config.blocking_wait[(int)opType_]) {
256     // block user stream
257     setAndThrowException();
258     fence->block(at::cuda::getCurrentCUDAStream());
259     return true;
260   }
261 #endif
262   // wait for complete.  For blocking case, the main thread will be blocked in
263   // this loop until the progress thread changes the status of this request.
264   // If timeout occurs, UCC will return UCC_ERR_TIMEOUT as the status.  The
265   // main thread will throw out the exception then. There is no "abort"
266   // function in UCC currently.
267   while (!isCompleted())
268     ;
269   setAndThrowException();
270   // manually call profiling end callbacks if they are set,
271   // since progress thread does not own WorkUCC
272   if (Work::recordFunctionEndCallback_) {
273     Work::recordFunctionEndCallback_();
274     Work::recordFunctionEndCallback_ = nullptr;
275   }
276   return true;
277 }
278 
getFuture()279 c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupUCC::WorkUCC::getFuture() {
280   return future_;
281 }
282 
sourceRank() const283 int ProcessGroupUCC::WorkUCC::sourceRank() const {
284   if (opType_ != OpType::RECV && opType_ != OpType::RECVANYSOURCE) {
285     // Throw an error
286     return Work::sourceRank();
287   }
288   return sourceRank_;
289 }
290 
result()291 std::vector<at::Tensor> ProcessGroupUCC::WorkUCC::result() {
292   return *outputs_;
293 }
294 
finalize(std::exception_ptr eptr)295 void ProcessGroupUCC::ProgressEntry::finalize(std::exception_ptr eptr) {
296   ucc_status_t status = UCC_OK;
297 
298   if (request_ != nullptr) {
299     status = request_->status;
300     comm_->free_request(request_);
301   }
302   if (eptr) {
303     eptr_ = eptr;
304   } else {
305     status_ = status;
306   }
307   if (future_) {
308     if (eptr) {
309       future_->setError(eptr);
310     } else {
311       future_->markCompleted(
312           c10::IValue(data ? data->dst : std::vector<at::Tensor>()));
313     }
314   }
315 }
316 
Comm(const c10::intrusive_ptr<ProcessGroupUCCLogger> & logger_,std::shared_ptr<torch_ucc_oob_coll_info_t> oob_,c10::Device dev,bool is_health_check)317 Comm::Comm(
318     const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger_,
319     std::shared_ptr<torch_ucc_oob_coll_info_t> oob_,
320     c10::Device dev,
321     bool is_health_check)
322     : logger(logger_),
323       oob(oob_),
324       ucc_comm(oob, logger),
325       finalize_phase(
326           is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_FINALIZE),
327       cuda_device_index(TORCH_UCC_DEVICE_NOT_SET) {
328   if (dev.is_cuda()) {
329     cuda_device_index = dev.index();
330   }
331   stop_progress_loop = false;
332   collective_inprogress = false;
333   progress_thread = std::thread(&Comm::progress_loop, this);
334 #ifdef _GNU_SOURCE
335   pthread_setname_np(progress_thread.native_handle(), "ucc-progress");
336 #endif
337 }
338 
~Comm()339 Comm::~Comm() {
340   std::unique_lock<std::mutex> lock(mutex);
341   queue_consume_cv.wait(
342       lock, [&] { return progress_queue.empty() && !collective_inprogress; });
343   stop_progress_loop = true;
344   lock.unlock();
345   queue_produce_cv.notify_all();
346   progress_thread.join();
347 }
348 
get_comm(uint32_t & id,c10::Device dev,std::shared_ptr<torch_ucc_oob_coll_info_t> oob,const c10::intrusive_ptr<ProcessGroupUCCLogger> & logger,bool is_health_check)349 std::shared_ptr<Comm> Comm::get_comm(
350     uint32_t& id,
351     c10::Device dev,
352     std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
353     const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
354     bool is_health_check) {
355   static std::mutex m;
356   static std::weak_ptr<Comm> comm;
357   static uint32_t comm_id;
358 
359   std::lock_guard<std::mutex> lock(m);
360   id = comm_id;
361 
362   std::string group_id = "group_id";
363   if (is_health_check) {
364     group_id = c10::str(dev.type()) + "/" + group_id;
365   }
366 
367   std::vector<uint8_t> remote_comm_id;
368   oob->store->deleteKey(group_id + std::to_string(0));
369   if (oob->rank != 0) {
370     std::vector<uint8_t> val = std::vector<uint8_t>(
371         reinterpret_cast<uint8_t*>(&id),
372         reinterpret_cast<uint8_t*>(&id) + sizeof(id));
373     oob->store->set(group_id + std::to_string(oob->rank), val);
374   } else {
375     for (int i = 1; i < oob->size; i++) {
376       remote_comm_id = oob->store->get(group_id + std::to_string(i));
377       oob->store->deleteKey(group_id + std::to_string(i));
378       // Find the highest id.
379       id = std::max(id, *(reinterpret_cast<uint32_t*>(remote_comm_id.data())));
380     }
381     std::vector<uint8_t> val = std::vector<uint8_t>(
382         reinterpret_cast<uint8_t*>(&id),
383         reinterpret_cast<uint8_t*>(&id) + sizeof(id));
384     oob->store->set(group_id + std::to_string(oob->rank), val);
385   }
386   remote_comm_id = oob->store->get(group_id + std::to_string(0));
387   oob->comm_id = *(reinterpret_cast<uint32_t*>(remote_comm_id.data()));
388   // Prepare comm_id (static variable) to the next id.
389   comm_id = oob->comm_id + 1;
390 
391   if (torch_ucc_config.shared_comm) {
392     std::shared_ptr<Comm> shared_comm = comm.lock();
393     if (!shared_comm) {
394       shared_comm = std::make_shared<Comm>(logger, oob, dev, is_health_check);
395       comm = shared_comm;
396     } else {
397       if (dev.is_cuda() && !is_health_check) {
398         if ((shared_comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
399             (shared_comm->cuda_device_index != dev.index())) {
400           TORCH_UCC_LOG_ERROR(
401               is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_INIT,
402               "ucc communicator was initialized with different cuda device,"
403               "multi device is not supported");
404           throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
405         }
406         shared_comm->cuda_device_index = dev.index();
407       }
408     }
409     return shared_comm;
410   } else {
411     return std::make_shared<Comm>(logger, oob, dev, is_health_check);
412   }
413 }
414 
ucc_create_team(ucc_team_h & team,std::shared_ptr<torch_ucc_oob_coll_info_t> oob)415 void Comm::ucc_create_team(
416     ucc_team_h& team,
417     std::shared_ptr<torch_ucc_oob_coll_info_t> oob) {
418   ucc_status_t st;
419   ucc_team_params_t team_params;
420   team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_EP_RANGE |
421       UCC_TEAM_PARAM_FIELD_OOB;
422   team_params.oob.allgather = oob_allgather;
423   team_params.oob.req_test = oob_allgather_test;
424   team_params.oob.req_free = oob_allgather_free;
425   team_params.oob.coll_info = oob.get();
426   team_params.oob.n_oob_eps = oob->size;
427   team_params.oob.oob_ep = oob->rank;
428   team_params.ep = oob->rank;
429   team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG;
430   TORCH_UCC_CHECK(
431       ucc_team_create_post(&ucc_comm.context, 1, &team_params, &team),
432       "failed to post team create");
433   do {
434     st = ucc_team_create_test(team);
435     ucc_context_progress(ucc_comm.context);
436   } while (st == UCC_INPROGRESS);
437   TORCH_UCC_CHECK(st, "failed to create UCC team");
438 }
439 
ucc_destroy_team(ucc_team_h & team)440 void Comm::ucc_destroy_team(ucc_team_h& team) {
441   std::unique_lock<std::mutex> lock(mutex);
442   queue_consume_cv.wait(
443       lock, [&] { return progress_queue.empty() && !collective_inprogress; });
444 
445   ucc_status_t status;
446   while (UCC_INPROGRESS == (status = ucc_team_destroy(team))) {
447     if (UCC_OK != status) {
448       TORCH_UCC_LOG_ERROR(
449           finalize_phase,
450           c10::str("ucc team destroy error: ", ucc_status_string(status)));
451       break;
452     }
453   }
454 
455   lock.unlock();
456 }
457 
enqueue_collective(std::unique_ptr<ProcessGroupUCC::WorkData> data,c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,ucc_coll_args_t & coll,ucc_team_h team)458 void Comm::enqueue_collective(
459     std::unique_ptr<ProcessGroupUCC::WorkData> data,
460     c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
461     ucc_coll_args_t& coll,
462     ucc_team_h team) {
463   ucc_coll_req_h request;
464   TORCH_UCC_CHECK(
465       ucc_collective_init(&coll, &request, team), "failed to init collective");
466   TORCH_UCC_CHECK_REQUEST(
467       request, ucc_collective_post(request), "failed to post collective");
468 
469   auto entry =
470       std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request);
471   entry->data = std::move(data);
472   entry->future_ = work->getFuture();
473   work->entry_ = entry;
474   std::unique_lock<std::mutex> lock(mutex);
475   progress_queue.push_back(entry);
476   lock.unlock();
477   queue_produce_cv.notify_one();
478 }
479 
480 #ifdef USE_CUDA
enqueue_cuda_collective(std::unique_ptr<ProcessGroupUCC::WorkData> data,c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,ucc_coll_args_t & coll,ucc_team_h team,ucc_ee_h ee)481 void Comm::enqueue_cuda_collective(
482     std::unique_ptr<ProcessGroupUCC::WorkData> data,
483     c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
484     ucc_coll_args_t& coll,
485     ucc_team_h team,
486     ucc_ee_h ee) {
487   ucc_coll_req_h request;
488   TORCH_UCC_CHECK(
489       ucc_collective_init(&coll, &request, team),
490       "failed to init cuda collective");
491   ucc_ev_t comp_ev, *post_ev;
492   comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE;
493   comp_ev.ev_context = nullptr;
494   comp_ev.ev_context_size = 0;
495   comp_ev.req = request;
496   TORCH_UCC_CHECK_REQUEST(
497       request,
498       ucc_collective_triggered_post(ee, &comp_ev),
499       "failed to post triggered collective");
500   ucc_status_t st = ucc_ee_get_event(ee, &post_ev);
501   TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST);
502   ucc_ee_ack_event(ee, post_ev);
503   auto entry =
504       std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request);
505   entry->data = std::move(data);
506   work->entry_ = entry;
507   std::unique_lock<std::mutex> lock(mutex);
508   progress_queue.push_back(entry);
509   lock.unlock();
510   queue_produce_cv.notify_one();
511 }
512 #endif
513 
progress_loop()514 void Comm::progress_loop() {
515   std::unique_lock<std::mutex> lock(mutex);
516 #ifdef USE_CUDA
517   bool device_set = false;
518 #endif
519   while (!stop_progress_loop) {
520     if (progress_queue.empty()) {
521       queue_produce_cv.wait(lock);
522       continue;
523     }
524     collective_inprogress = true;
525     auto work = progress_queue.front();
526     progress_queue.pop_front();
527     lock.unlock();
528 #ifdef USE_CUDA
529     if ((!device_set) && (cuda_device_index != TORCH_UCC_DEVICE_NOT_SET)) {
530       c10::cuda::set_device(cuda_device_index);
531       CUcontext pctx = nullptr;
532       at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx);
533       if (C10_UNLIKELY(!pctx)) {
534         at::globalContext().getNVRTC().cuDevicePrimaryCtxRetain(
535             &pctx, cuda_device_index);
536         at::globalContext().getNVRTC().cuCtxSetCurrent(pctx);
537       }
538       device_set = true;
539     }
540 #endif
541     std::exception_ptr eptr;
542     try {
543       while (work->request_->status > 0) {
544         ucc_comm.progress();
545       }
546       if (work->request_->status < 0) {
547         eptr = std::make_exception_ptr(
548             std::runtime_error(ucc_status_string(work->request_->status)));
549         std::string err_log = c10::str(
550             "Failed to progress communication", // TODO: report exact op type or
551                                                 // id?
552             ucc_status_string(work->request_->status));
553         TORCH_UCC_LOG_ERROR(TORCH_UCC_COLL_PROGRESS, err_log);
554       }
555     } catch (...) {
556       eptr = std::current_exception();
557     }
558     work->finalize(eptr);
559     work = nullptr;
560     collective_inprogress = false;
561     queue_consume_cv.notify_one();
562     lock.lock();
563   }
564 }
565 
ProcessGroupUCC(const c10::intrusive_ptr<Store> & store,int rank,int size,std::chrono::duration<float> timeout)566 ProcessGroupUCC::ProcessGroupUCC(
567     const c10::intrusive_ptr<Store>& store,
568     int rank,
569     int size,
570     std::chrono::duration<float> timeout)
571     : Backend(rank, size), timeout_(timeout) {
572   c10::call_once(torch_ucc_config.flag, read_config);
573   oob = std::make_shared<torch_ucc_oob_coll_info_t>();
574   oob->rank = rank;
575   oob->size = size;
576   oob->store = store;
577   comm = nullptr;
578   cuda_ee = nullptr;
579   static uint32_t id = 0;
580   uint32_t pg_id = id++;
581 
582   logger = c10::make_intrusive<ProcessGroupUCCLogger>(
583       c10::str("[Rank ", rank_, "]", "[ProcessGroupUCC-", pg_id, "]"),
584       TORCH_UCC_INIT);
585   TORCH_UCC_LOG_INFO(
586       TORCH_UCC_INIT,
587       c10::str(
588           "Created ProcessGroupUCC with ",
589           size,
590           " ranks, with timeout ",
591           timeout_.count(),
592           " secs"));
593   std::string envs = "";
594   for (auto& torch_ucc_env : torch_ucc_envs_map) {
595     envs += ("\n\t" + torch_ucc_env.first + "=" + torch_ucc_env.second);
596   }
597   TORCH_UCC_LOG_INFO(
598       TORCH_UCC_INIT,
599       c10::str(
600           "Successfully read and set ProcessGroupUCC env. variables as followings",
601           envs));
602 
603   if (torch_ucc_config.enable_health_check) {
604     // Perform health check by initializing dummy communicators and destroying
605     // them. This will help indicate any UCC/UCX-related issues prior to the
606     // first collective. Run it in a separate thread and wait on CV to handle
607     // timeouts so that if there are hangs, the main thread can still run
608     // correctly.
609     runHealthCheck();
610   }
611   if (torch_ucc_config.enable_comms_logger) {
612     logger->initCommsTracer();
613   }
614 }
615 
~ProcessGroupUCC()616 ProcessGroupUCC::~ProcessGroupUCC() {
617   if (torch_ucc_config.enable_comms_logger) {
618     logger->flushComms(this->getRank(), this->getSize());
619   }
620   if (comm) {
621     logger->setPhase(TORCH_UCC_FINALIZE);
622     comm->ucc_destroy_team(team);
623     TORCH_UCC_LOG_INFO(
624         TORCH_UCC_FINALIZE, "Successfully destroyed UCC library");
625     try {
626       if (cuda_ee) {
627         ucc_ee_destroy(cuda_ee);
628         ucc_ee_destroy(cuda_ee_p2p[0]);
629         ucc_ee_destroy(cuda_ee_p2p[1]);
630       }
631     } catch (std::exception& ex) {
632       TORCH_UCC_LOG_INFO(
633           TORCH_UCC_FINALIZE,
634           c10::str(
635               "(~ProcessGroupUCC) Caught error in Store Operation .. ",
636               "[",
637               ex.what(),
638               "]"));
639     }
640     comm = nullptr;
641   }
642 }
643 
644 #ifdef USE_CUDA
645 // Return CUDA device with ordinal given by input rank.
getCUDADeviceForRank(int rank)646 c10::Device getCUDADeviceForRank(int rank) {
647   TORCH_CHECK(rank >= 0, "Invalid rank ", rank);
648   auto numGPUs = at::cuda::getNumGPUs();
649   auto deviceIdx = static_cast<c10::DeviceIndex>(rank % numGPUs);
650   return c10::Device(c10::DeviceType::CUDA, deviceIdx);
651 }
652 #endif
653 
runHealthCheck()654 void ProcessGroupUCC::runHealthCheck() {
655   // Run health check in a separate thread and wait on CV to handle timeouts.
656   // This design allows us to handle hangs.
657 
658   // When size_ is 1, there is no need to do any communication at all.
659   if (size_ == 1)
660     return;
661 
662   struct HealthCheckData {
663     std::mutex healthCheckMutex;
664     std::condition_variable healthCheckCv;
665     bool uccHealthCheckSuccess = false;
666     std::exception_ptr healthCheckException;
667   } healthCheckData;
668 
669   auto t = std::thread([&healthCheckData, this]() {
670     std::list<c10::Device> devices{c10::kCPU};
671 #ifdef USE_CUDA
672     c10::cuda::OptionalCUDAGuard gpuGuard;
673     if (at::cuda::is_available()) {
674       devices.emplace_front(getCUDADeviceForRank(rank_));
675     }
676 #endif
677     for (auto device : devices) {
678       bool is_last_device = (device == devices.back());
679       try {
680         auto oob = std::make_shared<torch_ucc_oob_coll_info_t>();
681         oob->rank = this->oob->rank;
682         oob->size = this->oob->size;
683         oob->store = this->oob->store;
684         ucc_team_h team = nullptr;
685         uint32_t comm_id;
686 #ifdef USE_CUDA
687         if (device.is_cuda()) {
688           gpuGuard.set_index(device.index());
689         }
690 #endif
691         auto comm = Comm::get_comm(comm_id, device, oob, logger, true);
692         comm->ucc_create_team(team, oob);
693         comm->ucc_destroy_team(team);
694         TORCH_UCC_LOG_INFO(
695             TORCH_UCC_HEALTH_CHECK,
696             c10::str(
697                 "UCC library health check succeed for device ",
698                 c10::DeviceTypeName(device.type())));
699         // Mark ucc health check as complete.
700         if (is_last_device) {
701           std::lock_guard<std::mutex> lk(healthCheckData.healthCheckMutex);
702           healthCheckData.uccHealthCheckSuccess = true;
703         }
704 
705         comm = nullptr;
706         oob = nullptr;
707         // Notify main thread the health check is complete.
708         if (is_last_device) {
709           healthCheckData.healthCheckCv.notify_one();
710         }
711       } catch (const std::exception&) {
712         // Populate exception ptr.
713         healthCheckData.healthCheckException = std::current_exception();
714         // Unblock waiting main thread which will report exception.
715         healthCheckData.healthCheckCv.notify_one();
716       } // Unknown exceptions will just cause the program to terminate.
717     }
718   });
719   // We don't need to join the thread, just need to verify health check via the
720   // CV. Hence we detach the thread here.
721   t.detach(); // NOLINT
722   TORCH_UCC_LOG_INFO(
723       TORCH_UCC_HEALTH_CHECK,
724       c10::str(
725           "will wait up to ",
726           timeout_.count(),
727           " msec for UCC health check to complete."));
728   std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex);
729   healthCheckData.healthCheckCv.wait_for(lock, timeout_, [&healthCheckData]() {
730     return healthCheckData.uccHealthCheckSuccess;
731   });
732 
733   if (healthCheckData.healthCheckException) {
734     std::rethrow_exception(healthCheckData.healthCheckException);
735   }
736   // If there is no exception, the likely culprit is a timeout/hang
737   TORCH_CHECK(
738       healthCheckData.uccHealthCheckSuccess,
739       "ProcessGroupUCC: Health check failure: Failed to initialize UCC on rank ",
740       rank_);
741 }
742 
set_timeout(ucc_coll_args_t & args)743 void ProcessGroupUCC::set_timeout(ucc_coll_args_t& args) {
744   args.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
745   args.flags |= UCC_COLL_ARGS_FLAG_TIMEOUT;
746   args.timeout = timeout_.count();
747 }
748 
749 #ifdef USE_CUDA
getPooledEvent()750 std::unique_ptr<at::cuda::CUDAEvent> ProcessGroupUCC::getPooledEvent() {
751   std::unique_ptr<at::cuda::CUDAEvent> ev;
752   std::lock_guard<std::mutex> lock(ep.event_pool_mutex);
753   if (ep.event_pool.empty()) {
754     ev = std::make_unique<at::cuda::CUDAEvent>();
755   } else {
756     ev = std::move(ep.event_pool.front());
757     ep.event_pool.pop();
758   }
759   return ev;
760 }
761 #endif
762 
763 template <typename PreProcess, typename PostProcess>
collective_post(OpType opType,PreProcess preproc,PostProcess postproc,ucc_coll_args_t & coll,std::unique_ptr<ProcessGroupUCC::WorkData> data,c10::Device dev,std::vector<at::Tensor> & inputTensors,std::vector<at::Tensor> & outputTensors,const char * prof_title)764 c10::intrusive_ptr<Work> ProcessGroupUCC::collective_post(
765     OpType opType,
766     PreProcess preproc,
767     PostProcess postproc,
768     ucc_coll_args_t& coll,
769     std::unique_ptr<ProcessGroupUCC::WorkData> data,
770     c10::Device dev,
771     std::vector<at::Tensor>& inputTensors,
772     std::vector<at::Tensor>& outputTensors,
773     const char* prof_title) {
774   seq_++;
775   set_timeout(coll);
776   auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
777       opType, seq_, prof_title, inputTensors, logger);
778 
779   if (opType == OpType::RECV) {
780     work->sourceRank_ = coll.root;
781   }
782 
783   RECORD_COMMS_TRACE(
784       logger->trace_generator,
785       work,
786       opType,
787       this->getRank(),
788       this->getSize(),
789       inputTensors,
790       outputTensors);
791 
792   // Store references to outputs to be used by result
793   work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputTensors);
794   switch (dev.type()) {
795     case c10::DeviceType::CPU: {
796       if (torch_ucc_config.use_future) {
797         work->future_ = c10::make_intrusive<at::ivalue::Future>(
798             c10::ListType::create(c10::TensorType::get()));
799       }
800       preproc();
801       comm->enqueue_collective(std::move(data), work, coll, team);
802       postproc();
803       return work;
804     }
805 #ifdef USE_CUDA
806     case c10::DeviceType::CUDA: {
807       auto cuda_ev = getPooledEvent();
808       at::cuda::CUDAStream* op_stream;
809       ucc_ee_h* op_ee;
810       if (opType == OpType::SEND) {
811         op_stream = stream_p2p[0].get();
812         op_ee = &cuda_ee_p2p[0];
813       } else if (opType == OpType::RECV) {
814         op_stream = stream_p2p[1].get();
815         op_ee = &cuda_ee_p2p[1];
816       } else {
817         op_stream = stream.get();
818         op_ee = &cuda_ee;
819       }
820 
821       cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index()));
822       cuda_ev->block(*op_stream);
823       at::cuda::CUDAStreamGuard guard(*op_stream);
824       preproc();
825       comm->enqueue_cuda_collective(std::move(data), work, coll, team, *op_ee);
826       postproc();
827       cuda_ev->record(*op_stream);
828       work->fence = std::move(cuda_ev);
829       work->ep = &ep;
830       if (torch_ucc_config.use_future) {
831         c10::cuda::CUDAMultiStreamGuard streamGuard(*op_stream);
832         std::vector<c10::Device> devList{dev};
833         work->future_ = c10::make_intrusive<at::ivalue::Future>(
834             c10::ListType::create(c10::TensorType::get()), devList);
835         // Add a callback that runs profiling end callbacks
836         if (work->recordFunctionEndCallback_) {
837           work->future_->addCallback([work](at::ivalue::Future& /* unused */) {
838             work->recordFunctionEndCallback_();
839           });
840         }
841 
842         work->future_->markCompleted(c10::IValue(outputTensors));
843       }
844       return work;
845     }
846 #endif // #ifdef USE_CUDA
847     default: {
848       TORCH_UCC_LOG_ERROR(
849           TORCH_UCC_COLL_POST, c10::str("unsupported device type ", dev.str()));
850       throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
851     }
852   }
853 }
854 
allgather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions &)855 c10::intrusive_ptr<Work> ProcessGroupUCC::allgather(
856     std::vector<std::vector<at::Tensor>>& outputTensors,
857     std::vector<at::Tensor>& inputTensors,
858     const AllgatherOptions& /* unused */) {
859   auto& tensor = inputTensors[0];
860   check_device(tensor.device(), outputTensors[0][0].device());
861   initComm(tensor.device());
862 
863   if (tensor.device().is_cpu() || torch_ucc_config.use_allgatherv) {
864     AllgathervWorkData* data = new AllgathervWorkData(size_);
865     for (int i = 0; i < size_; i++) {
866       data->recv_lengths[i] = tensor.element_size() * tensor.numel();
867       data->recv_offsets[i] = (uint64_t)outputTensors[0][i].data_ptr();
868     }
869     ucc_coll_args_t coll;
870     coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
871     coll.flags =
872         UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
873     coll.coll_type = UCC_COLL_TYPE_ALLGATHERV;
874     coll.src.info.buffer = tensor.data_ptr();
875     coll.src.info.count = tensor.element_size() * tensor.numel();
876     coll.src.info.datatype = UCC_DT_UINT8;
877     coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
878     coll.dst.info_v.buffer = nullptr;
879     coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
880     coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
881     coll.dst.info_v.datatype = UCC_DT_UINT8;
882     coll.dst.info_v.mem_type =
883         to_ucc_memType(outputTensors[0][0].device().type());
884     SAVE_TENSORS(inputTensors, data->src);
885     SAVE_TENSORS(outputTensors[0], data->dst);
886 
887     return collective_post(
888         OpType::ALLGATHER,
889         []() {},
890         []() {},
891         coll,
892         std::unique_ptr<WorkData>(data),
893         tensor.device(),
894         inputTensors,
895         outputTensors[0],
896         "ucc:all_gather");
897   } else {
898     WorkData* data = new WorkData();
899     std::vector<at::Tensor> flat_output(outputTensors.size());
900     for (size_t i = 0; i < outputTensors.size(); i++) {
901       TORCH_CHECK(
902           outputTensors[i].size() == outputTensors.size() * size_,
903           "Tensor output list is not valid for the number of participants");
904       flat_output[i] = c10d::newLikeFlat(outputTensors, i);
905     }
906     SAVE_TENSORS(flat_output, data->flat);
907     ucc_coll_args_t coll;
908     coll.mask = 0;
909     coll.flags = 0;
910     coll.coll_type = UCC_COLL_TYPE_ALLGATHER;
911     coll.src.info.buffer = tensor.data_ptr();
912     coll.src.info.count = tensor.numel();
913     coll.src.info.datatype = to_ucc_dType(tensor);
914     coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
915     coll.dst.info.buffer = flat_output[0].data_ptr();
916     coll.dst.info.count = flat_output[0].numel();
917     coll.dst.info.datatype = to_ucc_dType(flat_output[0]);
918     coll.dst.info.mem_type =
919         to_ucc_memType(outputTensors[0][0].device().type());
920 
921     auto copy_from_flat = [&] {
922       bool asyncCopy = false;
923 #ifdef USE_CUDA
924       bool isCuda = outputTensors[0][0].device().is_cuda();
925       ;
926 #endif
927       for (size_t i = 0; i < outputTensors.size(); i++) {
928         auto inumel = inputTensors[i].numel();
929         for (size_t j = 0; j < outputTensors[i].size(); j++) {
930           TORCH_CHECK(
931               (outputTensors[i][j].numel() == inumel),
932               "Tensor operand counts must be same");
933 #ifdef USE_CUDA
934           if (isCuda) {
935             c10::cuda::CUDACachingAllocator::recordStream(
936                 outputTensors[i][j].storage().data_ptr(), (*stream));
937             asyncCopy = true;
938           }
939 #endif
940           outputTensors[i][j].copy_(flat_output[i][j], asyncCopy);
941         }
942       }
943     };
944     return collective_post(
945         OpType::ALLGATHER,
946         []() {},
947         copy_from_flat,
948         coll,
949         std::unique_ptr<WorkData>(data),
950         tensor.device(),
951         inputTensors,
952         outputTensors[0],
953         "ucc:all_gather");
954   }
955 }
956 
_allgather_base(at::Tensor & outputTensor,at::Tensor & inputTensor,const AllgatherOptions & opts)957 c10::intrusive_ptr<Work> ProcessGroupUCC::_allgather_base(
958     at::Tensor& outputTensor,
959     at::Tensor& inputTensor,
960     const AllgatherOptions& opts) {
961   check_tensor({outputTensor});
962   check_tensor({inputTensor});
963   initComm(outputTensor.device());
964 
965   WorkData* data = new WorkData();
966 
967   ucc_coll_args_t coll;
968   coll.mask = 0;
969   coll.flags = 0;
970   coll.coll_type = UCC_COLL_TYPE_ALLGATHER;
971   coll.src.info.buffer = inputTensor.data_ptr();
972   coll.src.info.count = inputTensor.numel();
973   coll.src.info.datatype = ucc_dtype_map.at(inputTensor.scalar_type());
974   coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type());
975   coll.dst.info.buffer = outputTensor.data_ptr();
976   coll.dst.info.count = outputTensor.numel();
977   coll.dst.info.datatype = ucc_dtype_map.at(outputTensor.scalar_type());
978   coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type());
979 
980   std::vector<at::Tensor> inputTensors = {inputTensor};
981   std::vector<at::Tensor> outputTensors = {outputTensor};
982   SAVE_TENSORS(inputTensors, data->src);
983   SAVE_TENSORS(outputTensors, data->dst);
984 
985   return collective_post(
986       OpType::_ALLGATHER_BASE,
987       []() {},
988       []() {},
989       coll,
990       std::unique_ptr<WorkData>(data),
991       outputTensor.device(),
992       inputTensors,
993       outputTensors,
994       "ucc:allgather_base");
995 }
996 
allreduce(std::vector<at::Tensor> & tensors,const AllreduceOptions & opts)997 c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce(
998     std::vector<at::Tensor>& tensors,
999     const AllreduceOptions& opts) {
1000   check_tensor(tensors);
1001   auto& tensor = tensors[0];
1002   initComm(tensor.device());
1003   WorkData* data = new WorkData();
1004 
1005   ucc_coll_args_t coll;
1006   coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1007   coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
1008   coll.coll_type = UCC_COLL_TYPE_ALLREDUCE;
1009   coll.op = to_ucc_reduceOp(opts.reduceOp, tensor.scalar_type());
1010   coll.src.info.buffer = nullptr;
1011   coll.src.info.count = tensor.numel();
1012   coll.src.info.datatype = to_ucc_dType(tensor);
1013   coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
1014   coll.dst.info.buffer = tensor.data_ptr();
1015   coll.dst.info.count = tensor.numel();
1016   coll.dst.info.datatype = to_ucc_dType(tensor);
1017   coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
1018   SAVE_TENSORS(tensors, data->dst);
1019   return collective_post(
1020       OpType::ALLREDUCE,
1021       []() {},
1022       []() {},
1023       coll,
1024       std::unique_ptr<WorkData>(data),
1025       tensor.device(),
1026       tensors,
1027       tensors,
1028       "ucc:all_reduce");
1029 }
1030 
allreduce_coalesced(std::vector<at::Tensor> &,const AllreduceCoalescedOptions &)1031 c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce_coalesced(
1032     std::vector<at::Tensor>& /* unused */,
1033     const AllreduceCoalescedOptions& /* unused */) {
1034   throw std::invalid_argument(
1035       "ProcessGroupUCC does not support allreduce_coalesced");
1036 }
1037 
alltoall(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllToAllOptions &)1038 c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall(
1039     std::vector<at::Tensor>& outputTensors,
1040     std::vector<at::Tensor>& inputTensors,
1041     const AllToAllOptions& /* unused */) {
1042   auto device = outputTensors[0].device();
1043   for (const auto r : c10::irange(outputTensors.size())) {
1044     TORCH_CHECK(
1045         device == outputTensors[r].device() &&
1046             device == inputTensors[r].device(),
1047         "Tensors must be on the same device")
1048   }
1049 
1050   initComm(device);
1051   ucc_coll_args_t coll;
1052   AlltoallWorkData* data;
1053   data = new AlltoallWorkData(size_);
1054 
1055   /* to avoid flatten the tensors, we use alltoallv to achieve Alltoall as
1056      follow.
1057       1. store addresses of each tensor directly in displacements, keep buffer
1058      to nullptr, i.e., 0
1059       2. convert datatype to UINT8, which is always 1 bytes, to avoid wrong size
1060      calculation in UCC layer
1061       3. post Alltoallv
1062   */
1063   for (const auto i : c10::irange(size_)) {
1064     data->send_lengths[i] =
1065         (uint64_t)(inputTensors[i].element_size() * inputTensors[i].numel());
1066     data->send_offsets[i] = (uint64_t)inputTensors[i].data_ptr();
1067     data->recv_lengths[i] =
1068         (uint64_t)(outputTensors[i].element_size() * outputTensors[i].numel());
1069     data->recv_offsets[i] = (uint64_t)outputTensors[i].data_ptr();
1070   }
1071 
1072   coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1073   coll.flags =
1074       UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
1075   coll.coll_type = UCC_COLL_TYPE_ALLTOALLV;
1076   coll.src.info_v.buffer = 0;
1077   coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
1078   coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
1079   coll.src.info_v.datatype = UCC_DT_UINT8;
1080   coll.src.info_v.mem_type = to_ucc_memType(inputTensors[0].device().type());
1081   coll.dst.info_v.buffer = 0;
1082   coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
1083   coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
1084   coll.dst.info_v.datatype = UCC_DT_UINT8;
1085   coll.dst.info_v.mem_type = to_ucc_memType(outputTensors[0].device().type());
1086 
1087   SAVE_TENSORS(inputTensors, data->src);
1088   SAVE_TENSORS(outputTensors, data->dst);
1089 
1090   return collective_post(
1091       OpType::ALLTOALL,
1092       []() {},
1093       []() {},
1094       coll,
1095       std::unique_ptr<WorkData>(data),
1096       device,
1097       inputTensors,
1098       outputTensors,
1099       "ucc:alltoall");
1100 }
1101 
alltoall_base(at::Tensor & outputTensor,at::Tensor & inputTensor,std::vector<int64_t> & outputSplitSizes,std::vector<int64_t> & inputSplitSizes,const AllToAllOptions &)1102 c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall_base(
1103     at::Tensor& outputTensor,
1104     at::Tensor& inputTensor,
1105     std::vector<int64_t>& outputSplitSizes,
1106     std::vector<int64_t>& inputSplitSizes,
1107     const AllToAllOptions& /* unused */) {
1108   check_device(inputTensor.device(), outputTensor.device());
1109   initComm(inputTensor.device());
1110   ucc_coll_args_t coll;
1111   AlltoallWorkData* data;
1112 
1113   if ((outputSplitSizes.size() == 0) && (inputSplitSizes.size() == 0)) {
1114     data = new AlltoallWorkData(0);
1115     TORCH_CHECK(
1116         (outputTensor.size(0) % size_ == 0) &&
1117             (inputTensor.size(0) % size_ == 0),
1118         "Tensor's dim 0 does not divide equally across group size");
1119     coll.mask = 0;
1120     coll.flags = 0;
1121     coll.coll_type = UCC_COLL_TYPE_ALLTOALL;
1122     coll.src.info.buffer = inputTensor.data_ptr();
1123     coll.src.info.count = inputTensor.element_size() * inputTensor.numel();
1124     coll.src.info.datatype = UCC_DT_UINT8;
1125     coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type());
1126     coll.dst.info.buffer = outputTensor.data_ptr();
1127     coll.dst.info.count = outputTensor.element_size() * outputTensor.numel();
1128     coll.dst.info.datatype = UCC_DT_UINT8;
1129     coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type());
1130     coll.flags = 0;
1131   } else {
1132     data = new AlltoallWorkData(size_);
1133     c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
1134     c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
1135     computeLengthsAndOffsets(
1136         outputSplitSizes,
1137         outputTensor,
1138         &data->recv_lengths,
1139         &data->recv_offsets);
1140     computeLengthsAndOffsets(
1141         inputSplitSizes, inputTensor, &data->send_lengths, &data->send_offsets);
1142     coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1143     coll.coll_type = UCC_COLL_TYPE_ALLTOALLV;
1144     coll.src.info_v.buffer = inputTensor.data_ptr();
1145     coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
1146     coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
1147     coll.src.info_v.datatype = to_ucc_dType(inputTensor);
1148     coll.src.info_v.mem_type = to_ucc_memType(inputTensor.device().type());
1149     coll.dst.info_v.buffer = outputTensor.data_ptr();
1150     coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
1151     coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
1152     coll.dst.info_v.datatype = to_ucc_dType(outputTensor);
1153     coll.dst.info_v.mem_type = to_ucc_memType(outputTensor.device().type());
1154     coll.flags = UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER |
1155         UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER | UCC_COLL_ARGS_FLAG_COUNT_64BIT |
1156         UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
1157 
1158     if (torch_ucc_config.enable_comms_logger) {
1159       logger->trace_generator->recordOptionalInfo(
1160           outputSplitSizes, inputSplitSizes);
1161     }
1162   }
1163   std::vector<at::Tensor> inputTensors = {inputTensor};
1164   std::vector<at::Tensor> outputTensors = {outputTensor};
1165   SAVE_TENSORS(inputTensors, data->src);
1166   SAVE_TENSORS(outputTensors, data->dst);
1167 
1168   return collective_post(
1169       OpType::ALLTOALL_BASE,
1170       []() {},
1171       []() {},
1172       coll,
1173       std::unique_ptr<WorkData>(data),
1174       inputTensor.device(),
1175       inputTensors,
1176       outputTensors,
1177       "ucc:alltoall");
1178 }
1179 
barrier(const BarrierOptions & opts)1180 c10::intrusive_ptr<Work> ProcessGroupUCC::barrier(const BarrierOptions& opts) {
1181   c10::Device device = c10::Device(c10::DeviceType::CPU);
1182 #ifdef USE_CUDA
1183   auto numGPUs = c10::cuda::device_count();
1184   if (!opts.device_ids.empty()) {
1185     device = c10::Device(c10::DeviceType::CUDA, opts.device_ids.front());
1186   } else if (comm && comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) {
1187     device = c10::Device(c10::DeviceType::CUDA, comm->cuda_device_index);
1188   } else if (numGPUs > 0) {
1189     int8_t deviceIdx = static_cast<int8_t>(c10::cuda::current_device());
1190     // if current device is 0, likely the device is not set, use the best guess
1191     if (0 == (int)deviceIdx) {
1192       deviceIdx = static_cast<int8_t>(this->getRank() % numGPUs);
1193     }
1194     TORCH_UCC_LOG_INFO(
1195         TORCH_UCC_COLL_POST,
1196         c10::str(
1197             "post barrier before specifying any GPU while there are ",
1198             numGPUs,
1199             " GPUs available. ",
1200             "Not clear if GPU barrier is required, using GPU ",
1201             (int)deviceIdx,
1202             " to perform barrier. ",
1203             "Specify device_ids option in barrier() to force ",
1204             "use of a particular device"));
1205     device = c10::Device(c10::DeviceType::CUDA, deviceIdx);
1206   }
1207 #endif
1208   initComm(device);
1209 
1210   ucc_coll_args_t coll;
1211   coll.mask = 0;
1212   coll.flags = 0;
1213   coll.coll_type = UCC_COLL_TYPE_BARRIER;
1214   auto dummy_tensor = std::vector<at::Tensor>();
1215   return collective_post(
1216       OpType::BARRIER,
1217       []() {},
1218       []() {},
1219       coll,
1220       nullptr,
1221       device,
1222       dummy_tensor,
1223       dummy_tensor,
1224       "ucc:barrier");
1225 }
1226 
broadcast(std::vector<at::Tensor> & tensors,const BroadcastOptions & opts)1227 c10::intrusive_ptr<Work> ProcessGroupUCC::broadcast(
1228     std::vector<at::Tensor>& tensors,
1229     const BroadcastOptions& opts) {
1230   check_tensor(tensors);
1231   auto& tensor = tensors[0];
1232   initComm(tensor.device());
1233   WorkData* data = new WorkData();
1234 
1235   ucc_coll_args_t coll;
1236   coll.mask = 0;
1237   coll.flags = 0;
1238   coll.coll_type = UCC_COLL_TYPE_BCAST;
1239   coll.src.info.buffer = tensor.data_ptr();
1240   coll.src.info.count = tensor.numel();
1241   coll.src.info.datatype = to_ucc_dType(tensor);
1242   coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
1243   coll.root = opts.rootRank;
1244   SAVE_TENSORS(tensors, data->dst);
1245 
1246   if (torch_ucc_config.enable_comms_logger) {
1247     logger->trace_generator->recordOptionalInfo(opts.rootRank);
1248   }
1249 
1250   return collective_post(
1251       OpType::BROADCAST,
1252       []() {},
1253       []() {},
1254       coll,
1255       std::unique_ptr<WorkData>(data),
1256       tensor.device(),
1257       tensors,
1258       tensors,
1259       "ucc:broadcast");
1260 }
1261 
gather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const GatherOptions & opts)1262 c10::intrusive_ptr<Work> ProcessGroupUCC::gather(
1263     std::vector<std::vector<at::Tensor>>& outputTensors,
1264     std::vector<at::Tensor>& inputTensors,
1265     const GatherOptions& opts) {
1266   std::vector<at::Tensor> outputs;
1267   auto& input = inputTensors[0];
1268   initComm(input.device());
1269 
1270   AllgathervWorkData* data = new AllgathervWorkData(size_);
1271   ucc_coll_args_t coll;
1272   coll.root = opts.rootRank;
1273   coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1274   coll.flags =
1275       UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
1276   coll.coll_type = UCC_COLL_TYPE_GATHERV;
1277 
1278   /* for non-root ranks, only src is valid */
1279   coll.src.info.buffer = input.data_ptr();
1280   coll.src.info.count = (uint64_t)(input.element_size() * input.numel());
1281   coll.src.info.datatype = UCC_DT_UINT8;
1282   coll.src.info.mem_type = to_ucc_memType(input.device().type());
1283 
1284   if (getRank() == opts.rootRank) {
1285     if (outputTensors.size() != 1) {
1286       TORCH_UCC_LOG_ERROR(
1287           TORCH_UCC_COLL_POST,
1288           c10::str(
1289               "gather requires a single-element output list containing a list with ",
1290               getSize(),
1291               " tensors."));
1292     } else if (outputTensors[0].size() != static_cast<size_t>(getSize())) {
1293       TORCH_UCC_LOG_ERROR(
1294           TORCH_UCC_COLL_POST,
1295           c10::str(
1296               "Incorrect output list size ",
1297               outputTensors[0].size(),
1298               ". Output list size should be ",
1299               getSize(),
1300               ", same as size of the process group."));
1301     }
1302     outputs = outputTensors[0];
1303 
1304     for (int i = 0; i < size_; i++) {
1305       data->recv_lengths[i] =
1306           (uint64_t)(outputs[i].element_size() * outputs[i].numel());
1307       data->recv_offsets[i] = (uint64_t)outputs[i].data_ptr();
1308     }
1309     /* use gatherv and store non-contiguous addresses in displacements to avoid
1310      * flatten outputTensors */
1311     coll.dst.info_v.buffer = nullptr;
1312     coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
1313     coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
1314     coll.dst.info_v.datatype = UCC_DT_UINT8;
1315     coll.dst.info_v.mem_type = to_ucc_memType(outputs[0].device().type());
1316 
1317     SAVE_TENSORS(outputs, data->dst);
1318   } else {
1319     // for non-root ranks, outputTensors should be an empty list
1320     if (outputTensors.size() != 0) {
1321       TORCH_UCC_LOG_ERROR(
1322           TORCH_UCC_COLL_POST, "requires empty output on non-root");
1323     }
1324     outputs = {};
1325     // append a empty tensor to the list to be used by future mark
1326     outputs.emplace_back();
1327   }
1328 
1329   SAVE_TENSORS(inputTensors, data->src);
1330 
1331   return collective_post(
1332       OpType::GATHER,
1333       []() {},
1334       []() {},
1335       coll,
1336       std::unique_ptr<WorkData>(data),
1337       input.device(),
1338       inputTensors,
1339       outputs,
1340       "ucc:gather");
1341 }
1342 
reduce(std::vector<at::Tensor> & tensors,const ReduceOptions & opts)1343 c10::intrusive_ptr<Work> ProcessGroupUCC::reduce(
1344     std::vector<at::Tensor>& tensors,
1345     const ReduceOptions& opts) {
1346   check_tensor(tensors);
1347   auto& tensor = tensors[0];
1348   initComm(tensor.device());
1349   WorkData* data = new WorkData();
1350 
1351   ucc_coll_args_t coll;
1352   coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1353   coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
1354   coll.coll_type = UCC_COLL_TYPE_REDUCE;
1355   coll.op = ucc_op_map.at(opts.reduceOp);
1356   coll.root = opts.rootRank;
1357   coll.src.info.buffer = tensor.data_ptr();
1358   coll.src.info.count = tensor.numel();
1359   coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
1360   coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
1361   coll.dst.info.buffer = tensor.data_ptr();
1362   coll.dst.info.count = tensor.numel();
1363   coll.dst.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
1364   coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
1365   SAVE_TENSORS(tensors, data->dst);
1366   return collective_post(
1367       OpType::REDUCE,
1368       []() {},
1369       []() {},
1370       coll,
1371       std::unique_ptr<WorkData>(data),
1372       tensor.device(),
1373       tensors,
1374       tensors,
1375       "ucc:reduce");
1376 }
1377 
reduce_scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ReduceScatterOptions & opts)1378 c10::intrusive_ptr<Work> ProcessGroupUCC::reduce_scatter(
1379     std::vector<at::Tensor>& outputTensors,
1380     std::vector<std::vector<at::Tensor>>& inputTensors,
1381     const ReduceScatterOptions& opts) {
1382   TORCH_CHECK(
1383       (outputTensors.size() == inputTensors.size()),
1384       "Tensor input/output list for reduce_scatter must have same size");
1385   check_tensor(outputTensors);
1386   check_device(inputTensors[0][0].device(), outputTensors[0].device());
1387   initComm(inputTensors[0][0].device());
1388   auto data = std::make_unique<WorkData>();
1389   std::vector<at::Tensor> flat_input(inputTensors.size());
1390   for (size_t i = 0; i < inputTensors.size(); i++) {
1391     TORCH_CHECK(
1392         inputTensors[i].size() == inputTensors.size() * size_,
1393         "Tensor input list is not valid for the number of participants");
1394     flat_input[i] = c10d::newLikeFlat(inputTensors, i);
1395   }
1396   SAVE_TENSORS(flat_input, data->flat);
1397   check_tensor(flat_input);
1398   ucc_coll_args_t coll;
1399   coll.mask = 0;
1400   coll.flags = 0;
1401   coll.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER;
1402   coll.op = to_ucc_reduceOp(opts.reduceOp, flat_input[0].scalar_type());
1403 
1404   coll.src.info.buffer = flat_input[0].data_ptr();
1405   coll.src.info.count = flat_input[0].numel();
1406   coll.src.info.datatype = to_ucc_dType(flat_input[0]);
1407   coll.src.info.mem_type = to_ucc_memType(flat_input[0].device().type());
1408   coll.dst.info.buffer = outputTensors[0].data_ptr();
1409   coll.dst.info.count = outputTensors[0].numel();
1410   coll.dst.info.datatype = to_ucc_dType(outputTensors[0]);
1411   coll.dst.info.mem_type = to_ucc_memType(outputTensors[0].device().type());
1412 
1413   SAVE_TENSORS(inputTensors[0], data->src);
1414   SAVE_TENSORS(outputTensors, data->dst);
1415 
1416   auto copy_to_flat = [&] {
1417     bool asyncCopy = false;
1418     auto isize = inputTensors.size();
1419 #ifdef USE_CUDA
1420     bool isCuda = inputTensors[0][0].device().is_cuda();
1421 #endif
1422     for (size_t i = 0; i < isize; i++) {
1423       auto onumel = outputTensors[i].numel();
1424       for (size_t j = 0; j < inputTensors[i].size(); j++) {
1425         TORCH_CHECK(
1426             (inputTensors[i][j].numel() == onumel),
1427             "Tensor operand counts must be same");
1428 #ifdef USE_CUDA
1429         if (isCuda) {
1430           c10::cuda::CUDACachingAllocator::recordStream(
1431               inputTensors[i][j].storage().data_ptr(), (*stream));
1432           asyncCopy = true;
1433         }
1434 #endif
1435         flat_input[i][j].copy_(inputTensors[i][j], asyncCopy);
1436       }
1437     }
1438   };
1439 
1440   return collective_post(
1441       OpType::REDUCE_SCATTER,
1442       copy_to_flat,
1443       []() {},
1444       coll,
1445       std::move(data),
1446       inputTensors[0][0].device(),
1447       inputTensors[0],
1448       outputTensors,
1449       "ucc:reduce_scatter");
1450 }
1451 
scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ScatterOptions & opts)1452 c10::intrusive_ptr<Work> ProcessGroupUCC::scatter(
1453     std::vector<at::Tensor>& outputTensors,
1454     std::vector<std::vector<at::Tensor>>& inputTensors,
1455     const ScatterOptions& opts) {
1456   auto& tensor = outputTensors[0];
1457   initComm(tensor.device());
1458 
1459   ScattervWorkData* data = new ScattervWorkData(size_);
1460   ucc_coll_args_t coll;
1461   coll.root = opts.rootRank;
1462   coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1463   coll.flags =
1464       UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
1465   coll.coll_type = UCC_COLL_TYPE_SCATTERV;
1466 
1467   if (getRank() == opts.rootRank) {
1468     /* src is only valid at non-root rank */
1469     if (inputTensors.size() != 1) {
1470       TORCH_UCC_LOG_ERROR(
1471           TORCH_UCC_COLL_POST,
1472           c10::str(
1473               "gather requires a single-element output list containing a list with ",
1474               getSize(),
1475               " tensors."));
1476     } else if (inputTensors[0].size() != static_cast<size_t>(getSize())) {
1477       TORCH_UCC_LOG_ERROR(
1478           TORCH_UCC_COLL_POST,
1479           c10::str(
1480               "Incorrect output list size ",
1481               inputTensors[0].size(),
1482               ". Output list size should be ",
1483               getSize(),
1484               ", same as size of the process group."));
1485     }
1486 
1487     for (int i = 0; i < size_; i++) {
1488       data->send_lengths[i] = (uint64_t)tensor.element_size() * tensor.numel();
1489       data->send_offsets[i] = (uint64_t)inputTensors[0][i].data_ptr();
1490     }
1491     /* use scatter and store non-contiguous addresses in displacements to avoid
1492      * flatten inputTensors */
1493     coll.src.info_v.buffer = nullptr;
1494     coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
1495     coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
1496     coll.src.info_v.datatype = UCC_DT_UINT8;
1497     coll.src.info_v.mem_type =
1498         to_ucc_memType(inputTensors[0][0].device().type());
1499 
1500     SAVE_TENSORS(inputTensors[0], data->src);
1501   } else {
1502     // for non-root ranks, inputTensors should be an empty list
1503     if (inputTensors.size() != 0) {
1504       TORCH_UCC_LOG_ERROR(
1505           TORCH_UCC_COLL_POST, "requires empty output on non-root");
1506     }
1507   }
1508 
1509   coll.dst.info.buffer = tensor.data_ptr();
1510   coll.dst.info.count = (uint64_t)tensor.element_size() * tensor.numel();
1511   coll.dst.info.datatype = UCC_DT_UINT8;
1512   coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
1513   SAVE_TENSORS(outputTensors, data->dst);
1514 
1515   return collective_post(
1516       OpType::SCATTER,
1517       []() {},
1518       []() {},
1519       coll,
1520       std::unique_ptr<WorkData>(data),
1521       tensor.device(),
1522       (getRank() == opts.rootRank) ? inputTensors[0] : outputTensors,
1523       outputTensors,
1524       "ucc:scatter");
1525 }
1526 
send(std::vector<at::Tensor> & tensors,int dstRank,int tag)1527 c10::intrusive_ptr<Work> ProcessGroupUCC::send(
1528     std::vector<at::Tensor>& tensors,
1529     int dstRank,
1530     int tag) {
1531   check_tensor(tensors);
1532   auto& tensor = tensors[0];
1533   initComm(tensor.device());
1534 
1535   WorkData* data = new WorkData();
1536   ucc_coll_args_t coll;
1537   coll.tag = tag;
1538   coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG;
1539   coll.flags = 0;
1540   coll.coll_type = UCC_COLL_TYPE_BCAST;
1541   coll.src.info.buffer = tensor.data_ptr();
1542   coll.src.info.count = tensor.numel();
1543   coll.src.info.datatype = to_ucc_dType(tensor);
1544   coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
1545   coll.root = getRank();
1546 
1547   coll.active_set.size = 2;
1548   coll.active_set.start = getRank();
1549   coll.active_set.stride = dstRank - getRank();
1550   SAVE_TENSORS(tensors, data->dst);
1551 
1552   return collective_post(
1553       OpType::SEND,
1554       []() {},
1555       []() {},
1556       coll,
1557       std::unique_ptr<WorkData>(data),
1558       tensor.device(),
1559       tensors,
1560       tensors,
1561       "ucc:send");
1562 }
1563 
recv(std::vector<at::Tensor> & tensors,int srcRank,int tag)1564 c10::intrusive_ptr<Work> ProcessGroupUCC::recv(
1565     std::vector<at::Tensor>& tensors,
1566     int srcRank,
1567     int tag) {
1568   check_tensor(tensors);
1569   auto& tensor = tensors[0];
1570   initComm(tensor.device());
1571 
1572   WorkData* data = new WorkData();
1573   ucc_coll_args_t coll;
1574   coll.tag = tag;
1575   coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG;
1576   coll.flags = 0;
1577   coll.coll_type = UCC_COLL_TYPE_BCAST;
1578   coll.src.info.buffer = tensor.data_ptr();
1579   coll.src.info.count = tensor.numel();
1580   coll.src.info.datatype = to_ucc_dType(tensor);
1581   coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
1582   coll.root = srcRank;
1583 
1584   coll.active_set.size = 2;
1585   coll.active_set.start = srcRank;
1586   coll.active_set.stride = getRank() - srcRank;
1587   SAVE_TENSORS(tensors, data->dst);
1588 
1589   return collective_post(
1590       OpType::RECV,
1591       []() {},
1592       []() {},
1593       coll,
1594       std::unique_ptr<WorkData>(data),
1595       tensor.device(),
1596       tensors,
1597       tensors,
1598       "ucc:recv");
1599 }
1600 
setSequenceNumberForGroup()1601 void ProcessGroupUCC::setSequenceNumberForGroup() {}
1602 
getSequenceNumberForGroup()1603 uint64_t ProcessGroupUCC::getSequenceNumberForGroup() {
1604   return seq_;
1605 }
1606 
createProcessGroupUCC(const c10::intrusive_ptr<::c10d::Store> & store,int rank,int size,const std::chrono::duration<float> & timeout)1607 c10::intrusive_ptr<Backend> ProcessGroupUCC::createProcessGroupUCC(
1608     const c10::intrusive_ptr<::c10d::Store>& store,
1609     int rank,
1610     int size,
1611     const std::chrono::duration<float>& timeout) {
1612   return c10::make_intrusive<ProcessGroupUCC>(store, rank, size, timeout);
1613 }
1614 
initComm(c10::Device dev)1615 void ProcessGroupUCC::initComm(c10::Device dev) {
1616   if (!comm) {
1617 #ifdef USE_CUDA
1618     if (dev.is_cuda()) {
1619       c10::cuda::set_device(dev.index());
1620     }
1621 #endif
1622     comm = Comm::get_comm(comm_id, dev, oob, logger);
1623     TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCX library");
1624     comm->ucc_create_team(team, oob);
1625     TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCC library");
1626     logger->setPhase(TORCH_UCC_READY);
1627   } else {
1628     if (dev.is_cuda()) {
1629       if ((comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
1630           (comm->cuda_device_index != dev.index())) {
1631         TORCH_UCC_LOG_ERROR(
1632             TORCH_UCC_INIT,
1633             "ucc communicator was initialized with different cuda device,"
1634             "multi device is not supported");
1635         throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
1636       }
1637       comm->cuda_device_index = dev.index();
1638     }
1639   }
1640 #ifdef USE_CUDA
1641   // Create UCC execution engine.
1642   if (!cuda_ee && dev.is_cuda()) {
1643     stream = std::make_unique<at::cuda::CUDAStream>(
1644         at::cuda::getStreamFromPool(true, dev.index()));
1645     ucc_ee_params_t params;
1646     params.ee_type = UCC_EE_CUDA_STREAM;
1647     params.ee_context = (void*)stream->stream();
1648     params.ee_context_size = sizeof(cudaStream_t);
1649     TORCH_UCC_CHECK(
1650         ucc_ee_create(team, &params, &cuda_ee),
1651         "failed to create UCC execution engine");
1652     for (int i = 0; i < 2; i++) {
1653       stream_p2p[i] = std::make_unique<at::cuda::CUDAStream>(
1654           at::cuda::getStreamFromPool(true, dev.index()));
1655       ucc_ee_params_t params;
1656       params.ee_type = UCC_EE_CUDA_STREAM;
1657       params.ee_context = (void*)stream_p2p[i]->stream();
1658       params.ee_context_size = sizeof(cudaStream_t);
1659       TORCH_UCC_CHECK(
1660           ucc_ee_create(team, &params, &cuda_ee_p2p[i]),
1661           "failed to create UCC P2P execution engine");
1662     }
1663   }
1664 #endif
1665 }
1666 
1667 } // namespace c10d
1668 
1669 #endif // USE_C10D_UCC
1670