xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
2 
3 #ifdef USE_C10D_MPI
4 
5 #include <iostream>
6 #include <map>
7 
8 #include <c10/core/DeviceGuard.h>
9 #include <c10/util/irange.h>
10 
11 #if defined(OPEN_MPI) && OPEN_MPI
12 #include <mpi-ext.h> // Needed for CUDA-aware check
13 #endif
14 
15 namespace c10d {
16 
17 #define MPI_CHECK(cmd)                                                   \
18   do {                                                                   \
19     int mpiStatus = cmd;                                                 \
20     if (mpiStatus != MPI_SUCCESS) {                                      \
21       std::string err = "MPI error in: " + std::string(__FILE__) + ":" + \
22           std::to_string(__LINE__) +                                     \
23           ", with error code: " + std::to_string(mpiStatus);             \
24       TORCH_CHECK(false, err);                                           \
25     }                                                                    \
26   } while (0)
27 
28 namespace {
29 
30 // Op mapping
31 std::map<ReduceOp::RedOpType, MPI_Op> mpiOp = {
32     {ReduceOp::MIN, MPI_MIN},
33     {ReduceOp::MAX, MPI_MAX},
34     {ReduceOp::SUM, MPI_SUM},
35     {ReduceOp::PRODUCT, MPI_PROD},
36 };
37 // Type mapping
38 std::map<at::ScalarType, MPI_Datatype> mpiDatatype = {
39     {at::kByte, MPI_UNSIGNED_CHAR},
40     {at::kChar, MPI_CHAR},
41     {at::kDouble, MPI_DOUBLE},
42     {at::kFloat, MPI_FLOAT},
43     {at::kInt, MPI_INT},
44     {at::kLong, MPI_LONG},
45     {at::kShort, MPI_SHORT},
46 };
47 
48 // Checking CUDA-aware MPI support, currently we only support CUDA aware
49 // MPI ops through Open MPI
cudaAwareMpiCheck()50 bool cudaAwareMpiCheck() {
51 // Run time check
52 #if defined(MPIX_CUDA_AWARE_SUPPORT)
53   if (MPIX_Query_cuda_support() == 1) {
54     return true;
55   } else {
56     return false;
57   }
58 #else // !defined(MPIX_CUDA_AWARE_SUPPORT)
59   return false;
60 #endif // MPIX_CUDA_AWARE_SUPPORT
61 }
62 
63 // Checking the input tensor's validity
checkSingleTensorHelper(const at::Tensor & tensor)64 void checkSingleTensorHelper(const at::Tensor& tensor) {
65   if (!tensor.is_contiguous()) {
66     TORCH_CHECK(false, "input tensor has to be contiguous");
67   }
68   if (tensor.is_sparse()) {
69     TORCH_CHECK(false, "input tensor has to be dense");
70   }
71   if (tensor.is_cuda() && !cudaAwareMpiCheck()) {
72     TORCH_CHECK(
73         false,
74         "CUDA tensor detected and the MPI used doesn't "
75         "have CUDA-aware MPI support");
76   }
77 }
78 
checkSingleTensor(const std::vector<at::Tensor> & tensors)79 void checkSingleTensor(const std::vector<at::Tensor>& tensors) {
80   if (tensors.size() != 1) {
81     TORCH_CHECK(
82         false, "MPI process group does not support multi-GPU collectives");
83   }
84   checkSingleTensorHelper(tensors[0]);
85 }
86 
checkSameSizeAndType(const at::Tensor & t_in,const std::vector<at::Tensor> & tensors)87 void checkSameSizeAndType(
88     const at::Tensor& t_in,
89     const std::vector<at::Tensor>& tensors) {
90   for (const auto& tensor : tensors) {
91     if ((tensor.numel() != t_in.numel()) ||
92         (tensor.scalar_type() != t_in.scalar_type())) {
93       TORCH_CHECK(false, "Tensors are not equal in size or data type");
94     }
95     checkSingleTensorHelper(tensor);
96   }
97 }
98 
99 } // namespace
100 
result()101 std::vector<at::Tensor> ProcessGroupMPI::WorkMPI::result() {
102   return outputTensors_;
103 }
104 
getFuture()105 c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupMPI::WorkMPI::getFuture() {
106   return future_;
107 }
108 
finishWorkMPIError(const std::exception_ptr & eptr)109 void ProcessGroupMPI::WorkMPI::finishWorkMPIError(
110     const std::exception_ptr& eptr) {
111   future_->setError(eptr);
112   finish(eptr);
113 }
114 
finishWorkMPI()115 void ProcessGroupMPI::WorkMPI::finishWorkMPI() {
116   future_->markCompleted(at::IValue(outputTensors_));
117   finish();
118 }
119 
AsyncWork(MPI_Request request,std::vector<at::Tensor> outputTensors,const char * profilingTitle,const std::optional<std::vector<at::Tensor>> & inputTensors)120 ProcessGroupMPI::AsyncWork::AsyncWork(
121     MPI_Request request,
122     std::vector<at::Tensor> outputTensors,
123     const char* profilingTitle,
124     const std::optional<std::vector<at::Tensor>>& inputTensors)
125     : Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors),
126       outputTensors_(std::move(outputTensors)),
127       request_(request) {
128   memset(&status_, 0, sizeof(status_));
129 }
130 
~AsyncWork()131 ProcessGroupMPI::AsyncWork::~AsyncWork() {
132   if (request_ != MPI_REQUEST_NULL) {
133     std::cerr
134         << "Attempted destruction of AsyncWork before work has completed, "
135         << "terminating the program." << '\n';
136     std::terminate();
137   }
138 }
139 
isCompleted()140 bool ProcessGroupMPI::AsyncWork::isCompleted() {
141   if (request_ == MPI_REQUEST_NULL) {
142     return true;
143   }
144 
145   std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
146   int flag = 0;
147   MPI_CHECK(MPI_Test(&request_, &flag, &status_));
148   if (request_ != MPI_REQUEST_NULL) {
149     return false;
150   }
151 
152   // request_ == MPI_REQUEST_NULL; the work has completed
153   // Populate exception if request was not successful
154   if (status_.MPI_ERROR != MPI_SUCCESS) {
155     populateException();
156   }
157 
158   return true;
159 }
160 
isSuccess() const161 bool ProcessGroupMPI::AsyncWork::isSuccess() const {
162   if (request_ != MPI_REQUEST_NULL) {
163     TORCH_CHECK(
164         false,
165         "Invalid call to AsyncWork::isSuccess before work has completed");
166   }
167 
168   return status_.MPI_ERROR == MPI_SUCCESS;
169 }
170 
sourceRank() const171 int ProcessGroupMPI::AsyncWork::sourceRank() const {
172   return status_.MPI_SOURCE;
173 }
174 
wait(std::chrono::milliseconds)175 bool ProcessGroupMPI::AsyncWork::wait(std::chrono::milliseconds /* unused */) {
176   if (request_ == MPI_REQUEST_NULL) {
177     // AsyncWork needs to manually call profiling end callbacks if they are set,
178     // since it does not call ProcessGroup::finish().
179     if (Work::recordFunctionEndCallback_) {
180       Work::recordFunctionEndCallback_();
181       Work::recordFunctionEndCallback_ = nullptr;
182     }
183     return true;
184   }
185 
186   std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
187   MPI_CHECK(MPI_Wait(&request_, &status_));
188   auto ok = (status_.MPI_ERROR == MPI_SUCCESS);
189 
190   // AsyncWork needs to manually call profiling end callbacks if they are set,
191   // since it does not call ProcessGroup::finish().
192   if (Work::recordFunctionEndCallback_) {
193     Work::recordFunctionEndCallback_();
194     Work::recordFunctionEndCallback_ = nullptr;
195   }
196 
197   if (!ok) {
198     populateException();
199     std::rethrow_exception(exception_);
200   }
201   // Always return true, because abort API is not implemented.
202   return true;
203 }
204 
abort()205 void ProcessGroupMPI::AsyncWork::abort(){
206     TORCH_CHECK(false, "ProcessGroupMPI::AsyncWork::abort not implemented.")}
207 
result()208 std::vector<at::Tensor> ProcessGroupMPI::AsyncWork::result() {
209   return outputTensors_;
210 }
211 
populateException()212 void ProcessGroupMPI::AsyncWork::populateException() {
213   std::array<char, MPI_MAX_ERROR_STRING> buf{};
214   int len = buf.size();
215   MPI_CHECK(MPI_Error_string(status_.MPI_ERROR, buf.data(), &len));
216   exception_ =
217       std::make_exception_ptr(std::runtime_error(std::string(buf.data(), len)));
218 }
219 
220 // Static global states
221 int ProcessGroupMPI::mpiThreadSupport_ = 0;
222 std::mutex ProcessGroupMPI::pgGlobalMutex_;
223 // We only want to initialize once
224 c10::once_flag ProcessGroupMPI::onceFlagInitMPI;
225 
mpiExit()226 void ProcessGroupMPI::mpiExit() {
227   std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
228   MPI_CHECK(MPI_Finalize());
229 }
230 
initMPIOnce()231 void ProcessGroupMPI::initMPIOnce() {
232   // Initialize MPI environment
233   c10::call_once(onceFlagInitMPI, []() {
234     int mpi_was_initialized = 0;
235     MPI_CHECK(MPI_Initialized(&mpi_was_initialized));
236     if (mpi_was_initialized == 0) {
237       MPI_CHECK(MPI_Init_thread(
238           nullptr, nullptr, MPI_THREAD_SERIALIZED, &mpiThreadSupport_));
239       if (mpiThreadSupport_ < MPI_THREAD_SERIALIZED) {
240         TORCH_CHECK(
241             false,
242             "Used MPI implementation doesn't have the "
243             "minimum level of threading support: "
244             "MPI_THREAD_SERIALIZED. This is required by "
245             "c10d package");
246       }
247       if (std::atexit(ProcessGroupMPI::mpiExit)) {
248         TORCH_CHECK(false, "Fail to register the MPI exit handler");
249       }
250     } else {
251       TORCH_WARN_ONCE("MPI was previously initialized.");
252     }
253   });
254 }
255 
createProcessGroupMPI(std::vector<int> ranks)256 c10::intrusive_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI(
257     std::vector<int> ranks) {
258   // Once initialization
259   initMPIOnce();
260 
261   MPI_Comm groupComm = MPI_COMM_WORLD;
262   int rank = -1;
263   int size = -1;
264 
265   {
266     std::lock_guard<std::mutex> globalLock(pgGlobalMutex_);
267 
268     // If no ranks are specified, assume we're creating the root group
269     if (!ranks.empty()) {
270       MPI_Group worldGroup{};
271       MPI_Group ranksGroup{};
272       MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
273       MPI_CHECK(
274           MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup));
275       // `MPI_Comm_create` can be flaky in certain cases.
276       // See: https://github.com/pytorch/pytorch/issues/53899
277       constexpr int kMaxNumRetries = 3;
278       bool groupComm_updated = false;
279       MPI_Barrier(MPI_COMM_WORLD);
280       for (const auto i : c10::irange(kMaxNumRetries)) {
281         (void)i;
282         if (MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm)) {
283           groupComm_updated = true;
284           break;
285         }
286       }
287       MPI_CHECK(groupComm_updated);
288       MPI_CHECK(MPI_Group_free(&worldGroup));
289       MPI_CHECK(MPI_Group_free(&ranksGroup));
290     }
291 
292     // Fetch rank and world size for this group (MPI_COMM_WORLD or new)
293     if (groupComm != MPI_COMM_NULL) {
294       MPI_CHECK(MPI_Comm_rank(groupComm, &rank));
295       MPI_CHECK(MPI_Comm_size(groupComm, &size));
296 
297       if (rank < 0 || size < 0) {
298         TORCH_CHECK(false, "Failed to get the world_size / rank");
299       }
300     }
301   }
302 
303   // If this process is not part of the group, we don't construct a
304   // process group instance. This is in line with the semantics of the
305   // other process group types.
306   if (groupComm == MPI_COMM_NULL) {
307     return c10::intrusive_ptr<ProcessGroupMPI>();
308   }
309 
310   return c10::make_intrusive<ProcessGroupMPI>(rank, size, groupComm);
311 }
312 
ProcessGroupMPI(int rank,int size,MPI_Comm pgComm)313 ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm)
314     : Backend(rank, size), stop_(false), pgComm_(pgComm) {
315   if (pgComm_ == MPI_COMM_NULL) {
316     TORCH_CHECK(false, "pgComm_ must not be MPI_COMM_NULL");
317   }
318 
319   // Start the worker thread accepting MPI calls
320   workerThread_ = std::thread(&ProcessGroupMPI::runLoop, this);
321 
322   init();
323 }
324 
~ProcessGroupMPI()325 ProcessGroupMPI::~ProcessGroupMPI() {
326   destroy();
327 }
328 
destroy()329 void ProcessGroupMPI::destroy() {
330   std::unique_lock<std::mutex> lock(pgMutex_);
331   queueConsumeCV_.wait(lock, [&] { return queue_.empty(); });
332 
333   // Queue is empty, signal stop
334   stop_ = true;
335 
336   // Release lock to allow threads to terminate
337   lock.unlock();
338   queueProduceCV_.notify_all();
339 
340   // Join the single worker thread
341   workerThread_.join();
342 }
343 
abort()344 void ProcessGroupMPI::abort() {
345   destroy();
346   MPI_Abort(pgComm_, EXIT_FAILURE);
347 }
348 
runLoop()349 void ProcessGroupMPI::runLoop() {
350   std::unique_lock<std::mutex> lock(pgMutex_);
351 
352   while (!stop_) {
353     if (queue_.empty()) {
354       queueProduceCV_.wait(lock);
355       continue;
356     }
357 
358     auto workTuple = std::move(queue_.front());
359 
360     queue_.pop_front();
361 
362     auto& workEntry = std::get<0>(workTuple);
363     auto& work = std::get<1>(workTuple);
364 
365     lock.unlock();
366     queueConsumeCV_.notify_one();
367 
368     try {
369       workEntry->run(workEntry);
370       work->finishWorkMPI();
371     } catch (...) {
372       work->finishWorkMPIError(std::current_exception());
373     }
374 
375     lock.lock();
376   }
377 }
378 
enqueue(std::unique_ptr<WorkEntry> entry,const char * profilingTitle,const std::optional<std::vector<at::Tensor>> & inputTensors)379 c10::intrusive_ptr<Work> ProcessGroupMPI::enqueue(
380     std::unique_ptr<WorkEntry> entry,
381     const char* profilingTitle,
382     const std::optional<std::vector<at::Tensor>>& inputTensors) {
383   auto work =
384       c10::make_intrusive<WorkMPI>(entry->dst, profilingTitle, inputTensors);
385   std::unique_lock<std::mutex> lock(pgMutex_);
386   queue_.emplace_back(std::move(entry), work);
387   lock.unlock();
388   queueProduceCV_.notify_one();
389   return work;
390 }
391 
broadcast(std::vector<at::Tensor> & tensors,const BroadcastOptions & opts)392 c10::intrusive_ptr<Work> ProcessGroupMPI::broadcast(
393     std::vector<at::Tensor>& tensors,
394     const BroadcastOptions& opts) {
395   checkSingleTensor(tensors);
396   std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
397       [opts, this](std::unique_ptr<WorkEntry>& entry) {
398         auto data = (entry->src)[0];
399         c10::DeviceGuard guard(data.device());
400         std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
401         MPI_CHECK(MPI_Bcast(
402             data.data_ptr(),
403             data.numel(),
404             mpiDatatype.at(data.scalar_type()),
405             opts.rootRank,
406             pgComm_));
407       };
408   auto entry =
409       std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
410   return enqueue(
411       std::move(entry),
412       "mpi:broadcast",
413       std::optional<std::vector<at::Tensor>>(tensors));
414 }
415 
allreduce(std::vector<at::Tensor> & tensors,const AllreduceOptions & opts)416 c10::intrusive_ptr<Work> ProcessGroupMPI::allreduce(
417     std::vector<at::Tensor>& tensors,
418     const AllreduceOptions& opts) {
419   checkSingleTensor(tensors);
420 
421   std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
422       [opts, this](std::unique_ptr<WorkEntry>& entry) {
423         auto data = (entry->src)[0];
424         c10::DeviceGuard guard(data.device());
425         std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
426         MPI_CHECK(MPI_Allreduce(
427             MPI_IN_PLACE,
428             data.data_ptr(),
429             data.numel(),
430             mpiDatatype.at(data.scalar_type()),
431             mpiOp.at(opts.reduceOp),
432             pgComm_));
433       };
434   auto entry =
435       std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
436   return enqueue(
437       std::move(entry),
438       "mpi:all_reduce",
439       std::optional<std::vector<at::Tensor>>(tensors));
440 }
441 
allreduce_coalesced(std::vector<at::Tensor> & tensors,const AllreduceCoalescedOptions & opts)442 c10::intrusive_ptr<Work> ProcessGroupMPI::allreduce_coalesced(
443     std::vector<at::Tensor>& tensors,
444     const AllreduceCoalescedOptions& opts) {
445   TORCH_CHECK(false, "allreduce_coalesced is currently not supported with MPI");
446 }
447 
reduce(std::vector<at::Tensor> & tensors,const ReduceOptions & opts)448 c10::intrusive_ptr<Work> ProcessGroupMPI::reduce(
449     std::vector<at::Tensor>& tensors,
450     const ReduceOptions& opts) {
451   checkSingleTensor(tensors);
452 
453   std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
454       [opts, this](std::unique_ptr<WorkEntry>& entry) {
455         auto data = (entry->src)[0];
456         auto dataPtr = (entry->src)[0].data_ptr();
457         void* sendbuf = (rank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr;
458         void* recvbuf = (rank_ == opts.rootRank) ? dataPtr : nullptr;
459 
460         c10::DeviceGuard guard(data.device());
461         std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
462         MPI_CHECK(MPI_Reduce(
463             sendbuf,
464             recvbuf,
465             data.numel(),
466             mpiDatatype.at(data.scalar_type()),
467             mpiOp.at(opts.reduceOp),
468             opts.rootRank,
469             pgComm_));
470       };
471   auto entry =
472       std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
473   return enqueue(
474       std::move(entry),
475       "mpi:reduce",
476       std::optional<std::vector<at::Tensor>>(tensors));
477 }
478 
allgather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts)479 c10::intrusive_ptr<Work> ProcessGroupMPI::allgather(
480     std::vector<std::vector<at::Tensor>>& outputTensors,
481     std::vector<at::Tensor>& inputTensors,
482     const AllgatherOptions& opts) {
483   checkSingleTensor(inputTensors);
484   if (outputTensors.size() != 1) {
485     TORCH_CHECK(
486         false,
487         "MPI process group only supports a single "
488         "tensor op");
489   }
490   if (static_cast<size_t>(size_) != outputTensors[0].size()) {
491     TORCH_CHECK(
492         false,
493         "All gather: number of output tensors should equal "
494         "to the world size");
495   }
496 
497   checkSameSizeAndType(inputTensors[0], outputTensors[0]);
498 
499   std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
500       [this](std::unique_ptr<WorkEntry>& entry) {
501         auto data = (entry->src)[0];
502         std::vector<at::Tensor> outputDataVec = entry->dst;
503         auto flatOutputTensor = newLikeFlat(outputDataVec);
504 
505         c10::DeviceGuard guard(data.device());
506         std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
507         MPI_CHECK(MPI_Allgather(
508             data.data_ptr(),
509             data.numel(),
510             mpiDatatype.at(data.scalar_type()),
511             flatOutputTensor.data_ptr(),
512             data.numel(),
513             mpiDatatype.at(data.scalar_type()),
514             pgComm_));
515 
516         for (const auto i : c10::irange(outputDataVec.size())) {
517           outputDataVec[i].copy_(flatOutputTensor[static_cast<int64_t>(i)]);
518         }
519       };
520   auto entry = std::make_unique<WorkEntry>(
521       &inputTensors, &outputTensors[0], std::move(runFunc));
522   return enqueue(
523       std::move(entry),
524       "mpi:all_gather",
525       std::optional<std::vector<at::Tensor>>(inputTensors));
526 }
527 
allgather_coalesced(std::vector<std::vector<at::Tensor>> &,std::vector<at::Tensor> &,const AllgatherOptions &)528 c10::intrusive_ptr<Work> ProcessGroupMPI::allgather_coalesced(
529     std::vector<std::vector<at::Tensor>>& /* unused */,
530     std::vector<at::Tensor>& /* unused */,
531     const AllgatherOptions& /* unused */) {
532   TORCH_CHECK(false, "ProcessGroupMPI does not support allgather_coalesced");
533 }
534 
gather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const GatherOptions & opts)535 c10::intrusive_ptr<Work> ProcessGroupMPI::gather(
536     std::vector<std::vector<at::Tensor>>& outputTensors,
537     std::vector<at::Tensor>& inputTensors,
538     const GatherOptions& opts) {
539   checkSingleTensor(inputTensors);
540 
541   if (rank_ != opts.rootRank) {
542     if (!outputTensors.empty()) {
543       TORCH_CHECK(
544           false,
545           "Gather: number of output tensors should be 0 "
546           "for non-root");
547     }
548   } else {
549     if (outputTensors.size() != 1) {
550       TORCH_CHECK(false, "Gather: multi-GPU collective is not supported");
551     }
552     if (static_cast<size_t>(size_) != outputTensors[0].size()) {
553       TORCH_CHECK(
554           false,
555           "Gather: number of output tensors should equal "
556           "to the world size");
557     }
558     checkSameSizeAndType(inputTensors[0], outputTensors[0]);
559   }
560 
561   std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
562       [opts, this](std::unique_ptr<WorkEntry>& entry) {
563         auto data = (entry->src)[0];
564         void* recvbuf = nullptr;
565         at::Tensor flatOutputTensor;
566 
567         std::vector<at::Tensor> dstdata = entry->dst;
568         if (rank_ == opts.rootRank) {
569           flatOutputTensor = newLikeFlat(dstdata);
570           recvbuf = flatOutputTensor.data_ptr();
571         }
572 
573         c10::DeviceGuard guard(data.device());
574         std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
575         MPI_CHECK(MPI_Gather(
576             data.data_ptr(),
577             data.numel(),
578             mpiDatatype.at(data.scalar_type()),
579             recvbuf,
580             data.numel(),
581             mpiDatatype.at(data.scalar_type()),
582             opts.rootRank,
583             pgComm_));
584 
585         if (rank_ == opts.rootRank) {
586           const std::vector<at::Tensor>& outputDataVec = entry->dst;
587           // copy the flattened output tensors to the outputs
588           for (const auto i : c10::irange(outputDataVec.size())) {
589             outputDataVec.at(i).copy_(
590                 flatOutputTensor[static_cast<int64_t>(i)]);
591           }
592         }
593       };
594 
595   if (rank_ == opts.rootRank) {
596     auto entry = std::make_unique<WorkEntry>(
597         &inputTensors, &outputTensors[0], std::move(runFunc));
598     return enqueue(
599         std::move(entry),
600         "mpi:gather",
601         std::optional<std::vector<at::Tensor>>(inputTensors));
602   } else {
603     auto entry =
604         std::make_unique<WorkEntry>(&inputTensors, nullptr, std::move(runFunc));
605     return enqueue(
606         std::move(entry),
607         "mpi:gather",
608         std::optional<std::vector<at::Tensor>>(inputTensors));
609   }
610 }
611 
scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ScatterOptions & opts)612 c10::intrusive_ptr<Work> ProcessGroupMPI::scatter(
613     std::vector<at::Tensor>& outputTensors,
614     std::vector<std::vector<at::Tensor>>& inputTensors,
615     const ScatterOptions& opts) {
616   checkSingleTensor(outputTensors);
617 
618   if (rank_ != opts.rootRank) {
619     if (!inputTensors.empty()) {
620       TORCH_CHECK(
621           false,
622           "Scatter: number of input tensors should be 0 "
623           "for non-root");
624     }
625   } else {
626     if (inputTensors.size() != 1) {
627       TORCH_CHECK(false, "Scatter: multi-GPU collective is not supported");
628     }
629     if (static_cast<size_t>(size_) != inputTensors[0].size()) {
630       TORCH_CHECK(
631           false,
632           "Scatter: number of input tensors should equal "
633           "to the world size");
634     }
635     checkSameSizeAndType(outputTensors[0], inputTensors[0]);
636   }
637 
638   std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
639       [opts, this](std::unique_ptr<WorkEntry>& entry) {
640         auto data = (entry->dst)[0];
641         void* sendbuf = nullptr;
642         at::Tensor flatInputTensor;
643 
644         if (rank_ == opts.rootRank) {
645           std::vector<at::Tensor>& inputDataVec = entry->src;
646           flatInputTensor = newLikeFlat(inputDataVec);
647           sendbuf = flatInputTensor.data_ptr();
648 
649           // copy the input tensors to the flatten large send buffer
650           for (const auto i : c10::irange(inputDataVec.size())) {
651             flatInputTensor[static_cast<int64_t>(i)].copy_(inputDataVec.at(i));
652           }
653         }
654 
655         c10::DeviceGuard guard(data.device());
656         std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
657         MPI_CHECK(MPI_Scatter(
658             sendbuf,
659             data.numel(),
660             mpiDatatype.at(data.scalar_type()),
661             data.data_ptr(),
662             data.numel(),
663             mpiDatatype.at(data.scalar_type()),
664             opts.rootRank,
665             pgComm_));
666       };
667 
668   if (rank_ == opts.rootRank) {
669     auto entry = std::make_unique<WorkEntry>(
670         &inputTensors[0], &outputTensors, std::move(runFunc));
671     return enqueue(
672         std::move(entry),
673         "mpi:scatter",
674         !inputTensors.empty()
675             ? std::optional<std::vector<at::Tensor>>(inputTensors[0])
676             : std::nullopt);
677   } else {
678     auto entry = std::make_unique<WorkEntry>(
679         nullptr, &outputTensors, std::move(runFunc));
680     return enqueue(
681         std::move(entry),
682         "mpi:scatter",
683         !inputTensors.empty()
684             ? std::optional<std::vector<at::Tensor>>(inputTensors[0])
685             : std::nullopt);
686   }
687 }
688 
reduce_scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ReduceScatterOptions & opts)689 c10::intrusive_ptr<Work> ProcessGroupMPI::reduce_scatter(
690     std::vector<at::Tensor>& outputTensors,
691     std::vector<std::vector<at::Tensor>>& inputTensors,
692     const ReduceScatterOptions& opts) {
693   TORCH_CHECK(false, "ProcessGroupMPI does not support reduce_scatter");
694 }
695 
alltoall_base(at::Tensor & outputTensor,at::Tensor & inputTensor,std::vector<int64_t> & outputSplitSizes,std::vector<int64_t> & inputSplitSizes,const AllToAllOptions & opts)696 c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall_base(
697     at::Tensor& outputTensor,
698     at::Tensor& inputTensor,
699     std::vector<int64_t>& outputSplitSizes,
700     std::vector<int64_t>& inputSplitSizes,
701     const AllToAllOptions& opts) {
702   checkSingleTensorHelper(inputTensor);
703   checkSingleTensorHelper(outputTensor);
704 
705   if (outputSplitSizes.empty() && inputSplitSizes.empty()) {
706     // We can use alltoall
707     TORCH_CHECK(
708         outputTensor.numel() == inputTensor.numel() &&
709             outputTensor.type() == inputTensor.type(),
710         "Tensors are not equal in size or data type");
711     TORCH_CHECK(
712         outputTensor.size(0) % size_ == 0,
713         "Tensor's dim 0 does not divide equally across group size");
714 
715     std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
716         [this](std::unique_ptr<WorkEntry>& entry) {
717           auto srcdata = (entry->src)[0];
718           auto dstdata = (entry->dst)[0];
719           c10::DeviceGuard guard(srcdata.device());
720           std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
721           MPI_CHECK(MPI_Alltoall(
722               srcdata.data_ptr(),
723               srcdata.numel() / size_,
724               mpiDatatype.at(srcdata.scalar_type()),
725               dstdata.data_ptr(),
726               dstdata.numel() / size_,
727               mpiDatatype.at(dstdata.scalar_type()),
728               pgComm_));
729         };
730     std::vector<at::Tensor> inputTensors = {inputTensor};
731     std::vector<at::Tensor> outputTensors = {outputTensor};
732     auto entry = std::make_unique<WorkEntry>(
733         &inputTensors, &outputTensors, std::move(runFunc));
734     return enqueue(
735         std::move(entry),
736         "mpi:all_to_all",
737         std::optional<std::vector<at::Tensor>>(inputTensors));
738   } else {
739     // Need alltoallv
740     c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
741     c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
742     std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
743         [this, inputSplitSizes, outputSplitSizes](
744             std::unique_ptr<WorkEntry>& entry) {
745           auto srcdata = (entry->src)[0];
746           auto dstdata = (entry->dst)[0];
747           std::vector<int> send_lengths(size_);
748           std::vector<int> recv_lengths(size_);
749           std::vector<int> send_offsets(size_);
750           std::vector<int> recv_offsets(size_);
751           c10d::computeLengthsAndOffsets(
752               inputSplitSizes, srcdata, &send_lengths, &send_offsets);
753           c10d::computeLengthsAndOffsets(
754               outputSplitSizes, dstdata, &recv_lengths, &recv_offsets);
755           c10::DeviceGuard guard(srcdata.device());
756           std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
757           MPI_CHECK(MPI_Alltoallv(
758               srcdata.data_ptr(),
759               send_lengths.data(),
760               send_offsets.data(),
761               mpiDatatype.at(srcdata.scalar_type()),
762               dstdata.data_ptr(),
763               recv_lengths.data(),
764               recv_offsets.data(),
765               mpiDatatype.at(dstdata.scalar_type()),
766               pgComm_));
767         };
768     std::vector<at::Tensor> inputTensors = {inputTensor};
769     std::vector<at::Tensor> outputTensors = {outputTensor};
770     auto entry = std::make_unique<WorkEntry>(
771         &inputTensors, &outputTensors, std::move(runFunc));
772     return enqueue(
773         std::move(entry),
774         "mpi:all_to_all",
775         std::optional<std::vector<at::Tensor>>(inputTensors));
776   }
777 }
778 
alltoall(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllToAllOptions & opts)779 c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall(
780     std::vector<at::Tensor>& outputTensors,
781     std::vector<at::Tensor>& inputTensors,
782     const AllToAllOptions& opts) {
783   TORCH_CHECK(
784       inputTensors.size() == static_cast<size_t>(size_),
785       "Number of input tensors are not equal to group size");
786   TORCH_CHECK(
787       outputTensors.size() == static_cast<size_t>(size_),
788       "Number of output tensors are not equal to group size");
789   std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
790       [this](std::unique_ptr<WorkEntry>& entry) {
791         std::vector<int> send_lengths(size_);
792         std::vector<int> recv_lengths(size_);
793         std::vector<int> send_offsets(size_);
794         std::vector<int> recv_offsets(size_);
795         auto srcdata = entry->src;
796         auto dstdata = entry->dst;
797         auto src_len = c10d::computeLengthsAndOffsets(
798             srcdata, &send_lengths, &send_offsets);
799         auto dst_len = c10d::computeLengthsAndOffsets(
800             dstdata, &recv_lengths, &recv_offsets);
801         std::vector<int64_t> send_lengthsL(
802             send_lengths.begin(), send_lengths.end());
803         std::vector<int64_t> recv_lengthsL(
804             recv_lengths.begin(), recv_lengths.end());
805         at::Tensor srcFlatData =
806             at::empty({static_cast<int64_t>(src_len)}, srcdata[0].options());
807         at::Tensor dstFlatData =
808             at::empty({static_cast<int64_t>(dst_len)}, dstdata[0].options());
809         auto srcFlatDataSplits =
810             srcFlatData.split_with_sizes(c10::IntArrayRef(send_lengthsL), 0);
811         for (const auto i : c10::irange(size_)) {
812           srcFlatDataSplits[i].copy_(srcdata[i].view({-1}));
813         }
814         c10::DeviceGuard guard1(srcdata[0].device());
815         std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
816         MPI_CHECK(MPI_Alltoallv(
817             srcFlatData.data_ptr(),
818             send_lengths.data(),
819             send_offsets.data(),
820             mpiDatatype.at(srcdata[0].scalar_type()),
821             dstFlatData.data_ptr(),
822             recv_lengths.data(),
823             recv_offsets.data(),
824             mpiDatatype.at(dstdata[0].scalar_type()),
825             pgComm_));
826 
827         auto dstFlatDataSplits =
828             dstFlatData.split_with_sizes(c10::IntArrayRef(recv_lengthsL), 0);
829         for (const auto i : c10::irange(size_)) {
830           dstdata[i].view({-1}).copy_(dstFlatDataSplits[i]);
831         }
832       };
833   auto entry = std::make_unique<WorkEntry>(
834       &inputTensors, &outputTensors, std::move(runFunc));
835   return enqueue(
836       std::move(entry),
837       "mpi:all_to_all",
838       std::optional<std::vector<at::Tensor>>(inputTensors));
839 }
840 
send(std::vector<at::Tensor> & tensors,int dstRank,int tag)841 c10::intrusive_ptr<Work> ProcessGroupMPI::send(
842     std::vector<at::Tensor>& tensors,
843     int dstRank,
844     int tag) {
845   checkSingleTensor(tensors);
846 
847   auto& tensor = tensors[0];
848   MPI_Request request = MPI_REQUEST_NULL;
849 
850   {
851     c10::DeviceGuard guard(tensor.device());
852     std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
853     MPI_CHECK(MPI_Isend(
854         tensor.data_ptr(),
855         tensor.numel(),
856         mpiDatatype.at(tensor.scalar_type()),
857         dstRank,
858         tag,
859         pgComm_,
860         &request));
861   }
862 
863   return c10::make_intrusive<AsyncWork>(
864       request,
865       std::vector<at::Tensor>(),
866       "mpi:send",
867       std::optional<std::vector<at::Tensor>>(tensors));
868 }
869 
recv(std::vector<at::Tensor> & tensors,int srcRank,int tag)870 c10::intrusive_ptr<Work> ProcessGroupMPI::recv(
871     std::vector<at::Tensor>& tensors,
872     int srcRank,
873     int tag) {
874   checkSingleTensor(tensors);
875 
876   auto& tensor = tensors[0];
877   MPI_Request request = MPI_REQUEST_NULL;
878 
879   {
880     c10::DeviceGuard guard(tensor.device());
881     std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
882     MPI_CHECK(MPI_Irecv(
883         tensor.data_ptr(),
884         tensor.numel(),
885         mpiDatatype.at(tensor.scalar_type()),
886         srcRank,
887         tag,
888         pgComm_,
889         &request));
890   }
891 
892   return c10::make_intrusive<AsyncWork>(
893       request,
894       tensors,
895       "mpi:recv",
896       std::optional<std::vector<at::Tensor>>(tensors));
897 }
898 
recvAnysource(std::vector<at::Tensor> & tensors,int tag)899 c10::intrusive_ptr<Work> ProcessGroupMPI::recvAnysource(
900     std::vector<at::Tensor>& tensors,
901     int tag) {
902   checkSingleTensor(tensors);
903 
904   auto& tensor = tensors[0];
905   MPI_Request request = MPI_REQUEST_NULL;
906 
907   {
908     c10::DeviceGuard guard(tensor.device());
909     std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
910     MPI_CHECK(MPI_Irecv(
911         tensor.data_ptr(),
912         tensor.numel(),
913         mpiDatatype.at(tensor.scalar_type()),
914         MPI_ANY_SOURCE,
915         tag,
916         pgComm_,
917         &request));
918   }
919 
920   return c10::make_intrusive<AsyncWork>(
921       request,
922       tensors,
923       "mpi:recvAnySource",
924       std::optional<std::vector<at::Tensor>>(tensors));
925 }
926 
barrier(const BarrierOptions & opts)927 c10::intrusive_ptr<Work> ProcessGroupMPI::barrier(const BarrierOptions& opts) {
928   std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
929       [this](std::unique_ptr<WorkEntry>& entry) {
930         std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
931         MPI_CHECK(MPI_Barrier(pgComm_));
932       };
933   auto entry =
934       std::make_unique<WorkEntry>(nullptr, nullptr, std::move(runFunc));
935   return enqueue(std::move(entry), "mpi:barrier", std::nullopt);
936 }
937 
_allgather_base(at::Tensor &,at::Tensor &,const AllgatherOptions &)938 c10::intrusive_ptr<Work> ProcessGroupMPI::_allgather_base(
939     at::Tensor& /*unused */,
940     at::Tensor& /*unused */,
941     const AllgatherOptions& /*unused */) {
942   TORCH_CHECK(false, "no support for _allgather_base in MPI process group");
943 }
944 
945 } // namespace c10d
946 
947 #endif // USE_C10D_MPI
948