xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /******************************************************************************
2  * Copyright (c) 2023, Tri Dao.
3  ******************************************************************************/
4 
5 #pragma once
6 
7 #include <cassert>
8 #include <cstdint>
9 #include <cstdlib>
10 
11 #include <cuda_fp16.h>
12 
13 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
14 #include <cuda_bf16.h>
15 #endif
16 
17 #include <cute/algorithm/copy.hpp>
18 #include <cute/algorithm/gemm.hpp>
19 
20 #include <cutlass/array.h>
21 #include <cutlass/cutlass.h>
22 #include <cutlass/numeric_conversion.h>
23 #include <cutlass/numeric_types.h>
24 
25 ////////////////////////////////////////////////////////////////////////////////////////////////////
26 
27 namespace pytorch_flash {
28 
29 ////////////////////////////////////////////////////////////////////////////////////////////////////
30 
31 template<typename T>
32 __forceinline__ __device__ uint32_t relu2(const uint32_t x);
33 
34 template<>
35 __forceinline__ __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
36     uint32_t res;
37     const uint32_t zero = 0u;
38 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
39     asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
40 #else
41     asm volatile( \
42         "{\n" \
43         "\t .reg .f16x2 sela;\n" \
44         "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \
45         "\t and.b32 %0, sela, %1;\n"
46         "}\n" : "=r"(res) : "r"(x), "r"(zero));
47 #endif
48     return res;
49 }
50 
51 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
52 template<>
53 __forceinline__ __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
54     uint32_t res;
55     const uint32_t zero = 0u;
56     asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
57     return res;
58 }
59 #endif
60 
61 ////////////////////////////////////////////////////////////////////////////////////////////////////
62 
63 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
64 
65 template<typename T>
66 __forceinline__ __device__ uint32_t convert_relu2(const float2 x);
67 
68 template<>
69 __forceinline__ __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
70     uint32_t res;
71     const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
72     const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
73     asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
74     return res;
75 }
76 
77 template<>
78 __forceinline__ __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
79     uint32_t res;
80     const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
81     const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
82     asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
83     return res;
84 }
85 
86 #endif
87 
88 ////////////////////////////////////////////////////////////////////////////////////////////////////
89 
90 template<typename T>
91 struct MaxOp {
operatorMaxOp92 __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
93 };
94 
95 template <>
96 struct MaxOp<float> {
97 // This is slightly faster
98 __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
99 };
100 
101 ////////////////////////////////////////////////////////////////////////////////////////////////////
102 
103 template<typename T>
104 struct SumOp {
105 __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
106 };
107 
108 ////////////////////////////////////////////////////////////////////////////////////////////////////
109 
110 template<int THREADS>
111 struct Allreduce {
112     static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
113     template<typename T, typename Operator>
114     static __device__ __forceinline__ T run(T x, Operator &op) {
115         constexpr int OFFSET = THREADS / 2;
116         x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
117         return Allreduce<OFFSET>::run(x, op);
118     }
119 };
120 
121 ////////////////////////////////////////////////////////////////////////////////////////////////////
122 
123 template<>
124 struct Allreduce<2> {
125 template<typename T, typename Operator>
126 static __device__ __forceinline__ T run(T x, Operator &op) {
127     x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
128     return x;
129 }
130 };
131 
132 ////////////////////////////////////////////////////////////////////////////////////////////////////
133 
134 template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
135          typename Tensor2, typename Tensor3, typename Tensor4,
136          typename TiledMma, typename TiledCopyA, typename TiledCopyB,
137          typename ThrCopyA, typename ThrCopyB>
138 __forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
139                             Tensor4 const& tCsB, TiledMma tiled_mma,
140                             TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
141                             ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
142     CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));                     // MMA_M
143     CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));                     // MMA_N
144     CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));                     // MMA_K
145     Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
146     CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));            // M
147     Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
148     CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));            // N
149     if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
150     if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
151     #pragma unroll
152     for (int i = 0; i < size<2>(tCrA); ++i) {
153         if (i < size<2>(tCrA) - 1) {
154             if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
155             if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
156         }
157         cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
158     }
159 }
160 
161 ////////////////////////////////////////////////////////////////////////////////////////////////////
162 
163 template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
164          typename TiledMma, typename TiledCopy, typename ThrCopy>
165 __forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
166                                TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
167                                ThrCopy smem_thr_copy_B) {
168     CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));                     // MMA_M
169     CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));                     // MMA_N
170     CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));                     // MMA_K
171     Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
172     CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));            // N
173     cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
174     #pragma unroll
175     for (int i = 0; i < size<2>(tCrA); ++i) {
176         if (i < size<2>(tCrA) - 1) {
177             cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
178         }
179         cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
180     }
181 }
182 
183 ////////////////////////////////////////////////////////////////////////////////////////////////////
184 
185 // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
186 template<typename Layout>
187 __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
188     static_assert(decltype(size<0>(acc_layout))::value == 4);
189     static_assert(decltype(rank(acc_layout))::value == 3);
190     auto l = logical_divide(acc_layout, Shape<_2>{});  // ((2, 2), MMA_M, MMA_N)
191     return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
192 };
193 
194 ////////////////////////////////////////////////////////////////////////////////////////////////////
195 
196 // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
197 // if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
198 template<typename MMA_traits, typename Layout>
199 __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
200     using X = Underscore;
201     static_assert(decltype(size<0>(acc_layout))::value == 4);
202     static_assert(decltype(rank(acc_layout))::value == 3);
203     constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
204     static_assert(mma_shape_K == 8 || mma_shape_K == 16);
205     if constexpr (mma_shape_K == 8) {
206         return acc_layout;
207     } else {
208         auto l = logical_divide(acc_layout, Shape<X, X, _2>{});  // (4, MMA_M, (2, MMA_N / 2)))
209         return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
210     }
211 };
212 
213 ////////////////////////////////////////////////////////////////////////////////////////////////////
214 
215 // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
216 template<typename Layout>
217 __forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) {
218     using X = Underscore;
219     static_assert(decltype(size<0>(acc_layout))::value == 4);
220     static_assert(decltype(rank(acc_layout))::value == 3);
221     auto l = logical_divide(acc_layout, Shape<X, X, _2>{});  // (4, MMA_M, (2, MMA_N / 2)))
222     return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
223 };
224 
225 ////////////////////////////////////////////////////////////////////////////////////////////////////
226 
227 template <typename To_type, typename Engine, typename Layout>
228 __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
229     using From_type = typename Engine::value_type;
230     constexpr int numel = decltype(size(tensor))::value;
231     cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
232     // HACK: this requires tensor to be "contiguous"
233     auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
234     return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
235 }
236 
237 ////////////////////////////////////////////////////////////////////////////////////////////////////
238 
239 template <typename Engine, typename Layout>
240 __forceinline__ __device__ void relu_(Tensor<Engine, Layout> &tensor) {
241     constexpr int numel = decltype(size(tensor))::value;
242     static_assert(numel % 2 == 0);
243     using value_t = typename Engine::value_type;
244     // HACK: this requires tensor to be "contiguous"
245     Tensor tensor_uint32 = recast<uint32_t>(tensor);
246     #pragma unroll
247     for (int i = 0; i < size(tensor_uint32); ++i) {
248         tensor_uint32(i) = relu2<value_t>(tensor_uint32(i));
249     }
250 }
251 
252 ////////////////////////////////////////////////////////////////////////////////////////////////////
253 
254 // On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
255 template <typename To_type, typename Engine, typename Layout>
256 __forceinline__ __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
257     using From_type = typename Engine::value_type;
258     static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
259     static_assert(std::is_same_v<float, From_type>);
260     constexpr int numel = decltype(size(tensor))::value;
261     static_assert(numel % 2 == 0);
262 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
263     // HACK: this requires tensor to be "contiguous"
264     Tensor tensor_float2 = recast<float2>(tensor);
265     Tensor out_uint32 = make_tensor<uint32_t>(tensor_float2.layout());
266     #pragma unroll
267     for (int i = 0; i < size(out_uint32); ++i) {
268         out_uint32(i) = convert_relu2<To_type>(tensor_float2(i));
269     }
270     Tensor out = make_tensor(make_rmem_ptr<To_type>(out_uint32.data()), tensor.layout());
271 #else
272     Tensor out = pytorch_flash::convert_type<To_type>(tensor);
273     pytorch_flash::relu_(out);
274 #endif
275     return out;
276 }
277 
278 ////////////////////////////////////////////////////////////////////////////////////////////////////
279 
280 // Blocks until all but N previous cp.async.commit_group operations have committed.
281 // This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
282 // (which is equivalent to commit_group then wait_group 0).
283 // Instead we just call cp.async.wait_group 0, which is slightly faster.
284 // https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
285 template <int N>
286 CUTE_HOST_DEVICE
287 void cp_async_wait() {
288 #if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
289     asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
290 #endif
291 }
292 
293 ////////////////////////////////////////////////////////////////////////////////////////////////////
294 
295 template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
296           typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
297           typename Engine2, typename Layout2, typename Engine3, typename Layout3>
298 __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
299                             Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
300                             Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
301     CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
302     CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
303     CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA
304     CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M
305     CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K
306     // There's no case where !Clear_OOB_K && Clear_OOB_MN
307     static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
308     #pragma unroll
309     for (int m = 0; m < size<1>(S); ++m) {
310         if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
311             #pragma unroll
312             for (int k = 0; k < size<2>(S); ++k) {
313                 if (Is_even_K || predicate_K(k)) {
314                     cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
315                 } else if (Clear_OOB_K) {
316                     cute::clear(D(_, m, k));
317                 }
318             }
319         } else if (Clear_OOB_MN) {
320             cute::clear(D(_, m, _));
321         }
322     }
323     // TD [2023-04-13]: Strange that the code below can cause race condition.
324     // I think it's because the copies are under an if statement.
325     // if (Is_even_K) {
326     //     #pragma unroll
327     //     for (int m = 0; m < size<1>(S); ++m) {
328     //         if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
329     //             copy(tiled_copy, S(_, m, _), D(_, m, _));
330     //         } else if (Clear_OOB_MN) {
331     //             clear(D(_, m, _));
332     //         }
333     //     }
334     // } else {  // It's slightly faster in this case if iterate over K first
335     //     #pragma unroll
336     //     for (int k = 0; k < size<2>(S); ++k) {
337     //         if (predicate_K(k)) {
338     //             #pragma unroll
339     //             for (int m = 0; m < size<1>(S); ++m) {
340     //                 if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
341     //                     copy(tiled_copy, S(_, m, k), D(_, m, k));
342     //                 } else if (Clear_OOB_MN) {
343     //                     clear(D(_, m, k));
344     //                 }
345     //             }
346     //         } else if (Clear_OOB_K) {  // There's no case where !Clear_OOB_K && Clear_OOB_MN
347     //             if (Clear_OOB_MN || Is_even_MN) {
348     //                 clear(D(_, _, k));
349     //             } else {
350     //                 #pragma unroll
351     //                 for (int m = 0; m < size<1>(S); ++m) {
352     //                     if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) {
353     //                         clear(D(_, m, k));
354     //                     }
355     //                 }
356     //             }
357     //         }
358     //     }
359     // }
360 }
361 
362 ////////////////////////////////////////////////////////////////////////////////////////////////////
363 
364 template <bool Is_even_K=true,
365           typename Engine0, typename Layout0, typename Engine1, typename Layout1,
366           typename Engine2, typename Layout2, typename Engine3, typename Layout3>
367 __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
368                                       Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
369                                       Tensor<Engine3, Layout3> const &predicate_K,
370                                       const int max_MN=0, const int min_MN=0) {
371     CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
372     CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
373     CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA
374     CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M
375     CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K
376     // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
377     #pragma unroll
378     for (int m = 0; m < size<1>(S); ++m) {
379         // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
380         if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
381             // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
382             #pragma unroll
383             for (int k = 0; k < size<2>(S); ++k) {
384                 if (Is_even_K || predicate_K(k)) {
385                     cute::copy(S(_, m, k), D(_, m, k));
386                 }
387             }
388         }
389     }
390 }
391 
392 ////////////////////////////////////////////////////////////////////////////////////////////////////
393 
394 }  // namespace pytorch_flash
395