1 #include <ATen/ATen.h>
2 #include <ATen/core/op_registration/op_registration.h>
3 #include <c10/core/DispatchKey.h>
4 #include <torch/csrc/autograd/custom_function.h>
5 #include <torch/csrc/autograd/function.h>
6 #include <torch/csrc/distributed/c10d/Functional.hpp>
7 #include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
8 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
9 #include <torch/csrc/distributed/c10d/RankLocal.hpp>
10 #include <utility>
11
12 namespace {
13
14 class WorkRegistry {
15 public:
register_work(const at::Tensor & tensor,const c10::intrusive_ptr<c10d::Work> & work)16 void register_work(
17 const at::Tensor& tensor,
18 const c10::intrusive_ptr<c10d::Work>& work) {
19 auto storage = tensor.storage().getWeakStorageImpl();
20 std::unique_lock lock(lock_);
21 auto [it, inserted] = registry_.try_emplace(std::move(storage), work);
22 TORCH_CHECK(
23 inserted || it->second != work,
24 "The tensor storage is already associated with another work.");
25 }
26
pop_work(const at::Tensor & tensor)27 c10::intrusive_ptr<c10d::Work> pop_work(const at::Tensor& tensor) {
28 const auto storage = tensor.storage().getWeakStorageImpl();
29 std::unique_lock lock(lock_);
30 auto it = registry_.find(storage);
31 if (it == registry_.end()) {
32 return nullptr;
33 }
34 auto work = it->second;
35 registry_.erase(it);
36 return work;
37 }
38
~WorkRegistry()39 ~WorkRegistry() {
40 // If there are still unwaited work objects, their corresponding process
41 // groups should have already been destroyed at this stage. Any attempts to
42 // wait for these work objects or to destroy them will only result in
43 // confusing errors. Therefore, we simply issue a warning and intentionally
44 // allow the unwaited work objects to leak.
45 if (!registry_.empty()) {
46 TORCH_WARN(
47 "At the time of process termination, there are still ",
48 registry_.size(),
49 " unwaited c10d_functional collective calls. "
50 "Please review your program to ensure c10d_functional.wait_tensor() "
51 "is invoked on all tensors returned from c10d_functional collective "
52 "ops before they are used.");
53 }
54 for (auto& it : registry_) {
55 it.second.release();
56 }
57 }
58
59 private:
60 std::unordered_map<
61 c10::weak_intrusive_ptr<c10::StorageImpl>,
62 c10::intrusive_ptr<c10d::Work>>
63 registry_;
64 std::mutex lock_;
65 };
66
67 static WorkRegistry process_registry;
68
69 } // namespace
70
71 namespace c10d {
72
register_work(const at::Tensor & tensor,const c10::intrusive_ptr<c10d::Work> & work)73 void register_work(
74 const at::Tensor& tensor,
75 const c10::intrusive_ptr<c10d::Work>& work) {
76 RankLocal<WorkRegistry>::get().register_work(tensor, work);
77 }
78
79 } // namespace c10d
80
81 namespace {
82
83 const std::unordered_map<std::string, c10d::ReduceOp> str_to_reduce_op = {
84 {"sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::SUM)},
85 {"avg", c10d::ReduceOp(c10d::ReduceOp::RedOpType::AVG)},
86 {"product", c10d::ReduceOp(c10d::ReduceOp::RedOpType::PRODUCT)},
87 {"min", c10d::ReduceOp(c10d::ReduceOp::RedOpType::MIN)},
88 {"max", c10d::ReduceOp(c10d::ReduceOp::RedOpType::MAX)},
89 {"band", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BAND)},
90 {"bor", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BOR)},
91 {"bxor", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BXOR)},
92 // TODO: support premul_sum
93 // {"premul_sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::PREMUL_SUM)},
94 {"unused", c10d::ReduceOp(c10d::ReduceOp::RedOpType::UNUSED)}};
95
to_reduce_op(const std::string & reduce_op)96 c10d::ReduceOp to_reduce_op(const std::string& reduce_op) {
97 auto it = str_to_reduce_op.find(reduce_op);
98 TORCH_CHECK(
99 it != str_to_reduce_op.end(), "Unrecognized reduce_op: ", reduce_op);
100 return it->second;
101 }
102
all_reduce_(at::Tensor & input,std::string reduce_op,std::string group_name)103 at::Tensor& all_reduce_(
104 at::Tensor& input,
105 // NOLINTNEXTLINE(performance-unnecessary-value-param)
106 std::string reduce_op,
107 // NOLINTNEXTLINE(performance-unnecessary-value-param)
108 std::string group_name) {
109 c10d::AllreduceOptions opts;
110 opts.reduceOp = to_reduce_op(reduce_op);
111
112 std::vector<at::Tensor> inputs{input};
113 auto group = c10d::resolve_process_group(group_name);
114 auto work = group->allreduce(inputs, opts);
115 c10d::register_work(input, work);
116 return input;
117 }
118
all_reduce(const at::Tensor & input,std::string reduce_op,std::string group_name)119 at::Tensor all_reduce(
120 const at::Tensor& input,
121 std::string reduce_op,
122 std::string group_name) {
123 auto output = input.clone(at::MemoryFormat::Contiguous);
124 return all_reduce_(output, std::move(reduce_op), std::move(group_name));
125 }
126
all_reduce_coalesced_(std::vector<at::Tensor> inputs,std::string reduce_op,std::string group_name)127 std::vector<at::Tensor> all_reduce_coalesced_(
128 std::vector<at::Tensor> inputs,
129 // NOLINTNEXTLINE(performance-unnecessary-value-param)
130 std::string reduce_op,
131 // NOLINTNEXTLINE(performance-unnecessary-value-param)
132 std::string group_name) {
133 c10d::AllreduceCoalescedOptions opts;
134 opts.reduceOp = to_reduce_op(reduce_op);
135
136 auto group = c10d::resolve_process_group(group_name);
137 auto work = group->allreduce_coalesced(inputs, opts);
138 for (const auto& tensor : inputs) {
139 c10d::register_work(tensor, work);
140 }
141 return inputs;
142 }
143
all_reduce_coalesced(std::vector<at::Tensor> inputs,std::string reduce_op,std::string group_name)144 std::vector<at::Tensor> all_reduce_coalesced(
145 // NOLINTNEXTLINE(performance-unnecessary-value-param)
146 std::vector<at::Tensor> inputs,
147 std::string reduce_op,
148 std::string group_name) {
149 std::vector<at::Tensor> outputs;
150 outputs.reserve(inputs.size());
151 for (const auto& tensor : inputs) {
152 outputs.push_back(tensor.clone(at::MemoryFormat::Contiguous));
153 }
154 return all_reduce_coalesced_(
155 outputs, std::move(reduce_op), std::move(group_name));
156 }
157
allocate_all_gather_output(const at::Tensor & input,int64_t group_size)158 at::Tensor allocate_all_gather_output(
159 const at::Tensor& input,
160 int64_t group_size) {
161 auto output_size = input.sizes().vec();
162 output_size[0] *= group_size;
163 return at::empty(
164 output_size,
165 at::TensorOptions().dtype(input.dtype()).device(input.device()));
166 }
167
all_gather_into_tensor_coalesced(std::vector<at::Tensor> inputs,int64_t group_size,std::string group_name)168 std::vector<at::Tensor> all_gather_into_tensor_coalesced(
169 std::vector<at::Tensor> inputs,
170 int64_t group_size,
171 // NOLINTNEXTLINE(performance-unnecessary-value-param)
172 std::string group_name) {
173 std::vector<at::Tensor> outputs;
174 outputs.reserve(inputs.size());
175 for (const auto& tensor : inputs) {
176 outputs.push_back(allocate_all_gather_output(tensor, group_size));
177 }
178
179 auto group = c10d::resolve_process_group(group_name);
180 auto work = group->allgather_into_tensor_coalesced(outputs, inputs);
181 for (const auto& tensor : outputs) {
182 c10d::register_work(tensor, work);
183 }
184 return outputs;
185 }
186
all_gather_into_tensor(const at::Tensor & input,int64_t group_size,std::string group_name)187 at::Tensor all_gather_into_tensor(
188 const at::Tensor& input,
189 int64_t group_size,
190 std::string group_name) {
191 std::vector<at::Tensor> inputs{input};
192 return all_gather_into_tensor_coalesced(
193 inputs, group_size, std::move(group_name))[0];
194 }
195
all_gather_into_tensor_out(at::Tensor & input,int64_t group_size,const std::string & group_name,at::Tensor & output)196 at::Tensor& all_gather_into_tensor_out(
197 at::Tensor& input,
198 int64_t group_size,
199 const std::string& group_name,
200 at::Tensor& output) {
201 c10d::AllgatherOptions opts;
202
203 auto group = c10d::resolve_process_group(group_name);
204 auto work = group->_allgather_base(output, input, opts);
205 c10d::register_work(output, work);
206 return output;
207 }
208
allocate_reduce_scatter_output(const at::Tensor & input,const int64_t group_size)209 at::Tensor allocate_reduce_scatter_output(
210 const at::Tensor& input,
211 const int64_t group_size) {
212 auto output_size = input.sizes().vec();
213 if (output_size[0] % group_size != 0) {
214 LOG(WARNING) << "The first dimension of the reduce_scatter input ("
215 << output_size[0] << ") is not divisible by the group size ("
216 << group_size << ").";
217 }
218 output_size[0] /= group_size;
219 return at::empty(
220 output_size,
221 at::TensorOptions().dtype(input.dtype()).device(input.device()));
222 }
223
reduce_scatter_tensor_coalesced(std::vector<at::Tensor> inputs,std::string reduce_op,int64_t group_size,std::string group_name)224 std::vector<at::Tensor> reduce_scatter_tensor_coalesced(
225 std::vector<at::Tensor> inputs,
226 // NOLINTNEXTLINE(performance-unnecessary-value-param)
227 std::string reduce_op,
228 int64_t group_size,
229 // NOLINTNEXTLINE(performance-unnecessary-value-param)
230 std::string group_name) {
231 c10d::ReduceScatterOptions opts;
232 opts.reduceOp = to_reduce_op(reduce_op);
233 std::vector<at::Tensor> outputs;
234 outputs.reserve(inputs.size());
235 for (const auto& tensor : inputs) {
236 outputs.push_back(allocate_reduce_scatter_output(tensor, group_size));
237 }
238
239 auto group = c10d::resolve_process_group(group_name);
240 auto work = group->reduce_scatter_tensor_coalesced(outputs, inputs, opts);
241 for (const auto& tensor : outputs) {
242 c10d::register_work(tensor, work);
243 }
244 return outputs;
245 }
246
reduce_scatter_tensor(const at::Tensor & input,std::string reduce_op,int64_t group_size,std::string group_name)247 at::Tensor reduce_scatter_tensor(
248 const at::Tensor& input,
249 std::string reduce_op,
250 int64_t group_size,
251 std::string group_name) {
252 std::vector<at::Tensor> inputs{input};
253 return reduce_scatter_tensor_coalesced(
254 inputs, std::move(reduce_op), group_size, std::move(group_name))[0];
255 }
256
all_to_all_single(const at::Tensor & input,std::vector<int64_t> output_split_sizes,std::vector<int64_t> input_split_sizes,std::string group_name)257 at::Tensor all_to_all_single(
258 const at::Tensor& input,
259 std::vector<int64_t> output_split_sizes,
260 std::vector<int64_t> input_split_sizes,
261 // NOLINTNEXTLINE(performance-unnecessary-value-param)
262 std::string group_name) {
263 std::vector<int64_t> output_sizes = input.sizes().vec();
264 output_sizes[0] = std::accumulate(
265 output_split_sizes.begin(), output_split_sizes.end(), int64_t(0));
266 auto output = input.new_empty(output_sizes);
267
268 auto group = c10d::resolve_process_group(group_name);
269 auto work = group->alltoall_base(
270 output,
271 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
272 const_cast<at::Tensor&>(input),
273 output_split_sizes,
274 input_split_sizes);
275 c10d::register_work(output, work);
276 return output;
277 }
278
279 // NOLINTNEXTLINE(performance-unnecessary-value-param)
broadcast_(at::Tensor & input,int64_t src,std::string group_name)280 at::Tensor& broadcast_(at::Tensor& input, int64_t src, std::string group_name) {
281 c10d::BroadcastOptions opts;
282 opts.rootRank = src;
283 std::vector<at::Tensor> inputs{input};
284
285 auto group = c10d::resolve_process_group(group_name);
286 auto work = group->broadcast(inputs, opts);
287 c10d::register_work(input, work);
288 return input;
289 }
290
broadcast(const at::Tensor & input,int64_t src,std::string group_name)291 at::Tensor broadcast(
292 const at::Tensor& input,
293 int64_t src,
294 std::string group_name) {
295 auto output = input.clone(at::MemoryFormat::Contiguous);
296 return broadcast_(output, src, std::move(group_name));
297 }
298
wait_tensor(const at::Tensor & tensor)299 at::Tensor wait_tensor(const at::Tensor& tensor) {
300 auto work = c10d::RankLocal<WorkRegistry>::get().pop_work(tensor);
301 if (work != nullptr) {
302 work->wait();
303 }
304 return tensor;
305 }
306
307 } // namespace
308
TORCH_LIBRARY(_c10d_functional,m)309 TORCH_LIBRARY(_c10d_functional, m) {
310 m.def(
311 "all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
312 torch::dispatch(
313 c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce),
314 {at::Tag::pt2_compliant_tag});
315
316 m.def(
317 "all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)",
318 torch::dispatch(
319 c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_),
320 {at::Tag::pt2_compliant_tag});
321
322 m.def(
323 "all_reduce_coalesced(Tensor[] inputs, str reduce_op, str group_name) -> Tensor[]",
324 torch::dispatch(
325 c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced),
326 {at::Tag::pt2_compliant_tag});
327
328 m.def(
329 "all_reduce_coalesced_(Tensor[](a!) inputs, str reduce_op, str group_name) -> Tensor[](a!)",
330 torch::dispatch(
331 c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced_),
332 {at::Tag::pt2_compliant_tag});
333
334 m.def(
335 "all_gather_into_tensor_out(Tensor input, int group_size, str group_name, *, Tensor(a!) out) -> Tensor(a!)",
336 torch::dispatch(
337 c10::DispatchKey::CompositeExplicitAutograd,
338 ::all_gather_into_tensor_out),
339 {at::Tag::pt2_compliant_tag});
340
341 m.def(
342 "all_gather_into_tensor(Tensor input, int group_size, str group_name) -> Tensor",
343 torch::dispatch(
344 c10::DispatchKey::CompositeExplicitAutograd,
345 ::all_gather_into_tensor),
346 {at::Tag::pt2_compliant_tag});
347
348 m.def(
349 "all_gather_into_tensor_coalesced(Tensor[] inputs, int group_size, str group_name) -> Tensor[]",
350 torch::dispatch(
351 c10::DispatchKey::CompositeExplicitAutograd,
352 ::all_gather_into_tensor_coalesced),
353 {at::Tag::pt2_compliant_tag});
354
355 m.def(
356 "reduce_scatter_tensor(Tensor input, str reduce_op, int group_size, str group_name) -> Tensor",
357 torch::dispatch(
358 c10::DispatchKey::CompositeExplicitAutograd, ::reduce_scatter_tensor),
359 {at::Tag::pt2_compliant_tag});
360
361 m.def(
362 "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduce_op, int group_size, str group_name) -> Tensor[]",
363 torch::dispatch(
364 c10::DispatchKey::CompositeExplicitAutograd,
365 ::reduce_scatter_tensor_coalesced),
366 {at::Tag::pt2_compliant_tag});
367
368 m.def(
369 "all_to_all_single("
370 "Tensor input, "
371 "SymInt[] output_split_sizes, "
372 "SymInt[] input_split_sizes, "
373 "str group_name) -> Tensor",
374 torch::dispatch(
375 c10::DispatchKey::CompositeExplicitAutograd, ::all_to_all_single),
376 {at::Tag::pt2_compliant_tag});
377
378 m.def(
379 "broadcast(Tensor input, int src, str group_name) -> Tensor",
380 torch::dispatch(c10::DispatchKey::CompositeExplicitAutograd, ::broadcast),
381 {at::Tag::pt2_compliant_tag});
382
383 m.def(
384 "broadcast_(Tensor(a!) input, int src, str group_name) -> Tensor(a!)",
385 torch::dispatch(
386 c10::DispatchKey::CompositeExplicitAutograd, ::broadcast_),
387 {at::Tag::pt2_compliant_tag});
388
389 m.def(
390 "wait_tensor(Tensor tensor) -> Tensor",
391 torch::dispatch(
392 c10::DispatchKey::CompositeExplicitAutograd, ::wait_tensor),
393 {at::Tag::pt2_compliant_tag});
394 }
395
396 namespace {
397 class AllToAllSingle : public torch::autograd::Function<AllToAllSingle> {
398 public:
forward(torch::autograd::AutogradContext * ctx,const at::Tensor & input,std::vector<int64_t> output_split_sizes,std::vector<int64_t> input_split_sizes,std::string group_name)399 static torch::autograd::Variable forward(
400 torch::autograd::AutogradContext* ctx,
401 const at::Tensor& input,
402 // NOLINTNEXTLINE(performance-unnecessary-value-param)
403 std::vector<int64_t> output_split_sizes,
404 // NOLINTNEXTLINE(performance-unnecessary-value-param)
405 std::vector<int64_t> input_split_sizes,
406 // NOLINTNEXTLINE(performance-unnecessary-value-param)
407 std::string group_name) {
408 // swap sizes for backwards pass
409 ctx->saved_data["output_split_sizes"] = input_split_sizes;
410 ctx->saved_data["input_split_sizes"] = output_split_sizes;
411 ctx->saved_data["group_name"] = group_name;
412
413 return c10::Dispatcher::singleton()
414 .findSchemaOrThrow("_c10d_functional::all_to_all_single", "")
415 .typed<decltype(all_to_all_single)>()
416 .call(input, output_split_sizes, input_split_sizes, group_name);
417 }
418
backward(torch::autograd::AutogradContext * ctx,torch::autograd::variable_list grad_out_list)419 static torch::autograd::variable_list backward(
420 torch::autograd::AutogradContext* ctx,
421 torch::autograd::variable_list grad_out_list) {
422 const std::vector<int64_t>& output_split_sizes =
423 ctx->saved_data["output_split_sizes"].toIntVector();
424 const std::vector<int64_t>& input_split_sizes =
425 ctx->saved_data["input_split_sizes"].toIntVector();
426 const std::string& group_name = ctx->saved_data["group_name"].toStringRef();
427
428 DCHECK(grad_out_list.size() == 1);
429 auto grad_out = grad_out_list[0].contiguous();
430
431 auto out =
432 c10::Dispatcher::singleton()
433 .findSchemaOrThrow("_c10d_functional::all_to_all_single", "")
434 .typed<decltype(all_to_all_single)>()
435 .call(grad_out, output_split_sizes, input_split_sizes, group_name);
436
437 // do an explicit wait to avoid cuda stream issues
438 // TODO: track active cuda stream in wait
439 out = c10::Dispatcher::singleton()
440 .findSchemaOrThrow("_c10d_functional::wait_tensor", "")
441 .typed<decltype(wait_tensor)>()
442 .call(out);
443
444 return {out, at::Tensor(), at::Tensor(), at::Tensor()};
445 }
446 };
447
all_to_all_single_autograd(const at::Tensor & input,const std::vector<int64_t> & output_split_sizes,const std::vector<int64_t> & input_split_sizes,const std::string & group_name)448 at::Tensor all_to_all_single_autograd(
449 const at::Tensor& input,
450 const std::vector<int64_t>& output_split_sizes,
451 const std::vector<int64_t>& input_split_sizes,
452 const std::string& group_name) {
453 return AllToAllSingle::apply(
454 input, output_split_sizes, input_split_sizes, group_name);
455 }
456
457 class ReduceScatterTensor
458 : public torch::autograd::Function<ReduceScatterTensor> {
459 public:
forward(torch::autograd::AutogradContext * ctx,const at::Tensor & input,const std::string & reduce_op,int64_t group_size,const std::string & group_name)460 static torch::autograd::Variable forward(
461 torch::autograd::AutogradContext* ctx,
462 const at::Tensor& input,
463 const std::string& reduce_op,
464 int64_t group_size,
465 const std::string& group_name) {
466 TORCH_CHECK(reduce_op == "sum", "Only sum reduce op is supported");
467
468 ctx->saved_data["group_size"] = group_size;
469 ctx->saved_data["group_name"] = group_name;
470
471 return c10::Dispatcher::singleton()
472 .findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "")
473 .typed<decltype(reduce_scatter_tensor)>()
474 .call(input, reduce_op, group_size, group_name);
475 }
476
backward(torch::autograd::AutogradContext * ctx,torch::autograd::variable_list grad_out_list)477 static torch::autograd::variable_list backward(
478 torch::autograd::AutogradContext* ctx,
479 torch::autograd::variable_list grad_out_list) {
480 const int64_t group_size = ctx->saved_data["group_size"].toInt();
481 const std::string& group_name = ctx->saved_data["group_name"].toStringRef();
482
483 DCHECK(grad_out_list.size() == 1);
484 auto grad_out = grad_out_list[0];
485
486 auto out =
487 c10::Dispatcher::singleton()
488 .findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "")
489 .typed<decltype(all_gather_into_tensor)>()
490 .call(grad_out, group_size, group_name);
491
492 // do an explicit wait to avoid cuda stream issues
493 // TODO: track active cuda stream in wait
494 out = c10::Dispatcher::singleton()
495 .findSchemaOrThrow("_c10d_functional::wait_tensor", "")
496 .typed<decltype(wait_tensor)>()
497 .call(out);
498
499 return {
500 out,
501 at::Tensor(),
502 at::Tensor(),
503 at::Tensor(),
504 };
505 }
506 };
507
reduce_scatter_tensor_autograd(const at::Tensor & input,const std::string & reduce_op,int64_t group_size,const std::string & group_name)508 at::Tensor reduce_scatter_tensor_autograd(
509 const at::Tensor& input,
510 const std::string& reduce_op,
511 int64_t group_size,
512 const std::string& group_name) {
513 return ReduceScatterTensor::apply(input, reduce_op, group_size, group_name);
514 }
515
516 class AllGatherIntoTensor
517 : public torch::autograd::Function<AllGatherIntoTensor> {
518 public:
forward(torch::autograd::AutogradContext * ctx,const at::Tensor & input,int64_t group_size,const std::string & group_name)519 static torch::autograd::Variable forward(
520 torch::autograd::AutogradContext* ctx,
521 const at::Tensor& input,
522 int64_t group_size,
523 const std::string& group_name) {
524 ctx->saved_data["group_size"] = group_size;
525 ctx->saved_data["group_name"] = group_name;
526
527 return c10::Dispatcher::singleton()
528 .findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "")
529 .typed<decltype(all_gather_into_tensor)>()
530 .call(input, group_size, group_name);
531 }
532
backward(torch::autograd::AutogradContext * ctx,torch::autograd::variable_list grad_out_list)533 static torch::autograd::variable_list backward(
534 torch::autograd::AutogradContext* ctx,
535 torch::autograd::variable_list grad_out_list) {
536 const int64_t group_size = ctx->saved_data["group_size"].toInt();
537 const std::string& group_name = ctx->saved_data["group_name"].toStringRef();
538
539 DCHECK(grad_out_list.size() == 1);
540 auto grad_out = grad_out_list[0];
541
542 auto out =
543 c10::Dispatcher::singleton()
544 .findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "")
545 .typed<decltype(reduce_scatter_tensor)>()
546 .call(grad_out, "sum", group_size, group_name);
547
548 // do an explicit wait to avoid cuda stream issues
549 // TODO: track active cuda stream in wait
550 out = c10::Dispatcher::singleton()
551 .findSchemaOrThrow("_c10d_functional::wait_tensor", "")
552 .typed<decltype(wait_tensor)>()
553 .call(out);
554
555 return {
556 out,
557 at::Tensor(),
558 at::Tensor(),
559 };
560 }
561 };
562
all_gather_into_tensor_autograd(const at::Tensor & input,int64_t group_size,const std::string & group_name)563 at::Tensor all_gather_into_tensor_autograd(
564 const at::Tensor& input,
565 int64_t group_size,
566 const std::string& group_name) {
567 return AllGatherIntoTensor::apply(input, group_size, group_name);
568 }
569
570 } // namespace
571
TORCH_LIBRARY(_c10d_functional_autograd,m)572 TORCH_LIBRARY(_c10d_functional_autograd, m) {
573 m.def(
574 "all_to_all_single("
575 "Tensor input, "
576 "SymInt[] output_split_sizes, "
577 "SymInt[] input_split_sizes, "
578 "str group_name) -> Tensor",
579 torch::dispatch(c10::DispatchKey::Autograd, ::all_to_all_single_autograd),
580 {at::Tag::pt2_compliant_tag});
581 m.def(
582 "reduce_scatter_tensor("
583 "Tensor input, "
584 "str reduce_op, "
585 "int group_size, "
586 "str group_name) -> Tensor",
587 torch::dispatch(
588 c10::DispatchKey::Autograd, ::reduce_scatter_tensor_autograd),
589 {at::Tag::pt2_compliant_tag});
590 m.def(
591 "all_gather_into_tensor("
592 "Tensor input, "
593 "int group_size, "
594 "str group_name) -> Tensor",
595 torch::dispatch(
596 c10::DispatchKey::Autograd, ::all_gather_into_tensor_autograd),
597 {at::Tag::pt2_compliant_tag});
598 }
599
600 namespace {
601 // DTensor related comm operations, sharing code with functional collective for
602 // now
shard_dim_alltoall(const at::Tensor & input,int64_t gather_dim,int64_t shard_dim,const std::string & group_name)603 at::Tensor shard_dim_alltoall(
604 const at::Tensor& input,
605 int64_t gather_dim,
606 int64_t shard_dim,
607 const std::string& group_name) {
608 auto group = c10d::resolve_process_group(group_name);
609 auto group_size = group->getSize();
610 std::vector<int64_t> output_sizes = input.sizes().vec();
611 if (output_sizes[shard_dim] % group_size != 0) {
612 LOG(WARNING) << "The first dimension of the shard_dim_alltoall input ("
613 << output_sizes[shard_dim]
614 << ") is not divisible by the group size (" << group_size
615 << ").";
616 }
617 output_sizes[shard_dim] = output_sizes[shard_dim] / group_size;
618 std::vector<at::Tensor> inputs;
619 inputs.reserve(group_size);
620 auto length = output_sizes[shard_dim];
621 for (int i = 0; i < group_size; i++) {
622 inputs.push_back(input.narrow(shard_dim, i * length, length).contiguous());
623 }
624 // allocate outputs
625 std::vector<at::Tensor> outputs;
626 outputs.reserve(group_size);
627 for (int i = 0; i < group_size; i++) {
628 outputs.push_back(input.new_empty(output_sizes).contiguous());
629 }
630 auto work = group->alltoall(outputs, inputs);
631
632 work->wait();
633 // TODO: it's very tricky to get the current async behavior work for shard dim
634 // alltoall so for now we just keep this comm op to be synchronous. We can
635 // revisit later how to support the async case with the Work registry.
636 return at::cat(outputs, gather_dim);
637 }
638 } // namespace
639
640 // DTensor comm op registry
TORCH_LIBRARY(_dtensor,m)641 TORCH_LIBRARY(_dtensor, m) {
642 m.def(
643 "shard_dim_alltoall(Tensor input, int gather_dim, int shard_dim, str group_name) -> Tensor",
644 torch::dispatch(
645 c10::DispatchKey::CompositeExplicitAutograd, ::shard_dim_alltoall),
646 {at::Tag::pt2_compliant_tag});
647 }
648