xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/Backend.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <memory>
4 #include <utility>
5 #include <vector>
6 
7 #include <ATen/ATen.h>
8 #include <c10/macros/Macros.h>
9 
10 #include <torch/csrc/distributed/c10d/Types.hpp>
11 #include <torch/csrc/distributed/c10d/Utils.hpp>
12 #include <torch/csrc/distributed/c10d/Work.hpp>
13 #include <torch/csrc/distributed/c10d/debug.h>
14 
15 constexpr auto kBackendDefaultTimeout =
16     std::chrono::milliseconds(30 * 60 * 1000);
17 
18 namespace c10d {
19 
20 class TORCH_API Backend : public torch::CustomClassHolder {
21  public:
22   // Backend Options is a base struct that defines the basic options
23   // when constructing a Backend. Each Backend subclass should
24   // extend this struct and define its options if it wants to provide more
25   // config options (beyond basic ones defined here) to end user.
26   struct TORCH_API Options : torch::CustomClassHolder {
Optionsc10d::Backend::Options27     explicit Options(
28         std::string backend,
29         std::chrono::milliseconds timeout = kBackendDefaultTimeout)
30         : timeout(timeout), backend(std::move(backend)) {}
31     ~Options() override = default;
32 
33     std::chrono::milliseconds timeout;
34 
35     // backend name
36     // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
37     const std::string backend;
38   };
39 
40   explicit Backend(int rank, int size);
41   ~Backend() override = 0;
42 
getRank() const43   int getRank() const {
44     return rank_;
45   }
46 
getSize() const47   int getSize() const {
48     return size_;
49   }
50 
51   // Returns an unique opaque ID of this backend that can be used to correlate
52   // with its collectives.
getID() const53   int64_t getID() const {
54     return reinterpret_cast<std::intptr_t>(this);
55   }
56 
supportsSplitting() const57   virtual bool supportsSplitting() const {
58     return false;
59   }
60 
startCoalescing()61   virtual void startCoalescing() {
62     TORCH_CHECK(
63         false,
64         c10::str(
65             "Backend ",
66             getBackendName(),
67             " does not implement startCoalescing"));
68   }
69 
endCoalescing()70   virtual c10::intrusive_ptr<Work> endCoalescing() {
71     TORCH_CHECK(
72         false,
73         c10::str(
74             "Backend ", getBackendName(), " does not implement endCoalescing"));
75   }
76 
77   // Subclasses must override this method to return the backend name
getBackendName() const78   virtual const std::string getBackendName() const {
79     TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented.");
80   };
81 
broadcast(std::vector<at::Tensor> &,const BroadcastOptions &=BroadcastOptions ())82   virtual c10::intrusive_ptr<Work> broadcast(
83       std::vector<at::Tensor>& /* tensors */,
84       const BroadcastOptions& /* opts */ = BroadcastOptions()) {
85     TORCH_CHECK(
86         false,
87         c10::str("Backend ", getBackendName(), " does not support broadcast"));
88   }
89 
allreduce(std::vector<at::Tensor> &,const AllreduceOptions &=AllreduceOptions ())90   virtual c10::intrusive_ptr<Work> allreduce(
91       std::vector<at::Tensor>& /* tensors */,
92       const AllreduceOptions& /* opts */ = AllreduceOptions()) {
93     TORCH_CHECK(
94         false,
95         c10::str("Backend ", getBackendName(), " does not support allreduce"));
96   }
97 
allreduce_sparse(std::vector<at::Tensor> &,const AllreduceOptions &=AllreduceOptions ())98   virtual c10::intrusive_ptr<Work> allreduce_sparse(
99       std::vector<at::Tensor>& /* tensors */,
100       const AllreduceOptions& /* opts */ = AllreduceOptions()) {
101     TORCH_CHECK(
102         false,
103         c10::str(
104             "Backend ",
105             getBackendName(),
106             " does not support allreduce sparse"));
107   }
108 
allreduce_coalesced(std::vector<at::Tensor> &,const AllreduceCoalescedOptions &=AllreduceCoalescedOptions ())109   virtual c10::intrusive_ptr<Work> allreduce_coalesced(
110       std::vector<at::Tensor>& /* tensors */,
111       const AllreduceCoalescedOptions& /* opts */ =
112           AllreduceCoalescedOptions()) {
113     TORCH_CHECK(
114         false,
115         c10::str(
116             "Backend ",
117             getBackendName(),
118             " does not support allreduce_coalesced"));
119   }
120 
reduce(std::vector<at::Tensor> &,const ReduceOptions &=ReduceOptions ())121   virtual c10::intrusive_ptr<Work> reduce(
122       std::vector<at::Tensor>& /* tensors */,
123       const ReduceOptions& /* opts */ = ReduceOptions()) {
124     TORCH_CHECK(
125         false,
126         c10::str("Backend ", getBackendName(), " does not support reduce"));
127   }
128 
allgather(std::vector<std::vector<at::Tensor>> &,std::vector<at::Tensor> &,const AllgatherOptions &=AllgatherOptions ())129   virtual c10::intrusive_ptr<Work> allgather(
130       std::vector<std::vector<at::Tensor>>& /* outputTensors */,
131       std::vector<at::Tensor>& /* inputTensors */,
132       const AllgatherOptions& /* opts */ = AllgatherOptions()) {
133     TORCH_CHECK(
134         false,
135         c10::str("Backend ", getBackendName(), " does not support allgather"));
136   }
137 
138   // Gathers a single tensor inputBuffer into a single buffer outputBuffer that
139   // is interpreted as a contiguous collection of size inputBuffer * WORLD_SIZE.
140   // For implementers of ProcessGroup API and advanced users only.
141   // Note: this function will be deprecated in near future.
_allgather_base(at::Tensor &,at::Tensor &,const AllgatherOptions &=AllgatherOptions ())142   virtual c10::intrusive_ptr<Work> _allgather_base(
143       at::Tensor& /* outputBuffer */,
144       at::Tensor& /* inputBuffer */,
145       const AllgatherOptions& /* opts */ = AllgatherOptions()) {
146     TORCH_CHECK(
147         false,
148         c10::str(
149             "Backend ", getBackendName(), " does not support _allgather_base"));
150   }
151 
152   // This function is deprecated and will be moved out of Backend to comms:
153   // * do not add dependencies on this function,
154   // * do not implement it in your Backend, implement _allgather_base
155   //   instead.
allgather_coalesced(std::vector<std::vector<at::Tensor>> &,std::vector<at::Tensor> &,const AllgatherOptions &=AllgatherOptions ())156   virtual c10::intrusive_ptr<Work> allgather_coalesced(
157       std::vector<std::vector<at::Tensor>>& /* outputTensorLists */,
158       std::vector<at::Tensor>& /* inputTensors */,
159       const AllgatherOptions& /* opts */ = AllgatherOptions()) {
160     TORCH_CHECK(
161         false,
162         c10::str(
163             "Backend ",
164             getBackendName(),
165             " does not support allgather_coalesced"));
166   }
167 
168   // This function is a coalesced version of `allgather_into_tensor` (currently
169   // still named as `_allgather_base`). Each tensor in the vector corresponds to
170   // an input/output of one `allgather_into_tensor` operation.
allgather_into_tensor_coalesced(std::vector<at::Tensor> &,std::vector<at::Tensor> &,const AllgatherOptions &=AllgatherOptions ())171   virtual c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
172       std::vector<at::Tensor>& /* outputs */,
173       std::vector<at::Tensor>& /* inputs */,
174       const AllgatherOptions& /* opts */ = AllgatherOptions()) {
175     TORCH_CHECK(
176         false,
177         c10::str(
178             "Backend ",
179             getBackendName(),
180             " does not support allgather_into_tensor_coalesced"));
181   }
182 
gather(std::vector<std::vector<at::Tensor>> &,std::vector<at::Tensor> &,const GatherOptions &=GatherOptions ())183   virtual c10::intrusive_ptr<Work> gather(
184       std::vector<std::vector<at::Tensor>>& /* outputTensors */,
185       std::vector<at::Tensor>& /* inputTensors */,
186       const GatherOptions& /* opts */ = GatherOptions()) {
187     TORCH_CHECK(
188         false,
189         c10::str("Backend ", getBackendName(), " does not support gather"));
190   }
191 
scatter(std::vector<at::Tensor> &,std::vector<std::vector<at::Tensor>> &,const ScatterOptions &=ScatterOptions ())192   virtual c10::intrusive_ptr<Work> scatter(
193       std::vector<at::Tensor>& /* outputTensors */,
194       std::vector<std::vector<at::Tensor>>& /* inputTensors */,
195       const ScatterOptions& /* opts */ = ScatterOptions()) {
196     TORCH_CHECK(
197         false,
198         c10::str("Backend ", getBackendName(), " does not support scatter"));
199   }
200 
reduce_scatter(std::vector<at::Tensor> &,std::vector<std::vector<at::Tensor>> &,const ReduceScatterOptions &=ReduceScatterOptions ())201   virtual c10::intrusive_ptr<Work> reduce_scatter(
202       std::vector<at::Tensor>& /* outputTensors */,
203       std::vector<std::vector<at::Tensor>>& /* inputTensors */,
204       const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
205     TORCH_CHECK(
206         false,
207         c10::str(
208             "Backend ", getBackendName(), " does not support reduce_scatter"));
209   }
210 
_reduce_scatter_base(at::Tensor &,at::Tensor &,const ReduceScatterOptions &=ReduceScatterOptions ())211   virtual c10::intrusive_ptr<Work> _reduce_scatter_base(
212       at::Tensor& /* outputBuffer */,
213       at::Tensor& /* inputBuffer */,
214       const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
215     TORCH_CHECK(
216         false,
217         c10::str(
218             "Backend ",
219             getBackendName(),
220             " does not support _reduce_scatter_base"));
221   }
222 
223   // This function is a coalesced version of `reduce_scatter_tensor` (currently
224   // still named as `_reduce_scatter_base`). Each tensor in the vector
225   // corresponds to an input/output of one `reduce_scatter_tensor` operation.
reduce_scatter_tensor_coalesced(std::vector<at::Tensor> &,std::vector<at::Tensor> &,const ReduceScatterOptions &=ReduceScatterOptions ())226   virtual c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
227       std::vector<at::Tensor>& /* outputs */,
228       std::vector<at::Tensor>& /* inputs */,
229       const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
230     TORCH_CHECK(
231         false,
232         c10::str(
233             "Backend ",
234             getBackendName(),
235             " does not support reduce_scatter_tensor_coalesced"));
236   }
237 
alltoall_base(at::Tensor &,at::Tensor &,std::vector<int64_t> &,std::vector<int64_t> &,const AllToAllOptions &=AllToAllOptions ())238   virtual c10::intrusive_ptr<Work> alltoall_base(
239       at::Tensor& /* outputBuffer */,
240       at::Tensor& /* inputBuffer */,
241       std::vector<int64_t>& /* outputSplitSizes */,
242       std::vector<int64_t>& /* inputSplitSizes */,
243       const AllToAllOptions& /* opts */ = AllToAllOptions()) {
244     TORCH_CHECK(
245         false,
246         c10::str(
247             "Backend ", getBackendName(), " does not support alltoall_base"));
248   }
249 
alltoall(std::vector<at::Tensor> &,std::vector<at::Tensor> &,const AllToAllOptions & opts=AllToAllOptions ())250   virtual c10::intrusive_ptr<Work> alltoall(
251       std::vector<at::Tensor>& /* outputTensors */,
252       std::vector<at::Tensor>& /* inputTensors */,
253       const AllToAllOptions& opts = AllToAllOptions()) {
254     TORCH_CHECK(
255         false,
256         c10::str("Backend ", getBackendName(), " does not support alltoall"));
257   }
258 
monitoredBarrier(const BarrierOptions &,bool=false)259   virtual void monitoredBarrier(
260       const BarrierOptions& /* unused */,
261       bool /* unused */ = false) {
262     auto backendName = getBackendName();
263     TORCH_CHECK(
264         false,
265         c10::str(
266             "Backend ",
267             backendName,
268             " does not support monitoredBarrier, only GLOO supports monitored barrier."));
269   }
270 
271   // Agrees on an initial sequence number for the whole group by having rank 0
272   // create it and broadcast it to other ranks using the store. Only implemented
273   // for GLOO and NCCL backends currently.
setSequenceNumberForGroup()274   virtual void setSequenceNumberForGroup() {
275     auto backendName = getBackendName();
276     TORCH_CHECK(
277         false,
278         c10::str(
279             "Backend ",
280             backendName,
281             " does not yet support sequence numbers."));
282   }
283 
284   // Retrieves the current sequence number for the whole group, which should be
285   // in sync. If the returned number is not consistent across the group, it
286   // may indicate that there is some sort of collective desynchronization.
getSequenceNumberForGroup()287   virtual uint64_t getSequenceNumberForGroup() {
288     auto backendName = getBackendName();
289     TORCH_CHECK(
290         false,
291         c10::str(
292             "Backend ",
293             backendName,
294             " does not yet support sequence numbers."));
295   }
296 
send(std::vector<at::Tensor> &,int,int)297   virtual c10::intrusive_ptr<Work> send(
298       std::vector<at::Tensor>& /* tensors */,
299       int /* dstRank */,
300       int /* tag */) {
301     TORCH_CHECK(
302         false,
303         c10::str("Backend ", getBackendName(), " does not support send"));
304   }
305 
recv(std::vector<at::Tensor> &,int,int)306   virtual c10::intrusive_ptr<Work> recv(
307       std::vector<at::Tensor>& /* tensors */,
308       int /* srcRank */,
309       int /* tag */) {
310     TORCH_CHECK(
311         false,
312         c10::str("Backend ", getBackendName(), " does not support recv"));
313   }
314 
recvAnysource(std::vector<at::Tensor> &,int)315   virtual c10::intrusive_ptr<Work> recvAnysource(
316       std::vector<at::Tensor>& /* tensors */,
317       int /* tag */) {
318     TORCH_CHECK(
319         false,
320         c10::str(
321             "Backend ", getBackendName(), " does not support recvAnysource"));
322   }
323 
barrier(const BarrierOptions &=BarrierOptions ())324   virtual c10::intrusive_ptr<Work> barrier(
325       const BarrierOptions& /* opts */ = BarrierOptions()) {
326     TORCH_CHECK(
327         false,
328         c10::str("Backend ", getBackendName(), " does not support barrier"));
329   }
330 
registerOnCompletionHook(std::function<void (std::shared_ptr<WorkInfo>)> && hook)331   virtual void registerOnCompletionHook(
332       std::function<void(std::shared_ptr<WorkInfo>)>&& hook) {
333     TORCH_CHECK(
334         false,
335         "Only ProcessGrouppNCCL supports onCompletion hook, but got ",
336         getBackendName(),
337         " backend.");
338   }
339 
waitForPendingWorks()340   virtual void waitForPendingWorks() {
341     TORCH_CHECK(
342         false,
343         "Only ProcessGrouppNCCL supports waitForPendingWorks, but got ",
344         getBackendName(),
345         " backend.");
346   }
347 
enableCollectivesTiming()348   virtual void enableCollectivesTiming() {
349     TORCH_CHECK(
350         false,
351         "Backend ",
352         getBackendName(),
353         " is missing implementation of enableCollectivesTiming.");
354   }
355 
hasHooks() const356   bool hasHooks() const {
357     return onCompletionHook_ != nullptr;
358   }
359 
360   // Do not call this directly, use ProcessGroup::setGroupName instead.
setGroupUid(const std::string & pg_uid)361   void setGroupUid(const std::string& pg_uid) {
362     pg_uid_ = pg_uid;
363   }
364 
getGroupUid() const365   const std::string& getGroupUid() const {
366     return pg_uid_;
367   }
368 
setGroupDesc(const std::string & desc)369   void setGroupDesc(const std::string& desc) {
370     pg_desc_ = desc;
371   }
372 
getGroupDesc() const373   const std::string& getGroupDesc() const {
374     return pg_desc_;
375   }
376 
377   // See similar functions in ProcessGroup.hpp for context.
getBoundDeviceId() const378   std::optional<at::Device> getBoundDeviceId() const {
379     return bound_device_id_;
380   }
381 
382   // Perform an eager connect to the specified device if the backend supports
383   // it.
eagerConnectSingleDevice(at::Device device)384   virtual void eagerConnectSingleDevice(at::Device device) {
385     // no-op in the default case; this is an optimization some
386     // backends may perform
387   }
388 
setBoundDeviceId(std::optional<at::Device> device)389   void setBoundDeviceId(std::optional<at::Device> device) {
390     if (device) {
391       TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index");
392     }
393     bound_device_id_ = device;
394   }
395 
396  protected:
397   // Implementations of this interface need to call this to setup
398   // appropriate logging etc.
399   void init();
400 
401   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
402   const int rank_;
403   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
404   const int size_;
405   // Debug level setting. It is parsed once when ProcessGroup is constructed and
406   // remains the same across use of this process group.
407   DebugLevel dist_debug_level_;
408   std::string pg_uid_;
409   std::string pg_desc_;
410 
411   std::function<void(std::shared_ptr<WorkInfo>)> onCompletionHook_;
412 
413   std::optional<at::Device> bound_device_id_;
414 };
415 
416 } // namespace c10d
417