1 #ifdef USE_C10D_UCC
2
3 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
4 #include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
5 #include <torch/csrc/distributed/c10d/UCCTracing.hpp>
6 #include <torch/csrc/distributed/c10d/UCCUtils.hpp>
7 #include <list>
8 #include <memory>
9 #include <unordered_map>
10 #include <unordered_set>
11
12 namespace c10d {
13
14 namespace {
15
16 const std::map<c10::DeviceType, ucc_memory_type_t> ucc_mtype_map = {
17 {c10::kCPU, UCC_MEMORY_TYPE_HOST},
18 {c10::kCUDA, UCC_MEMORY_TYPE_CUDA},
19 };
20
to_ucc_memType(c10::DeviceType _c10_type)21 ucc_memory_type_t to_ucc_memType(c10::DeviceType _c10_type) {
22 if (ucc_mtype_map.find(_c10_type) != ucc_mtype_map.end())
23 return ucc_mtype_map.at(_c10_type);
24 else
25 return UCC_MEMORY_TYPE_UNKNOWN;
26 }
27
28 const std::map<at::ScalarType, ucc_datatype_t> ucc_dtype_map = {
29 {at::kByte, UCC_DT_UINT8},
30 {at::kChar, UCC_DT_INT8},
31 {at::kHalf, UCC_DT_FLOAT16},
32 {at::kBFloat16, UCC_DT_BFLOAT16},
33 {at::kDouble, UCC_DT_FLOAT64},
34 {at::kFloat, UCC_DT_FLOAT32},
35 {at::kInt, UCC_DT_INT32},
36 {at::kLong, UCC_DT_INT64},
37 {at::kBool, UCC_DT_UINT8},
38 };
39
to_ucc_dType(at::Tensor _tensor)40 ucc_datatype_t to_ucc_dType(at::Tensor _tensor) {
41 if (_tensor.scalar_type() == at::kBool && _tensor.element_size() != 1) {
42 TORCH_CHECK(
43 false, "Size of Boolean type larger than 1 is not supported in UCC");
44 }
45 try {
46 return ucc_dtype_map.at(_tensor.scalar_type());
47 } catch (const std::out_of_range&) {
48 TORCH_CHECK(false, "Not supported data type for UCC");
49 }
50 }
51
52 const std::map<ReduceOp, ucc_reduction_op_t> ucc_op_map = {
53 {ReduceOp::SUM, UCC_OP_SUM},
54 {ReduceOp::PRODUCT, UCC_OP_PROD},
55 {ReduceOp::MIN, UCC_OP_MIN},
56 {ReduceOp::MAX, UCC_OP_MAX},
57 {ReduceOp::BAND, UCC_OP_BAND},
58 {ReduceOp::BOR, UCC_OP_BOR},
59 {ReduceOp::BXOR, UCC_OP_BXOR},
60 {ReduceOp::AVG, UCC_OP_AVG},
61 };
62
to_ucc_reduceOp(const ReduceOp _op,const at::ScalarType _dt)63 ucc_reduction_op_t to_ucc_reduceOp(
64 const ReduceOp _op,
65 const at::ScalarType _dt) {
66 if (_dt == at::kBool) {
67 if (_op == ReduceOp::SUM) {
68 // bitwise or
69 return UCC_OP_MAX;
70 } else if (_op == ReduceOp::PRODUCT) {
71 // bitwise and
72 return UCC_OP_MIN;
73 } else if (_op == ReduceOp::AVG) {
74 TORCH_CHECK(false, "Cannot use ReduceOp.AVG with boolean inputs");
75 }
76 }
77
78 try {
79 return ucc_op_map.at(_op);
80 } catch (const std::out_of_range&) {
81 TORCH_CHECK(false, "Not supported ReduceOp for UCC");
82 }
83 }
84
85 struct torch_ucc_config_t {
86 c10::once_flag flag;
87 std::array<bool, 32> blocking_wait;
88 bool enable_comms_logger;
89 bool use_future;
90 // Sharing UCC communicator among multiple PGs to save resource.
91 bool shared_comm;
92 // Using allgatherv to achieve allgather, without flattening the list of
93 // (potentially non-contiguous) tensors.
94 bool use_allgatherv;
95 bool enable_health_check;
96 } torch_ucc_config;
97
98 std::unordered_map<std::string, std::string> torch_ucc_envs_map = {
99 // TORCH_UCC_BLOCKING_WAIT allowed syntax:
100 // - TORCH_UCC_BLOCKING_WAIT=none --> blocking wait completely disabled
101 // - TORCH_UCC_BLOCKING_WAIT=all --> blocking wait completely enabled
102 // - TORCH_UCC_BLOCKING_WAIT=allreduce,send,recv --> blocking wait enabled
103 // on selected operations
104 // Supported operations:
105 // [allgather,allgather_base,allreduce,alltoall,broadcast,
106 // gather,reduce,reduce_scatter,scatter,send,recv]
107 {"TORCH_UCC_BLOCKING_WAIT", "none"},
108
109 {"TORCH_UCC_USE_FUTURE", "1"},
110 {"TORCH_UCC_PROFILING_ENABLE", "0"},
111 {"TORCH_UCC_SHARED_COMM", "1"},
112 {"TORCH_UCC_USE_ALLGATHERV", "0"},
113 {"TORCH_UCC_ENABLE_HEALTH_CHECK", "0"},
114 {"TORCH_UCC_ENABLE_COMMS_LOGGER", "0"},
115 };
116
parse_blocking_wait(std::string op_list_string)117 std::vector<OpType> parse_blocking_wait(std::string op_list_string) {
118 const static std::unordered_map<std::string, OpType> str2op = {
119 {"allgather", OpType::ALLGATHER},
120 {"allgather_base", OpType::_ALLGATHER_BASE},
121 {"allreduce", OpType::ALLREDUCE},
122 {"alltoall_base", OpType::ALLTOALL_BASE},
123 {"broadcast", OpType::BROADCAST},
124 {"gather", OpType::GATHER},
125 {"reduce", OpType::REDUCE},
126 {"reduce_scatter", OpType::REDUCE_SCATTER},
127 {"scatter", OpType::SCATTER},
128 {"send", OpType::SEND},
129 {"recv", OpType::RECV},
130 };
131 auto op_list = parse_list(op_list_string);
132 if (op_list == std::vector<std::string>{"none"}) {
133 return {};
134 }
135 std::vector<OpType> result;
136 if (op_list == std::vector<std::string>{"all"}) {
137 for (auto entry : str2op) {
138 result.push_back(entry.second);
139 }
140 } else {
141 for (auto op_string : op_list) {
142 result.push_back(str2op.at(op_string));
143 }
144 }
145 return result;
146 }
147
148 } // namespace
149
read_config()150 void read_config() {
151 // default configuration
152 torch_ucc_config.blocking_wait.fill(false);
153 torch_ucc_config.use_future = true;
154 torch_ucc_config.shared_comm = false;
155 torch_ucc_config.use_allgatherv = false;
156 torch_ucc_config.enable_health_check = false;
157 torch_ucc_config.enable_comms_logger = false;
158
159 // read all torch_ucc env. variables and update the map
160 char* env;
161 for (auto& torch_ucc_env : torch_ucc_envs_map) {
162 env = std::getenv(torch_ucc_env.first.c_str());
163 if (env) {
164 torch_ucc_envs_map[torch_ucc_env.first] = std::string(env);
165 }
166 }
167
168 auto blocking_wait_str = torch_ucc_envs_map.at("TORCH_UCC_BLOCKING_WAIT");
169 for (auto op : parse_blocking_wait(blocking_wait_str)) {
170 torch_ucc_config.blocking_wait[(std::uint8_t)op] = true;
171 }
172 // barrier is always blocking
173 torch_ucc_config.blocking_wait[(std::uint8_t)OpType::BARRIER] = true;
174
175 torch_ucc_config.use_future =
176 std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_FUTURE"));
177 torch_ucc_config.shared_comm =
178 std::stoi(torch_ucc_envs_map.at("TORCH_UCC_SHARED_COMM"));
179 torch_ucc_config.use_allgatherv =
180 std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_ALLGATHERV"));
181 torch_ucc_config.enable_health_check =
182 std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_HEALTH_CHECK"));
183 torch_ucc_config.enable_comms_logger =
184 std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_COMMS_LOGGER"));
185 }
186
check_device(c10::Device dev1,c10::Device dev2)187 void check_device(c10::Device dev1, c10::Device dev2) {
188 if (dev1.is_cuda() && dev2.is_cuda() && dev1 != dev2) {
189 throw std::invalid_argument("ProcessGroupUCC multidevice is not supported");
190 }
191 }
192
check_tensor(const std::vector<at::Tensor> & tensors)193 void check_tensor(const std::vector<at::Tensor>& tensors) {
194 if (tensors.size() != 1) {
195 throw std::invalid_argument(
196 "ProcessGroupUCC takes 1 tensor. Got " +
197 std::to_string(tensors.size()) + ". ");
198 }
199 if (!tensors[0].is_contiguous()) {
200 throw std::invalid_argument(
201 "ProcessGroupUCC input tensor has to be contiguous");
202 }
203 if (tensors[0].is_sparse()) {
204 throw std::invalid_argument("ProcessGroupUCC input tensor has to be dense");
205 }
206 // TODO: check cuda case
207 }
208
~WorkUCC()209 ProcessGroupUCC::WorkUCC::~WorkUCC() {
210 #ifdef USE_CUDA
211 if (fence && ep) {
212 std::lock_guard<std::mutex> lock(ep->event_pool_mutex);
213 ep->event_pool.push(std::move(fence));
214 }
215 #endif
216 }
217
setException()218 void ProcessGroupUCC::WorkUCC::setException() {
219 if (exception() || !entry_) {
220 return;
221 }
222 exception_ = entry_->eptr_;
223 }
224
setAndThrowException()225 void ProcessGroupUCC::WorkUCC::setAndThrowException() {
226 setException();
227 if (exception()) {
228 std::rethrow_exception(exception());
229 }
230 }
231
isCompleted()232 bool ProcessGroupUCC::WorkUCC::isCompleted() {
233 if (!entry_) {
234 return true;
235 }
236 setException();
237 // status_ <= 0 to avoid listing all possible status codes. The main thread
238 // needs to be unblocked when UCC (in progress thread) returns success (== 0)
239 // or any error code (< 0).
240 return exception() || entry_->status_ <= 0;
241 }
242
isSuccess() const243 bool ProcessGroupUCC::WorkUCC::isSuccess() const {
244 if (!entry_) {
245 return true;
246 }
247 return !exception() && entry_->status_ == 0;
248 }
249
wait(std::chrono::milliseconds)250 bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) {
251 if (torch_ucc_config.enable_comms_logger && logger_) {
252 logger_->trace_generator->recordComms("wait", (uintptr_t)this, rank_);
253 }
254 #ifdef USE_CUDA
255 if (fence && !torch_ucc_config.blocking_wait[(int)opType_]) {
256 // block user stream
257 setAndThrowException();
258 fence->block(at::cuda::getCurrentCUDAStream());
259 return true;
260 }
261 #endif
262 // wait for complete. For blocking case, the main thread will be blocked in
263 // this loop until the progress thread changes the status of this request.
264 // If timeout occurs, UCC will return UCC_ERR_TIMEOUT as the status. The
265 // main thread will throw out the exception then. There is no "abort"
266 // function in UCC currently.
267 while (!isCompleted())
268 ;
269 setAndThrowException();
270 // manually call profiling end callbacks if they are set,
271 // since progress thread does not own WorkUCC
272 if (Work::recordFunctionEndCallback_) {
273 Work::recordFunctionEndCallback_();
274 Work::recordFunctionEndCallback_ = nullptr;
275 }
276 return true;
277 }
278
getFuture()279 c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupUCC::WorkUCC::getFuture() {
280 return future_;
281 }
282
sourceRank() const283 int ProcessGroupUCC::WorkUCC::sourceRank() const {
284 if (opType_ != OpType::RECV && opType_ != OpType::RECVANYSOURCE) {
285 // Throw an error
286 return Work::sourceRank();
287 }
288 return sourceRank_;
289 }
290
result()291 std::vector<at::Tensor> ProcessGroupUCC::WorkUCC::result() {
292 return *outputs_;
293 }
294
finalize(std::exception_ptr eptr)295 void ProcessGroupUCC::ProgressEntry::finalize(std::exception_ptr eptr) {
296 ucc_status_t status = UCC_OK;
297
298 if (request_ != nullptr) {
299 status = request_->status;
300 comm_->free_request(request_);
301 }
302 if (eptr) {
303 eptr_ = eptr;
304 } else {
305 status_ = status;
306 }
307 if (future_) {
308 if (eptr) {
309 future_->setError(eptr);
310 } else {
311 future_->markCompleted(
312 c10::IValue(data ? data->dst : std::vector<at::Tensor>()));
313 }
314 }
315 }
316
Comm(const c10::intrusive_ptr<ProcessGroupUCCLogger> & logger_,std::shared_ptr<torch_ucc_oob_coll_info_t> oob_,c10::Device dev,bool is_health_check)317 Comm::Comm(
318 const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger_,
319 std::shared_ptr<torch_ucc_oob_coll_info_t> oob_,
320 c10::Device dev,
321 bool is_health_check)
322 : logger(logger_),
323 oob(oob_),
324 ucc_comm(oob, logger),
325 finalize_phase(
326 is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_FINALIZE),
327 cuda_device_index(TORCH_UCC_DEVICE_NOT_SET) {
328 if (dev.is_cuda()) {
329 cuda_device_index = dev.index();
330 }
331 stop_progress_loop = false;
332 collective_inprogress = false;
333 progress_thread = std::thread(&Comm::progress_loop, this);
334 #ifdef _GNU_SOURCE
335 pthread_setname_np(progress_thread.native_handle(), "ucc-progress");
336 #endif
337 }
338
~Comm()339 Comm::~Comm() {
340 std::unique_lock<std::mutex> lock(mutex);
341 queue_consume_cv.wait(
342 lock, [&] { return progress_queue.empty() && !collective_inprogress; });
343 stop_progress_loop = true;
344 lock.unlock();
345 queue_produce_cv.notify_all();
346 progress_thread.join();
347 }
348
get_comm(uint32_t & id,c10::Device dev,std::shared_ptr<torch_ucc_oob_coll_info_t> oob,const c10::intrusive_ptr<ProcessGroupUCCLogger> & logger,bool is_health_check)349 std::shared_ptr<Comm> Comm::get_comm(
350 uint32_t& id,
351 c10::Device dev,
352 std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
353 const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
354 bool is_health_check) {
355 static std::mutex m;
356 static std::weak_ptr<Comm> comm;
357 static uint32_t comm_id;
358
359 std::lock_guard<std::mutex> lock(m);
360 id = comm_id;
361
362 std::string group_id = "group_id";
363 if (is_health_check) {
364 group_id = c10::str(dev.type()) + "/" + group_id;
365 }
366
367 std::vector<uint8_t> remote_comm_id;
368 oob->store->deleteKey(group_id + std::to_string(0));
369 if (oob->rank != 0) {
370 std::vector<uint8_t> val = std::vector<uint8_t>(
371 reinterpret_cast<uint8_t*>(&id),
372 reinterpret_cast<uint8_t*>(&id) + sizeof(id));
373 oob->store->set(group_id + std::to_string(oob->rank), val);
374 } else {
375 for (int i = 1; i < oob->size; i++) {
376 remote_comm_id = oob->store->get(group_id + std::to_string(i));
377 oob->store->deleteKey(group_id + std::to_string(i));
378 // Find the highest id.
379 id = std::max(id, *(reinterpret_cast<uint32_t*>(remote_comm_id.data())));
380 }
381 std::vector<uint8_t> val = std::vector<uint8_t>(
382 reinterpret_cast<uint8_t*>(&id),
383 reinterpret_cast<uint8_t*>(&id) + sizeof(id));
384 oob->store->set(group_id + std::to_string(oob->rank), val);
385 }
386 remote_comm_id = oob->store->get(group_id + std::to_string(0));
387 oob->comm_id = *(reinterpret_cast<uint32_t*>(remote_comm_id.data()));
388 // Prepare comm_id (static variable) to the next id.
389 comm_id = oob->comm_id + 1;
390
391 if (torch_ucc_config.shared_comm) {
392 std::shared_ptr<Comm> shared_comm = comm.lock();
393 if (!shared_comm) {
394 shared_comm = std::make_shared<Comm>(logger, oob, dev, is_health_check);
395 comm = shared_comm;
396 } else {
397 if (dev.is_cuda() && !is_health_check) {
398 if ((shared_comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
399 (shared_comm->cuda_device_index != dev.index())) {
400 TORCH_UCC_LOG_ERROR(
401 is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_INIT,
402 "ucc communicator was initialized with different cuda device,"
403 "multi device is not supported");
404 throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
405 }
406 shared_comm->cuda_device_index = dev.index();
407 }
408 }
409 return shared_comm;
410 } else {
411 return std::make_shared<Comm>(logger, oob, dev, is_health_check);
412 }
413 }
414
ucc_create_team(ucc_team_h & team,std::shared_ptr<torch_ucc_oob_coll_info_t> oob)415 void Comm::ucc_create_team(
416 ucc_team_h& team,
417 std::shared_ptr<torch_ucc_oob_coll_info_t> oob) {
418 ucc_status_t st;
419 ucc_team_params_t team_params;
420 team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_EP_RANGE |
421 UCC_TEAM_PARAM_FIELD_OOB;
422 team_params.oob.allgather = oob_allgather;
423 team_params.oob.req_test = oob_allgather_test;
424 team_params.oob.req_free = oob_allgather_free;
425 team_params.oob.coll_info = oob.get();
426 team_params.oob.n_oob_eps = oob->size;
427 team_params.oob.oob_ep = oob->rank;
428 team_params.ep = oob->rank;
429 team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG;
430 TORCH_UCC_CHECK(
431 ucc_team_create_post(&ucc_comm.context, 1, &team_params, &team),
432 "failed to post team create");
433 do {
434 st = ucc_team_create_test(team);
435 ucc_context_progress(ucc_comm.context);
436 } while (st == UCC_INPROGRESS);
437 TORCH_UCC_CHECK(st, "failed to create UCC team");
438 }
439
ucc_destroy_team(ucc_team_h & team)440 void Comm::ucc_destroy_team(ucc_team_h& team) {
441 std::unique_lock<std::mutex> lock(mutex);
442 queue_consume_cv.wait(
443 lock, [&] { return progress_queue.empty() && !collective_inprogress; });
444
445 ucc_status_t status;
446 while (UCC_INPROGRESS == (status = ucc_team_destroy(team))) {
447 if (UCC_OK != status) {
448 TORCH_UCC_LOG_ERROR(
449 finalize_phase,
450 c10::str("ucc team destroy error: ", ucc_status_string(status)));
451 break;
452 }
453 }
454
455 lock.unlock();
456 }
457
enqueue_collective(std::unique_ptr<ProcessGroupUCC::WorkData> data,c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,ucc_coll_args_t & coll,ucc_team_h team)458 void Comm::enqueue_collective(
459 std::unique_ptr<ProcessGroupUCC::WorkData> data,
460 c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
461 ucc_coll_args_t& coll,
462 ucc_team_h team) {
463 ucc_coll_req_h request;
464 TORCH_UCC_CHECK(
465 ucc_collective_init(&coll, &request, team), "failed to init collective");
466 TORCH_UCC_CHECK_REQUEST(
467 request, ucc_collective_post(request), "failed to post collective");
468
469 auto entry =
470 std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request);
471 entry->data = std::move(data);
472 entry->future_ = work->getFuture();
473 work->entry_ = entry;
474 std::unique_lock<std::mutex> lock(mutex);
475 progress_queue.push_back(entry);
476 lock.unlock();
477 queue_produce_cv.notify_one();
478 }
479
480 #ifdef USE_CUDA
enqueue_cuda_collective(std::unique_ptr<ProcessGroupUCC::WorkData> data,c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,ucc_coll_args_t & coll,ucc_team_h team,ucc_ee_h ee)481 void Comm::enqueue_cuda_collective(
482 std::unique_ptr<ProcessGroupUCC::WorkData> data,
483 c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
484 ucc_coll_args_t& coll,
485 ucc_team_h team,
486 ucc_ee_h ee) {
487 ucc_coll_req_h request;
488 TORCH_UCC_CHECK(
489 ucc_collective_init(&coll, &request, team),
490 "failed to init cuda collective");
491 ucc_ev_t comp_ev, *post_ev;
492 comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE;
493 comp_ev.ev_context = nullptr;
494 comp_ev.ev_context_size = 0;
495 comp_ev.req = request;
496 TORCH_UCC_CHECK_REQUEST(
497 request,
498 ucc_collective_triggered_post(ee, &comp_ev),
499 "failed to post triggered collective");
500 ucc_status_t st = ucc_ee_get_event(ee, &post_ev);
501 TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST);
502 ucc_ee_ack_event(ee, post_ev);
503 auto entry =
504 std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request);
505 entry->data = std::move(data);
506 work->entry_ = entry;
507 std::unique_lock<std::mutex> lock(mutex);
508 progress_queue.push_back(entry);
509 lock.unlock();
510 queue_produce_cv.notify_one();
511 }
512 #endif
513
progress_loop()514 void Comm::progress_loop() {
515 std::unique_lock<std::mutex> lock(mutex);
516 #ifdef USE_CUDA
517 bool device_set = false;
518 #endif
519 while (!stop_progress_loop) {
520 if (progress_queue.empty()) {
521 queue_produce_cv.wait(lock);
522 continue;
523 }
524 collective_inprogress = true;
525 auto work = progress_queue.front();
526 progress_queue.pop_front();
527 lock.unlock();
528 #ifdef USE_CUDA
529 if ((!device_set) && (cuda_device_index != TORCH_UCC_DEVICE_NOT_SET)) {
530 c10::cuda::set_device(cuda_device_index);
531 CUcontext pctx = nullptr;
532 at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx);
533 if (C10_UNLIKELY(!pctx)) {
534 at::globalContext().getNVRTC().cuDevicePrimaryCtxRetain(
535 &pctx, cuda_device_index);
536 at::globalContext().getNVRTC().cuCtxSetCurrent(pctx);
537 }
538 device_set = true;
539 }
540 #endif
541 std::exception_ptr eptr;
542 try {
543 while (work->request_->status > 0) {
544 ucc_comm.progress();
545 }
546 if (work->request_->status < 0) {
547 eptr = std::make_exception_ptr(
548 std::runtime_error(ucc_status_string(work->request_->status)));
549 std::string err_log = c10::str(
550 "Failed to progress communication", // TODO: report exact op type or
551 // id?
552 ucc_status_string(work->request_->status));
553 TORCH_UCC_LOG_ERROR(TORCH_UCC_COLL_PROGRESS, err_log);
554 }
555 } catch (...) {
556 eptr = std::current_exception();
557 }
558 work->finalize(eptr);
559 work = nullptr;
560 collective_inprogress = false;
561 queue_consume_cv.notify_one();
562 lock.lock();
563 }
564 }
565
ProcessGroupUCC(const c10::intrusive_ptr<Store> & store,int rank,int size,std::chrono::duration<float> timeout)566 ProcessGroupUCC::ProcessGroupUCC(
567 const c10::intrusive_ptr<Store>& store,
568 int rank,
569 int size,
570 std::chrono::duration<float> timeout)
571 : Backend(rank, size), timeout_(timeout) {
572 c10::call_once(torch_ucc_config.flag, read_config);
573 oob = std::make_shared<torch_ucc_oob_coll_info_t>();
574 oob->rank = rank;
575 oob->size = size;
576 oob->store = store;
577 comm = nullptr;
578 cuda_ee = nullptr;
579 static uint32_t id = 0;
580 uint32_t pg_id = id++;
581
582 logger = c10::make_intrusive<ProcessGroupUCCLogger>(
583 c10::str("[Rank ", rank_, "]", "[ProcessGroupUCC-", pg_id, "]"),
584 TORCH_UCC_INIT);
585 TORCH_UCC_LOG_INFO(
586 TORCH_UCC_INIT,
587 c10::str(
588 "Created ProcessGroupUCC with ",
589 size,
590 " ranks, with timeout ",
591 timeout_.count(),
592 " secs"));
593 std::string envs = "";
594 for (auto& torch_ucc_env : torch_ucc_envs_map) {
595 envs += ("\n\t" + torch_ucc_env.first + "=" + torch_ucc_env.second);
596 }
597 TORCH_UCC_LOG_INFO(
598 TORCH_UCC_INIT,
599 c10::str(
600 "Successfully read and set ProcessGroupUCC env. variables as followings",
601 envs));
602
603 if (torch_ucc_config.enable_health_check) {
604 // Perform health check by initializing dummy communicators and destroying
605 // them. This will help indicate any UCC/UCX-related issues prior to the
606 // first collective. Run it in a separate thread and wait on CV to handle
607 // timeouts so that if there are hangs, the main thread can still run
608 // correctly.
609 runHealthCheck();
610 }
611 if (torch_ucc_config.enable_comms_logger) {
612 logger->initCommsTracer();
613 }
614 }
615
~ProcessGroupUCC()616 ProcessGroupUCC::~ProcessGroupUCC() {
617 if (torch_ucc_config.enable_comms_logger) {
618 logger->flushComms(this->getRank(), this->getSize());
619 }
620 if (comm) {
621 logger->setPhase(TORCH_UCC_FINALIZE);
622 comm->ucc_destroy_team(team);
623 TORCH_UCC_LOG_INFO(
624 TORCH_UCC_FINALIZE, "Successfully destroyed UCC library");
625 try {
626 if (cuda_ee) {
627 ucc_ee_destroy(cuda_ee);
628 ucc_ee_destroy(cuda_ee_p2p[0]);
629 ucc_ee_destroy(cuda_ee_p2p[1]);
630 }
631 } catch (std::exception& ex) {
632 TORCH_UCC_LOG_INFO(
633 TORCH_UCC_FINALIZE,
634 c10::str(
635 "(~ProcessGroupUCC) Caught error in Store Operation .. ",
636 "[",
637 ex.what(),
638 "]"));
639 }
640 comm = nullptr;
641 }
642 }
643
644 #ifdef USE_CUDA
645 // Return CUDA device with ordinal given by input rank.
getCUDADeviceForRank(int rank)646 c10::Device getCUDADeviceForRank(int rank) {
647 TORCH_CHECK(rank >= 0, "Invalid rank ", rank);
648 auto numGPUs = at::cuda::getNumGPUs();
649 auto deviceIdx = static_cast<c10::DeviceIndex>(rank % numGPUs);
650 return c10::Device(c10::DeviceType::CUDA, deviceIdx);
651 }
652 #endif
653
runHealthCheck()654 void ProcessGroupUCC::runHealthCheck() {
655 // Run health check in a separate thread and wait on CV to handle timeouts.
656 // This design allows us to handle hangs.
657
658 // When size_ is 1, there is no need to do any communication at all.
659 if (size_ == 1)
660 return;
661
662 struct HealthCheckData {
663 std::mutex healthCheckMutex;
664 std::condition_variable healthCheckCv;
665 bool uccHealthCheckSuccess = false;
666 std::exception_ptr healthCheckException;
667 } healthCheckData;
668
669 auto t = std::thread([&healthCheckData, this]() {
670 std::list<c10::Device> devices{c10::kCPU};
671 #ifdef USE_CUDA
672 c10::cuda::OptionalCUDAGuard gpuGuard;
673 if (at::cuda::is_available()) {
674 devices.emplace_front(getCUDADeviceForRank(rank_));
675 }
676 #endif
677 for (auto device : devices) {
678 bool is_last_device = (device == devices.back());
679 try {
680 auto oob = std::make_shared<torch_ucc_oob_coll_info_t>();
681 oob->rank = this->oob->rank;
682 oob->size = this->oob->size;
683 oob->store = this->oob->store;
684 ucc_team_h team = nullptr;
685 uint32_t comm_id;
686 #ifdef USE_CUDA
687 if (device.is_cuda()) {
688 gpuGuard.set_index(device.index());
689 }
690 #endif
691 auto comm = Comm::get_comm(comm_id, device, oob, logger, true);
692 comm->ucc_create_team(team, oob);
693 comm->ucc_destroy_team(team);
694 TORCH_UCC_LOG_INFO(
695 TORCH_UCC_HEALTH_CHECK,
696 c10::str(
697 "UCC library health check succeed for device ",
698 c10::DeviceTypeName(device.type())));
699 // Mark ucc health check as complete.
700 if (is_last_device) {
701 std::lock_guard<std::mutex> lk(healthCheckData.healthCheckMutex);
702 healthCheckData.uccHealthCheckSuccess = true;
703 }
704
705 comm = nullptr;
706 oob = nullptr;
707 // Notify main thread the health check is complete.
708 if (is_last_device) {
709 healthCheckData.healthCheckCv.notify_one();
710 }
711 } catch (const std::exception&) {
712 // Populate exception ptr.
713 healthCheckData.healthCheckException = std::current_exception();
714 // Unblock waiting main thread which will report exception.
715 healthCheckData.healthCheckCv.notify_one();
716 } // Unknown exceptions will just cause the program to terminate.
717 }
718 });
719 // We don't need to join the thread, just need to verify health check via the
720 // CV. Hence we detach the thread here.
721 t.detach(); // NOLINT
722 TORCH_UCC_LOG_INFO(
723 TORCH_UCC_HEALTH_CHECK,
724 c10::str(
725 "will wait up to ",
726 timeout_.count(),
727 " msec for UCC health check to complete."));
728 std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex);
729 healthCheckData.healthCheckCv.wait_for(lock, timeout_, [&healthCheckData]() {
730 return healthCheckData.uccHealthCheckSuccess;
731 });
732
733 if (healthCheckData.healthCheckException) {
734 std::rethrow_exception(healthCheckData.healthCheckException);
735 }
736 // If there is no exception, the likely culprit is a timeout/hang
737 TORCH_CHECK(
738 healthCheckData.uccHealthCheckSuccess,
739 "ProcessGroupUCC: Health check failure: Failed to initialize UCC on rank ",
740 rank_);
741 }
742
set_timeout(ucc_coll_args_t & args)743 void ProcessGroupUCC::set_timeout(ucc_coll_args_t& args) {
744 args.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
745 args.flags |= UCC_COLL_ARGS_FLAG_TIMEOUT;
746 args.timeout = timeout_.count();
747 }
748
749 #ifdef USE_CUDA
getPooledEvent()750 std::unique_ptr<at::cuda::CUDAEvent> ProcessGroupUCC::getPooledEvent() {
751 std::unique_ptr<at::cuda::CUDAEvent> ev;
752 std::lock_guard<std::mutex> lock(ep.event_pool_mutex);
753 if (ep.event_pool.empty()) {
754 ev = std::make_unique<at::cuda::CUDAEvent>();
755 } else {
756 ev = std::move(ep.event_pool.front());
757 ep.event_pool.pop();
758 }
759 return ev;
760 }
761 #endif
762
763 template <typename PreProcess, typename PostProcess>
collective_post(OpType opType,PreProcess preproc,PostProcess postproc,ucc_coll_args_t & coll,std::unique_ptr<ProcessGroupUCC::WorkData> data,c10::Device dev,std::vector<at::Tensor> & inputTensors,std::vector<at::Tensor> & outputTensors,const char * prof_title)764 c10::intrusive_ptr<Work> ProcessGroupUCC::collective_post(
765 OpType opType,
766 PreProcess preproc,
767 PostProcess postproc,
768 ucc_coll_args_t& coll,
769 std::unique_ptr<ProcessGroupUCC::WorkData> data,
770 c10::Device dev,
771 std::vector<at::Tensor>& inputTensors,
772 std::vector<at::Tensor>& outputTensors,
773 const char* prof_title) {
774 seq_++;
775 set_timeout(coll);
776 auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
777 opType, seq_, prof_title, inputTensors, logger);
778
779 if (opType == OpType::RECV) {
780 work->sourceRank_ = coll.root;
781 }
782
783 RECORD_COMMS_TRACE(
784 logger->trace_generator,
785 work,
786 opType,
787 this->getRank(),
788 this->getSize(),
789 inputTensors,
790 outputTensors);
791
792 // Store references to outputs to be used by result
793 work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputTensors);
794 switch (dev.type()) {
795 case c10::DeviceType::CPU: {
796 if (torch_ucc_config.use_future) {
797 work->future_ = c10::make_intrusive<at::ivalue::Future>(
798 c10::ListType::create(c10::TensorType::get()));
799 }
800 preproc();
801 comm->enqueue_collective(std::move(data), work, coll, team);
802 postproc();
803 return work;
804 }
805 #ifdef USE_CUDA
806 case c10::DeviceType::CUDA: {
807 auto cuda_ev = getPooledEvent();
808 at::cuda::CUDAStream* op_stream;
809 ucc_ee_h* op_ee;
810 if (opType == OpType::SEND) {
811 op_stream = stream_p2p[0].get();
812 op_ee = &cuda_ee_p2p[0];
813 } else if (opType == OpType::RECV) {
814 op_stream = stream_p2p[1].get();
815 op_ee = &cuda_ee_p2p[1];
816 } else {
817 op_stream = stream.get();
818 op_ee = &cuda_ee;
819 }
820
821 cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index()));
822 cuda_ev->block(*op_stream);
823 at::cuda::CUDAStreamGuard guard(*op_stream);
824 preproc();
825 comm->enqueue_cuda_collective(std::move(data), work, coll, team, *op_ee);
826 postproc();
827 cuda_ev->record(*op_stream);
828 work->fence = std::move(cuda_ev);
829 work->ep = &ep;
830 if (torch_ucc_config.use_future) {
831 c10::cuda::CUDAMultiStreamGuard streamGuard(*op_stream);
832 std::vector<c10::Device> devList{dev};
833 work->future_ = c10::make_intrusive<at::ivalue::Future>(
834 c10::ListType::create(c10::TensorType::get()), devList);
835 // Add a callback that runs profiling end callbacks
836 if (work->recordFunctionEndCallback_) {
837 work->future_->addCallback([work](at::ivalue::Future& /* unused */) {
838 work->recordFunctionEndCallback_();
839 });
840 }
841
842 work->future_->markCompleted(c10::IValue(outputTensors));
843 }
844 return work;
845 }
846 #endif // #ifdef USE_CUDA
847 default: {
848 TORCH_UCC_LOG_ERROR(
849 TORCH_UCC_COLL_POST, c10::str("unsupported device type ", dev.str()));
850 throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
851 }
852 }
853 }
854
allgather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions &)855 c10::intrusive_ptr<Work> ProcessGroupUCC::allgather(
856 std::vector<std::vector<at::Tensor>>& outputTensors,
857 std::vector<at::Tensor>& inputTensors,
858 const AllgatherOptions& /* unused */) {
859 auto& tensor = inputTensors[0];
860 check_device(tensor.device(), outputTensors[0][0].device());
861 initComm(tensor.device());
862
863 if (tensor.device().is_cpu() || torch_ucc_config.use_allgatherv) {
864 AllgathervWorkData* data = new AllgathervWorkData(size_);
865 for (int i = 0; i < size_; i++) {
866 data->recv_lengths[i] = tensor.element_size() * tensor.numel();
867 data->recv_offsets[i] = (uint64_t)outputTensors[0][i].data_ptr();
868 }
869 ucc_coll_args_t coll;
870 coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
871 coll.flags =
872 UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
873 coll.coll_type = UCC_COLL_TYPE_ALLGATHERV;
874 coll.src.info.buffer = tensor.data_ptr();
875 coll.src.info.count = tensor.element_size() * tensor.numel();
876 coll.src.info.datatype = UCC_DT_UINT8;
877 coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
878 coll.dst.info_v.buffer = nullptr;
879 coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
880 coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
881 coll.dst.info_v.datatype = UCC_DT_UINT8;
882 coll.dst.info_v.mem_type =
883 to_ucc_memType(outputTensors[0][0].device().type());
884 SAVE_TENSORS(inputTensors, data->src);
885 SAVE_TENSORS(outputTensors[0], data->dst);
886
887 return collective_post(
888 OpType::ALLGATHER,
889 []() {},
890 []() {},
891 coll,
892 std::unique_ptr<WorkData>(data),
893 tensor.device(),
894 inputTensors,
895 outputTensors[0],
896 "ucc:all_gather");
897 } else {
898 WorkData* data = new WorkData();
899 std::vector<at::Tensor> flat_output(outputTensors.size());
900 for (size_t i = 0; i < outputTensors.size(); i++) {
901 TORCH_CHECK(
902 outputTensors[i].size() == outputTensors.size() * size_,
903 "Tensor output list is not valid for the number of participants");
904 flat_output[i] = c10d::newLikeFlat(outputTensors, i);
905 }
906 SAVE_TENSORS(flat_output, data->flat);
907 ucc_coll_args_t coll;
908 coll.mask = 0;
909 coll.flags = 0;
910 coll.coll_type = UCC_COLL_TYPE_ALLGATHER;
911 coll.src.info.buffer = tensor.data_ptr();
912 coll.src.info.count = tensor.numel();
913 coll.src.info.datatype = to_ucc_dType(tensor);
914 coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
915 coll.dst.info.buffer = flat_output[0].data_ptr();
916 coll.dst.info.count = flat_output[0].numel();
917 coll.dst.info.datatype = to_ucc_dType(flat_output[0]);
918 coll.dst.info.mem_type =
919 to_ucc_memType(outputTensors[0][0].device().type());
920
921 auto copy_from_flat = [&] {
922 bool asyncCopy = false;
923 #ifdef USE_CUDA
924 bool isCuda = outputTensors[0][0].device().is_cuda();
925 ;
926 #endif
927 for (size_t i = 0; i < outputTensors.size(); i++) {
928 auto inumel = inputTensors[i].numel();
929 for (size_t j = 0; j < outputTensors[i].size(); j++) {
930 TORCH_CHECK(
931 (outputTensors[i][j].numel() == inumel),
932 "Tensor operand counts must be same");
933 #ifdef USE_CUDA
934 if (isCuda) {
935 c10::cuda::CUDACachingAllocator::recordStream(
936 outputTensors[i][j].storage().data_ptr(), (*stream));
937 asyncCopy = true;
938 }
939 #endif
940 outputTensors[i][j].copy_(flat_output[i][j], asyncCopy);
941 }
942 }
943 };
944 return collective_post(
945 OpType::ALLGATHER,
946 []() {},
947 copy_from_flat,
948 coll,
949 std::unique_ptr<WorkData>(data),
950 tensor.device(),
951 inputTensors,
952 outputTensors[0],
953 "ucc:all_gather");
954 }
955 }
956
_allgather_base(at::Tensor & outputTensor,at::Tensor & inputTensor,const AllgatherOptions & opts)957 c10::intrusive_ptr<Work> ProcessGroupUCC::_allgather_base(
958 at::Tensor& outputTensor,
959 at::Tensor& inputTensor,
960 const AllgatherOptions& opts) {
961 check_tensor({outputTensor});
962 check_tensor({inputTensor});
963 initComm(outputTensor.device());
964
965 WorkData* data = new WorkData();
966
967 ucc_coll_args_t coll;
968 coll.mask = 0;
969 coll.flags = 0;
970 coll.coll_type = UCC_COLL_TYPE_ALLGATHER;
971 coll.src.info.buffer = inputTensor.data_ptr();
972 coll.src.info.count = inputTensor.numel();
973 coll.src.info.datatype = ucc_dtype_map.at(inputTensor.scalar_type());
974 coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type());
975 coll.dst.info.buffer = outputTensor.data_ptr();
976 coll.dst.info.count = outputTensor.numel();
977 coll.dst.info.datatype = ucc_dtype_map.at(outputTensor.scalar_type());
978 coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type());
979
980 std::vector<at::Tensor> inputTensors = {inputTensor};
981 std::vector<at::Tensor> outputTensors = {outputTensor};
982 SAVE_TENSORS(inputTensors, data->src);
983 SAVE_TENSORS(outputTensors, data->dst);
984
985 return collective_post(
986 OpType::_ALLGATHER_BASE,
987 []() {},
988 []() {},
989 coll,
990 std::unique_ptr<WorkData>(data),
991 outputTensor.device(),
992 inputTensors,
993 outputTensors,
994 "ucc:allgather_base");
995 }
996
allreduce(std::vector<at::Tensor> & tensors,const AllreduceOptions & opts)997 c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce(
998 std::vector<at::Tensor>& tensors,
999 const AllreduceOptions& opts) {
1000 check_tensor(tensors);
1001 auto& tensor = tensors[0];
1002 initComm(tensor.device());
1003 WorkData* data = new WorkData();
1004
1005 ucc_coll_args_t coll;
1006 coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1007 coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
1008 coll.coll_type = UCC_COLL_TYPE_ALLREDUCE;
1009 coll.op = to_ucc_reduceOp(opts.reduceOp, tensor.scalar_type());
1010 coll.src.info.buffer = nullptr;
1011 coll.src.info.count = tensor.numel();
1012 coll.src.info.datatype = to_ucc_dType(tensor);
1013 coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
1014 coll.dst.info.buffer = tensor.data_ptr();
1015 coll.dst.info.count = tensor.numel();
1016 coll.dst.info.datatype = to_ucc_dType(tensor);
1017 coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
1018 SAVE_TENSORS(tensors, data->dst);
1019 return collective_post(
1020 OpType::ALLREDUCE,
1021 []() {},
1022 []() {},
1023 coll,
1024 std::unique_ptr<WorkData>(data),
1025 tensor.device(),
1026 tensors,
1027 tensors,
1028 "ucc:all_reduce");
1029 }
1030
allreduce_coalesced(std::vector<at::Tensor> &,const AllreduceCoalescedOptions &)1031 c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce_coalesced(
1032 std::vector<at::Tensor>& /* unused */,
1033 const AllreduceCoalescedOptions& /* unused */) {
1034 throw std::invalid_argument(
1035 "ProcessGroupUCC does not support allreduce_coalesced");
1036 }
1037
alltoall(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllToAllOptions &)1038 c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall(
1039 std::vector<at::Tensor>& outputTensors,
1040 std::vector<at::Tensor>& inputTensors,
1041 const AllToAllOptions& /* unused */) {
1042 auto device = outputTensors[0].device();
1043 for (const auto r : c10::irange(outputTensors.size())) {
1044 TORCH_CHECK(
1045 device == outputTensors[r].device() &&
1046 device == inputTensors[r].device(),
1047 "Tensors must be on the same device")
1048 }
1049
1050 initComm(device);
1051 ucc_coll_args_t coll;
1052 AlltoallWorkData* data;
1053 data = new AlltoallWorkData(size_);
1054
1055 /* to avoid flatten the tensors, we use alltoallv to achieve Alltoall as
1056 follow.
1057 1. store addresses of each tensor directly in displacements, keep buffer
1058 to nullptr, i.e., 0
1059 2. convert datatype to UINT8, which is always 1 bytes, to avoid wrong size
1060 calculation in UCC layer
1061 3. post Alltoallv
1062 */
1063 for (const auto i : c10::irange(size_)) {
1064 data->send_lengths[i] =
1065 (uint64_t)(inputTensors[i].element_size() * inputTensors[i].numel());
1066 data->send_offsets[i] = (uint64_t)inputTensors[i].data_ptr();
1067 data->recv_lengths[i] =
1068 (uint64_t)(outputTensors[i].element_size() * outputTensors[i].numel());
1069 data->recv_offsets[i] = (uint64_t)outputTensors[i].data_ptr();
1070 }
1071
1072 coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1073 coll.flags =
1074 UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
1075 coll.coll_type = UCC_COLL_TYPE_ALLTOALLV;
1076 coll.src.info_v.buffer = 0;
1077 coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
1078 coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
1079 coll.src.info_v.datatype = UCC_DT_UINT8;
1080 coll.src.info_v.mem_type = to_ucc_memType(inputTensors[0].device().type());
1081 coll.dst.info_v.buffer = 0;
1082 coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
1083 coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
1084 coll.dst.info_v.datatype = UCC_DT_UINT8;
1085 coll.dst.info_v.mem_type = to_ucc_memType(outputTensors[0].device().type());
1086
1087 SAVE_TENSORS(inputTensors, data->src);
1088 SAVE_TENSORS(outputTensors, data->dst);
1089
1090 return collective_post(
1091 OpType::ALLTOALL,
1092 []() {},
1093 []() {},
1094 coll,
1095 std::unique_ptr<WorkData>(data),
1096 device,
1097 inputTensors,
1098 outputTensors,
1099 "ucc:alltoall");
1100 }
1101
alltoall_base(at::Tensor & outputTensor,at::Tensor & inputTensor,std::vector<int64_t> & outputSplitSizes,std::vector<int64_t> & inputSplitSizes,const AllToAllOptions &)1102 c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall_base(
1103 at::Tensor& outputTensor,
1104 at::Tensor& inputTensor,
1105 std::vector<int64_t>& outputSplitSizes,
1106 std::vector<int64_t>& inputSplitSizes,
1107 const AllToAllOptions& /* unused */) {
1108 check_device(inputTensor.device(), outputTensor.device());
1109 initComm(inputTensor.device());
1110 ucc_coll_args_t coll;
1111 AlltoallWorkData* data;
1112
1113 if ((outputSplitSizes.size() == 0) && (inputSplitSizes.size() == 0)) {
1114 data = new AlltoallWorkData(0);
1115 TORCH_CHECK(
1116 (outputTensor.size(0) % size_ == 0) &&
1117 (inputTensor.size(0) % size_ == 0),
1118 "Tensor's dim 0 does not divide equally across group size");
1119 coll.mask = 0;
1120 coll.flags = 0;
1121 coll.coll_type = UCC_COLL_TYPE_ALLTOALL;
1122 coll.src.info.buffer = inputTensor.data_ptr();
1123 coll.src.info.count = inputTensor.element_size() * inputTensor.numel();
1124 coll.src.info.datatype = UCC_DT_UINT8;
1125 coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type());
1126 coll.dst.info.buffer = outputTensor.data_ptr();
1127 coll.dst.info.count = outputTensor.element_size() * outputTensor.numel();
1128 coll.dst.info.datatype = UCC_DT_UINT8;
1129 coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type());
1130 coll.flags = 0;
1131 } else {
1132 data = new AlltoallWorkData(size_);
1133 c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
1134 c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
1135 computeLengthsAndOffsets(
1136 outputSplitSizes,
1137 outputTensor,
1138 &data->recv_lengths,
1139 &data->recv_offsets);
1140 computeLengthsAndOffsets(
1141 inputSplitSizes, inputTensor, &data->send_lengths, &data->send_offsets);
1142 coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1143 coll.coll_type = UCC_COLL_TYPE_ALLTOALLV;
1144 coll.src.info_v.buffer = inputTensor.data_ptr();
1145 coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
1146 coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
1147 coll.src.info_v.datatype = to_ucc_dType(inputTensor);
1148 coll.src.info_v.mem_type = to_ucc_memType(inputTensor.device().type());
1149 coll.dst.info_v.buffer = outputTensor.data_ptr();
1150 coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
1151 coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
1152 coll.dst.info_v.datatype = to_ucc_dType(outputTensor);
1153 coll.dst.info_v.mem_type = to_ucc_memType(outputTensor.device().type());
1154 coll.flags = UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER |
1155 UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER | UCC_COLL_ARGS_FLAG_COUNT_64BIT |
1156 UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
1157
1158 if (torch_ucc_config.enable_comms_logger) {
1159 logger->trace_generator->recordOptionalInfo(
1160 outputSplitSizes, inputSplitSizes);
1161 }
1162 }
1163 std::vector<at::Tensor> inputTensors = {inputTensor};
1164 std::vector<at::Tensor> outputTensors = {outputTensor};
1165 SAVE_TENSORS(inputTensors, data->src);
1166 SAVE_TENSORS(outputTensors, data->dst);
1167
1168 return collective_post(
1169 OpType::ALLTOALL_BASE,
1170 []() {},
1171 []() {},
1172 coll,
1173 std::unique_ptr<WorkData>(data),
1174 inputTensor.device(),
1175 inputTensors,
1176 outputTensors,
1177 "ucc:alltoall");
1178 }
1179
barrier(const BarrierOptions & opts)1180 c10::intrusive_ptr<Work> ProcessGroupUCC::barrier(const BarrierOptions& opts) {
1181 c10::Device device = c10::Device(c10::DeviceType::CPU);
1182 #ifdef USE_CUDA
1183 auto numGPUs = c10::cuda::device_count();
1184 if (!opts.device_ids.empty()) {
1185 device = c10::Device(c10::DeviceType::CUDA, opts.device_ids.front());
1186 } else if (comm && comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) {
1187 device = c10::Device(c10::DeviceType::CUDA, comm->cuda_device_index);
1188 } else if (numGPUs > 0) {
1189 int8_t deviceIdx = static_cast<int8_t>(c10::cuda::current_device());
1190 // if current device is 0, likely the device is not set, use the best guess
1191 if (0 == (int)deviceIdx) {
1192 deviceIdx = static_cast<int8_t>(this->getRank() % numGPUs);
1193 }
1194 TORCH_UCC_LOG_INFO(
1195 TORCH_UCC_COLL_POST,
1196 c10::str(
1197 "post barrier before specifying any GPU while there are ",
1198 numGPUs,
1199 " GPUs available. ",
1200 "Not clear if GPU barrier is required, using GPU ",
1201 (int)deviceIdx,
1202 " to perform barrier. ",
1203 "Specify device_ids option in barrier() to force ",
1204 "use of a particular device"));
1205 device = c10::Device(c10::DeviceType::CUDA, deviceIdx);
1206 }
1207 #endif
1208 initComm(device);
1209
1210 ucc_coll_args_t coll;
1211 coll.mask = 0;
1212 coll.flags = 0;
1213 coll.coll_type = UCC_COLL_TYPE_BARRIER;
1214 auto dummy_tensor = std::vector<at::Tensor>();
1215 return collective_post(
1216 OpType::BARRIER,
1217 []() {},
1218 []() {},
1219 coll,
1220 nullptr,
1221 device,
1222 dummy_tensor,
1223 dummy_tensor,
1224 "ucc:barrier");
1225 }
1226
broadcast(std::vector<at::Tensor> & tensors,const BroadcastOptions & opts)1227 c10::intrusive_ptr<Work> ProcessGroupUCC::broadcast(
1228 std::vector<at::Tensor>& tensors,
1229 const BroadcastOptions& opts) {
1230 check_tensor(tensors);
1231 auto& tensor = tensors[0];
1232 initComm(tensor.device());
1233 WorkData* data = new WorkData();
1234
1235 ucc_coll_args_t coll;
1236 coll.mask = 0;
1237 coll.flags = 0;
1238 coll.coll_type = UCC_COLL_TYPE_BCAST;
1239 coll.src.info.buffer = tensor.data_ptr();
1240 coll.src.info.count = tensor.numel();
1241 coll.src.info.datatype = to_ucc_dType(tensor);
1242 coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
1243 coll.root = opts.rootRank;
1244 SAVE_TENSORS(tensors, data->dst);
1245
1246 if (torch_ucc_config.enable_comms_logger) {
1247 logger->trace_generator->recordOptionalInfo(opts.rootRank);
1248 }
1249
1250 return collective_post(
1251 OpType::BROADCAST,
1252 []() {},
1253 []() {},
1254 coll,
1255 std::unique_ptr<WorkData>(data),
1256 tensor.device(),
1257 tensors,
1258 tensors,
1259 "ucc:broadcast");
1260 }
1261
gather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const GatherOptions & opts)1262 c10::intrusive_ptr<Work> ProcessGroupUCC::gather(
1263 std::vector<std::vector<at::Tensor>>& outputTensors,
1264 std::vector<at::Tensor>& inputTensors,
1265 const GatherOptions& opts) {
1266 std::vector<at::Tensor> outputs;
1267 auto& input = inputTensors[0];
1268 initComm(input.device());
1269
1270 AllgathervWorkData* data = new AllgathervWorkData(size_);
1271 ucc_coll_args_t coll;
1272 coll.root = opts.rootRank;
1273 coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1274 coll.flags =
1275 UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
1276 coll.coll_type = UCC_COLL_TYPE_GATHERV;
1277
1278 /* for non-root ranks, only src is valid */
1279 coll.src.info.buffer = input.data_ptr();
1280 coll.src.info.count = (uint64_t)(input.element_size() * input.numel());
1281 coll.src.info.datatype = UCC_DT_UINT8;
1282 coll.src.info.mem_type = to_ucc_memType(input.device().type());
1283
1284 if (getRank() == opts.rootRank) {
1285 if (outputTensors.size() != 1) {
1286 TORCH_UCC_LOG_ERROR(
1287 TORCH_UCC_COLL_POST,
1288 c10::str(
1289 "gather requires a single-element output list containing a list with ",
1290 getSize(),
1291 " tensors."));
1292 } else if (outputTensors[0].size() != static_cast<size_t>(getSize())) {
1293 TORCH_UCC_LOG_ERROR(
1294 TORCH_UCC_COLL_POST,
1295 c10::str(
1296 "Incorrect output list size ",
1297 outputTensors[0].size(),
1298 ". Output list size should be ",
1299 getSize(),
1300 ", same as size of the process group."));
1301 }
1302 outputs = outputTensors[0];
1303
1304 for (int i = 0; i < size_; i++) {
1305 data->recv_lengths[i] =
1306 (uint64_t)(outputs[i].element_size() * outputs[i].numel());
1307 data->recv_offsets[i] = (uint64_t)outputs[i].data_ptr();
1308 }
1309 /* use gatherv and store non-contiguous addresses in displacements to avoid
1310 * flatten outputTensors */
1311 coll.dst.info_v.buffer = nullptr;
1312 coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
1313 coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
1314 coll.dst.info_v.datatype = UCC_DT_UINT8;
1315 coll.dst.info_v.mem_type = to_ucc_memType(outputs[0].device().type());
1316
1317 SAVE_TENSORS(outputs, data->dst);
1318 } else {
1319 // for non-root ranks, outputTensors should be an empty list
1320 if (outputTensors.size() != 0) {
1321 TORCH_UCC_LOG_ERROR(
1322 TORCH_UCC_COLL_POST, "requires empty output on non-root");
1323 }
1324 outputs = {};
1325 // append a empty tensor to the list to be used by future mark
1326 outputs.emplace_back();
1327 }
1328
1329 SAVE_TENSORS(inputTensors, data->src);
1330
1331 return collective_post(
1332 OpType::GATHER,
1333 []() {},
1334 []() {},
1335 coll,
1336 std::unique_ptr<WorkData>(data),
1337 input.device(),
1338 inputTensors,
1339 outputs,
1340 "ucc:gather");
1341 }
1342
reduce(std::vector<at::Tensor> & tensors,const ReduceOptions & opts)1343 c10::intrusive_ptr<Work> ProcessGroupUCC::reduce(
1344 std::vector<at::Tensor>& tensors,
1345 const ReduceOptions& opts) {
1346 check_tensor(tensors);
1347 auto& tensor = tensors[0];
1348 initComm(tensor.device());
1349 WorkData* data = new WorkData();
1350
1351 ucc_coll_args_t coll;
1352 coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1353 coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
1354 coll.coll_type = UCC_COLL_TYPE_REDUCE;
1355 coll.op = ucc_op_map.at(opts.reduceOp);
1356 coll.root = opts.rootRank;
1357 coll.src.info.buffer = tensor.data_ptr();
1358 coll.src.info.count = tensor.numel();
1359 coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
1360 coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
1361 coll.dst.info.buffer = tensor.data_ptr();
1362 coll.dst.info.count = tensor.numel();
1363 coll.dst.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
1364 coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
1365 SAVE_TENSORS(tensors, data->dst);
1366 return collective_post(
1367 OpType::REDUCE,
1368 []() {},
1369 []() {},
1370 coll,
1371 std::unique_ptr<WorkData>(data),
1372 tensor.device(),
1373 tensors,
1374 tensors,
1375 "ucc:reduce");
1376 }
1377
reduce_scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ReduceScatterOptions & opts)1378 c10::intrusive_ptr<Work> ProcessGroupUCC::reduce_scatter(
1379 std::vector<at::Tensor>& outputTensors,
1380 std::vector<std::vector<at::Tensor>>& inputTensors,
1381 const ReduceScatterOptions& opts) {
1382 TORCH_CHECK(
1383 (outputTensors.size() == inputTensors.size()),
1384 "Tensor input/output list for reduce_scatter must have same size");
1385 check_tensor(outputTensors);
1386 check_device(inputTensors[0][0].device(), outputTensors[0].device());
1387 initComm(inputTensors[0][0].device());
1388 auto data = std::make_unique<WorkData>();
1389 std::vector<at::Tensor> flat_input(inputTensors.size());
1390 for (size_t i = 0; i < inputTensors.size(); i++) {
1391 TORCH_CHECK(
1392 inputTensors[i].size() == inputTensors.size() * size_,
1393 "Tensor input list is not valid for the number of participants");
1394 flat_input[i] = c10d::newLikeFlat(inputTensors, i);
1395 }
1396 SAVE_TENSORS(flat_input, data->flat);
1397 check_tensor(flat_input);
1398 ucc_coll_args_t coll;
1399 coll.mask = 0;
1400 coll.flags = 0;
1401 coll.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER;
1402 coll.op = to_ucc_reduceOp(opts.reduceOp, flat_input[0].scalar_type());
1403
1404 coll.src.info.buffer = flat_input[0].data_ptr();
1405 coll.src.info.count = flat_input[0].numel();
1406 coll.src.info.datatype = to_ucc_dType(flat_input[0]);
1407 coll.src.info.mem_type = to_ucc_memType(flat_input[0].device().type());
1408 coll.dst.info.buffer = outputTensors[0].data_ptr();
1409 coll.dst.info.count = outputTensors[0].numel();
1410 coll.dst.info.datatype = to_ucc_dType(outputTensors[0]);
1411 coll.dst.info.mem_type = to_ucc_memType(outputTensors[0].device().type());
1412
1413 SAVE_TENSORS(inputTensors[0], data->src);
1414 SAVE_TENSORS(outputTensors, data->dst);
1415
1416 auto copy_to_flat = [&] {
1417 bool asyncCopy = false;
1418 auto isize = inputTensors.size();
1419 #ifdef USE_CUDA
1420 bool isCuda = inputTensors[0][0].device().is_cuda();
1421 #endif
1422 for (size_t i = 0; i < isize; i++) {
1423 auto onumel = outputTensors[i].numel();
1424 for (size_t j = 0; j < inputTensors[i].size(); j++) {
1425 TORCH_CHECK(
1426 (inputTensors[i][j].numel() == onumel),
1427 "Tensor operand counts must be same");
1428 #ifdef USE_CUDA
1429 if (isCuda) {
1430 c10::cuda::CUDACachingAllocator::recordStream(
1431 inputTensors[i][j].storage().data_ptr(), (*stream));
1432 asyncCopy = true;
1433 }
1434 #endif
1435 flat_input[i][j].copy_(inputTensors[i][j], asyncCopy);
1436 }
1437 }
1438 };
1439
1440 return collective_post(
1441 OpType::REDUCE_SCATTER,
1442 copy_to_flat,
1443 []() {},
1444 coll,
1445 std::move(data),
1446 inputTensors[0][0].device(),
1447 inputTensors[0],
1448 outputTensors,
1449 "ucc:reduce_scatter");
1450 }
1451
scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ScatterOptions & opts)1452 c10::intrusive_ptr<Work> ProcessGroupUCC::scatter(
1453 std::vector<at::Tensor>& outputTensors,
1454 std::vector<std::vector<at::Tensor>>& inputTensors,
1455 const ScatterOptions& opts) {
1456 auto& tensor = outputTensors[0];
1457 initComm(tensor.device());
1458
1459 ScattervWorkData* data = new ScattervWorkData(size_);
1460 ucc_coll_args_t coll;
1461 coll.root = opts.rootRank;
1462 coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1463 coll.flags =
1464 UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
1465 coll.coll_type = UCC_COLL_TYPE_SCATTERV;
1466
1467 if (getRank() == opts.rootRank) {
1468 /* src is only valid at non-root rank */
1469 if (inputTensors.size() != 1) {
1470 TORCH_UCC_LOG_ERROR(
1471 TORCH_UCC_COLL_POST,
1472 c10::str(
1473 "gather requires a single-element output list containing a list with ",
1474 getSize(),
1475 " tensors."));
1476 } else if (inputTensors[0].size() != static_cast<size_t>(getSize())) {
1477 TORCH_UCC_LOG_ERROR(
1478 TORCH_UCC_COLL_POST,
1479 c10::str(
1480 "Incorrect output list size ",
1481 inputTensors[0].size(),
1482 ". Output list size should be ",
1483 getSize(),
1484 ", same as size of the process group."));
1485 }
1486
1487 for (int i = 0; i < size_; i++) {
1488 data->send_lengths[i] = (uint64_t)tensor.element_size() * tensor.numel();
1489 data->send_offsets[i] = (uint64_t)inputTensors[0][i].data_ptr();
1490 }
1491 /* use scatter and store non-contiguous addresses in displacements to avoid
1492 * flatten inputTensors */
1493 coll.src.info_v.buffer = nullptr;
1494 coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
1495 coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
1496 coll.src.info_v.datatype = UCC_DT_UINT8;
1497 coll.src.info_v.mem_type =
1498 to_ucc_memType(inputTensors[0][0].device().type());
1499
1500 SAVE_TENSORS(inputTensors[0], data->src);
1501 } else {
1502 // for non-root ranks, inputTensors should be an empty list
1503 if (inputTensors.size() != 0) {
1504 TORCH_UCC_LOG_ERROR(
1505 TORCH_UCC_COLL_POST, "requires empty output on non-root");
1506 }
1507 }
1508
1509 coll.dst.info.buffer = tensor.data_ptr();
1510 coll.dst.info.count = (uint64_t)tensor.element_size() * tensor.numel();
1511 coll.dst.info.datatype = UCC_DT_UINT8;
1512 coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
1513 SAVE_TENSORS(outputTensors, data->dst);
1514
1515 return collective_post(
1516 OpType::SCATTER,
1517 []() {},
1518 []() {},
1519 coll,
1520 std::unique_ptr<WorkData>(data),
1521 tensor.device(),
1522 (getRank() == opts.rootRank) ? inputTensors[0] : outputTensors,
1523 outputTensors,
1524 "ucc:scatter");
1525 }
1526
send(std::vector<at::Tensor> & tensors,int dstRank,int tag)1527 c10::intrusive_ptr<Work> ProcessGroupUCC::send(
1528 std::vector<at::Tensor>& tensors,
1529 int dstRank,
1530 int tag) {
1531 check_tensor(tensors);
1532 auto& tensor = tensors[0];
1533 initComm(tensor.device());
1534
1535 WorkData* data = new WorkData();
1536 ucc_coll_args_t coll;
1537 coll.tag = tag;
1538 coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG;
1539 coll.flags = 0;
1540 coll.coll_type = UCC_COLL_TYPE_BCAST;
1541 coll.src.info.buffer = tensor.data_ptr();
1542 coll.src.info.count = tensor.numel();
1543 coll.src.info.datatype = to_ucc_dType(tensor);
1544 coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
1545 coll.root = getRank();
1546
1547 coll.active_set.size = 2;
1548 coll.active_set.start = getRank();
1549 coll.active_set.stride = dstRank - getRank();
1550 SAVE_TENSORS(tensors, data->dst);
1551
1552 return collective_post(
1553 OpType::SEND,
1554 []() {},
1555 []() {},
1556 coll,
1557 std::unique_ptr<WorkData>(data),
1558 tensor.device(),
1559 tensors,
1560 tensors,
1561 "ucc:send");
1562 }
1563
recv(std::vector<at::Tensor> & tensors,int srcRank,int tag)1564 c10::intrusive_ptr<Work> ProcessGroupUCC::recv(
1565 std::vector<at::Tensor>& tensors,
1566 int srcRank,
1567 int tag) {
1568 check_tensor(tensors);
1569 auto& tensor = tensors[0];
1570 initComm(tensor.device());
1571
1572 WorkData* data = new WorkData();
1573 ucc_coll_args_t coll;
1574 coll.tag = tag;
1575 coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG;
1576 coll.flags = 0;
1577 coll.coll_type = UCC_COLL_TYPE_BCAST;
1578 coll.src.info.buffer = tensor.data_ptr();
1579 coll.src.info.count = tensor.numel();
1580 coll.src.info.datatype = to_ucc_dType(tensor);
1581 coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
1582 coll.root = srcRank;
1583
1584 coll.active_set.size = 2;
1585 coll.active_set.start = srcRank;
1586 coll.active_set.stride = getRank() - srcRank;
1587 SAVE_TENSORS(tensors, data->dst);
1588
1589 return collective_post(
1590 OpType::RECV,
1591 []() {},
1592 []() {},
1593 coll,
1594 std::unique_ptr<WorkData>(data),
1595 tensor.device(),
1596 tensors,
1597 tensors,
1598 "ucc:recv");
1599 }
1600
setSequenceNumberForGroup()1601 void ProcessGroupUCC::setSequenceNumberForGroup() {}
1602
getSequenceNumberForGroup()1603 uint64_t ProcessGroupUCC::getSequenceNumberForGroup() {
1604 return seq_;
1605 }
1606
createProcessGroupUCC(const c10::intrusive_ptr<::c10d::Store> & store,int rank,int size,const std::chrono::duration<float> & timeout)1607 c10::intrusive_ptr<Backend> ProcessGroupUCC::createProcessGroupUCC(
1608 const c10::intrusive_ptr<::c10d::Store>& store,
1609 int rank,
1610 int size,
1611 const std::chrono::duration<float>& timeout) {
1612 return c10::make_intrusive<ProcessGroupUCC>(store, rank, size, timeout);
1613 }
1614
initComm(c10::Device dev)1615 void ProcessGroupUCC::initComm(c10::Device dev) {
1616 if (!comm) {
1617 #ifdef USE_CUDA
1618 if (dev.is_cuda()) {
1619 c10::cuda::set_device(dev.index());
1620 }
1621 #endif
1622 comm = Comm::get_comm(comm_id, dev, oob, logger);
1623 TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCX library");
1624 comm->ucc_create_team(team, oob);
1625 TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCC library");
1626 logger->setPhase(TORCH_UCC_READY);
1627 } else {
1628 if (dev.is_cuda()) {
1629 if ((comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
1630 (comm->cuda_device_index != dev.index())) {
1631 TORCH_UCC_LOG_ERROR(
1632 TORCH_UCC_INIT,
1633 "ucc communicator was initialized with different cuda device,"
1634 "multi device is not supported");
1635 throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
1636 }
1637 comm->cuda_device_index = dev.index();
1638 }
1639 }
1640 #ifdef USE_CUDA
1641 // Create UCC execution engine.
1642 if (!cuda_ee && dev.is_cuda()) {
1643 stream = std::make_unique<at::cuda::CUDAStream>(
1644 at::cuda::getStreamFromPool(true, dev.index()));
1645 ucc_ee_params_t params;
1646 params.ee_type = UCC_EE_CUDA_STREAM;
1647 params.ee_context = (void*)stream->stream();
1648 params.ee_context_size = sizeof(cudaStream_t);
1649 TORCH_UCC_CHECK(
1650 ucc_ee_create(team, ¶ms, &cuda_ee),
1651 "failed to create UCC execution engine");
1652 for (int i = 0; i < 2; i++) {
1653 stream_p2p[i] = std::make_unique<at::cuda::CUDAStream>(
1654 at::cuda::getStreamFromPool(true, dev.index()));
1655 ucc_ee_params_t params;
1656 params.ee_type = UCC_EE_CUDA_STREAM;
1657 params.ee_context = (void*)stream_p2p[i]->stream();
1658 params.ee_context_size = sizeof(cudaStream_t);
1659 TORCH_UCC_CHECK(
1660 ucc_ee_create(team, ¶ms, &cuda_ee_p2p[i]),
1661 "failed to create UCC P2P execution engine");
1662 }
1663 }
1664 #endif
1665 }
1666
1667 } // namespace c10d
1668
1669 #endif // USE_C10D_UCC
1670