1 #include <ATen/core/dispatch/Dispatcher.h>
2 #include <c10/util/intrusive_ptr.h>
3 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
4 #include <torch/csrc/distributed/c10d/Types.hpp>
5 #include <torch/library.h>
6
7 namespace c10d {
8 namespace {
9
TORCH_LIBRARY(c10d,m)10 TORCH_LIBRARY(c10d, m) {
11 // The following ProcessGroup, Work, and ReduceOp definitions are more like
12 // declarations. They don't expose the details of the two classes into
13 // TorchScript.
14 m.class_<ProcessGroup>("ProcessGroup").def(torch::init<int64_t, int64_t>());
15 m.class_<Work>("Work")
16 .def(torch::init<>())
17 .def("wait", [](const c10::intrusive_ptr<Work>& self) { self->wait(); });
18 m.class_<ReduceOp>("ReduceOp").def(torch::init<>());
19 m.def(
20 "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
21 m.def(
22 "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
23 m.def(
24 "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
25 m.def(
26 "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
27 m.def(
28 "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
29 m.def(
30 "allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
31 m.def(
32 "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
33 m.def(
34 "reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
35 m.def(
36 "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
37 m.def(
38 "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
39 m.def(
40 "reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, int timeout) -> __torch__.torch.classes.c10d.Work");
41 m.def(
42 "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int timeout) -> __torch__.torch.classes.c10d.Work");
43 m.def(
44 "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
45 m.def(
46 "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
47 m.def(
48 "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, int timeout) -> __torch__.torch.classes.c10d.Work");
49 m.def(
50 "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout) -> __torch__.torch.classes.c10d.Work");
51 m.def(
52 "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout, bool wait_all_ranks) -> ()");
53 m.def(
54 "send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int dst, int tag) -> __torch__.torch.classes.c10d.Work");
55 m.def(
56 "recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int src, int tag) -> __torch__.torch.classes.c10d.Work");
57 m.def(
58 "recv_any_source_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int tag) -> __torch__.torch.classes.c10d.Work");
59 }
60 } // namespace
61
62 namespace ops {
63
64 // Below are ProcessGroup's corresponding ops for each backend. Ops are but
65 // routed through the dispatcher to be dispatched to the appropriate backend.
66 // Currently a no-op as the process group does not have a list of backends.
67
68 namespace {
69
70 #define IMPL_SEND(DEV) \
71 c10::intrusive_ptr<Work> send##DEV( \
72 at::TensorList tensors, \
73 const c10::intrusive_ptr<ProcessGroup>& process_group, \
74 int64_t dstRank, \
75 int64_t tag) { \
76 auto tensor_vec = tensors.vec(); \
77 return process_group->getBackend(c10::DeviceType::DEV) \
78 ->send(tensor_vec, static_cast<int>(dstRank), static_cast<int>(tag)); \
79 }
80
81 IMPL_SEND(CPU)
IMPL_SEND(CUDA)82 IMPL_SEND(CUDA)
83 IMPL_SEND(PrivateUse1)
84
85 #define IMPL_RECV(DEV) \
86 c10::intrusive_ptr<Work> recv_##DEV( \
87 at::TensorList tensors, \
88 const c10::intrusive_ptr<ProcessGroup>& process_group, \
89 int64_t srcRank, \
90 int64_t tag) { \
91 auto tensor_vec = tensors.vec(); \
92 return process_group->getBackend(c10::DeviceType::DEV) \
93 ->recv(tensor_vec, static_cast<int>(srcRank), static_cast<int>(tag)); \
94 }
95
96 IMPL_RECV(CPU)
97 IMPL_RECV(CUDA)
98 IMPL_RECV(PrivateUse1)
99
100 #define IMPL_RECV_ANY_SOURCE(DEV) \
101 c10::intrusive_ptr<Work> recv_any_source_##DEV( \
102 at::TensorList tensors, \
103 const c10::intrusive_ptr<ProcessGroup>& process_group, \
104 int64_t tag) { \
105 auto tensor_vec = tensors.vec(); \
106 return process_group->getBackend(c10::DeviceType::DEV) \
107 ->recvAnysource(tensor_vec, static_cast<int>(tag)); \
108 }
109
110 IMPL_RECV_ANY_SOURCE(CPU)
111 IMPL_RECV_ANY_SOURCE(CUDA)
112 IMPL_RECV_ANY_SOURCE(PrivateUse1)
113
114 #define IMPL_REDUCE(DEV) \
115 c10::intrusive_ptr<Work> reduce_##DEV( \
116 at::TensorList tensors, \
117 const c10::intrusive_ptr<ProcessGroup>& process_group, \
118 const c10::intrusive_ptr<ReduceOp>& reduce_op, \
119 int64_t root_rank, \
120 int64_t root_tensor, \
121 int64_t timeout) { \
122 auto tensor_vec = tensors.vec(); \
123 return process_group->getBackend(c10::DeviceType::DEV) \
124 ->reduce( \
125 tensor_vec, \
126 ReduceOptions{ \
127 *reduce_op.get(), \
128 root_rank, \
129 root_tensor, \
130 std::chrono::milliseconds(timeout)}); \
131 }
132
133 IMPL_REDUCE(CPU)
134 IMPL_REDUCE(CUDA)
135 IMPL_REDUCE(PrivateUse1)
136
137 #define IMPL_BROADCAST(DEV) \
138 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> \
139 broadcast_##DEV( \
140 at::TensorList tensors, \
141 const c10::intrusive_ptr<ProcessGroup>& process_group, \
142 int64_t root_rank, \
143 int64_t root_tensor, \
144 bool asyncOp, \
145 int64_t timeout) { \
146 auto tensor_vec = tensors.vec(); \
147 auto work = process_group->getBackend(c10::DeviceType::DEV) -> broadcast( \
148 tensor_vec, \
149 BroadcastOptions{ \
150 root_rank, \
151 root_tensor, \
152 std::chrono::milliseconds(timeout), \
153 asyncOp}); \
154 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
155 std::move(tensor_vec), work); \
156 }
157
158 IMPL_BROADCAST(CPU)
159 IMPL_BROADCAST(CUDA)
160 IMPL_BROADCAST(PrivateUse1)
161
162 // Return input tensors as output tensors to make inplace allreduce look like
163 // a functional API, so that make_fx can correctly build the dependencies in
164 // the graph later.
165 #define IMPL_ALLREDUCE(DEV) \
166 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> \
167 allreduce_##DEV( \
168 at::TensorList tensors, \
169 const c10::intrusive_ptr<ProcessGroup>& process_group, \
170 const c10::intrusive_ptr<ReduceOp>& reduce_op, \
171 const std::optional<at::Tensor>& sparse_indices, \
172 int64_t timeout) { \
173 auto tensor_vec = tensors.vec(); \
174 auto work = process_group->getBackend(c10::DeviceType::DEV) -> allreduce( \
175 tensor_vec, \
176 AllreduceOptions{ \
177 *reduce_op.get(), std::chrono::milliseconds(timeout)}); \
178 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
179 std::move(tensor_vec), work); \
180 }
181
182 IMPL_ALLREDUCE(CPU)
183 IMPL_ALLREDUCE(CUDA)
184 IMPL_ALLREDUCE(PrivateUse1)
185
186 #define IMPL_ALLREDUCE_COALESCED(DEV) \
187 c10::intrusive_ptr<Work> allreduce_coalesced_##DEV( \
188 at::TensorList tensors, \
189 const c10::intrusive_ptr<ProcessGroup>& process_group, \
190 const c10::intrusive_ptr<ReduceOp>& reduce_op, \
191 int64_t timeout) { \
192 auto tensor_vec = tensors.vec(); \
193 AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; \
194 opts.reduceOp = *reduce_op.get(); \
195 opts.timeout = std::chrono::milliseconds(timeout); \
196 return process_group->getBackend(c10::DeviceType::DEV) \
197 ->allreduce_coalesced(tensor_vec, opts); \
198 }
199
200 IMPL_ALLREDUCE_COALESCED(CPU)
201 IMPL_ALLREDUCE_COALESCED(CUDA)
202 IMPL_ALLREDUCE_COALESCED(PrivateUse1)
203
204 // Copy output tensors (not storage) so that this can be used in a functional
205 // manner
206 #define IMPL_ALLGATHER(DEV) \
207 std::tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>> \
208 allgather_##DEV( \
209 const std::vector<std::vector<at::Tensor>>& output_tensors, \
210 at::TensorList input_tensors, \
211 const c10::intrusive_ptr<ProcessGroup>& process_group, \
212 int64_t timeout) { \
213 auto input_tensors_vec = input_tensors.vec(); \
214 auto work = process_group->getBackend(c10::DeviceType::DEV) -> allgather( \
215 const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), \
216 input_tensors_vec, \
217 AllgatherOptions{std::chrono::milliseconds(timeout)}); \
218 return std:: \
219 tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>( \
220 output_tensors, work); \
221 }
222
223 // NOLINTBEGIN(cppcoreguidelines-pro-type-const-cast)
224 IMPL_ALLGATHER(CPU)
225 IMPL_ALLGATHER(CUDA)
226 IMPL_ALLGATHER(PrivateUse1)
227
228 #define IMPL__ALLGATHER_BASE(DEV) \
229 std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _allgather_base_##DEV( \
230 at::Tensor& output_tensor, \
231 at::Tensor& input_tensor, \
232 const c10::intrusive_ptr<ProcessGroup>& process_group, \
233 bool asyncOp, \
234 int64_t timeout) { \
235 auto work = \
236 process_group->getBackend(c10::DeviceType::DEV) -> _allgather_base( \
237 output_tensor, \
238 input_tensor, \
239 AllgatherOptions{std::chrono::milliseconds(timeout), asyncOp}); \
240 return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>( \
241 output_tensor, work); \
242 }
243
244 IMPL__ALLGATHER_BASE(CPU)
245 IMPL__ALLGATHER_BASE(CUDA)
246 IMPL__ALLGATHER_BASE(PrivateUse1)
247
248 #define IMPL_ALLGATHER_COALESCED(DEV) \
249 c10::intrusive_ptr<Work> allgather_coalesced_##DEV( \
250 const std::vector<std::vector<at::Tensor>>& output_lists, \
251 const at::TensorList& input_list, \
252 const c10::intrusive_ptr<ProcessGroup>& process_group) { \
253 auto input_list_vec = input_list.vec(); \
254 return process_group->getBackend(c10::DeviceType::DEV) \
255 ->allgather_coalesced( \
256 const_cast<std::vector<std::vector<at::Tensor>>&>(output_lists), \
257 input_list_vec); \
258 }
259
260 IMPL_ALLGATHER_COALESCED(CPU)
261 IMPL_ALLGATHER_COALESCED(CUDA)
262 IMPL_ALLGATHER_COALESCED(PrivateUse1)
263
264 #define IMPL_ALLGATHER_INTO_TENSOR_COALESCED(DEV) \
265 c10::intrusive_ptr<c10d::Work> allgather_into_tensor_coalesced_##DEV( \
266 at::TensorList outputs, \
267 at::TensorList inputs, \
268 const c10::intrusive_ptr<ProcessGroup>& process_group) { \
269 auto output_vec = outputs.vec(); \
270 auto input_vec = inputs.vec(); \
271 return process_group->getBackend(c10::DeviceType::DEV) \
272 ->allgather_into_tensor_coalesced(output_vec, input_vec); \
273 }
274
275 IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU)
276 IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CUDA)
277 IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1)
278
279 #define IMPL_REDUCE_SCATTER(DEV) \
280 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> \
281 reduce_scatter_##DEV( \
282 const at::TensorList& output_tensors, \
283 const std::vector<std::vector<at::Tensor>>& input_tensors, \
284 const c10::intrusive_ptr<ProcessGroup>& process_group, \
285 const c10::intrusive_ptr<ReduceOp>& reduce_op, \
286 int64_t timeout) { \
287 auto output_tensors_vec = output_tensors.vec(); \
288 auto work = \
289 process_group->getBackend(c10::DeviceType::DEV) -> reduce_scatter( \
290 output_tensors_vec, \
291 const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors), \
292 ReduceScatterOptions{ \
293 *reduce_op.get(), std::chrono::milliseconds(timeout)}); \
294 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
295 output_tensors_vec, work); \
296 }
297
298 IMPL_REDUCE_SCATTER(CPU)
299 IMPL_REDUCE_SCATTER(CUDA)
300 IMPL_REDUCE_SCATTER(PrivateUse1)
301
302 #define IMPL__REDUCE_SCATTER_BASE(DEV) \
303 std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _reduce_scatter_base_##DEV( \
304 at::Tensor& output_tensor, \
305 at::Tensor& input_tensor, \
306 const c10::intrusive_ptr<ProcessGroup>& process_group, \
307 const c10::intrusive_ptr<ReduceOp>& reduce_op, \
308 bool asyncOp, \
309 int64_t timeout) { \
310 auto work = process_group->getBackend(c10::DeviceType::DEV) \
311 -> _reduce_scatter_base( \
312 output_tensor, \
313 input_tensor, \
314 ReduceScatterOptions{ \
315 *reduce_op.get(), \
316 std::chrono::milliseconds(timeout), \
317 asyncOp}); \
318 return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>( \
319 output_tensor, work); \
320 }
321
322 IMPL__REDUCE_SCATTER_BASE(CPU)
323 IMPL__REDUCE_SCATTER_BASE(CUDA)
324 IMPL__REDUCE_SCATTER_BASE(PrivateUse1)
325
326 #define IMPL_REDUCE_SCATTER_TENSOR_COALESCED(DEV) \
327 c10::intrusive_ptr<c10d::Work> reduce_scatter_tensor_coalesced_##DEV( \
328 at::TensorList outputs, \
329 at::TensorList inputs, \
330 const c10::intrusive_ptr<ProcessGroup>& process_group, \
331 const c10::intrusive_ptr<ReduceOp>& reduce_op, \
332 int64_t timeout) { \
333 auto output_vec = outputs.vec(); \
334 auto input_vec = inputs.vec(); \
335 return process_group->getBackend(c10::DeviceType::DEV) \
336 ->reduce_scatter_tensor_coalesced( \
337 output_vec, \
338 input_vec, \
339 ReduceScatterOptions{ \
340 *reduce_op.get(), std::chrono::milliseconds(timeout)}); \
341 }
342
343 IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU)
344 IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CUDA)
345 IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1)
346
347 #define IMPL_GATHER(DEV) \
348 c10::intrusive_ptr<Work> gather_##DEV( \
349 const std::vector<std::vector<at::Tensor>>& output_tensors, \
350 const at::TensorList& input_tensors, \
351 const c10::intrusive_ptr<ProcessGroup>& process_group, \
352 int64_t root_rank, \
353 int64_t timeout) { \
354 auto input_tensors_vec = input_tensors.vec(); \
355 return process_group->getBackend(c10::DeviceType::DEV) \
356 ->gather( \
357 const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), \
358 input_tensors_vec, \
359 GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); \
360 }
361
362 IMPL_GATHER(CPU)
363 IMPL_GATHER(CUDA)
364 IMPL_GATHER(PrivateUse1)
365
366 #define IMPL_SCATTER(DEV) \
367 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> scatter_##DEV( \
368 const at::TensorList& output_tensors, \
369 const std::vector<std::vector<at::Tensor>>& input_tensors, \
370 const c10::intrusive_ptr<ProcessGroup>& process_group, \
371 int64_t root_rank, \
372 bool asyncOp, \
373 int64_t timeout) { \
374 auto output_tensors_vec = output_tensors.vec(); \
375 auto work = process_group->getBackend(c10::DeviceType::DEV) -> scatter( \
376 output_tensors_vec, \
377 const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors), \
378 ScatterOptions{ \
379 root_rank, std::chrono::milliseconds(timeout), asyncOp}); \
380 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
381 std::move(output_tensors_vec), work); \
382 }
383
384 IMPL_SCATTER(CPU)
385 IMPL_SCATTER(CUDA)
386 IMPL_SCATTER(PrivateUse1)
387
388 #define IMPL_ALLTOALL(DEV) \
389 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> \
390 alltoall_##DEV( \
391 const at::TensorList& output_tensors, \
392 const at::TensorList& input_tensors, \
393 const c10::intrusive_ptr<ProcessGroup>& process_group, \
394 int64_t timeout) { \
395 auto output_tensors_vec = output_tensors.vec(); \
396 auto input_tensors_vec = input_tensors.vec(); \
397 auto work = process_group->getBackend(c10::DeviceType::DEV) -> alltoall( \
398 output_tensors_vec, \
399 input_tensors_vec, \
400 AllToAllOptions{std::chrono::milliseconds(timeout)}); \
401 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
402 std::move(output_tensors_vec), work); \
403 }
404
405 IMPL_ALLTOALL(CPU)
406 IMPL_ALLTOALL(CUDA)
407 IMPL_ALLTOALL(PrivateUse1)
408
409 #define IMPL_ALLTOALL_BASE(DEV) \
410 c10::intrusive_ptr<Work> alltoall_base_##DEV( \
411 at::Tensor& output, \
412 at::Tensor& input, \
413 const c10::intrusive_ptr<ProcessGroup>& process_group, \
414 std::vector<int64_t> output_split_sizes, \
415 std::vector<int64_t> input_split_sizes, \
416 int64_t timeout) { \
417 return process_group->getBackend(c10::DeviceType::DEV) \
418 ->alltoall_base( \
419 output, \
420 input, \
421 output_split_sizes, \
422 input_split_sizes, \
423 AllToAllOptions{std::chrono::milliseconds(timeout)}); \
424 }
425
426 IMPL_ALLTOALL_BASE(CPU)
427 IMPL_ALLTOALL_BASE(CUDA)
428 IMPL_ALLTOALL_BASE(PrivateUse1)
429
430 #define IMPL_BARRIER(DEV) \
431 c10::intrusive_ptr<Work> barrier##DEV( \
432 at::Tensor /* unused */, \
433 const c10::intrusive_ptr<ProcessGroup>& process_group, \
434 const std::vector<int64_t>& device_ids, \
435 int64_t timeout) { \
436 return process_group->getBackend(c10::DeviceType::DEV) \
437 ->barrier( \
438 BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); \
439 }
440
441 IMPL_BARRIER(CPU)
442 IMPL_BARRIER(CUDA)
443 IMPL_BARRIER(PrivateUse1)
444 // NOLINTEND(cppcoreguidelines-pro-type-const-cast)
445
446 void monitored_barrier_CPU(
447 at::Tensor /* unused */,
448 const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group,
449 const std::vector<int64_t>& device_ids,
450 int64_t timeout,
451 bool wait_all_ranks) {
452 process_group->getBackend(c10::DeviceType::CPU)
453 ->monitoredBarrier(
454 BarrierOptions{device_ids, std::chrono::milliseconds(timeout)},
455 wait_all_ranks);
456 }
457
458 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>
allreduce_sparse_cuda_(at::TensorList tensors,const c10::intrusive_ptr<ProcessGroup> & process_group,const c10::intrusive_ptr<ReduceOp> & reduce_op,const std::optional<at::Tensor> & sparse_indices,int64_t timeout)459 allreduce_sparse_cuda_(
460 at::TensorList tensors,
461 const c10::intrusive_ptr<ProcessGroup>& process_group,
462 const c10::intrusive_ptr<ReduceOp>& reduce_op,
463 const std::optional<at::Tensor>& sparse_indices,
464 int64_t timeout) {
465 auto tensor_vec = tensors.vec();
466 auto work = process_group->getBackend(c10::DeviceType::CUDA)
467 ->allreduce_sparse(
468 tensor_vec,
469 AllreduceOptions{
470 *reduce_op,
471 std::chrono::milliseconds(timeout),
472 sparse_indices});
473
474 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
475 std::move(tensor_vec), work);
476 }
477 } // namespace
478
479 // register functions to dispatcher
480 namespace {
481
482 // 2nd level expansion
483 // FUNC: op name
484 // DEV: device
485 #define REGISTER_C10D_OP1(FUNC, DEV) \
486 TORCH_LIBRARY_IMPL(c10d, DEV, m) { \
487 m.impl(#FUNC, FUNC##DEV); \
488 }
489
490 // 1st level expansion
491 #define REGISTER_C10D_OP(FUNC) \
492 REGISTER_C10D_OP1(FUNC, CPU) \
493 REGISTER_C10D_OP1(FUNC, CUDA) \
494 REGISTER_C10D_OP1(FUNC, PrivateUse1)
495
496 // Now we start to register ops with the three device keys
497
498 REGISTER_C10D_OP(send)
REGISTER_C10D_OP(recv_)499 REGISTER_C10D_OP(recv_)
500 REGISTER_C10D_OP(recv_any_source_)
501 REGISTER_C10D_OP(reduce_)
502 REGISTER_C10D_OP(broadcast_)
503 REGISTER_C10D_OP(allreduce_)
504 REGISTER_C10D_OP(allreduce_coalesced_)
505 REGISTER_C10D_OP(allgather_)
506 REGISTER_C10D_OP(_allgather_base_)
507 REGISTER_C10D_OP(allgather_coalesced_)
508 REGISTER_C10D_OP(allgather_into_tensor_coalesced_)
509 REGISTER_C10D_OP(reduce_scatter_)
510 REGISTER_C10D_OP(_reduce_scatter_base_)
511 REGISTER_C10D_OP(reduce_scatter_tensor_coalesced_)
512 REGISTER_C10D_OP(gather_)
513 REGISTER_C10D_OP(scatter_)
514 REGISTER_C10D_OP(alltoall_)
515 REGISTER_C10D_OP(alltoall_base_)
516 REGISTER_C10D_OP(barrier)
517
518 // The following ops are specialized, register them separately
519
520 TORCH_LIBRARY_IMPL(c10d, CPU, m) {
521 m.impl("monitored_barrier_", monitored_barrier_CPU);
522 }
523
524 // TODO: The SparseCPU/SparseCUDA dispatched methods are only used to support
525 // sparse all_reduce in the Gloo backend
TORCH_LIBRARY_IMPL(c10d,SparseCPU,m)526 TORCH_LIBRARY_IMPL(c10d, SparseCPU, m) {
527 m.impl("allreduce_", allreduce_CPU);
528 }
529
TORCH_LIBRARY_IMPL(c10d,SparseCUDA,m)530 TORCH_LIBRARY_IMPL(c10d, SparseCUDA, m) {
531 m.impl("allreduce_", allreduce_sparse_cuda_);
532 }
533
534 } // namespace
535
536 } // namespace ops
537 } // namespace c10d
538