xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && CUDART_VERSION >= 12010
4 #define NVCC_SUPPORTS_MULTICAST 1
5 #endif
6 
7 #include <ATen/ATen.h>
8 
9 namespace c10d::symmetric_memory {
10 
11 constexpr size_t max_num_threads_per_block = 1024;
12 constexpr size_t max_num_blocks = 8;
13 
14 template <typename T>
get_alignment(T ptr_or_size)15 size_t get_alignment(T ptr_or_size) {
16   auto val = reinterpret_cast<uintptr_t>(ptr_or_size);
17   if (val % 16 == 0) {
18     return 16;
19   } else if (val % 8 == 0) {
20     return 8;
21   } else if (val % 4 == 0) {
22     return 4;
23   } else if (val % 2 == 0) {
24     return 2;
25   } else {
26     return 1;
27   }
28 }
29 
30 template <>
31 size_t get_alignment<size_t>(size_t size) {
32   return get_alignment(reinterpret_cast<void*>(size));
33 }
34 
35 __device__ __forceinline__ uint32_t
cas_sys(uint32_t * addr,uint32_t compare,uint32_t val)36 cas_sys(uint32_t* addr, uint32_t compare, uint32_t val) {
37 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
38   CUDA_KERNEL_ASSERT(false);
39 #else
40   uint32_t old_val;
41   asm volatile("atom.global.sys.cas.b32 %0, [%1], %2, %3;"
42                : "=r"(old_val)
43                : "l"(addr), "r"(compare), "r"(val)
44                : "memory");
45   return old_val;
46 #endif
47 }
48 
49 __device__ __forceinline__ uint32_t
cas_release_sys(uint32_t * addr,uint32_t compare,uint32_t val)50 cas_release_sys(uint32_t* addr, uint32_t compare, uint32_t val) {
51 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
52   CUDA_KERNEL_ASSERT(false);
53 #else
54   uint32_t old_val;
55   asm volatile("atom.global.release.sys.cas.b32 %0, [%1], %2, %3;"
56                : "=r"(old_val)
57                : "l"(addr), "r"(compare), "r"(val)
58                : "memory");
59   return old_val;
60 #endif
61 }
62 
release_signal(uint32_t * addr)63 __device__ __forceinline__ void release_signal(uint32_t* addr) {
64   while (cas_release_sys(addr, 0, 1) != 0)
65     ;
66 }
67 
wait_signal(uint32_t * addr)68 __device__ __forceinline__ void wait_signal(uint32_t* addr) {
69   while (cas_sys(addr, 1, 0) != 1)
70     ;
71 }
72 
acquire_signal(uint32_t * addr)73 __device__ __forceinline__ uint32_t acquire_signal(uint32_t* addr) {
74 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
75   CUDA_KERNEL_ASSERT(false);
76 #else
77   uint32_t val;
78   asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
79                : "=r"(val)
80                : "l"(addr)
81                : "memory");
82   return val;
83 #endif
84 }
85 
86 // Perform a barrier to establish observation order between memory operations
87 // issued before and after the barrier.
barrier(uint32_t ** signal_pads,size_t rank,size_t world_size)88 __device__ __forceinline__ void barrier(
89     uint32_t** signal_pads,
90     size_t rank,
91     size_t world_size) {
92   if (threadIdx.x < world_size) {
93     auto target_rank = threadIdx.x;
94     release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank);
95     wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
96   }
97   __syncthreads();
98 }
99 
100 // Perform a barrier and establish causality order between memory operations
101 // issued before the calling kernel on all devices and memory operations
102 // issued after this function by all thread in the calling kernel.
103 //
104 // NOTE: this function does NOT ensure that memory operations issues in the
105 // current kernel are visible to all threads in the current kernel.
106 //
107 // | mem ops (guaranteed to be visible by all threads at point T)
108 // | kernel K
109 // | +- mem ops (not guaranteed to be visible all threads at point T)
110 // | +- barrier_and_acquire_previous_kernel_writes()
111 // | +- point T
112 // v
barrier_and_acquire_previous_kernel_writes(uint32_t ** signal_pads,size_t rank,size_t world_size)113 __device__ __forceinline__ void barrier_and_acquire_previous_kernel_writes(
114     uint32_t** signal_pads,
115     size_t rank,
116     size_t world_size) {
117   if (threadIdx.x < world_size) {
118     auto target_rank = threadIdx.x;
119     release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank);
120     wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
121   }
122   __syncthreads();
123   // At this point, we established observation order between memory operations
124   // issued before and after the barrier. Now we convert the observation order
125   // into causality order by having every thread acquire the signals released
126   // by threads on peer devices. Due to the implicit synchronizes-with
127   // relationships at task/kernel boundaries, acquiring the signal released by
128   // thread T in kernel K transitively acquires memory operations issued prior
129   // to kernel K.
130   //
131   // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-fence-interference
132   for (size_t target_rank = 0; target_rank < world_size; ++target_rank) {
133     acquire_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
134   }
135 }
136 
137 template <bool Value, class... Args>
138 inline constexpr bool dependent_bool_value = Value;
139 
140 template <class... Args>
141 inline constexpr bool dependent_false = dependent_bool_value<false, Args...>;
142 
143 template <int Size>
144 union Vec;
145 
146 template <>
147 union Vec<4> {
148   uint16_t u16[2];
149   uint32_t u32, as_scalar;
150 };
151 
152 template <>
153 union Vec<8> {
154   uint16_t u16[4];
155   uint32_t u32[2];
156   uint64_t u64, as_scalar;
157 };
158 
159 template <>
160 union alignas(16) Vec<16> {
161   uint16_t u16[8];
162   uint32_t u32[4];
163   uint64_t u64[2];
164   uint4 u128, as_scalar;
165 };
166 
167 template <typename T>
168 struct MultimemLdReduce {
169   template <int Alignment>
170   __device__ __inline__ Vec<Alignment> operator()(T* mc_ptr) {
171     static_assert(dependent_false<T>);
172   }
173 };
174 
175 template <int Alignment, typename T>
176 __device__ __inline__ Vec<Alignment> multimem_ld_reduce_add(T* mc_ptr) {
177   MultimemLdReduce<T> functor;
178   return functor.template operator()<Alignment>(mc_ptr);
179 }
180 
181 #if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
182 #define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type)        \
183   template <>                                                       \
184   struct MultimemLdReduce<type> {                                   \
185     template <int Alignment>                                        \
186     __device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) { \
187       CUDA_KERNEL_ASSERT(false);                                    \
188     }                                                               \
189   };
190 #else
191 #define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type)                   \
192   template <>                                                                  \
193   struct MultimemLdReduce<type> {                                              \
194     template <int Alignment>                                                   \
195     __device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) {            \
196       Vec<Alignment> vec;                                                      \
197       if constexpr (Alignment == 16) {                                         \
198         asm("multimem.ld_reduce.relaxed.sys.global.add.v4." asm_type           \
199             " {%0,%1,%2,%3}, [%4];"                                            \
200             : "=r"(vec.u32[0]),                                                \
201               "=r"(vec.u32[1]),                                                \
202               "=r"(vec.u32[2]),                                                \
203               "=r"(vec.u32[3])                                                 \
204             : "l"(mc_ptr)                                                      \
205             : "memory");                                                       \
206       } else if constexpr (Alignment == 8) {                                   \
207         asm("multimem.ld_reduce.relaxed.sys.global.add.v2." asm_type           \
208             " {%0,%1}, [%2];"                                                  \
209             : "=r"(vec.u32[0]), "=r"(vec.u32[1])                               \
210             : "l"(mc_ptr)                                                      \
211             : "memory");                                                       \
212       } else if constexpr (Alignment == 4) {                                   \
213         asm("multimem.ld_reduce.relaxed.sys.global.add." asm_type " %0, [%1];" \
214             : "=r"(vec.u32)                                                    \
215             : "l"(mc_ptr)                                                      \
216             : "memory");                                                       \
217       }                                                                        \
218       return vec;                                                              \
219     }                                                                          \
220   };
221 #endif
222 
223 SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(at::BFloat16, "bf16x2");
224 SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(float, "f32");
225 
226 template <int Alignment, typename T>
227 __device__ __inline__ void multimem_st(T* mc_ptr, Vec<Alignment>& vec) {
228 #if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
229   CUDA_KERNEL_ASSERT(false);
230 #else
231   if constexpr (Alignment == 16) {
232     asm("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};"
233         :
234         : "l"(mc_ptr),
235           "r"(vec.u32[0]),
236           "r"(vec.u32[1]),
237           "r"(vec.u32[2]),
238           "r"(vec.u32[3])
239         : "memory");
240   } else if constexpr (Alignment == 8) {
241     asm("multimem.st.relaxed.sys.global.v2.f32 [%0], {%1,%2};"
242         :
243         : "l"(mc_ptr), "r"(vec.u32[0]), "r"(vec.u32[1])
244         : "memory");
245   } else if constexpr (Alignment == 4) {
246     asm("multimem.st.relaxed.sys.global.f32 [%0], %1;"
247         :
248         : "l"(mc_ptr), "r"(vec.u32)
249         : "memory");
250   } else {
251     static_assert(dependent_false<T>);
252   }
253 #endif
254 }
255 
256 } // namespace c10d::symmetric_memory
257