xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/parallel/data_parallel.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/cuda.h>
4 #include <torch/nn/module.h>
5 #include <torch/nn/pimpl.h>
6 #include <torch/types.h>
7 
8 #include <ATen/core/functional.h>
9 #include <torch/csrc/autograd/functions/comm.h>
10 #include <torch/csrc/autograd/functions/utils.h>
11 
12 #include <ATen/Device.h>
13 #include <ATen/Parallel.h>
14 #include <c10/core/TensorOptions.h>
15 #include <c10/util/Exception.h>
16 #include <c10/util/irange.h>
17 
18 #include <cstddef>
19 #include <exception>
20 #include <memory>
21 #include <mutex>
22 #include <vector>
23 
24 namespace torch {
25 namespace nn {
26 
27 namespace {
28 
29 // Note [Replicating Modules]
30 // ~~~~~~~~~~~~~~~~~~~~~~~~~~
31 //
32 // Module replication is implemented in the following two steps:
33 // 1) create a module replica on each destination device using Module.clone().
34 // 2) manually add a gradient edge pointing from every parameter X in every
35 //    module replica to the same parameter X in the original module, using
36 //    ReduceAdd as the grad_fn.
37 //
38 // ReduceAdd can ONLY be used during the backward pass of data parallel. Forward
39 // pass cannot use this function as it does not setup gradient function and
40 // history at all. Do NOT try to use ReduceAdd for any other purposes.
41 //
42 // NB: An alternative is to add Broadcast and ReduceAddCoalesce to
43 // torch/csrc/autograd/functions/comm.cpp as normal autograd functions,
44 // implement a Replicatable (like cloneable) class and add it as a friend class
45 // in Module.h. In the forward pass, the Replicatable could use the Broadcast
46 // function to replicate every module parameter and set gradient functions using
47 // ReduceAddCoalesce (like how it is implemented in Python). However, unlike in
48 // Python, where changes to Linear._parameters["weight"] would also apply to
49 // Linear.weight (using Linear as an example), Linear.weight and
50 // Linear.parameters_["weight"] are two tensor objects pointing to the same
51 // TensorImpl. Assigning a new tensor to Linear.parameters_["weight"] will not
52 // change Linear.weight. To make this work, we will have to:
53 // 1) force every module to also inherit from Replicatable
54 // 2) force every module to implement an additional function, e.g.,
55 //    Replicatable::load_params(), to pick up changes from parameters_ to their
56 //    own member fields.
57 // This will be an overkill as Replicatable will only be used in data_parallel,
58 // not even ddp.
59 
60 // Autograd function for the replicate step in data parallel. This is only used
61 // in data parallel, and should not be exposed as a user API.
62 struct ReduceAdd : public autograd::Node {
ReduceAddReduceAdd63   explicit ReduceAdd(const at::Device& destination_device)
64       : destination_device_(destination_device){};
~ReduceAddReduceAdd65   ~ReduceAdd() override {}
66 
applyReduceAdd67   autograd::variable_list apply(autograd::variable_list&& inputs) override {
68     TORCH_CHECK(
69         !torch::autograd::compute_requires_grad(inputs),
70         "ReduceAdd can only be used during the backward pass of data parallel.");
71 
72     Tensor output = torch::zeros_like(inputs[0], {destination_device_});
73 
74     for (auto& input : inputs) {
75       TORCH_CHECK(
76           input.sizes() == inputs[0].sizes(),
77           "All inputs of ReduceAdd must have the same size, but got ",
78           input.sizes(),
79           " and ",
80           inputs[0].sizes());
81 
82       TORCH_CHECK(
83           input.dtype() == inputs[0].dtype(),
84           "All inputs of ReduceAdd must have the same dtype, but got ",
85           input.dtype(),
86           " and ",
87           inputs[0].dtype());
88 
89       // TODO: use nccl reduce
90       output.add_(input.to(destination_device_));
91     }
92 
93     return {output};
94   }
95 
96  private:
97   at::Device destination_device_;
98 };
99 
100 } // namespace
101 
102 // A friend function to Module, it recursively sets gradient edges pointing from
103 // every parameter X in every module replica to the same parameter X in the
104 // original module. See [Replicating Modules]
105 template <typename ModuleType>
replicate_grad_edges(const std::shared_ptr<Module> & module,const std::vector<std::shared_ptr<ModuleType>> & replicas,const std::vector<Device> & devices)106 void replicate_grad_edges(
107     const std::shared_ptr<Module>& module,
108     const std::vector<std::shared_ptr<ModuleType>>& replicas,
109     const std::vector<Device>& devices) {
110   for (auto& parameter : module->named_parameters(/*recurse=*/false)) {
111     auto grad_fn = std::make_shared<ReduceAdd>((*parameter).device());
112     grad_fn->set_next_edges(autograd::collect_next_edges(*parameter));
113 
114     for (const auto i : c10::irange(devices.size())) {
115       autograd::set_history(replicas[i]->parameters_[parameter.key()], grad_fn);
116     }
117   }
118 
119   for (auto& buffer : module->named_buffers(/*recurse=*/false)) {
120     if (buffer.value().requires_grad()) {
121       auto grad_fn = std::make_shared<ReduceAdd>((*buffer).device());
122       grad_fn->set_next_edges(autograd::collect_next_edges(*buffer));
123 
124       for (const auto i : c10::irange(devices.size())) {
125         autograd::set_history(replicas[i]->buffers_[buffer.key()], grad_fn);
126       }
127     }
128   }
129 
130   for (auto& child : module->children_) {
131     std::vector<std::shared_ptr<Module>> child_replicas;
132     child_replicas.reserve(devices.size());
133     for (auto& replica : replicas) {
134       child_replicas.push_back(replica->children_[child.key()]);
135     }
136 
137     // recursively set gradient edges for all children
138     replicate_grad_edges(*child, child_replicas, devices);
139   }
140 }
141 
142 namespace parallel {
143 
144 /// Replicates a module on the given list of devices.
145 /// A replica is created by calling `clone()` on the module. For this, the
146 /// module must inherit from `nn::Cloneable`, or define its own `clone()`
147 /// method, which is expected to perform a deep copy of the module.
148 template <typename ModuleType>
replicate(const std::shared_ptr<ModuleType> & module,const std::vector<Device> & devices)149 std::vector<std::shared_ptr<ModuleType>> replicate(
150     const std::shared_ptr<ModuleType>& module,
151     const std::vector<Device>& devices) {
152   std::vector<std::shared_ptr<ModuleType>> replicas;
153   replicas.reserve(devices.size());
154   for (const auto& device : devices) {
155     replicas.push_back(
156         std::dynamic_pointer_cast<ModuleType>(module->clone(device)));
157   }
158   // Configure gradient edges to point from replcia parameters to original
159   // module parameters. See [Replicating Modules]
160   replicate_grad_edges(module, replicas, devices);
161   return replicas;
162 }
163 
164 /// Replicates a module holder on the given list of devices.
165 /// This method allows calling `replicate()` with a module holder, such as
166 /// `Linear`.
167 template <typename ModuleType>
replicate(const ModuleHolder<ModuleType> & module,const std::vector<Device> & devices)168 std::vector<ModuleHolder<ModuleType>> replicate(
169     const ModuleHolder<ModuleType>& module,
170     const std::vector<Device>& devices) {
171   auto ptrs = replicate(module.ptr(), devices);
172   return std::vector<ModuleHolder<ModuleType>>(ptrs.begin(), ptrs.end());
173 }
174 
175 /// Applies the given inputs to the given modules in a parallel fashion.
176 /// Conceptually, a thread is spawned for each `(module, input)` pair, in which
177 /// `forward()` is called on the module with its corresponding input. The
178 /// outputs of the individual calls are stored in a vector and returned.
179 ///
180 /// The first exception caught by any thread is stashed and rethrown after all
181 /// threads have completed their operation.
182 ///
183 /// Further remarks:
184 /// 1. The length of the module container must match the length of the inputs.
185 /// 2. If a list of devices is supplied, it must match the list of modules in
186 /// length. Each device will be set to the current default device during the
187 /// invocation of the respective module. This means any tensors allocated on the
188 /// default device inside the module will be constructed on this device.
189 template <typename ModuleType>
190 std::vector<Tensor> parallel_apply(
191     std::vector<ModuleType>& modules,
192     const std::vector<Tensor>& inputs,
193     const std::optional<std::vector<Device>>& devices = std::nullopt) {
194   TORCH_CHECK(
195       modules.size() == inputs.size(), "Must have as many inputs as modules");
196   if (devices) {
197     TORCH_CHECK(
198         modules.size() == devices->size(),
199         "Must have as many devices as modules");
200   }
201 
202   std::vector<Tensor> outputs(modules.size());
203   std::mutex mutex;
204 
205   // std::exception_ptr can be passed between threads:
206   // > An instance of std::exception_ptr may be passed to another function,
207   // > possibly on another thread, where the exception may be rethrown [...].
208   // https://en.cppreference.com/w/cpp/error/exception_ptr
209   std::exception_ptr exception;
210 
211   at::parallel_for(
212       /*begin=*/0,
213       /*end=*/modules.size(),
214       /*grain_size=*/1,
215       [&modules, &inputs, &devices, &outputs, &mutex, &exception](
216           int64_t index, int64_t stop) {
217         for (; index < stop; ++index) {
218           try {
219             auto output = modules[index]->forward(inputs[index]);
220             output =
221                 output.to(devices ? (*devices)[index] : inputs[index].device());
222             std::lock_guard<std::mutex> lock(mutex);
223             outputs[index] = output;
224           } catch (...) {
225             std::lock_guard<std::mutex> lock(mutex);
226             if (!exception) {
227               exception = std::current_exception();
228             }
229           }
230         }
231       });
232 
233   if (exception) {
234     std::rethrow_exception(exception);
235   }
236 
237   return outputs;
238 }
239 
240 /// Evaluates `module(input)` in parallel across the given `devices`. If
241 /// `devices` is not supplied, the invocation is parallelized across all
242 /// available CUDA devices. If `output_device` is supplied, the final, combined
243 /// tensor will be placed on this device. If not, it defaults to the first
244 /// device in `devices`.
245 ///
246 /// In detail, this method performs the following four distinct steps:
247 /// 1. *Scatter* the input to the given devices,
248 /// 2. *Replicate* (deep clone) the model on each device,
249 /// 3. *Evaluate* each module with its input on its device,
250 /// 4. *Gather* the outputs of each replica into a single output tensor, located
251 /// on the `output_device`.
252 template <typename ModuleType>
253 Tensor data_parallel(
254     ModuleType module,
255     Tensor input,
256     std::optional<std::vector<Device>> devices = std::nullopt,
257     std::optional<Device> output_device = std::nullopt,
258     int64_t dim = 0) {
259   if (!devices) {
260     const auto device_count = torch::cuda::device_count();
261     TORCH_CHECK(
262         device_count > 0, "Expected at least one CUDA device to be available");
263     devices = std::vector<Device>();
264     devices->reserve(device_count);
265     for (const auto index : c10::irange(device_count)) {
266       devices->emplace_back(kCUDA, static_cast<torch::DeviceIndex>(index));
267     }
268   }
269   if (!output_device) {
270     output_device = devices->front();
271   }
272 
273   if (devices->size() == 1) {
274     module->to(devices->front());
275     input = input.to(devices->front());
276     return module->forward(std::move(input)).to(*output_device);
277   }
278 
279   autograd::Scatter scatter(*devices, /*chunk_sizes=*/nullopt, dim);
280   auto scattered_inputs = fmap<Tensor>(scatter.apply({std::move(input)}));
281   // Input tensor might not be big enough to scale across all available devices
282   if (scattered_inputs.size() < devices->size()) {
283     devices->resize(
284         scattered_inputs.size(),
285         Device(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES));
286   }
287 
288   auto replicas = replicate(module, *devices);
289   auto outputs = parallel_apply(replicas, scattered_inputs, *devices);
290   return autograd::Gather(*output_device, dim)
291       .apply(fmap<autograd::Variable>(std::move(outputs)))
292       .front();
293 }
294 
295 } // namespace parallel
296 } // namespace nn
297 } // namespace torch
298