1 #if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
2
3 #include <ATen/ATen.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/cuda/CUDAContext.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/empty_like.h>
12 #endif
13
14 #include <torch/library.h>
15
16 #include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
17 #include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
18
19 namespace {
20
21 using namespace c10d::symmetric_memory;
22
get_and_verify_alignment(const at::Tensor & input,const char * op_name)23 size_t get_and_verify_alignment(const at::Tensor& input, const char* op_name) {
24 const size_t min_alignment = std::max(4l, input.element_size());
25 // Only check the offset since the multicast address is always at least
26 // 128-bit aligned
27 const size_t ptr_alignment = get_alignment(
28 static_cast<size_t>(input.storage_offset() * input.element_size()));
29 TORCH_CHECK(
30 ptr_alignment >= min_alignment,
31 op_name,
32 "<",
33 input.scalar_type(),
34 ">: input ptr + offset must be at least ",
35 min_alignment,
36 "-byte aligned.");
37
38 const size_t size_alignment =
39 get_alignment(static_cast<size_t>(input.numel() * input.element_size()));
40 TORCH_CHECK(
41 size_alignment >= min_alignment,
42 op_name,
43 "<",
44 input.scalar_type(),
45 ">: input size must be at least ",
46 min_alignment,
47 "-byte aligned.");
48 return std::min(ptr_alignment, size_alignment);
49 }
50
init_elementwise_launch_config(size_t numel,size_t element_size,size_t alignment,size_t splits,int & num_blocks,int & num_threads)51 void init_elementwise_launch_config(
52 size_t numel,
53 size_t element_size,
54 size_t alignment,
55 size_t splits,
56 int& num_blocks,
57 int& num_threads) {
58 // Align to preserve alignment in each split
59 const size_t aligned_numel = at::round_up(numel, alignment * splits);
60 const size_t numel_per_split = aligned_numel / splits;
61 const size_t numel_per_thread = alignment / element_size;
62
63 if (numel_per_split <= max_num_threads_per_block * numel_per_thread) {
64 num_blocks = 1;
65 num_threads = at::round_up(
66 at::ceil_div(numel_per_split, numel_per_thread),
67 static_cast<size_t>(C10_WARP_SIZE));
68 } else {
69 num_blocks = std::min(
70 at::ceil_div(
71 numel_per_split, max_num_threads_per_block * numel_per_thread),
72 max_num_blocks);
73 num_threads = max_num_threads_per_block;
74 }
75 }
76
77 template <typename T, int alignment>
multimem_all_reduce_kernel(T * input_mc_ptr,size_t numel,uint32_t ** signal_pads,size_t rank,size_t world_size)78 static __global__ void multimem_all_reduce_kernel(
79 T* input_mc_ptr,
80 size_t numel,
81 uint32_t** signal_pads,
82 size_t rank,
83 size_t world_size) {
84 static_assert(alignment % sizeof(T) == 0);
85 constexpr size_t numel_per_thread = alignment / sizeof(T);
86
87 barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size);
88
89 const size_t numel_per_rank =
90 at::round_up(numel, alignment * world_size) / world_size;
91 const size_t start = numel_per_rank * rank;
92
93 auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
94 auto stride = blockDim.x * gridDim.x * numel_per_thread;
95 for (size_t i = offset; i < numel_per_rank; i += stride) {
96 if (start + i >= numel) {
97 continue;
98 }
99 auto vec = multimem_ld_reduce_add<alignment>(input_mc_ptr + start + i);
100 multimem_st<alignment>(input_mc_ptr + start + i, vec);
101 }
102 // Establish observation order - all writes are in-flight beyond this point.
103 barrier(signal_pads, rank, world_size);
104 // Establish causality order - all writes are visible to all devices beyond
105 // this point.
106 __threadfence_system();
107 }
108
multimem_all_reduce_(const at::Tensor & input,std::string reduce_op,std::string group_name)109 at::Tensor multimem_all_reduce_(
110 const at::Tensor& input,
111 std::string reduce_op,
112 std::string group_name) {
113 TORCH_CHECK(
114 input.is_contiguous(), "multimem_all_reduce_: input must be contiguous.");
115 TORCH_CHECK(
116 reduce_op == "sum",
117 "multimem_all_reduce_: only sum is supported for now.");
118
119 auto symm_mem = c10d::symmetric_memory::rendezvous(input);
120 TORCH_CHECK(
121 symm_mem != nullptr,
122 "multimem_all_reduce_: input must be allocated with empty_strided_p2p().");
123 TORCH_CHECK(
124 symm_mem->has_multicast_support(),
125 "multimem_all_reduce_: multicast support is required.");
126
127 const size_t alignment =
128 get_and_verify_alignment(input, "multimem_all_reduce_");
129
130 int num_blocks = 0, num_threads = 0;
131 init_elementwise_launch_config(
132 input.numel(),
133 input.element_size(),
134 alignment,
135 symm_mem->get_world_size(),
136 num_blocks,
137 num_threads);
138
139 #define DISPATCH(scalar_t, kernel_alignment) \
140 if (alignment == kernel_alignment) { \
141 multimem_all_reduce_kernel<scalar_t, kernel_alignment> \
142 <<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
143 reinterpret_cast<scalar_t*>(symm_mem->get_multicast_ptr()) + \
144 input.storage_offset(), \
145 input.numel(), \
146 reinterpret_cast<uint32_t**>(symm_mem->get_signal_pad_ptrs_dev()), \
147 symm_mem->get_rank(), \
148 symm_mem->get_world_size()); \
149 C10_CUDA_KERNEL_LAUNCH_CHECK(); \
150 }
151
152 AT_DISPATCH_SWITCH(
153 input.scalar_type(),
154 "multimem_all_reduce",
155 AT_DISPATCH_CASE(at::kBFloat16, [&] {
156 DISPATCH(scalar_t, 16);
157 DISPATCH(scalar_t, 8);
158 DISPATCH(scalar_t, 4);
159 }) AT_DISPATCH_CASE(at::kFloat, [&] {
160 DISPATCH(scalar_t, 16);
161 DISPATCH(scalar_t, 8);
162 DISPATCH(scalar_t, 4);
163 }));
164
165 #undef DISPATCH
166 return input;
167 }
168
169 template <typename T, int alignment>
multimem_one_shot_all_reduce_kernel(T * input_mc_ptr,T * output_ptr,size_t numel,uint32_t ** signal_pads,size_t rank,size_t world_size)170 static __global__ void multimem_one_shot_all_reduce_kernel(
171 T* input_mc_ptr,
172 T* output_ptr,
173 size_t numel,
174 uint32_t** signal_pads,
175 size_t rank,
176 size_t world_size) {
177 static_assert(alignment % sizeof(T) == 0);
178 constexpr size_t numel_per_thread = alignment / sizeof(T);
179
180 barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size);
181
182 auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
183 auto stride = blockDim.x * gridDim.x * numel_per_thread;
184 for (size_t i = offset; i < numel; i += stride) {
185 auto vec = multimem_ld_reduce_add<alignment>(input_mc_ptr + i);
186 *reinterpret_cast<decltype(vec.as_scalar)*>(output_ptr + i) = vec.as_scalar;
187 }
188 }
189
multimem_one_shot_all_reduce(const at::Tensor & input,std::string reduce_op,std::string group_name)190 at::Tensor multimem_one_shot_all_reduce(
191 const at::Tensor& input,
192 std::string reduce_op,
193 std::string group_name) {
194 TORCH_CHECK(
195 input.is_contiguous(),
196 "multimem_one_shot_all_reduce: input must be contiguous.");
197 TORCH_CHECK(
198 reduce_op == "sum",
199 "multimem_one_shot_all_reduce: only sum is supported for now.");
200
201 auto symm_mem = c10d::symmetric_memory::rendezvous(input);
202 TORCH_CHECK(
203 symm_mem != nullptr,
204 "multimem_one_shot_all_reduce: input must be allocated with empty_strided_p2p().");
205 TORCH_CHECK(
206 symm_mem->has_multicast_support(),
207 "multimem_one_shot_all_reduce: requires multicast support.");
208
209 auto output = at::empty_like(input);
210
211 const size_t alignment =
212 get_and_verify_alignment(input, "multimem_one_shot_all_reduce");
213
214 int num_blocks = 0, num_threads = 0;
215 init_elementwise_launch_config(
216 input.numel(),
217 input.element_size(),
218 alignment,
219 1,
220 num_blocks,
221 num_threads);
222
223 #define DISPATCH(scalar_t, kernel_alignment) \
224 if (alignment == kernel_alignment) { \
225 multimem_one_shot_all_reduce_kernel<scalar_t, kernel_alignment> \
226 <<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
227 reinterpret_cast<scalar_t*>(symm_mem->get_multicast_ptr()) + \
228 input.storage_offset(), \
229 output.data_ptr<scalar_t>(), \
230 input.numel(), \
231 reinterpret_cast<uint32_t**>(symm_mem->get_signal_pad_ptrs_dev()), \
232 symm_mem->get_rank(), \
233 symm_mem->get_world_size()); \
234 C10_CUDA_KERNEL_LAUNCH_CHECK(); \
235 }
236
237 AT_DISPATCH_SWITCH(
238 input.scalar_type(),
239 "multimem_all_reduce",
240 AT_DISPATCH_CASE(at::kBFloat16, [&] {
241 DISPATCH(scalar_t, 16);
242 DISPATCH(scalar_t, 8);
243 DISPATCH(scalar_t, 4);
244 }) AT_DISPATCH_CASE(at::kFloat, [&] {
245 DISPATCH(scalar_t, 16);
246 DISPATCH(scalar_t, 8);
247 DISPATCH(scalar_t, 4);
248 }));
249
250 return output;
251 }
252
TORCH_LIBRARY_FRAGMENT(symm_mem,m)253 TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
254 m.def(
255 "multimem_all_reduce_(Tensor input, str reduce_op, str group_name) -> Tensor",
256 torch::dispatch(c10::DispatchKey::CUDA, ::multimem_all_reduce_),
257 {at::Tag::pt2_compliant_tag});
258
259 m.def(
260 "multimem_one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
261 torch::dispatch(c10::DispatchKey::CUDA, ::multimem_one_shot_all_reduce),
262 {at::Tag::pt2_compliant_tag});
263 }
264
265 } // namespace
266
267 #endif
268