xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
2 
3 #include <ATen/ATen.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/cuda/CUDAContext.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/empty_like.h>
12 #endif
13 
14 #include <torch/library.h>
15 
16 #include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
17 #include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
18 
19 namespace {
20 
21 using namespace c10d::symmetric_memory;
22 
get_and_verify_alignment(const at::Tensor & input,const char * op_name)23 size_t get_and_verify_alignment(const at::Tensor& input, const char* op_name) {
24   const size_t min_alignment = std::max(4l, input.element_size());
25   // Only check the offset since the multicast address is always at least
26   // 128-bit aligned
27   const size_t ptr_alignment = get_alignment(
28       static_cast<size_t>(input.storage_offset() * input.element_size()));
29   TORCH_CHECK(
30       ptr_alignment >= min_alignment,
31       op_name,
32       "<",
33       input.scalar_type(),
34       ">: input ptr + offset must be at least ",
35       min_alignment,
36       "-byte aligned.");
37 
38   const size_t size_alignment =
39       get_alignment(static_cast<size_t>(input.numel() * input.element_size()));
40   TORCH_CHECK(
41       size_alignment >= min_alignment,
42       op_name,
43       "<",
44       input.scalar_type(),
45       ">: input size must be at least ",
46       min_alignment,
47       "-byte aligned.");
48   return std::min(ptr_alignment, size_alignment);
49 }
50 
init_elementwise_launch_config(size_t numel,size_t element_size,size_t alignment,size_t splits,int & num_blocks,int & num_threads)51 void init_elementwise_launch_config(
52     size_t numel,
53     size_t element_size,
54     size_t alignment,
55     size_t splits,
56     int& num_blocks,
57     int& num_threads) {
58   // Align to preserve alignment in each split
59   const size_t aligned_numel = at::round_up(numel, alignment * splits);
60   const size_t numel_per_split = aligned_numel / splits;
61   const size_t numel_per_thread = alignment / element_size;
62 
63   if (numel_per_split <= max_num_threads_per_block * numel_per_thread) {
64     num_blocks = 1;
65     num_threads = at::round_up(
66         at::ceil_div(numel_per_split, numel_per_thread),
67         static_cast<size_t>(C10_WARP_SIZE));
68   } else {
69     num_blocks = std::min(
70         at::ceil_div(
71             numel_per_split, max_num_threads_per_block * numel_per_thread),
72         max_num_blocks);
73     num_threads = max_num_threads_per_block;
74   }
75 }
76 
77 template <typename T, int alignment>
multimem_all_reduce_kernel(T * input_mc_ptr,size_t numel,uint32_t ** signal_pads,size_t rank,size_t world_size)78 static __global__ void multimem_all_reduce_kernel(
79     T* input_mc_ptr,
80     size_t numel,
81     uint32_t** signal_pads,
82     size_t rank,
83     size_t world_size) {
84   static_assert(alignment % sizeof(T) == 0);
85   constexpr size_t numel_per_thread = alignment / sizeof(T);
86 
87   barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size);
88 
89   const size_t numel_per_rank =
90       at::round_up(numel, alignment * world_size) / world_size;
91   const size_t start = numel_per_rank * rank;
92 
93   auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
94   auto stride = blockDim.x * gridDim.x * numel_per_thread;
95   for (size_t i = offset; i < numel_per_rank; i += stride) {
96     if (start + i >= numel) {
97       continue;
98     }
99     auto vec = multimem_ld_reduce_add<alignment>(input_mc_ptr + start + i);
100     multimem_st<alignment>(input_mc_ptr + start + i, vec);
101   }
102   // Establish observation order - all writes are in-flight beyond this point.
103   barrier(signal_pads, rank, world_size);
104   // Establish causality order - all writes are visible to all devices beyond
105   // this point.
106   __threadfence_system();
107 }
108 
multimem_all_reduce_(const at::Tensor & input,std::string reduce_op,std::string group_name)109 at::Tensor multimem_all_reduce_(
110     const at::Tensor& input,
111     std::string reduce_op,
112     std::string group_name) {
113   TORCH_CHECK(
114       input.is_contiguous(), "multimem_all_reduce_: input must be contiguous.");
115   TORCH_CHECK(
116       reduce_op == "sum",
117       "multimem_all_reduce_: only sum is supported for now.");
118 
119   auto symm_mem = c10d::symmetric_memory::rendezvous(input);
120   TORCH_CHECK(
121       symm_mem != nullptr,
122       "multimem_all_reduce_: input must be allocated with empty_strided_p2p().");
123   TORCH_CHECK(
124       symm_mem->has_multicast_support(),
125       "multimem_all_reduce_: multicast support is required.");
126 
127   const size_t alignment =
128       get_and_verify_alignment(input, "multimem_all_reduce_");
129 
130   int num_blocks = 0, num_threads = 0;
131   init_elementwise_launch_config(
132       input.numel(),
133       input.element_size(),
134       alignment,
135       symm_mem->get_world_size(),
136       num_blocks,
137       num_threads);
138 
139 #define DISPATCH(scalar_t, kernel_alignment)                                   \
140   if (alignment == kernel_alignment) {                                         \
141     multimem_all_reduce_kernel<scalar_t, kernel_alignment>                     \
142         <<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(    \
143             reinterpret_cast<scalar_t*>(symm_mem->get_multicast_ptr()) +       \
144                 input.storage_offset(),                                        \
145             input.numel(),                                                     \
146             reinterpret_cast<uint32_t**>(symm_mem->get_signal_pad_ptrs_dev()), \
147             symm_mem->get_rank(),                                              \
148             symm_mem->get_world_size());                                       \
149     C10_CUDA_KERNEL_LAUNCH_CHECK();                                            \
150   }
151 
152   AT_DISPATCH_SWITCH(
153       input.scalar_type(),
154       "multimem_all_reduce",
155       AT_DISPATCH_CASE(at::kBFloat16, [&] {
156         DISPATCH(scalar_t, 16);
157         DISPATCH(scalar_t, 8);
158         DISPATCH(scalar_t, 4);
159       }) AT_DISPATCH_CASE(at::kFloat, [&] {
160         DISPATCH(scalar_t, 16);
161         DISPATCH(scalar_t, 8);
162         DISPATCH(scalar_t, 4);
163       }));
164 
165 #undef DISPATCH
166   return input;
167 }
168 
169 template <typename T, int alignment>
multimem_one_shot_all_reduce_kernel(T * input_mc_ptr,T * output_ptr,size_t numel,uint32_t ** signal_pads,size_t rank,size_t world_size)170 static __global__ void multimem_one_shot_all_reduce_kernel(
171     T* input_mc_ptr,
172     T* output_ptr,
173     size_t numel,
174     uint32_t** signal_pads,
175     size_t rank,
176     size_t world_size) {
177   static_assert(alignment % sizeof(T) == 0);
178   constexpr size_t numel_per_thread = alignment / sizeof(T);
179 
180   barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size);
181 
182   auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
183   auto stride = blockDim.x * gridDim.x * numel_per_thread;
184   for (size_t i = offset; i < numel; i += stride) {
185     auto vec = multimem_ld_reduce_add<alignment>(input_mc_ptr + i);
186     *reinterpret_cast<decltype(vec.as_scalar)*>(output_ptr + i) = vec.as_scalar;
187   }
188 }
189 
multimem_one_shot_all_reduce(const at::Tensor & input,std::string reduce_op,std::string group_name)190 at::Tensor multimem_one_shot_all_reduce(
191     const at::Tensor& input,
192     std::string reduce_op,
193     std::string group_name) {
194   TORCH_CHECK(
195       input.is_contiguous(),
196       "multimem_one_shot_all_reduce: input must be contiguous.");
197   TORCH_CHECK(
198       reduce_op == "sum",
199       "multimem_one_shot_all_reduce: only sum is supported for now.");
200 
201   auto symm_mem = c10d::symmetric_memory::rendezvous(input);
202   TORCH_CHECK(
203       symm_mem != nullptr,
204       "multimem_one_shot_all_reduce: input must be allocated with empty_strided_p2p().");
205   TORCH_CHECK(
206       symm_mem->has_multicast_support(),
207       "multimem_one_shot_all_reduce: requires multicast support.");
208 
209   auto output = at::empty_like(input);
210 
211   const size_t alignment =
212       get_and_verify_alignment(input, "multimem_one_shot_all_reduce");
213 
214   int num_blocks = 0, num_threads = 0;
215   init_elementwise_launch_config(
216       input.numel(),
217       input.element_size(),
218       alignment,
219       1,
220       num_blocks,
221       num_threads);
222 
223 #define DISPATCH(scalar_t, kernel_alignment)                                   \
224   if (alignment == kernel_alignment) {                                         \
225     multimem_one_shot_all_reduce_kernel<scalar_t, kernel_alignment>            \
226         <<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(    \
227             reinterpret_cast<scalar_t*>(symm_mem->get_multicast_ptr()) +       \
228                 input.storage_offset(),                                        \
229             output.data_ptr<scalar_t>(),                                       \
230             input.numel(),                                                     \
231             reinterpret_cast<uint32_t**>(symm_mem->get_signal_pad_ptrs_dev()), \
232             symm_mem->get_rank(),                                              \
233             symm_mem->get_world_size());                                       \
234     C10_CUDA_KERNEL_LAUNCH_CHECK();                                            \
235   }
236 
237   AT_DISPATCH_SWITCH(
238       input.scalar_type(),
239       "multimem_all_reduce",
240       AT_DISPATCH_CASE(at::kBFloat16, [&] {
241         DISPATCH(scalar_t, 16);
242         DISPATCH(scalar_t, 8);
243         DISPATCH(scalar_t, 4);
244       }) AT_DISPATCH_CASE(at::kFloat, [&] {
245         DISPATCH(scalar_t, 16);
246         DISPATCH(scalar_t, 8);
247         DISPATCH(scalar_t, 4);
248       }));
249 
250   return output;
251 }
252 
TORCH_LIBRARY_FRAGMENT(symm_mem,m)253 TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
254   m.def(
255       "multimem_all_reduce_(Tensor input, str reduce_op, str group_name) -> Tensor",
256       torch::dispatch(c10::DispatchKey::CUDA, ::multimem_all_reduce_),
257       {at::Tag::pt2_compliant_tag});
258 
259   m.def(
260       "multimem_one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
261       torch::dispatch(c10::DispatchKey::CUDA, ::multimem_one_shot_all_reduce),
262       {at::Tag::pt2_compliant_tag});
263 }
264 
265 } // namespace
266 
267 #endif
268