1 #pragma once
2
3 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && CUDART_VERSION >= 12010
4 #define NVCC_SUPPORTS_MULTICAST 1
5 #endif
6
7 #include <ATen/ATen.h>
8
9 namespace c10d::symmetric_memory {
10
11 constexpr size_t max_num_threads_per_block = 1024;
12 constexpr size_t max_num_blocks = 8;
13
14 template <typename T>
get_alignment(T ptr_or_size)15 size_t get_alignment(T ptr_or_size) {
16 auto val = reinterpret_cast<uintptr_t>(ptr_or_size);
17 if (val % 16 == 0) {
18 return 16;
19 } else if (val % 8 == 0) {
20 return 8;
21 } else if (val % 4 == 0) {
22 return 4;
23 } else if (val % 2 == 0) {
24 return 2;
25 } else {
26 return 1;
27 }
28 }
29
30 template <>
31 size_t get_alignment<size_t>(size_t size) {
32 return get_alignment(reinterpret_cast<void*>(size));
33 }
34
35 __device__ __forceinline__ uint32_t
cas_sys(uint32_t * addr,uint32_t compare,uint32_t val)36 cas_sys(uint32_t* addr, uint32_t compare, uint32_t val) {
37 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
38 CUDA_KERNEL_ASSERT(false);
39 #else
40 uint32_t old_val;
41 asm volatile("atom.global.sys.cas.b32 %0, [%1], %2, %3;"
42 : "=r"(old_val)
43 : "l"(addr), "r"(compare), "r"(val)
44 : "memory");
45 return old_val;
46 #endif
47 }
48
49 __device__ __forceinline__ uint32_t
cas_release_sys(uint32_t * addr,uint32_t compare,uint32_t val)50 cas_release_sys(uint32_t* addr, uint32_t compare, uint32_t val) {
51 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
52 CUDA_KERNEL_ASSERT(false);
53 #else
54 uint32_t old_val;
55 asm volatile("atom.global.release.sys.cas.b32 %0, [%1], %2, %3;"
56 : "=r"(old_val)
57 : "l"(addr), "r"(compare), "r"(val)
58 : "memory");
59 return old_val;
60 #endif
61 }
62
release_signal(uint32_t * addr)63 __device__ __forceinline__ void release_signal(uint32_t* addr) {
64 while (cas_release_sys(addr, 0, 1) != 0)
65 ;
66 }
67
wait_signal(uint32_t * addr)68 __device__ __forceinline__ void wait_signal(uint32_t* addr) {
69 while (cas_sys(addr, 1, 0) != 1)
70 ;
71 }
72
acquire_signal(uint32_t * addr)73 __device__ __forceinline__ uint32_t acquire_signal(uint32_t* addr) {
74 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
75 CUDA_KERNEL_ASSERT(false);
76 #else
77 uint32_t val;
78 asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
79 : "=r"(val)
80 : "l"(addr)
81 : "memory");
82 return val;
83 #endif
84 }
85
86 // Perform a barrier to establish observation order between memory operations
87 // issued before and after the barrier.
barrier(uint32_t ** signal_pads,size_t rank,size_t world_size)88 __device__ __forceinline__ void barrier(
89 uint32_t** signal_pads,
90 size_t rank,
91 size_t world_size) {
92 if (threadIdx.x < world_size) {
93 auto target_rank = threadIdx.x;
94 release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank);
95 wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
96 }
97 __syncthreads();
98 }
99
100 // Perform a barrier and establish causality order between memory operations
101 // issued before the calling kernel on all devices and memory operations
102 // issued after this function by all thread in the calling kernel.
103 //
104 // NOTE: this function does NOT ensure that memory operations issues in the
105 // current kernel are visible to all threads in the current kernel.
106 //
107 // | mem ops (guaranteed to be visible by all threads at point T)
108 // | kernel K
109 // | +- mem ops (not guaranteed to be visible all threads at point T)
110 // | +- barrier_and_acquire_previous_kernel_writes()
111 // | +- point T
112 // v
barrier_and_acquire_previous_kernel_writes(uint32_t ** signal_pads,size_t rank,size_t world_size)113 __device__ __forceinline__ void barrier_and_acquire_previous_kernel_writes(
114 uint32_t** signal_pads,
115 size_t rank,
116 size_t world_size) {
117 if (threadIdx.x < world_size) {
118 auto target_rank = threadIdx.x;
119 release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank);
120 wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
121 }
122 __syncthreads();
123 // At this point, we established observation order between memory operations
124 // issued before and after the barrier. Now we convert the observation order
125 // into causality order by having every thread acquire the signals released
126 // by threads on peer devices. Due to the implicit synchronizes-with
127 // relationships at task/kernel boundaries, acquiring the signal released by
128 // thread T in kernel K transitively acquires memory operations issued prior
129 // to kernel K.
130 //
131 // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-fence-interference
132 for (size_t target_rank = 0; target_rank < world_size; ++target_rank) {
133 acquire_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
134 }
135 }
136
137 template <bool Value, class... Args>
138 inline constexpr bool dependent_bool_value = Value;
139
140 template <class... Args>
141 inline constexpr bool dependent_false = dependent_bool_value<false, Args...>;
142
143 template <int Size>
144 union Vec;
145
146 template <>
147 union Vec<4> {
148 uint16_t u16[2];
149 uint32_t u32, as_scalar;
150 };
151
152 template <>
153 union Vec<8> {
154 uint16_t u16[4];
155 uint32_t u32[2];
156 uint64_t u64, as_scalar;
157 };
158
159 template <>
160 union alignas(16) Vec<16> {
161 uint16_t u16[8];
162 uint32_t u32[4];
163 uint64_t u64[2];
164 uint4 u128, as_scalar;
165 };
166
167 template <typename T>
168 struct MultimemLdReduce {
169 template <int Alignment>
170 __device__ __inline__ Vec<Alignment> operator()(T* mc_ptr) {
171 static_assert(dependent_false<T>);
172 }
173 };
174
175 template <int Alignment, typename T>
176 __device__ __inline__ Vec<Alignment> multimem_ld_reduce_add(T* mc_ptr) {
177 MultimemLdReduce<T> functor;
178 return functor.template operator()<Alignment>(mc_ptr);
179 }
180
181 #if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
182 #define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \
183 template <> \
184 struct MultimemLdReduce<type> { \
185 template <int Alignment> \
186 __device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) { \
187 CUDA_KERNEL_ASSERT(false); \
188 } \
189 };
190 #else
191 #define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \
192 template <> \
193 struct MultimemLdReduce<type> { \
194 template <int Alignment> \
195 __device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) { \
196 Vec<Alignment> vec; \
197 if constexpr (Alignment == 16) { \
198 asm("multimem.ld_reduce.relaxed.sys.global.add.v4." asm_type \
199 " {%0,%1,%2,%3}, [%4];" \
200 : "=r"(vec.u32[0]), \
201 "=r"(vec.u32[1]), \
202 "=r"(vec.u32[2]), \
203 "=r"(vec.u32[3]) \
204 : "l"(mc_ptr) \
205 : "memory"); \
206 } else if constexpr (Alignment == 8) { \
207 asm("multimem.ld_reduce.relaxed.sys.global.add.v2." asm_type \
208 " {%0,%1}, [%2];" \
209 : "=r"(vec.u32[0]), "=r"(vec.u32[1]) \
210 : "l"(mc_ptr) \
211 : "memory"); \
212 } else if constexpr (Alignment == 4) { \
213 asm("multimem.ld_reduce.relaxed.sys.global.add." asm_type " %0, [%1];" \
214 : "=r"(vec.u32) \
215 : "l"(mc_ptr) \
216 : "memory"); \
217 } \
218 return vec; \
219 } \
220 };
221 #endif
222
223 SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(at::BFloat16, "bf16x2");
224 SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(float, "f32");
225
226 template <int Alignment, typename T>
227 __device__ __inline__ void multimem_st(T* mc_ptr, Vec<Alignment>& vec) {
228 #if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
229 CUDA_KERNEL_ASSERT(false);
230 #else
231 if constexpr (Alignment == 16) {
232 asm("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};"
233 :
234 : "l"(mc_ptr),
235 "r"(vec.u32[0]),
236 "r"(vec.u32[1]),
237 "r"(vec.u32[2]),
238 "r"(vec.u32[3])
239 : "memory");
240 } else if constexpr (Alignment == 8) {
241 asm("multimem.st.relaxed.sys.global.v2.f32 [%0], {%1,%2};"
242 :
243 : "l"(mc_ptr), "r"(vec.u32[0]), "r"(vec.u32[1])
244 : "memory");
245 } else if constexpr (Alignment == 4) {
246 asm("multimem.st.relaxed.sys.global.f32 [%0], %1;"
247 :
248 : "l"(mc_ptr), "r"(vec.u32)
249 : "memory");
250 } else {
251 static_assert(dependent_false<T>);
252 }
253 #endif
254 }
255
256 } // namespace c10d::symmetric_memory
257