1 #pragma once 2 3 #ifdef USE_C10D_MPI 4 5 #include <condition_variable> 6 #include <deque> 7 #include <exception> 8 #include <memory> 9 #include <mutex> 10 #include <thread> 11 #include <vector> 12 13 #include <ATen/core/ivalue.h> 14 #include <ATen/core/ivalue_inl.h> 15 16 #include <torch/csrc/distributed/c10d/Backend.hpp> 17 #include <torch/csrc/distributed/c10d/Types.hpp> 18 #include <torch/csrc/distributed/c10d/Utils.hpp> 19 20 #include <c10/util/CallOnce.h> 21 22 #include <mpi.h> 23 24 namespace c10d { 25 26 constexpr const char* MPI_BACKEND_NAME = "mpi"; 27 28 // WorkEntry is the state associated with a single MPI run instance. 29 // It include the source Tensor list and destination Tensor list, as well as 30 // The actual run function that will operate either on src or dst or both. 31 struct WorkEntry { WorkEntryc10d::WorkEntry32 explicit WorkEntry( 33 std::vector<at::Tensor>* srcPtr, 34 std::vector<at::Tensor>* dstPtr, 35 std::function<void(std::unique_ptr<WorkEntry>&)> run) 36 : dst(dstPtr ? *dstPtr : std::vector<at::Tensor>()), run(std::move(run)) { 37 if (srcPtr) { 38 src = *srcPtr; 39 } 40 } 41 42 // Not copyable 43 WorkEntry(const WorkEntry&) = delete; 44 // Not copy assignable 45 WorkEntry& operator=(const WorkEntry&) = delete; 46 47 // For input and output tensors (in-place), we will always use src 48 std::vector<at::Tensor> src; 49 50 // Copy of user provided outputs. 51 const std::vector<at::Tensor> dst; 52 53 // src rank returned, for recv only 54 int* srcRank = nullptr; 55 std::function<void(std::unique_ptr<WorkEntry>&)> run; 56 }; 57 58 // ProcessGroupMPI implements MPI bindings for c10d. 59 // 60 // All functions on this class are expected to be called in the same 61 // order across processes in the group. This is the only way that we 62 // can guarantee to match up the same calls across processes. 63 // 64 // All MPI functions provided by this class is asynchronously scheduled on a 65 // Worker thread. Therefore, ProcessGroupMPI requires the MPI implementation 66 // that is used to have a minimum thread support value of MPI_THREAD_SERIALIZED. 67 // That is, The process may be multi-threaded, and multiple threads may make 68 // MPI calls, but only one at a time: MPI calls are not made concurrently from 69 // two distinct threads (all MPI calls are serialized). However, with 70 // MPI_THREAD_SERIALIZED, ProcessGroupMPI will only support a singe process 71 // group. In other words, no more than 1 process group can be created globally. 72 // 73 // If you would like to use multiple ProcessGroupMPI, it requires your MPI 74 // implementation to have a thread support value of MPI_THREAD_MULTIPLE, that 75 // is, multiple threads may call MPI, with no restriction. 76 // 77 // Also note that ProcessGroupMPI only supports a single Tensor operation. In 78 // other words, the size of the input Tensor vector should always be 1. 79 // 80 // CUDA tensor can be supported if the MPI used is CUDA-aware MPI, and 81 // ProcessGroupMPI will automatically detect this support. 82 class TORCH_API ProcessGroupMPI : public Backend { 83 public: 84 class WorkMPI : public Work { 85 public: WorkMPI(std::vector<at::Tensor> outputTensors,const char * profilingTitle=nullptr,const std::optional<std::vector<at::Tensor>> & inputTensors=std::nullopt)86 explicit WorkMPI( 87 std::vector<at::Tensor> outputTensors, 88 const char* profilingTitle = nullptr, 89 const std::optional<std::vector<at::Tensor>>& inputTensors = 90 std::nullopt) 91 : Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors), 92 outputTensors_(std::move(outputTensors)), 93 future_(c10::make_intrusive<at::ivalue::Future>( 94 c10::ListType::create(c10::TensorType::get()))) {} 95 96 std::vector<at::Tensor> result() override; 97 98 c10::intrusive_ptr<c10::ivalue::Future> getFuture() override; 99 100 protected: 101 friend class ProcessGroupMPI; 102 103 private: 104 void finishWorkMPI(); 105 void finishWorkMPIError(const std::exception_ptr& eptr); 106 107 std::vector<at::Tensor> outputTensors_; 108 c10::intrusive_ptr<at::ivalue::Future> future_; 109 }; 110 111 class AsyncWork : public Work { 112 public: 113 AsyncWork( 114 MPI_Request request, 115 std::vector<at::Tensor> outputTensors, 116 const char* profilingTitle = nullptr, 117 const std::optional<std::vector<at::Tensor>>& inputTensors = 118 std::nullopt); 119 120 ~AsyncWork() override; 121 122 bool isCompleted() override; 123 124 bool isSuccess() const override; 125 126 int sourceRank() const override; 127 128 bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override; 129 130 void abort() override; 131 132 std::vector<at::Tensor> result() override; 133 134 protected: 135 void populateException(); 136 137 private: 138 const std::vector<at::Tensor> outputTensors_; 139 MPI_Request request_; 140 MPI_Status status_{}; 141 }; 142 143 // Constructor will spawn up the worker thread loop 144 explicit ProcessGroupMPI(int rank, int size, MPI_Comm pgComm); 145 146 ~ProcessGroupMPI() override; 147 148 // Abort the MPI program, needs to be called when exception is detected 149 void abort(); 150 getBackendName() const151 const std::string getBackendName() const override { 152 return std::string(MPI_BACKEND_NAME); 153 } 154 155 c10::intrusive_ptr<Work> broadcast( 156 std::vector<at::Tensor>& data, 157 const BroadcastOptions& opts = BroadcastOptions()) override; 158 159 c10::intrusive_ptr<Work> allreduce( 160 std::vector<at::Tensor>& tensors, 161 const AllreduceOptions& opts = AllreduceOptions()) override; 162 163 c10::intrusive_ptr<Work> allreduce_coalesced( 164 std::vector<at::Tensor>& tensors, 165 const AllreduceCoalescedOptions& opts = 166 AllreduceCoalescedOptions()) override; 167 168 c10::intrusive_ptr<Work> reduce( 169 std::vector<at::Tensor>& tensors, 170 const ReduceOptions& opts = ReduceOptions()) override; 171 172 c10::intrusive_ptr<Work> allgather( 173 std::vector<std::vector<at::Tensor>>& outputTensors, 174 std::vector<at::Tensor>& inputTensors, 175 const AllgatherOptions& opts = AllgatherOptions()) override; 176 177 c10::intrusive_ptr<Work> _allgather_base( 178 at::Tensor& outputbuffer, 179 at::Tensor& inputbuffer, 180 const AllgatherOptions& opts = AllgatherOptions()) override; 181 182 c10::intrusive_ptr<Work> allgather_coalesced( 183 std::vector<std::vector<at::Tensor>>& outputTensorLists, 184 std::vector<at::Tensor>& inputTensors, 185 const AllgatherOptions& opts = AllgatherOptions()) override; 186 187 c10::intrusive_ptr<Work> gather( 188 std::vector<std::vector<at::Tensor>>& outputTensors, 189 std::vector<at::Tensor>& inputTensors, 190 const GatherOptions& opts = GatherOptions()) override; 191 192 c10::intrusive_ptr<Work> scatter( 193 std::vector<at::Tensor>& outputTensors, 194 std::vector<std::vector<at::Tensor>>& inputTensors, 195 const ScatterOptions& opts = ScatterOptions()) override; 196 197 c10::intrusive_ptr<Work> reduce_scatter( 198 std::vector<at::Tensor>& outputTensors, 199 std::vector<std::vector<at::Tensor>>& inputTensors, 200 const ReduceScatterOptions& opts = ReduceScatterOptions()) override; 201 202 c10::intrusive_ptr<Work> alltoall_base( 203 at::Tensor& outputTensor, 204 at::Tensor& inputTensor, 205 std::vector<int64_t>& outputSplitSizes, 206 std::vector<int64_t>& inputSplitSizes, 207 const AllToAllOptions& opts = AllToAllOptions()) override; 208 209 c10::intrusive_ptr<Work> alltoall( 210 std::vector<at::Tensor>& outputTensors, 211 std::vector<at::Tensor>& inputTensors, 212 const AllToAllOptions& opts = AllToAllOptions()) override; 213 214 c10::intrusive_ptr<Work> send( 215 std::vector<at::Tensor>& tensors, 216 int dstRank, 217 int tag) override; 218 219 c10::intrusive_ptr<Work> recv( 220 std::vector<at::Tensor>& tensors, 221 int srcRank, 222 int tag) override; 223 224 c10::intrusive_ptr<Work> recvAnysource( 225 std::vector<at::Tensor>& tensor, 226 int tag) override; 227 228 c10::intrusive_ptr<Work> barrier( 229 const BarrierOptions& opts = BarrierOptions()) override; 230 231 // Creating a new ProcessGroupMPI, will initialize MPI if not initialized 232 static c10::intrusive_ptr<ProcessGroupMPI> createProcessGroupMPI( 233 std::vector<int> ranks = {}); 234 235 protected: 236 using WorkType = 237 std::tuple<std::unique_ptr<WorkEntry>, c10::intrusive_ptr<WorkMPI>>; 238 // Worker thread loop 239 void runLoop(); 240 // Helper function that is called by the destructor 241 void destroy(); 242 243 c10::intrusive_ptr<Work> enqueue( 244 std::unique_ptr<WorkEntry> entry, 245 const char* profilingTitle = nullptr, 246 const std::optional<std::vector<at::Tensor>>& inputTensors = 247 std::nullopt); 248 249 bool stop_; 250 251 std::mutex pgMutex_; 252 std::thread workerThread_; 253 254 std::deque<WorkType> queue_; 255 std::condition_variable queueProduceCV_; 256 std::condition_variable queueConsumeCV_; 257 258 // Global states 259 static void initMPIOnce(); 260 static void mpiExit(); 261 static c10::once_flag onceFlagInitMPI; 262 263 static std::mutex pgGlobalMutex_; 264 static int mpiThreadSupport_; 265 266 MPI_Comm pgComm_; 267 }; 268 269 } // namespace c10d 270 271 #endif // USE_C10D_MPI 272