xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /******************************************************************************
2  * Copyright (c) 2024, Tri Dao.
3  ******************************************************************************/
4 #include <c10/core/ScalarType.h>
5 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
6 
7 #include <cstdint>
8 #include <tuple>
9 
10 
11 #ifdef USE_FLASH_ATTENTION
12 #include <ATen/core/Tensor.h>
13 #include <ATen/cuda/CUDAContext.h>
14 #include <c10/cuda/CUDAGuard.h>
15 #include <ATen/cuda/CUDAGraphsUtils.cuh>
16 
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #include <ATen/NativeFunctions.h>
20 #else
21 #include <ATen/ops/empty.h>
22 #include <ATen/ops/empty_like.h>
23 #include <ATen/ops/reshape.h>
24 #include <ATen/ops/scalar_tensor.h>
25 #include <ATen/ops/sum.h>
26 #include <ATen/ops/slice.h>
27 #include <ATen/ops/narrow.h>
28 #include <ATen/ops/pad.h>
29 #include <ATen/ops/zeros.h>
30 #endif
31 
32 
33 #include <cutlass/numeric_types.h>
34 
35 #include <ATen/native/transformers/cuda/flash_attn/flash.h>
36 #include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
37 #include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
38 
39 #include <c10/util/Exception.h>
40 
41 namespace pytorch_flash {
42 
43 #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
44 #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
45 #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
46 
47 
set_params_fprop(Flash_fwd_params & params,const size_t b,const size_t seqlen_q,const size_t seqlen_k,const size_t seqlen_q_rounded,const size_t seqlen_k_rounded,const size_t h,const size_t h_k,const size_t d,const size_t d_rounded,const at::Tensor q,const at::Tensor k,const at::Tensor v,at::Tensor out,void * cu_seqlens_q_d,void * cu_seqlens_k_d,void * seqused_k,void * p_d,void * softmax_lse_d,float p_dropout,float softmax_scale,int window_size_left,int window_size_right,bool seqlenq_ngroups_swapped=false)48 void set_params_fprop(Flash_fwd_params &params,
49                       // sizes
50                       const size_t b,
51                       const size_t seqlen_q,
52                       const size_t seqlen_k,
53                       const size_t seqlen_q_rounded,
54                       const size_t seqlen_k_rounded,
55                       const size_t h,
56                       const size_t h_k,
57                       const size_t d,
58                       const size_t d_rounded,
59                       // device pointers
60                       const at::Tensor q,
61                       const at::Tensor k,
62                       const at::Tensor v,
63                       at::Tensor out,
64                       void *cu_seqlens_q_d,
65                       void *cu_seqlens_k_d,
66                       void *seqused_k,
67                       void *p_d,
68                       void *softmax_lse_d,
69                       float p_dropout,
70                       float softmax_scale,
71                       int window_size_left,
72                       int window_size_right,
73                       bool seqlenq_ngroups_swapped=false) {
74 
75     // Reset the parameters
76     params = {};
77 
78     params.is_bf16 = q.dtype() == at::kBFloat16;
79 
80     // Set the pointers and strides.
81     params.q_ptr = q.data_ptr();
82     params.k_ptr = k.data_ptr();
83     params.v_ptr = v.data_ptr();
84     // All stride are in elements, not bytes.
85     params.q_row_stride = q.stride(-3);
86     params.k_row_stride = k.stride(-3);
87     params.v_row_stride = v.stride(-3);
88     params.q_head_stride = q.stride(-2);
89     params.k_head_stride = k.stride(-2);
90     params.v_head_stride = v.stride(-2);
91     params.o_ptr = out.data_ptr();
92     params.o_row_stride = out.stride(-3);
93     params.o_head_stride = out.stride(-2);
94 
95     if (cu_seqlens_q_d == nullptr) {
96         params.q_batch_stride = q.stride(0);
97         params.k_batch_stride = k.stride(0);
98         params.v_batch_stride = v.stride(0);
99         params.o_batch_stride = out.stride(0);
100         if (seqlenq_ngroups_swapped) {
101              params.q_batch_stride *= seqlen_q;
102              params.o_batch_stride *= seqlen_q;
103         }
104     }
105 
106     params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
107     params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
108     params.seqused_k = static_cast<int *>(seqused_k);
109 
110     // P = softmax(QK^T)
111     params.p_ptr = p_d;
112 
113     // Softmax sum
114     params.softmax_lse_ptr = softmax_lse_d;
115 
116     // Set the dimensions.
117     params.b = b;
118     params.h = h;
119     params.h_k = h_k;
120     params.h_h_k_ratio = h / h_k;
121     params.seqlen_q = seqlen_q;
122     params.seqlen_k = seqlen_k;
123     params.seqlen_q_rounded = seqlen_q_rounded;
124     params.seqlen_k_rounded = seqlen_k_rounded;
125     params.d = d;
126     params.d_rounded = d_rounded;
127 
128     // Set the different scale values.
129     params.scale_softmax = softmax_scale;
130     params.scale_softmax_log2 = softmax_scale * M_LOG2E;
131 
132     // Set this to probability of keeping an element to simplify things.
133     params.p_dropout = 1.f - p_dropout;
134     // Convert p from float to int so we don't have to convert the random uint to float to compare.
135     // [Minor] We want to round down since when we do the comparison we use <= instead of <
136     // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
137     // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
138     params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
139     params.rp_dropout = 1.f / params.p_dropout;
140     params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
141     TORCH_CHECK(p_dropout < 1.f);
142     #ifdef FLASHATTENTION_DISABLE_DROPOUT
143         TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
144     #endif
145 
146     // Causal is the special case where window_size_right == 0 and window_size_left < 0.
147     // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
148     params.is_causal = window_size_left < 0 && window_size_right == 0;
149 
150     if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
151     if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
152     params.window_size_left = window_size_left;
153     params.window_size_right = window_size_right;
154 
155     #ifdef FLASHATTENTION_DISABLE_LOCAL
156         TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
157             "This flash attention build does not support local attention.");
158     #endif
159 
160     params.is_seqlens_k_cumulative = true;
161 
162     #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
163         TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
164     #endif
165 }
166 
set_params_dgrad(Flash_bwd_params & params,const size_t b,const size_t seqlen_q,const size_t seqlen_k,const size_t seqlen_q_rounded,const size_t seqlen_k_rounded,const size_t h,const size_t h_k,const size_t d,const size_t d_rounded,const at::Tensor q,const at::Tensor k,const at::Tensor v,const at::Tensor out,const at::Tensor dout,at::Tensor dq,at::Tensor dk,at::Tensor dv,void * cu_seqlens_q_d,void * cu_seqlens_k_d,void * dq_accum_d,void * dk_accum_d,void * dv_accum_d,void * softmax_lse_d,void * dsoftmax_sum_d,float p_dropout,float softmax_scale,int window_size_left,int window_size_right,bool deterministic)167 void set_params_dgrad(Flash_bwd_params &params,
168                       // sizes
169                       const size_t b,
170                       const size_t seqlen_q,
171                       const size_t seqlen_k,
172                       const size_t seqlen_q_rounded,
173                       const size_t seqlen_k_rounded,
174                       const size_t h,
175                       const size_t h_k,
176                       const size_t d,
177                       const size_t d_rounded,
178                       // device pointers
179                       const at::Tensor q,
180                       const at::Tensor k,
181                       const at::Tensor v,
182                       const at::Tensor out,
183                       const at::Tensor dout,
184                       at::Tensor dq,
185                       at::Tensor dk,
186                       at::Tensor dv,
187                       void *cu_seqlens_q_d,
188                       void *cu_seqlens_k_d,
189                       void *dq_accum_d,
190                       void *dk_accum_d,
191                       void *dv_accum_d,
192                       void *softmax_lse_d,
193                       void *dsoftmax_sum_d,
194                       float p_dropout,
195                       float softmax_scale,
196                       int window_size_left,
197                       int window_size_right,
198                       bool deterministic) {
199 
200     set_params_fprop(params,
201                      b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
202                      q, k, v, out,
203                      cu_seqlens_q_d,
204                      cu_seqlens_k_d,
205                      nullptr,
206                      nullptr,
207                      softmax_lse_d,
208                      p_dropout,
209                      softmax_scale,
210                      window_size_left,
211                      window_size_right);
212 
213     // Set the pointers and strides.
214     params.do_ptr = dout.data_ptr();
215     params.do_row_stride = dout.stride(-3);
216     params.do_head_stride = dout.stride(-2);
217     params.dq_ptr = dq.data_ptr();
218     params.dk_ptr = dk.data_ptr();
219     params.dv_ptr = dv.data_ptr();
220     params.dq_row_stride = dq.stride(-3);
221     params.dk_row_stride = dk.stride(-3);
222     params.dv_row_stride = dv.stride(-3);
223     params.dq_head_stride = dq.stride(-2);
224     params.dk_head_stride = dk.stride(-2);
225     params.dv_head_stride = dv.stride(-2);
226 
227     if (cu_seqlens_q_d == nullptr) {
228         params.do_batch_stride = dout.stride(0);
229         params.dq_batch_stride = dq.stride(0);
230         params.dk_batch_stride = dk.stride(0);
231         params.dv_batch_stride = dv.stride(0);
232     }
233 
234     params.dq_accum_ptr = dq_accum_d;
235     params.dk_accum_ptr = dk_accum_d;
236     params.dv_accum_ptr = dv_accum_d;
237 
238     // Softmax sum
239     params.dsoftmax_sum = dsoftmax_sum_d;
240 
241     params.deterministic = deterministic;
242 }
243 
run_mha_fwd(Flash_fwd_params & params,cudaStream_t stream,bool force_split_kernel=false)244 void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
245     FP16_SWITCH(!params.is_bf16, [&] {
246         HEADDIM_SWITCH(params.d, [&] {
247             if (params.num_splits <= 1 && !force_split_kernel) {  // If we don't set it num_splits == 0
248                 run_mha_fwd_<elem_type, kHeadDim>(params, stream);
249             } else {
250                 run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
251             }
252         });
253     });
254 }
255 
256 // Find the number of splits that maximizes the occupancy. For example, if we have
257 // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
258 // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
259 // splits as that would incur more HBM reads/writes.
260 // So we find the best efficiency, then find the smallest number of splits that gets 85%
261 // of the best efficiency.
num_splits_heuristic(int batch_nheads_mblocks,int num_SMs,int num_n_blocks,int max_splits)262 inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
263     // If we have enough to almost fill the SMs, then just use 1 split
264     if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
265     max_splits = std::min({max_splits, num_SMs, num_n_blocks});
266     float max_efficiency = 0.f;
267     std::vector<float> efficiency;
268     efficiency.reserve(max_splits);
269     auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
270     // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
271     // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
272     // (i.e. it's 11 splits anyway).
273     // So we check if the number of blocks per split is the same as the previous num_splits.
274     auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
275         return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
276     };
277     for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
278         if (!is_split_eligible(num_splits)) {
279             efficiency.push_back(0.f);
280         } else {
281             float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
282             float eff = n_waves / ceil(n_waves);
283             // printf("num_splits = %d, eff = %f\n", num_splits, eff);
284             if (eff > max_efficiency) { max_efficiency = eff; }
285             efficiency.push_back(eff);
286         }
287     }
288     for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
289         if (!is_split_eligible(num_splits)) { continue; }
290         if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
291             // printf("num_splits chosen = %d\n", num_splits);
292             return num_splits;
293         }
294     }
295     return 1;
296 }
set_params_splitkv(Flash_fwd_params & params,const int batch_size,const int num_heads,const int head_size,const int max_seqlen_k,const int max_seqlen_q,const int head_size_rounded,const float p_dropout,const int num_splits,cudaDeviceProp * dprops,struct c10::TensorOptions opts)297 std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params, const int batch_size,
298     const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
299     const int head_size_rounded, const float p_dropout,
300     const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
301 
302     // This needs to match with run_mha_fwd_splitkv_dispatch
303     const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
304     const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
305     // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
306     // In any case we don't expect seqlen_q to be larger than 64 for inference.
307     const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
308     params.num_splits = num_splits;
309     at::Tensor softmax_lse_accum;
310     at::Tensor out_accum;
311 
312     if (p_dropout == 0.0f) {  // SplitKV is not implemented for dropout
313         if (num_splits < 1) {
314             // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
315             params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
316         }
317         if (params.num_splits > 1) {
318             softmax_lse_accum = at::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
319             out_accum = at::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
320             params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
321             params.oaccum_ptr = out_accum.data_ptr();
322         }
323         TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
324     }
325 
326     return std::make_tuple(softmax_lse_accum, out_accum);
327 }
328 
set_params_alibi(Flash_fwd_params & params,std::optional<at::Tensor> & alibi_slopes_,int batch_size,int num_heads)329 void set_params_alibi(Flash_fwd_params &params, std::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
330 #ifdef FLASHATTENTION_DISABLE_ALIBI
331     TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
332     params.alibi_slopes_ptr = nullptr;
333 #else
334     if (alibi_slopes_.has_value()) {
335         auto alibi_slopes = alibi_slopes_.value();
336         TORCH_CHECK(alibi_slopes.dtype() == at::kFloat, "ALiBi slopes must have dtype fp32");
337         CHECK_DEVICE(alibi_slopes);
338         TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
339         TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({num_heads}) || alibi_slopes.sizes() == at::IntArrayRef({batch_size, num_heads}));
340         params.alibi_slopes_ptr = alibi_slopes.data_ptr();
341         params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
342     } else {
343         params.alibi_slopes_ptr = nullptr;
344     }
345 #endif
346 }
347 
348 // return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
349 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_fwd(const at::Tensor & q,const at::Tensor & k,const at::Tensor & v,std::optional<at::Tensor> & out_,std::optional<at::Tensor> & alibi_slopes_,const float p_dropout,const float softmax_scale,bool is_causal,int window_size_left,int window_size_right,const bool return_softmax,std::optional<at::Generator> gen_)350 mha_fwd(const at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
351         const at::Tensor &k,         // batch_size x seqlen_k x num_heads_k x head_size
352         const at::Tensor &v,         // batch_size x seqlen_k x num_heads_k x head_size
353         std::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x head_size
354         std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
355         const float p_dropout,
356         const float softmax_scale,
357         bool is_causal,
358         int window_size_left,
359         int window_size_right,
360         const bool return_softmax,
361         std::optional<at::Generator> gen_) {
362 
363     auto dprops = at::cuda::getCurrentDeviceProperties();
364     // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
365     bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
366     bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
367     TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
368     // We will support Turing in the near future
369     // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
370 
371     auto q_dtype = q.dtype();
372     TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
373                 "FlashAttention only support fp16 and bf16 data type");
374     if (q_dtype == at::kBFloat16) {
375         TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
376     }
377     TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
378     TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
379 
380     CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
381 
382     TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
383     TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
384     TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
385 
386     const auto sizes = q.sizes();
387 
388     const int batch_size = sizes[0];
389     int seqlen_q = sizes[1];
390     int num_heads = sizes[2];
391     const int head_size_og = sizes[3];
392     const int seqlen_k = k.size(1);
393     const int num_heads_k = k.size(2);
394     TORCH_CHECK(batch_size > 0, "batch size must be positive");
395     TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!");
396     TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
397     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
398 
399     if (window_size_left >= seqlen_k) { window_size_left = -1; }
400     if (window_size_right >= seqlen_k) { window_size_right = -1; }
401 
402     // causal=true is the same as causal=false in this case
403     if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
404     if (is_causal) { window_size_right = 0; }
405 
406     // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
407     // H/t Daniel Haziza
408     const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
409     const int ngroups = num_heads / num_heads_k;
410     at::Tensor temp_q = q;
411     if (seqlenq_ngroups_swapped) {
412         temp_q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
413         seqlen_q = ngroups;
414         num_heads = num_heads_k;
415     }
416 
417     CHECK_SHAPE(temp_q, batch_size, seqlen_q, num_heads, head_size_og);
418     CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
419     CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
420 
421     at::Tensor q_padded, k_padded, v_padded;
422     q_padded = temp_q;
423     k_padded = k;
424     v_padded = v;
425 
426     at::Tensor out;
427     if (out_.has_value()) {
428         out = out_.value();
429         TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
430         CHECK_DEVICE(out);
431         TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
432         CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
433         if (seqlenq_ngroups_swapped) {
434             out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
435         }
436         CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
437         if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); }
438     } else {
439         out = at::empty_like(q_padded);
440     }
441 
442     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
443     const int head_size = round_multiple(head_size_og, 8);
444     const int head_size_rounded = round_multiple(head_size, 32);
445     const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
446     const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
447 
448     // Otherwise the kernel will be launched from cuda:0 device
449     // Cast to char to avoid compiler warning about narrowing
450     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
451 
452     auto opts = q.options();
453 
454     auto softmax_lse = at::empty({batch_size, num_heads, seqlen_q }, opts.dtype(at::kFloat));
455 
456     at::Tensor p;
457     // Only return softmax if there's dropout to reduce compilation time
458     if (return_softmax) {
459         TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
460         p = at::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
461     }
462 
463     Flash_fwd_params params;
464     set_params_fprop(params,
465                      batch_size,
466                      seqlen_q, seqlen_k,
467                      seqlen_q_rounded, seqlen_k_rounded,
468                      num_heads, num_heads_k,
469                      head_size, head_size_rounded,
470                      q_padded, k_padded, v_padded, out,
471                      /*cu_seqlens_q_d=*/nullptr,
472                      /*cu_seqlens_k_d=*/nullptr,
473                      /*seqused_k=*/nullptr,
474                      return_softmax ? p.data_ptr() : nullptr,
475                      softmax_lse.data_ptr(),
476                      p_dropout,
477                      softmax_scale,
478                      window_size_left,
479                      window_size_right);
480 
481 
482     // Keep references to these tensors to extend their lifetime
483     at::Tensor softmax_lse_accum, out_accum;
484     std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads,
485                         head_size, seqlen_k, seqlen_q,
486                         head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
487 
488     // We want to checkpoint and save the RNG state for backward if dropout
489     // We get the default generator and return the seed and offset which will
490     // be used in the backward function
491     at::Tensor seed_t, offset_t;
492     if (p_dropout > 0.0)  {
493         auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
494         // number of times random will be generated per thread, to offset philox counter in thc random
495         // state
496         // We use a custom RNG that increases the offset by batch_size * nheads * 32.
497         int64_t counter_offset = params.b * params.h * 32;
498         // See Note [Acquire lock when using random generators]
499         std::lock_guard<std::mutex> lock(gen->mutex_);
500         at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
501         if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
502           auto [seed, offset] = at::cuda::philox::unpack(philox_state);
503           seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
504           offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
505         } else {
506           seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
507           offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
508           params.seed = seed_t.data_ptr<int64_t>();
509           params.extragraph_offset = offset_t.data_ptr<int64_t>();
510         }
511         params.philox_args = philox_state;
512     } else {
513         if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
514             seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
515             offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
516         } else {
517             seed_t = at::empty({}, at::dtype(at::kLong));
518             offset_t = at::empty({}, at::dtype(at::kLong));
519         }
520 
521     }
522 
523     set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
524 
525     if (seqlen_k > 0) {
526         auto stream = at::cuda::getCurrentCUDAStream().stream();
527         run_mha_fwd(params, stream);
528     } else {
529         // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
530         out.zero_();
531         softmax_lse.fill_(std::numeric_limits<float>::infinity());
532     }
533 
534     if (seqlenq_ngroups_swapped) {
535         out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
536         q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
537         softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
538     }
539     return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p};
540 }
541 
542 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_varlen_fwd(const at::Tensor & q,const at::Tensor & k,const at::Tensor & v,std::optional<at::Tensor> & out_,const at::Tensor & cu_seqlens_q,const at::Tensor & cu_seqlens_k,std::optional<at::Tensor> & seqused_k,std::optional<at::Tensor> & block_table_,std::optional<at::Tensor> & alibi_slopes_,int max_seqlen_q,const int max_seqlen_k,const float p_dropout,const float softmax_scale,const bool zero_tensors,bool is_causal,int window_size_left,int window_size_right,const bool return_softmax,std::optional<at::Generator> gen_)543 mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
544                const at::Tensor &k,  // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
545                const at::Tensor &v,  // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
546                std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
547                const at::Tensor &cu_seqlens_q,  // b+1
548                const at::Tensor &cu_seqlens_k,  // b+1
549                std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
550                std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
551                std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
552                int max_seqlen_q,
553                const int max_seqlen_k,
554                const float p_dropout,
555                const float softmax_scale,
556                const bool zero_tensors,
557                bool is_causal,
558                int window_size_left,
559                int window_size_right,
560                const bool return_softmax,
561                std::optional<at::Generator> gen_) {
562 
563     auto dprops = at::cuda::getCurrentDeviceProperties();
564     // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
565     bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
566     bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
567     TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
568     // We will support Turing in the near future
569     // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
570 
571     auto q_dtype = q.dtype();
572     TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
573                 "FlashAttention only support fp16 and bf16 data type");
574     if (q_dtype == at::kBFloat16) {
575         TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
576     }
577     TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
578     TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
579     TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32");
580     TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32");
581 
582     CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
583     CHECK_DEVICE(cu_seqlens_q);
584     CHECK_DEVICE(cu_seqlens_k);
585 
586     at::Tensor block_table;
587     const bool paged_KV = block_table_.has_value();
588     if (paged_KV) {
589         block_table = block_table_.value();
590         CHECK_DEVICE(block_table);
591         TORCH_CHECK(block_table.dtype() == at::kInt, "block_table must have dtype torch.int32");
592         TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
593     }
594 
595     TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
596     TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
597     TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
598     CHECK_CONTIGUOUS(cu_seqlens_q);
599     CHECK_CONTIGUOUS(cu_seqlens_k);
600 
601     const auto sizes = q.sizes();
602 
603     const int batch_size = cu_seqlens_q.numel() - 1;
604     int num_heads = sizes[1];
605     const int head_size_og = sizes[2];
606     const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
607 
608     const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
609     const int num_blocks = !paged_KV ? 0 : k.size(0);
610     const int page_block_size = !paged_KV ? 1 : k.size(1);
611     TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
612 
613     if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }  // causal=true is the same as causal=false in this case
614     if (is_causal) { window_size_right = 0; }
615 
616     void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
617 
618     // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
619     // H/t Daniel Haziza
620     const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
621     at::Tensor temp_q = q;
622     const int ngroups = num_heads / num_heads_k;
623     if (seqlenq_ngroups_swapped) {
624         temp_q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
625         max_seqlen_q = ngroups;
626         num_heads = num_heads_k;
627         cu_seqlens_q_d = nullptr;
628     }
629 
630     const int total_q = temp_q.sizes()[0];
631 
632     TORCH_CHECK(batch_size > 0, "batch size must be positive");
633     TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
634     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
635     TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!")
636 
637     if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
638     if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
639 
640     CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og);
641     if (!paged_KV) {
642         const int total_k = k.size(0);
643         CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
644         CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
645     } else {
646         CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
647         CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
648         CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
649     }
650     CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
651     CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
652     if (seqused_k.has_value()){
653         auto seqused_k_ = seqused_k.value();
654         TORCH_CHECK(seqused_k_.dtype() == at::kInt, "seqused_k must have dtype int32");
655         TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
656         TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
657         CHECK_SHAPE(seqused_k_, batch_size);
658     }
659 
660     at::Tensor q_padded, k_padded, v_padded;
661     q_padded = temp_q;
662     k_padded = k;
663     v_padded = v;
664 
665     at::Tensor out;
666     if (out_.has_value()) {
667         out = out_.value();
668         TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
669         CHECK_DEVICE(out);
670         TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
671         CHECK_SHAPE(out, total_q, num_heads, head_size_og);
672         CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
673         if (seqlenq_ngroups_swapped) {
674             out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
675         }
676         if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); }
677     } else {
678         out = at::empty_like(q_padded);
679     }
680 
681     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
682     const int head_size = round_multiple(head_size_og, 8);
683     const int head_size_rounded = round_multiple(head_size, 32);
684     const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
685     const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
686 
687     // Otherwise the kernel will be launched from cuda:0 device
688     // Cast to char to avoid compiler warning about narrowing
689     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
690 
691     auto opts = q.options();
692 
693     auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
694     at::Tensor p;
695     // Only return softmax if there's dropout to reduce compilation time
696     if (return_softmax) {
697         TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
698         p = at::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
699     }
700 
701     if (zero_tensors) {
702         out.zero_();
703         softmax_lse.fill_(-std::numeric_limits<float>::infinity());
704         if (return_softmax) {p.zero_();}
705     }
706 
707     Flash_fwd_params params;
708     set_params_fprop(params,
709                      batch_size,
710                      max_seqlen_q, max_seqlen_k,
711                      seqlen_q_rounded, seqlen_k_rounded,
712                      num_heads, num_heads_k,
713                      head_size, head_size_rounded,
714                      q_padded, k_padded, v_padded, out,
715                      cu_seqlens_q_d,
716                      cu_seqlens_k.data_ptr(),
717                      seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
718                      return_softmax ? p.data_ptr() : nullptr,
719                      softmax_lse.data_ptr(),
720                      p_dropout,
721                      softmax_scale,
722                      window_size_left,
723                      window_size_right,
724                      seqlenq_ngroups_swapped);
725     if (paged_KV) {
726         params.block_table = block_table.data_ptr<int>();
727         params.block_table_batch_stride = block_table.stride(0);
728         params.k_batch_stride = k_padded.stride(0);
729         params.v_batch_stride = v_padded.stride(0);
730     }
731     params.page_block_size = page_block_size;
732     // Keep references to these tensors to extend their lifetime
733     at::Tensor softmax_lse_accum, out_accum;
734     if (seqlenq_ngroups_swapped) {
735         // Only apply split-k for decoding
736         std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads,
737                            head_size, max_seqlen_k, max_seqlen_q,
738                            head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
739     }
740 
741     // We want to checkpoint and save the RNG state for backward if dropout
742     // We get the default generator and return the seed and offset which will
743     // be used in the backward function
744     auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
745     at::Tensor seed_t, offset_t;
746     if (p_dropout > 0.0)  {
747         // number of times random will be generated per thread, to offset philox counter in thc random
748         // state
749         // We use a custom RNG that increases the offset by batch_size * nheads * 32.
750         int64_t counter_offset = params.b * params.h * 32;
751         // See Note [Acquire lock when using random generators]
752         std::lock_guard<std::mutex> lock(gen->mutex_);
753         at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
754         if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
755           auto [seed, offset] = at::cuda::philox::unpack(philox_state);
756           seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
757           offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
758         } else {
759           seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
760           offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
761           params.seed = seed_t.data_ptr<int64_t>();
762           params.extragraph_offset = offset_t.data_ptr<int64_t>();
763         }
764         params.philox_args = philox_state;
765     } else {
766         if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
767             seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
768             offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
769         } else {
770             seed_t = at::empty({}, at::dtype(at::kLong));
771             offset_t = at::empty({}, at::dtype(at::kLong));
772         }
773 
774     }
775 
776     set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
777 
778     if (max_seqlen_k > 0) {
779         auto stream = at::cuda::getCurrentCUDAStream().stream();
780         run_mha_fwd(params, stream, paged_KV);
781     } else {
782         // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
783         out.zero_();
784         softmax_lse.fill_(std::numeric_limits<float>::infinity());
785     }
786 
787     if (seqlenq_ngroups_swapped) {
788         std::array<int64_t, 4> size_before = {batch_size, max_seqlen_q, num_heads_k, head_size_og};
789         std::array<int64_t, 3> size_after = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
790         out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
791         q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
792         softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1});
793     }
794 
795     return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p};
796 }
797 
run_mha_bwd(Flash_bwd_params & params,cudaStream_t stream)798 void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
799     FP16_SWITCH(!params.is_bf16, [&] {
800         HEADDIM_SWITCH(params.d, [&] {
801             run_mha_bwd_<elem_type, kHeadDim>(params, stream);
802         });
803     });
804 }
805 
806 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_bwd(const at::Tensor & dout,const at::Tensor & q,const at::Tensor & k,const at::Tensor & v,const at::Tensor & out,const at::Tensor & softmax_lse,std::optional<at::Tensor> & dq_,std::optional<at::Tensor> & dk_,std::optional<at::Tensor> & dv_,std::optional<at::Tensor> & alibi_slopes_,const float p_dropout,const float softmax_scale,const bool is_causal,int window_size_left,int window_size_right,const bool deterministic,const at::Tensor philox_seed,const at::Tensor philox_offset)807 mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_size_og
808         const at::Tensor &q,   // batch_size x seqlen_q x num_heads x head_size
809         const at::Tensor &k,   // batch_size x seqlen_k x num_heads_k x head_size
810         const at::Tensor &v,   // batch_size x seqlen_k x num_heads_k x head_size
811         const at::Tensor &out,   // batch_size x seqlen_q x num_heads x head_size
812         const at::Tensor &softmax_lse,     // b x h x seqlen_q
813         std::optional<at::Tensor> &dq_,   // batch_size x seqlen_q x num_heads x head_size
814         std::optional<at::Tensor> &dk_,   // batch_size x seqlen_k x num_heads_k x head_size
815         std::optional<at::Tensor> &dv_,   // batch_size x seqlen_k x num_heads_k x head_size
816         std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
817         const float p_dropout,         // probability to drop
818         const float softmax_scale,
819         const bool is_causal,
820         int window_size_left,
821         int window_size_right,
822         const bool deterministic,
823         const at::Tensor philox_seed,
824         const at::Tensor philox_offset) {
825 
826     #ifdef FLASHATTENTION_DISABLE_BACKWARD
827         TORCH_CHECK(false, "This flash attention build does not support backward.");
828     #endif
829     if (is_causal) { window_size_right = 0; }
830     auto dprops = at::cuda::getCurrentDeviceProperties();
831     // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
832     bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
833     bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
834     bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
835     TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
836     // We will support Turing in the near future
837     // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
838 
839     bool is_dropout = p_dropout > 0.0;
840     auto stream = at::cuda::getCurrentCUDAStream().stream();
841 
842     auto q_dtype = q.dtype();
843     TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
844                 "FlashAttention only support fp16 and bf16 data type");
845     if (q_dtype == at::kBFloat16) {
846         TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
847     }
848     TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
849     TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
850     TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
851     TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
852 
853     CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
854     CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
855 
856     TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
857     TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
858     TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
859     TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
860     TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
861 
862     const auto sizes = q.sizes();
863 
864     const int batch_size = sizes[0];
865     const int seqlen_q = sizes[1];
866     const int num_heads = sizes[2];
867     const int head_size_og = dout.size(3);
868     const int head_size = sizes[3];
869     const int seqlen_k = k.size(1);
870     const int num_heads_k = k.size(2);
871     TORCH_CHECK(batch_size > 0, "batch size must be positive");
872     TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
873     TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
874     TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
875     if (head_size > 192 && (head_size <= 224 || is_dropout)) {
876         TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
877     }
878     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
879 
880     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
881     const int head_size_rounded = round_multiple(head_size, 32);
882     const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
883     const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
884 
885     TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
886 
887     if (window_size_left >= seqlen_k) { window_size_left = -1; }
888     if (window_size_right >= seqlen_k) { window_size_right = -1; }
889 
890     CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
891     CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
892     CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
893     CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
894     CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
895 
896     at::Tensor dq, dk, dv;
897     if (dq_.has_value()) {
898         dq = dq_.value();
899         TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
900         CHECK_DEVICE(dq);
901         TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
902         CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
903     } else {
904         dq = at::empty_like(q);
905     }
906     if (dk_.has_value()) {
907         dk = dk_.value();
908         TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
909         CHECK_DEVICE(dk);
910         TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
911         CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
912     } else {
913         dk = at::empty_like(k);
914     }
915     if (dv_.has_value()) {
916         dv = dv_.value();
917         TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
918         CHECK_DEVICE(dv);
919         TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
920         CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
921     } else {
922         dv = at::empty_like(v);
923     }
924 
925     // bool loop = seqlen_k > blocksize_c;
926     // TODO: change later, for now set to true for simplicity
927     bool loop = true;
928 
929     // Otherwise the kernel will be launched from cuda:0 device
930     // Cast to char to avoid compiler warning about narrowing
931     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
932 
933     auto opts = q.options();
934     auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
935     at::Tensor dq_accum;
936     at::Tensor dk_accum, dv_accum;
937     if (loop) {
938         if (!deterministic) {
939             dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
940         } else {
941             const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
942             dq_accum = at::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
943         }
944         // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
945         // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
946     }
947 
948     at::Tensor dk_expanded, dv_expanded;
949     if (num_heads_k != num_heads) {  // MQA / GQA
950         dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
951         dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
952     } else {
953         dk_expanded = dk;
954         dv_expanded = dv;
955     }
956 
957     Flash_bwd_params params;
958 
959     set_params_dgrad(params,
960                      batch_size,
961                      seqlen_q, seqlen_k,
962                      seqlen_q_rounded, seqlen_k_rounded,
963                      num_heads, num_heads_k,
964                      head_size, head_size_rounded,
965                      q, k, v, out,
966                      dout, dq, dk_expanded, dv_expanded,
967                      nullptr,
968                      nullptr,
969                      loop ? dq_accum.data_ptr() : nullptr,
970                      // loop ? dk_accum.data_ptr() : nullptr,
971                      // loop ? dv_accum.data_ptr() : nullptr,
972                      nullptr,
973                      nullptr,
974                      softmax_lse.data_ptr(),
975                      softmax_d.data_ptr(),
976                      p_dropout,
977                      softmax_scale,
978                      window_size_left,
979                      window_size_right,
980                      deterministic);
981     params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
982 
983     auto launch = &run_mha_bwd;
984 
985     at::PhiloxCudaState philox_args;
986     if (is_dropout) {
987         if (at::cuda::currentStreamCaptureStatus() ==
988                 at::cuda::CaptureStatus::None)
989         {
990             philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
991         } else { // dropout + capture
992             philox_args = at::PhiloxCudaState(
993                 philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
994         }
995     }
996     params.philox_args = philox_args;
997 
998     set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
999 
1000     if (seqlen_q > 0) {
1001         launch(params, stream);
1002     } else {
1003         // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
1004         dk_expanded.zero_();
1005         dv_expanded.zero_();
1006         softmax_d.zero_();
1007     }
1008 
1009     // For MQA/GQA we need to sum dK and dV across the groups
1010     if (num_heads_k != num_heads) {
1011         at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
1012         at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
1013     }
1014     return { dq, dk, dv, softmax_d };
1015 }
1016 
1017 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_varlen_bwd(const at::Tensor & dout,const at::Tensor & q,const at::Tensor & k,const at::Tensor & v,const at::Tensor & out,const at::Tensor & softmax_lse,std::optional<at::Tensor> & dq_,std::optional<at::Tensor> & dk_,std::optional<at::Tensor> & dv_,const at::Tensor & cu_seqlens_q,const at::Tensor & cu_seqlens_k,std::optional<at::Tensor> & alibi_slopes_,const int max_seqlen_q,const int max_seqlen_k,const float p_dropout,const float softmax_scale,const bool zero_tensors,const bool is_causal,int window_size_left,int window_size_right,const bool deterministic,const at::Tensor philox_seed,const at::Tensor philox_offset)1018 mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
1019                const at::Tensor &q,   // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
1020                const at::Tensor &k,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
1021                const at::Tensor &v,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
1022                const at::Tensor &out,   // total_q x num_heads x head_size
1023                const at::Tensor &softmax_lse,     // b x h x s   softmax logsumexp
1024                std::optional<at::Tensor> &dq_,   // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
1025                std::optional<at::Tensor> &dk_,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
1026                std::optional<at::Tensor> &dv_,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
1027                const at::Tensor &cu_seqlens_q,  // b+1
1028                const at::Tensor &cu_seqlens_k,  // b+1
1029                std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
1030                const int max_seqlen_q,
1031                const int max_seqlen_k,          // max sequence length to choose the kernel
1032                const float p_dropout,         // probability to drop
1033                const float softmax_scale,
1034                const bool zero_tensors,
1035                const bool is_causal,
1036                int window_size_left,
1037                int window_size_right,
1038                const bool deterministic,
1039                const at::Tensor philox_seed,
1040                const at::Tensor philox_offset)
1041 {
1042 
1043     #ifdef FLASHATTENTION_DISABLE_BACKWARD
1044         TORCH_CHECK(false, "This flash attention build does not support backward.");
1045     #endif
1046 
1047     if (is_causal) { window_size_right = 0; }
1048     auto dprops = at::cuda::getCurrentDeviceProperties();
1049     // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
1050     bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
1051     bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
1052     bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
1053     TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
1054     // We will support Turing in the near future
1055     // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
1056     bool is_dropout = p_dropout > 0.0;
1057     auto stream = at::cuda::getCurrentCUDAStream().stream();
1058 
1059     auto q_dtype = q.dtype();
1060     TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
1061                 "FlashAttention only support fp16 and bf16 data type");
1062     if (q_dtype == at::kBFloat16) {
1063         TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
1064     }
1065     TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
1066     TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
1067     TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
1068     TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
1069     TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32");
1070     TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32");
1071 
1072     CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
1073     CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
1074     CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
1075 
1076     TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1077     TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1078     TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1079     TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
1080     TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
1081     CHECK_CONTIGUOUS(cu_seqlens_q);
1082     CHECK_CONTIGUOUS(cu_seqlens_k);
1083 
1084     const auto sizes = q.sizes();
1085 
1086     const int total_q = sizes[0];
1087     const int batch_size = cu_seqlens_q.numel() - 1;
1088     const int num_heads = sizes[1];
1089     const int head_size_og = dout.size(2);
1090     const int head_size = sizes[2];
1091     const int total_k = k.size(0);
1092     const int num_heads_k = k.size(1);
1093     TORCH_CHECK(batch_size > 0, "batch size must be positive");
1094     TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
1095     TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
1096     TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
1097     if (head_size > 192 && (head_size <= 224 || is_dropout)) {
1098         TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
1099     }
1100     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
1101 
1102     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1103     const int head_size_rounded = round_multiple(head_size, 32);
1104     const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
1105     const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
1106 
1107     TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
1108 
1109     if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
1110     if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
1111 
1112     CHECK_SHAPE(q, total_q, num_heads, head_size);
1113     CHECK_SHAPE(k, total_k, num_heads_k, head_size);
1114     CHECK_SHAPE(v, total_k, num_heads_k, head_size);
1115     CHECK_SHAPE(out, total_q, num_heads, head_size);
1116     CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
1117     CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
1118     CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
1119 
1120     at::Tensor dq, dk, dv;
1121     if (dq_.has_value()) {
1122         dq = dq_.value();
1123         TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
1124         CHECK_DEVICE(dq);
1125         TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
1126         CHECK_SHAPE(dq, total_q, num_heads, head_size);
1127     } else {
1128         dq = at::empty_like(q);
1129     }
1130     if (dk_.has_value()) {
1131         dk = dk_.value();
1132         TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
1133         CHECK_DEVICE(dk);
1134         TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
1135         CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
1136     } else {
1137         dk = at::empty_like(k);
1138     }
1139     if (dv_.has_value()) {
1140         dv = dv_.value();
1141         TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
1142         CHECK_DEVICE(dv);
1143         TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
1144         CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
1145     } else {
1146         dv = at::empty_like(v);
1147     }
1148 
1149     // bool loop = max_seqlen_k > blocksize_c;
1150     // TODO: change later, for now set to true for simplicity
1151     bool loop = true;
1152 
1153     // Otherwise the kernel will be launched from cuda:0 device
1154     // Cast to char to avoid compiler warning about narrowing
1155     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
1156 
1157     auto opts = q.options();
1158     auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
1159     at::Tensor dq_accum;
1160     if (loop) {
1161         // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
1162         // because that would be too large if there is a very long sequence and the rest of the sequences are short.
1163         // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded).
1164         // Note that 128 is the max block size on the seqlen_q dimension.
1165         // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to
1166         // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
1167         // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
1168         // allowed to do. So we won't have to do any bound checking, and performance should stay the same.
1169         if (!deterministic) {
1170             dq_accum = at::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
1171         } else {
1172             const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
1173             dq_accum = at::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
1174         }
1175     }
1176 
1177     at::Tensor dk_expanded, dv_expanded;
1178     if (num_heads_k != num_heads) {  // MQA / GQA
1179         dk_expanded = at::empty({total_k, num_heads, head_size}, opts);
1180         dv_expanded = at::empty({total_k, num_heads, head_size}, opts);
1181     } else {
1182         dk_expanded = dk;
1183         dv_expanded = dv;
1184     }
1185 
1186     if( zero_tensors ) {
1187         dq.zero_();
1188         dk_expanded.zero_();
1189         dv_expanded.zero_();
1190         softmax_d.zero_();
1191     }
1192 
1193     Flash_bwd_params params;
1194 
1195     set_params_dgrad(params,
1196                      batch_size,
1197                      max_seqlen_q, max_seqlen_k,
1198                      seqlen_q_rounded, seqlen_k_rounded,
1199                      num_heads, num_heads_k,
1200                      head_size, head_size_rounded,
1201                      q, k, v, out,
1202                      dout, dq, dk_expanded, dv_expanded,
1203                      cu_seqlens_q.data_ptr(),
1204                      cu_seqlens_k.data_ptr(),
1205                      loop ? dq_accum.data_ptr() : nullptr,
1206                      nullptr,
1207                      nullptr,
1208                      softmax_lse.data_ptr(),
1209                      softmax_d.data_ptr(),
1210                      p_dropout,
1211                      softmax_scale,
1212                      window_size_left,
1213                      window_size_right,
1214                      deterministic);
1215     params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
1216 
1217     auto launch = &run_mha_bwd;
1218 
1219     at::PhiloxCudaState philox_args;
1220     if (is_dropout) {
1221         if (at::cuda::currentStreamCaptureStatus() ==
1222                 at::cuda::CaptureStatus::None)
1223         {
1224             philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
1225         } else { // dropout + capture
1226             philox_args = at::PhiloxCudaState(
1227                 philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
1228         }
1229     }
1230     params.philox_args = philox_args;
1231 
1232     set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
1233 
1234     if (max_seqlen_q > 0) {
1235         launch(params, stream);
1236     } else {
1237         // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
1238         dk_expanded.zero_();
1239         dv_expanded.zero_();
1240         softmax_d.zero_();
1241     }
1242 
1243     // For MQA/GQA we need to sum dK and dV across the groups
1244     if (num_heads_k != num_heads) {
1245         at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
1246         at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
1247     }
1248 
1249     return { dq, dk, dv, softmax_d };
1250 }
1251 
1252 std::tuple<at::Tensor, at::Tensor>
mha_fwd_kvcache(at::Tensor & q,const at::Tensor & kcache,const at::Tensor & vcache,std::optional<const at::Tensor> & k_,std::optional<const at::Tensor> & v_,std::optional<const at::Tensor> & seqlens_k_,std::optional<const at::Tensor> & rotary_cos_,std::optional<const at::Tensor> & rotary_sin_,std::optional<const at::Tensor> & cache_batch_idx_,std::optional<at::Tensor> & block_table_,std::optional<at::Tensor> & alibi_slopes_,std::optional<at::Tensor> & out_,const float softmax_scale,bool is_causal,int window_size_left,int window_size_right,bool is_rotary_interleaved,int num_splits)1253 mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_heads x head_size
1254                 const at::Tensor &kcache,            // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
1255                 const at::Tensor &vcache,            // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
1256                 std::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
1257                 std::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
1258                 std::optional<const at::Tensor> &seqlens_k_, // batch_size
1259                 std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
1260                 std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
1261                 std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
1262                 std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
1263                 std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
1264                 std::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x head_size
1265                 const float softmax_scale,
1266                 bool is_causal,
1267                 int window_size_left,
1268                 int window_size_right,
1269                 bool is_rotary_interleaved,   // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
1270                 int num_splits
1271                 ) {
1272 
1273     auto dprops = at::cuda::getCurrentDeviceProperties();
1274     // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
1275     bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
1276     bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
1277     TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
1278     // We will support Turing in the near future
1279     // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
1280 
1281     auto q_dtype = q.dtype();
1282     TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
1283                 "FlashAttention only support fp16 and bf16 data type");
1284     if (q_dtype == at::kBFloat16) {
1285         TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
1286     }
1287     TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
1288     TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
1289 
1290     CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
1291 
1292     TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1293     TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1294     TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1295 
1296     at::Tensor block_table;
1297     const bool paged_KV = block_table_.has_value();
1298     if (paged_KV) {
1299         TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
1300         block_table = block_table_.value();
1301         CHECK_DEVICE(block_table);
1302         TORCH_CHECK(block_table.dtype() == at::kInt, "block_table must have dtype torch.int32");
1303         TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
1304     }
1305 
1306     const auto sizes = q.sizes();
1307 
1308     const int batch_size = sizes[0];
1309     int seqlen_q = sizes[1];
1310     int num_heads = sizes[2];
1311     const int head_size_og = sizes[3];
1312 
1313     const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
1314     const int num_blocks = !paged_KV ? 0 : kcache.size(0);
1315     const int page_block_size = !paged_KV ? 1 : kcache.size(1);
1316     TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
1317     const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
1318     const int num_heads_k = kcache.size(2);
1319     const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
1320     TORCH_CHECK(batch_size > 0, "batch size must be postive");
1321     TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
1322     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
1323 
1324     // causal=true is the same as causal=false in this case
1325     if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
1326     if (is_causal) { window_size_right = 0; }
1327 
1328     // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
1329     // H/t Daniel Haziza
1330     const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
1331     if (seqlenq_ngroups_swapped) {
1332         const int ngroups = num_heads / num_heads_k;
1333         q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
1334         seqlen_q = ngroups;
1335         num_heads = num_heads_k;
1336     }
1337 
1338     if (window_size_left >= seqlen_k) { window_size_left = -1; }
1339     if (window_size_right >= seqlen_k) { window_size_right = -1; }
1340 
1341     CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
1342     if (!paged_KV) {
1343         CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
1344         CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
1345     } else {
1346         CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
1347         CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
1348         CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
1349     }
1350 
1351     at::Tensor q_padded, kcache_padded, vcache_padded;
1352     if (head_size_og % 8 != 0) {
1353         q_padded = at::pad(q, {0, 8 - head_size_og % 8});
1354         kcache_padded = at::pad(kcache, {0, 8 - head_size_og % 8});
1355         vcache_padded = at::pad(vcache, {0, 8 - head_size_og % 8});
1356         // q_padded = at::nn::functional::pad(q, at::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1357         // kcache_padded = at::nn::functional::pad(kcache, at::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1358         // vcache_padded = at::nn::functional::pad(vcache, at::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1359     } else {
1360         q_padded = q;
1361         kcache_padded = kcache;
1362         vcache_padded = vcache;
1363     }
1364 
1365     at::Tensor out;
1366     if (out_.has_value()) {
1367         out = out_.value();
1368         TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
1369         CHECK_DEVICE(out);
1370         TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
1371         CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
1372         if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); }
1373     } else {
1374         out = at::empty_like(q_padded);
1375     }
1376 
1377     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1378     const int head_size = round_multiple(head_size_og, 8);
1379     const int head_size_rounded = round_multiple(head_size, 32);
1380     const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
1381     const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
1382 
1383     // Otherwise the kernel will be launched from cuda:0 device
1384     // Cast to char to avoid compiler warning about narrowing
1385     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
1386 
1387     auto opts = q.options();
1388 
1389     auto softmax_lse = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
1390 
1391     Flash_fwd_params params;
1392     set_params_fprop(params,
1393                      batch_size,
1394                      seqlen_q, seqlen_k,
1395                      seqlen_q_rounded, seqlen_k_rounded,
1396                      num_heads, num_heads_k,
1397                      head_size, head_size_rounded,
1398                      q_padded, kcache_padded, vcache_padded, out,
1399                      /*cu_seqlens_q_d=*/nullptr,
1400                      /*cu_seqlens_k_d=*/nullptr,
1401                      /*seqused_k=*/nullptr,
1402                      /*p_ptr=*/nullptr,
1403                      softmax_lse.data_ptr(),
1404                      /*p_dropout=*/0.f,
1405                      softmax_scale,
1406                      window_size_left,
1407                      window_size_right);
1408 
1409     at::Tensor k, v, k_padded, v_padded;
1410     if (k_.has_value()) {
1411         TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
1412         TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
1413         TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
1414         k = k_.value();
1415         v = v_.value();
1416         TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
1417         TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
1418         CHECK_DEVICE(k); CHECK_DEVICE(v);
1419         TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
1420         TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
1421         int seqlen_knew = k.size(1);
1422         CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
1423         CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
1424         if (head_size_og % 8 != 0) {
1425             k_padded = at::pad(k, {0, 8 - head_size_og % 8});
1426             v_padded = at::pad(v, {0, 8 - head_size_og % 8});
1427             // k_padded = at::nn::functional::pad(k, at::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1428             // v_padded = at::nn::functional::pad(v, at::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1429         } else {
1430             k_padded = k;
1431             v_padded = v;
1432         }
1433         params.seqlen_knew = seqlen_knew;
1434         params.knew_ptr = k_padded.data_ptr();
1435         params.vnew_ptr = v_padded.data_ptr();
1436         // All stride are in elements, not bytes.
1437         params.knew_batch_stride = k_padded.stride(0);
1438         params.vnew_batch_stride = v_padded.stride(0);
1439         params.knew_row_stride = k_padded.stride(-3);
1440         params.vnew_row_stride = v_padded.stride(-3);
1441         params.knew_head_stride = k_padded.stride(-2);
1442         params.vnew_head_stride = v_padded.stride(-2);
1443     }
1444 
1445     if (seqlens_k_.has_value()) {
1446         auto seqlens_k = seqlens_k_.value();
1447         TORCH_CHECK(seqlens_k.dtype() == at::kInt, "seqlens_k must have dtype int32");
1448         CHECK_DEVICE(seqlens_k);
1449         CHECK_CONTIGUOUS(seqlens_k);
1450         CHECK_SHAPE(seqlens_k, batch_size);
1451         params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
1452     }
1453     params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
1454 
1455     if (rotary_cos_.has_value()) {
1456         TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
1457         auto rotary_cos = rotary_cos_.value();
1458         CHECK_DEVICE(rotary_cos);
1459         params.rotary_dim = rotary_cos.size(1) * 2;
1460         TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
1461         TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
1462         const int seqlen_ro = rotary_cos.size(0);
1463         TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
1464         CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
1465         CHECK_CONTIGUOUS(rotary_cos);
1466         TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
1467 
1468         TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
1469         auto rotary_sin = rotary_sin_.value();
1470         CHECK_DEVICE(rotary_sin);
1471         CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
1472         CHECK_CONTIGUOUS(rotary_sin);
1473         TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
1474         params.rotary_cos_ptr = rotary_cos.data_ptr();
1475         params.rotary_sin_ptr = rotary_sin.data_ptr();
1476         params.is_rotary_interleaved = is_rotary_interleaved;
1477     } else {
1478         params.rotary_dim = 0;
1479     }
1480 
1481     if (cache_batch_idx_.has_value()) {
1482         auto cache_batch_idx = cache_batch_idx_.value();
1483         CHECK_DEVICE(cache_batch_idx);
1484         CHECK_CONTIGUOUS(cache_batch_idx);
1485         TORCH_CHECK(cache_batch_idx.scalar_type() == at::kInt, "cache_batch_idx must have dtype int32");
1486         params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
1487     }
1488 
1489     // Keep references to these tensors to extend their lifetime
1490     at::Tensor softmax_lse_accum, out_accum;
1491     std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads,
1492                        head_size, seqlen_k, seqlen_q,
1493                        head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts);
1494 
1495     if (paged_KV) {
1496         params.block_table = block_table.data_ptr<int>();
1497         params.block_table_batch_stride = block_table.stride(0);
1498     }
1499     params.page_block_size = page_block_size;
1500 
1501 
1502     set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
1503 
1504     auto stream = at::cuda::getCurrentCUDAStream().stream();
1505     // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
1506     // or paged KV cache
1507     run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV);
1508 
1509     if (head_size_og % 8 != 0) {
1510         // out = out.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)});
1511         out = out.narrow(-1, 0, head_size_og);
1512         if (out_.has_value()) { out_.value().copy_(out); }
1513         if (k_.has_value()) {
1514             // It's expensive to copy the KV cache here for the case where head size not divisible by 8,
1515             // but we don't expect to get this case in practice. This is just so that the code works for that case.
1516             kcache.copy_(kcache_padded.narrow(-1, 0, head_size_og));
1517             vcache.copy_(vcache_padded.narrow(-1, 0, head_size_og));
1518             // kcache.copy_(kcache_padded.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}));
1519             // vcache.copy_(vcache_padded.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}));
1520         }
1521     }
1522 
1523     if (seqlenq_ngroups_swapped) {
1524         out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
1525         softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
1526     }
1527     return {out, softmax_lse};
1528 }
1529 
1530 } // namespace pytorch_fmha
1531 
1532 #endif
1533