1 #include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
2
3 #ifdef USE_C10D_MPI
4
5 #include <iostream>
6 #include <map>
7
8 #include <c10/core/DeviceGuard.h>
9 #include <c10/util/irange.h>
10
11 #if defined(OPEN_MPI) && OPEN_MPI
12 #include <mpi-ext.h> // Needed for CUDA-aware check
13 #endif
14
15 namespace c10d {
16
17 #define MPI_CHECK(cmd) \
18 do { \
19 int mpiStatus = cmd; \
20 if (mpiStatus != MPI_SUCCESS) { \
21 std::string err = "MPI error in: " + std::string(__FILE__) + ":" + \
22 std::to_string(__LINE__) + \
23 ", with error code: " + std::to_string(mpiStatus); \
24 TORCH_CHECK(false, err); \
25 } \
26 } while (0)
27
28 namespace {
29
30 // Op mapping
31 std::map<ReduceOp::RedOpType, MPI_Op> mpiOp = {
32 {ReduceOp::MIN, MPI_MIN},
33 {ReduceOp::MAX, MPI_MAX},
34 {ReduceOp::SUM, MPI_SUM},
35 {ReduceOp::PRODUCT, MPI_PROD},
36 };
37 // Type mapping
38 std::map<at::ScalarType, MPI_Datatype> mpiDatatype = {
39 {at::kByte, MPI_UNSIGNED_CHAR},
40 {at::kChar, MPI_CHAR},
41 {at::kDouble, MPI_DOUBLE},
42 {at::kFloat, MPI_FLOAT},
43 {at::kInt, MPI_INT},
44 {at::kLong, MPI_LONG},
45 {at::kShort, MPI_SHORT},
46 };
47
48 // Checking CUDA-aware MPI support, currently we only support CUDA aware
49 // MPI ops through Open MPI
cudaAwareMpiCheck()50 bool cudaAwareMpiCheck() {
51 // Run time check
52 #if defined(MPIX_CUDA_AWARE_SUPPORT)
53 if (MPIX_Query_cuda_support() == 1) {
54 return true;
55 } else {
56 return false;
57 }
58 #else // !defined(MPIX_CUDA_AWARE_SUPPORT)
59 return false;
60 #endif // MPIX_CUDA_AWARE_SUPPORT
61 }
62
63 // Checking the input tensor's validity
checkSingleTensorHelper(const at::Tensor & tensor)64 void checkSingleTensorHelper(const at::Tensor& tensor) {
65 if (!tensor.is_contiguous()) {
66 TORCH_CHECK(false, "input tensor has to be contiguous");
67 }
68 if (tensor.is_sparse()) {
69 TORCH_CHECK(false, "input tensor has to be dense");
70 }
71 if (tensor.is_cuda() && !cudaAwareMpiCheck()) {
72 TORCH_CHECK(
73 false,
74 "CUDA tensor detected and the MPI used doesn't "
75 "have CUDA-aware MPI support");
76 }
77 }
78
checkSingleTensor(const std::vector<at::Tensor> & tensors)79 void checkSingleTensor(const std::vector<at::Tensor>& tensors) {
80 if (tensors.size() != 1) {
81 TORCH_CHECK(
82 false, "MPI process group does not support multi-GPU collectives");
83 }
84 checkSingleTensorHelper(tensors[0]);
85 }
86
checkSameSizeAndType(const at::Tensor & t_in,const std::vector<at::Tensor> & tensors)87 void checkSameSizeAndType(
88 const at::Tensor& t_in,
89 const std::vector<at::Tensor>& tensors) {
90 for (const auto& tensor : tensors) {
91 if ((tensor.numel() != t_in.numel()) ||
92 (tensor.scalar_type() != t_in.scalar_type())) {
93 TORCH_CHECK(false, "Tensors are not equal in size or data type");
94 }
95 checkSingleTensorHelper(tensor);
96 }
97 }
98
99 } // namespace
100
result()101 std::vector<at::Tensor> ProcessGroupMPI::WorkMPI::result() {
102 return outputTensors_;
103 }
104
getFuture()105 c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupMPI::WorkMPI::getFuture() {
106 return future_;
107 }
108
finishWorkMPIError(const std::exception_ptr & eptr)109 void ProcessGroupMPI::WorkMPI::finishWorkMPIError(
110 const std::exception_ptr& eptr) {
111 future_->setError(eptr);
112 finish(eptr);
113 }
114
finishWorkMPI()115 void ProcessGroupMPI::WorkMPI::finishWorkMPI() {
116 future_->markCompleted(at::IValue(outputTensors_));
117 finish();
118 }
119
AsyncWork(MPI_Request request,std::vector<at::Tensor> outputTensors,const char * profilingTitle,const std::optional<std::vector<at::Tensor>> & inputTensors)120 ProcessGroupMPI::AsyncWork::AsyncWork(
121 MPI_Request request,
122 std::vector<at::Tensor> outputTensors,
123 const char* profilingTitle,
124 const std::optional<std::vector<at::Tensor>>& inputTensors)
125 : Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors),
126 outputTensors_(std::move(outputTensors)),
127 request_(request) {
128 memset(&status_, 0, sizeof(status_));
129 }
130
~AsyncWork()131 ProcessGroupMPI::AsyncWork::~AsyncWork() {
132 if (request_ != MPI_REQUEST_NULL) {
133 std::cerr
134 << "Attempted destruction of AsyncWork before work has completed, "
135 << "terminating the program." << '\n';
136 std::terminate();
137 }
138 }
139
isCompleted()140 bool ProcessGroupMPI::AsyncWork::isCompleted() {
141 if (request_ == MPI_REQUEST_NULL) {
142 return true;
143 }
144
145 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
146 int flag = 0;
147 MPI_CHECK(MPI_Test(&request_, &flag, &status_));
148 if (request_ != MPI_REQUEST_NULL) {
149 return false;
150 }
151
152 // request_ == MPI_REQUEST_NULL; the work has completed
153 // Populate exception if request was not successful
154 if (status_.MPI_ERROR != MPI_SUCCESS) {
155 populateException();
156 }
157
158 return true;
159 }
160
isSuccess() const161 bool ProcessGroupMPI::AsyncWork::isSuccess() const {
162 if (request_ != MPI_REQUEST_NULL) {
163 TORCH_CHECK(
164 false,
165 "Invalid call to AsyncWork::isSuccess before work has completed");
166 }
167
168 return status_.MPI_ERROR == MPI_SUCCESS;
169 }
170
sourceRank() const171 int ProcessGroupMPI::AsyncWork::sourceRank() const {
172 return status_.MPI_SOURCE;
173 }
174
wait(std::chrono::milliseconds)175 bool ProcessGroupMPI::AsyncWork::wait(std::chrono::milliseconds /* unused */) {
176 if (request_ == MPI_REQUEST_NULL) {
177 // AsyncWork needs to manually call profiling end callbacks if they are set,
178 // since it does not call ProcessGroup::finish().
179 if (Work::recordFunctionEndCallback_) {
180 Work::recordFunctionEndCallback_();
181 Work::recordFunctionEndCallback_ = nullptr;
182 }
183 return true;
184 }
185
186 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
187 MPI_CHECK(MPI_Wait(&request_, &status_));
188 auto ok = (status_.MPI_ERROR == MPI_SUCCESS);
189
190 // AsyncWork needs to manually call profiling end callbacks if they are set,
191 // since it does not call ProcessGroup::finish().
192 if (Work::recordFunctionEndCallback_) {
193 Work::recordFunctionEndCallback_();
194 Work::recordFunctionEndCallback_ = nullptr;
195 }
196
197 if (!ok) {
198 populateException();
199 std::rethrow_exception(exception_);
200 }
201 // Always return true, because abort API is not implemented.
202 return true;
203 }
204
abort()205 void ProcessGroupMPI::AsyncWork::abort(){
206 TORCH_CHECK(false, "ProcessGroupMPI::AsyncWork::abort not implemented.")}
207
result()208 std::vector<at::Tensor> ProcessGroupMPI::AsyncWork::result() {
209 return outputTensors_;
210 }
211
populateException()212 void ProcessGroupMPI::AsyncWork::populateException() {
213 std::array<char, MPI_MAX_ERROR_STRING> buf{};
214 int len = buf.size();
215 MPI_CHECK(MPI_Error_string(status_.MPI_ERROR, buf.data(), &len));
216 exception_ =
217 std::make_exception_ptr(std::runtime_error(std::string(buf.data(), len)));
218 }
219
220 // Static global states
221 int ProcessGroupMPI::mpiThreadSupport_ = 0;
222 std::mutex ProcessGroupMPI::pgGlobalMutex_;
223 // We only want to initialize once
224 c10::once_flag ProcessGroupMPI::onceFlagInitMPI;
225
mpiExit()226 void ProcessGroupMPI::mpiExit() {
227 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
228 MPI_CHECK(MPI_Finalize());
229 }
230
initMPIOnce()231 void ProcessGroupMPI::initMPIOnce() {
232 // Initialize MPI environment
233 c10::call_once(onceFlagInitMPI, []() {
234 int mpi_was_initialized = 0;
235 MPI_CHECK(MPI_Initialized(&mpi_was_initialized));
236 if (mpi_was_initialized == 0) {
237 MPI_CHECK(MPI_Init_thread(
238 nullptr, nullptr, MPI_THREAD_SERIALIZED, &mpiThreadSupport_));
239 if (mpiThreadSupport_ < MPI_THREAD_SERIALIZED) {
240 TORCH_CHECK(
241 false,
242 "Used MPI implementation doesn't have the "
243 "minimum level of threading support: "
244 "MPI_THREAD_SERIALIZED. This is required by "
245 "c10d package");
246 }
247 if (std::atexit(ProcessGroupMPI::mpiExit)) {
248 TORCH_CHECK(false, "Fail to register the MPI exit handler");
249 }
250 } else {
251 TORCH_WARN_ONCE("MPI was previously initialized.");
252 }
253 });
254 }
255
createProcessGroupMPI(std::vector<int> ranks)256 c10::intrusive_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI(
257 std::vector<int> ranks) {
258 // Once initialization
259 initMPIOnce();
260
261 MPI_Comm groupComm = MPI_COMM_WORLD;
262 int rank = -1;
263 int size = -1;
264
265 {
266 std::lock_guard<std::mutex> globalLock(pgGlobalMutex_);
267
268 // If no ranks are specified, assume we're creating the root group
269 if (!ranks.empty()) {
270 MPI_Group worldGroup{};
271 MPI_Group ranksGroup{};
272 MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
273 MPI_CHECK(
274 MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup));
275 // `MPI_Comm_create` can be flaky in certain cases.
276 // See: https://github.com/pytorch/pytorch/issues/53899
277 constexpr int kMaxNumRetries = 3;
278 bool groupComm_updated = false;
279 MPI_Barrier(MPI_COMM_WORLD);
280 for (const auto i : c10::irange(kMaxNumRetries)) {
281 (void)i;
282 if (MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm)) {
283 groupComm_updated = true;
284 break;
285 }
286 }
287 MPI_CHECK(groupComm_updated);
288 MPI_CHECK(MPI_Group_free(&worldGroup));
289 MPI_CHECK(MPI_Group_free(&ranksGroup));
290 }
291
292 // Fetch rank and world size for this group (MPI_COMM_WORLD or new)
293 if (groupComm != MPI_COMM_NULL) {
294 MPI_CHECK(MPI_Comm_rank(groupComm, &rank));
295 MPI_CHECK(MPI_Comm_size(groupComm, &size));
296
297 if (rank < 0 || size < 0) {
298 TORCH_CHECK(false, "Failed to get the world_size / rank");
299 }
300 }
301 }
302
303 // If this process is not part of the group, we don't construct a
304 // process group instance. This is in line with the semantics of the
305 // other process group types.
306 if (groupComm == MPI_COMM_NULL) {
307 return c10::intrusive_ptr<ProcessGroupMPI>();
308 }
309
310 return c10::make_intrusive<ProcessGroupMPI>(rank, size, groupComm);
311 }
312
ProcessGroupMPI(int rank,int size,MPI_Comm pgComm)313 ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm)
314 : Backend(rank, size), stop_(false), pgComm_(pgComm) {
315 if (pgComm_ == MPI_COMM_NULL) {
316 TORCH_CHECK(false, "pgComm_ must not be MPI_COMM_NULL");
317 }
318
319 // Start the worker thread accepting MPI calls
320 workerThread_ = std::thread(&ProcessGroupMPI::runLoop, this);
321
322 init();
323 }
324
~ProcessGroupMPI()325 ProcessGroupMPI::~ProcessGroupMPI() {
326 destroy();
327 }
328
destroy()329 void ProcessGroupMPI::destroy() {
330 std::unique_lock<std::mutex> lock(pgMutex_);
331 queueConsumeCV_.wait(lock, [&] { return queue_.empty(); });
332
333 // Queue is empty, signal stop
334 stop_ = true;
335
336 // Release lock to allow threads to terminate
337 lock.unlock();
338 queueProduceCV_.notify_all();
339
340 // Join the single worker thread
341 workerThread_.join();
342 }
343
abort()344 void ProcessGroupMPI::abort() {
345 destroy();
346 MPI_Abort(pgComm_, EXIT_FAILURE);
347 }
348
runLoop()349 void ProcessGroupMPI::runLoop() {
350 std::unique_lock<std::mutex> lock(pgMutex_);
351
352 while (!stop_) {
353 if (queue_.empty()) {
354 queueProduceCV_.wait(lock);
355 continue;
356 }
357
358 auto workTuple = std::move(queue_.front());
359
360 queue_.pop_front();
361
362 auto& workEntry = std::get<0>(workTuple);
363 auto& work = std::get<1>(workTuple);
364
365 lock.unlock();
366 queueConsumeCV_.notify_one();
367
368 try {
369 workEntry->run(workEntry);
370 work->finishWorkMPI();
371 } catch (...) {
372 work->finishWorkMPIError(std::current_exception());
373 }
374
375 lock.lock();
376 }
377 }
378
enqueue(std::unique_ptr<WorkEntry> entry,const char * profilingTitle,const std::optional<std::vector<at::Tensor>> & inputTensors)379 c10::intrusive_ptr<Work> ProcessGroupMPI::enqueue(
380 std::unique_ptr<WorkEntry> entry,
381 const char* profilingTitle,
382 const std::optional<std::vector<at::Tensor>>& inputTensors) {
383 auto work =
384 c10::make_intrusive<WorkMPI>(entry->dst, profilingTitle, inputTensors);
385 std::unique_lock<std::mutex> lock(pgMutex_);
386 queue_.emplace_back(std::move(entry), work);
387 lock.unlock();
388 queueProduceCV_.notify_one();
389 return work;
390 }
391
broadcast(std::vector<at::Tensor> & tensors,const BroadcastOptions & opts)392 c10::intrusive_ptr<Work> ProcessGroupMPI::broadcast(
393 std::vector<at::Tensor>& tensors,
394 const BroadcastOptions& opts) {
395 checkSingleTensor(tensors);
396 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
397 [opts, this](std::unique_ptr<WorkEntry>& entry) {
398 auto data = (entry->src)[0];
399 c10::DeviceGuard guard(data.device());
400 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
401 MPI_CHECK(MPI_Bcast(
402 data.data_ptr(),
403 data.numel(),
404 mpiDatatype.at(data.scalar_type()),
405 opts.rootRank,
406 pgComm_));
407 };
408 auto entry =
409 std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
410 return enqueue(
411 std::move(entry),
412 "mpi:broadcast",
413 std::optional<std::vector<at::Tensor>>(tensors));
414 }
415
allreduce(std::vector<at::Tensor> & tensors,const AllreduceOptions & opts)416 c10::intrusive_ptr<Work> ProcessGroupMPI::allreduce(
417 std::vector<at::Tensor>& tensors,
418 const AllreduceOptions& opts) {
419 checkSingleTensor(tensors);
420
421 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
422 [opts, this](std::unique_ptr<WorkEntry>& entry) {
423 auto data = (entry->src)[0];
424 c10::DeviceGuard guard(data.device());
425 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
426 MPI_CHECK(MPI_Allreduce(
427 MPI_IN_PLACE,
428 data.data_ptr(),
429 data.numel(),
430 mpiDatatype.at(data.scalar_type()),
431 mpiOp.at(opts.reduceOp),
432 pgComm_));
433 };
434 auto entry =
435 std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
436 return enqueue(
437 std::move(entry),
438 "mpi:all_reduce",
439 std::optional<std::vector<at::Tensor>>(tensors));
440 }
441
allreduce_coalesced(std::vector<at::Tensor> & tensors,const AllreduceCoalescedOptions & opts)442 c10::intrusive_ptr<Work> ProcessGroupMPI::allreduce_coalesced(
443 std::vector<at::Tensor>& tensors,
444 const AllreduceCoalescedOptions& opts) {
445 TORCH_CHECK(false, "allreduce_coalesced is currently not supported with MPI");
446 }
447
reduce(std::vector<at::Tensor> & tensors,const ReduceOptions & opts)448 c10::intrusive_ptr<Work> ProcessGroupMPI::reduce(
449 std::vector<at::Tensor>& tensors,
450 const ReduceOptions& opts) {
451 checkSingleTensor(tensors);
452
453 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
454 [opts, this](std::unique_ptr<WorkEntry>& entry) {
455 auto data = (entry->src)[0];
456 auto dataPtr = (entry->src)[0].data_ptr();
457 void* sendbuf = (rank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr;
458 void* recvbuf = (rank_ == opts.rootRank) ? dataPtr : nullptr;
459
460 c10::DeviceGuard guard(data.device());
461 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
462 MPI_CHECK(MPI_Reduce(
463 sendbuf,
464 recvbuf,
465 data.numel(),
466 mpiDatatype.at(data.scalar_type()),
467 mpiOp.at(opts.reduceOp),
468 opts.rootRank,
469 pgComm_));
470 };
471 auto entry =
472 std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
473 return enqueue(
474 std::move(entry),
475 "mpi:reduce",
476 std::optional<std::vector<at::Tensor>>(tensors));
477 }
478
allgather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts)479 c10::intrusive_ptr<Work> ProcessGroupMPI::allgather(
480 std::vector<std::vector<at::Tensor>>& outputTensors,
481 std::vector<at::Tensor>& inputTensors,
482 const AllgatherOptions& opts) {
483 checkSingleTensor(inputTensors);
484 if (outputTensors.size() != 1) {
485 TORCH_CHECK(
486 false,
487 "MPI process group only supports a single "
488 "tensor op");
489 }
490 if (static_cast<size_t>(size_) != outputTensors[0].size()) {
491 TORCH_CHECK(
492 false,
493 "All gather: number of output tensors should equal "
494 "to the world size");
495 }
496
497 checkSameSizeAndType(inputTensors[0], outputTensors[0]);
498
499 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
500 [this](std::unique_ptr<WorkEntry>& entry) {
501 auto data = (entry->src)[0];
502 std::vector<at::Tensor> outputDataVec = entry->dst;
503 auto flatOutputTensor = newLikeFlat(outputDataVec);
504
505 c10::DeviceGuard guard(data.device());
506 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
507 MPI_CHECK(MPI_Allgather(
508 data.data_ptr(),
509 data.numel(),
510 mpiDatatype.at(data.scalar_type()),
511 flatOutputTensor.data_ptr(),
512 data.numel(),
513 mpiDatatype.at(data.scalar_type()),
514 pgComm_));
515
516 for (const auto i : c10::irange(outputDataVec.size())) {
517 outputDataVec[i].copy_(flatOutputTensor[static_cast<int64_t>(i)]);
518 }
519 };
520 auto entry = std::make_unique<WorkEntry>(
521 &inputTensors, &outputTensors[0], std::move(runFunc));
522 return enqueue(
523 std::move(entry),
524 "mpi:all_gather",
525 std::optional<std::vector<at::Tensor>>(inputTensors));
526 }
527
allgather_coalesced(std::vector<std::vector<at::Tensor>> &,std::vector<at::Tensor> &,const AllgatherOptions &)528 c10::intrusive_ptr<Work> ProcessGroupMPI::allgather_coalesced(
529 std::vector<std::vector<at::Tensor>>& /* unused */,
530 std::vector<at::Tensor>& /* unused */,
531 const AllgatherOptions& /* unused */) {
532 TORCH_CHECK(false, "ProcessGroupMPI does not support allgather_coalesced");
533 }
534
gather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const GatherOptions & opts)535 c10::intrusive_ptr<Work> ProcessGroupMPI::gather(
536 std::vector<std::vector<at::Tensor>>& outputTensors,
537 std::vector<at::Tensor>& inputTensors,
538 const GatherOptions& opts) {
539 checkSingleTensor(inputTensors);
540
541 if (rank_ != opts.rootRank) {
542 if (!outputTensors.empty()) {
543 TORCH_CHECK(
544 false,
545 "Gather: number of output tensors should be 0 "
546 "for non-root");
547 }
548 } else {
549 if (outputTensors.size() != 1) {
550 TORCH_CHECK(false, "Gather: multi-GPU collective is not supported");
551 }
552 if (static_cast<size_t>(size_) != outputTensors[0].size()) {
553 TORCH_CHECK(
554 false,
555 "Gather: number of output tensors should equal "
556 "to the world size");
557 }
558 checkSameSizeAndType(inputTensors[0], outputTensors[0]);
559 }
560
561 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
562 [opts, this](std::unique_ptr<WorkEntry>& entry) {
563 auto data = (entry->src)[0];
564 void* recvbuf = nullptr;
565 at::Tensor flatOutputTensor;
566
567 std::vector<at::Tensor> dstdata = entry->dst;
568 if (rank_ == opts.rootRank) {
569 flatOutputTensor = newLikeFlat(dstdata);
570 recvbuf = flatOutputTensor.data_ptr();
571 }
572
573 c10::DeviceGuard guard(data.device());
574 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
575 MPI_CHECK(MPI_Gather(
576 data.data_ptr(),
577 data.numel(),
578 mpiDatatype.at(data.scalar_type()),
579 recvbuf,
580 data.numel(),
581 mpiDatatype.at(data.scalar_type()),
582 opts.rootRank,
583 pgComm_));
584
585 if (rank_ == opts.rootRank) {
586 const std::vector<at::Tensor>& outputDataVec = entry->dst;
587 // copy the flattened output tensors to the outputs
588 for (const auto i : c10::irange(outputDataVec.size())) {
589 outputDataVec.at(i).copy_(
590 flatOutputTensor[static_cast<int64_t>(i)]);
591 }
592 }
593 };
594
595 if (rank_ == opts.rootRank) {
596 auto entry = std::make_unique<WorkEntry>(
597 &inputTensors, &outputTensors[0], std::move(runFunc));
598 return enqueue(
599 std::move(entry),
600 "mpi:gather",
601 std::optional<std::vector<at::Tensor>>(inputTensors));
602 } else {
603 auto entry =
604 std::make_unique<WorkEntry>(&inputTensors, nullptr, std::move(runFunc));
605 return enqueue(
606 std::move(entry),
607 "mpi:gather",
608 std::optional<std::vector<at::Tensor>>(inputTensors));
609 }
610 }
611
scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ScatterOptions & opts)612 c10::intrusive_ptr<Work> ProcessGroupMPI::scatter(
613 std::vector<at::Tensor>& outputTensors,
614 std::vector<std::vector<at::Tensor>>& inputTensors,
615 const ScatterOptions& opts) {
616 checkSingleTensor(outputTensors);
617
618 if (rank_ != opts.rootRank) {
619 if (!inputTensors.empty()) {
620 TORCH_CHECK(
621 false,
622 "Scatter: number of input tensors should be 0 "
623 "for non-root");
624 }
625 } else {
626 if (inputTensors.size() != 1) {
627 TORCH_CHECK(false, "Scatter: multi-GPU collective is not supported");
628 }
629 if (static_cast<size_t>(size_) != inputTensors[0].size()) {
630 TORCH_CHECK(
631 false,
632 "Scatter: number of input tensors should equal "
633 "to the world size");
634 }
635 checkSameSizeAndType(outputTensors[0], inputTensors[0]);
636 }
637
638 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
639 [opts, this](std::unique_ptr<WorkEntry>& entry) {
640 auto data = (entry->dst)[0];
641 void* sendbuf = nullptr;
642 at::Tensor flatInputTensor;
643
644 if (rank_ == opts.rootRank) {
645 std::vector<at::Tensor>& inputDataVec = entry->src;
646 flatInputTensor = newLikeFlat(inputDataVec);
647 sendbuf = flatInputTensor.data_ptr();
648
649 // copy the input tensors to the flatten large send buffer
650 for (const auto i : c10::irange(inputDataVec.size())) {
651 flatInputTensor[static_cast<int64_t>(i)].copy_(inputDataVec.at(i));
652 }
653 }
654
655 c10::DeviceGuard guard(data.device());
656 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
657 MPI_CHECK(MPI_Scatter(
658 sendbuf,
659 data.numel(),
660 mpiDatatype.at(data.scalar_type()),
661 data.data_ptr(),
662 data.numel(),
663 mpiDatatype.at(data.scalar_type()),
664 opts.rootRank,
665 pgComm_));
666 };
667
668 if (rank_ == opts.rootRank) {
669 auto entry = std::make_unique<WorkEntry>(
670 &inputTensors[0], &outputTensors, std::move(runFunc));
671 return enqueue(
672 std::move(entry),
673 "mpi:scatter",
674 !inputTensors.empty()
675 ? std::optional<std::vector<at::Tensor>>(inputTensors[0])
676 : std::nullopt);
677 } else {
678 auto entry = std::make_unique<WorkEntry>(
679 nullptr, &outputTensors, std::move(runFunc));
680 return enqueue(
681 std::move(entry),
682 "mpi:scatter",
683 !inputTensors.empty()
684 ? std::optional<std::vector<at::Tensor>>(inputTensors[0])
685 : std::nullopt);
686 }
687 }
688
reduce_scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ReduceScatterOptions & opts)689 c10::intrusive_ptr<Work> ProcessGroupMPI::reduce_scatter(
690 std::vector<at::Tensor>& outputTensors,
691 std::vector<std::vector<at::Tensor>>& inputTensors,
692 const ReduceScatterOptions& opts) {
693 TORCH_CHECK(false, "ProcessGroupMPI does not support reduce_scatter");
694 }
695
alltoall_base(at::Tensor & outputTensor,at::Tensor & inputTensor,std::vector<int64_t> & outputSplitSizes,std::vector<int64_t> & inputSplitSizes,const AllToAllOptions & opts)696 c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall_base(
697 at::Tensor& outputTensor,
698 at::Tensor& inputTensor,
699 std::vector<int64_t>& outputSplitSizes,
700 std::vector<int64_t>& inputSplitSizes,
701 const AllToAllOptions& opts) {
702 checkSingleTensorHelper(inputTensor);
703 checkSingleTensorHelper(outputTensor);
704
705 if (outputSplitSizes.empty() && inputSplitSizes.empty()) {
706 // We can use alltoall
707 TORCH_CHECK(
708 outputTensor.numel() == inputTensor.numel() &&
709 outputTensor.type() == inputTensor.type(),
710 "Tensors are not equal in size or data type");
711 TORCH_CHECK(
712 outputTensor.size(0) % size_ == 0,
713 "Tensor's dim 0 does not divide equally across group size");
714
715 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
716 [this](std::unique_ptr<WorkEntry>& entry) {
717 auto srcdata = (entry->src)[0];
718 auto dstdata = (entry->dst)[0];
719 c10::DeviceGuard guard(srcdata.device());
720 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
721 MPI_CHECK(MPI_Alltoall(
722 srcdata.data_ptr(),
723 srcdata.numel() / size_,
724 mpiDatatype.at(srcdata.scalar_type()),
725 dstdata.data_ptr(),
726 dstdata.numel() / size_,
727 mpiDatatype.at(dstdata.scalar_type()),
728 pgComm_));
729 };
730 std::vector<at::Tensor> inputTensors = {inputTensor};
731 std::vector<at::Tensor> outputTensors = {outputTensor};
732 auto entry = std::make_unique<WorkEntry>(
733 &inputTensors, &outputTensors, std::move(runFunc));
734 return enqueue(
735 std::move(entry),
736 "mpi:all_to_all",
737 std::optional<std::vector<at::Tensor>>(inputTensors));
738 } else {
739 // Need alltoallv
740 c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
741 c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
742 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
743 [this, inputSplitSizes, outputSplitSizes](
744 std::unique_ptr<WorkEntry>& entry) {
745 auto srcdata = (entry->src)[0];
746 auto dstdata = (entry->dst)[0];
747 std::vector<int> send_lengths(size_);
748 std::vector<int> recv_lengths(size_);
749 std::vector<int> send_offsets(size_);
750 std::vector<int> recv_offsets(size_);
751 c10d::computeLengthsAndOffsets(
752 inputSplitSizes, srcdata, &send_lengths, &send_offsets);
753 c10d::computeLengthsAndOffsets(
754 outputSplitSizes, dstdata, &recv_lengths, &recv_offsets);
755 c10::DeviceGuard guard(srcdata.device());
756 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
757 MPI_CHECK(MPI_Alltoallv(
758 srcdata.data_ptr(),
759 send_lengths.data(),
760 send_offsets.data(),
761 mpiDatatype.at(srcdata.scalar_type()),
762 dstdata.data_ptr(),
763 recv_lengths.data(),
764 recv_offsets.data(),
765 mpiDatatype.at(dstdata.scalar_type()),
766 pgComm_));
767 };
768 std::vector<at::Tensor> inputTensors = {inputTensor};
769 std::vector<at::Tensor> outputTensors = {outputTensor};
770 auto entry = std::make_unique<WorkEntry>(
771 &inputTensors, &outputTensors, std::move(runFunc));
772 return enqueue(
773 std::move(entry),
774 "mpi:all_to_all",
775 std::optional<std::vector<at::Tensor>>(inputTensors));
776 }
777 }
778
alltoall(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllToAllOptions & opts)779 c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall(
780 std::vector<at::Tensor>& outputTensors,
781 std::vector<at::Tensor>& inputTensors,
782 const AllToAllOptions& opts) {
783 TORCH_CHECK(
784 inputTensors.size() == static_cast<size_t>(size_),
785 "Number of input tensors are not equal to group size");
786 TORCH_CHECK(
787 outputTensors.size() == static_cast<size_t>(size_),
788 "Number of output tensors are not equal to group size");
789 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
790 [this](std::unique_ptr<WorkEntry>& entry) {
791 std::vector<int> send_lengths(size_);
792 std::vector<int> recv_lengths(size_);
793 std::vector<int> send_offsets(size_);
794 std::vector<int> recv_offsets(size_);
795 auto srcdata = entry->src;
796 auto dstdata = entry->dst;
797 auto src_len = c10d::computeLengthsAndOffsets(
798 srcdata, &send_lengths, &send_offsets);
799 auto dst_len = c10d::computeLengthsAndOffsets(
800 dstdata, &recv_lengths, &recv_offsets);
801 std::vector<int64_t> send_lengthsL(
802 send_lengths.begin(), send_lengths.end());
803 std::vector<int64_t> recv_lengthsL(
804 recv_lengths.begin(), recv_lengths.end());
805 at::Tensor srcFlatData =
806 at::empty({static_cast<int64_t>(src_len)}, srcdata[0].options());
807 at::Tensor dstFlatData =
808 at::empty({static_cast<int64_t>(dst_len)}, dstdata[0].options());
809 auto srcFlatDataSplits =
810 srcFlatData.split_with_sizes(c10::IntArrayRef(send_lengthsL), 0);
811 for (const auto i : c10::irange(size_)) {
812 srcFlatDataSplits[i].copy_(srcdata[i].view({-1}));
813 }
814 c10::DeviceGuard guard1(srcdata[0].device());
815 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
816 MPI_CHECK(MPI_Alltoallv(
817 srcFlatData.data_ptr(),
818 send_lengths.data(),
819 send_offsets.data(),
820 mpiDatatype.at(srcdata[0].scalar_type()),
821 dstFlatData.data_ptr(),
822 recv_lengths.data(),
823 recv_offsets.data(),
824 mpiDatatype.at(dstdata[0].scalar_type()),
825 pgComm_));
826
827 auto dstFlatDataSplits =
828 dstFlatData.split_with_sizes(c10::IntArrayRef(recv_lengthsL), 0);
829 for (const auto i : c10::irange(size_)) {
830 dstdata[i].view({-1}).copy_(dstFlatDataSplits[i]);
831 }
832 };
833 auto entry = std::make_unique<WorkEntry>(
834 &inputTensors, &outputTensors, std::move(runFunc));
835 return enqueue(
836 std::move(entry),
837 "mpi:all_to_all",
838 std::optional<std::vector<at::Tensor>>(inputTensors));
839 }
840
send(std::vector<at::Tensor> & tensors,int dstRank,int tag)841 c10::intrusive_ptr<Work> ProcessGroupMPI::send(
842 std::vector<at::Tensor>& tensors,
843 int dstRank,
844 int tag) {
845 checkSingleTensor(tensors);
846
847 auto& tensor = tensors[0];
848 MPI_Request request = MPI_REQUEST_NULL;
849
850 {
851 c10::DeviceGuard guard(tensor.device());
852 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
853 MPI_CHECK(MPI_Isend(
854 tensor.data_ptr(),
855 tensor.numel(),
856 mpiDatatype.at(tensor.scalar_type()),
857 dstRank,
858 tag,
859 pgComm_,
860 &request));
861 }
862
863 return c10::make_intrusive<AsyncWork>(
864 request,
865 std::vector<at::Tensor>(),
866 "mpi:send",
867 std::optional<std::vector<at::Tensor>>(tensors));
868 }
869
recv(std::vector<at::Tensor> & tensors,int srcRank,int tag)870 c10::intrusive_ptr<Work> ProcessGroupMPI::recv(
871 std::vector<at::Tensor>& tensors,
872 int srcRank,
873 int tag) {
874 checkSingleTensor(tensors);
875
876 auto& tensor = tensors[0];
877 MPI_Request request = MPI_REQUEST_NULL;
878
879 {
880 c10::DeviceGuard guard(tensor.device());
881 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
882 MPI_CHECK(MPI_Irecv(
883 tensor.data_ptr(),
884 tensor.numel(),
885 mpiDatatype.at(tensor.scalar_type()),
886 srcRank,
887 tag,
888 pgComm_,
889 &request));
890 }
891
892 return c10::make_intrusive<AsyncWork>(
893 request,
894 tensors,
895 "mpi:recv",
896 std::optional<std::vector<at::Tensor>>(tensors));
897 }
898
recvAnysource(std::vector<at::Tensor> & tensors,int tag)899 c10::intrusive_ptr<Work> ProcessGroupMPI::recvAnysource(
900 std::vector<at::Tensor>& tensors,
901 int tag) {
902 checkSingleTensor(tensors);
903
904 auto& tensor = tensors[0];
905 MPI_Request request = MPI_REQUEST_NULL;
906
907 {
908 c10::DeviceGuard guard(tensor.device());
909 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
910 MPI_CHECK(MPI_Irecv(
911 tensor.data_ptr(),
912 tensor.numel(),
913 mpiDatatype.at(tensor.scalar_type()),
914 MPI_ANY_SOURCE,
915 tag,
916 pgComm_,
917 &request));
918 }
919
920 return c10::make_intrusive<AsyncWork>(
921 request,
922 tensors,
923 "mpi:recvAnySource",
924 std::optional<std::vector<at::Tensor>>(tensors));
925 }
926
barrier(const BarrierOptions & opts)927 c10::intrusive_ptr<Work> ProcessGroupMPI::barrier(const BarrierOptions& opts) {
928 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
929 [this](std::unique_ptr<WorkEntry>& entry) {
930 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
931 MPI_CHECK(MPI_Barrier(pgComm_));
932 };
933 auto entry =
934 std::make_unique<WorkEntry>(nullptr, nullptr, std::move(runFunc));
935 return enqueue(std::move(entry), "mpi:barrier", std::nullopt);
936 }
937
_allgather_base(at::Tensor &,at::Tensor &,const AllgatherOptions &)938 c10::intrusive_ptr<Work> ProcessGroupMPI::_allgather_base(
939 at::Tensor& /*unused */,
940 at::Tensor& /*unused */,
941 const AllgatherOptions& /*unused */) {
942 TORCH_CHECK(false, "no support for _allgather_base in MPI process group");
943 }
944
945 } // namespace c10d
946
947 #endif // USE_C10D_MPI
948