xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_C10D_MPI
4 
5 #include <condition_variable>
6 #include <deque>
7 #include <exception>
8 #include <memory>
9 #include <mutex>
10 #include <thread>
11 #include <vector>
12 
13 #include <ATen/core/ivalue.h>
14 #include <ATen/core/ivalue_inl.h>
15 
16 #include <torch/csrc/distributed/c10d/Backend.hpp>
17 #include <torch/csrc/distributed/c10d/Types.hpp>
18 #include <torch/csrc/distributed/c10d/Utils.hpp>
19 
20 #include <c10/util/CallOnce.h>
21 
22 #include <mpi.h>
23 
24 namespace c10d {
25 
26 constexpr const char* MPI_BACKEND_NAME = "mpi";
27 
28 // WorkEntry is the state associated with a single MPI run instance.
29 // It include the source Tensor list and destination Tensor list, as well as
30 // The actual run function that will operate either on src or dst or both.
31 struct WorkEntry {
WorkEntryc10d::WorkEntry32   explicit WorkEntry(
33       std::vector<at::Tensor>* srcPtr,
34       std::vector<at::Tensor>* dstPtr,
35       std::function<void(std::unique_ptr<WorkEntry>&)> run)
36       : dst(dstPtr ? *dstPtr : std::vector<at::Tensor>()), run(std::move(run)) {
37     if (srcPtr) {
38       src = *srcPtr;
39     }
40   }
41 
42   // Not copyable
43   WorkEntry(const WorkEntry&) = delete;
44   // Not copy assignable
45   WorkEntry& operator=(const WorkEntry&) = delete;
46 
47   // For input and output tensors (in-place), we will always use src
48   std::vector<at::Tensor> src;
49 
50   // Copy of user provided outputs.
51   const std::vector<at::Tensor> dst;
52 
53   // src rank returned, for recv only
54   int* srcRank = nullptr;
55   std::function<void(std::unique_ptr<WorkEntry>&)> run;
56 };
57 
58 // ProcessGroupMPI implements MPI bindings for c10d.
59 //
60 // All functions on this class are expected to be called in the same
61 // order across processes in the group. This is the only way that we
62 // can guarantee to match up the same calls across processes.
63 //
64 // All MPI functions provided by this class is asynchronously scheduled on a
65 // Worker thread. Therefore, ProcessGroupMPI requires the MPI implementation
66 // that is used to have a minimum thread support value of MPI_THREAD_SERIALIZED.
67 // That is, The process may be multi-threaded, and multiple threads may make
68 // MPI calls, but only one at a time: MPI calls are not made concurrently from
69 // two distinct threads (all MPI calls are serialized). However, with
70 // MPI_THREAD_SERIALIZED, ProcessGroupMPI will only support a singe process
71 // group. In other words, no more than 1 process group can be created globally.
72 //
73 // If you would like to use multiple ProcessGroupMPI, it requires your MPI
74 // implementation to have a thread support value of MPI_THREAD_MULTIPLE, that
75 // is, multiple threads may call MPI, with no restriction.
76 //
77 // Also note that ProcessGroupMPI only supports a single Tensor operation. In
78 // other words, the size of the input Tensor vector should always be 1.
79 //
80 // CUDA tensor can be supported if the MPI used is CUDA-aware MPI, and
81 // ProcessGroupMPI will automatically detect this support.
82 class TORCH_API ProcessGroupMPI : public Backend {
83  public:
84   class WorkMPI : public Work {
85    public:
WorkMPI(std::vector<at::Tensor> outputTensors,const char * profilingTitle=nullptr,const std::optional<std::vector<at::Tensor>> & inputTensors=std::nullopt)86     explicit WorkMPI(
87         std::vector<at::Tensor> outputTensors,
88         const char* profilingTitle = nullptr,
89         const std::optional<std::vector<at::Tensor>>& inputTensors =
90             std::nullopt)
91         : Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors),
92           outputTensors_(std::move(outputTensors)),
93           future_(c10::make_intrusive<at::ivalue::Future>(
94               c10::ListType::create(c10::TensorType::get()))) {}
95 
96     std::vector<at::Tensor> result() override;
97 
98     c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
99 
100    protected:
101     friend class ProcessGroupMPI;
102 
103    private:
104     void finishWorkMPI();
105     void finishWorkMPIError(const std::exception_ptr& eptr);
106 
107     std::vector<at::Tensor> outputTensors_;
108     c10::intrusive_ptr<at::ivalue::Future> future_;
109   };
110 
111   class AsyncWork : public Work {
112    public:
113     AsyncWork(
114         MPI_Request request,
115         std::vector<at::Tensor> outputTensors,
116         const char* profilingTitle = nullptr,
117         const std::optional<std::vector<at::Tensor>>& inputTensors =
118             std::nullopt);
119 
120     ~AsyncWork() override;
121 
122     bool isCompleted() override;
123 
124     bool isSuccess() const override;
125 
126     int sourceRank() const override;
127 
128     bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
129 
130     void abort() override;
131 
132     std::vector<at::Tensor> result() override;
133 
134    protected:
135     void populateException();
136 
137    private:
138     const std::vector<at::Tensor> outputTensors_;
139     MPI_Request request_;
140     MPI_Status status_{};
141   };
142 
143   // Constructor will spawn up the worker thread loop
144   explicit ProcessGroupMPI(int rank, int size, MPI_Comm pgComm);
145 
146   ~ProcessGroupMPI() override;
147 
148   // Abort the MPI program, needs to be called when exception is detected
149   void abort();
150 
getBackendName() const151   const std::string getBackendName() const override {
152     return std::string(MPI_BACKEND_NAME);
153   }
154 
155   c10::intrusive_ptr<Work> broadcast(
156       std::vector<at::Tensor>& data,
157       const BroadcastOptions& opts = BroadcastOptions()) override;
158 
159   c10::intrusive_ptr<Work> allreduce(
160       std::vector<at::Tensor>& tensors,
161       const AllreduceOptions& opts = AllreduceOptions()) override;
162 
163   c10::intrusive_ptr<Work> allreduce_coalesced(
164       std::vector<at::Tensor>& tensors,
165       const AllreduceCoalescedOptions& opts =
166           AllreduceCoalescedOptions()) override;
167 
168   c10::intrusive_ptr<Work> reduce(
169       std::vector<at::Tensor>& tensors,
170       const ReduceOptions& opts = ReduceOptions()) override;
171 
172   c10::intrusive_ptr<Work> allgather(
173       std::vector<std::vector<at::Tensor>>& outputTensors,
174       std::vector<at::Tensor>& inputTensors,
175       const AllgatherOptions& opts = AllgatherOptions()) override;
176 
177   c10::intrusive_ptr<Work> _allgather_base(
178       at::Tensor& outputbuffer,
179       at::Tensor& inputbuffer,
180       const AllgatherOptions& opts = AllgatherOptions()) override;
181 
182   c10::intrusive_ptr<Work> allgather_coalesced(
183       std::vector<std::vector<at::Tensor>>& outputTensorLists,
184       std::vector<at::Tensor>& inputTensors,
185       const AllgatherOptions& opts = AllgatherOptions()) override;
186 
187   c10::intrusive_ptr<Work> gather(
188       std::vector<std::vector<at::Tensor>>& outputTensors,
189       std::vector<at::Tensor>& inputTensors,
190       const GatherOptions& opts = GatherOptions()) override;
191 
192   c10::intrusive_ptr<Work> scatter(
193       std::vector<at::Tensor>& outputTensors,
194       std::vector<std::vector<at::Tensor>>& inputTensors,
195       const ScatterOptions& opts = ScatterOptions()) override;
196 
197   c10::intrusive_ptr<Work> reduce_scatter(
198       std::vector<at::Tensor>& outputTensors,
199       std::vector<std::vector<at::Tensor>>& inputTensors,
200       const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
201 
202   c10::intrusive_ptr<Work> alltoall_base(
203       at::Tensor& outputTensor,
204       at::Tensor& inputTensor,
205       std::vector<int64_t>& outputSplitSizes,
206       std::vector<int64_t>& inputSplitSizes,
207       const AllToAllOptions& opts = AllToAllOptions()) override;
208 
209   c10::intrusive_ptr<Work> alltoall(
210       std::vector<at::Tensor>& outputTensors,
211       std::vector<at::Tensor>& inputTensors,
212       const AllToAllOptions& opts = AllToAllOptions()) override;
213 
214   c10::intrusive_ptr<Work> send(
215       std::vector<at::Tensor>& tensors,
216       int dstRank,
217       int tag) override;
218 
219   c10::intrusive_ptr<Work> recv(
220       std::vector<at::Tensor>& tensors,
221       int srcRank,
222       int tag) override;
223 
224   c10::intrusive_ptr<Work> recvAnysource(
225       std::vector<at::Tensor>& tensor,
226       int tag) override;
227 
228   c10::intrusive_ptr<Work> barrier(
229       const BarrierOptions& opts = BarrierOptions()) override;
230 
231   // Creating a new ProcessGroupMPI, will initialize MPI if not initialized
232   static c10::intrusive_ptr<ProcessGroupMPI> createProcessGroupMPI(
233       std::vector<int> ranks = {});
234 
235  protected:
236   using WorkType =
237       std::tuple<std::unique_ptr<WorkEntry>, c10::intrusive_ptr<WorkMPI>>;
238   // Worker thread loop
239   void runLoop();
240   // Helper function that is called by the destructor
241   void destroy();
242 
243   c10::intrusive_ptr<Work> enqueue(
244       std::unique_ptr<WorkEntry> entry,
245       const char* profilingTitle = nullptr,
246       const std::optional<std::vector<at::Tensor>>& inputTensors =
247           std::nullopt);
248 
249   bool stop_;
250 
251   std::mutex pgMutex_;
252   std::thread workerThread_;
253 
254   std::deque<WorkType> queue_;
255   std::condition_variable queueProduceCV_;
256   std::condition_variable queueConsumeCV_;
257 
258   // Global states
259   static void initMPIOnce();
260   static void mpiExit();
261   static c10::once_flag onceFlagInitMPI;
262 
263   static std::mutex pgGlobalMutex_;
264   static int mpiThreadSupport_;
265 
266   MPI_Comm pgComm_;
267 };
268 
269 } // namespace c10d
270 
271 #endif // USE_C10D_MPI
272