xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/functions/comm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/functions/comm.h>
2 
3 #include <ATen/core/functional.h>
4 #include <torch/csrc/autograd/function.h>
5 #include <torch/csrc/autograd/functions/utils.h>
6 #include <torch/csrc/autograd/variable.h>
7 #include <torch/csrc/cuda/comm.h>
8 
9 #include <ATen/ATen.h>
10 #include <ATen/cuda/CUDAContext.h>
11 
12 #include <memory>
13 #include <vector>
14 
15 namespace torch::autograd {
Scatter(std::vector<at::Device> devices,std::optional<std::vector<int64_t>> chunk_sizes,int64_t dim,std::optional<std::vector<std::optional<at::cuda::CUDAStream>>> streams,bool unsqueeze_scalars)16 Scatter::Scatter(
17     std::vector<at::Device> devices,
18     std::optional<std::vector<int64_t>> chunk_sizes,
19     int64_t dim,
20     std::optional<std::vector<std::optional<at::cuda::CUDAStream>>> streams,
21     bool unsqueeze_scalars)
22     : devices_(std::move(devices)),
23       chunk_sizes_(std::move(chunk_sizes)),
24       dim_(dim),
25       streams_(std::move(streams)),
26       unsqueeze_scalars_(unsqueeze_scalars) {}
27 
28 Scatter::~Scatter() = default;
29 
apply(variable_list && inputs)30 variable_list Scatter::apply(variable_list&& inputs) {
31   AT_ASSERT(inputs.size() == 1);
32   auto& input = inputs.front();
33 
34   std::shared_ptr<Node> grad_fn;
35   if (compute_requires_grad(input)) {
36     grad_fn =
37         std::make_shared<Gather>(/*destination_device=*/input.device(), dim_);
38     grad_fn->set_next_edges(collect_next_edges(input));
39   }
40 
41   auto device_indices = fmap(devices_, [](const at::Device& device) -> int64_t {
42     return device.index();
43   });
44   auto tensors =
45       torch::cuda::scatter(input, device_indices, chunk_sizes_, dim_, streams_);
46 
47   std::vector<Variable> variables;
48   variables.reserve(tensors.size());
49   for (auto& tensor : tensors) {
50     AT_ASSERT(tensor.defined());
51     if (unsqueeze_scalars_) {
52       AT_ASSERT(tensor.dim() == 1 && tensor.numel() == 1);
53       variables.push_back(tensor[0]);
54     } else {
55       variables.push_back(std::move(tensor));
56     }
57   }
58 
59   if (grad_fn) {
60     set_history(variables, grad_fn);
61   }
62 
63   return variables;
64 }
65 
Gather(const at::Device & destination_device,int64_t dim)66 Gather::Gather(const at::Device& destination_device, int64_t dim)
67     : destination_device_(destination_device), dim_(dim) {}
68 
69 Gather::~Gather() = default;
70 
apply(variable_list && inputs)71 variable_list Gather::apply(variable_list&& inputs) {
72   bool all_are_zero_dim = true;
73   for (const auto& input : inputs) {
74     TORCH_CHECK(
75         input.is_cuda(),
76         "All inputs to Gather must be CUDA tensors, got ",
77         input.toString());
78     if (input.dim() > 0) {
79       all_are_zero_dim = false;
80     }
81   }
82 
83   const bool unsqueeze_scalars = all_are_zero_dim && dim_ == 0;
84   if (unsqueeze_scalars) {
85     TORCH_WARN(
86         "Was asked to gather along dimension 0, but all "
87         "input tensors were scalars; will instead unsqueeze "
88         "and return a vector.");
89   }
90 
91   std::shared_ptr<Node> grad_fn;
92   // compute this before moving variables from `inputs`
93   if (compute_requires_grad(inputs)) {
94     std::vector<at::Device> source_devices;
95     source_devices.reserve(inputs.size());
96     std::vector<int64_t> input_sizes;
97     input_sizes.reserve(inputs.size());
98     for (auto& input : inputs) {
99       source_devices.push_back(input.device());
100       input_sizes.push_back(input.size(dim_));
101     }
102     grad_fn = std::make_shared<Scatter>(
103         std::move(source_devices),
104         std::move(input_sizes),
105         dim_,
106         /*streams=*/std::nullopt,
107         /*unsqueeze_scalars=*/unsqueeze_scalars);
108     grad_fn->set_next_edges(collect_next_edges(inputs));
109   }
110 
111   std::vector<at::Tensor> tensors;
112   tensors.reserve(inputs.size());
113   for (auto& variable : inputs) {
114     if (unsqueeze_scalars) {
115       tensors.push_back(variable.view(1));
116     } else {
117       tensors.push_back(std::move(variable));
118     }
119   }
120 
121   // Disable the autograd during the actual computation
122   // torch::cuda::gather does not return a view or change things inplace
123   // so no need for extra logic here
124   at::Tensor variable;
125   {
126     at::AutoDispatchBelowAutograd mode;
127     // This is special logic for torch::cuda::gather!
128     const auto destination_index =
129         destination_device_.is_cpu() ? -1 : destination_device_.index();
130     variable = torch::cuda::gather(tensors, dim_, destination_index);
131   }
132   if (grad_fn) {
133     set_history(variable, grad_fn);
134   }
135   return {variable};
136 }
137 
138 } // namespace torch::autograd
139