1 #pragma once 2 3 #ifdef USE_C10D_GLOO 4 5 #include <condition_variable> 6 #include <deque> 7 #include <mutex> 8 #include <thread> 9 #include <vector> 10 11 #include <gloo/algorithm.h> 12 #include <gloo/common/error.h> 13 #include <gloo/context.h> 14 #include <gloo/rendezvous/store.h> 15 #include <gloo/transport/device.h> 16 17 #include <c10/util/hash.h> 18 19 #include <torch/csrc/distributed/c10d/Backend.hpp> 20 #include <torch/csrc/distributed/c10d/Store.hpp> 21 #include <torch/csrc/distributed/c10d/Types.hpp> 22 #include <torch/csrc/distributed/c10d/Utils.hpp> 23 24 namespace c10d { 25 26 constexpr const char* GLOO_BACKEND_NAME = "gloo"; 27 28 // ProcessGroupGloo implements Gloo bindings for c10d. 29 // 30 // All functions on this class are expected to be called in the same 31 // order across processes in the group. This is the only way that we 32 // can guarantee to match up the same calls across processes. For 33 // multi-threaded usage of process groups, you can use consider using 34 // multiple process group instances. 35 // 36 // The Gloo algorithms that this class calls into are cached by their 37 // signature (see description of AlgorithmKey above). This cache works 38 // as follows: every function call instantiates an AlgorithmKey and 39 // looks in the cache for existing entries. If there is one, it is 40 // removed from the cache and returned to the caller. If there are 41 // none, a new entry is created and returned. If an entry was created 42 // before, but is still in use, the call will block and wait until the 43 // entry is returned to the cache. 44 // 45 // In the future, we hope to extend this to allow multiple entries per 46 // key, to enable parallelism for a single key. The number of entries 47 // per key must always be identical for all processes. This maximum 48 // number can be automatically tuned, but only if we let a single 49 // process take charge, and have it broadcast the limits. 50 // 51 class TORCH_API ProcessGroupGloo : public Backend { 52 public: 53 // AsyncWork is the Gloo specific superclass for asynchronous work items. 54 // We can split asynchronous work into 3 phases: 55 // 1) Sanity checks and prepare input (e.g. memcpy) 56 // 2) Run operation on background thread 57 // 3) Synchronize with completion on foreground thread 58 // 59 // There is state to be shared between these 3 phases and all of this state 60 // is captured in the AsyncWork class and its derivatives. 61 // 62 // Note: while we are porting operations to use new style collectives, there 63 // is a split between operations using the existing caching approach and 64 // operations using the new AsyncWork base class. Over time we will port 65 // all operations and perform needed cleanup. 66 // 67 // FIXME: This probably should be called WorkGloo since the work is executed 68 // in sync mode by a background thread. 69 class TORCH_API AsyncWork : public Work { 70 public: 71 explicit AsyncWork( 72 std::vector<std::vector<at::Tensor>> outputTensors, 73 OpType opType, 74 uint64_t seq, 75 const char* profilingTitle = nullptr, 76 const std::optional<std::vector<at::Tensor>>& inputTensors = 77 std::nullopt); 78 79 ~AsyncWork() override = default; 80 81 static void execute(const c10::intrusive_ptr<AsyncWork>& work); 82 83 virtual void run() = 0; 84 85 std::vector<at::Tensor> result() override; 86 87 c10::intrusive_ptr<c10::ivalue::Future> getFuture() override; 88 uint64_t getSequencenumber() const override; 89 90 protected: 91 friend class ProcessGroupGloo; 92 93 private: 94 void finishWorkGloo(); 95 void finishWorkGlooError(const std::exception_ptr& eptr); 96 inline void recordAsyncWorkProfilingInfo( 97 const char* profilingTitle, 98 const std::optional<std::vector<at::Tensor>>& inputTensors); 99 100 const std::vector<std::vector<at::Tensor>> outputTensors_; 101 c10::intrusive_ptr<at::ivalue::Future> future_; 102 std::function<void()> recordFunctionBeforeCallback_; 103 const uint64_t seq_; 104 }; 105 106 // Wrap c10d store as Gloo store 107 class TORCH_API GlooStore : public ::gloo::rendezvous::Store { 108 public: GlooStore(const c10::intrusive_ptr<::c10d::Store> & store)109 GlooStore(const c10::intrusive_ptr<::c10d::Store>& store) : store_(store) {} 110 setUint(const std::string & key,const std::vector<uint8_t> & value)111 void setUint(const std::string& key, const std::vector<uint8_t>& value) { 112 store_->set(key, value); 113 } 114 set(const std::string & key,const std::vector<char> & value)115 void set(const std::string& key, const std::vector<char>& value) override { 116 std::vector<uint8_t> tmp(value.begin(), value.end()); 117 store_->set(key, tmp); 118 } 119 getUint(const std::string & key)120 std::vector<uint8_t> getUint(const std::string& key) { 121 auto value = store_->get(key); 122 return value; 123 } 124 get(const std::string & key)125 std::vector<char> get(const std::string& key) override { 126 auto value = store_->get(key); 127 return std::vector<char>(value.begin(), value.end()); 128 } 129 wait(const std::vector<std::string> & keys)130 void wait(const std::vector<std::string>& keys) override { 131 store_->wait(keys, ::c10d::Store::kDefaultTimeout); 132 } 133 wait(const std::vector<std::string> & keys,const std::chrono::milliseconds & timeout)134 void wait( 135 const std::vector<std::string>& keys, 136 const std::chrono::milliseconds& timeout) override { 137 store_->wait(keys, timeout); 138 } 139 140 #ifdef GLOO_STORE_HAS_STORE_V2 has_v2_support()141 bool has_v2_support() override { 142 return store_->hasExtendedApi(); 143 } 144 multi_get(const std::vector<std::string> & keys)145 std::vector<std::vector<char>> multi_get( 146 const std::vector<std::string>& keys) override { 147 std::vector<std::vector<char>> res; 148 for (auto& value : store_->multiGet(keys)) { 149 res.emplace_back(value.begin(), value.end()); 150 } 151 return res; 152 } 153 multi_set(const std::vector<std::string> & keys,const std::vector<std::vector<char>> & values)154 void multi_set( 155 const std::vector<std::string>& keys, 156 const std::vector<std::vector<char>>& values) override { 157 std::vector<std::vector<uint8_t>> u_values; 158 u_values.reserve(values.size()); 159 for (auto& value : values) { 160 u_values.emplace_back(value.begin(), value.end()); 161 } 162 store_->multiSet(keys, u_values); 163 } 164 append(const std::string & key,const std::vector<char> & value)165 void append(const std::string& key, const std::vector<char>& value) 166 override { 167 std::vector<uint8_t> tmp(value.begin(), value.end()); 168 return store_->append(key, tmp); 169 } 170 add(const std::string & key,int64_t value)171 int64_t add(const std::string& key, int64_t value) override { 172 return store_->add(key, value); 173 } 174 #endif 175 176 protected: 177 c10::intrusive_ptr<::c10d::Store> store_; 178 }; 179 180 // For send and recv operations there is no need to pass them to the 181 // thread pool as they are entirely completed by the device thread. 182 // This work object is used to synchronize completion of the send or 183 // recv operation. It keeps a reference to the tensor it is 184 // operating on to prevent it from being deallocated while the 185 // operation is still in flight. 186 class TORCH_API SendWork : public Work { 187 public: 188 explicit SendWork( 189 at::Tensor& tensor, 190 std::unique_ptr<::gloo::transport::UnboundBuffer> buffer, 191 uint64_t seq); 192 193 bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; 194 195 void abort() override; 196 197 uint64_t getSequencenumber() const override; 198 199 protected: 200 at::Tensor tensor_; 201 std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_; 202 const uint64_t seq_; 203 }; 204 205 class TORCH_API RecvWork : public Work { 206 public: 207 explicit RecvWork( 208 at::Tensor& tensor, 209 std::unique_ptr<::gloo::transport::UnboundBuffer> buffer, 210 OpType opType, 211 uint64_t seq, 212 const char* profilingTitle = nullptr); 213 214 int sourceRank() const override; 215 216 bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; 217 218 void abort() override; 219 220 uint64_t getSequencenumber() const override; 221 222 protected: 223 at::Tensor tensor_; 224 std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_; 225 int srcRank_; 226 const uint64_t seq_; 227 }; 228 229 struct TORCH_API Options : public Backend::Options { 230 explicit Options( 231 std::chrono::milliseconds timeout = kBackendDefaultTimeout); 232 233 // return intrusive_ptr of the object createc10d::ProcessGroupGloo::Options234 static c10::intrusive_ptr<Options> create( 235 std::chrono::milliseconds timeout = kBackendDefaultTimeout) { 236 return c10::make_intrusive<Options>(timeout); 237 } 238 239 std::vector<std::shared_ptr<::gloo::transport::Device>> devices; 240 int threads; 241 }; 242 getBackendName() const243 const std::string getBackendName() const override { 244 return std::string(GLOO_BACKEND_NAME); 245 } 246 247 // Helper functions to create a new device object. 248 // They are static functions on this class to keep them logically 249 // separate from the rest of the code base (e.g. torch/csrc/distributed). 250 251 // Create new device instance for specific interface. 252 static std::shared_ptr<::gloo::transport::Device> createDeviceForInterface( 253 const std::string& interface); 254 255 // Create new device instance for specific hostname or address. 256 static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname( 257 const std::string& hostname); 258 259 // Create new device instance. 260 // It tries to resolve this machine's hostname and bind to that address. 261 // If that fails (i.e. the hostname doesn't resolve to an address), it 262 // falls back to binding to the loopback address. 263 static std::shared_ptr<::gloo::transport::Device> createDefaultDevice(); 264 265 // Create ProcessGroupGloo instance. 266 static c10::intrusive_ptr<ProcessGroupGloo> createProcessGroupGloo( 267 const c10::intrusive_ptr<Store>& store, 268 int rank, 269 int size, 270 std::chrono::milliseconds timeout); 271 272 explicit ProcessGroupGloo( 273 const c10::intrusive_ptr<Store>& store, 274 int rank, 275 int size, 276 c10::intrusive_ptr<Options> options = Options::create()); 277 278 ~ProcessGroupGloo() override; 279 getOptions()280 c10::intrusive_ptr<Options> getOptions() { 281 return options_; 282 } 283 284 c10::intrusive_ptr<Work> broadcast( 285 std::vector<at::Tensor>& tensors, 286 const BroadcastOptions& opts = BroadcastOptions()) override; 287 288 c10::intrusive_ptr<Work> allreduce( 289 std::vector<at::Tensor>& tensors, 290 const AllreduceOptions& opts = AllreduceOptions()) override; 291 292 c10::intrusive_ptr<Work> allreduce_sparse( 293 std::vector<at::Tensor>& tensors, 294 const AllreduceOptions& opts = AllreduceOptions()) override; 295 296 c10::intrusive_ptr<Work> allreduce_coalesced( 297 std::vector<at::Tensor>& tensors, 298 const AllreduceCoalescedOptions& opts = 299 AllreduceCoalescedOptions()) override; 300 301 c10::intrusive_ptr<Work> reduce( 302 std::vector<at::Tensor>& tensors, 303 const ReduceOptions& opts = ReduceOptions()) override; 304 305 c10::intrusive_ptr<Work> _reduce_scatter_base( 306 at::Tensor& outputTensor, 307 at::Tensor& inputTensor, 308 const ReduceScatterOptions& opts = ReduceScatterOptions()) override; 309 310 c10::intrusive_ptr<Work> _allgather_base( 311 at::Tensor& output_tensor, 312 at::Tensor& input_tensor, 313 const AllgatherOptions& opts = AllgatherOptions()) override; 314 315 c10::intrusive_ptr<Work> allgather( 316 std::vector<std::vector<at::Tensor>>& outputs, 317 std::vector<at::Tensor>& inputs, 318 const AllgatherOptions& opts = AllgatherOptions()) override; 319 320 c10::intrusive_ptr<Work> allgather_coalesced( 321 std::vector<std::vector<at::Tensor>>& output_lists, 322 std::vector<at::Tensor>& input_list, 323 const AllgatherOptions& opts = AllgatherOptions()) override; 324 325 c10::intrusive_ptr<Work> allgather_into_tensor_coalesced( 326 std::vector<at::Tensor>& outputs, 327 std::vector<at::Tensor>& inputs, 328 const AllgatherOptions& opts = AllgatherOptions()) override; 329 330 c10::intrusive_ptr<Work> gather( 331 std::vector<std::vector<at::Tensor>>& outputs, 332 std::vector<at::Tensor>& inputs, 333 const GatherOptions& opts = GatherOptions()) override; 334 335 c10::intrusive_ptr<Work> scatter( 336 std::vector<at::Tensor>& outputs, 337 std::vector<std::vector<at::Tensor>>& inputs, 338 const ScatterOptions& opts = ScatterOptions()) override; 339 340 c10::intrusive_ptr<Work> reduce_scatter( 341 std::vector<at::Tensor>& outputs, 342 std::vector<std::vector<at::Tensor>>& inputs, 343 const ReduceScatterOptions& opts = ReduceScatterOptions()) override; 344 345 c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced( 346 std::vector<at::Tensor>& outputTensors, 347 std::vector<at::Tensor>& inputTensors, 348 const ReduceScatterOptions& opts = ReduceScatterOptions()) override; 349 350 c10::intrusive_ptr<Work> alltoall_base( 351 at::Tensor& outputTensor, 352 at::Tensor& inputTensor, 353 std::vector<int64_t>& outputCounts, 354 std::vector<int64_t>& inputCounts, 355 const AllToAllOptions& opts = AllToAllOptions()) override; 356 357 c10::intrusive_ptr<Work> send( 358 std::vector<at::Tensor>& tensors, 359 int dstRank, 360 int tag) override; 361 362 c10::intrusive_ptr<Work> recv( 363 std::vector<at::Tensor>& tensors, 364 int srcRank, 365 int tag) override; 366 367 c10::intrusive_ptr<Work> recvAnysource( 368 std::vector<at::Tensor>& tensors, 369 int tag) override; 370 371 c10::intrusive_ptr<Work> barrier( 372 const BarrierOptions& opts = BarrierOptions()) override; 373 374 void enableCollectivesTiming() override; 375 _getStore() const376 const std::unique_ptr<::gloo::rendezvous::Store>& _getStore() const { 377 return store_; 378 } 379 380 // Similar to barrier(), but blocks rank 0 until all other ranks have 381 // acknowledged that they are alive (through send/recv from rank 0). Rank 0 382 // is able to report all failed ranks if waitAllRanks = true, otherwise 383 // reports the first rank it detected as failed. 384 void monitoredBarrier( 385 const BarrierOptions& opts = BarrierOptions(), 386 bool waitAllRanks = false) override; 387 388 // Agrees on an initial sequence number for the whole group by having rank 0 389 // create it and broadcast it to other ranks using the store. 390 void setSequenceNumberForGroup() override; 391 392 // Retrieves the current sequence number for the whole group, which should be 393 // in sync. If the returned number is not consistent across the group, it 394 // may indicate that there is some sort of collective desynchronization. 395 uint64_t getSequenceNumberForGroup() override; 396 getNumThreads()397 int getNumThreads() { 398 return options_->threads; 399 } 400 401 protected: 402 std::unique_ptr<::gloo::rendezvous::Store> store_; 403 const c10::intrusive_ptr<Options> options_; 404 405 // Every Gloo context represents a set of connections to its peers. 406 // In order to use more than one device (or allow for parallelism on 407 // a single device), you need multiple contexts. 408 std::vector<std::shared_ptr<::gloo::Context>> contexts_; 409 std::vector<std::thread> threads_; 410 bool stop_; 411 412 // Incremented for every collective we kick off. 413 // The value is used as tag for collective operations. Collectives are kicked 414 // off in identical order across processes. Therefore the tag can be used 415 // to match up operations during concurrent execution. 416 uint32_t collectiveCounter_; 417 418 // Returns next collective tag to use (uses collectiveCounter_). 419 uint32_t nextTag(); 420 421 // Returns the context to use for the specified tag. 422 // With `nextTag` returning an increasing number, this should lead 423 // to contexts being used in a round-robin fashion. 424 std::shared_ptr<::gloo::Context> getContext(uint32_t tag); 425 426 // Entrypoint for worker threads. 427 void runLoop(int workerIndex); 428 429 // Queue work to run on worker thread. 430 void enqueue(c10::intrusive_ptr<AsyncWork> work); 431 432 // Keep both a queue of pending work, and a vector with in progress work. 433 // Both of these can only be mutated when holding the queue lock. 434 // We keep both around instead of just the queue, so we can grab a weak_ptr 435 // to all in progress and pending work when executing a barrier. 436 // When executing a barrier, we need to ensure that all prior work 437 // has completed before completing itself. 438 std::deque<c10::intrusive_ptr<AsyncWork>> workQueue_; 439 std::vector<c10::intrusive_ptr<AsyncWork>> workInProgress_; 440 std::mutex workMutex_; 441 std::condition_variable workProduceCV_; 442 std::condition_variable workConsumeCV_; 443 uint64_t seq_{0}; 444 }; 445 446 } // namespace c10d 447 448 #endif // USE_C10D_GLOO 449