xref: /aosp_15_r20/external/pytorch/torch/csrc/cuda/comm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/cuda/comm.h>
2 
3 #include <torch/csrc/cuda/device_set.h>
4 #include <torch/csrc/utils/tensor_flatten.h>
5 
6 #ifdef USE_NCCL
7 #include <torch/csrc/cuda/nccl.h>
8 #endif
9 
10 #include <ATen/ATen.h>
11 #include <ATen/WrapDimUtils.h>
12 #include <ATen/cuda/CUDAContext.h>
13 #include <c10/cuda/CUDAGuard.h>
14 #include <c10/util/irange.h>
15 #include <torch/csrc/autograd/variable.h>
16 #include <optional>
17 
18 #include <cstddef>
19 #include <vector>
20 
21 namespace torch::cuda {
22 using namespace at;
23 using namespace torch::autograd;
24 
25 // Some operations can be performed more efficiently if we're handling tensors
26 // of a single type only. Adding this logic directly in the loop makes it a bit
27 // ugly, so here's a helper for it.
28 struct unique_type_checker {
showtorch::cuda::unique_type_checker29   void show(size_t type_id) {
30     if (!unique) {
31       return;
32     }
33     if (!type_id_) {
34       type_id_ = type_id;
35     }
36 
37     unique = type_id_.value() == type_id;
38   }
39 
40   std::optional<size_t> type_id_;
41   bool unique = true;
42 };
43 
44 // ***************** Broadcast *******************
45 //
46 // Broadcast a source tensor (CPU or CUDA) to a list of CUDA devices, or CUDA
47 // tensors on one or more devices.
48 
49 // no checks
_broadcast_out_impl(const Tensor & tensor,std::vector<Tensor> & out_tensors)50 static inline std::vector<Tensor>& _broadcast_out_impl(
51     const Tensor& tensor,
52     std::vector<Tensor>& out_tensors) {
53 #ifdef USE_NCCL
54   std::vector<Tensor> nccl_list;
55   nccl_list.reserve(out_tensors.size() + 1);
56   nccl_list.emplace_back(tensor);
57   for (auto& out_tensor : out_tensors) {
58     nccl_list.emplace_back(out_tensor);
59   }
60   if (nccl::is_available(nccl_list)) {
61     nccl::broadcast(nccl_list);
62   } else {
63 #else
64   {
65 #endif
66     for (auto& out_tensor : out_tensors) {
67       out_tensor.copy_(tensor, /*non_blocking=*/true);
68     }
69   }
70   return out_tensors;
71 }
72 
73 std::vector<Tensor>& broadcast_out(
74     const Tensor& tensor,
75     std::vector<Tensor>& out_tensors) {
76   for (const auto i : c10::irange(out_tensors.size())) {
77     TORCH_CHECK(
78         out_tensors[i].is_cuda(),
79         "Expected all output tensors to be CUDA tensors, but output tensor at index ",
80         i,
81         " has device '",
82         out_tensors[i].device(),
83         "'");
84     TORCH_CHECK(
85         out_tensors[i].sizes() == tensor.sizes(),
86         "Expected all output tensors to have same shape as the source tensor ",
87         tensor.sizes(),
88         ", but output tensor at index ",
89         i,
90         " has shape ",
91         out_tensors[i].sizes());
92   }
93   return _broadcast_out_impl(tensor, out_tensors);
94 }
95 
96 std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
97   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
98   std::vector<Tensor> diff_device_dst_tensors;
99   diff_device_dst_tensors.reserve(devices.size());
100   for (auto device : devices) {
101     TORCH_CHECK(
102         device >= 0, "Expected non-negative device index, but got ", device);
103     if (device != tensor.get_device()) {
104       diff_device_dst_tensors.emplace_back(at::empty(
105           tensor.sizes(),
106           tensor.options().device(at::Device(
107               DeviceType::CUDA,
108               static_cast<DeviceIndex>(device))))); // preserve memory format
109     }
110   }
111   _broadcast_out_impl(tensor, diff_device_dst_tensors);
112   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
113   std::vector<Tensor> dst_tensors;
114   dst_tensors.reserve(devices.size());
115   auto it = diff_device_dst_tensors.begin();
116   for (auto device : devices) {
117     // NOLINTNEXTLINE(bugprone-branch-clone)
118     if (device != tensor.get_device()) {
119       dst_tensors.emplace_back(*it++);
120     } else {
121       dst_tensors.emplace_back(tensor);
122     }
123   }
124   TORCH_INTERNAL_ASSERT(it == diff_device_dst_tensors.end());
125   return dst_tensors;
126 }
127 
128 // NOTE [ Version Counter in comm.*_coalesced ]
129 //
130 // broadcast_coalesced
131 // ~~~~~~~~~~~~~~~~~~~
132 //
133 // In broadcast_coalesced, multiple variables may be coalesced into a single
134 // large one, broadcast to other devices, and the get split according to the
135 // original shapes.
136 //
137 // When splitting, the view operations will make all Variables broadcast
138 // together to share a single version counter, because they are all views of the
139 // large Variable. However, that large Variable is immediately discarded and all
140 // these Variables do not share storage at all.
141 //
142 // For example, when two buffers are broadcast together in `DataParallel` and
143 // one of them is modified in-place during `forward` but the other is needed in
144 // backward, autograd engine will complain.
145 //
146 // We thus re-wrap these Variables after broadcasting (i.e., effectively doing
147 // what is equivalent to .data in Python), and give them individual version
148 // counters.
149 //
150 // NB: Just calling detach() on the variables is not sufficient
151 //
152 // NB: For `device[0]` in broadcast_coalesced, the input Variables are always
153 //     returned as-is, so **do not** re-wrap them.
154 //
155 // reduce_add_coalesced
156 // ~~~~~~~~~~~~~~~~~~~~
157 //
158 // Similarly for reduce_add_coalesced, when the output are newly created
159 // Variables.
160 tensor_list2d broadcast_coalesced(
161     TensorList tensors,
162     IntArrayRef devices,
163     size_t buffer_size) {
164   TORCH_CHECK(
165       std::all_of(
166           tensors.begin(),
167           tensors.end(),
168           [&](const at::Tensor& t) { return t.get_device() == devices[0]; }),
169       "All tensors must be on devices[0]: ",
170       devices[0]);
171 #ifdef USE_NCCL
172   buffer_size = std::min(torch::cuda::nccl::get_max_count(), buffer_size);
173 #endif
174 
175   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
176   tensor_list2d outputs(devices.size());
177   outputs[0] = tensors.vec();
178   for (auto& o : outputs)
179     o.reserve(tensors.size());
180 
181   unique_type_checker type_checker;
182   at::cuda::CUDAGuard device_guard(static_cast<DeviceIndex>(devices[0]));
183   for (auto& chunk : torch::utils::take_tensors(tensors, buffer_size)) {
184     auto type_id = chunk.type_id();
185     type_checker.show(type_id);
186     std::vector<at::Tensor> results;
187     if (chunk.options().is_sparse()) {
188       auto flat_tuple = torch::utils::flatten_sparse_tensors(chunk.tensors);
189       auto broadcast_indices = broadcast(flat_tuple.first, devices);
190       auto broadcast_values = broadcast(flat_tuple.second, devices);
191       results.reserve(devices.size());
192       for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
193         device_guard.set_index(static_cast<DeviceIndex>(devices[i]));
194         auto& device_outputs = outputs[i];
195         auto& inds = broadcast_indices[i];
196         auto& vals = broadcast_values[i];
197         for (const auto& var : torch::utils::unflatten_sparse_tensors(
198                  inds, vals, chunk.tensors)) {
199           // See NOTE [ Version Counter in comm.*_coalesced ]
200           device_outputs.emplace_back(make_variable(var.tensor_data(), false));
201         }
202       }
203     } else {
204       auto results = broadcast(
205           torch::utils::flatten_dense_tensors(chunk.tensors), devices);
206       for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
207         device_guard.set_index(static_cast<DeviceIndex>(devices[i]));
208         auto& device_outputs = outputs[i];
209         for (auto& var :
210              torch::utils::unflatten_dense_tensors(results[i], chunk.tensors)) {
211           // See NOTE [ Version Counter in comm.*_coalesced ]
212           device_outputs.emplace_back(make_variable(var.tensor_data(), false));
213         }
214       }
215     }
216   }
217 
218   // If we only saw a single tensor type, then we can skip expensive reordering
219   if (!type_checker.unique) {
220     for (auto& o : outputs)
221       torch::utils::reorder_tensors_like(o, tensors);
222   }
223   return outputs;
224 }
225 
226 // ***************** Scatter *******************
227 //
228 // Scatter a source tensor (CPU or CUDA) to a list of CUDA tensors on one or
229 // more devices.
230 
231 std::vector<at::Tensor>& scatter_out(
232     const at::Tensor& tensor,
233     std::vector<at::Tensor>& out_tensors,
234     int64_t dim,
235     const std::optional<std::vector<std::optional<at::cuda::CUDAStream>>>&
236         streams) {
237   TORCH_CHECK(
238       !out_tensors.empty(),
239       "Expected at least one output tensor to scatter to");
240   dim = at::maybe_wrap_dim(dim, tensor);
241   int64_t total_size = 0;
242   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
243   std::vector<int64_t> chunk_sizes;
244   chunk_sizes.reserve(out_tensors.size());
245   for (const auto i : c10::irange(out_tensors.size())) {
246     TORCH_CHECK(
247         out_tensors[i].is_cuda(),
248         "Expected all output tensors to be CUDA tensors, but output tensor at index ",
249         i,
250         " has device '",
251         out_tensors[i].device(),
252         "'");
253     auto out_sizes = out_tensors[i].sizes().vec();
254     bool same_ndim = out_sizes.size() == static_cast<size_t>(tensor.dim());
255     if (same_ndim) {
256       total_size += out_sizes[dim];
257       chunk_sizes.emplace_back(out_sizes[dim]);
258       out_sizes[dim] = tensor.size(dim);
259     }
260     TORCH_CHECK(
261         same_ndim && out_sizes == tensor.sizes(),
262         "Output tensor at index ",
263         i,
264         " has incorrect shape: ",
265         out_tensors[i].sizes(),
266         ". Expected same "
267         "shape except for scatter dim ",
268         dim,
269         " as the source tensor: ",
270         at::IntArrayRef(tensor.sizes()));
271   }
272   TORCH_CHECK(
273       total_size == tensor.size(dim),
274       "Total size for output tensors along scatter dim ",
275       dim,
276       " does not match "
277       "the source tensor size at dim ",
278       dim,
279       ". Expected ",
280       tensor.size(dim),
281       ", but got total size ",
282       total_size);
283 
284   auto chunks =
285       tensor.split_with_sizes(/*split_sizes=*/chunk_sizes, /*dim=*/dim);
286   at::cuda::OptionalCUDAStreamGuard cuda_guard;
287   for (const auto i : c10::irange(chunks.size())) {
288     if (i < (streams ? streams->size() : 0U) && (*streams)[i]) {
289       const auto device_index =
290           static_cast<int16_t>(out_tensors[i].get_device());
291       TORCH_CHECK(
292           (*streams)[i]->device_index() == device_index,
293           "Expected the device associated with the stream at index ",
294           i,
295           " (was ",
296           (*streams)[i]->device_index(),
297           ") ",
298           "to match the device supplied at that index ",
299           "(expected ",
300           device_index,
301           ")");
302       cuda_guard.reset_stream(*(*streams)[i]);
303     }
304     // NB: We don't detect the case where `out_tensor` is already the correct
305     //     view of `tensor` since that would be nontrivial and involve checking
306     //     ptr, offset, and strides. So `scatter_out(src, src.chunk(...))` does
307     //     more copying than `scatter(src)`.
308     out_tensors[i].copy_(chunks[i], /*non_blocking=*/true);
309   }
310   return out_tensors;
311 }
312 
313 std::vector<at::Tensor> scatter(
314     const at::Tensor& tensor,
315     at::IntArrayRef devices,
316     const std::optional<std::vector<int64_t>>& chunk_sizes,
317     int64_t dim,
318     const std::optional<std::vector<std::optional<at::cuda::CUDAStream>>>&
319         streams) {
320   TORCH_CHECK(!devices.empty(), "Expected at least one device to scatter to");
321   if (chunk_sizes.has_value()) {
322     TORCH_CHECK(
323         chunk_sizes->size() == devices.size(),
324         "Expected devices and chunk_sizes to be of same length, but got "
325         "len(devices) = ",
326         devices.size(),
327         " and len(chunk_sizes) = ",
328         chunk_sizes->size());
329   }
330   dim = at::maybe_wrap_dim(dim, tensor);
331   std::vector<at::Tensor> chunks = chunk_sizes
332       ? tensor.split_with_sizes(/*split_sizes=*/*chunk_sizes, /*dim=*/dim)
333       : tensor.chunk(
334             /*chunks=*/static_cast<int64_t>(devices.size()), /*dim=*/dim);
335   at::cuda::OptionalCUDAStreamGuard cuda_guard;
336   for (const auto i : c10::irange(chunks.size())) {
337     const auto device_index = static_cast<int16_t>(devices[i]);
338     if (device_index != tensor.get_device()) {
339       if (i < (streams ? streams->size() : 0U) && (*streams)[i]) {
340         TORCH_CHECK(
341             (*streams)[i]->device_index() == device_index,
342             "Expected the device associated with the stream at index ",
343             i,
344             " (was ",
345             (*streams)[i]->device_index(),
346             ") ",
347             "to match the device supplied at that index ",
348             "(expected ",
349             device_index,
350             ")");
351         cuda_guard.reset_stream(*(*streams)[i]);
352       }
353       TORCH_CHECK(
354           device_index >= 0,
355           "Expected non-negative device index, but got ",
356           device_index);
357       chunks[i] = chunks[i].to(
358           {DeviceType::CUDA, device_index},
359           /*non_blocking=*/true,
360           /*copy=*/false,
361           /*memory_format=*/at::MemoryFormat::Preserve);
362     }
363   }
364   return chunks;
365 }
366 
367 // ***************** Gather *******************
368 //
369 // Gather a list of CUDA tensors on one or more devices to a target tensor or
370 // device, either CPU or CUDA.
371 
372 // no checks
373 static inline at::Tensor& _gather_out_impl(
374     at::TensorList tensors,
375     at::Tensor& out_tensor,
376     int64_t dim) {
377   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
378   std::vector<int64_t> chunk_sizes;
379   chunk_sizes.reserve(tensors.size());
380   for (auto& tensor : tensors) {
381     chunk_sizes.emplace_back(tensor.size(dim));
382   }
383   auto chunks =
384       out_tensor.split_with_sizes(/*split_sizes=*/chunk_sizes, /*dim=*/dim);
385   for (const auto i : c10::irange(tensors.size())) {
386     chunks[i].copy_(tensors[i], /*non_blocking=*/out_tensor.is_cuda());
387   }
388   return out_tensor;
389 }
390 
391 at::Tensor& gather_out(
392     at::TensorList tensors,
393     at::Tensor& out_tensor,
394     int64_t dim) {
395   TORCH_CHECK(!tensors.empty(), "Expected at least one tensor to gather from");
396   int64_t total_size = 0;
397   auto& first = tensors.front();
398   const auto first_size = first.sizes();
399   dim = at::maybe_wrap_dim(dim, first);
400   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
401   std::vector<int64_t> expected_size(first_size.begin(), first_size.end());
402   for (const auto i : c10::irange(tensors.size())) {
403     const auto& tensor = tensors[i];
404     TORCH_CHECK(
405         tensor.is_cuda(),
406         "Expected all input tensors to be CUDA tensors, but "
407         "tensor at index ",
408         i,
409         " has device '",
410         tensor.device(),
411         "'");
412     TORCH_CHECK(
413         tensor.ndimension() == static_cast<int64_t>(expected_size.size()),
414         "Expected all input tensors to have the same number of dimensions, but ",
415         "tensor at index ",
416         i,
417         "has ",
418         tensor.ndimension(),
419         " dimensions, (expected ",
420         expected_size.size(),
421         ")");
422     expected_size[dim] = tensor.size(dim);
423     for (const auto dimension : c10::irange(expected_size.size())) {
424       TORCH_CHECK(
425           expected_size[dimension] == tensor.size(dimension),
426           "Input tensor at index ",
427           i,
428           " has invalid shape ",
429           tensor.sizes(),
430           ", but expected ",
431           at::IntArrayRef(expected_size));
432     }
433     total_size += tensor.size(dim);
434   }
435   expected_size[dim] = total_size;
436   TORCH_CHECK(
437       out_tensor.sizes() == expected_size,
438       "Expected out tensor to have shape ",
439       at::IntArrayRef(expected_size),
440       ", but got ",
441       out_tensor.sizes())
442 
443   return _gather_out_impl(tensors, out_tensor, dim);
444 }
445 
446 at::Tensor gather(
447     at::TensorList tensors,
448     int64_t dim,
449     std::optional<int32_t> destination_index) {
450   TORCH_CHECK(!tensors.empty(), "Expected at least one tensor to gather from");
451   int64_t total_size = 0;
452   auto& first = tensors.front();
453   const auto first_size = first.sizes();
454   dim = at::maybe_wrap_dim(dim, first);
455   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
456   std::vector<int64_t> expected_size(first_size.begin(), first_size.end());
457   auto memory_format = first.suggest_memory_format();
458   for (const auto i : c10::irange(tensors.size())) {
459     const auto& tensor = tensors[i];
460     TORCH_CHECK(
461         tensor.is_cuda(),
462         "Expected all input tensors to be CUDA tensors, but "
463         "tensor at index ",
464         i,
465         " has device ",
466         tensor.device());
467     TORCH_CHECK(
468         tensor.ndimension() == static_cast<int64_t>(expected_size.size()),
469         "Expected all input tensors to have the same number of dimensions, but ",
470         "tensor at index ",
471         i,
472         "has ",
473         tensor.ndimension(),
474         " dimensions, (expected ",
475         expected_size.size(),
476         ")");
477     expected_size[dim] = tensor.size(dim);
478     for (const auto dimension : c10::irange(expected_size.size())) {
479       TORCH_CHECK(
480           expected_size[dimension] == tensor.size(dimension),
481           "Input tensor at index ",
482           i,
483           " has invalid shape ",
484           tensor.sizes(),
485           ", but expected ",
486           at::IntArrayRef(expected_size));
487     }
488     total_size += tensor.size(dim);
489     if (memory_format != MemoryFormat::Contiguous &&
490         tensor.suggest_memory_format() != memory_format) {
491       memory_format = MemoryFormat::Contiguous;
492     }
493   }
494   expected_size[dim] = total_size;
495   at::Device device(DeviceType::CPU);
496   if (!destination_index || *destination_index != -1) {
497     device = at::Device(
498         DeviceType::CUDA,
499         destination_index ? static_cast<DeviceIndex>(*destination_index)
500                           : DeviceIndex(-1));
501   }
502 
503   at::Tensor result =
504       at::empty(expected_size, first.options().device(device), memory_format);
505   return _gather_out_impl(tensors, result, dim);
506 }
507 
508 } // namespace torch::cuda
509