xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/Ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/dispatch/Dispatcher.h>
2 #include <c10/util/intrusive_ptr.h>
3 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
4 #include <torch/csrc/distributed/c10d/Types.hpp>
5 #include <torch/library.h>
6 
7 namespace c10d {
8 namespace {
9 
TORCH_LIBRARY(c10d,m)10 TORCH_LIBRARY(c10d, m) {
11   // The following ProcessGroup, Work, and ReduceOp definitions are more like
12   // declarations. They don't expose the details of the two classes into
13   // TorchScript.
14   m.class_<ProcessGroup>("ProcessGroup").def(torch::init<int64_t, int64_t>());
15   m.class_<Work>("Work")
16       .def(torch::init<>())
17       .def("wait", [](const c10::intrusive_ptr<Work>& self) { self->wait(); });
18   m.class_<ReduceOp>("ReduceOp").def(torch::init<>());
19   m.def(
20       "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
21   m.def(
22       "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
23   m.def(
24       "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
25   m.def(
26       "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
27   m.def(
28       "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
29   m.def(
30       "allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
31   m.def(
32       "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
33   m.def(
34       "reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
35   m.def(
36       "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
37   m.def(
38       "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
39   m.def(
40       "reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, int timeout) -> __torch__.torch.classes.c10d.Work");
41   m.def(
42       "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int timeout) -> __torch__.torch.classes.c10d.Work");
43   m.def(
44       "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
45   m.def(
46       "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
47   m.def(
48       "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, int timeout) -> __torch__.torch.classes.c10d.Work");
49   m.def(
50       "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout) -> __torch__.torch.classes.c10d.Work");
51   m.def(
52       "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout, bool wait_all_ranks) -> ()");
53   m.def(
54       "send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int dst, int tag) -> __torch__.torch.classes.c10d.Work");
55   m.def(
56       "recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int src, int tag) -> __torch__.torch.classes.c10d.Work");
57   m.def(
58       "recv_any_source_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int tag) -> __torch__.torch.classes.c10d.Work");
59 }
60 } // namespace
61 
62 namespace ops {
63 
64 // Below are ProcessGroup's corresponding ops for each backend. Ops are but
65 // routed through the dispatcher to be dispatched to the appropriate backend.
66 // Currently a no-op as the process group does not have a list of backends.
67 
68 namespace {
69 
70 #define IMPL_SEND(DEV)                                                        \
71   c10::intrusive_ptr<Work> send##DEV(                                         \
72       at::TensorList tensors,                                                 \
73       const c10::intrusive_ptr<ProcessGroup>& process_group,                  \
74       int64_t dstRank,                                                        \
75       int64_t tag) {                                                          \
76     auto tensor_vec = tensors.vec();                                          \
77     return process_group->getBackend(c10::DeviceType::DEV)                    \
78         ->send(tensor_vec, static_cast<int>(dstRank), static_cast<int>(tag)); \
79   }
80 
81 IMPL_SEND(CPU)
IMPL_SEND(CUDA)82 IMPL_SEND(CUDA)
83 IMPL_SEND(PrivateUse1)
84 
85 #define IMPL_RECV(DEV)                                                        \
86   c10::intrusive_ptr<Work> recv_##DEV(                                        \
87       at::TensorList tensors,                                                 \
88       const c10::intrusive_ptr<ProcessGroup>& process_group,                  \
89       int64_t srcRank,                                                        \
90       int64_t tag) {                                                          \
91     auto tensor_vec = tensors.vec();                                          \
92     return process_group->getBackend(c10::DeviceType::DEV)                    \
93         ->recv(tensor_vec, static_cast<int>(srcRank), static_cast<int>(tag)); \
94   }
95 
96 IMPL_RECV(CPU)
97 IMPL_RECV(CUDA)
98 IMPL_RECV(PrivateUse1)
99 
100 #define IMPL_RECV_ANY_SOURCE(DEV)                            \
101   c10::intrusive_ptr<Work> recv_any_source_##DEV(            \
102       at::TensorList tensors,                                \
103       const c10::intrusive_ptr<ProcessGroup>& process_group, \
104       int64_t tag) {                                         \
105     auto tensor_vec = tensors.vec();                         \
106     return process_group->getBackend(c10::DeviceType::DEV)   \
107         ->recvAnysource(tensor_vec, static_cast<int>(tag));  \
108   }
109 
110 IMPL_RECV_ANY_SOURCE(CPU)
111 IMPL_RECV_ANY_SOURCE(CUDA)
112 IMPL_RECV_ANY_SOURCE(PrivateUse1)
113 
114 #define IMPL_REDUCE(DEV)                                     \
115   c10::intrusive_ptr<Work> reduce_##DEV(                     \
116       at::TensorList tensors,                                \
117       const c10::intrusive_ptr<ProcessGroup>& process_group, \
118       const c10::intrusive_ptr<ReduceOp>& reduce_op,         \
119       int64_t root_rank,                                     \
120       int64_t root_tensor,                                   \
121       int64_t timeout) {                                     \
122     auto tensor_vec = tensors.vec();                         \
123     return process_group->getBackend(c10::DeviceType::DEV)   \
124         ->reduce(                                            \
125             tensor_vec,                                      \
126             ReduceOptions{                                   \
127                 *reduce_op.get(),                            \
128                 root_rank,                                   \
129                 root_tensor,                                 \
130                 std::chrono::milliseconds(timeout)});        \
131   }
132 
133 IMPL_REDUCE(CPU)
134 IMPL_REDUCE(CUDA)
135 IMPL_REDUCE(PrivateUse1)
136 
137 #define IMPL_BROADCAST(DEV)                                                   \
138   std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>               \
139       broadcast_##DEV(                                                        \
140           at::TensorList tensors,                                             \
141           const c10::intrusive_ptr<ProcessGroup>& process_group,              \
142           int64_t root_rank,                                                  \
143           int64_t root_tensor,                                                \
144           bool asyncOp,                                                       \
145           int64_t timeout) {                                                  \
146     auto tensor_vec = tensors.vec();                                          \
147     auto work = process_group->getBackend(c10::DeviceType::DEV) -> broadcast( \
148         tensor_vec,                                                           \
149         BroadcastOptions{                                                     \
150             root_rank,                                                        \
151             root_tensor,                                                      \
152             std::chrono::milliseconds(timeout),                               \
153             asyncOp});                                                        \
154     return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(     \
155         std::move(tensor_vec), work);                                         \
156   }
157 
158 IMPL_BROADCAST(CPU)
159 IMPL_BROADCAST(CUDA)
160 IMPL_BROADCAST(PrivateUse1)
161 
162 // Return input tensors as output tensors to make inplace allreduce look like
163 // a functional API, so that make_fx can correctly build the dependencies in
164 // the graph later.
165 #define IMPL_ALLREDUCE(DEV)                                                   \
166   std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>               \
167       allreduce_##DEV(                                                        \
168           at::TensorList tensors,                                             \
169           const c10::intrusive_ptr<ProcessGroup>& process_group,              \
170           const c10::intrusive_ptr<ReduceOp>& reduce_op,                      \
171           const std::optional<at::Tensor>& sparse_indices,                    \
172           int64_t timeout) {                                                  \
173     auto tensor_vec = tensors.vec();                                          \
174     auto work = process_group->getBackend(c10::DeviceType::DEV) -> allreduce( \
175         tensor_vec,                                                           \
176         AllreduceOptions{                                                     \
177             *reduce_op.get(), std::chrono::milliseconds(timeout)});           \
178     return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(     \
179         std::move(tensor_vec), work);                                         \
180   }
181 
182 IMPL_ALLREDUCE(CPU)
183 IMPL_ALLREDUCE(CUDA)
184 IMPL_ALLREDUCE(PrivateUse1)
185 
186 #define IMPL_ALLREDUCE_COALESCED(DEV)                             \
187   c10::intrusive_ptr<Work> allreduce_coalesced_##DEV(             \
188       at::TensorList tensors,                                     \
189       const c10::intrusive_ptr<ProcessGroup>& process_group,      \
190       const c10::intrusive_ptr<ReduceOp>& reduce_op,              \
191       int64_t timeout) {                                          \
192     auto tensor_vec = tensors.vec();                              \
193     AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; \
194     opts.reduceOp = *reduce_op.get();                             \
195     opts.timeout = std::chrono::milliseconds(timeout);            \
196     return process_group->getBackend(c10::DeviceType::DEV)        \
197         ->allreduce_coalesced(tensor_vec, opts);                  \
198   }
199 
200 IMPL_ALLREDUCE_COALESCED(CPU)
201 IMPL_ALLREDUCE_COALESCED(CUDA)
202 IMPL_ALLREDUCE_COALESCED(PrivateUse1)
203 
204 // Copy output tensors (not storage) so that this can be used in a functional
205 // manner
206 #define IMPL_ALLGATHER(DEV)                                                    \
207   std::tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>   \
208       allgather_##DEV(                                                         \
209           const std::vector<std::vector<at::Tensor>>& output_tensors,          \
210           at::TensorList input_tensors,                                        \
211           const c10::intrusive_ptr<ProcessGroup>& process_group,               \
212           int64_t timeout) {                                                   \
213     auto input_tensors_vec = input_tensors.vec();                              \
214     auto work = process_group->getBackend(c10::DeviceType::DEV) -> allgather(  \
215         const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),     \
216         input_tensors_vec,                                                     \
217         AllgatherOptions{std::chrono::milliseconds(timeout)});                 \
218     return std::                                                               \
219         tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>( \
220             output_tensors, work);                                             \
221   }
222 
223 // NOLINTBEGIN(cppcoreguidelines-pro-type-const-cast)
224 IMPL_ALLGATHER(CPU)
225 IMPL_ALLGATHER(CUDA)
226 IMPL_ALLGATHER(PrivateUse1)
227 
228 #define IMPL__ALLGATHER_BASE(DEV)                                           \
229   std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _allgather_base_##DEV(   \
230       at::Tensor& output_tensor,                                            \
231       at::Tensor& input_tensor,                                             \
232       const c10::intrusive_ptr<ProcessGroup>& process_group,                \
233       bool asyncOp,                                                         \
234       int64_t timeout) {                                                    \
235     auto work =                                                             \
236         process_group->getBackend(c10::DeviceType::DEV) -> _allgather_base( \
237             output_tensor,                                                  \
238             input_tensor,                                                   \
239             AllgatherOptions{std::chrono::milliseconds(timeout), asyncOp}); \
240     return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(                \
241         output_tensor, work);                                               \
242   }
243 
244 IMPL__ALLGATHER_BASE(CPU)
245 IMPL__ALLGATHER_BASE(CUDA)
246 IMPL__ALLGATHER_BASE(PrivateUse1)
247 
248 #define IMPL_ALLGATHER_COALESCED(DEV)                                        \
249   c10::intrusive_ptr<Work> allgather_coalesced_##DEV(                        \
250       const std::vector<std::vector<at::Tensor>>& output_lists,              \
251       const at::TensorList& input_list,                                      \
252       const c10::intrusive_ptr<ProcessGroup>& process_group) {               \
253     auto input_list_vec = input_list.vec();                                  \
254     return process_group->getBackend(c10::DeviceType::DEV)                   \
255         ->allgather_coalesced(                                               \
256             const_cast<std::vector<std::vector<at::Tensor>>&>(output_lists), \
257             input_list_vec);                                                 \
258   }
259 
260 IMPL_ALLGATHER_COALESCED(CPU)
261 IMPL_ALLGATHER_COALESCED(CUDA)
262 IMPL_ALLGATHER_COALESCED(PrivateUse1)
263 
264 #define IMPL_ALLGATHER_INTO_TENSOR_COALESCED(DEV)                       \
265   c10::intrusive_ptr<c10d::Work> allgather_into_tensor_coalesced_##DEV( \
266       at::TensorList outputs,                                           \
267       at::TensorList inputs,                                            \
268       const c10::intrusive_ptr<ProcessGroup>& process_group) {          \
269     auto output_vec = outputs.vec();                                    \
270     auto input_vec = inputs.vec();                                      \
271     return process_group->getBackend(c10::DeviceType::DEV)              \
272         ->allgather_into_tensor_coalesced(output_vec, input_vec);       \
273   }
274 
275 IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU)
276 IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CUDA)
277 IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1)
278 
279 #define IMPL_REDUCE_SCATTER(DEV)                                              \
280   std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>               \
281       reduce_scatter_##DEV(                                                   \
282           const at::TensorList& output_tensors,                               \
283           const std::vector<std::vector<at::Tensor>>& input_tensors,          \
284           const c10::intrusive_ptr<ProcessGroup>& process_group,              \
285           const c10::intrusive_ptr<ReduceOp>& reduce_op,                      \
286           int64_t timeout) {                                                  \
287     auto output_tensors_vec = output_tensors.vec();                           \
288     auto work =                                                               \
289         process_group->getBackend(c10::DeviceType::DEV) -> reduce_scatter(    \
290             output_tensors_vec,                                               \
291             const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors), \
292             ReduceScatterOptions{                                             \
293                 *reduce_op.get(), std::chrono::milliseconds(timeout)});       \
294     return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(     \
295         output_tensors_vec, work);                                            \
296   }
297 
298 IMPL_REDUCE_SCATTER(CPU)
299 IMPL_REDUCE_SCATTER(CUDA)
300 IMPL_REDUCE_SCATTER(PrivateUse1)
301 
302 #define IMPL__REDUCE_SCATTER_BASE(DEV)                                         \
303   std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _reduce_scatter_base_##DEV( \
304       at::Tensor& output_tensor,                                               \
305       at::Tensor& input_tensor,                                                \
306       const c10::intrusive_ptr<ProcessGroup>& process_group,                   \
307       const c10::intrusive_ptr<ReduceOp>& reduce_op,                           \
308       bool asyncOp,                                                            \
309       int64_t timeout) {                                                       \
310     auto work = process_group->getBackend(c10::DeviceType::DEV)                \
311                     -> _reduce_scatter_base(                                   \
312                         output_tensor,                                         \
313                         input_tensor,                                          \
314                         ReduceScatterOptions{                                  \
315                             *reduce_op.get(),                                  \
316                             std::chrono::milliseconds(timeout),                \
317                             asyncOp});                                         \
318     return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(                   \
319         output_tensor, work);                                                  \
320   }
321 
322 IMPL__REDUCE_SCATTER_BASE(CPU)
323 IMPL__REDUCE_SCATTER_BASE(CUDA)
324 IMPL__REDUCE_SCATTER_BASE(PrivateUse1)
325 
326 #define IMPL_REDUCE_SCATTER_TENSOR_COALESCED(DEV)                       \
327   c10::intrusive_ptr<c10d::Work> reduce_scatter_tensor_coalesced_##DEV( \
328       at::TensorList outputs,                                           \
329       at::TensorList inputs,                                            \
330       const c10::intrusive_ptr<ProcessGroup>& process_group,            \
331       const c10::intrusive_ptr<ReduceOp>& reduce_op,                    \
332       int64_t timeout) {                                                \
333     auto output_vec = outputs.vec();                                    \
334     auto input_vec = inputs.vec();                                      \
335     return process_group->getBackend(c10::DeviceType::DEV)              \
336         ->reduce_scatter_tensor_coalesced(                              \
337             output_vec,                                                 \
338             input_vec,                                                  \
339             ReduceScatterOptions{                                       \
340                 *reduce_op.get(), std::chrono::milliseconds(timeout)}); \
341   }
342 
343 IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU)
344 IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CUDA)
345 IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1)
346 
347 #define IMPL_GATHER(DEV)                                                       \
348   c10::intrusive_ptr<Work> gather_##DEV(                                       \
349       const std::vector<std::vector<at::Tensor>>& output_tensors,              \
350       const at::TensorList& input_tensors,                                     \
351       const c10::intrusive_ptr<ProcessGroup>& process_group,                   \
352       int64_t root_rank,                                                       \
353       int64_t timeout) {                                                       \
354     auto input_tensors_vec = input_tensors.vec();                              \
355     return process_group->getBackend(c10::DeviceType::DEV)                     \
356         ->gather(                                                              \
357             const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), \
358             input_tensors_vec,                                                 \
359             GatherOptions{root_rank, std::chrono::milliseconds(timeout)});     \
360   }
361 
362 IMPL_GATHER(CPU)
363 IMPL_GATHER(CUDA)
364 IMPL_GATHER(PrivateUse1)
365 
366 #define IMPL_SCATTER(DEV)                                                      \
367   std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> scatter_##DEV( \
368       const at::TensorList& output_tensors,                                    \
369       const std::vector<std::vector<at::Tensor>>& input_tensors,               \
370       const c10::intrusive_ptr<ProcessGroup>& process_group,                   \
371       int64_t root_rank,                                                       \
372       bool asyncOp,                                                            \
373       int64_t timeout) {                                                       \
374     auto output_tensors_vec = output_tensors.vec();                            \
375     auto work = process_group->getBackend(c10::DeviceType::DEV) -> scatter(    \
376         output_tensors_vec,                                                    \
377         const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors),      \
378         ScatterOptions{                                                        \
379             root_rank, std::chrono::milliseconds(timeout), asyncOp});          \
380     return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(      \
381         std::move(output_tensors_vec), work);                                  \
382   }
383 
384 IMPL_SCATTER(CPU)
385 IMPL_SCATTER(CUDA)
386 IMPL_SCATTER(PrivateUse1)
387 
388 #define IMPL_ALLTOALL(DEV)                                                   \
389   std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>              \
390       alltoall_##DEV(                                                        \
391           const at::TensorList& output_tensors,                              \
392           const at::TensorList& input_tensors,                               \
393           const c10::intrusive_ptr<ProcessGroup>& process_group,             \
394           int64_t timeout) {                                                 \
395     auto output_tensors_vec = output_tensors.vec();                          \
396     auto input_tensors_vec = input_tensors.vec();                            \
397     auto work = process_group->getBackend(c10::DeviceType::DEV) -> alltoall( \
398         output_tensors_vec,                                                  \
399         input_tensors_vec,                                                   \
400         AllToAllOptions{std::chrono::milliseconds(timeout)});                \
401     return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(    \
402         std::move(output_tensors_vec), work);                                \
403   }
404 
405 IMPL_ALLTOALL(CPU)
406 IMPL_ALLTOALL(CUDA)
407 IMPL_ALLTOALL(PrivateUse1)
408 
409 #define IMPL_ALLTOALL_BASE(DEV)                                   \
410   c10::intrusive_ptr<Work> alltoall_base_##DEV(                   \
411       at::Tensor& output,                                         \
412       at::Tensor& input,                                          \
413       const c10::intrusive_ptr<ProcessGroup>& process_group,      \
414       std::vector<int64_t> output_split_sizes,                    \
415       std::vector<int64_t> input_split_sizes,                     \
416       int64_t timeout) {                                          \
417     return process_group->getBackend(c10::DeviceType::DEV)        \
418         ->alltoall_base(                                          \
419             output,                                               \
420             input,                                                \
421             output_split_sizes,                                   \
422             input_split_sizes,                                    \
423             AllToAllOptions{std::chrono::milliseconds(timeout)}); \
424   }
425 
426 IMPL_ALLTOALL_BASE(CPU)
427 IMPL_ALLTOALL_BASE(CUDA)
428 IMPL_ALLTOALL_BASE(PrivateUse1)
429 
430 #define IMPL_BARRIER(DEV)                                                    \
431   c10::intrusive_ptr<Work> barrier##DEV(                                     \
432       at::Tensor /* unused */,                                               \
433       const c10::intrusive_ptr<ProcessGroup>& process_group,                 \
434       const std::vector<int64_t>& device_ids,                                \
435       int64_t timeout) {                                                     \
436     return process_group->getBackend(c10::DeviceType::DEV)                   \
437         ->barrier(                                                           \
438             BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); \
439   }
440 
441 IMPL_BARRIER(CPU)
442 IMPL_BARRIER(CUDA)
443 IMPL_BARRIER(PrivateUse1)
444 // NOLINTEND(cppcoreguidelines-pro-type-const-cast)
445 
446 void monitored_barrier_CPU(
447     at::Tensor /* unused */,
448     const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group,
449     const std::vector<int64_t>& device_ids,
450     int64_t timeout,
451     bool wait_all_ranks) {
452   process_group->getBackend(c10::DeviceType::CPU)
453       ->monitoredBarrier(
454           BarrierOptions{device_ids, std::chrono::milliseconds(timeout)},
455           wait_all_ranks);
456 }
457 
458 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>
allreduce_sparse_cuda_(at::TensorList tensors,const c10::intrusive_ptr<ProcessGroup> & process_group,const c10::intrusive_ptr<ReduceOp> & reduce_op,const std::optional<at::Tensor> & sparse_indices,int64_t timeout)459 allreduce_sparse_cuda_(
460     at::TensorList tensors,
461     const c10::intrusive_ptr<ProcessGroup>& process_group,
462     const c10::intrusive_ptr<ReduceOp>& reduce_op,
463     const std::optional<at::Tensor>& sparse_indices,
464     int64_t timeout) {
465   auto tensor_vec = tensors.vec();
466   auto work = process_group->getBackend(c10::DeviceType::CUDA)
467                   ->allreduce_sparse(
468                       tensor_vec,
469                       AllreduceOptions{
470                           *reduce_op,
471                           std::chrono::milliseconds(timeout),
472                           sparse_indices});
473 
474   return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
475       std::move(tensor_vec), work);
476 }
477 } // namespace
478 
479 // register functions to dispatcher
480 namespace {
481 
482 // 2nd level expansion
483 // FUNC: op name
484 // DEV: device
485 #define REGISTER_C10D_OP1(FUNC, DEV) \
486   TORCH_LIBRARY_IMPL(c10d, DEV, m) { \
487     m.impl(#FUNC, FUNC##DEV);        \
488   }
489 
490 // 1st level expansion
491 #define REGISTER_C10D_OP(FUNC)  \
492   REGISTER_C10D_OP1(FUNC, CPU)  \
493   REGISTER_C10D_OP1(FUNC, CUDA) \
494   REGISTER_C10D_OP1(FUNC, PrivateUse1)
495 
496 // Now we start to register ops with the three device keys
497 
498 REGISTER_C10D_OP(send)
REGISTER_C10D_OP(recv_)499 REGISTER_C10D_OP(recv_)
500 REGISTER_C10D_OP(recv_any_source_)
501 REGISTER_C10D_OP(reduce_)
502 REGISTER_C10D_OP(broadcast_)
503 REGISTER_C10D_OP(allreduce_)
504 REGISTER_C10D_OP(allreduce_coalesced_)
505 REGISTER_C10D_OP(allgather_)
506 REGISTER_C10D_OP(_allgather_base_)
507 REGISTER_C10D_OP(allgather_coalesced_)
508 REGISTER_C10D_OP(allgather_into_tensor_coalesced_)
509 REGISTER_C10D_OP(reduce_scatter_)
510 REGISTER_C10D_OP(_reduce_scatter_base_)
511 REGISTER_C10D_OP(reduce_scatter_tensor_coalesced_)
512 REGISTER_C10D_OP(gather_)
513 REGISTER_C10D_OP(scatter_)
514 REGISTER_C10D_OP(alltoall_)
515 REGISTER_C10D_OP(alltoall_base_)
516 REGISTER_C10D_OP(barrier)
517 
518 // The following ops are specialized, register them separately
519 
520 TORCH_LIBRARY_IMPL(c10d, CPU, m) {
521   m.impl("monitored_barrier_", monitored_barrier_CPU);
522 }
523 
524 // TODO: The SparseCPU/SparseCUDA dispatched methods are only used to support
525 // sparse all_reduce in the Gloo backend
TORCH_LIBRARY_IMPL(c10d,SparseCPU,m)526 TORCH_LIBRARY_IMPL(c10d, SparseCPU, m) {
527   m.impl("allreduce_", allreduce_CPU);
528 }
529 
TORCH_LIBRARY_IMPL(c10d,SparseCUDA,m)530 TORCH_LIBRARY_IMPL(c10d, SparseCUDA, m) {
531   m.impl("allreduce_", allreduce_sparse_cuda_);
532 }
533 
534 } // namespace
535 
536 } // namespace ops
537 } // namespace c10d
538