xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/Functional.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/core/op_registration/op_registration.h>
3 #include <c10/core/DispatchKey.h>
4 #include <torch/csrc/autograd/custom_function.h>
5 #include <torch/csrc/autograd/function.h>
6 #include <torch/csrc/distributed/c10d/Functional.hpp>
7 #include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
8 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
9 #include <torch/csrc/distributed/c10d/RankLocal.hpp>
10 #include <utility>
11 
12 namespace {
13 
14 class WorkRegistry {
15  public:
register_work(const at::Tensor & tensor,const c10::intrusive_ptr<c10d::Work> & work)16   void register_work(
17       const at::Tensor& tensor,
18       const c10::intrusive_ptr<c10d::Work>& work) {
19     auto storage = tensor.storage().getWeakStorageImpl();
20     std::unique_lock lock(lock_);
21     auto [it, inserted] = registry_.try_emplace(std::move(storage), work);
22     TORCH_CHECK(
23         inserted || it->second != work,
24         "The tensor storage is already associated with another work.");
25   }
26 
pop_work(const at::Tensor & tensor)27   c10::intrusive_ptr<c10d::Work> pop_work(const at::Tensor& tensor) {
28     const auto storage = tensor.storage().getWeakStorageImpl();
29     std::unique_lock lock(lock_);
30     auto it = registry_.find(storage);
31     if (it == registry_.end()) {
32       return nullptr;
33     }
34     auto work = it->second;
35     registry_.erase(it);
36     return work;
37   }
38 
~WorkRegistry()39   ~WorkRegistry() {
40     // If there are still unwaited work objects, their corresponding process
41     // groups should have already been destroyed at this stage. Any attempts to
42     // wait for these work objects or to destroy them will only result in
43     // confusing errors. Therefore, we simply issue a warning and intentionally
44     // allow the unwaited work objects to leak.
45     if (!registry_.empty()) {
46       TORCH_WARN(
47           "At the time of process termination, there are still ",
48           registry_.size(),
49           " unwaited c10d_functional collective calls. "
50           "Please review your program to ensure c10d_functional.wait_tensor() "
51           "is invoked on all tensors returned from c10d_functional collective "
52           "ops before they are used.");
53     }
54     for (auto& it : registry_) {
55       it.second.release();
56     }
57   }
58 
59  private:
60   std::unordered_map<
61       c10::weak_intrusive_ptr<c10::StorageImpl>,
62       c10::intrusive_ptr<c10d::Work>>
63       registry_;
64   std::mutex lock_;
65 };
66 
67 static WorkRegistry process_registry;
68 
69 } // namespace
70 
71 namespace c10d {
72 
register_work(const at::Tensor & tensor,const c10::intrusive_ptr<c10d::Work> & work)73 void register_work(
74     const at::Tensor& tensor,
75     const c10::intrusive_ptr<c10d::Work>& work) {
76   RankLocal<WorkRegistry>::get().register_work(tensor, work);
77 }
78 
79 } // namespace c10d
80 
81 namespace {
82 
83 const std::unordered_map<std::string, c10d::ReduceOp> str_to_reduce_op = {
84     {"sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::SUM)},
85     {"avg", c10d::ReduceOp(c10d::ReduceOp::RedOpType::AVG)},
86     {"product", c10d::ReduceOp(c10d::ReduceOp::RedOpType::PRODUCT)},
87     {"min", c10d::ReduceOp(c10d::ReduceOp::RedOpType::MIN)},
88     {"max", c10d::ReduceOp(c10d::ReduceOp::RedOpType::MAX)},
89     {"band", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BAND)},
90     {"bor", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BOR)},
91     {"bxor", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BXOR)},
92     // TODO: support premul_sum
93     // {"premul_sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::PREMUL_SUM)},
94     {"unused", c10d::ReduceOp(c10d::ReduceOp::RedOpType::UNUSED)}};
95 
to_reduce_op(const std::string & reduce_op)96 c10d::ReduceOp to_reduce_op(const std::string& reduce_op) {
97   auto it = str_to_reduce_op.find(reduce_op);
98   TORCH_CHECK(
99       it != str_to_reduce_op.end(), "Unrecognized reduce_op: ", reduce_op);
100   return it->second;
101 }
102 
all_reduce_(at::Tensor & input,std::string reduce_op,std::string group_name)103 at::Tensor& all_reduce_(
104     at::Tensor& input,
105     // NOLINTNEXTLINE(performance-unnecessary-value-param)
106     std::string reduce_op,
107     // NOLINTNEXTLINE(performance-unnecessary-value-param)
108     std::string group_name) {
109   c10d::AllreduceOptions opts;
110   opts.reduceOp = to_reduce_op(reduce_op);
111 
112   std::vector<at::Tensor> inputs{input};
113   auto group = c10d::resolve_process_group(group_name);
114   auto work = group->allreduce(inputs, opts);
115   c10d::register_work(input, work);
116   return input;
117 }
118 
all_reduce(const at::Tensor & input,std::string reduce_op,std::string group_name)119 at::Tensor all_reduce(
120     const at::Tensor& input,
121     std::string reduce_op,
122     std::string group_name) {
123   auto output = input.clone(at::MemoryFormat::Contiguous);
124   return all_reduce_(output, std::move(reduce_op), std::move(group_name));
125 }
126 
all_reduce_coalesced_(std::vector<at::Tensor> inputs,std::string reduce_op,std::string group_name)127 std::vector<at::Tensor> all_reduce_coalesced_(
128     std::vector<at::Tensor> inputs,
129     // NOLINTNEXTLINE(performance-unnecessary-value-param)
130     std::string reduce_op,
131     // NOLINTNEXTLINE(performance-unnecessary-value-param)
132     std::string group_name) {
133   c10d::AllreduceCoalescedOptions opts;
134   opts.reduceOp = to_reduce_op(reduce_op);
135 
136   auto group = c10d::resolve_process_group(group_name);
137   auto work = group->allreduce_coalesced(inputs, opts);
138   for (const auto& tensor : inputs) {
139     c10d::register_work(tensor, work);
140   }
141   return inputs;
142 }
143 
all_reduce_coalesced(std::vector<at::Tensor> inputs,std::string reduce_op,std::string group_name)144 std::vector<at::Tensor> all_reduce_coalesced(
145     // NOLINTNEXTLINE(performance-unnecessary-value-param)
146     std::vector<at::Tensor> inputs,
147     std::string reduce_op,
148     std::string group_name) {
149   std::vector<at::Tensor> outputs;
150   outputs.reserve(inputs.size());
151   for (const auto& tensor : inputs) {
152     outputs.push_back(tensor.clone(at::MemoryFormat::Contiguous));
153   }
154   return all_reduce_coalesced_(
155       outputs, std::move(reduce_op), std::move(group_name));
156 }
157 
allocate_all_gather_output(const at::Tensor & input,int64_t group_size)158 at::Tensor allocate_all_gather_output(
159     const at::Tensor& input,
160     int64_t group_size) {
161   auto output_size = input.sizes().vec();
162   output_size[0] *= group_size;
163   return at::empty(
164       output_size,
165       at::TensorOptions().dtype(input.dtype()).device(input.device()));
166 }
167 
all_gather_into_tensor_coalesced(std::vector<at::Tensor> inputs,int64_t group_size,std::string group_name)168 std::vector<at::Tensor> all_gather_into_tensor_coalesced(
169     std::vector<at::Tensor> inputs,
170     int64_t group_size,
171     // NOLINTNEXTLINE(performance-unnecessary-value-param)
172     std::string group_name) {
173   std::vector<at::Tensor> outputs;
174   outputs.reserve(inputs.size());
175   for (const auto& tensor : inputs) {
176     outputs.push_back(allocate_all_gather_output(tensor, group_size));
177   }
178 
179   auto group = c10d::resolve_process_group(group_name);
180   auto work = group->allgather_into_tensor_coalesced(outputs, inputs);
181   for (const auto& tensor : outputs) {
182     c10d::register_work(tensor, work);
183   }
184   return outputs;
185 }
186 
all_gather_into_tensor(const at::Tensor & input,int64_t group_size,std::string group_name)187 at::Tensor all_gather_into_tensor(
188     const at::Tensor& input,
189     int64_t group_size,
190     std::string group_name) {
191   std::vector<at::Tensor> inputs{input};
192   return all_gather_into_tensor_coalesced(
193       inputs, group_size, std::move(group_name))[0];
194 }
195 
all_gather_into_tensor_out(at::Tensor & input,int64_t group_size,const std::string & group_name,at::Tensor & output)196 at::Tensor& all_gather_into_tensor_out(
197     at::Tensor& input,
198     int64_t group_size,
199     const std::string& group_name,
200     at::Tensor& output) {
201   c10d::AllgatherOptions opts;
202 
203   auto group = c10d::resolve_process_group(group_name);
204   auto work = group->_allgather_base(output, input, opts);
205   c10d::register_work(output, work);
206   return output;
207 }
208 
allocate_reduce_scatter_output(const at::Tensor & input,const int64_t group_size)209 at::Tensor allocate_reduce_scatter_output(
210     const at::Tensor& input,
211     const int64_t group_size) {
212   auto output_size = input.sizes().vec();
213   if (output_size[0] % group_size != 0) {
214     LOG(WARNING) << "The first dimension of the reduce_scatter input ("
215                  << output_size[0] << ") is not divisible by the group size ("
216                  << group_size << ").";
217   }
218   output_size[0] /= group_size;
219   return at::empty(
220       output_size,
221       at::TensorOptions().dtype(input.dtype()).device(input.device()));
222 }
223 
reduce_scatter_tensor_coalesced(std::vector<at::Tensor> inputs,std::string reduce_op,int64_t group_size,std::string group_name)224 std::vector<at::Tensor> reduce_scatter_tensor_coalesced(
225     std::vector<at::Tensor> inputs,
226     // NOLINTNEXTLINE(performance-unnecessary-value-param)
227     std::string reduce_op,
228     int64_t group_size,
229     // NOLINTNEXTLINE(performance-unnecessary-value-param)
230     std::string group_name) {
231   c10d::ReduceScatterOptions opts;
232   opts.reduceOp = to_reduce_op(reduce_op);
233   std::vector<at::Tensor> outputs;
234   outputs.reserve(inputs.size());
235   for (const auto& tensor : inputs) {
236     outputs.push_back(allocate_reduce_scatter_output(tensor, group_size));
237   }
238 
239   auto group = c10d::resolve_process_group(group_name);
240   auto work = group->reduce_scatter_tensor_coalesced(outputs, inputs, opts);
241   for (const auto& tensor : outputs) {
242     c10d::register_work(tensor, work);
243   }
244   return outputs;
245 }
246 
reduce_scatter_tensor(const at::Tensor & input,std::string reduce_op,int64_t group_size,std::string group_name)247 at::Tensor reduce_scatter_tensor(
248     const at::Tensor& input,
249     std::string reduce_op,
250     int64_t group_size,
251     std::string group_name) {
252   std::vector<at::Tensor> inputs{input};
253   return reduce_scatter_tensor_coalesced(
254       inputs, std::move(reduce_op), group_size, std::move(group_name))[0];
255 }
256 
all_to_all_single(const at::Tensor & input,std::vector<int64_t> output_split_sizes,std::vector<int64_t> input_split_sizes,std::string group_name)257 at::Tensor all_to_all_single(
258     const at::Tensor& input,
259     std::vector<int64_t> output_split_sizes,
260     std::vector<int64_t> input_split_sizes,
261     // NOLINTNEXTLINE(performance-unnecessary-value-param)
262     std::string group_name) {
263   std::vector<int64_t> output_sizes = input.sizes().vec();
264   output_sizes[0] = std::accumulate(
265       output_split_sizes.begin(), output_split_sizes.end(), int64_t(0));
266   auto output = input.new_empty(output_sizes);
267 
268   auto group = c10d::resolve_process_group(group_name);
269   auto work = group->alltoall_base(
270       output,
271       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
272       const_cast<at::Tensor&>(input),
273       output_split_sizes,
274       input_split_sizes);
275   c10d::register_work(output, work);
276   return output;
277 }
278 
279 // NOLINTNEXTLINE(performance-unnecessary-value-param)
broadcast_(at::Tensor & input,int64_t src,std::string group_name)280 at::Tensor& broadcast_(at::Tensor& input, int64_t src, std::string group_name) {
281   c10d::BroadcastOptions opts;
282   opts.rootRank = src;
283   std::vector<at::Tensor> inputs{input};
284 
285   auto group = c10d::resolve_process_group(group_name);
286   auto work = group->broadcast(inputs, opts);
287   c10d::register_work(input, work);
288   return input;
289 }
290 
broadcast(const at::Tensor & input,int64_t src,std::string group_name)291 at::Tensor broadcast(
292     const at::Tensor& input,
293     int64_t src,
294     std::string group_name) {
295   auto output = input.clone(at::MemoryFormat::Contiguous);
296   return broadcast_(output, src, std::move(group_name));
297 }
298 
wait_tensor(const at::Tensor & tensor)299 at::Tensor wait_tensor(const at::Tensor& tensor) {
300   auto work = c10d::RankLocal<WorkRegistry>::get().pop_work(tensor);
301   if (work != nullptr) {
302     work->wait();
303   }
304   return tensor;
305 }
306 
307 } // namespace
308 
TORCH_LIBRARY(_c10d_functional,m)309 TORCH_LIBRARY(_c10d_functional, m) {
310   m.def(
311       "all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
312       torch::dispatch(
313           c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce),
314       {at::Tag::pt2_compliant_tag});
315 
316   m.def(
317       "all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)",
318       torch::dispatch(
319           c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_),
320       {at::Tag::pt2_compliant_tag});
321 
322   m.def(
323       "all_reduce_coalesced(Tensor[] inputs, str reduce_op, str group_name) -> Tensor[]",
324       torch::dispatch(
325           c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced),
326       {at::Tag::pt2_compliant_tag});
327 
328   m.def(
329       "all_reduce_coalesced_(Tensor[](a!) inputs, str reduce_op, str group_name) -> Tensor[](a!)",
330       torch::dispatch(
331           c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced_),
332       {at::Tag::pt2_compliant_tag});
333 
334   m.def(
335       "all_gather_into_tensor_out(Tensor input, int group_size, str group_name, *, Tensor(a!) out) -> Tensor(a!)",
336       torch::dispatch(
337           c10::DispatchKey::CompositeExplicitAutograd,
338           ::all_gather_into_tensor_out),
339       {at::Tag::pt2_compliant_tag});
340 
341   m.def(
342       "all_gather_into_tensor(Tensor input, int group_size, str group_name) -> Tensor",
343       torch::dispatch(
344           c10::DispatchKey::CompositeExplicitAutograd,
345           ::all_gather_into_tensor),
346       {at::Tag::pt2_compliant_tag});
347 
348   m.def(
349       "all_gather_into_tensor_coalesced(Tensor[] inputs, int group_size, str group_name) -> Tensor[]",
350       torch::dispatch(
351           c10::DispatchKey::CompositeExplicitAutograd,
352           ::all_gather_into_tensor_coalesced),
353       {at::Tag::pt2_compliant_tag});
354 
355   m.def(
356       "reduce_scatter_tensor(Tensor input, str reduce_op, int group_size, str group_name) -> Tensor",
357       torch::dispatch(
358           c10::DispatchKey::CompositeExplicitAutograd, ::reduce_scatter_tensor),
359       {at::Tag::pt2_compliant_tag});
360 
361   m.def(
362       "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduce_op, int group_size, str group_name) -> Tensor[]",
363       torch::dispatch(
364           c10::DispatchKey::CompositeExplicitAutograd,
365           ::reduce_scatter_tensor_coalesced),
366       {at::Tag::pt2_compliant_tag});
367 
368   m.def(
369       "all_to_all_single("
370       "Tensor input, "
371       "SymInt[] output_split_sizes, "
372       "SymInt[] input_split_sizes, "
373       "str group_name) -> Tensor",
374       torch::dispatch(
375           c10::DispatchKey::CompositeExplicitAutograd, ::all_to_all_single),
376       {at::Tag::pt2_compliant_tag});
377 
378   m.def(
379       "broadcast(Tensor input, int src, str group_name) -> Tensor",
380       torch::dispatch(c10::DispatchKey::CompositeExplicitAutograd, ::broadcast),
381       {at::Tag::pt2_compliant_tag});
382 
383   m.def(
384       "broadcast_(Tensor(a!) input, int src, str group_name) -> Tensor(a!)",
385       torch::dispatch(
386           c10::DispatchKey::CompositeExplicitAutograd, ::broadcast_),
387       {at::Tag::pt2_compliant_tag});
388 
389   m.def(
390       "wait_tensor(Tensor tensor) -> Tensor",
391       torch::dispatch(
392           c10::DispatchKey::CompositeExplicitAutograd, ::wait_tensor),
393       {at::Tag::pt2_compliant_tag});
394 }
395 
396 namespace {
397 class AllToAllSingle : public torch::autograd::Function<AllToAllSingle> {
398  public:
forward(torch::autograd::AutogradContext * ctx,const at::Tensor & input,std::vector<int64_t> output_split_sizes,std::vector<int64_t> input_split_sizes,std::string group_name)399   static torch::autograd::Variable forward(
400       torch::autograd::AutogradContext* ctx,
401       const at::Tensor& input,
402       // NOLINTNEXTLINE(performance-unnecessary-value-param)
403       std::vector<int64_t> output_split_sizes,
404       // NOLINTNEXTLINE(performance-unnecessary-value-param)
405       std::vector<int64_t> input_split_sizes,
406       // NOLINTNEXTLINE(performance-unnecessary-value-param)
407       std::string group_name) {
408     // swap sizes for backwards pass
409     ctx->saved_data["output_split_sizes"] = input_split_sizes;
410     ctx->saved_data["input_split_sizes"] = output_split_sizes;
411     ctx->saved_data["group_name"] = group_name;
412 
413     return c10::Dispatcher::singleton()
414         .findSchemaOrThrow("_c10d_functional::all_to_all_single", "")
415         .typed<decltype(all_to_all_single)>()
416         .call(input, output_split_sizes, input_split_sizes, group_name);
417   }
418 
backward(torch::autograd::AutogradContext * ctx,torch::autograd::variable_list grad_out_list)419   static torch::autograd::variable_list backward(
420       torch::autograd::AutogradContext* ctx,
421       torch::autograd::variable_list grad_out_list) {
422     const std::vector<int64_t>& output_split_sizes =
423         ctx->saved_data["output_split_sizes"].toIntVector();
424     const std::vector<int64_t>& input_split_sizes =
425         ctx->saved_data["input_split_sizes"].toIntVector();
426     const std::string& group_name = ctx->saved_data["group_name"].toStringRef();
427 
428     DCHECK(grad_out_list.size() == 1);
429     auto grad_out = grad_out_list[0].contiguous();
430 
431     auto out =
432         c10::Dispatcher::singleton()
433             .findSchemaOrThrow("_c10d_functional::all_to_all_single", "")
434             .typed<decltype(all_to_all_single)>()
435             .call(grad_out, output_split_sizes, input_split_sizes, group_name);
436 
437     // do an explicit wait to avoid cuda stream issues
438     // TODO: track active cuda stream in wait
439     out = c10::Dispatcher::singleton()
440               .findSchemaOrThrow("_c10d_functional::wait_tensor", "")
441               .typed<decltype(wait_tensor)>()
442               .call(out);
443 
444     return {out, at::Tensor(), at::Tensor(), at::Tensor()};
445   }
446 };
447 
all_to_all_single_autograd(const at::Tensor & input,const std::vector<int64_t> & output_split_sizes,const std::vector<int64_t> & input_split_sizes,const std::string & group_name)448 at::Tensor all_to_all_single_autograd(
449     const at::Tensor& input,
450     const std::vector<int64_t>& output_split_sizes,
451     const std::vector<int64_t>& input_split_sizes,
452     const std::string& group_name) {
453   return AllToAllSingle::apply(
454       input, output_split_sizes, input_split_sizes, group_name);
455 }
456 
457 class ReduceScatterTensor
458     : public torch::autograd::Function<ReduceScatterTensor> {
459  public:
forward(torch::autograd::AutogradContext * ctx,const at::Tensor & input,const std::string & reduce_op,int64_t group_size,const std::string & group_name)460   static torch::autograd::Variable forward(
461       torch::autograd::AutogradContext* ctx,
462       const at::Tensor& input,
463       const std::string& reduce_op,
464       int64_t group_size,
465       const std::string& group_name) {
466     TORCH_CHECK(reduce_op == "sum", "Only sum reduce op is supported");
467 
468     ctx->saved_data["group_size"] = group_size;
469     ctx->saved_data["group_name"] = group_name;
470 
471     return c10::Dispatcher::singleton()
472         .findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "")
473         .typed<decltype(reduce_scatter_tensor)>()
474         .call(input, reduce_op, group_size, group_name);
475   }
476 
backward(torch::autograd::AutogradContext * ctx,torch::autograd::variable_list grad_out_list)477   static torch::autograd::variable_list backward(
478       torch::autograd::AutogradContext* ctx,
479       torch::autograd::variable_list grad_out_list) {
480     const int64_t group_size = ctx->saved_data["group_size"].toInt();
481     const std::string& group_name = ctx->saved_data["group_name"].toStringRef();
482 
483     DCHECK(grad_out_list.size() == 1);
484     auto grad_out = grad_out_list[0];
485 
486     auto out =
487         c10::Dispatcher::singleton()
488             .findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "")
489             .typed<decltype(all_gather_into_tensor)>()
490             .call(grad_out, group_size, group_name);
491 
492     // do an explicit wait to avoid cuda stream issues
493     // TODO: track active cuda stream in wait
494     out = c10::Dispatcher::singleton()
495               .findSchemaOrThrow("_c10d_functional::wait_tensor", "")
496               .typed<decltype(wait_tensor)>()
497               .call(out);
498 
499     return {
500         out,
501         at::Tensor(),
502         at::Tensor(),
503         at::Tensor(),
504     };
505   }
506 };
507 
reduce_scatter_tensor_autograd(const at::Tensor & input,const std::string & reduce_op,int64_t group_size,const std::string & group_name)508 at::Tensor reduce_scatter_tensor_autograd(
509     const at::Tensor& input,
510     const std::string& reduce_op,
511     int64_t group_size,
512     const std::string& group_name) {
513   return ReduceScatterTensor::apply(input, reduce_op, group_size, group_name);
514 }
515 
516 class AllGatherIntoTensor
517     : public torch::autograd::Function<AllGatherIntoTensor> {
518  public:
forward(torch::autograd::AutogradContext * ctx,const at::Tensor & input,int64_t group_size,const std::string & group_name)519   static torch::autograd::Variable forward(
520       torch::autograd::AutogradContext* ctx,
521       const at::Tensor& input,
522       int64_t group_size,
523       const std::string& group_name) {
524     ctx->saved_data["group_size"] = group_size;
525     ctx->saved_data["group_name"] = group_name;
526 
527     return c10::Dispatcher::singleton()
528         .findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "")
529         .typed<decltype(all_gather_into_tensor)>()
530         .call(input, group_size, group_name);
531   }
532 
backward(torch::autograd::AutogradContext * ctx,torch::autograd::variable_list grad_out_list)533   static torch::autograd::variable_list backward(
534       torch::autograd::AutogradContext* ctx,
535       torch::autograd::variable_list grad_out_list) {
536     const int64_t group_size = ctx->saved_data["group_size"].toInt();
537     const std::string& group_name = ctx->saved_data["group_name"].toStringRef();
538 
539     DCHECK(grad_out_list.size() == 1);
540     auto grad_out = grad_out_list[0];
541 
542     auto out =
543         c10::Dispatcher::singleton()
544             .findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "")
545             .typed<decltype(reduce_scatter_tensor)>()
546             .call(grad_out, "sum", group_size, group_name);
547 
548     // do an explicit wait to avoid cuda stream issues
549     // TODO: track active cuda stream in wait
550     out = c10::Dispatcher::singleton()
551               .findSchemaOrThrow("_c10d_functional::wait_tensor", "")
552               .typed<decltype(wait_tensor)>()
553               .call(out);
554 
555     return {
556         out,
557         at::Tensor(),
558         at::Tensor(),
559     };
560   }
561 };
562 
all_gather_into_tensor_autograd(const at::Tensor & input,int64_t group_size,const std::string & group_name)563 at::Tensor all_gather_into_tensor_autograd(
564     const at::Tensor& input,
565     int64_t group_size,
566     const std::string& group_name) {
567   return AllGatherIntoTensor::apply(input, group_size, group_name);
568 }
569 
570 } // namespace
571 
TORCH_LIBRARY(_c10d_functional_autograd,m)572 TORCH_LIBRARY(_c10d_functional_autograd, m) {
573   m.def(
574       "all_to_all_single("
575       "Tensor input, "
576       "SymInt[] output_split_sizes, "
577       "SymInt[] input_split_sizes, "
578       "str group_name) -> Tensor",
579       torch::dispatch(c10::DispatchKey::Autograd, ::all_to_all_single_autograd),
580       {at::Tag::pt2_compliant_tag});
581   m.def(
582       "reduce_scatter_tensor("
583       "Tensor input, "
584       "str reduce_op, "
585       "int group_size, "
586       "str group_name) -> Tensor",
587       torch::dispatch(
588           c10::DispatchKey::Autograd, ::reduce_scatter_tensor_autograd),
589       {at::Tag::pt2_compliant_tag});
590   m.def(
591       "all_gather_into_tensor("
592       "Tensor input, "
593       "int group_size, "
594       "str group_name) -> Tensor",
595       torch::dispatch(
596           c10::DispatchKey::Autograd, ::all_gather_into_tensor_autograd),
597       {at::Tag::pt2_compliant_tag});
598 }
599 
600 namespace {
601 // DTensor related comm operations, sharing code with functional collective for
602 // now
shard_dim_alltoall(const at::Tensor & input,int64_t gather_dim,int64_t shard_dim,const std::string & group_name)603 at::Tensor shard_dim_alltoall(
604     const at::Tensor& input,
605     int64_t gather_dim,
606     int64_t shard_dim,
607     const std::string& group_name) {
608   auto group = c10d::resolve_process_group(group_name);
609   auto group_size = group->getSize();
610   std::vector<int64_t> output_sizes = input.sizes().vec();
611   if (output_sizes[shard_dim] % group_size != 0) {
612     LOG(WARNING) << "The first dimension of the shard_dim_alltoall input ("
613                  << output_sizes[shard_dim]
614                  << ") is not divisible by the group size (" << group_size
615                  << ").";
616   }
617   output_sizes[shard_dim] = output_sizes[shard_dim] / group_size;
618   std::vector<at::Tensor> inputs;
619   inputs.reserve(group_size);
620   auto length = output_sizes[shard_dim];
621   for (int i = 0; i < group_size; i++) {
622     inputs.push_back(input.narrow(shard_dim, i * length, length).contiguous());
623   }
624   // allocate outputs
625   std::vector<at::Tensor> outputs;
626   outputs.reserve(group_size);
627   for (int i = 0; i < group_size; i++) {
628     outputs.push_back(input.new_empty(output_sizes).contiguous());
629   }
630   auto work = group->alltoall(outputs, inputs);
631 
632   work->wait();
633   // TODO: it's very tricky to get the current async behavior work for shard dim
634   // alltoall so for now we just keep this comm op to be synchronous. We can
635   // revisit later how to support the async case with the Work registry.
636   return at::cat(outputs, gather_dim);
637 }
638 } // namespace
639 
640 // DTensor comm op registry
TORCH_LIBRARY(_dtensor,m)641 TORCH_LIBRARY(_dtensor, m) {
642   m.def(
643       "shard_dim_alltoall(Tensor input, int gather_dim, int shard_dim, str group_name) -> Tensor",
644       torch::dispatch(
645           c10::DispatchKey::CompositeExplicitAutograd, ::shard_dim_alltoall),
646       {at::Tag::pt2_compliant_tag});
647 }
648