xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_C10D_GLOO
4 
5 #include <condition_variable>
6 #include <deque>
7 #include <mutex>
8 #include <thread>
9 #include <vector>
10 
11 #include <gloo/algorithm.h>
12 #include <gloo/common/error.h>
13 #include <gloo/context.h>
14 #include <gloo/rendezvous/store.h>
15 #include <gloo/transport/device.h>
16 
17 #include <c10/util/hash.h>
18 
19 #include <torch/csrc/distributed/c10d/Backend.hpp>
20 #include <torch/csrc/distributed/c10d/Store.hpp>
21 #include <torch/csrc/distributed/c10d/Types.hpp>
22 #include <torch/csrc/distributed/c10d/Utils.hpp>
23 
24 namespace c10d {
25 
26 constexpr const char* GLOO_BACKEND_NAME = "gloo";
27 
28 // ProcessGroupGloo implements Gloo bindings for c10d.
29 //
30 // All functions on this class are expected to be called in the same
31 // order across processes in the group. This is the only way that we
32 // can guarantee to match up the same calls across processes. For
33 // multi-threaded usage of process groups, you can use consider using
34 // multiple process group instances.
35 //
36 // The Gloo algorithms that this class calls into are cached by their
37 // signature (see description of AlgorithmKey above). This cache works
38 // as follows: every function call instantiates an AlgorithmKey and
39 // looks in the cache for existing entries. If there is one, it is
40 // removed from the cache and returned to the caller. If there are
41 // none, a new entry is created and returned. If an entry was created
42 // before, but is still in use, the call will block and wait until the
43 // entry is returned to the cache.
44 //
45 // In the future, we hope to extend this to allow multiple entries per
46 // key, to enable parallelism for a single key. The number of entries
47 // per key must always be identical for all processes. This maximum
48 // number can be automatically tuned, but only if we let a single
49 // process take charge, and have it broadcast the limits.
50 //
51 class TORCH_API ProcessGroupGloo : public Backend {
52  public:
53   // AsyncWork is the Gloo specific superclass for asynchronous work items.
54   // We can split asynchronous work into 3 phases:
55   // 1) Sanity checks and prepare input (e.g. memcpy)
56   // 2) Run operation on background thread
57   // 3) Synchronize with completion on foreground thread
58   //
59   // There is state to be shared between these 3 phases and all of this state
60   // is captured in the AsyncWork class and its derivatives.
61   //
62   // Note: while we are porting operations to use new style collectives, there
63   // is a split between operations using the existing caching approach and
64   // operations using the new AsyncWork base class. Over time we will port
65   // all operations and perform needed cleanup.
66   //
67   // FIXME: This probably should be called WorkGloo since the work is executed
68   // in sync mode by a background thread.
69   class TORCH_API AsyncWork : public Work {
70    public:
71     explicit AsyncWork(
72         std::vector<std::vector<at::Tensor>> outputTensors,
73         OpType opType,
74         uint64_t seq,
75         const char* profilingTitle = nullptr,
76         const std::optional<std::vector<at::Tensor>>& inputTensors =
77             std::nullopt);
78 
79     ~AsyncWork() override = default;
80 
81     static void execute(const c10::intrusive_ptr<AsyncWork>& work);
82 
83     virtual void run() = 0;
84 
85     std::vector<at::Tensor> result() override;
86 
87     c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
88     uint64_t getSequencenumber() const override;
89 
90    protected:
91     friend class ProcessGroupGloo;
92 
93    private:
94     void finishWorkGloo();
95     void finishWorkGlooError(const std::exception_ptr& eptr);
96     inline void recordAsyncWorkProfilingInfo(
97         const char* profilingTitle,
98         const std::optional<std::vector<at::Tensor>>& inputTensors);
99 
100     const std::vector<std::vector<at::Tensor>> outputTensors_;
101     c10::intrusive_ptr<at::ivalue::Future> future_;
102     std::function<void()> recordFunctionBeforeCallback_;
103     const uint64_t seq_;
104   };
105 
106   // Wrap c10d store as Gloo store
107   class TORCH_API GlooStore : public ::gloo::rendezvous::Store {
108    public:
GlooStore(const c10::intrusive_ptr<::c10d::Store> & store)109     GlooStore(const c10::intrusive_ptr<::c10d::Store>& store) : store_(store) {}
110 
setUint(const std::string & key,const std::vector<uint8_t> & value)111     void setUint(const std::string& key, const std::vector<uint8_t>& value) {
112       store_->set(key, value);
113     }
114 
set(const std::string & key,const std::vector<char> & value)115     void set(const std::string& key, const std::vector<char>& value) override {
116       std::vector<uint8_t> tmp(value.begin(), value.end());
117       store_->set(key, tmp);
118     }
119 
getUint(const std::string & key)120     std::vector<uint8_t> getUint(const std::string& key) {
121       auto value = store_->get(key);
122       return value;
123     }
124 
get(const std::string & key)125     std::vector<char> get(const std::string& key) override {
126       auto value = store_->get(key);
127       return std::vector<char>(value.begin(), value.end());
128     }
129 
wait(const std::vector<std::string> & keys)130     void wait(const std::vector<std::string>& keys) override {
131       store_->wait(keys, ::c10d::Store::kDefaultTimeout);
132     }
133 
wait(const std::vector<std::string> & keys,const std::chrono::milliseconds & timeout)134     void wait(
135         const std::vector<std::string>& keys,
136         const std::chrono::milliseconds& timeout) override {
137       store_->wait(keys, timeout);
138     }
139 
140 #ifdef GLOO_STORE_HAS_STORE_V2
has_v2_support()141     bool has_v2_support() override {
142       return store_->hasExtendedApi();
143     }
144 
multi_get(const std::vector<std::string> & keys)145     std::vector<std::vector<char>> multi_get(
146         const std::vector<std::string>& keys) override {
147       std::vector<std::vector<char>> res;
148       for (auto& value : store_->multiGet(keys)) {
149         res.emplace_back(value.begin(), value.end());
150       }
151       return res;
152     }
153 
multi_set(const std::vector<std::string> & keys,const std::vector<std::vector<char>> & values)154     void multi_set(
155         const std::vector<std::string>& keys,
156         const std::vector<std::vector<char>>& values) override {
157       std::vector<std::vector<uint8_t>> u_values;
158       u_values.reserve(values.size());
159       for (auto& value : values) {
160         u_values.emplace_back(value.begin(), value.end());
161       }
162       store_->multiSet(keys, u_values);
163     }
164 
append(const std::string & key,const std::vector<char> & value)165     void append(const std::string& key, const std::vector<char>& value)
166         override {
167       std::vector<uint8_t> tmp(value.begin(), value.end());
168       return store_->append(key, tmp);
169     }
170 
add(const std::string & key,int64_t value)171     int64_t add(const std::string& key, int64_t value) override {
172       return store_->add(key, value);
173     }
174 #endif
175 
176    protected:
177     c10::intrusive_ptr<::c10d::Store> store_;
178   };
179 
180   // For send and recv operations there is no need to pass them to the
181   // thread pool as they are entirely completed by the device thread.
182   // This work object is used to synchronize completion of the send or
183   // recv operation. It keeps a reference to the tensor it is
184   // operating on to prevent it from being deallocated while the
185   // operation is still in flight.
186   class TORCH_API SendWork : public Work {
187    public:
188     explicit SendWork(
189         at::Tensor& tensor,
190         std::unique_ptr<::gloo::transport::UnboundBuffer> buffer,
191         uint64_t seq);
192 
193     bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
194 
195     void abort() override;
196 
197     uint64_t getSequencenumber() const override;
198 
199    protected:
200     at::Tensor tensor_;
201     std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_;
202     const uint64_t seq_;
203   };
204 
205   class TORCH_API RecvWork : public Work {
206    public:
207     explicit RecvWork(
208         at::Tensor& tensor,
209         std::unique_ptr<::gloo::transport::UnboundBuffer> buffer,
210         OpType opType,
211         uint64_t seq,
212         const char* profilingTitle = nullptr);
213 
214     int sourceRank() const override;
215 
216     bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
217 
218     void abort() override;
219 
220     uint64_t getSequencenumber() const override;
221 
222    protected:
223     at::Tensor tensor_;
224     std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_;
225     int srcRank_;
226     const uint64_t seq_;
227   };
228 
229   struct TORCH_API Options : public Backend::Options {
230     explicit Options(
231         std::chrono::milliseconds timeout = kBackendDefaultTimeout);
232 
233     // return intrusive_ptr of the object
createc10d::ProcessGroupGloo::Options234     static c10::intrusive_ptr<Options> create(
235         std::chrono::milliseconds timeout = kBackendDefaultTimeout) {
236       return c10::make_intrusive<Options>(timeout);
237     }
238 
239     std::vector<std::shared_ptr<::gloo::transport::Device>> devices;
240     int threads;
241   };
242 
getBackendName() const243   const std::string getBackendName() const override {
244     return std::string(GLOO_BACKEND_NAME);
245   }
246 
247   // Helper functions to create a new device object.
248   // They are static functions on this class to keep them logically
249   // separate from the rest of the code base (e.g. torch/csrc/distributed).
250 
251   // Create new device instance for specific interface.
252   static std::shared_ptr<::gloo::transport::Device> createDeviceForInterface(
253       const std::string& interface);
254 
255   // Create new device instance for specific hostname or address.
256   static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname(
257       const std::string& hostname);
258 
259   // Create new device instance.
260   // It tries to resolve this machine's hostname and bind to that address.
261   // If that fails (i.e. the hostname doesn't resolve to an address), it
262   // falls back to binding to the loopback address.
263   static std::shared_ptr<::gloo::transport::Device> createDefaultDevice();
264 
265   // Create ProcessGroupGloo instance.
266   static c10::intrusive_ptr<ProcessGroupGloo> createProcessGroupGloo(
267       const c10::intrusive_ptr<Store>& store,
268       int rank,
269       int size,
270       std::chrono::milliseconds timeout);
271 
272   explicit ProcessGroupGloo(
273       const c10::intrusive_ptr<Store>& store,
274       int rank,
275       int size,
276       c10::intrusive_ptr<Options> options = Options::create());
277 
278   ~ProcessGroupGloo() override;
279 
getOptions()280   c10::intrusive_ptr<Options> getOptions() {
281     return options_;
282   }
283 
284   c10::intrusive_ptr<Work> broadcast(
285       std::vector<at::Tensor>& tensors,
286       const BroadcastOptions& opts = BroadcastOptions()) override;
287 
288   c10::intrusive_ptr<Work> allreduce(
289       std::vector<at::Tensor>& tensors,
290       const AllreduceOptions& opts = AllreduceOptions()) override;
291 
292   c10::intrusive_ptr<Work> allreduce_sparse(
293       std::vector<at::Tensor>& tensors,
294       const AllreduceOptions& opts = AllreduceOptions()) override;
295 
296   c10::intrusive_ptr<Work> allreduce_coalesced(
297       std::vector<at::Tensor>& tensors,
298       const AllreduceCoalescedOptions& opts =
299           AllreduceCoalescedOptions()) override;
300 
301   c10::intrusive_ptr<Work> reduce(
302       std::vector<at::Tensor>& tensors,
303       const ReduceOptions& opts = ReduceOptions()) override;
304 
305   c10::intrusive_ptr<Work> _reduce_scatter_base(
306       at::Tensor& outputTensor,
307       at::Tensor& inputTensor,
308       const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
309 
310   c10::intrusive_ptr<Work> _allgather_base(
311       at::Tensor& output_tensor,
312       at::Tensor& input_tensor,
313       const AllgatherOptions& opts = AllgatherOptions()) override;
314 
315   c10::intrusive_ptr<Work> allgather(
316       std::vector<std::vector<at::Tensor>>& outputs,
317       std::vector<at::Tensor>& inputs,
318       const AllgatherOptions& opts = AllgatherOptions()) override;
319 
320   c10::intrusive_ptr<Work> allgather_coalesced(
321       std::vector<std::vector<at::Tensor>>& output_lists,
322       std::vector<at::Tensor>& input_list,
323       const AllgatherOptions& opts = AllgatherOptions()) override;
324 
325   c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
326       std::vector<at::Tensor>& outputs,
327       std::vector<at::Tensor>& inputs,
328       const AllgatherOptions& opts = AllgatherOptions()) override;
329 
330   c10::intrusive_ptr<Work> gather(
331       std::vector<std::vector<at::Tensor>>& outputs,
332       std::vector<at::Tensor>& inputs,
333       const GatherOptions& opts = GatherOptions()) override;
334 
335   c10::intrusive_ptr<Work> scatter(
336       std::vector<at::Tensor>& outputs,
337       std::vector<std::vector<at::Tensor>>& inputs,
338       const ScatterOptions& opts = ScatterOptions()) override;
339 
340   c10::intrusive_ptr<Work> reduce_scatter(
341       std::vector<at::Tensor>& outputs,
342       std::vector<std::vector<at::Tensor>>& inputs,
343       const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
344 
345   c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
346       std::vector<at::Tensor>& outputTensors,
347       std::vector<at::Tensor>& inputTensors,
348       const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
349 
350   c10::intrusive_ptr<Work> alltoall_base(
351       at::Tensor& outputTensor,
352       at::Tensor& inputTensor,
353       std::vector<int64_t>& outputCounts,
354       std::vector<int64_t>& inputCounts,
355       const AllToAllOptions& opts = AllToAllOptions()) override;
356 
357   c10::intrusive_ptr<Work> send(
358       std::vector<at::Tensor>& tensors,
359       int dstRank,
360       int tag) override;
361 
362   c10::intrusive_ptr<Work> recv(
363       std::vector<at::Tensor>& tensors,
364       int srcRank,
365       int tag) override;
366 
367   c10::intrusive_ptr<Work> recvAnysource(
368       std::vector<at::Tensor>& tensors,
369       int tag) override;
370 
371   c10::intrusive_ptr<Work> barrier(
372       const BarrierOptions& opts = BarrierOptions()) override;
373 
374   void enableCollectivesTiming() override;
375 
_getStore() const376   const std::unique_ptr<::gloo::rendezvous::Store>& _getStore() const {
377     return store_;
378   }
379 
380   // Similar to barrier(), but blocks rank 0 until all other ranks have
381   // acknowledged that they are alive (through send/recv from rank 0). Rank 0
382   // is able to report all failed ranks if waitAllRanks = true, otherwise
383   // reports the first rank it detected as failed.
384   void monitoredBarrier(
385       const BarrierOptions& opts = BarrierOptions(),
386       bool waitAllRanks = false) override;
387 
388   // Agrees on an initial sequence number for the whole group by having rank 0
389   // create it and broadcast it to other ranks using the store.
390   void setSequenceNumberForGroup() override;
391 
392   // Retrieves the current sequence number for the whole group, which should be
393   // in sync. If the returned number is not consistent across the group, it
394   // may indicate that there is some sort of collective desynchronization.
395   uint64_t getSequenceNumberForGroup() override;
396 
getNumThreads()397   int getNumThreads() {
398     return options_->threads;
399   }
400 
401  protected:
402   std::unique_ptr<::gloo::rendezvous::Store> store_;
403   const c10::intrusive_ptr<Options> options_;
404 
405   // Every Gloo context represents a set of connections to its peers.
406   // In order to use more than one device (or allow for parallelism on
407   // a single device), you need multiple contexts.
408   std::vector<std::shared_ptr<::gloo::Context>> contexts_;
409   std::vector<std::thread> threads_;
410   bool stop_;
411 
412   // Incremented for every collective we kick off.
413   // The value is used as tag for collective operations. Collectives are kicked
414   // off in identical order across processes. Therefore the tag can be used
415   // to match up operations during concurrent execution.
416   uint32_t collectiveCounter_;
417 
418   // Returns next collective tag to use (uses collectiveCounter_).
419   uint32_t nextTag();
420 
421   // Returns the context to use for the specified tag.
422   // With `nextTag` returning an increasing number, this should lead
423   // to contexts being used in a round-robin fashion.
424   std::shared_ptr<::gloo::Context> getContext(uint32_t tag);
425 
426   // Entrypoint for worker threads.
427   void runLoop(int workerIndex);
428 
429   // Queue work to run on worker thread.
430   void enqueue(c10::intrusive_ptr<AsyncWork> work);
431 
432   // Keep both a queue of pending work, and a vector with in progress work.
433   // Both of these can only be mutated when holding the queue lock.
434   // We keep both around instead of just the queue, so we can grab a weak_ptr
435   // to all in progress and pending work when executing a barrier.
436   // When executing a barrier, we need to ensure that all prior work
437   // has completed before completing itself.
438   std::deque<c10::intrusive_ptr<AsyncWork>> workQueue_;
439   std::vector<c10::intrusive_ptr<AsyncWork>> workInProgress_;
440   std::mutex workMutex_;
441   std::condition_variable workProduceCV_;
442   std::condition_variable workConsumeCV_;
443   uint64_t seq_{0};
444 };
445 
446 } // namespace c10d
447 
448 #endif // USE_C10D_GLOO
449